In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#!g1.1
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import pm4py
import tqdm

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

random.seed(3407)
torch.manual_seed(3407)
torch.cuda.manual_seed(3407)
np.random.seed(3407)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import wandb
wandb.login()

In [None]:
#!g1.1
LOG_TYPE = 'bpi12'
event_log = pm4py.objects.log.importer.xes.importer.apply(f'./data/{LOG_TYPE}.xes')

### PM4PY log view

В дальнейшем отфильтруем трейсы с крайне редкими окончаниями

In [None]:
#!g1.1
print("Log datetime:", event_log[0][1]['time:timestamp'])
print("Converting to ts", int(event_log[0][1]['time:timestamp'].timestamp()))
print()

print("Trace view")
print(event_log[0])
print()

print("Event view")
print(event_log[0][1])
print()

print("Start activities")
print(pm4py.get_start_activities(event_log))
print()

print("End activities")
print(pm4py.get_end_activities(event_log))
print()

### Traces lengths distribution

В дальнейшем отфильтруем еще по длине (возьмем только трейсы длиннее 2)

In [None]:
#!g1.1
lens = [len(t) for t in event_log]

plt.figure(figsize=(12, 6))
sns.histplot(lens, kde=True).grid()

print(f'Median: {np.median(lens)}')
print(f'Percentile 75: {np.percentile(lens, q=75)}')
print(f'Percentile 90: {np.percentile(lens, q=90)}')
print(f'Percentile 95: {np.percentile(lens, q=95)}')
print(f'Percentile 99: {np.percentile(lens, q=99)}')

### Traces durations distribution

In [None]:
#!g1.1
durs = pm4py.get_all_case_durations(event_log)

plt.figure(figsize=(12, 6))
sns.histplot(durs, kde=True).grid()

print(f'Median: {np.median(durs)}')
print(f'Percentile 75: {np.percentile(durs, q=75)}')
print(f'Percentile 90: {np.percentile(durs, q=90)}')
print(f'Percentile 95: {np.percentile(durs, q=95)}')
print(f'Percentile 99: {np.percentile(durs, q=99)}')

### Activities frequency

In [None]:
#!g1.1
activities = {}
for t in event_log:
    for e in t:
        activities[e['concept:name']] = activities.get(e['concept:name'], 0) + 1

for act_name, freq in sorted(activities.items(), key=lambda x: x[1]):
    print(f"Activity '{act_name}': {freq}")

# Logs cleanup

1. Фильтруем короткие и длинные трейсы (оставляем длинее 2 и короче 150)
2. Фильтруем короткие по времени трейсы (оставляем длиннее 30 секунд)
3. Фильтруем трейсы, которые заканчиваются на 'A_REGISTERED' или 'W_Wijzigen contractgegevens'
4. Фильтруем трейсы, которые содержат 'W_Wijzigen contractgegevens'
5. Делим на train, val, test по времени начала в хронологическом порядке

In [None]:
#!g1.1
from logmentations.datasets import filter_log

event_log_filtered = filter_log(event_log, LOG_TYPE)

### Define mapping from activity name to activity id

In [None]:
#!g1.1
act2id = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2}
id2act = {0: '<PAD>', 1: '<BOS>', 2: '<EOS>'}

freqs = {}

current_id = 3
for t in event_log_filtered:
    for e in t:
        if e['concept:name'] not in act2id:
            act2id[e['concept:name']] = current_id
            id2act[current_id] = e['concept:name']
            current_id += 1

        freqs[act2id[e['concept:name']]] = freqs.get(act2id[e['concept:name']], 0) + 1

events_cnt = sum(cnt for act, cnt in freqs.items())
weights = {act: events_cnt / (2 * cnt) for act, cnt in freqs.items()}
print(weights)

In [None]:
#!g1.1
# new lengths
lens = [len(t) for t in event_log_filtered]

plt.figure(figsize=(12, 6))
sns.histplot(lens, kde=True).grid()

print(f'Median: {np.median(lens)}')
print(f'Percentile 75: {np.percentile(lens, q=75)}')
print(f'Percentile 90: {np.percentile(lens, q=90)}')
print(f'Percentile 95: {np.percentile(lens, q=95)}')
print(f'Percentile 99: {np.percentile(lens, q=99)}')

