In [1]:
# Standard
import logging

# Third party
import numpy as np
import rdkit
from rdkit import RDLogger
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader
import torchmetrics
from torchmetrics import MeanSquaredError
from pytorch_lamb import Lamb

# GT-PyG
from gt_pyg.data.utils import (
    get_tensor_data, 
    get_node_dim,
    get_edge_dim,
    get_molecule_ace_datasets
)
from gt_pyg.nn.model import GraphTransformerNet

# Turn off majority of RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)


# Set a random seed for a reproducibility purposes
torch.manual_seed(192837465)

# Setup the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Log the used versions of RDkit and torch
print(f'Numpy version: {np.__version__}')
print(f'Rdkit version: {rdkit.__version__}')
print(f'Torch version: {torch.__version__}')
print(f'TorchMetrics version: {torchmetrics.__version__}')

  from .autonotebook import tqdm as notebook_tqdm


Numpy version: 1.21.6
Rdkit version: 2022.09.5
Torch version: 1.13.1
TorchMetrics version: 0.11.4


## Get the MoleculeACE data

**Note**: To use the code below, make sure that the chosen endpoint is a regression task.

In [16]:
PE_DIM = 6
(tr, va, te) = get_molecule_ace_datasets('CHEMBL2034_Ki', min_num_atoms=0, 
                                         training_fraction=0.8, valid_fraction=0.2)
tr_dataset = get_tensor_data(tr.SMILES.to_list(), tr.Y.to_list(), pe_dim=PE_DIM)
va_dataset = get_tensor_data(va.SMILES.to_list(), va.Y.to_list(), pe_dim=PE_DIM)
te_dataset = get_tensor_data(te.SMILES.to_list(), te.Y.to_list(), pe_dim=PE_DIM)
NODE_DIM = get_node_dim()
EDGE_DIM = get_edge_dim()

print(f'Number of training examples: {len(tr_dataset)}')
print(f'Number of validation examples: {len(va_dataset)}')
print(f'Number of test examples: {len(te_dataset)}')

train_loader = DataLoader(tr_dataset, batch_size=64)
val_loader = DataLoader(va_dataset, batch_size=512)
test_loader = DataLoader(te_dataset, batch_size=512)

Number of training examples: 478
Number of validation examples: 120
Number of test examples: 152


### Auxiliary functions

In [11]:
def train(epoch, loss_func):
    """
    RMSE = MSE ** 0.5
    """
    model.train()
    train_mse = MeanSquaredError()

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch, zero_var=False)
        loss = loss_func(out.squeeze(), data.y)
        loss.backward()
        optimizer.step()
        
        train_mse.update(out.squeeze(), data.y)

    return train_mse.compute()**0.5


@torch.no_grad()
def test(loader):
    """
    RMSE = MSE ** 0.5
    """
    model.eval()
    test_mse = MeanSquaredError()

    total_error = 0
    for data in loader:
        data = data.to(device)
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch)
        
        test_mse.update(out.squeeze(), data.y)
        
    return test_mse.compute() ** 0.5

train_loss = nn.MSELoss(reduction='mean')

### Slighlty optimized Graph Transformer architecture

1. `gelu` activation is used instead of `relu`
2. Multiaggregator used for global pooling
3. Multiaggregator used for message passing
4. Lamb optmizer for this [paper](https://arxiv.org/abs/1904.00962), after this [repo](https://github.com/cybertronai/pytorch-lamb)

Number of params 873k instead of 709k

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gt_aggregators=['sum', 'mean'],
                            aggregators=['sum','mean','max', 'std'],
                            dropout=0.1,
                            act='gelu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = Lamb(model.parameters(),
                 lr=0.005,
                 weight_decay=0.05,
                 betas=(.9, .999), 
                 adam=False)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")


best_epoch = 0
best_validation_loss = np.inf
test_set_rmse = np.inf
for epoch in range(1, 51):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader)
    te_loss = test(test_loader)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_rmse = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'RMSE={test_set_rmse}\n'
        f'Epoch={best_epoch}')

GraphTransformerNet(
  (node_emb): Linear(in_features=76, out_features=128, bias=False)
  (edge_emb): Linear(in_features=10, out_features=128, bias=False)
  (pe_emb): Linear(in_features=6, out_features=128, bias=False)
  (gt_layers): ModuleList(
    (0): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (1): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (2): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (3): GTConv(128, 128, heads=8, aggrs: sum,mean)
  )
  (global_pool): MultiAggregation([
    SumAggregation(),
    MeanAggregation(),
    MaxAggregation(),
    StdAggregation(),
  ], mode=cat)
  (mu_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (log_var_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=128, out_featur