In [1]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from networks import FractalNet, FractalNetShared, Net, GNN_no_rel, GNN
from subgraph import Graph_to_Subgraph
from train import train_model, get_qm9

In [21]:
# GLOBAL VARIABLES FOR THE EXPERIMENT
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 55
batch_size = 32
node_features = 5
# data related (x and y)
Z_ONE_HOT_DIM = 5
LABEL_INDEX = 7
EDGE_ATTR_DIM = 4
edge_features = 0
hidden_features = 64
out_features = 1

# TRAINING SHARED PARAMETERS FRACTAL NET

In [9]:
model_name = 'FractalNetShared'
model = FractalNetShared(node_features,
                           edge_features,
                           hidden_features,
                           out_features,
                           depth=1,
                           pool='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)
train, valid, test = get_qm9("data/qm9",
                             device=device,
                             LABEL_INDEX = LABEL_INDEX,
                             transform=Graph_to_Subgraph())
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

In [10]:
fractalnetshared_results = train_model(model, model_name, epochs, train_loader, valid_loader, test_loader, optimizer, criterion, scheduler, device, LABEL_INDEX, Z_ONE_HOT_DIM)

Total number of parameters: 25281


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

ValueError: Model name not recognized

In [22]:
model_name = 'FractalNet'
model = FractalNet(node_features,
                   edge_features,
                   hidden_features,
                   out_features,
                   depth=3,
                   pool='add',
                   add_residual_skip=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)
train, valid, test = get_qm9("data/qm9",
                             device=device,
                             LABEL_INDEX=LABEL_INDEX,
                             transform=Graph_to_Subgraph())
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

In [None]:
fractalnet_results = train_model(model, model_name, epochs, train_loader, valid_loader, test_loader, optimizer, criterion, scheduler, device, LABEL_INDEX, Z_ONE_HOT_DIM)

Total number of parameters: 298433


  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 0, Loss: 0.20341165145688692, Valid Loss: 0.0024187428742489473


  0%|          | 0/3125 [00:00<?, ?it/s]

# TRAINING SAME NET AS FRACTAL BUT WITHOUT SUBNODES

In [9]:
model_name = 'Net'
model = Net(node_features,
            edge_features,
            hidden_features,
            out_features,
            depth=3,
            pool='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)
train, valid, test = get_qm9("data/qm9",
                             device=device,
                             LABEL_INDEX=LABEL_INDEX,
                             transform=None)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

In [10]:
no_subnode_results = train_model(model, model_name, epochs, train_loader, valid_loader, test_loader, optimizer, criterion, scheduler, device, LABEL_INDEX, Z_ONE_HOT_DIM)

Total number of parameters: 74945


  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 0, Loss: 0.39469747448563575, Valid Loss: 0.3496957777288204


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.3404630948036909, Valid Loss: 0.34609778178409456


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.33312930680304764, Valid Loss: 0.33795559165267325


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.3262712341582775, Valid Loss: 0.3275855558034711


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.31916085559368135, Valid Loss: 0.32447425214746317


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.3125726771378517, Valid Loss: 0.31598887183129215


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.30597684620141985, Valid Loss: 0.31275037258339766


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 7, Loss: 0.2992024395740032, Valid Loss: 0.30143371608834296


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 8, Loss: 0.2920136449587345, Valid Loss: 0.2955020200282621


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 9, Loss: 0.28450348773658274, Valid Loss: 0.28576937596352336


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 10, Loss: 0.27638406969070434, Valid Loss: 0.27505965342822547


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 11, Loss: 0.2666834018576145, Valid Loss: 0.2683354552728109


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 12, Loss: 0.25462042869746687, Valid Loss: 0.24546331035110136


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 13, Loss: 0.2336000930544734, Valid Loss: 0.22905728156669453


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 14, Loss: 0.21710871714413166, Valid Loss: 0.21127650131599401


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 15, Loss: 0.20473991336286068, Valid Loss: 0.18992331687111064


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 16, Loss: 0.19690393943965434, Valid Loss: 0.18175645994421202


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 17, Loss: 0.18910795077860357, Valid Loss: 0.18110866698260886


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 18, Loss: 0.18095013382971287, Valid Loss: 0.17924172641893926


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 19, Loss: 0.17486114077806472, Valid Loss: 0.16606051257005133


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 20, Loss: 0.16895723731517792, Valid Loss: 0.17087002115222974


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 21, Loss: 0.16660132495045663, Valid Loss: 0.17235678211806682


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 22, Loss: 0.15833399774730206, Valid Loss: 0.15268774347278638


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 23, Loss: 0.15440453350841998, Valid Loss: 0.1613359309840031


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 24, Loss: 0.15033643167436123, Valid Loss: 0.14380964322592885


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 25, Loss: 0.14679466474831104, Valid Loss: 0.1413365977105146


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 26, Loss: 0.14254861802756785, Valid Loss: 0.15019171750440766


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 27, Loss: 0.13922701494157314, Valid Loss: 0.1495399278240463


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 28, Loss: 0.1353263271576166, Valid Loss: 0.12585684466071592


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 29, Loss: 0.13150895205020904, Valid Loss: 0.13156366802918645


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 30, Loss: 0.12913176690012215, Valid Loss: 0.12195155424431871


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 31, Loss: 0.12637155183672905, Valid Loss: 0.12481225633059446


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 32, Loss: 0.12298367806911469, Valid Loss: 0.12826699773057962


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 33, Loss: 0.11981479392468929, Valid Loss: 0.11831562449566473


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 34, Loss: 0.117449780356884, Valid Loss: 0.11105680902497456


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 35, Loss: 0.11437987644731998, Valid Loss: 0.11098244209021044


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 36, Loss: 0.11104934961080551, Valid Loss: 0.1133606125550053


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 37, Loss: 0.10953255480289459, Valid Loss: 0.10602636422854833


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 38, Loss: 0.10628370450615883, Valid Loss: 0.1047638409410993


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 39, Loss: 0.103815485188663, Valid Loss: 0.10152359197314936


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 40, Loss: 0.1012169891628623, Valid Loss: 0.09708650622028893


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 41, Loss: 0.09972091797560453, Valid Loss: 0.09586216237979195


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 42, Loss: 0.0970341670936346, Valid Loss: 0.10191663833091054


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 43, Loss: 0.09493588941872119, Valid Loss: 0.09884174165729516


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 44, Loss: 0.09320176350146532, Valid Loss: 0.09031768523442288


  0%|          | 0/651 [00:00<?, ?it/s]

Test Loss: 0.0867915347393047


# TRAINING A NORMAL GNN WITH NO RELATIONAL INFO NET

In [17]:
model_name = 'GNN_no_rel'
model = GNN_no_rel(5, edge_features, hidden_features, out_features, num_convolution_blocks=3, pooling='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)
train, valid, test = get_qm9("data/qm9",
                             device=device,
                             LABEL_INDEX=LABEL_INDEX,
                             transform=None)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)
test_loader = DataLoader(test, batch_size=32, shuffle=False)

In [18]:
gnn_no_rel_results = train_model(model, model_name, epochs, train_loader, valid_loader, test_loader, optimizer, criterion, scheduler, device, LABEL_INDEX, Z_ONE_HOT_DIM)

Total number of parameters: 277057


  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 0, Loss: 0.34073594618916514, Valid Loss: 0.26796615189804246


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.2002595726749301, Valid Loss: 0.11254382863069495


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.052325787158980966, Valid Loss: 0.01320092515430797


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.004784174772752449, Valid Loss: 0.0029446304483434407


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.0017802835443534422, Valid Loss: 0.0011488691687557143


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.0013485616207064596, Valid Loss: 0.0008311322576402814


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.0012334666186670075, Valid Loss: 0.0003099502361021205


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 7, Loss: 0.0010672309357262566, Valid Loss: 0.0017416033840451104


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 8, Loss: 0.0009332941744074923, Valid Loss: 0.0005364063240841742


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 9, Loss: 0.000706238659650553, Valid Loss: 0.0005293432732171099


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 00011: reducing learning rate of group 0 to 3.5000e-04.
Epoch: 10, Loss: 0.0005822339579527033, Valid Loss: 0.0005092425674254733


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 11, Loss: 0.0003077914858500299, Valid Loss: 0.0008925487304637476


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 12, Loss: 0.00034889882537550876, Valid Loss: 0.0005732798170199636


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 13, Loss: 0.00028307309053852807, Valid Loss: 0.0005240922351795495


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 00015: reducing learning rate of group 0 to 2.4500e-04.
Epoch: 14, Loss: 0.00031421407085596004, Valid Loss: 0.0004182716977618711


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 15, Loss: 0.0001408623219979927, Valid Loss: 0.0004023277626471715


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 16, Loss: 0.00015678595537177897, Valid Loss: 0.0004508169166948877


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 17, Loss: 0.00015129269527606085, Valid Loss: 0.00037310238281751114


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 00019: reducing learning rate of group 0 to 1.7150e-04.
Epoch: 18, Loss: 0.00014093932886949915, Valid Loss: 0.0005385101686695594


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 19, Loss: 7.378145212042e-05, Valid Loss: 0.0006095445816242954


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 20, Loss: 5.1224616796316695e-05, Valid Loss: 0.00043007814446944907


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 21, Loss: 8.91770587641804e-05, Valid Loss: 0.0004898761498408273


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 00023: reducing learning rate of group 0 to 1.2005e-04.
Epoch: 22, Loss: 6.0796163273807904e-05, Valid Loss: 0.0006306070788185479


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 23, Loss: 2.972445246125062e-05, Valid Loss: 0.0004025591450440065


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 24, Loss: 3.784198907251266e-05, Valid Loss: 0.00041785356562161377


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 25, Loss: 3.283295836803518e-05, Valid Loss: 0.0004213635944069074


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 00027: reducing learning rate of group 0 to 8.4035e-05.
Epoch: 26, Loss: 4.4157627581244016e-05, Valid Loss: 0.000441299996451655


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 27, Loss: 1.4356185552260285e-05, Valid Loss: 0.00042767308757948847


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 28, Loss: 2.1143508784134612e-05, Valid Loss: 0.0004223199410105259


  0%|          | 0/3125 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch: 29, Loss: 1.61127240807582e-05, Valid Loss: 0.00044732314751880724


  0%|          | 0/3125 [00:00<?, ?it/s]

KeyboardInterrupt: 

# TRAINING A GNN WITH EDGE FEATURES

In [12]:
# create a fractal net and train it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'GNN'
model = GNN(n_node_features = Z_ONE_HOT_DIM,
            n_edge_features=EDGE_ATTR_DIM,
            n_hidden=64,
            n_output=out_features,
            num_convolution_blocks=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True)
train, valid, test = get_qm9("data/qm9",
                             device=device,
                             LABEL_INDEX=LABEL_INDEX,
                             transform=None)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid, batch_size=32, shuffle=False)

  warn("Using non-standard permutation since permute.pt does not exist.")


In [13]:
gnn_results = train_model(model, model_name, epochs, train_loader, valid_loader, test_loader, optimizer, criterion, scheduler, device, LABEL_INDEX, Z_ONE_HOT_DIM)

Total number of parameters: 227329


  0%|          | 0/45 [00:00<?, ?it/s]

  0%|          | 0/3125 [00:00<?, ?it/s]

IndexError: too many indices for tensor of dimension 2

# PLOTTING LOSS

In [None]:
# plot loss
# IGNORE FOR NOW #
import matplotlib.pyplot as plt
# plot train loss on same plot for different runs
plt.plot(fractalnetshared_results['train_loss'], label='FractalNetShared')
plt.plot(fractalnet_results['train_loss'], label='FractalNet')
plt.plot(gnn_no_rel_results['train_loss'], label='GNN_no_rel')
plt.plot(gnn_results['train_loss'], label='GNN')
plt.plot(no_subnode_results['train_loss'], label='No Subnodes')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# plot validation loss on same plot for different runs
plt.plot(fractalnetshared_results['valid_loss'], label='FractalNetShared')
plt.plot(fractalnet_results['valid_loss'], label='FractalNet')
plt.plot(gnn_no_rel_results['valid_loss'], label='GNN_no_rel')
plt.plot(gnn_results['valid_loss'], label='GNN')
plt.plot(no_subnode_results['valid_loss'], label='No Subnodes')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [21]:
# print final test losses of all models
#print('FractalNetShared Test Loss: ', fractalnetshared_results['test_loss'])
print('FractalNet Test Loss: ', fractalnet_results['test_loss'])
print('GNN_no_rel Test Loss: ', gnn_no_rel_results['test_loss'])
#print('GNN Test Loss: ', gnn_results['test_loss'])
print('No Subnodes Test Loss: ', no_subnode_results['test_loss'])

FractalNet Test Loss:  0.0005467538204020275
GNN_no_rel Test Loss:  1.4065201867546444e-05
No Subnodes Test Loss:  3.3362182585271106e-06
