In [2]:
import sys
#sys.path.append('/home/samjhall/github/XASNet-XAI/src')
sys.path.append('D:\github\XASNet-XAI\src')

import os.path as osp
import numpy as np
import pickle as pkl
import torch
from torch_geometric.loader import DataLoader


from XASNet.data import QM9_XAS
from XASNet.data import save_split

from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet

from XASNet.trainer import GNNTrainer

In [9]:
# --- Load in the dataset
root = './XASNet-data/mol_dataset.pt'
go_spec = QM9_XAS(root=root,
                  raw_dir='./XASNet-data/',
                  spectra=[])

In [10]:
# --- Print details of the dataset
print(go_spec)
print('------------')
print(f'Number of graphs: {len(go_spec)}')
print(f'Number of features: {go_spec.num_features}')
print('')

# --- Print details of the first molecule/graph in dataset
data = go_spec[0]

print(data)
print('------------')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

QM9_XAS(319)
------------
Number of graphs: 319
Number of features: 10

Data(x=[30, 10], edge_index=[2, 76], edge_attr=[76, 6], spectrum=[200], idx=[1], smiles='[c:0]12[c:4]3[c:8]4[c:10]5[cH:11][cH:14][c:15]6[c:13]4[c:17]4[c:19]([cH:18][cH:16]6)[cH:20][c:22]([OH:25])[c:23]([c:21]14)[CH2:24][CH:1]1[C:2]2([CH:3]=[CH:5][C:6]32[CH:7]([CH:9]5[C:12](=[O:26])[OH:27])[O:29]2)[O:28]1')
------------
Number of nodes: 30
Number of edges: 76
Average node degree: 2.53
Has isolated nodes: False
Has self loops: False
Is undirected: True


In [12]:
# --- Create spilt file with the dataset
# split into test, validation and test datasets
idxs = save_split(
    path='./raw/xasnet-split.npz',
    ndata=len(go_spec),
    ntrain=252,
    nval=28,
    ntest=39,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

{'train': 252, 'val': 28, 'test': 39}


In [13]:
# --- Create variables for each dataset split
train_go = [go_spec[i] for i in idxs['train']]
val_go = [go_spec[i] for i in idxs['val']]
test_go = [go_spec[i] for i in idxs['test']]

# --- Save datasets splits into dataloaders
train_loader = DataLoader(train_go, batch_size=30, shuffle=True)
val_loader = DataLoader(val_go, batch_size=30, shuffle=True)
test_loader = DataLoader(test_go, batch_size=30, shuffle=False)

In [15]:
# --- Save the dataloader to a file
torch.save(test_go, './XASNet-data/test_mol_dataset.pt')

In [16]:
# --- Define cost functions
def RSE_loss(prediction, target):
    dE = (300 - 280) / 200
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

In [17]:
# --- Set name for ML model
model_name = 'xasnet_model'
# --- Set number of epochs to run
num_epochs = 300
# --- Set the learning rate 
lr = 1e-3
# --- Milestones to reduce learning rate in steps 
milestones = np.arange(10, 100, 10).tolist()

In [18]:
# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn',
    in_channels = [10, 100, 200, 300],
    out_channels = [100, 200, 300, 300],
    num_targets = 200,
    num_layers = 4,
    heads = 1
).to(device)

path_to_model = osp.join('./best_model,', model_name)

if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [19]:
xasnet_gnn

XASNet_GNN(
  (interaction_layers): ModuleList(
    (0): GCNConv(10, 100)
    (1): ReLU(inplace=True)
    (2): GCNConv(100, 200)
    (3): ReLU(inplace=True)
    (4): GCNConv(200, 300)
    (5): ReLU(inplace=True)
    (6): GCNConv(300, 300)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=300, out_features=200, bias=True)
)

In [20]:
# --- Set additional ML parameters
optimizer = torch.optim.AdamW(xasnet_gnn.parameters(), lr=lr)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [21]:
# --- Create trainier
trainer = GNNTrainer(model = xasnet_gnn,
                     model_name = model_name,
                     device = device,
                     metric_path = './metrics')

In [22]:
# --- Train the ML model
trainer.train_val(train_loader, val_loader, optimizer, RMSE,
                  scheduler, num_epochs, write_every=25, train_graphnet=False)

RuntimeError: Tried to instantiate dummy base class Event