References:

1. https://github.com/xbresson/CS6208_2023/blob/main/codes/labs_lecture07/01_vanilla_graph_transformers.ipynb
2. https://github.com/xbresson/CS6208_2023/blob/main/codes/labs_lecture07/03_graph_transformers_regression_exercise.ipynb
3. https://github.com/pgniewko/pytorch_geometric/blob/master/torch_geometric/nn/conv/transformer_conv.py
4. https://arxiv.org/abs/2012.09699
5. https://arxiv.org/abs/1703.04977

TDC:
1. Test the model on this new set: https://practicalcheminformatics.blogspot.com/2023/06/getting-real-with-molecular-property.html

In [1]:
import pandas as pd

import rdkit
import torch
import math
import torch.nn.functional as F
import numpy as np
from torch import nn
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_scatter.composite import scatter_softmax
from torch_scatter.scatter import scatter_add
from torch_geometric.utils import softmax
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn.aggr import MultiAggregation
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
import os.path as osp
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.utils import degree
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Turn off majority of RDKit warnings
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

print(f'Rdkit version: {rdkit.__version__}')
print(f'Torch version: {torch.__version__}')

torch.manual_seed(192837465)

#
# Generate example data
#x = torch.randn(6, 3)  # Node features (6 nodes, 3-dimensional features)
#edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5], [1, 0, 2, 1, 3, 2, 5, 4]], dtype=torch.long)  # Edge indices
#edge_attr = torch.randn(8, 2)  # Edge attributes (8 edges, 2-dimensional attributes)
#
#data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

  from .autonotebook import tqdm as notebook_tqdm


Rdkit version: 2023.03.1
Torch version: 1.13.1


<torch._C.Generator at 0x7fb231a374f0>

In [18]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, 
                 num_hidden_layers=1,
                 dropout=0.0, act='relu', 
                 act_kwargs=None):
        super(MLP, self).__init__()
    
        if isinstance(hidden_dims, int):
            hidden_dims = [hidden_dims] * num_hidden_layers
        
        hidden_dims = [input_dim] + hidden_dims
        assert len(hidden_dims) - 1 == num_hidden_layers
        
        layers = []
        
        for (i_dim, o_dim) in zip(hidden_dims[:-1], hidden_dims[1:]):
            layers.append(nn.Linear(i_dim, o_dim, bias=True))
            layers.append(activation_resolver(act, **(act_kwargs or {})))
            if dropout > 0:
                layers.append(nn.Dropout(p=dropout))
                
        layers.append(nn.Linear(hidden_dims[-1], output_dim, bias=True))
        
        self.mlp = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.mlp(x)

    
