In [None]:
import torch
import pickle
import numpy as np
import pandas as pd
import os

from os.path import dirname



root_path = dirname(os.getcwd()) + "/HGNN_NA"

pd.set_option("display.max_columns", None)
data_dir = root_path + "/data/datasets/original/"
data_dir_processed = root_path + "/data/datasets/processed/"
data_dir_graphs = root_path + "/data/datasets/graphs/"

print(root_path, data_dir, data_dir_processed, data_dir_graphs, sep="\n")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [2]:
dataset = "BPI_Challenge_2012_A"

In [None]:
tab_all = pd.read_csv(data_dir_processed+dataset+"_processed_all.csv")
print(tab_all.head())
list_activities = list(tab_all["Activity"].unique())

In [4]:
import random

torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [5]:
with open(data_dir_graphs + dataset + "_TRAIN_event_prediction.pkl", "rb") as f:
    X_train, Y_train = pickle.load(f)
with open(data_dir_graphs + dataset + "_VALID_event_prediction.pkl", "rb") as f:
    X_valid, Y_valid = pickle.load(f)
with open(data_dir_graphs + dataset + "_TEST_event_prediction.pkl", "rb") as f:
    X_test, Y_test = pickle.load(f)

In [6]:
from typing_extensions import Self
from torch_geometric.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.transforms import ToUndirected, NormalizeFeatures

transform = ToUndirected()
t2 = NormalizeFeatures()


class Het_graph_data(Dataset):
    def __init__(self, prefix_graphs, labels) -> Self:
        self.X = prefix_graphs
        self.Y = labels

    # get the number of rows in the dataset
    def __len__(self):
        return len(self.Y)

    # get a row at a particular index in the dataset
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    @staticmethod
    def collate(batch):
        # print(batch)
        data = [t2(item[0]) for item in batch]
        # data = [transform(item[0]) for item in batch]
        
        Y = [item[1] for item in batch]
        return [data, Y]

In [8]:
train_loader = DataLoader(
    Het_graph_data(X_train + X_valid, Y_train + Y_valid),
    batch_size=128,
    shuffle=True,
    collate_fn=Het_graph_data.collate,
)


