In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gnninterpreter import *

In [3]:
import torch
from tqdm.auto import trange

# Cyclicity

In [4]:
cyclicity = CyclicityDataset(seed=12345)
cyclicity_train, cyclicity_val = cyclicity.train_test_split(k=10)
cyclicity_model = NNConvClassifier(node_features=len(cyclicity.NODE_CLS),
                                   edge_features=len(cyclicity.EDGE_CLS),
                                   num_classes=len(cyclicity.GRAPH_CLS),
                                   hidden_channels=32)

In [5]:
cyclicity_model.load_state_dict(torch.load('ckpts/cyclicity.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = cyclicity_train.model_fit(cyclicity_model, lr=0.001)
    train_metrics = cyclicity_train.model_evaluate(cyclicity_model)
    val_metrics = cyclicity_val.model_evaluate(cyclicity_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(cyclicity_model.state_dict(), 'ckpts/cyclicity.pt')

# Motif

In [10]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [5]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

<All keys matched successfully>

In [12]:
for epoch in range(64):
    train_loss = motif_train.model_fit(motif_model, lr=0.0001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

Epoch: 000, Train Loss: 0.0043, Train Acc: 0.9995, Test Acc: 0.9957, Train F1: {'partial': 0.9987972378730774, 'house': 0.9994962215423584, 'house_x': 1.0, 'comp_4': 0.9993079304695129, 'comp_5': 1.0}, Test F1: {'partial': 0.9887133240699768, 'house': 0.9956331849098206, 'house_x': 1.0, 'comp_4': 0.9937106966972351, 'comp_5': 1.0}
Epoch: 001, Train Loss: 0.0028, Train Acc: 0.9997, Test Acc: 0.9948, Train F1: {'partial': 0.9992786645889282, 'house': 1.0, 'house_x': 1.0, 'comp_4': 0.9993079304695129, 'comp_5': 1.0}, Test F1: {'partial': 0.9864864945411682, 'house': 0.9934354424476624, 'house_x': 1.0, 'comp_4': 0.9937106966972351, 'comp_5': 1.0}
Epoch: 002, Train Loss: 0.0026, Train Acc: 0.9995, Test Acc: 0.9957, Train F1: {'partial': 0.9987977743148804, 'house': 0.9994962215423584, 'house_x': 1.0, 'comp_4': 0.9993076324462891, 'comp_5': 1.0}, Test F1: {'partial': 0.9887133240699768, 'house': 0.9978213310241699, 'house_x': 1.0, 'comp_4': 0.9915966391563416, 'comp_5': 1.0}
Epoch: 003, Trai

In [15]:
torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

# MUTAG

In [9]:
mutag = MUTAGDataset(seed=12345)
mutag_train, mutag_val = mutag.train_test_split(k=10)
mutag_model = GCNClassifier(node_features=len(mutag.NODE_CLS),
                            num_classes=len(mutag.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [11]:
mutag_model.load_state_dict(torch.load('ckpts/mutag.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = mutag_train.model_fit(mutag_model, lr=0.001)
    train_metrics = mutag_train.model_evaluate(mutag_model)
    val_metrics = mutag_val.model_evaluate(mutag_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
torch.save(mutag_model.state_dict(), 'ckpts/mutag.pt')

# Shape

In [12]:
shape = ShapeDataset(seed=12345)
shape_train, shape_val = shape.train_test_split(k=10)
shape_model = GCNClassifier(node_features=len(shape.NODE_CLS),
                            num_classes=len(shape.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=4)

In [13]:
shape_model.load_state_dict(torch.load('ckpts/shape.pt'))

<All keys matched successfully>

In [None]:
for epoch in range(128):
    train_loss = shape_train.model_fit(shape_model, lr=0.001)
    train_metrics = shape_train.model_evaluate(shape_model)
    val_metrics = shape_val.model_evaluate(shape_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
torch.save(shape_model.state_dict(), 'ckpts/shape_overfitting.pt')