In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gnn_boundary.datasets import *
from gnn_boundary.models import *

In [3]:
import torch
from tqdm.auto import trange

# Motif

In [13]:
motif = MotifDataset(seed=12345)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=6,
                            num_layers=3)

Processing...


Loading graphs:   0%|          | 0/11534 [00:00<?, ?it/s]

Done!


In [41]:
for epoch in trange(128):
    train_loss = motif_train.model_fit(motif_model, lr=0.001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

  0%|          | 0/128 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 0.0700, Train Acc: 0.6204, Test Acc: 0.5167, Train F1: {'EC1': 0.5068492889404297, 'EC2': 0.6111111044883728, 'EC3': 0.7100591659545898, 'EC4': 0.6802030205726624, 'EC5': 0.6011560559272766, 'EC6': 0.5953488349914551}, Test F1: {'EC1': 0.3333333432674408, 'EC2': 0.5263158082962036, 'EC3': 0.8181818127632141, 'EC4': 0.6428571343421936, 'EC5': 0.23529411852359772, 'EC6': 0.375}
Epoch: 001, Train Loss: 0.0626, Train Acc: 0.6204, Test Acc: 0.5167, Train F1: {'EC1': 0.5068492889404297, 'EC2': 0.6111111044883728, 'EC3': 0.7100591659545898, 'EC4': 0.6802030205726624, 'EC5': 0.6011560559272766, 'EC6': 0.5953488349914551}, Test F1: {'EC1': 0.3333333432674408, 'EC2': 0.5263158082962036, 'EC3': 0.8181818127632141, 'EC4': 0.6428571343421936, 'EC5': 0.23529411852359772, 'EC6': 0.375}
Epoch: 002, Train Loss: 0.0554, Train Acc: 0.6204, Test Acc: 0.5167, Train F1: {'EC1': 0.5068492889404297, 'EC2': 0.6111111044883728, 'EC3': 0.7100591659545898, 'EC4': 0.6802030205726624, 'EC5':

KeyboardInterrupt: 

In [22]:
# torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

In [6]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

<All keys matched successfully>

# ENZYMES

In [4]:
enzymes = ENZYMESDataset(seed=12345)
enzymes_train, enzymes_val = enzymes.train_test_split(k=10)
enzymes_model = GCNClassifier(node_features=len(enzymes.NODE_CLS),
                              num_classes=len(enzymes.GRAPH_CLS),
                              hidden_channels=32,
                              num_layers=3)

In [8]:
for epoch in trange(4096):
    train_loss = enzymes_train.model_fit(enzymes_model, lr=0.0001)
    train_metrics = enzymes_train.model_evaluate(enzymes_model)
    val_metrics = enzymes_val.model_evaluate(enzymes_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

  0%|          | 0/4096 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 0.9627, Train Acc: 0.6741, Test Acc: 0.5000, Train F1: {'EC1': 0.5714285969734192, 'EC2': 0.659217894077301, 'EC3': 0.7515151500701904, 'EC4': 0.7472527623176575, 'EC5': 0.6737967729568481, 'EC6': 0.6384976506233215}, Test F1: {'EC1': 0.3529411852359772, 'EC2': 0.5263158082962036, 'EC3': 0.8333333134651184, 'EC4': 0.47999998927116394, 'EC5': 0.25, 'EC6': 0.42105263471603394}
Epoch: 001, Train Loss: 0.9737, Train Acc: 0.6778, Test Acc: 0.5000, Train F1: {'EC1': 0.5960264801979065, 'EC2': 0.6630434989929199, 'EC3': 0.7692307829856873, 'EC4': 0.7526881694793701, 'EC5': 0.639053225517273, 'EC6': 0.6346153616905212}, Test F1: {'EC1': 0.3529411852359772, 'EC2': 0.5263158082962036, 'EC3': 0.800000011920929, 'EC4': 0.47999998927116394, 'EC5': 0.2666666805744171, 'EC6': 0.42105263471603394}
Epoch: 002, Train Loss: 0.9631, Train Acc: 0.6778, Test Acc: 0.4333, Train F1: {'EC1': 0.5931034684181213, 'EC2': 0.6666666865348816, 'EC3': 0.7329843044281006, 'EC4': 0.75, 'EC5': 0.

In [9]:
torch.save(enzymes_model.state_dict(), f"ckpts/enzymes.pt")

In [5]:
enzymes_model.load_state_dict(torch.load('ckpts/enzymes.pt'))

<All keys matched successfully>

# COLLAB

In [42]:
collab = CollabDataset(seed=12345)
collab_train, collab_val = collab.train_test_split(k=10)
collab_model = GCNClassifier(node_features=len(collab.NODE_CLS),
                             num_classes=len(collab.GRAPH_CLS),
                             hidden_channels=64,
                             num_layers=5)

In [43]:
for epoch in trange(1024):
    train_loss = collab_train.model_fit(collab_model, lr=0.001)
    train_metrics = collab_train.model_evaluate(collab_model)
    val_metrics = collab_val.model_evaluate(collab_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

  0%|          | 0/1024 [00:00<?, ?it/s]

Epoch: 000, Train Loss: 0.9901, Train Acc: 0.4789, Test Acc: 0.4660, Train F1: {'High Energy': 0.544240415096283, 'Condensed Matter': 0.0, 'Astro': 0.48462414741516113}, Test F1: {'High Energy': 0.5205479264259338, 'Condensed Matter': 0.0, 'Astro': 0.4878048896789551}
Epoch: 001, Train Loss: 0.9488, Train Acc: 0.5676, Test Acc: 0.5480, Train F1: {'High Energy': 0.7063971161842346, 'Condensed Matter': 0.0, 'Astro': 0.25417661666870117}, Test F1: {'High Energy': 0.6852367520332336, 'Condensed Matter': 0.0, 'Astro': 0.27586206793785095}
Epoch: 002, Train Loss: 0.9152, Train Acc: 0.5962, Test Acc: 0.5660, Train F1: {'High Energy': 0.7147119045257568, 'Condensed Matter': 0.0, 'Astro': 0.42067182064056396}, Test F1: {'High Energy': 0.6841338872909546, 'Condensed Matter': 0.0, 'Astro': 0.41025641560554504}
Epoch: 003, Train Loss: 0.7949, Train Acc: 0.6967, Test Acc: 0.6720, Train F1: {'High Energy': 0.7693422436714172, 'Condensed Matter': 0.49306121468544006, 'Astro': 0.6294326186180115}, Tes

In [None]:
torch.save(collab_model.state_dict(), f"ckpts/collab.pt")

In [26]:
# collab_model.load_state_dict(torch.load('ckpts/collab.pt'))

<All keys matched successfully>