valid_loader = DataLoader(
    Het_graph_data(X_valid, Y_valid),
    batch_size=2,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

test_loader = DataLoader(
    Het_graph_data(X_test, Y_test),
    batch_size=32,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

In [9]:
# A Class to keep track of the metrics of the classification process
class ClassificationMetrics:

  # Constructor takes the number of classes, in our case 20
  def __init__(self, num_classes=20):
    self.num_classes = num_classes
    # Initialize a confusion matrix
    self.C = torch.zeros(num_classes, num_classes)

  # Update the confusion matrix with the new scores
  def add(self, yp, yt):
    # yp: 1D tensor with predictions
    # yt: 1D tensor with ground-truth targets
    yp = yp.to("cpu")
    yt = yt.to("cpu")
    with torch.no_grad(): # We require no computation graph
      self.C+=(yt*self.C.shape[1]+yp).bincount(minlength=self.C.numel()).view(self.C.shape).float()

  def clear(self):
    # We set the confusion matrix to zero
    self.C.zero_()

  # Computes the global accuracy
  def acc(self):
    return self.C.diag().sum().item()/self.C.sum()

  # Computes the class-averaged accuracy
  def mAcc(self):
    return (self.C.diag()/self.C.sum(-1)).mean().item()

  # Computers the class-averaged Intersection over Union
  def mIoU(self):
    return (self.C.diag()/(self.C.sum(0)+self.C.sum(1)-self.C.diag())).mean().item()

  # Returns the confusion matrix
  def confusion_matrix(self):
    return self.C

In [10]:
loaders = {"train": train_loader, "validation" : valid_loader, "test" : test_loader}

In [11]:
node_types, edge_types = X_train[0].metadata()

In [None]:
node_types

In [None]:
edge_types

In [50]:
import datetime

from torch import nn
from tqdm.notebook import tqdm
from copy import deepcopy
from typing_extensions import Self
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv, GCN  # Linear, GCNConv
from torch.nn import ModuleList, Module, Sequential, Softmax, Dropout, Linear, ReLU
from torch import mean, stack, concat


class HGNN(Module):

    def __init__(self, hid, out, layers, node_types, nodes_relations) -> Self:  # type: ignore
        super().__init__()

        # List of convolutional layers
        self.convs = ModuleList()
        for _ in range(layers):
            conv = HeteroConv(
                {relation: SAGEConv((-1, -1), hid, normalize=False) for relation in nodes_relations}
                # {('Activity', 'followed_by', 'Activity') : SAGEConv((-1,-1), hid),
                #  ("org:resource", "related_to", 'org:resource') : SAGEConv((-1,-1), hid),
                #  ('time:timestamp', "related_to", "time:timestamp") : SAGEConv((-1,-1), hid)}
                ,
                aggr="sum",
            )
            self.convs.append(conv)

        # print(nodes_relations)
        # Take each node hid representation and apply a linear layer
        # self.linear_nodes = Sequential(Linear(hid, hid),ReLU(), Dropout(p=0.5), Linear(hid, hid), ReLU(), Dropout(p=0.5), Linear(hid, int(hid / 2)), ReLU())

        # Return the softmax with the class probabilities
        # self.fc = Sequential(Linear(int(hid/2)*(len(node_types)), out), ReLU(), Softmax(dim=0))
        self.fc_NA = Sequential(
            Linear(hid * (len(node_types)), 256),
            ReLU(),
            nn.BatchNorm1d(256),
            Dropout(p=0.5),
            Linear(256, 128),
            ReLU(),
            nn.BatchNorm1d(128),
            Linear(128, out),
            # Softmax(dim=1)
            # ReLU()
        )
        self.fc_timestamp = Sequential(
            Linear(hid * (len(node_types)), 256),
            ReLU(),
            nn.BatchNorm1d(256),
            Dropout(p=0.5),
            Linear(256, 128),
            ReLU(),
            nn.BatchNorm1d(128),
            Linear(128, 1),
            
        )
        self.fc_org_resource = Sequential(
            Linear(hid * (len(node_types)), 256),
            ReLU(),
            nn.BatchNorm1d(256),
            Dropout(p=0.5),
            Linear(256, 128),
            ReLU(),
            nn.BatchNorm1d(128),
            Linear(128, 1),
        )

    def forward(self, x, edge_index):
        outs = []
        for i in range(len(x)):
            for conv in self.convs:
                x[i] = conv(x[i], edge_index[i])
                x[i] = {node_type: nodes_features.relu() for node_type, nodes_features in x[i].items()}

            # Node features of each node in the graph
            nodes_features = [x[i][key] for key in x[i].keys()]
            # print(nodes_features)
            

            # Global sum of each node type
            for j in range(len(nodes_features)):
                nodes_features[j] = torch.sum(nodes_features[j], dim=0)

            # Global mean pooling
            # nodes_features = mean(stack(nodes_features), dim=0)
            nodes_features = concat(nodes_features)
            outs.append(nodes_features)

        outs = stack(outs)

        activities = self.fc_NA(outs)
        org_resources = self.fc_org_resource(outs)
        timestamps = self.fc_timestamp(outs)

        return [activities, timestamps, org_resources]  # {key : self.linear(x_dict[key]) for key in x_dict.keys()}, nodes_features

org:resource, related_to, org:resource)={
    edge_attr=[1, 1],
    edge_index=[2, 1],
  },
  (Activity, followed_by, Activity)={
    edge_attr=[1, 1],
    edge_index=[2, 1],
  },
  (time:timestamp, related_to, time:timestamp)

In [None]:
weights = []

In [None]:
cl_train = {key:0 for key in range(10)}
for i,(x,y) in enumerate(train_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_train[c.item()] +=1
        except KeyError:
            cl_train[c.item()] = 1
            
cl_train

In [None]:
cl_val = {key:0 for key in range(10)}
for i,(x,y) in enumerate(valid_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_val[c.item()] +=1
        except KeyError:
            cl_val[c.item()] = 1
            
cl_val

In [None]:
cl_test = {key:0 for key in range(10)}
for i,(x,y) in enumerate(test_loader):
    _,classes = torch.max(stack(y), dim=1)
    for c in list(classes):
        try:
            cl_test[c.item()] +=1
        except KeyError:
            cl_test[c.item()] = 1
            
cl_test

In [None]:
from torch import tensor


s = 0
for i in cl_train:
    s += cl_train[i]
weights = [s/cl_train[k] if cl_train[k] != 0 else 0 for k in cl_train ]

# weights = [0.7,0.7,1,0.7,0.7,0.7,0.7,0.7,0.7,0.7]
weights = tensor(weights, device=device)
weights[0] *= 0.7
weights[1] *= 1
weights[2] *= 0.7
weights[3] *= 0.7
weights[4] *= 1

weights[0] *= 1
weights

In [51]:
model = HGNN(hid=256, out=len(list_activities), layers=4, node_types=node_types, nodes_relations=edge_types)
model.to(device)

HGNN(
  (convs): ModuleList(
    (0-3): 4 x HeteroConv(num_relations=8)
  )
  (fc_NA): Sequential(
    (0): Linear(in_features=1536, out_features=64, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Linear(in_features=32, out_features=10, bias=True)
  )
  (fc_timestamp): Sequential(
    (0): Linear(in_features=1536, out_features=64, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Linear(in_features=32, out_features=1, bias=True)
  )
  

In [None]:


with torch.no_grad():
    for i,(x,y) in enumerate(train_loader):
        x = [xx.to(device) for xx in x]
        
        x_edge_index_dicts = [xx.edge_index_dict for xx in x]
        x_dicts = [xx.x_dict for xx in x]
        # print("//"*50)
        #print(x)
        #print(x.edge_index_dict)
        #print(x.x_dict)
        # print("//"*50)
        
        y = [yy.to(device) for yy in y]
        y = stack(y)
        model(x_dicts, x_edge_index_dicts)
        break

In [48]:
from torch import int32, tensor

# model = HGNN(hid=16, out=len(list_activities), layers=2, node_types=node_types, nodes_relations=edge_types)


num_epochs = 20
best_accuracy = 0
early_stop_patience = 10
lr_value = 0.01

best_model = None

num_runs = 1
running_time = []

metric_tracker = ClassificationMetrics(num_classes=len(list_activities))

for run in range(num_runs):

    start = datetime.datetime.now()
    print("Run: {}".format(run + 1))

    # model = model.to(device)

    # print("compiling...")
    # model = torch.compile(model, dynamic=True)
    # print("Compiled model")

    # criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
    # criterion = nn.CrossEntropyLoss(weight=weights)
    act_loss = nn.CrossEntropyLoss()
    timestamp_loss = nn.L1Loss()
    org_resource_loss = nn.L1Loss()

    # optimizer = torch.optim.Adagrad(model.parameters(), lr=lr_value) # momentum=0.9, weight_decay=1e-1)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_value)

    not_improved_count = 0

    for epoch in range(num_epochs):
        print(
            "\n-- EPOCH {}/{} -------------------------\n".format(epoch + 1, num_epochs)
        )
        torch.cuda.empty_cache()
        count_train = [0 for _ in range(len(list_activities))]
        count_true_train = [0 for _ in range(len(list_activities))]
        count_val = [0 for _ in range(len(list_activities))]
        count_true_val = [0 for _ in range(len(list_activities))]


        avg_mse_timestamp = []
        avg_mse_org_resource = []

        for state in ["train", "test"]:


            if state == "train":
                model.train()
                metric_tracker.clear()
             

            else:

         
                print(
                    "\tTRAIN | NA_acc: {:.4f} | mae_timestamp(s): {:.4f} | mae_org_resource: {:.4f}".format(
                        metric_tracker.acc(),
                        sum(avg_mse_timestamp) / len(avg_mse_timestamp),
                        sum(avg_mse_org_resource) / len(avg_mse_org_resource)
                    )
                )

                avg_mse_timestamp = []
                avg_mse_org_resource = []
                metric_tracker.clear()
                model.eval()
             

            

            for i, (x, y) in tqdm(enumerate(loaders[state])):

                # if i % 1 == 0:
                #     print(metric_tracker.acc())
                #     # print(running_loss/i)
                #     print(metric_tracker.confusion_matrix())
                #     # print(count_train)
                #     # print(count_true_train)
                #     print("//" * 14)
                # print("X")
                # print("//"*50)
                # print(x)
                # print("//"*50)
                # print(y)
                x = [xx.to(device) for xx in x]
                x_edge_index_dicts = [xx.edge_index_dict for xx in x]
                x_dicts = [xx.x_dict for xx in x]
                # print("//"*50)
                # print(x)
                # print(x.edge_index_dict)
                # print(x.x_dict)
                # print("//"*50)

                activities_labels = [yy[0].to(device) for yy in y]
                activities_labels = stack(activities_labels)

                timestamp_labels = [yy[1].to(device) for yy in y]
                timestamp_labels = stack(timestamp_labels)

                org_resource_labels = [yy[2].to(device) for yy in y]
                org_resource_labels = stack(org_resource_labels)

               

                if state == "train":
                    optimizer.zero_grad()
                    
                    _, true_preds = torch.max(activities_labels, 1)
                    
                    outputs_act, outputs_timestamp, outputs_org_resource = model(x_dicts, x_edge_index_dicts)
                    
                    act_loss_step = act_loss(outputs_act, true_preds)
                    
                    
                    timestamp_loss_step = timestamp_loss(outputs_timestamp, outputs_org_resource)
                    avg_mse_timestamp.append(timestamp_loss_step.item())

                    
                    org_resource_loss_step = org_resource_loss(outputs_org_resource, org_resource_labels)
                    avg_mse_org_resource.append(org_resource_loss_step.item())
                    
                    
                    total_loss = act_loss_step + timestamp_loss_step + 3*org_resource_loss_step
                    total_loss.backward()

                    optimizer.step()
                else:
                    with torch.no_grad():
                        _, true_preds = torch.max(activities_labels, 1)
                        
                        
                        outputs_act, outputs_timestamp, outputs_org_resource = model(x_dicts, x_edge_index_dicts)

                        # act_loss_step = act_loss(outputs_act, true_preds)

                        timestamp_loss_step = timestamp_loss(outputs_timestamp, outputs_org_resource)
                        avg_mse_timestamp.append(timestamp_loss_step.item())

                        org_resource_loss_step = org_resource_loss(outputs_org_resource, org_resource_labels)
                        avg_mse_org_resource.append(org_resource_loss_step.item())
                    
                    
                   


                _, preds = torch.max(outputs_act, 1)

                
                
                preds = preds.to(device)
                true_preds = true_preds.to(device)
                
                metric_tracker.add(
                    torch.tensor(preds.to(int32)), torch.tensor(true_preds.to(int32))
                )
                

        
        print(
                    "\tTEST | NA_acc: {:.4f} | mae_timestamp(s): {:.4f} | mae_org_resource: {:.4f}".format(
                        metric_tracker.acc(),
                        sum(avg_mse_timestamp) / len(avg_mse_timestamp),
                        sum(avg_mse_org_resource) / len(avg_mse_org_resource)
                    )
                )

        if epoch == 0:
            best_accuracy = metric_tracker.acc()

        else:
            if metric_tracker.acc() > best_accuracy:
                print("SAVING MODEL..............\n")
                best_accuracy = metric_tracker.acc()
                not_improved_count = 0
            else:
                not_improved_count += 1

        if not_improved_count == early_stop_patience:
            print(
                "Validation performance didn't improve for {} epochs. "
                "Training stops.".format(early_stop_patience)
            )
            break

    running_time.append((datetime.datetime.now() - start).total_seconds())

Run: 1

-- EPOCH 1/20 -------------------------



0it [00:00, ?it/s]

  torch.tensor(preds.to(int32)), torch.tensor(true_preds.to(int32))


	TRAIN | NA_acc: 0.6315 | mae_timestamp(s): 5.6924 | mae_org_resource: 7929.3899


0it [00:00, ?it/s]

	TEST | NA_acc: 0.6248 | mae_timestamp(s): 46.4347 | mae_org_resource: 7807.4701

-- EPOCH 2/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6362 | mae_timestamp(s): 59.0460 | mae_org_resource: 7515.5884


0it [00:00, ?it/s]

	TEST | NA_acc: 0.6943 | mae_timestamp(s): 87.2106 | mae_org_resource: 7118.8269
SAVING MODEL..............


-- EPOCH 3/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6723 | mae_timestamp(s): 180.2215 | mae_org_resource: 6580.6769


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7231 | mae_timestamp(s): 155.9060 | mae_org_resource: 6057.8764
SAVING MODEL..............


-- EPOCH 4/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6759 | mae_timestamp(s): 293.7596 | mae_org_resource: 5101.8325


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7092 | mae_timestamp(s): 280.7663 | mae_org_resource: 3936.1477

-- EPOCH 5/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6859 | mae_timestamp(s): 419.9242 | mae_org_resource: 3219.9729


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7057 | mae_timestamp(s): 108.6922 | mae_org_resource: 2056.3482

-- EPOCH 6/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6818 | mae_timestamp(s): 277.3064 | mae_org_resource: 2096.6965


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7272 | mae_timestamp(s): 65.4186 | mae_org_resource: 1813.2149
SAVING MODEL..............


-- EPOCH 7/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.6946 | mae_timestamp(s): 227.0166 | mae_org_resource: 2024.0435


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7305 | mae_timestamp(s): 195.0237 | mae_org_resource: 1812.0186
SAVING MODEL..............


-- EPOCH 8/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7049 | mae_timestamp(s): 229.5566 | mae_org_resource: 2014.5107


0it [00:00, ?it/s]

	TEST | NA_acc: 0.6964 | mae_timestamp(s): 103.7995 | mae_org_resource: 1837.9572

-- EPOCH 9/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7119 | mae_timestamp(s): 226.7061 | mae_org_resource: 1994.9011


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7279 | mae_timestamp(s): 39.5226 | mae_org_resource: 1800.0591

-- EPOCH 10/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7154 | mae_timestamp(s): 249.6753 | mae_org_resource: 1996.6585


0it [00:00, ?it/s]

	TEST | NA_acc: 0.6895 | mae_timestamp(s): 146.0253 | mae_org_resource: 1849.4247

-- EPOCH 11/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7167 | mae_timestamp(s): 261.4873 | mae_org_resource: 1987.9747


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7644 | mae_timestamp(s): 54.6314 | mae_org_resource: 1834.8758
SAVING MODEL..............


-- EPOCH 12/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7177 | mae_timestamp(s): 267.3299 | mae_org_resource: 1954.5316


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7057 | mae_timestamp(s): 145.2725 | mae_org_resource: 1836.0791

-- EPOCH 13/20 -------------------------



0it [00:00, ?it/s]

	TRAIN | NA_acc: 0.7180 | mae_timestamp(s): 269.7009 | mae_org_resource: 1953.9154


0it [00:00, ?it/s]

	TEST | NA_acc: 0.7330 | mae_timestamp(s): 112.3218 | mae_org_resource: 1859.6975

-- EPOCH 14/20 -------------------------



0it [00:00, ?it/s]

KeyboardInterrupt: 

In [49]:
torch.save(model.state_dict(),"multi_obj_model")