In [None]:
#!g1.1
# new durations
durs = pm4py.get_all_case_durations(event_log_filtered)

plt.figure(figsize=(12, 6))
sns.histplot(durs, kde=True).grid()

print(f'Median: {np.median(durs)}')
print(f'Percentile 75: {np.percentile(durs, q=75)}')
print(f'Percentile 90: {np.percentile(durs, q=90)}')
print(f'Percentile 95: {np.percentile(durs, q=95)}')
print(f'Percentile 99: {np.percentile(durs, q=99)}')

In [None]:
#!g1.1
from logmentations.utils import uniform_kl

# new frequencies
activities = {}
for t in event_log_filtered:
    for e in t:
        activities[act2id[e['concept:name']]] = activities.get(act2id[e['concept:name']], 0) + 1

plt.figure(figsize=(12, 6))
print(activities)
sns.barplot(
    x=[p[0] for p in activities.items()],
    y=[p[1] for p in activities.items()]
).grid()

initial_probs = np.array([val for _, val in activities.items()]) / sum(val for _, val in activities.items())
print(f"Initial uniformed KL: {uniform_kl(initial_probs)}")

In [None]:
#!g1.1
from logmentations.utils import time_aware_data_split

train_log, val_log, test_log = time_aware_data_split(event_log_filtered, (0.7, 0.1, 0.2))

# Log datasets

In [None]:
#!g1.1
from logmentations.datasets import LogsDataset

normalizer_value = np.percentile(
    [np.max(np.diff([t[i]['time:timestamp'].timestamp() for i in range(len(t))]))
         for t in event_log_filtered], q=90
)

def time_scaling(time: float) -> float:
    return time / normalizer_value

train_ds = LogsDataset(train_log, act2id, time_applyer=time_scaling)
val_ds = LogsDataset(val_log, act2id, time_applyer=time_scaling)
test_ds = LogsDataset(test_log, act2id, time_applyer=time_scaling)

print(f'Normalizer value: {normalizer_value}')

# Baseline augmentation

In [None]:
#!g1.1
from logmentations.augmentations import StatisticsAugmentation

aug = StatisticsAugmentation(act2id).fit(train_log)

In [None]:
#!g1.1
rare_act_ids = [23, 25]
# rare_act_ids = [21, 22, 25, 26]

In [None]:
#!g1.1
synthetic_traces = []
for _ in tqdm.tqdm(range(50000), "Sampling traces"):
    seq = aug.sample()
    if any(map(lambda x: x[0] in rare_act_ids, seq)):
        for i in range(len(seq)):
            seq[i] = (seq[i][0], time_scaling(seq[i][1]))
        synthetic_traces.append(seq)

print(len(synthetic_traces))

In [None]:
#!g1.1
activities = {}
for t in synthetic_traces:
    for e in t:
        activities[e[0]] = activities.get(e[0], 0) + 1

plt.figure(figsize=(12, 6))
print(activities)
sns.barplot(
    x=[p[0] for p in activities.items()],
    y=[p[1] for p in activities.items()]
).grid()

final_probs = np.array([val for _, val in activities.items()]) / sum(val for _, val in activities.items())
print(f"Final uniformed KL: {uniform_kl(final_probs)}")

In [None]:
#!g1.1
from copy import deepcopy

base_train_ds = deepcopy(train_ds)
train_ds.extend_log(synthetic_traces)
aug_train_ds = train_ds

In [None]:
#!g1.1
from logmentations.utils import prediction_collate_fn
from logmentations.datasets import SlicedLogsDataset
from logmentations.datasets import LengthAwareSampler

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 128
NUM_WORKERS = 4

