In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gnninterpreter import *

In [3]:
import torch
from tqdm.auto import trange

# Cyclicity

In [4]:
cyclicity = CyclicityDataset(seed=12345)
cyclicity_train, cyclicity_val = cyclicity.train_test_split(k=10)
cyclicity_model = NNConvClassifier(node_features=len(cyclicity.NODE_CLS),
                                   edge_features=len(cyclicity.EDGE_CLS),
                                   num_classes=len(cyclicity.GRAPH_CLS),
                                   hidden_channels=32)

In [5]:
cyclicity_model.load_state_dict(torch.load('ckpts/cyclicity.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = cyclicity_train.fit_model(cyclicity_model, lr=0.001)
    train_f1 = cyclicity_train.evaluate_model(cyclicity_model)
    val_f1 = cyclicity_val.evaluate_model(cyclicity_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [None]:
# torch.save(cyclicity_model.state_dict(), 'ckpts/cyclicity.pt')

# Motif

In [6]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [7]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

<All keys matched successfully>

In [None]:
for epoch in range(128):
    train_loss = motif_train.fit_model(motif_model, lr=0.001)
    train_f1 = motif_train.evaluate_model(motif_model)
    val_f1 = motif_val.evaluate_model(motif_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [None]:
torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

# MUTAG

In [9]:
mutag = MUTAGDataset(seed=12345)
mutag_train, mutag_val = mutag.train_test_split(k=10)
mutag_model = GCNClassifier(node_features=len(mutag.NODE_CLS),
                            num_classes=len(mutag.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [11]:
mutag_model.load_state_dict(torch.load('ckpts/mutag.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = mutag_train.fit_model(mutag_model, lr=0.001)
    train_f1 = mutag_train.evaluate_model(mutag_model)
    val_f1 = mutag_val.evaluate_model(mutag_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [None]:
torch.save(mutag_model.state_dict(), 'ckpts/mutag.pt')

# Shape

In [12]:
shape = ShapeDataset(seed=12345)
shape_train, shape_val = shape.train_test_split(k=10)
shape_model = GCNClassifier(node_features=len(shape.NODE_CLS),
                            num_classes=len(shape.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=4)

In [13]:
shape_model.load_state_dict(torch.load('ckpts/shape.pt'))

<All keys matched successfully>

In [None]:
for epoch in range(128):
    train_loss = shape_train.fit_model(shape_model, lr=0.0001)
    train_f1 = shape_train.evaluate_model(shape_model)
    val_f1 = shape_val.evaluate_model(shape_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [None]:
torch.save(shape_model.state_dict(), 'ckpts/shape_overfitting.pt')