class GTConv(MessagePassing):
    def __init__(self, node_in_dim, hidden_dim, edge_in_dim=None, num_heads=1, dropout=0.0, norm='bn', act='relu'):
        super(GTConv, self).__init__(node_dim=0, aggr='add')
        
        assert hidden_dim % num_heads == 0
        assert (edge_in_dim is None) or (edge_in_dim > 0)
        
        self.WQ = nn.Linear(node_in_dim, hidden_dim, bias=True)
        self.WK = nn.Linear(node_in_dim, hidden_dim, bias=True)
        self.WV = nn.Linear(node_in_dim, hidden_dim, bias=True)
        self.WO = nn.Linear(hidden_dim, node_in_dim, bias=True)
        
        if edge_in_dim is not None:
            assert node_in_dim == edge_in_dim
            self.WE = nn.Linear(edge_in_dim, hidden_dim, bias=True)
            self.WOe = nn.Linear(hidden_dim, edge_in_dim, bias=True)
            self.ffn_e = MLP(input_dim=edge_in_dim,
                             output_dim=edge_in_dim,
                             hidden_dims=hidden_dim,
                             num_hidden_layers=1,
                             dropout=dropout, act=act)
            if norm.lower() in ['bn', 'batchnorm', 'batch_norm']:
                self.norm1e = nn.BatchNorm1d(edge_in_dim)
                self.norm2e = nn.BatchNorm1d(edge_in_dim)
            elif norm.lower() in ['ln', 'layernorm', 'layer_norm']:
                self.norm1e = nn.LayerNorm(edge_in_dim)
                self.norm2e = nn.LayerNorm(edge_in_dim)
        else:
            self.WE = self.register_parameter('WE', None)
            self.WOe = self.register_parameter('WOe', None)
            self.ffn_e = self.register_parameter('ffn_e', None)
            self.norm1e = self.register_parameter('norm1e', None)
            self.norm2e = self.register_parameter('norm2e', None)
        
        if norm.lower() in ['bn', 'batchnorm', 'batch_norm']:
            self.norm1 = nn.BatchNorm1d(node_in_dim)
            self.norm2 = nn.BatchNorm1d(node_in_dim)
        elif norm.lower() in ['ln', 'layernorm', 'layer_norm']:
            self.norm1 = nn.LayerNorm(node_in_dim)
            self.norm2 = nn.LayerNorm(node_in_dim)
            
        self.dropout_layer = nn.Dropout(p=dropout)
            
        self.ffn = MLP(input_dim=node_in_dim,
                       output_dim=node_in_dim,
                       hidden_dims=hidden_dim,
                       num_hidden_layers=1,
                       dropout=dropout, act=act)
        
        self.num_heads = num_heads
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.norm = norm.lower()
        
        self.reset_parameters()
           
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.WQ.weight)
        nn.init.xavier_uniform_(self.WK.weight)
        nn.init.xavier_uniform_(self.WV.weight)
        nn.init.xavier_uniform_(self.WO.weight)
        if self.edge_in_dim is not None:
            nn.init.xavier_uniform_(self.WE.weight)
            nn.init.xavier_uniform_(self.WOe.weight)
        
    
    def forward(self, x, edge_index, edge_attr=None):
        x_ = x
        Q = self.WQ(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        K = self.WK(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        V = self.WV(x).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
        
        out = self.propagate(edge_index, Q=Q, K=K, V=V,
                             edge_attr=edge_attr, size=None)
        out = out.view(-1, self.hidden_dim)
        
        ## NODES
        out = self.dropout_layer(out)
        out = self.WO(out) + x_ # Residual connection
        out = self.norm1(out)
        # FFN-NODES
        ffn_in = out
        out = self.ffn(out)
        out = self.norm2(ffn_in + out)
        
        if self.edge_in_dim is None:
            out_eij = None
        else:
            out_eij = self._eij
            self._eij = None
            out_eij = out_eij.view(-1, self.hidden_dim)

            ## EDGES
            out_eij_ = out_eij
            out_eij = self.dropout_layer(out_eij)
            out_eij = self.WOe(out_eij) + out_eij_ # Residual connection
            out_eij = self.norm1e(out_eij)
            # FFN-EDGES
            ffn_eij_in = out_eij
            out_eij = self.ffn_e(out_eij)
            out_eij = self.norm2e(ffn_eij_in + out_eij)

        return (out, out_eij)
        
        
    def message(self, Q_i, K_j, V_j, index, edge_attr=None):
        if self.WE is not None:
            assert edge_attr is not None
            E = self.WE(edge_attr).view(-1, self.num_heads, self.hidden_dim // self.num_heads)
            K_j = E * K_j
        
        d_k = Q_i.size(-1)
        qijk = (Q_i * K_j).sum(dim=-1) / math.sqrt(d_k)
        self._eij = (Q_i * K_j) / math.sqrt(d_k)
        alpha = softmax(qijk, index) # Log-Sum-Exp trick used. No need for clipping (-5,5)
        
        return alpha.view(-1, self.num_heads, 1) * V_j
    
    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.node_in_dim}, '
                f'{self.hidden_dim}, heads={self.num_heads})')
    
    
class GraphTransformerNet(nn.Module):
    def __init__(self, node_dim_in, edge_dim_in=None,
                 hidden_dim=128, norm='bn',
                 num_gt_layers=4, num_heads=8,
                 aggregators=['sum'],
                 act='relu', dropout=0.0):
        super(GraphTransformerNet, self).__init__()
        
        self.node_emb = nn.Linear(node_dim_in, hidden_dim)
        if edge_dim_in:
            self.edge_emb = nn.Linear(edge_dim_in, hidden_dim)
        else:
            self.edge_emb = self.register_parameter('edge_emb', None)
        
        self.gt_layers = nn.ModuleList()
        for _ in range(num_gt_layers):
            self.gt_layers.append(GTConv(node_in_dim=hidden_dim,
                                         hidden_dim=hidden_dim, 
                                         edge_in_dim=hidden_dim,
                                         num_heads=num_heads,
                                         act=act,
                                         dropout=dropout,
                                         norm='bn'))
        
        self.global_pool = MultiAggregation(aggregators, mode='cat')
        
        num_aggrs = len(aggregators)
        self.mu_mlp = MLP(input_dim=num_aggrs * hidden_dim, output_dim=1,
                          hidden_dims=hidden_dim,
                          num_hidden_layers=1, dropout=0.0, act=act)
        self.std_mlp = MLP(input_dim=num_aggrs * hidden_dim, output_dim=1,
                           hidden_dims=hidden_dim,
                           num_hidden_layers=1, dropout=0.0, act=act)
        
        
        self.reset_parameters()
        
        
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.node_emb.weight)
        nn.init.xavier_uniform_(self.edge_emb.weight)
            
            
    def forward(self, x, edge_index, edge_attr, batch, return_std=False):
        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr)

        for gt_layer in self.gt_layers:
            (x, edge_attr) = gt_layer(x, edge_index, edge_attr=edge_attr)

        x = self.global_pool(x, batch)
        mu = self.mu_mlp(x)
        log_var = self.std_mlp(x)
        std = torch.exp(0.5 * log_var)
        
        if self.training:
            eps = torch.randn_like(std)
            return mu + std * eps, std
        else:
            return mu, std
        
    def num_parameters(self):   
        trainable_params = filter(lambda p: p.requires_grad, self.parameters())
        count = sum([p.numel() for p in trainable_params])
        return count