sliced_base_train_ds = SlicedLogsDataset(base_train_ds)
base_train_loader = torch.utils.data.DataLoader(
    dataset=sliced_base_train_ds,
    batch_sampler=LengthAwareSampler(
        data_len=len(sliced_base_train_ds),
        batch_size=BATCH_SIZE,
        group_size=BATCH_SIZE * 16
    ),
    collate_fn=prediction_collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

sliced_aug_train_ds = SlicedLogsDataset(aug_train_ds)
aug_train_loader = torch.utils.data.DataLoader(
    dataset=sliced_aug_train_ds,
    batch_sampler=LengthAwareSampler(
        data_len=len(sliced_aug_train_ds),
        batch_size=BATCH_SIZE,
        group_size=BATCH_SIZE * 16
    ),
    shuffle=False,
    collate_fn=prediction_collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

val_loader = torch.utils.data.DataLoader(
    dataset=SlicedLogsDataset(val_ds),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=prediction_collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

test_loader = torch.utils.data.DataLoader(
    SlicedLogsDataset(test_ds),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=prediction_collate_fn,
    pin_memory=True,
    num_workers=NUM_WORKERS
)


# Model training

In [None]:
#!g1.1
from logmentations.models import LstmModel
from logmentations.training import BaseConfig, train_predictive_epoch, eval_predictive_model

model = LstmModel(
    vocab_size=26, n_features=27,
    emb_size=64, hid_size=128,
    num_layers=3, bidirectional=True,
    predict_time=True
).to(DEVICE)

N_EPOCHS = 100
SAVE_PERIOD = 25
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)
act_weight, time_weight = 1., 0.8

CONFIG_BASE = BaseConfig({
    "n_epochs": N_EPOCHS,
    "save_period": SAVE_PERIOD,
    "optimizer": optimizer,
    "scheduler": scheduler,
    "grad_clip_value": 5.,
    "act_weight": 1.,
    "time_weight": 0.8,
    "device": DEVICE
})

In [None]:
#!g1.1
run = wandb.init(
    project="GenModels4PBPM-Prediction",
    entity="serp404",
    tags=["prediction", "origin_data", LOG_TYPE]
)

save_path = os.path.join("./checkpoints", run.name)
if not os.path.exists(save_path):
    os.mkdir(save_path)

best_f1 = None
for epoch in tqdm.notebook.tqdm(range(N_EPOCHS), "Training"):
    # Train step
    train_loss, train_ce, train_mae, grad_norm = train_predictive_epoch(
        model, base_train_loader, CONFIG_BASE
    )
    scheduler.step()

    # Validation step
    val_loss, val_ce, val_mae, val_accuracy, val_f1_macro = eval_predictive_model(
        model, val_loader, CONFIG_BASE
    )

    wandb.log(
        {
            "epoch": epoch,
            "train_loss": train_loss,
            "train_ce": train_ce,
            "train_mae": train_mae,
            "val_loss": val_loss,
            "val_ce": val_ce,
            "val_mae": val_mae,
            "val_accuracy": val_accuracy,
            "val_f1_macro_score": val_f1_macro,
            "grad_norm": grad_norm,
            "lr": optimizer.param_groups[0]['lr']
        }
    )

    if best_f1 is None or val_f1_macro > best_f1:
        torch.save(
            model.state_dict(),
            os.path.join(save_path, f"model_best.pth")
        )
        best_f1 = val_f1_macro

    if epoch % SAVE_PERIOD == 0:
        torch.save(
            model.state_dict(),
            os.path.join(save_path, f"model_e{epoch}.pth")
        )

run.finish()

# Model evaluation

In [None]:
#!g1.1
from logmentations.models import LstmModel
from logmentations.training import eval_prediction_test_metrics

# Test step
model_best = LstmModel(
    vocab_size=26, n_features=27,
    emb_size=64, hid_size=128,
    num_layers=3, bidirectional=True,
    predict_time=True
).to(DEVICE)

model_best.load_state_dict(torch.load(os.path.join(save_path, "model_best.pth"), map_location=DEVICE))

def time2days(time: float) -> float:
    return invert_scaling(time) / 3600 / 24

N_RUNS = 20
loss, ce, inv_mae, accuracy, f1 = eval_prediction_test_metrics(
    model_best, test_loader, CONFIG_BASE,
    time2days=time2days, n_runs=N_RUNS
)

print(f'Loss: {loss}')
print(f'Loss CE: {ce}')
print(f'Loss MAE: {inv_mae}')
print(f'Accuracy: {accuracy}')
print(f'F1-macro: {f1}')

In [None]:
#!g1.1