In [1]:
# --- Standard libraries
import os.path as osp
import pandas as pd
import numpy as np
import pickle as pkl
# --- Matplotlib
import matplotlib.pyplot as plt
import imageio
# --- PyTorch and PyG
import torch
from torch_geometric.loader import DataLoader
# --- XASNet
from XASNet.data import QM9_XAS
from XASNet.data import save_split
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet
from XASNet.trainer import GNNTrainer

In [2]:
# --- Load in the dataset
root = './datasets/mol_dataset.pt'
go_spec = QM9_XAS(root=root,
                  raw_dir='./datasets/',
                  spectra=[])

In [3]:
# --- Print details of the dataset
print(go_spec)
print('------------')
print(f'Number of graphs: {len(go_spec)}')
print(f'Number of features: {go_spec.num_features}')
print('')

# --- Print details of the first molecule/graph in dataset
data = go_spec[0]

print(data)
print('------------')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

QM9_XAS(319)
------------
Number of graphs: 319
Number of features: 17

Data(x=[28, 17], edge_index=[2, 72], edge_attr=[72, 6], spectrum=[200], idx=[1], smiles='c12[c:2]3[cH:1][cH:23][c:22]4[c:20]1[c:16]1[c:18]([cH:19][c:21]4[OH:24])[CH:17]4[CH:15]([C:14]5([OH:27])[CH:12]1[C:8]16[c:4]2[c:6]([cH:5][cH:3]3)[CH:7]=[CH:9][C:10]1([CH:11]=[CH:13]5)[O:26]6)[O:25]4')
------------
Number of nodes: 28
Number of edges: 72
Average node degree: 2.57
Has isolated nodes: False
Has self loops: False
Is undirected: True


In [4]:
# --- Create spilt file with the dataset
# --- split into test, validation and test datasets
idxs = save_split(
    path='./datasets/xasnet-mol-split.npz',
    ndata=len(go_spec),
    ntrain=244,
    nval=25,
    ntest=48,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [5]:
# --- Create variables for each dataset split
train_go = [go_spec[i] for i in idxs['train']]
val_go = [go_spec[i] for i in idxs['val']]
test_go = [go_spec[i] for i in idxs['test']]

# --- Save datasets splits into dataloaders
train_loader = DataLoader(train_go, batch_size=244, shuffle=True)
val_loader = DataLoader(val_go, batch_size=25, shuffle=True)
test_loader = DataLoader(test_go, batch_size=48, shuffle=False)

print(f'Training dataset length: {len(train_go)}, compiled in {len(train_loader)} loaders')
print(f'Validation dataset length: {len(val_go)}, compiled in {len(val_loader)} loaders')
print(f'Test dataset length: {len(test_go)}, compiled in {len(test_loader)} loaders')

Training dataset length: 244, compiled in 1 loaders
Validation dataset length: 25, compiled in 1 loaders
Test dataset length: 48, compiled in 1 loaders


In [6]:
# --- Save the dataloader to a file
torch.save(test_go, './datasets/test_mol_dataset.pt')

In [7]:
# --- Define cost functions
def RSE_loss(prediction, target):
    dE = (300 - 280) / 200
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

In [8]:
# --- Set name for ML model
model_name = 'xasnet_model'
# --- Set number of epochs to run
num_epochs = 500
# --- Set the learning rate 
lr = 0.01
# --- Milestones to reduce learning rate in steps 
milestones = np.arange(10, 100, 10).tolist()

##### XASNet GNN

In [9]:
# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [17, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final output
    num_layers = 3, # number of total layers
    heads = 1
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [10]:
# --- View the details of the created model
print(xasnet_gnn)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GNN(
  (batch_norms): ModuleList()
  (interaction_layers): ModuleList(
    (0): GCNConv(17, 256)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): GCNConv(256, 128)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): GCNConv(128, 64)
  )
  (dropout): Dropout(p=0.6, inplace=False)
  (out): Linear(in_features=64, out_features=200, bias=True)
)
----
 Model will be trained on: cpu


##### XASNet GAT