In [3]:
#gt = GTConv(node_in_dim=3, hidden_dim=6, num_heads=2, edge_in_dim=2)
#gt

In [4]:
#gt(x, edge_index, edge_attr=edge_attr)

In [5]:
#gt_net = GraphTransformerNet(node_dim_in=3, edge_dim_in=2)

In [6]:
#loader = DataLoader([data], batch_size=128, shuffle=False)
#for data in loader:
#    x = data.x
#    edge_index = data.edge_index
#    edge_attr = data.edge_attr
#    pred= gt_net(x, edge_index, edge_attr, data.batch)
#    print(pred)

## GET ZINC DATA

In [2]:
#path = osp.join('.', 'data', 'ZINC')
#train_dataset = ZINC(path, subset=True, split='train')
#val_dataset = ZINC(path, subset=True, split='val')
#test_dataset = ZINC(path, subset=True, split='test')

#train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=128)
#test_loader = DataLoader(test_dataset, batch_size=128)

## Compute the maximum in-degree in the training data.
#max_degree = -1
#for data in train_dataset:
#    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
#    max_degree = max(max_degree, int(d.max()))
#
## Compute the in-degree histogram tensor
#deg = torch.zeros(max_degree + 1, dtype=torch.long)
#for data in train_dataset:
#    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
#    deg += torch.bincount(d, minlength=deg.numel())

## Solubility data

In [9]:
from gt_pyg.data.utils import get_tensor_data, get_node_dim, get_edge_dim
biogen_data_file = './data/BioGen/biogen_solubility.csv'
df = pd.read_csv(biogen_data_file)
dataset = get_tensor_data(df.SMILES.to_list(), df.logS.to_list())
NODE_DIM = get_node_dim()
EDGE_DIM = get_edge_dim()

train_loader = DataLoader(dataset, batch_size=128)
val_loader = DataLoader(dataset, batch_size=128)
test_loader = DataLoader(dataset, batch_size=128)

## TRAIN AND EVAL

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphTransformerNet(node_dim_in=NODE_DIM, edge_dim_in=EDGE_DIM, num_gt_layers=4, hidden_dim=128, 
                           dropout=0.1).to(device)

if int(torch.__version__.split('.')[0]) >= 2:
    model = torch_geometric.compile(model) 

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)


print(model)
print(f"Number of params: {model.num_parameters()//1000}k")
def train(epoch):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        (out,_) = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        (out,_) = model(data.x, data.edge_index, data.edge_attr, data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)


for epoch in range(1, 11):
    loss = train(epoch)
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

GraphTransformerNet(
  (node_emb): Linear(in_features=79, out_features=128, bias=True)
  (edge_emb): Linear(in_features=10, out_features=128, bias=True)
  (gt_layers): ModuleList(
    (0): GTConv(128, 128, heads=8)
    (1): GTConv(128, 128, heads=8)
    (2): GTConv(128, 128, heads=8)
    (3): GTConv(128, 128, heads=8)
  )
  (global_pool): MultiAggregation([
    SumAggregation(),
  ], mode=cat)
  (mu_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (std_mlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
Number of params: 709k
Epoch: 01, Loss: 1.5729, Val: 1.4888, Test: 1.4888


KeyboardInterrupt: 