In [1]:
# --- Standard libraries
import os.path as osp
import pandas as pd
import numpy as np
import pickle as pkl
# --- Matplotlib
import matplotlib.pyplot as plt
#import imageio
# --- PyTorch and PyG
import torch
from torch.nn import MSELoss
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree
# --- XASNet
from XASNet.data import QM9_XAS
from XASNet.data import save_split
from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet#, XASNet_NNconv, XASNet_PNA
from XASNet.trainer import GNNTrainer
torch.__version__

'2.3.1+cu121'

In [2]:
# --- Load in the dataset
go_spec = torch.load('../datasets/mol_dataset.pt')

In [4]:
# --- Print details of the dataset
print(go_spec)
print('------------')
print(f'Number of graphs: {len(go_spec)}')
print(f'Number of features: {go_spec.num_features}')
print('')

# --- Print details of the first molecule/graph in dataset
data = go_spec[0]

print(data)
print('------------')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

XASDataset_mol(317)
------------
Number of graphs: 317
Number of features: 15

Data(x=[32, 15], edge_index=[2, 76], edge_attr=[76, 5], spectrum=[200], idx=[1], smiles='c12[c:2]3[c:1]([H:32])[c:25]([H:41])[c:24]4[c:22]1[c:17]1[c:19]([c:20]([H:39])[c:23]4[H:40])[C:18]([C:21](=[O:26])[O:27][H:42])=[C:16]([H:38])[C:15]4=[C:14]([H:37])[C:12]([H:36])=[C:11]5[C:9]([O:30][H:44])([C:4]2([H:47])[C:6]([O:31][H:46])([C:5]([C:8](=[O:28])[O:29][H:43])=[C:3]3[H:33])[C:7]([H:34])=[C:10]5[H:35])[C:13]41[H:45]')
------------
Number of nodes: 32
Number of edges: 76
Average node degree: 2.38
Has isolated nodes: False
Has self loops: False
Is undirected: True