In [9]:
# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gat = XASNet_GAT(
    node_features_dim=17,
    in_channels=[128, 128, 128, 128],
    out_channels=[128, 128, 128, 400],
    targets=200,
    n_layers=4,
    n_heads=3,
    gat_type='gatv2_custom',
    use_residuals=True,
    use_jk=True
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gat.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [10]:
# --- View the details of the created model
print(xasnet_gat)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GAT(
  (pre_layer): LinearLayer(
    (linear): Linear(in_features=17, out_features=128, bias=False)
    (_activation): ReLU(inplace=True)
  )
  (res_block): Residual_block(
    (res_layers): Sequential(
      (0): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (1): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (2): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (3): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
    )
  )
  (gat_layers): ModuleList(
    (0): GATv2LayerCus(128, 128)
    (1): ReLU(inplace=True)
    (2): GATv2LayerCus(384, 128)
    (3): ReLU(inplace=True)
    (4): GATv2LayerCus(384, 128)
    (5): ReLU

##### XASNet GraphNet

In [9]:
# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_graphnet = XASNet_GraphNet(
    node_dim=17,
    edge_dim=6,
    hidden_channels=512,
    out_channels=200,
    gat_hidd=512,
    gat_out=200,
    n_layers=3,
    n_targets=200
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_graphnet.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [10]:
# --- View the details of the created model
print(xasnet_graphnet)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GraphNet(
  (graphnets): ModuleList(
    (0): GraphNetwork(
      (gatencoder): GATEncoder(
        (gats): ModuleList(
          (0): GATv2Conv(17, 512, heads=3)
          (1): ReLU(inplace=True)
          (2): GATv2Conv(1536, 512, heads=3)
          (3): ReLU(inplace=True)
          (4): GATv2Conv(1536, 512, heads=3)
          (5): ReLU(inplace=True)
          (6): GATv2Conv(1536, 200, heads=1)
        )
      )
      (node_model): NodeModel(
        (mlp): Sequential(
          (0): Linear(in_features=229, out_features=512, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=512, out_features=200, bias=True)
          (3): ReLU(inplace=True)
          (4): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        )
      )
      (edge_model): EdgeModel(
        (mlp): Sequential(
          (0): Linear(in_features=606, out_features=512, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=512, out_features=200, bias=True

##### Train model

In [12]:
chosen_model = xasnet_graphnet

In [13]:
# --- Set additional ML parameters
optimizer = torch.optim.AdamW(chosen_model.parameters(), lr=lr, weight_decay=1e-5) #betas=(0.9, 0.99), eps=1e-08, amsgrad=True)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [14]:
# --- Create trainier
trainer = GNNTrainer(model = chosen_model,
                     model_name = model_name,
                     device = device,
                     metric_path = './metrics')

In [15]:
# --- Train the ML model
trainer.train_val(train_loader, val_loader, optimizer, RSE_loss,
                  scheduler, num_epochs, write_every=25, train_graphnet=True)

  0%|          | 1/500 [00:01<09:01,  1.08s/it]

epoch 0 | average train loss = 0.00432  and average validation loss = 0.11220  |learning rate = 0.01000


  5%|▌         | 26/500 [00:14<04:02,  1.95it/s]

epoch 25 | average train loss = 0.00107  and average validation loss = 0.00888  |learning rate = 0.00640


 10%|█         | 51/500 [00:27<03:55,  1.91it/s]

epoch 50 | average train loss = 0.00022  and average validation loss = 0.00146  |learning rate = 0.00328


 15%|█▌        | 76/500 [00:39<03:12,  2.20it/s]

epoch 75 | average train loss = 0.00008  and average validation loss = 0.00062  |learning rate = 0.00210


 20%|██        | 101/500 [00:50<03:07,  2.13it/s]

epoch 100 | average train loss = 0.00005  and average validation loss = 0.00051  |learning rate = 0.00134


 25%|██▌       | 126/500 [01:01<02:39,  2.34it/s]

epoch 125 | average train loss = 0.00004  and average validation loss = 0.00050  |learning rate = 0.00134


 30%|███       | 151/500 [01:13<02:38,  2.21it/s]

epoch 150 | average train loss = 0.00003  and average validation loss = 0.00051  |learning rate = 0.00134


 35%|███▌      | 176/500 [01:24<02:12,  2.44it/s]

epoch 175 | average train loss = 0.00003  and average validation loss = 0.00048  |learning rate = 0.00134


 40%|████      | 201/500 [01:34<02:03,  2.42it/s]

epoch 200 | average train loss = 0.00002  and average validation loss = 0.00056  |learning rate = 0.00134


 45%|████▌     | 226/500 [01:44<01:53,  2.42it/s]

epoch 225 | average train loss = 0.00002  and average validation loss = 0.00057  |learning rate = 0.00134


 50%|█████     | 251/500 [01:54<01:39,  2.51it/s]

epoch 250 | average train loss = 0.00002  and average validation loss = 0.00051  |learning rate = 0.00134


 55%|█████▌    | 276/500 [02:05<01:39,  2.26it/s]

epoch 275 | average train loss = 0.00002  and average validation loss = 0.00045  |learning rate = 0.00134


 60%|██████    | 301/500 [02:15<01:20,  2.47it/s]

epoch 300 | average train loss = 0.00002  and average validation loss = 0.00049  |learning rate = 0.00134


 65%|██████▌   | 326/500 [02:25<01:09,  2.51it/s]

epoch 325 | average train loss = 0.00002  and average validation loss = 0.00051  |learning rate = 0.00134


 70%|███████   | 351/500 [02:35<01:01,  2.42it/s]

epoch 350 | average train loss = 0.00002  and average validation loss = 0.00049  |learning rate = 0.00134


 75%|███████▌  | 376/500 [02:46<00:52,  2.34it/s]

epoch 375 | average train loss = 0.00002  and average validation loss = 0.00046  |learning rate = 0.00134


 80%|████████  | 401/500 [02:58<00:50,  1.95it/s]

epoch 400 | average train loss = 0.00002  and average validation loss = 0.00049  |learning rate = 0.00134


 85%|████████▌ | 426/500 [03:10<00:36,  2.03it/s]

epoch 425 | average train loss = 0.00002  and average validation loss = 0.00047  |learning rate = 0.00134


 90%|█████████ | 451/500 [03:22<00:23,  2.11it/s]

epoch 450 | average train loss = 0.00002  and average validation loss = 0.00051  |learning rate = 0.00134


 95%|█████████▌| 476/500 [03:34<00:10,  2.18it/s]

epoch 475 | average train loss = 0.00002  and average validation loss = 0.00049  |learning rate = 0.00134


100%|██████████| 500/500 [03:45<00:00,  2.22it/s]
