In [1]:
# --- Standard libraries
import os.path as osp
import pandas as pd
import numpy as np
import pickle as pkl
# --- Matplotlib
import matplotlib.pyplot as plt
import imageio
# --- PyTorch and PyG
import torch
from torch_geometric.loader import DataLoader
# --- XASNet
from XASNet.data import QM9_XAS
from XASNet.data import save_split
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet
from XASNet.trainer import GNNTrainer

In [2]:
# --- Load in the dataset
root = './datasets/mol_dataset.pt'
go_spec = QM9_XAS(root=root,
                  raw_dir='./datasets/',
                  spectra=[])

In [3]:
# --- Print details of the dataset
print(go_spec)
print('------------')
print(f'Number of graphs: {len(go_spec)}')
print(f'Number of features: {go_spec.num_features}')
print('')

# --- Print details of the first molecule/graph in dataset
data = go_spec[0]

print(data)
print('------------')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

QM9_XAS(319)
------------
Number of graphs: 319
Number of features: 17

Data(x=[28, 17], edge_index=[2, 72], edge_attr=[72, 6], spectrum=[200], idx=[1], smiles='c12[c:2]3[cH:1][cH:23][c:22]4[c:20]1[c:16]1[c:18]([cH:19][c:21]4[OH:24])[CH:17]4[CH:15]([C:14]5([OH:27])[CH:12]1[C:8]16[c:4]2[c:6]([cH:5][cH:3]3)[CH:7]=[CH:9][C:10]1([CH:11]=[CH:13]5)[O:26]6)[O:25]4')
------------
Number of nodes: 28
Number of edges: 72
Average node degree: 2.57
Has isolated nodes: False
Has self loops: False
Is undirected: True


In [4]:
# --- Create spilt file with the dataset
# --- split into test, validation and test datasets
idxs = save_split(
    path='./datasets/xasnet-mol-split.npz',
    ndata=len(go_spec),
    ntrain=244,
    nval=25,
    ntest=48,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

{'train': 244, 'val': 25, 'test': 48}


In [5]:
# --- Create variables for each dataset split
train_go = [go_spec[i] for i in idxs['train']]
val_go = [go_spec[i] for i in idxs['val']]
test_go = [go_spec[i] for i in idxs['test']]

# --- Save datasets splits into dataloaders
train_loader = DataLoader(train_go, batch_size=244, shuffle=True)
val_loader = DataLoader(val_go, batch_size=25, shuffle=True)
test_loader = DataLoader(test_go, batch_size=48, shuffle=False)

print(f'Training dataset length: {len(train_go)}, compiled in {len(train_loader)} loaders')
print(f'Validation dataset length: {len(val_go)}, compiled in {len(val_loader)} loaders')
print(f'Test dataset length: {len(test_go)}, compiled in {len(test_loader)} loaders')

Training dataset length: 244, compiled in 1 loaders
Validation dataset length: 25, compiled in 1 loaders
Test dataset length: 48, compiled in 1 loaders


In [6]:
# --- Save the dataloader to a file
torch.save(test_go, './datasets/test_mol_dataset.pt')

In [7]:
# --- Define cost functions
def RSE_loss(prediction, target):
    dE = (300 - 280) / 200
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

In [8]:
# --- Set name for ML model
model_name = 'xasnet_model'
# --- Set number of epochs to run
num_epochs = 500
# --- Set the learning rate 
lr = 0.01
# --- Milestones to reduce learning rate in steps 
milestones = np.arange(10, 100, 10).tolist()

In [9]:
# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [17, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final output
    num_layers = 3, # number of total layers
    heads = 1
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [10]:
# --- View the details of the created model
print(xasnet_gnn)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GNN(
  (batch_norms): ModuleList()
  (interaction_layers): ModuleList(
    (0): GCNConv(17, 256)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): GCNConv(256, 128)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): GCNConv(128, 64)
  )
  (dropout): Dropout(p=0.6, inplace=False)
  (out): Linear(in_features=64, out_features=200, bias=True)
)
----
 Model will be trained on: cpu


