In [308]:
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
import os.path as osp
import os

current_dir = os.getcwd()
path = osp.join(osp.dirname(osp.realpath(current_dir)), '../', 'datasets')
dataset = TUDataset(path, name='Mutagenicity').shuffle()
print(len(dataset))

4337


Create a new dataset with the uncommon nodes

In [303]:
common_element_indices = {0, 1, 3, 4}

def create_new_graph(data):
    uncommon_elements = []
    for i in range(data.x.shape[0]):
        if torch.argmax(data.x[i, :]).item() not in common_element_indices:
            uncommon_elements.append(data.x[i, :])

    if len(uncommon_elements) > 0:
        uncommon_elements = torch.stack(uncommon_elements)

        num_nodes = len(uncommon_elements)
        edge_index = torch.tensor([[i, j] for i in range(num_nodes) for j in range(i + 1, num_nodes)]).t().contiguous()
        # if len(edge_index) == 0:
        #     a = torch.tensor([[]], dtype=torch.long)
        #     b = torch.tensor([[]], dtype=torch.long)
        #     edge_index = torch.cat((a, b), 0)

        if len(edge_index) > 0:
            graph = Data(
                x=uncommon_elements,
                edge_index=edge_index,
                y=data.y
            )
            return graph
    return None

new_graphs = [create_new_graph(data) for data in dataset if create_new_graph(data) is not None]
print("Created " + str(len(new_graphs)) + " new graphs")

new_dataset = dataset + new_graphs

new_loader = DataLoader(new_dataset, batch_size=1, shuffle=True)

print("New dataset: ", len(new_loader))

Created 677 new graphs
New dataset:  5014


Test with the trained model:

In [304]:
import torch
import torch.nn as nn
import torch.functional as F
import re
import itertools
import sys
sys.path.append("/Users/raffaelepojer/Dev/RBN-GNN/python/")
from gnn.ACR_graph import *

model = MYACRGnnGraph(
    input_dim=14,
    hidden_dim=[16,8,8],
    num_layers=3,
    mlp_layers=0,
    final_read="add",
    num_classes=2,
    fwd_dp=0.15,
    lin_dp=0.15,
    mlp_dp=0.0
)

criterion = nn.CrossEntropyLoss()
model.load_state_dict(torch.load(f"/Users/raffaelepojer/Dev/RBN-GNN/models/Mutagenicity_16_8_8_20230814-204701/exp_33/rbn_acr_graph_Mutagenicity_16_8_8_add.pt"))
model.eval()

MYACRGnnGraph(
  (layers): ModuleList(
    (0-2): 3 x MYACRConv()
  )
  (linear): Linear(in_features=8, out_features=2, bias=True)
)

In [305]:
def test(model, data, device="cpu"):
    model.eval()
    tp = 0
    fp = 0
    tn = 0
    fn = 0

    for batch in data:
        batch = batch.to(device)
        with torch.no_grad():
            out = model(batch.x, batch.edge_index, batch.batch)
        target = batch.y

        _, pred = out.max(1)

        tp += ((pred == 1) & (target == 1)).sum().item()
        fp += ((pred == 1) & (target == 0)).sum().item()
        tn += ((pred == 0) & (target == 0)).sum().item()
        fn += ((pred == 0) & (target == 1)).sum().item()

    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return accuracy, precision, recall, f1_score, tn, tp, fn, fp

In [306]:
accuracy, precision, recall, f1_score, tn, tp, fn, fp = test(model, new_graphs)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"True Negatives: {tn}")
print(f"True Positives: {tp}")
print(f"False Negatives: {fn}")
print(f"False Positives: {fp}")

Accuracy: 0.5214
Precision: 0.5282
Recall: 0.9361
F1 Score: 0.6754
True Negatives: 16
True Positives: 337
False Negatives: 23
False Positives: 301


In [307]:
accuracy, precision, recall, f1_score, tn, tp, fn, fp = test(model, new_dataset)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"True Negatives: {tn}")
print(f"True Positives: {tp}")
print(f"False Negatives: {fn}")
print(f"False Positives: {fp}")

Accuracy: 0.7563
Precision: 0.7077
Recall: 0.7970
F1 Score: 0.7497
True Negatives: 1962
True Positives: 1830
False Negatives: 466
False Positives: 756


In [299]:
accuracy, precision, recall, f1_score, tn, tp, fn, fp = test(model, dataset)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"True Negatives: {tn}")
print(f"True Positives: {tp}")
print(f"False Negatives: {fn}")
print(f"False Positives: {fp}")

Accuracy: 0.7929
Precision: 0.7664
Recall: 0.7712
F1 Score: 0.7688
True Negatives: 1946
True Positives: 1493
False Negatives: 443
False Positives: 455
