In [None]:
%pip install networkx==2.5
%pip install  dgl -f https://data.dgl.ai/wheels/repo.html
%pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html
!pip install ipython-autotime
%load_ext autotime

In [None]:
# import required packages
from google.colab import drive
import matplotlib.pyplot as plt
import networkx as nx
import gzip, pickle
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
from dgl.data import DGLDataset
import torch as th
import json
from collections import defaultdict
import numpy as np
from numpy import array
from numpy import argmax
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from collections import OrderedDict
from dgl.nn.pytorch import GraphConv
import os
from tqdm import tqdm

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
time: 6.57 s (started: 2023-10-07 06:04:35 +00:00)


In [None]:
# get data from drive
drive.mount('/content/gdrive')
drive_path = "/content/gdrive/MyDrive/fyp_data/"

# set constant seed for consistency across multiple settings
torch.manual_seed(42)
dgl.seed(42)

device_name = torch.device('cpu')
print("Using {}.".format(device_name))

Mounted at /content/gdrive
Using cpu.
time: 43 s (started: 2023-10-07 06:04:42 +00:00)


In [None]:
node_types = ['argv', 'block', 'address', 'task', 'process_memory', 'file', 'socket', 'pipe', 'iattr', 'link', 'machine', 'path', 'shm']
edge_types = ['wasAssociatedW', 'used', 'wasGeneratedBy', 'wasInformedBy', 'wasDerivedFrom']

node_features_count = len(node_types)
edge_features_count = len(edge_types)

targets = np.array(node_types)
nodelabelEnc = LabelEncoder()
new_target = nodelabelEnc.fit_transform(targets)
nodeencoder = OneHotEncoder(sparse_output=False)
nodeencoder.fit(new_target.reshape(-1, 1))
nodeencoder = nodeencoder

targets = np.array(edge_types)
edgelabelEnc = LabelEncoder()
new_target = edgelabelEnc.fit_transform(targets)
edgeencoder = OneHotEncoder(sparse_output=False)
edgeencoder.fit(new_target.reshape(-1, 1))
edgeencoder = edgeencoder


def one_hot_node(data):
  new_target = nodelabelEnc.transform(np.array(data))
  return torch.from_numpy(nodeencoder.transform(new_target.reshape(-1, 1)))

def one_hot_edge(data):
  new_target = edgelabelEnc.transform(np.array(data))
  return torch.from_numpy(edgeencoder.transform(new_target.reshape(-1, 1)))


class ProvenanceDataset(DGLDataset):
    def __init__(self):
      super().__init__(name='provenance')

    def read_graph(self, file_name):
      graph_raw = json.load(open(file_name,"r"))

      list_of_nodes = defaultdict(set)

      node_types = {}
      edge_types = []

      out_edges = []
      in_edges = []

      for key, value in graph_raw.items():
        node1_type, edge_type, node2_type = key.split("-")
        node1_index, node2_index = value
        for i in node1_index:
          node_types[i] = node1_type
          edge_types.append(edge_type)
        for i in node2_index:
          node_types[i] = node2_type


        out_edges = out_edges + node1_index
        in_edges = in_edges + node2_index

      node_types = OrderedDict(sorted(node_types.items()))
      number_of_edges = list(node_types.keys())[-1] + 1

      for i in range(number_of_edges):
        if i not in node_types:
          node_types[i] = 'task'

      out_edges = th.tensor(out_edges)
      in_edges = th.tensor(in_edges)

      g = dgl.graph((out_edges, in_edges))

      g.ndata['attr'] = (one_hot_node(list(node_types.values())))
      g.edata['attr'] = (one_hot_edge(edge_types))

      g = dgl.add_self_loop(g)

      return g

    def process(self):
        self.graphs = []
        self.labels = []

        # attack graph ID
        for graph_id in tqdm(range(1, 501)):
          g = self.read_graph(f"{drive_path}attack_2/graph{graph_id}.json")
          self.graphs.append(g)
          self.labels.append(1)

        # benign graph ID
        for graph_id in tqdm(range(1, 1001)):
          g = self.read_graph(f"{drive_path}benign_full/graph{graph_id}.json")
          self.graphs.append(g)
          self.labels.append(0)

        # attack graph ID
        for graph_id in tqdm(range(1, 501)):
          g = self.read_graph(f"{drive_path}attack/graph{graph_id}.json")
          self.graphs.append(g)
          self.labels.append(1)

        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)

