In [1]:
import os.path as osp
from math import sqrt

import torch
import torch.nn.functional as F
from rdkit import Chem

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn.models import AttentiveFP


class GenAttentiveFeatures(object):
    '''
    AttentiveFP 39 node features generation
    '''
    def __init__(self):
        self.symbols = [
            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
            'Te', 'I', 'At', 'other'
        ]

        self.hybridizations = [
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]

        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
        ]

    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)

        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            symbol[self.symbols.index(atom.GetSymbol())] = 1.
            degree = [0.] * 6
            degree[atom.GetDegree()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            hybridization[self.hybridizations.index(
                atom.GetHybridization())] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.
            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
            chirality_type = [0.] * 2
            if atom.HasProp('_CIPCode'):
                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.

            x = torch.tensor(symbol + degree + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens + [chirality] +
                             chirality_type)
            xs.append(x)

        data.x = torch.stack(xs, dim=0)

        edge_indices = []
        edge_attrs = []
        for bond in mol.GetBonds():
            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]

            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 4
            stereo[self.stereos.index(bond.GetStereo())] = 1.

            edge_attr = torch.tensor(
                [single, double, triple, aromatic, conjugation, ring] + stereo)

            edge_attrs += [edge_attr, edge_attr]

        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            data.edge_index = torch.tensor(edge_indices).t().contiguous()
            data.edge_attr = torch.stack(edge_attrs, dim=0)

        return data


class GenAtomFeatures(object):
    '''
    Our features: @todo
    '''
    pass

In [2]:
from low_inhibitors import LSSInhibitor

for dataset_name in LSSInhibitor.names.keys():
    
    print(dataset_name)
    
    path = '../../tmp/data'
    
    # use the attentiveFP node and edge features during the mol-2-graph transoformation
    dataset = LSSInhibitor(path, name='mglur2', pre_transform=GenAttentiveFeatures()).shuffle()
    #dataset = MoleculeNet(path, name='FreeSolv', pre_transform=GenFeatures()).shuffle()
    
    batch_size = 8
    
    # train, valid, test splitting
    N = len(dataset) // 5
    val_dataset = dataset[:N]
    test_dataset = dataset[N:2 * N]
    train_dataset = dataset[2 * N:]
    
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1,
                        edge_dim=10, num_layers=2, num_timesteps=2,
                        dropout=0.2).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5,
                                 weight_decay=10**-5)

    def train():
        total_loss = total_examples = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            loss = F.mse_loss(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * data.num_graphs
            total_examples += data.num_graphs
        return sqrt(total_loss / total_examples)


    @torch.no_grad()
    def test(loader):
        mse = []
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
        return float(torch.cat(mse, dim=0).mean().sqrt())

    for epoch in range(1, 50):
        train_rmse = train()
        val_rmse = test(val_loader)
        test_rmse = test(test_loader)
        print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
              f'Test: {test_rmse:.4f}')

mglur2


Downloading https://raw.githubusercontent.com/bidd-group/LSSinhibitors/main/data/mGluR2.csv
Processing...
Done!


Epoch: 001, Loss: 5.2638 Val: 2.7361 Test: 3.1182
Epoch: 002, Loss: 2.8597 Val: 1.0392 Test: 1.1940
Epoch: 003, Loss: 1.1562 Val: 1.0386 Test: 1.2838
Epoch: 004, Loss: 1.0233 Val: 0.9924 Test: 1.0265
Epoch: 005, Loss: 0.9474 Val: 0.6431 Test: 1.2452
Epoch: 006, Loss: 1.1027 Val: 0.3823 Test: 0.8079
Epoch: 007, Loss: 1.0575 Val: 0.5822 Test: 0.9224
Epoch: 008, Loss: 0.9790 Val: 0.7652 Test: 0.5707
Epoch: 009, Loss: 1.2451 Val: 1.3015 Test: 0.9719
Epoch: 010, Loss: 1.1985 Val: 0.6167 Test: 0.6222
Epoch: 011, Loss: 1.1093 Val: 1.1619 Test: 0.9489
Epoch: 012, Loss: 1.0772 Val: 0.5794 Test: 0.6898
Epoch: 013, Loss: 1.0647 Val: 0.8794 Test: 0.8425
Epoch: 014, Loss: 1.0211 Val: 0.9489 Test: 0.3953
Epoch: 015, Loss: 0.8168 Val: 0.6625 Test: 0.7807
Epoch: 016, Loss: 0.8211 Val: 0.4875 Test: 0.8978
Epoch: 017, Loss: 0.9327 Val: 0.5638 Test: 0.5655
Epoch: 018, Loss: 1.0621 Val: 0.6356 Test: 0.6743
Epoch: 019, Loss: 0.6539 Val: 0.4807 Test: 0.7503
Epoch: 020, Loss: 1.1352 Val: 0.9272 Test: 0.5853
