In [None]:
import torch
import pickle
import numpy as np
import pandas as pd
import os

from os.path import dirname



root_path = dirname(os.getcwd()) + "/HGNN_NA"

pd.set_option("display.max_columns", None)
data_dir = root_path + "/data/datasets/original/"
data_dir_processed = root_path + "/data/datasets/processed/"
data_dir_graphs = root_path + "/data/datasets/graphs/"

print(root_path, data_dir, data_dir_processed, data_dir_graphs, sep="\n")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [2]:
dataset = "BPI_Challenge_2012_A"

In [None]:
tab_all = pd.read_csv(data_dir_processed+dataset+"_processed_all.csv")
print(tab_all.head())
list_activities = list(tab_all["Activity"].unique())

In [4]:
import random

torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [5]:
with open(data_dir_graphs + dataset + "_TRAIN.pkl", "rb") as f:
    X_train, Y_train = pickle.load(f)
with open(data_dir_graphs + dataset + "_VALID.pkl", "rb") as f:
    X_valid, Y_valid = pickle.load(f)
with open(data_dir_graphs + dataset + "_TEST.pkl", "rb") as f:
    X_test, Y_test = pickle.load(f)

In [6]:
from typing_extensions import Self
from torch_geometric.data import HeteroData
from torch_geometric.data import Dataset
from torch.utils.data import DataLoader


class Het_graph_data(Dataset):
    def __init__(self, prefix_graphs, labels) -> Self:
        self.X = prefix_graphs
        self.Y = labels

    # get the number of rows in the dataset
    def __len__(self):
        return len(self.Y)

    # get a row at a particular index in the dataset
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    @staticmethod
    def collate(batch):
        # print(batch)
        data = [item[0] for item in batch]
        Y = [item[1] for item in batch]
        return [data, Y]

In [7]:
train_loader = DataLoader(
    Het_graph_data(X_train, Y_train),
    batch_size=1,
    shuffle=True,
    collate_fn=Het_graph_data.collate,
)


