In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gnninterpreter import *

In [12]:
import torch
from tqdm.auto import trange

# Cyclicity

In [17]:
cyclicity = CyclicityDataset(seed=12345)
cyclicity_train, cyclicity_val = cyclicity.train_test_split(k=10)
cyclicity_model = NNConvClassifier(node_features=len(cyclicity.NODE_CLS),
                                   edge_features=len(cyclicity.EDGE_CLS),
                                   num_classes=len(cyclicity.GRAPH_CLS),
                                   hidden_channels=32)

In [21]:
for epoch in trange(128):
    train_loss = cyclicity_train.fit_model(cyclicity_model, lr=0.001)
    train_f1 = cyclicity_train.evaluate_model(cyclicity_model)
    val_f1 = cyclicity_val.evaluate_model(cyclicity_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 0.1018, Train F1: {'red_cyclic': 0.9803921580314636, 'green_cyclic': 0.9671629667282104, 'acyclic': 0.9750848412513733}, Test F1: {'red_cyclic': 0.9597197771072388, 'green_cyclic': 0.9403747916221619, 'acyclic': 0.9483066201210022}, 


In [22]:
torch.save(cyclicity_model.state_dict(), 'ckpts/cyclicity.pt')

# Motif

In [9]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [None]:
for epoch in trange(128):
    train_loss = motif_train.fit_model(motif_model, lr=0.001)
    train_f1 = motif_train.evaluate_model(motif_model)
    val_f1 = motif_val.evaluate_model(motif_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [22]:
torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

# MUTAG

In [10]:
mutag = MUTAGDataset(seed=12345)
mutag_train, mutag_val = mutag.train_test_split(k=10)
mutag_model = GCNClassifier(node_features=len(mutag.NODE_CLS),
                            num_classes=len(mutag.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [None]:
for epoch in trange(128):
    train_loss = mutag_train.fit_model(mutag_model, lr=0.001)
    train_f1 = mutag_train.evaluate_model(mutag_model)
    val_f1 = mutag_val.evaluate_model(mutag_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

In [22]:
torch.save(mutag_model.state_dict(), 'ckpts/mutag.pt')

# Shape

In [23]:
shape = ShapeDataset(seed=12345)
shape_train, shape_val = shape.train_test_split(k=10)
shape_model = GCNClassifier(node_features=len(shape.NODE_CLS),
                            num_classes=len(shape.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=4)

In [28]:
for epoch in trange(8):
    train_loss = shape_train.fit_model(shape_model, lr=0.0001)
    train_f1 = shape_train.evaluate_model(shape_model)
    val_f1 = shape_val.evaluate_model(shape_model)
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train F1: {train_f1}, '
          f'Test F1: {val_f1}')

  0%|          | 0/8 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 0.0339, Train F1: {'random': 0.9834139347076416, 'lollipop': 0.9876543283462524, 'wheel': 0.9925821423530579, 'grid': 0.9814377427101135, 'star': 1.0}, Test F1: {'random': 0.9722222089767456, 'lollipop': 0.9783281683921814, 'wheel': 0.9916897416114807, 'grid': 0.9693251252174377, 'star': 1.0}
Epoch: 001, Train Loss: 0.0340, Train F1: {'random': 0.9810671210289001, 'lollipop': 0.9810040593147278, 'wheel': 0.9902642369270325, 'grid': 0.975304365158081, 'star': 1.0}, Test F1: {'random': 0.9726027250289917, 'lollipop': 0.9726443886756897, 'wheel': 0.9917808175086975, 'grid': 0.9743589758872986, 'star': 1.0}
Epoch: 002, Train Loss: 0.0334, Train F1: {'random': 0.97994464635849, 'lollipop': 0.9870041012763977, 'wheel': 0.9906086921691895, 'grid': 0.9790593981742859, 'star': 1.0}, Test F1: {'random': 0.9722222089767456, 'lollipop': 0.9756097793579102, 'wheel': 0.9945054650306702, 'grid': 0.9748427867889404, 'star': 1.0}
Epoch: 003, Train Loss: 0.0346, Train F1: {'rando

In [29]:
torch.save(shape_model.state_dict(), 'ckpts/shape.pt')