In [11]:
chosen_model = xasnet_gnn

In [12]:
# --- Set additional ML parameters
optimizer = torch.optim.AdamW(chosen_model.parameters(), lr=lr)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [13]:
# --- Create trainier
trainer = GNNTrainer(model = chosen_model,
                     model_name = model_name,
                     device = device,
                     metric_path = './metrics')

In [14]:
# --- Train the ML model
trainer.train_val(train_loader, val_loader, optimizer, RSE_loss,
                  scheduler, num_epochs, write_every=25, train_graphnet=False)

  1%|          | 3/500 [00:00<00:40, 12.28it/s]

epoch 0 | average train loss = 0.00009  and average validation loss = 0.00207  |learning rate = 0.01000


  6%|▌         | 29/500 [00:01<00:29, 16.20it/s]

epoch 25 | average train loss = 0.00002  and average validation loss = 0.00049  |learning rate = 0.00640


 11%|█         | 53/500 [00:03<00:26, 16.56it/s]

epoch 50 | average train loss = 0.00002  and average validation loss = 0.00043  |learning rate = 0.00328


 15%|█▌        | 77/500 [00:04<00:28, 14.81it/s]

epoch 75 | average train loss = 0.00002  and average validation loss = 0.00045  |learning rate = 0.00210


 21%|██        | 103/500 [00:06<00:25, 15.48it/s]

epoch 100 | average train loss = 0.00002  and average validation loss = 0.00043  |learning rate = 0.00134


 26%|██▌       | 129/500 [00:08<00:24, 15.30it/s]

epoch 125 | average train loss = 0.00001  and average validation loss = 0.00043  |learning rate = 0.00134


 31%|███       | 153/500 [00:09<00:22, 15.19it/s]

epoch 150 | average train loss = 0.00001  and average validation loss = 0.00042  |learning rate = 0.00134


 36%|███▌      | 179/500 [00:11<00:20, 15.64it/s]

epoch 175 | average train loss = 0.00001  and average validation loss = 0.00043  |learning rate = 0.00134


 41%|████      | 203/500 [00:13<00:18, 15.85it/s]

epoch 200 | average train loss = 0.00001  and average validation loss = 0.00041  |learning rate = 0.00134


 46%|████▌     | 229/500 [00:14<00:17, 15.45it/s]

epoch 225 | average train loss = 0.00001  and average validation loss = 0.00042  |learning rate = 0.00134


 51%|█████     | 253/500 [00:16<00:16, 14.69it/s]

epoch 250 | average train loss = 0.00001  and average validation loss = 0.00040  |learning rate = 0.00134


 55%|█████▌    | 277/500 [00:17<00:13, 15.99it/s]

epoch 275 | average train loss = 0.00001  and average validation loss = 0.00041  |learning rate = 0.00134


 61%|██████    | 303/500 [00:19<00:12, 15.76it/s]

epoch 300 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


 66%|██████▌   | 329/500 [00:21<00:10, 15.80it/s]

epoch 325 | average train loss = 0.00001  and average validation loss = 0.00040  |learning rate = 0.00134


 71%|███████   | 353/500 [00:22<00:09, 15.59it/s]

epoch 350 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


 76%|███████▌  | 379/500 [00:24<00:07, 15.61it/s]

epoch 375 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


 81%|████████  | 403/500 [00:25<00:06, 15.75it/s]

epoch 400 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


 85%|████████▌ | 427/500 [00:27<00:04, 15.86it/s]

epoch 425 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


 91%|█████████ | 453/500 [00:28<00:02, 16.05it/s]

epoch 450 | average train loss = 0.00001  and average validation loss = 0.00038  |learning rate = 0.00134


 96%|█████████▌| 479/500 [00:30<00:01, 15.74it/s]

epoch 475 | average train loss = 0.00001  and average validation loss = 0.00039  |learning rate = 0.00134


100%|██████████| 500/500 [00:32<00:00, 15.60it/s]