dataset = ProvenanceDataset()
graph, label = dataset[0]
print(graph, label, graph.device, label.device)

100%|██████████| 500/500 [03:02<00:00,  2.74it/s]
100%|██████████| 1000/1000 [05:29<00:00,  3.03it/s]
100%|██████████| 500/500 [03:00<00:00,  2.76it/s]

Graph(num_nodes=3643, num_edges=20902,
      ndata_schemes={'attr': Scheme(shape=(13,), dtype=torch.float64)}
      edata_schemes={'attr': Scheme(shape=(5,), dtype=torch.float64)}) tensor(1) cpu cpu
time: 11min 33s (started: 2023-10-07 06:05:25 +00:00)





In [None]:
from torch.utils.data.dataloader import default_collate

batch_size = 4
train_set, test_set, validation_set = dgl.data.utils.split_dataset(dataset, [0.8, 0.1, 0.1], shuffle=True, random_state=42)
data_loader = GraphDataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
validation_loader = GraphDataLoader(validation_set, batch_size=1, shuffle=True, pin_memory=True)
test_loader = GraphDataLoader(test_set, batch_size=1, shuffle=True, pin_memory=True)

time: 4.09 ms (started: 2023-10-07 06:28:25 +00:00)


In [12]:
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, hidden_layer):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.hidden_layer = hidden_layer
        self.hidden_dim = hidden_dim
        if self.hidden_layer == 2:
          self.conv2 = GraphConv(hidden_dim, hidden_dim)
        if self.hidden_layer == 3:
          self.conv3 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.ndata['attr'].float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        if self.hidden_layer == 2:
          h = F.relu(self.conv2(g, h))
        if self.hidden_layer == 3:
          h = F.relu(self.conv3(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')

        return self.classify(hg)

    def print_params(self):
        print(f"{self.hidden_layer} hidden layers. {self.hidden_dim} hidden dim.")

time: 2.56 ms (started: 2023-10-07 12:18:53 +00:00)


In [13]:
def compute_metrics(preds: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
  is_multiclass = labels.max().item() > 1
  if is_multiclass:
      preds = torch.argmax(preds, dim=-1)
      probs = preds.tolist()  # Predicted class not raw probs
  else:
      probs = preds.tolist()
      preds = (preds > threshold).float()

  return {
      'accuracy': accuracy_score(preds, labels),
      'precision': precision_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'recall': recall_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'F1 micro': f1_score(preds, labels, average='micro'),
      'F1 macro': f1_score(preds, labels, average='macro'),
      'probs': probs,
      'labels': labels.tolist(),
  }

def eval_model(function_model, dataset_loader, name):
  predictions = []
  labels = []

  with torch.no_grad():
      for iter, (bg, label) in enumerate(dataset_loader):
          prediction = function_model(bg)
          probs_Y = torch.softmax(prediction, 1)
          argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
          predictions.append(argmax_Y)
          labels.append(label)

  argmax_Y = torch.cat(predictions, dim=0)
  test_Y = torch.cat(labels, dim=0)

  print(f"{name}: {accuracy_score(argmax_Y, test_Y)}")

  return

time: 1.34 ms (started: 2023-10-07 12:18:54 +00:00)


In [16]:
models = []

for hidden_layer in [3]:
  for hidden_dim in [128, 256, 512]:
    model = Classifier(13, hidden_dim, 2, hidden_layer)
    models.append(model)
# Create model
# model = Classifier(13, 256, 2)
# use pretrained model
# model.load_state_dict(torch.load(f"{drive_path}model.pth"))


time: 10.5 ms (started: 2023-10-07 12:19:10 +00:00)


In [17]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics(preds: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
  is_multiclass = labels.max().item() > 1
  if is_multiclass:
      preds = torch.argmax(preds, dim=-1)
      probs = preds.tolist()  # Predicted class not raw probs
  else:
      probs = preds.tolist()
      preds = (preds > threshold).float()

  return {
      'accuracy': accuracy_score(preds, labels),
      'precision': precision_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'recall': recall_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'F1 micro': f1_score(preds, labels, average='micro'),
      'F1 macro': f1_score(preds, labels, average='macro')
  }

#torch.save(model.state_dict(), "model.pth")
def print_metrics(model, loader):
  predictions = []
  labels = []

  with torch.no_grad():
      model.eval()
      for iter, (bg, label) in enumerate(loader):
          prediction = model(bg)
          probs_Y = torch.softmax(prediction, 1)
          argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
          predictions.append(argmax_Y)
          labels.append(label)
      model.train()

  argmax_Y = torch.cat(predictions, dim=0)
  test_Y = torch.cat(labels, dim=0)

  metrics = compute_metrics(argmax_Y, test_Y)
  for metric, value in metrics.items():
    print(f"{metric}: {value}")

time: 1.64 ms (started: 2023-10-07 12:19:16 +00:00)


In [None]:
for model in models:
  model.print_params()
  loss_func = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)

  # prepares for training
  model.train()
  epoch_losses = []
  best_validation_loss = 100000000000
  epochs_without_gain = 0
  end_training = False

  for epoch in tqdm(range(50)):
      epoch_loss = 0
      for i in enumerate(data_loader):
          iter, (bg, label) = i
          prediction = model(bg)
          loss = loss_func(prediction, label)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          epoch_loss += loss.detach().item()

      model.eval()

      epoch_loss /= (iter + 1)

      validation_loss = 0
      for iter, (bg, label) in enumerate(validation_loader):
          prediction = model(bg)
          validation_loss += loss_func(prediction, label).detach().item()

      validation_loss /= (iter + 1)

      if validation_loss < best_validation_loss and epoch >= 30:
        epochs_without_gain = 0
        best_validation_loss = validation_loss
        eval_model(model, data_loader, "train")
        eval_model(model, validation_loader, "validation")
        eval_model(model, test_loader, "test")
        torch.save(model.state_dict(), f"{drive_path}model{model.hidden_dim}_{model.hidden_layer}.pth")
        print_metrics(model, test_loader)
      #elif epoch >= 30:
      #  epochs_without_gain += 1
      #  if epochs_without_gain > 5:
      #    end_training = True

      model.train()

      print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
      epoch_losses.append(epoch_loss)

      if end_training:
        break

3 hidden layers. 128 hidden dim.


  2%|▏         | 1/50 [00:38<31:37, 38.73s/it]

Epoch 0, loss 0.6864


  4%|▍         | 2/50 [01:19<31:46, 39.72s/it]

Epoch 1, loss 0.6572


  6%|▌         | 3/50 [01:58<30:57, 39.52s/it]

Epoch 2, loss 0.6469


  8%|▊         | 4/50 [02:38<30:34, 39.89s/it]

Epoch 3, loss 0.6452


 10%|█         | 5/50 [03:18<29:45, 39.69s/it]

Epoch 4, loss 0.6445


 12%|█▏        | 6/50 [03:58<29:11, 39.81s/it]

Epoch 5, loss 0.6426


 14%|█▍        | 7/50 [04:37<28:20, 39.55s/it]

Epoch 6, loss 0.6441


 16%|█▌        | 8/50 [05:17<27:47, 39.70s/it]

Epoch 7, loss 0.6380


 18%|█▊        | 9/50 [05:56<26:56, 39.43s/it]

Epoch 8, loss 0.6365


 20%|██        | 10/50 [06:35<26:21, 39.53s/it]

Epoch 9, loss 0.6343


 22%|██▏       | 11/50 [07:14<25:32, 39.29s/it]

Epoch 10, loss 0.6338


 24%|██▍       | 12/50 [07:54<25:00, 39.48s/it]

Epoch 11, loss 0.6352


 26%|██▌       | 13/50 [08:33<24:12, 39.25s/it]

Epoch 12, loss 0.6315


 28%|██▊       | 14/50 [09:13<23:39, 39.42s/it]

Epoch 13, loss 0.6313


 30%|███       | 15/50 [09:52<22:54, 39.28s/it]

Epoch 14, loss 0.6296


 32%|███▏      | 16/50 [10:31<22:13, 39.23s/it]

Epoch 15, loss 0.6303


 34%|███▍      | 17/50 [11:10<21:39, 39.39s/it]

Epoch 16, loss 0.6298


 36%|███▌      | 18/50 [11:50<21:05, 39.56s/it]

Epoch 17, loss 0.6260


 38%|███▊      | 19/50 [12:30<20:23, 39.46s/it]

Epoch 18, loss 0.6264


 40%|████      | 20/50 [13:08<19:36, 39.22s/it]

Epoch 19, loss 0.6256


 42%|████▏     | 21/50 [13:48<19:04, 39.45s/it]

Epoch 20, loss 0.6228


 44%|████▍     | 22/50 [14:28<18:25, 39.47s/it]

Epoch 21, loss 0.6169


 46%|████▌     | 23/50 [15:08<17:48, 39.59s/it]

Epoch 22, loss 0.6236


 48%|████▊     | 24/50 [15:47<17:05, 39.45s/it]

Epoch 23, loss 0.6208


 50%|█████     | 25/50 [16:27<16:29, 39.59s/it]

Epoch 24, loss 0.6196


 52%|█████▏    | 26/50 [17:06<15:47, 39.47s/it]

Epoch 25, loss 0.6194


 54%|█████▍    | 27/50 [17:45<15:08, 39.51s/it]

Epoch 26, loss 0.6130


 56%|█████▌    | 28/50 [18:24<14:26, 39.37s/it]

Epoch 27, loss 0.6147


 58%|█████▊    | 29/50 [19:04<13:44, 39.28s/it]

Epoch 28, loss 0.6118


 60%|██████    | 30/50 [19:43<13:04, 39.21s/it]

Epoch 29, loss 0.6133
train: 0.679375
validation: 0.64
test: 0.645


 62%|██████▏   | 31/50 [20:48<14:54, 47.06s/it]

accuracy: 0.645
precision: 0.6666666666666666
recall: 0.660377358490566
F1 micro: 0.645
F1 macro: 0.6439228666716819
Epoch 30, loss 0.6142
train: 0.68625
validation: 0.65
test: 0.655


 64%|██████▍   | 32/50 [21:52<15:40, 52.24s/it]

accuracy: 0.655
precision: 0.580952380952381
recall: 0.7093023255813954
F1 micro: 0.655
F1 macro: 0.6542999574137629
Epoch 31, loss 0.6073


 66%|██████▌   | 33/50 [22:31<13:38, 48.16s/it]

Epoch 32, loss 0.6118


 68%|██████▊   | 34/50 [23:10<12:07, 45.49s/it]

Epoch 33, loss 0.6133


 70%|███████   | 35/50 [23:49<10:51, 43.44s/it]

Epoch 34, loss 0.6090


 72%|███████▏  | 36/50 [24:28<09:50, 42.21s/it]

Epoch 35, loss 0.6093


 74%|███████▍  | 37/50 [25:08<08:57, 41.35s/it]

Epoch 36, loss 0.6051
train: 0.68875
validation: 0.675
test: 0.68


 76%|███████▌  | 38/50 [26:12<09:40, 48.40s/it]

accuracy: 0.68
precision: 0.5523809523809524
recall: 0.7733333333333333
F1 micro: 0.68
F1 macro: 0.6767676767676768
Epoch 37, loss 0.6046


 78%|███████▊  | 39/50 [26:51<08:20, 45.48s/it]

Epoch 38, loss 0.6037


 80%|████████  | 40/50 [27:34<07:27, 44.72s/it]

Epoch 39, loss 0.6042


 82%|████████▏ | 41/50 [28:14<06:29, 43.26s/it]

Epoch 40, loss 0.6019


 84%|████████▍ | 42/50 [28:52<05:34, 41.85s/it]

Epoch 41, loss 0.6039
train: 0.694375
validation: 0.66
test: 0.685


 86%|████████▌ | 43/50 [29:56<05:38, 48.38s/it]

accuracy: 0.685
precision: 0.5904761904761905
recall: 0.7560975609756098
F1 micro: 0.685
F1 macro: 0.6836634781953755
Epoch 42, loss 0.6012


 88%|████████▊ | 44/50 [30:35<04:33, 45.63s/it]

Epoch 43, loss 0.6021


 90%|█████████ | 45/50 [31:16<03:40, 44.01s/it]

Epoch 44, loss 0.6017


 92%|█████████▏| 46/50 [31:55<02:50, 42.59s/it]

Epoch 45, loss 0.5968


 94%|█████████▍| 47/50 [32:37<02:07, 42.43s/it]

Epoch 46, loss 0.5985


 96%|█████████▌| 48/50 [33:17<01:23, 41.74s/it]

Epoch 47, loss 0.5961


 98%|█████████▊| 49/50 [33:57<00:41, 41.10s/it]

Epoch 48, loss 0.5978


100%|██████████| 50/50 [34:36<00:00, 41.53s/it]


Epoch 49, loss 0.5940
3 hidden layers. 256 hidden dim.


  2%|▏         | 1/50 [01:42<1:23:23, 102.12s/it]

Epoch 0, loss 0.6831


  4%|▍         | 2/50 [03:24<1:21:57, 102.45s/it]

Epoch 1, loss 0.6604


  6%|▌         | 3/50 [05:07<1:20:12, 102.40s/it]

Epoch 2, loss 0.6532


  8%|▊         | 4/50 [06:47<1:17:56, 101.65s/it]

Epoch 3, loss 0.6425


 10%|█         | 5/50 [08:32<1:16:59, 102.66s/it]

Epoch 4, loss 0.6426


 12%|█▏        | 6/50 [10:13<1:15:03, 102.34s/it]

Epoch 5, loss 0.6397


 14%|█▍        | 7/50 [11:56<1:13:18, 102.30s/it]

Epoch 6, loss 0.6433


 16%|█▌        | 8/50 [13:39<1:11:50, 102.63s/it]

Epoch 7, loss 0.6369


 18%|█▊        | 9/50 [15:20<1:09:51, 102.22s/it]

Epoch 8, loss 0.6347


 20%|██        | 10/50 [17:04<1:08:26, 102.66s/it]

Epoch 9, loss 0.6349


 22%|██▏       | 11/50 [18:47<1:06:51, 102.86s/it]

Epoch 10, loss 0.6329


 24%|██▍       | 12/50 [20:29<1:05:01, 102.66s/it]

Epoch 11, loss 0.6300


 26%|██▌       | 13/50 [22:12<1:03:18, 102.66s/it]

Epoch 12, loss 0.6281


 28%|██▊       | 14/50 [23:54<1:01:28, 102.45s/it]

Epoch 13, loss 0.6225


 30%|███       | 15/50 [25:38<59:57, 102.77s/it]  

Epoch 14, loss 0.6208


 32%|███▏      | 16/50 [27:19<58:00, 102.37s/it]

Epoch 15, loss 0.6186


 34%|███▍      | 17/50 [29:03<56:38, 103.00s/it]

Epoch 16, loss 0.6172


 36%|███▌      | 18/50 [30:45<54:39, 102.50s/it]

Epoch 17, loss 0.6136


 38%|███▊      | 19/50 [32:28<53:08, 102.86s/it]

Epoch 18, loss 0.6137


 40%|████      | 20/50 [34:10<51:14, 102.50s/it]

Epoch 19, loss 0.6100


 42%|████▏     | 21/50 [35:55<49:49, 103.08s/it]

Epoch 20, loss 0.6133


 44%|████▍     | 22/50 [37:36<47:53, 102.63s/it]

Epoch 21, loss 0.6074


 46%|████▌     | 23/50 [39:21<46:26, 103.21s/it]

Epoch 22, loss 0.6100


 48%|████▊     | 24/50 [41:02<44:30, 102.73s/it]

Epoch 23, loss 0.6098


 50%|█████     | 25/50 [42:46<42:57, 103.10s/it]

Epoch 24, loss 0.6056


 52%|█████▏    | 26/50 [44:28<41:02, 102.60s/it]

Epoch 25, loss 0.6063


 54%|█████▍    | 27/50 [46:12<39:28, 102.98s/it]

Epoch 26, loss 0.6055


 56%|█████▌    | 28/50 [47:53<37:35, 102.54s/it]

Epoch 27, loss 0.6028


 58%|█████▊    | 29/50 [49:37<36:00, 102.86s/it]

Epoch 28, loss 0.6037


 60%|██████    | 30/50 [51:20<34:17, 102.85s/it]

Epoch 29, loss 0.6019
train: 0.68125
validation: 0.66
test: 0.655


 62%|██████▏   | 31/50 [53:58<37:49, 119.46s/it]

accuracy: 0.655
precision: 0.7047619047619048
recall: 0.6607142857142857
F1 micro: 0.655
F1 macro: 0.6524892347208582
Epoch 30, loss 0.5996


 64%|██████▍   | 32/50 [55:44<34:40, 115.56s/it]

Epoch 31, loss 0.5958
train: 0.69375
validation: 0.65
test: 0.655


 66%|██████▌   | 33/50 [58:25<36:33, 129.04s/it]

accuracy: 0.655
precision: 0.6952380952380952
recall: 0.6636363636363637
F1 micro: 0.655
F1 macro: 0.6530483972344437
Epoch 32, loss 0.6002
train: 0.6925
validation: 0.685
test: 0.69


 68%|██████▊   | 34/50 [1:01:05<36:53, 138.34s/it]

accuracy: 0.69
precision: 0.5619047619047619
recall: 0.7866666666666666
F1 micro: 0.69
F1 macro: 0.6868686868686869
Epoch 33, loss 0.5939
train: 0.703125
validation: 0.67
test: 0.685


 70%|███████   | 35/50 [1:03:44<36:09, 144.65s/it]

accuracy: 0.685
precision: 0.6190476190476191
recall: 0.7386363636363636
F1 micro: 0.685
F1 macro: 0.6846136517233611
Epoch 34, loss 0.5951


 72%|███████▏  | 36/50 [1:05:28<30:55, 132.54s/it]

Epoch 35, loss 0.5916


 74%|███████▍  | 37/50 [1:07:12<26:52, 124.01s/it]

Epoch 36, loss 0.5899


 76%|███████▌  | 38/50 [1:08:58<23:40, 118.36s/it]

Epoch 37, loss 0.5883
train: 0.713125
validation: 0.69
test: 0.695


 78%|███████▊  | 39/50 [1:11:43<24:15, 132.31s/it]

accuracy: 0.695
precision: 0.6190476190476191
recall: 0.7558139534883721
F1 micro: 0.695
F1 macro: 0.6943811217715875
Epoch 38, loss 0.5842


 80%|████████  | 40/50 [1:13:26<20:37, 123.80s/it]

Epoch 39, loss 0.5830


 82%|████████▏ | 41/50 [1:15:12<17:44, 118.31s/it]

Epoch 40, loss 0.5783


 84%|████████▍ | 42/50 [1:16:56<15:11, 113.90s/it]

Epoch 41, loss 0.5803


 86%|████████▌ | 43/50 [1:18:42<13:01, 111.68s/it]

Epoch 42, loss 0.5753


 88%|████████▊ | 44/50 [1:20:26<10:56, 109.44s/it]

Epoch 43, loss 0.5745


 90%|█████████ | 45/50 [1:22:14<09:04, 108.92s/it]

Epoch 44, loss 0.5742


 92%|█████████▏| 46/50 [1:23:58<07:10, 107.55s/it]

Epoch 45, loss 0.5716


 94%|█████████▍| 47/50 [1:25:43<05:20, 106.67s/it]

Epoch 46, loss 0.5685


 96%|█████████▌| 48/50 [1:27:29<03:32, 106.42s/it]

Epoch 47, loss 0.5662


 98%|█████████▊| 49/50 [1:29:14<01:45, 105.94s/it]

Epoch 48, loss 0.5668


100%|██████████| 50/50 [1:30:58<00:00, 109.18s/it]


Epoch 49, loss 0.5689
3 hidden layers. 512 hidden dim.


  2%|▏         | 1/50 [06:38<5:25:40, 398.79s/it]

Epoch 0, loss 0.6866


  4%|▍         | 2/50 [13:02<5:12:07, 390.16s/it]

Epoch 1, loss 0.6675


  6%|▌         | 3/50 [19:37<5:07:08, 392.10s/it]

Epoch 2, loss 0.6493


  8%|▊         | 4/50 [26:10<5:00:51, 392.42s/it]

Epoch 3, loss 0.6479


 10%|█         | 5/50 [32:46<4:55:21, 393.82s/it]

Epoch 4, loss 0.6420


 12%|█▏        | 6/50 [39:20<4:48:54, 393.96s/it]

Epoch 5, loss 0.6372


 14%|█▍        | 7/50 [45:49<4:41:08, 392.30s/it]

Epoch 6, loss 0.6366


 16%|█▌        | 8/50 [52:28<4:36:02, 394.34s/it]

Epoch 7, loss 0.6298


 18%|█▊        | 9/50 [58:58<4:28:38, 393.14s/it]

Epoch 8, loss 0.6281


 20%|██        | 10/50 [1:05:29<4:21:37, 392.44s/it]

Epoch 9, loss 0.6229


 22%|██▏       | 11/50 [1:12:03<4:15:17, 392.77s/it]

Epoch 10, loss 0.6256


 24%|██▍       | 12/50 [1:18:34<4:08:22, 392.18s/it]

Epoch 11, loss 0.6206


 26%|██▌       | 13/50 [1:25:00<4:00:41, 390.32s/it]

Epoch 12, loss 0.6189


 28%|██▊       | 14/50 [1:31:29<3:53:58, 389.97s/it]

Epoch 13, loss 0.6173


 30%|███       | 15/50 [1:38:00<3:47:41, 390.34s/it]

Epoch 14, loss 0.6154


 32%|███▏      | 16/50 [1:44:29<3:40:54, 389.83s/it]

Epoch 15, loss 0.6159


 34%|███▍      | 17/50 [1:51:03<3:35:13, 391.31s/it]

Epoch 16, loss 0.6147


 36%|███▌      | 18/50 [1:57:40<3:29:31, 392.87s/it]

Epoch 17, loss 0.6113


 38%|███▊      | 19/50 [2:04:07<3:22:08, 391.26s/it]

Epoch 18, loss 0.6146


 40%|████      | 20/50 [2:10:43<3:16:13, 392.46s/it]

Epoch 19, loss 0.6109


 42%|████▏     | 21/50 [2:17:19<3:10:18, 393.75s/it]

Epoch 20, loss 0.6073


 44%|████▍     | 22/50 [2:23:49<3:03:10, 392.50s/it]

Epoch 21, loss 0.6094


 46%|████▌     | 23/50 [2:30:29<2:57:42, 394.89s/it]

Epoch 22, loss 0.6094


 48%|████▊     | 24/50 [2:37:07<2:51:27, 395.67s/it]

Epoch 23, loss 0.6096


 50%|█████     | 25/50 [2:43:34<2:43:45, 393.03s/it]

Epoch 24, loss 0.6061


 52%|█████▏    | 26/50 [2:50:10<2:37:33, 393.91s/it]

Epoch 25, loss 0.6037


 54%|█████▍    | 27/50 [2:56:44<2:31:03, 394.05s/it]

Epoch 26, loss 0.6062


 56%|█████▌    | 28/50 [3:03:21<2:24:46, 394.85s/it]

Epoch 27, loss 0.6036


 58%|█████▊    | 29/50 [3:09:48<2:17:20, 392.41s/it]

Epoch 28, loss 0.5994


 60%|██████    | 30/50 [3:16:18<2:10:34, 391.74s/it]

Epoch 29, loss 0.5996
train: 0.6975
validation: 0.665
test: 0.67


 62%|██████▏   | 31/50 [3:26:14<2:23:30, 453.19s/it]

accuracy: 0.67
precision: 0.6
recall: 0.7241379310344828
F1 micro: 0.67
F1 macro: 0.6694711538461539
Epoch 30, loss 0.6004
train: 0.698125
validation: 0.685
test: 0.69


 64%|██████▍   | 32/50 [3:35:56<2:27:33, 491.86s/it]

accuracy: 0.69
precision: 0.5714285714285714
recall: 0.7792207792207793
F1 micro: 0.69
F1 macro: 0.6874684948079444
Epoch 31, loss 0.5952


In [None]:
for model in models:
  torch.save(model.state_dict(), f"{drive_path}model{model.hidden_dim}_{model.hidden_layer}.pth")

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics(preds: torch.Tensor, labels: torch.Tensor, threshold: float = 0.5):
  is_multiclass = labels.max().item() > 1
  if is_multiclass:
      preds = torch.argmax(preds, dim=-1)
      probs = preds.tolist()  # Predicted class not raw probs
  else:
      probs = preds.tolist()
      preds = (preds > threshold).float()

  return {
      'accuracy': accuracy_score(preds, labels),
      'precision': precision_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'recall': recall_score(preds, labels, average='micro' if is_multiclass else 'binary'),
      'F1 micro': f1_score(preds, labels, average='micro'),
      'F1 macro': f1_score(preds, labels, average='macro'),
      'probs': probs,
      'labels': labels.tolist(),
  }

for model in models:
  model.eval()
  test_X, test_Y = map(list, zip(*test_set))
  test_bg = dgl.batch(test_X)
  test_Y = torch.tensor(test_Y).float().view(-1, 1)
  pred = model(test_bg)
  probs_Y = torch.softmax(pred, 1)
  argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)

  metrics = compute_metrics(argmax_Y, test_Y)
  for metric, value in metrics.items():
    print(f"{metric}: {value}")