In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gnninterpreter import *

In [3]:
import torch
from tqdm.auto import trange

# Cyclicity

In [4]:
cyclicity = CyclicityDataset(seed=12345)
cyclicity_train, cyclicity_val = cyclicity.train_test_split(k=10)
cyclicity_model = NNConvClassifier(node_features=len(cyclicity.NODE_CLS),
                                   edge_features=len(cyclicity.EDGE_CLS),
                                   num_classes=len(cyclicity.GRAPH_CLS),
                                   hidden_channels=32)

In [5]:
cyclicity_model.load_state_dict(torch.load('ckpts/cyclicity.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = cyclicity_train.model_fit(motif_model, lr=0.001)
    train_metrics = cyclicity_train.model_evaluate(motif_model)
    val_metrics = cyclicity_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(cyclicity_model.state_dict(), 'ckpts/cyclicity.pt')

# Motif

In [4]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [5]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

<All keys matched successfully>

In [8]:
for epoch in range(128):
    train_loss = motif_train.model_fit(motif_model, lr=0.001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

Epoch: 000, Train Loss: 0.0158, Train Acc: 0.9965, Test Acc: 0.9905, Train F1: {'partial': 0.9913957715034485, 'house': 0.9992445111274719, 'house_x': 0.9963099360466003, 'comp_4': 0.9965269565582275, 'comp_5': 0.9992886185646057}, Test F1: {'partial': 0.9755011200904846, 'house': 0.9978213310241699, 'house_x': 0.9912663698196411, 'comp_4': 0.9894291758537292, 'comp_5': 0.9978586435317993}
Epoch: 001, Train Loss: 0.0255, Train Acc: 0.9999, Test Acc: 0.9957, Train F1: {'partial': 0.9997596740722656, 'house': 1.0, 'house_x': 1.0, 'comp_4': 0.9997692108154297, 'comp_5': 1.0}, Test F1: {'partial': 0.9887640476226807, 'house': 0.9978213310241699, 'house_x': 1.0, 'comp_4': 0.9915611743927002, 'comp_5': 1.0}
Epoch: 002, Train Loss: 0.0159, Train Acc: 0.9976, Test Acc: 0.9948, Train F1: {'partial': 0.993959903717041, 'house': 0.9992441534996033, 'house_x': 0.9997548460960388, 'comp_4': 0.9965509176254272, 'comp_5': 0.9985781908035278}, Test F1: {'partial': 0.9863636493682861, 'house': 0.997821

KeyboardInterrupt: 

In [9]:
torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

# MUTAG

In [9]:
mutag = MUTAGDataset(seed=12345)
mutag_train, mutag_val = mutag.train_test_split(k=10)
mutag_model = GCNClassifier(node_features=len(mutag.NODE_CLS),
                            num_classes=len(mutag.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=3)

In [11]:
mutag_model.load_state_dict(torch.load('ckpts/mutag.pt'))

<All keys matched successfully>

In [None]:
for epoch in trange(128):
    train_loss = mutag_train.model_fit(motif_model, lr=0.001)
    train_metrics = mutag_train.model_evaluate(motif_model)
    val_metrics = mutag_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
torch.save(mutag_model.state_dict(), 'ckpts/mutag.pt')

# Shape

In [12]:
shape = ShapeDataset(seed=12345)
shape_train, shape_val = shape.train_test_split(k=10)
shape_model = GCNClassifier(node_features=len(shape.NODE_CLS),
                            num_classes=len(shape.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=4)

In [13]:
shape_model.load_state_dict(torch.load('ckpts/shape.pt'))

<All keys matched successfully>

In [None]:
for epoch in range(128):
    train_loss = shape_train.model_fit(motif_model, lr=0.001)
    train_metrics = shape_train.model_evaluate(motif_model)
    val_metrics = shape_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
torch.save(shape_model.state_dict(), 'ckpts/shape_overfitting.pt')