In [1]:
import sys
sys.path.append('/home/samjhall/github/XASNet-XAI/src')

import os.path as osp
import numpy as np
import pickle as pkl
import torch
from torch_geometric.loader import DataLoader


from XASNet.data import QM9_XAS
from XASNet.data import save_split

from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet

from XASNet.trainer import GNNTrainer

In [2]:
root = '/home/samjhall/github/GO_molecule_GNN/processed/train_mol.pt'
go_spec = QM9_XAS(root=root,
                  raw_dir='./processed/',
                  spectra=[])

In [3]:
# Show detail of the dataset
print(go_spec)
print('------------')
print(f'Number of graphs: {len(go_spec)}')
print(f'Number of features: {go_spec.num_features}')
print('')

# Show details of the first molecule/graph in dataset
data = go_spec[0]

print(data)
print('------------')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

QM9_XAS(319)
------------
Number of graphs: 319
Number of features: 10

Data(x=[30, 10], edge_index=[2, 76], edge_attr=[76, 6], spectrum=[200], idx=[1], smiles='[c:0]12[c:4]3[c:8]4[c:10]5[cH:11][cH:14][c:15]6[c:13]4[c:17]4[c:19]([cH:18][cH:16]6)[cH:20][c:22]([OH:25])[c:23]([c:21]14)[CH2:24][CH:1]1[C:2]2([CH:3]=[CH:5][C:6]32[CH:7]([CH:9]5[C:12](=[O:26])[OH:27])[O:29]2)[O:28]1')
------------
Number of nodes: 30
Number of edges: 76
Average node degree: 2.53
Has isolated nodes: False
Has self loops: False
Is undirected: True


In [4]:
# save/load split file
idxs = save_split(
    path='/home/samjhall/work/GO_molecule_GNN/raw/xasnet-split.npz',
    ndata=len(go_spec),
    ntrain=252,
    nval=28,
    ntest=39,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [5]:
# trai
train_go = [go_spec[i] for i in idxs['train']]
val_go = [go_spec[i] for i in idxs['val']]
test_go = [go_spec[i] for i in idxs['test']]

In [6]:
# data loaders
train_loader = DataLoader(train_go, batch_size=30, shuffle=True)
val_loader = DataLoader(val_go, batch_size=30, shuffle=True)
test_loader = DataLoader(test_go, batch_size=30, shuffle=False)

In [15]:
torch.save(test_go, './XASNet-data/test_mol_dataset.pt')

In [7]:
def RSE_loss(prediction, target):
    dE = (300 - 270) / 100
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

In [8]:
model_name = 'xasnet_model'
# number of epochs in training
num_epochs = 300
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#learning rate 
lr = 1e-3
# milestones to reduce learning rate in steps 
milestones = np.arange(10, 100, 10).tolist()

In [9]:
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn',
    in_channels = [10, 100, 200, 300],
    out_channels = [100, 200, 300, 300],
    num_targets = 200,
    num_layers = 4,
    heads = 1
).to(device)

path_to_model = osp.join('./best_model,', model_name)

if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [10]:
xasnet_gnn

XASNet_GNN(
  (interaction_layers): ModuleList(
    (0): GCNConv(10, 100)
    (1): ReLU(inplace=True)
    (2): GCNConv(100, 200)
    (3): ReLU(inplace=True)
    (4): GCNConv(200, 300)
    (5): ReLU(inplace=True)
    (6): GCNConv(300, 300)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=300, out_features=200, bias=True)
)

In [11]:
optimizer = torch.optim.AdamW(xasnet_gnn.parameters(), lr=lr)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [12]:
trainer = GNNTrainer(model = xasnet_gnn,
                     model_name = model_name,
                     device = device,
                     metric_path = './metrics')

In [13]:
trainer.train_val(train_loader, val_loader, optimizer, RMSE,
                  scheduler, num_epochs, write_every=25, train_graphnet=False)

  1%|          | 3/300 [00:00<00:27, 10.90it/s]

time = 0.00 mins mins
epoch 0 | average train loss = 0.02  and average validation loss = 0.01  |learning rate = 0.00100


 10%|█         | 30/300 [00:01<00:12, 21.57it/s]

time = 0.02 mins mins
epoch 25 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00064


 18%|█▊        | 54/300 [00:02<00:10, 23.22it/s]

time = 0.04 mins mins
epoch 50 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00033


 26%|██▌       | 78/300 [00:03<00:09, 22.50it/s]

time = 0.06 mins mins
epoch 75 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00021


 35%|███▌      | 105/300 [00:04<00:08, 23.52it/s]

time = 0.08 mins mins
epoch 100 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 43%|████▎     | 129/300 [00:05<00:07, 23.11it/s]

time = 0.09 mins mins
epoch 125 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 51%|█████     | 153/300 [00:06<00:06, 24.11it/s]

time = 0.11 mins mins
epoch 150 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 60%|██████    | 180/300 [00:07<00:05, 23.65it/s]

time = 0.13 mins mins
epoch 175 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 68%|██████▊   | 204/300 [00:08<00:04, 23.74it/s]

time = 0.15 mins mins
epoch 200 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 76%|███████▌  | 228/300 [00:09<00:03, 23.63it/s]

time = 0.16 mins mins
epoch 225 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 85%|████████▌ | 255/300 [00:11<00:01, 22.96it/s]

time = 0.18 mins mins
epoch 250 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


 93%|█████████▎| 279/300 [00:12<00:00, 23.62it/s]

time = 0.20 mins mins
epoch 275 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00013


100%|██████████| 300/300 [00:13<00:00, 23.01it/s]
