In [1]:
# Standard
import logging

# Third party
import numpy as np
import rdkit
from rdkit import RDLogger
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader

# GT-PyG
from gt_pyg.data.utils import (
    get_tensor_data, 
    get_node_dim, 
    get_edge_dim, 
    get_train_valid_test_data
)
from gt_pyg.nn.model import GraphTransformerNet

# Turn off majority of RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)


# Set a random seed for a reproducibility purposes
torch.manual_seed(192837465)

# Setup the logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Log the used versions of RDkit and torch
print(f'Numpy version: {np.__version__}')
print(f'Rdkit version: {rdkit.__version__}')
print(f'Torch version: {torch.__version__}')

  from .autonotebook import tqdm as notebook_tqdm


Numpy version: 1.21.6
Rdkit version: 2022.09.5
Torch version: 1.13.1


## Get the ADME@TDC data

**Note**: To use the code below, make sure that the chosen endpoint is a regression task.

In [2]:
from tdc import utils
names = utils.retrieve_benchmark_names('ADMET_Group')
output = "\n".join([f"{index}. {name}" for index, name in enumerate(names, start=1)])
print("Available endpoints:\n\n" + output)

Available endpoints:

1. caco2_wang
2. hia_hou
3. pgp_broccatelli
4. bioavailability_ma
5. lipophilicity_astrazeneca
6. solubility_aqsoldb
7. bbb_martins
8. ppbr_az
9. vdss_lombardo
10. cyp2d6_veith
11. cyp3a4_veith
12. cyp2c9_veith
13. cyp2d6_substrate_carbonmangels
14. cyp3a4_substrate_carbonmangels
15. cyp2c9_substrate_carbonmangels
16. half_life_obach
17. clearance_microsome_az
18. clearance_hepatocyte_az
19. herg
20. ames
21. dili
22. ld50_zhu


Regression endpoints with MAE metric:
1. caco2_wang (Best: 0.285 ± 0.005)
2. lipophilicity_astrazeneca (Best: 0.535 ± 0.012)
3. solubility_aqsoldb (Best: 0.776 ± 0.008)
4. ppbr_az (Best: 9.185 ± 0.000)
5. ld50_zhu (Best: 0.588 ± 0.005)

In [3]:
PE_DIM = 6
(tr, va, te) = get_train_valid_test_data('lipophilicity_astrazeneca', min_num_atoms=0)
tr_dataset = get_tensor_data(tr.Drug.to_list(), tr.Y.to_list(), pe_dim=PE_DIM)
va_dataset = get_tensor_data(va.Drug.to_list(), va.Y.to_list(), pe_dim=PE_DIM)
te_dataset = get_tensor_data(te.Drug.to_list(), te.Y.to_list(), pe_dim=PE_DIM)
NODE_DIM = get_node_dim()
EDGE_DIM = get_edge_dim()

print(f'Number of training examples: {len(tr_dataset)}')
print(f'Number of validation examples: {len(va_dataset)}')
print(f'Number of test examples: {len(te_dataset)}')

train_loader = DataLoader(tr_dataset, batch_size=64)
val_loader = DataLoader(va_dataset, batch_size=64)
test_loader = DataLoader(te_dataset, batch_size=64)

Found local copy...
Loading...
Done!


Number of training examples: 2940
Number of validation examples: 420
Number of test examples: 840


## Train and eval the GT model

### Auxiliary functions

In [4]:
def train(epoch, loss_func):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch, zero_var=False)
        loss = loss_func(out.squeeze(), data.y)
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader, loss_func):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        # randomly flip sign of eigenvectors
        batch_pe = data.pe * (2 * torch.randint(low=0, high=2, size=(1, PE_DIM)).float() - 1.0)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, batch_pe, data.batch)
        total_error += loss_func(out.squeeze(), data.y).item()
    return total_error / len(loader.dataset)

train_loss = nn.L1Loss(reduction='mean')
test_loss = nn.L1Loss(reduction='sum')