In [6]:
# --- Create spilt file with the dataset
# --- split into test, validation and test datasets
idxs = save_split(
    path='./xasnet-mol-split.npz',
    ndata=len(go_spec),
    ntrain=238,
    nval=29,
    ntest=48,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

{'train': 238, 'val': 29, 'test': 48}


In [7]:
# --- Create variables for each dataset split
train_go = [go_spec[i] for i in idxs['train']]
val_go = [go_spec[i] for i in idxs['val']]
test_go = [go_spec[i] for i in idxs['test']]

# --- Save datasets splits into dataloaders
train_loader = DataLoader(train_go, batch_size=238, shuffle=True)
val_loader = DataLoader(val_go, batch_size=29, shuffle=True)
test_loader = DataLoader(test_go, batch_size=48, shuffle=False)

print(f'Training dataset length: {len(train_go)}, compiled in {len(train_loader)} loaders')
print(f'Validation dataset length: {len(val_go)}, compiled in {len(val_loader)} loaders')
print(f'Test dataset length: {len(test_go)}, compiled in {len(test_loader)} loaders')

Training dataset length: 238, compiled in 1 loaders
Validation dataset length: 29, compiled in 1 loaders
Test dataset length: 48, compiled in 1 loaders


In [8]:
# --- Define cost functions
def RSE_loss(prediction, target):
    dE = (300 - 280) / 200
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

def RMSE_loss(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

def MSE_loss(prediction, target):
    return torch.mean((target - prediction)**2)

In [9]:
# --- Set number of epochs to run
num_epochs = 301
# --- Set the learning rate 
lr = 0.001
# --- Milestones to reduce learning rate in steps 
milestones = np.arange(100, 900, 200).tolist()
print(milestones)

[100, 300, 500, 700]


#### XASNet_GNN

In [10]:
# --- Set name for ML model
model_name = 'xasnet_gnn_model'

# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gnn = XASNet_GNN(
    gnn_name = 'gcn', # model type
    in_channels = [15, 256, 128], # input nodes for each layer
    out_channels = [256, 128, 64], # output nodes for each layer
    num_targets = 200, # nodes for final output
    num_layers = 3, # number of total layers
    heads = 0
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [11]:
# --- View the details of the created model
print(xasnet_gnn)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GNN(
  (interaction_layers): ModuleList(
    (0): GCNConv(15, 256)
    (1): ReLU(inplace=True)
    (2): GCNConv(256, 128)
    (3): ReLU(inplace=True)
    (4): GCNConv(128, 64)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=64, out_features=200, bias=True)
)
----
 Model will be trained on: cuda


#### XASNet_GAT

In [12]:
# --- Set name for ML model
model_name = 'xasnet_gat_model'

# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_gat = XASNet_GAT(
    node_features_dim=15,
    in_channels=[128, 128, 128, 128],
    out_channels=[128, 128, 128, 400],
    targets=200,
    n_layers=4,
    n_heads=3,
    gat_type='gatv2_custom',
    use_residuals=True,
    use_jk=True
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_gat.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [13]:
# --- View the details of the created model
print(xasnet_gat)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GAT(
  (pre_layer): LinearLayer(
    (linear): Linear(in_features=15, out_features=128, bias=False)
    (_activation): ReLU(inplace=True)
  )
  (res_block): Residual_block(
    (res_layers): Sequential(
      (0): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (1): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (2): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (3): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
    )
  )
  (gat_layers): ModuleList(
    (0): GATv2LayerCus(128, 128)
    (1): ReLU(inplace=True)
    (2): GATv2LayerCus(384, 128)
    (3): ReLU(inplace=True)
    (4): GATv2LayerCus(384, 128)
    (5): ReLU

#### XASNet_GraphNet

In [14]:
# --- Set name for ML model
model_name = 'xasnet_graphnet_model'

# --- Set device for model to run on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Create the type of ML model you want to run
xasnet_graphnet = XASNet_GraphNet(
    node_dim=15,
    edge_dim=6,
    hidden_channels=512,
    out_channels=200,
    gat_hidd=512,
    gat_out=200,
    n_layers=3,
    n_targets=200
).to(device)

# --- Location to save model
path_to_model = osp.join('./best_model,', model_name)

# --- Check if there is an already existing model
if osp.exists(path_to_model):
    xasnet_graphnet.load_state_dict(torch.load(path_to_model))
else:
    print('Model is not loaded.')

Model is not loaded.


In [15]:
# --- View the details of the created model
print(xasnet_graphnet)
print('----')
print(f' Model will be trained on: {device}')

XASNet_GraphNet(
  (graphnets): ModuleList(
    (0): GraphNetwork(
      (gatencoder): GATEncoder(
        (gats): ModuleList(
          (0): GATv2Conv(15, 512, heads=3)
          (1): ReLU(inplace=True)
          (2): GATv2Conv(1536, 512, heads=3)
          (3): ReLU(inplace=True)
          (4): GATv2Conv(1536, 512, heads=3)
          (5): ReLU(inplace=True)
          (6): GATv2Conv(1536, 200, heads=1)
        )
      )
      (node_model): NodeModel(
        (mlp): Sequential(
          (0): Linear(in_features=227, out_features=512, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=512, out_features=200, bias=True)
          (3): ReLU(inplace=True)
          (4): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        )
      )
      (edge_model): EdgeModel(
        (mlp): Sequential(
          (0): Linear(in_features=606, out_features=512, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=512, out_features=200, bias=True

#### Training

In [16]:
chosen_model = xasnet_gnn
model_name = 'xasnet_gnn_model_test'

In [17]:
# --- Set additional ML parameters
optimizer = torch.optim.Adam(chosen_model.parameters(), lr=lr, weight_decay=1e-5, betas=(0.9, 0.99), eps=1e-08, amsgrad=True)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.8)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=100, min_lr=0.000001)

In [18]:
# --- Create trainier
trainer = GNNTrainer(model = chosen_model,
                     model_name = model_name,
                     device = device,
                     metric_path = './metrics/' + model_name)

In [19]:
# --- Train the ML model
trainer.train_val(train_loader, val_loader, optimizer, MSE_loss,
                  scheduler, num_epochs, write_every=25, train_graphnet=False)

  1%|          | 3/301 [00:00<00:30,  9.64it/s]

time = 0.00 mins mins
epoch 0 | average train loss = 0.00  and average validation loss = 0.01  |learning rate = 0.00100


 13%|█▎        | 38/301 [00:00<00:04, 62.48it/s]

time = 0.01 mins mins
epoch 25 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00100


 22%|██▏       | 65/301 [00:01<00:03, 77.93it/s]

time = 0.02 mins mins
epoch 50 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00100


 31%|███       | 94/301 [00:01<00:02, 87.08it/s]

time = 0.02 mins mins
epoch 75 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00100


 37%|███▋      | 112/301 [00:01<00:02, 83.32it/s]

time = 0.03 mins mins
epoch 100 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 47%|████▋     | 141/301 [00:02<00:01, 89.51it/s]

time = 0.03 mins mins
epoch 125 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 53%|█████▎    | 161/301 [00:02<00:01, 90.94it/s]

time = 0.04 mins mins
epoch 150 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 63%|██████▎   | 191/301 [00:02<00:01, 91.00it/s]

time = 0.04 mins mins
epoch 175 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 70%|███████   | 211/301 [00:02<00:00, 90.15it/s]

time = 0.05 mins mins
epoch 200 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 80%|████████  | 241/301 [00:03<00:00, 91.04it/s]

time = 0.05 mins mins
epoch 225 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 86%|████████▋ | 260/301 [00:03<00:00, 86.18it/s]

time = 0.06 mins mins
epoch 250 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


 96%|█████████▌| 288/301 [00:03<00:00, 88.51it/s]

time = 0.06 mins mins
epoch 275 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00080


100%|██████████| 301/301 [00:03<00:00, 77.71it/s]

time = 0.06 mins mins
epoch 300 | average train loss = 0.00  and average validation loss = 0.00  |learning rate = 0.00064