valid_loader = DataLoader(
    Het_graph_data(X_valid, Y_valid),
    batch_size=1,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

test_loader = DataLoader(
    Het_graph_data(X_test, Y_test),
    batch_size=1,
    shuffle=False,
    collate_fn=Het_graph_data.collate,
)

In [8]:
# A Class to keep track of the metrics of the classification process
class ClassificationMetrics:

  # Constructor takes the number of classes, in our case 20
  def __init__(self, num_classes=20):
    self.num_classes = num_classes
    # Initialize a confusion matrix
    self.C = torch.zeros(num_classes, num_classes)

  # Update the confusion matrix with the new scores
  def add(self, yp, yt):
    # yp: 1D tensor with predictions
    # yt: 1D tensor with ground-truth targets
    yp = yp.to("cpu")
    yt = yt.to("cpu")
    with torch.no_grad(): # We require no computation graph
      self.C+=(yt*self.C.shape[1]+yp).bincount(minlength=self.C.numel()).view(self.C.shape).float()

  def clear(self):
    # We set the confusion matrix to zero
    self.C.zero_()

  # Computes the global accuracy
  def acc(self):
    return self.C.diag().sum().item()/self.C.sum()

  # Computes the class-averaged accuracy
  def mAcc(self):
    return (self.C.diag()/self.C.sum(-1)).mean().item()

  # Computers the class-averaged Intersection over Union
  def mIoU(self):
    return (self.C.diag()/(self.C.sum(0)+self.C.sum(1)-self.C.diag())).mean().item()

  # Returns the confusion matrix
  def confusion_matrix(self):
    return self.C

In [9]:
loaders = {"train": train_loader, "validation" : valid_loader, "test" : test_loader}

In [10]:
node_types, edge_types = X_train[0].metadata()

In [None]:
X_train[0].x_dict

In [None]:
node_types

In [None]:
edge_types

In [21]:
# from models.models import HGNN
import datetime

from torch import nn
from tqdm.notebook import tqdm


In [43]:
from typing_extensions import Self
from torch_geometric.nn import SAGEConv, HeteroConv, GATConv, Linear, GCNConv
from torch.nn import ModuleList, Module, Sequential, Softmax, Dropout
from torch import mean, stack, sum, concat


class HGNN(Module):

    def __init__(self, hid, out, layers, node_types, nodes_relations) -> Self:  # type: ignore
        super().__init__()

        # List of convolutional layers
        self.convs = ModuleList()
        for _ in range(layers):
            self.convs.append(
                HeteroConv(
                    {relation:SAGEConv((-1, -1), hid) for relation in nodes_relations}
                    ,
                    aggr="mean",
                )
            )
        print(nodes_relations)
        # Take each node hid representation and apply a linear layer
        self.linear_nodes = Linear(hid, hid)

        # Return the softmax with the class probabilities
        self.fc = Sequential(Linear(hid*len(node_types), out), Softmax())

    def forward(self, x_dict, edge_index_dict):
        
        x_dict = x_dict[0]
        edge_index_dict = edge_index_dict[0]
        # Convolutional layers
        for conv in self.convs:
            print("HEYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY")
            x_dict = conv(x_dict, edge_index_dict)
            print("HEYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY")
            x_dict = {key: x.relu() for key, x in x_dict.items()}

        print(x_dict)
        # Node features of each node in the graph
        nodes_features = [
            self.linear_nodes(x_dict[key]).relu() for key in x_dict.keys()
        ]
        print(nodes_features)
        # Global mean of each node type
        for i in range(len(nodes_features)):
            nodes_features[i] = mean(nodes_features[i], dim=0)

        # print(nodes_features)
        # print(concat(nodes_features))
        # Global mean pooling
        #nodes_features = mean(stack(nodes_features), dim=0)
        nodes_features = concat(nodes_features)
        nodes_features = self.fc(nodes_features)

        return nodes_features  # {key : self.linear(x_dict[key]) for key in x_dict.keys()}, nodes_features


In [44]:
model = HGNN(hid=64, out=len(list_activities), layers=4, node_types=node_types, nodes_relations=edge_types)
model

[('org:resource', 'related_to', 'org:resource'),
 ('Activity', 'followed_by', 'Activity'),
 ('time:timestamp', 'related_to', 'time:timestamp'),
 ('Activity', 'related_to', 'org:resource'),
 ('Activity', 'related_to', 'lifecycle:transition'),
 ('Activity', 'related_to', 'time:timestamp'),
 ('Activity', 'related_to', 'case:REG_DATE'),
 ('Activity', 'related_to', 'case:AMOUNT_REQ')]


HGNN(
  (convs): ModuleList(
    (0-3): 4 x HeteroConv(num_relations=8)
  )
  (linear_nodes): Linear(64, 64, bias=True)
  (fc): Sequential(
    (0): Linear(384, 10, bias=True)
    (1): Softmax(dim=None)
  )
)

In [None]:
X_train[0].x_dict.items()

In [None]:
X_train[0].edge_items()

In [None]:
from pprint import pprint as print
print("x_dict:")
print({k: v.shape for k, v in X_train[0].x_dict.items()})
print("edge_index_dict:")
print({k: v.shape for k, v in X_train[0].edge_index_dict.items()})


In [31]:
model([X_train[0].x_dict],[X_train[0].edge_index_dict])

'HEYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY'


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

In [None]:
num_epochs = 1
best_accuracy = 0
early_stop_patience = 10
lr_value = 1e-2

best_model = None

num_runs = 1
running_time = []

metric_tracker = ClassificationMetrics(num_classes=len(list_activities))

for run in range(num_runs):
    
    start = datetime.datetime.now()
    print("Run: {}".format(run + 1))
    
    
    
    model = model.to(device)
    #print("compiling...")
    #model = torch.compile(model, dynamic=True)
    #print("Compiled model")
    
    criterion = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_value) # momentum=0.9, weight_decay=1e-1)

    
    not_improved_count = 0
    
    for epoch in range(num_epochs):
        print(
            "\n-- EPOCH {}/{} -------------------------\n".format(epoch + 1, num_epochs)
        )
        torch.cuda.empty_cache()
        count_train = [0 for _ in range(len(list_activities))]
        count_val = [0 for _ in range(len(list_activities))]
        for state in ["train", "validation"]:
            if state == "train":
                model.train()
                metric_tracker.clear()
                # true_label = []
                # predictions = []
                # scores = []

            else:
                
                # print(metric_tracker.confusion_matrix())
                print(count_train)
                
                print("\tTRAIN | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}".format(metric_tracker.acc(),
                                                                                   #metric_tracker.mAcc(),
                                                                                   #metric_tracker.mIoU()
                                                                                   0,0))
                
                # y_true = np.concatenate(true_label)
                # y_pred = np.concatenate(predictions)
                # scores = np.concatenate(scores)
                
                # print("TRAIN")
                # print(conta(y_true, 0), conta(y_true, 1))
                # print(conta(y_pred, 0), conta(y_pred, 1))
                
                # print_stats(y_pred, y_true, scores)
                metric_tracker.clear()
                model.eval()
                # true_label = []
                # predictions = []
                # scores = []
            
            running_loss = 0.0
            running_corrects = 0

            for i,(x,y) in tqdm(enumerate(loaders[state])):
                # print("X")
                # print("//"*50)
                # print(x)
                # print("//"*50)
                # print(y)
                x = x[0].to(device)
                print("//"*50)
                #print(x)
                #print(x.edge_index_dict)
                #print(x.x_dict)
                print("//"*50)
                
                y = y[0].to(device)
                
                #x = [[sub_item.to(device=device) for sub_item in item] for item in x]


                #y = torch.tensor([torch.max(yi,0)[1] for yi in y])

                #y = y.to(device)
                
                outputs = model([x.x_dict], [x.edge_index_dict])
                
         
                
                outputs = outputs.to(device)
                
                loss = criterion(outputs, y)

                if state == "train":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                _, preds = torch.max(outputs, 1)                
                preds = preds.to(device)
                if state == "train":
                    for i in preds:
                        count_train[i] += 1
                else:
                    for i in preds:
                        count_val[i] += 1              
                metric_tracker.add(preds, y)
                
                
        print(count_val)
        print("\tEVAL  | acc: {:.4f} | mAcc: {:.4f} | mIoU: {:.4f}\n".format(metric_tracker.acc(),
                                                                             # metric_tracker.mAcc(),
                                                                             # metric_tracker.mIoU()
                                                                             # )
                                                                            0,0)   )     
                   

        if epoch == 0:# HERE WE KEEP BEST AUC VALUE
            best_accuracy = metric_tracker.acc()
            
        else:
            if metric_tracker.acc() > best_accuracy:
                print("SAVING MODEL..............\n")
                best_accuracy = metric_tracker.acc()
                not_improved_count = 0
            else:
                not_improved_count += 1

        if not_improved_count == early_stop_patience:
            print(
                "Validation performance didn't improve for {} epochs. "
                "Training stops.".format(early_stop_patience)
            )
            break

    running_time.append((datetime.datetime.now() - start).total_seconds())