### Standard Graph Transformer setup

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='bn',
                            gt_aggregators=['sum'],
                            aggregators=['sum'],
                            dropout=0.1,
                            act='relu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")

best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf
for epoch in range(1, 101):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader, loss_func=test_loss)
    te_loss = test(test_loader, loss_func=test_loss)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

GraphTransformerNet(
  (node_emb): Linear(in_features=76, out_features=128, bias=True)
  (edge_emb): Linear(in_features=10, out_features=128, bias=True)
  (pe_emb): Linear(in_features=6, out_features=128, bias=True)
  (gt_layers): ModuleList(
    (0): GTConv(128, 128, heads=8, aggrss: aggrs)
    (1): GTConv(128, 128, heads=8, aggrss: aggrs)
    (2): GTConv(128, 128, heads=8, aggrss: aggrs)
    (3): GTConv(128, 128, heads=8, aggrss: aggrs)
  )
  (global_pool): MultiAggregation([
    SumAggregation(),
  ], mode=cat)
  (mu_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (log_var_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
Number of params: 710 k
Epoch: 01, Loss: 254.1053, Val: 79.3010, Test: 78.4710
Epoch:

### Slighlty optimized Graph Transformer architecture

1. `gelu` activation is used instead of `relu`
2. Multiaggregator used for global pooling
3. Multiaggregator used for message passing

Number of params 873k instead of 701k

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformerNet(node_dim_in=NODE_DIM,
                            edge_dim_in=EDGE_DIM,
                            pe_in_dim=PE_DIM,
                            num_gt_layers=4, 
                            hidden_dim=128,
                            num_heads=8,
                            norm='ln',
                            gt_aggregators=['sum', 'mean'],
                            aggregators=['sum','mean','max', 'std'],
                            dropout=0.1,
                            act='gelu').to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.00001)

print(model)
print(f"Number of params: {model.num_parameters()//1000} k")


best_epoch = 0
best_validation_loss = np.inf
test_set_mae = np.inf
for epoch in range(1, 201):
    tr_loss = train(epoch, loss_func=train_loss)
    va_loss = test(val_loader, loss_func=test_loss)
    te_loss = test(test_loader, loss_func=test_loss)
    scheduler.step(va_loss)
    print(f'Epoch: {epoch:02d}, Loss: {tr_loss:.4f}, Val: {va_loss:.4f}, '
          f'Test: {te_loss:.4f}')
    if va_loss < best_validation_loss:
        best_epoch = epoch
        best_validation_loss = va_loss
        test_set_mae = te_loss
        
print("\nModel's performance on the test set\n"
        "===================================\n"
        f'MAE={test_set_mae}\n'
        f'Epoch={best_epoch}')

GraphTransformerNet(
  (node_emb): Linear(in_features=76, out_features=128, bias=True)
  (edge_emb): Linear(in_features=10, out_features=128, bias=True)
  (pe_emb): Linear(in_features=6, out_features=128, bias=True)
  (gt_layers): ModuleList(
    (0): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (1): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (2): GTConv(128, 128, heads=8, aggrs: sum,mean)
    (3): GTConv(128, 128, heads=8, aggrs: sum,mean)
  )
  (global_pool): MultiAggregation([
    SumAggregation(),
    MeanAggregation(),
    MaxAggregation(),
    StdAggregation(),
  ], mode=cat)
  (mu_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (log_var_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=512, out_features=128, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=128, out_features=

Epoch: 141, Loss: 0.1876, Val: 0.3820, Test: 0.4163
Epoch: 142, Loss: 0.1868, Val: 0.3839, Test: 0.4167
Epoch: 143, Loss: 0.1871, Val: 0.3891, Test: 0.4168
Epoch: 144, Loss: 0.1851, Val: 0.3823, Test: 0.4151
Epoch: 145, Loss: 0.1876, Val: 0.3826, Test: 0.4161
Epoch: 146, Loss: 0.1849, Val: 0.3819, Test: 0.4174
Epoch: 147, Loss: 0.1843, Val: 0.3819, Test: 0.4159
Epoch: 148, Loss: 0.1862, Val: 0.3813, Test: 0.4145
Epoch: 149, Loss: 0.1848, Val: 0.3790, Test: 0.4141
Epoch: 150, Loss: 0.1853, Val: 0.3831, Test: 0.4158
Epoch: 151, Loss: 0.1825, Val: 0.3810, Test: 0.4135
Epoch: 152, Loss: 0.1865, Val: 0.3792, Test: 0.4150
Epoch: 153, Loss: 0.1854, Val: 0.3816, Test: 0.4128
Epoch: 154, Loss: 0.1837, Val: 0.3799, Test: 0.4146
Epoch: 155, Loss: 0.1880, Val: 0.3829, Test: 0.4160
Epoch: 156, Loss: 0.1846, Val: 0.3826, Test: 0.4147
Epoch: 157, Loss: 0.1865, Val: 0.3824, Test: 0.4152
Epoch: 158, Loss: 0.1852, Val: 0.3785, Test: 0.4149
Epoch: 159, Loss: 0.1850, Val: 0.3820, Test: 0.4152
Epoch: 160, 

In [6]:
import weightwatcher as ww

watcher = ww.WeightWatcher(model=model)
details = watcher.analyze(plot=False)
details

PyTorch is available but CUDA is not. Defaulting to SciPy for SVD


Unnamed: 0,layer_id,name,D,M,N,Q,alpha,alpha_weighted,entropy,has_esd,...,rf,sigma,spectral_norm,stable_rank,status,sv_max,warning,weak_rank_loss,xmax,xmin
0,1,Linear,0.067975,76,128,1.684211,3.277394,4.026596,0.880096,True,...,1,0.569348,16.927663,7.08593,success,4.114324,,0,16.927663,2.309725
1,2,Linear,0.169242,10,128,12.8,2.642475,1.919284,0.945221,True,...,1,0.519396,5.325014,4.77769,success,2.307599,,0,5.325014,1.223018
2,11,Linear,0.058695,128,128,1.0,2.613188,3.193995,0.833103,True,...,1,0.28082,16.682456,10.798983,success,4.084416,,0,16.682456,1.720429
3,12,Linear,0.051392,128,128,1.0,2.685112,3.097238,0.845849,True,...,1,0.312917,14.239204,12.49973,success,3.773487,,1,14.239204,1.942548
4,13,Linear,0.089777,128,128,1.0,5.363919,3.915031,0.891117,True,...,1,1.058406,5.368865,23.248182,success,2.317081,,0,5.368865,2.346537
5,14,Linear,0.092734,128,256,2.0,3.811916,3.019194,0.933881,True,...,1,0.551462,6.194998,29.722648,success,2.488975,,0,6.194998,2.269211
6,15,Linear,0.094187,128,128,1.0,2.385702,3.254417,0.802269,True,...,1,0.202125,23.12779,7.721125,success,4.809136,,0,23.12779,1.099728
7,16,Linear,0.089884,128,128,1.0,3.55514,2.975651,0.87851,True,...,1,0.571347,6.870682,22.390432,success,2.621199,,0,6.870682,2.481076
8,19,Linear,0.055797,128,128,1.0,2.512242,1.845517,0.835333,True,...,1,0.259347,5.42762,11.910944,success,2.329725,,0,5.42762,0.591549
9,22,Linear,0.077458,128,128,1.0,2.698165,2.083067,0.833296,True,...,1,0.36205,5.916039,11.57088,success,2.432291,,0,5.916039,0.902335


In [7]:
watcher.get_summary(details)

{'log_norm': 2.012117284692378,
 'alpha': 3.215042345276656,
 'alpha_weighted': 2.504157234497073,
 'log_alpha_norm': 3.0239557866120674,
 'log_spectral_norm': 0.8036157617573815,
 'stable_rank': 18.009767940281925}