In [1]:
import pandas as pd
import numpy as np
import torch
import torch_geometric
from torch import nn, optim
from torch.nn import functional as Fy
from torch.utils import data
import math
from sklearn.metrics import roc_auc_score
import sys
from os.path import exists
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
from torch_geometric.utils import from_smiles
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import AttentiveFP

In [89]:
# this is the torch geometric from_smiles function
# you can edit the rdkit features included 
# default features are 9 - i will try to augment this

from typing import Any, Dict, List

import torch

import torch_geometric

x_map: Dict[str, List[Any]] = {
    'atomic_num': list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER', 'CHI_TETRAHEDRAL', 'CHI_ALLENE', 'CHI_SQUAREPLANAR',
        'CHI_TRIGONALBIPYRAMIDAL', 'CHI_OCTAHEDRAL',
    ],
    'degree': list(range(0, 11)),
    'formal_charge': list(range(-5, 7)),
    'num_hs': list(range(0, 9)),
    'num_radical_electrons': list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
    'valence': list(range(0, 8)),
    'num_rings': list(range(0, 10)),
    'implicit_valence': list(range(0, 7)),
    'explicit_valence': list(range(0, 8)),
    'mass': list(range(0, 250)),  # Discretized atomic mass
    'isotope': list(range(0, 10)),
    'ring_sizes': list(range(0, 10)),  # Smallest ring size
}

e_map: Dict[str, List[Any]] = {
    'bond_type': [
        'UNSPECIFIED',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'QUADRUPLE',
        'QUINTUPLE',
        'HEXTUPLE',
        'ONEANDAHALF',
        'TWOANDAHALF',
        'THREEANDAHALF',
        'FOURANDAHALF',
        'FIVEANDAHALF',
        'AROMATIC',
        'IONIC',
        'HYDROGEN',
        'THREECENTER',
        'DATIVEONE',
        'DATIVE',
        'DATIVEL',
        'DATIVER',
        'OTHER',
        'ZERO',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOANY',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
    ],
    'is_conjugated': [False, True],
}


def from_rdmol(mol: Any, use_3d:bool=False) -> 'torch_geometric.data.Data':
    r"""Converts a :class:`rdkit.Chem.Mol` instance to a
    :class:`torch_geometric.data.Data` instance.

    Args:
        mol (rdkit.Chem.Mol): The :class:`rdkit` molecule.
    """
    from rdkit import Chem

    from torch_geometric.data import Data

    assert isinstance(mol, Chem.Mol)

    xs: List[List[int]] = []
    pos: List[List[float]] = []

    for atom in mol.GetAtoms():
        row: List[int] = []

        # Original features
        row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        row.append(x_map['degree'].index(atom.GetTotalDegree()))
        row.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        row.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        row.append(x_map['num_radical_electrons'].index(atom.GetNumRadicalElectrons()))
        row.append(x_map['hybridization'].index(str(atom.GetHybridization())))
        row.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        row.append(x_map['is_in_ring'].index(atom.IsInRing()))
        
        # New features
        row.append(min(mol.GetRingInfo().NumAtomRings(atom.GetIdx()), 9))
        row.append(x_map['valence'].index(atom.GetTotalValence()))
        row.append(x_map['implicit_valence'].index(atom.GetImplicitValence()))
        row.append(x_map['explicit_valence'].index(atom.GetExplicitValence()))
        row.append(min(x_map['mass'].index(int(atom.GetMass())), 249))
        row.append(min(atom.GetIsotope(), 9))
        
        # Smallest ring size this atom is in (0 if not in ring)
        ring_info = mol.GetRingInfo()
        atom_rings = ring_info.AtomRings()
        ring_sizes = [len(ring) for ring in atom_rings if atom.GetIdx() in ring]
        smallest_ring = min(ring_sizes) if ring_sizes else 0
        row.append(min(smallest_ring, 9))
        
        xs.append(row)
        
        if use_3d:
            conf = mol.GetConformer()
            atom_pos = conf.GetAtomPosition(atom.GetIdx())
            pos.append([atom_pos.x, atom_pos.y, atom_pos.z])
    
    x = torch.tensor(xs, dtype=torch.long).view(-1, 16)


    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():  # type: ignore
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        e.append(e_map['bond_type'].index(str(bond.GetBondType())))
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

    if edge_index.numel() > 0:  # Sort indices.
        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

    if use_3d and pos:
        pos_tensor = torch.tensor(pos, dtype=torch.float)
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos_tensor)
    else:
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def from_smiles(
    smiles: str,
    with_hydrogen: bool = False,
    kekulize: bool = False,
    use_3d: bool = False,
) -> 'torch_geometric.data.Data':
    r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
    instance.

    Args:
        smiles (str): The SMILES string.
        with_hydrogen (bool, optional): If set to :obj:`True`, will store
            hydrogens in the molecule graph. (default: :obj:`False`)
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    from rdkit import Chem, RDLogger
    from rdkit.Chem import AllChem

    RDLogger.DisableLog('rdApp.*')  # type: ignore

    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        mol = Chem.MolFromSmiles('')
    if with_hydrogen:
        mol = Chem.AddHs(mol)
    if kekulize:
        Chem.Kekulize(mol)
    if use_3d:
        AllChem.EmbedMolecule(mol, randomSeed=42)  # Generate 3D structure
        AllChem.MMFFOptimizeMolecule(mol)  # Optional: optimize geometry

    data = from_rdmol(mol, use_3d=use_3d)
    data.smiles = smiles
    return data


def to_rdmol(
    data: 'torch_geometric.data.Data',
    kekulize: bool = False,
) -> Any:
    """Converts a :class:`torch_geometric.data.Data` instance to a
    :class:`rdkit.Chem.Mol` instance.

    Args:
        data (torch_geometric.data.Data): The molecular graph data.
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    from rdkit import Chem

    mol = Chem.RWMol()

    assert data.x is not None
    assert data.num_nodes is not None
    assert data.edge_index is not None
    assert data.edge_attr is not None
    for i in range(data.num_nodes):
        atom = Chem.Atom(int(data.x[i, 0]))
        atom.SetChiralTag(Chem.rdchem.ChiralType.values[int(data.x[i, 1])])
        atom.SetFormalCharge(x_map['formal_charge'][int(data.x[i, 3])])
        atom.SetNumExplicitHs(x_map['num_hs'][int(data.x[i, 4])])
        atom.SetNumRadicalElectrons(x_map['num_radical_electrons'][int(
            data.x[i, 5])])
        atom.SetHybridization(Chem.rdchem.HybridizationType.values[int(
            data.x[i, 6])])
        atom.SetIsAromatic(bool(data.x[i, 7]))
        mol.AddAtom(atom)

    edges = [tuple(i) for i in data.edge_index.t().tolist()]
    visited = set()

    for i in range(len(edges)):
        src, dst = edges[i]
        if tuple(sorted(edges[i])) in visited:
            continue

        bond_type = Chem.BondType.values[int(data.edge_attr[i, 0])]
        mol.AddBond(src, dst, bond_type)

        # Set stereochemistry:
        stereo = Chem.rdchem.BondStereo.values[int(data.edge_attr[i, 1])]
        if stereo != Chem.rdchem.BondStereo.STEREONONE:
            db = mol.GetBondBetweenAtoms(src, dst)
            db.SetStereoAtoms(dst, src)
            db.SetStereo(stereo)

        # Set conjugation:
        is_conjugated = bool(data.edge_attr[i, 2])
        mol.GetBondBetweenAtoms(src, dst).SetIsConjugated(is_conjugated)

        visited.add(tuple(sorted(edges[i])))

    mol = mol.GetMol()

    if kekulize:
        Chem.Kekulize(mol)

    Chem.SanitizeMol(mol)
    Chem.AssignStereochemistry(mol)

    return mol


def to_smiles(
    data: 'torch_geometric.data.Data',
    kekulize: bool = False,
) -> str:
    """Converts a :class:`torch_geometric.data.Data` instance to a SMILES
    string.

    Args:
        data (torch_geometric.data.Data): The molecular graph.
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    from rdkit import Chem
    mol = to_rdmol(data, kekulize=kekulize)
    return Chem.MolToSmiles(mol, isomericSmiles=True)

In [90]:
df = pd.read_pickle('mat_bandgap_morgan.pkl')
df = df.drop(['Morgan','confnum','homo','lumo'],axis=1)
mols = [Chem.MolFromSmiles(x) for x in df['smiles']]
print(df)

                                                smiles    gap
0    COC(=O)/C(=C/c1cc(C)c(c2ccc(c3sc(c4cc5c(s4)c(O...  0.235
1                  Cc1csc(c2ccc(c3cc(C)cs3)c3nsnc23)c1  0.262
2    COC(=O)/C(=C/c1cc(C)c(c2cc(C)c(c3ccc(c4sc(c5sc...  0.234
3    C[Si]1(C)c2ccsc2c2sc(c3nc4sc(c5cc6c(s5)c5sc(c7...  0.245
4    COc1c2ccsc2c(OC)c2cc(c3sc(c4scc5c4[C@@H]4C=C[C...  0.300
..                                                 ...    ...
311  Cc1c2ccsc2c(C)c2cc(c3ccc(c4cnc(c5cccs5)c5nsnc4...  0.236
312       CC(=O)c1cc2c(csc2c2cc3c(s2)c(C)c2ccsc2c3C)s1  0.276
313  Cc1cc(c2ccc(N(c3ccccc3)c3ccccc3)cc2)sc1c1cnc(c...  0.258
314  Cc1ccc(C2(c3ccc(C)cc3)c3ccsc3c3cc4c(cc23)c2sc(...  0.255
315  Cc1cc(c2cc3c4nsnc4c(c4cc(C)c(c5cccs5)s4)cc3c3n...  0.231

[316 rows x 2 columns]


In [91]:
# example for one molecule:

smile = df['smiles'][35]
print(smile)

c1csc(c2ccsc2c2cccs2)c1


In [92]:
# this generates a torch Data object for the molecule

g = from_smiles(smile, with_hydrogen=True, kekulize=True, use_3d=True)
print(g)

Data(x=[23, 16], edge_index=[2, 50], edge_attr=[50, 3], pos=[23, 3], smiles='c1csc(c2ccsc2c2cccs2)c1')


In [93]:
# this prints the node feature matrix - so all the atoms in the molecule and some metrics for them
g.x

tensor([[ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [16,  0,  2,  5,  0,  0,  3,  1,  1,  1,  2,  0,  2, 32,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [16,  0,  2,  5,  0,  0,  3,  1,  1,  1,  2,  0,  2, 32,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [ 6,  0,  3,  5,  0,  0,  3,  1,  1,  1,  4,  0,  4, 12,  0,  5],
        [16,  0,  2,  5,  0,  0,  3,  

In [6]:
# this prints the number of nodes:

g.num_nodes

23

In [7]:
g.num_node_features

9

In [94]:
g.edge_index

tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  6,
          6,  6,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11, 12, 12,
         12, 13, 13, 14, 14, 14, 15, 16, 17, 18, 19, 20, 21, 22],
        [ 1, 14, 15,  0,  2, 16,  1,  3,  2,  4, 14,  3,  5,  8,  4,  6, 17,  5,
          7, 18,  6,  8,  4,  7,  9,  8, 10, 13,  9, 11, 19, 10, 12, 20, 11, 13,
         21,  9, 12,  0,  3, 22,  0,  1,  5,  6, 10, 11, 12, 14]])

In [95]:
g.edge_attr

tensor([[12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  1],
        [12,  0,  1],
        [ 1,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  1],
        [ 1,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [12,  0,  1],
        [ 1,  0,  0],
        [ 1,  0,  0],
        [ 1,  0,  0],
        [ 1,  0,  0],
        [ 

In [96]:
# this generates data objects for all the molecules in the dataset

graph_list = []

for i, smile in enumerate(df['smiles']):
    g = from_smiles(smile)
    g.x = g.x.float()
    y = torch.tensor(df['gap'][i],dtype=torch.float).view(1,-1)
    g.y = y
    graph_list.append(g)

In [97]:
type(graph_list[0])

torch_geometric.data.data.Data

In [99]:
graph_list[1]

Data(x=[21, 16], edge_index=[2, 48], edge_attr=[48, 3], smiles='Cc1csc(c2ccc(c3cc(C)cs3)c3nsnc23)c1', y=[1, 1])

In [100]:
# train test split

train_ratio = 0.80
dataset_size = len(graph_list)
train_size = int(train_ratio*dataset_size)
test_size = dataset_size-train_size

generator1 = torch.Generator().manual_seed(42)
train_dataset, test_dataset = random_split(graph_list,[train_size,test_size], generator=generator1)

In [14]:
len(train_dataset)
len(test_dataset)

64

In [101]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [76]:
graph_list[0]

Data(x=[66, 16], edge_index=[2, 148], edge_attr=[148, 3], smiles='COC(=O)/C(=C/c1cc(C)c(c2ccc(c3sc(c4cc5c(s4)c(OC)c4cc(c6cc(C)c(c7ccc(c8sc(/C=C(\C#N)/C(=O)OC)cc8C)s7)s6)sc4c5OC)cc3C)s2)s1)/C#N', y=[1, 1])

In [107]:
model = AttentiveFP(in_channels=16,hidden_channels=64,out_channels=1,edge_dim=3,num_layers=3,num_timesteps=2,dropout=0.2)

optimizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=10**-5)

loss_function = nn.MSELoss()  

In [63]:
print(model)

AttentiveFP(in_channels=11, hidden_channels=64, out_channels=1, edge_dim=3, num_layers=2, num_timesteps=2)


In [108]:
# training and evaluation loops

from sklearn.metrics import r2_score

EPOCHS = 600

for i in range(EPOCHS):
    loss_list_train = []
    pred_list_train = []  
    true_list_train = []  
    
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
    
        output = model(data.x, data.edge_index, data.edge_attr, data.batch)

        loss = loss_function(output, data.y)
        loss.backward()  
        loss_list_train.append(loss.item())
        optimizer.step()  
        
        # Store predictions 
        pred_list_train.extend(output.detach().cpu().numpy().flatten())
        true_list_train.extend(data.y.cpu().numpy().flatten())
    
    # Calculate R² AFTER the loop
    r2_train = r2_score(true_list_train, pred_list_train)
        
    loss_list_test = []
    pred_list_test = []  
    true_list_test = []  
    
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            output = model(data.x, data.edge_index, data.edge_attr, data.batch)

            loss = loss_function(output, data.y)
            loss_list_test.append(loss.item())
     
            # Store predictions - REPLACED
            pred_list_test.extend(output.cpu().numpy().flatten())
            true_list_test.extend(data.y.cpu().numpy().flatten())
    
    # Calculate R² AFTER the loop - NEW
    r2_test = r2_score(true_list_test, pred_list_test)
                
    print(i, "Train Loss: %.4f Train R²: %.4f Test Loss: %.4f Test R²: %.4f"
        % (np.mean(loss_list_train), r2_train, 
           np.mean(loss_list_test), r2_test))

0 Train Loss: 2.2767 Train R²: -3289.6155 Test Loss: 0.1839 Test R²: -293.0222
1 Train Loss: 0.2227 Train R²: -319.8301 Test Loss: 0.0094 Test R²: -14.0415
2 Train Loss: 0.0714 Train R²: -101.1672 Test Loss: 0.0399 Test R²: -62.7964
3 Train Loss: 0.0364 Train R²: -51.2871 Test Loss: 0.0126 Test R²: -19.1459
4 Train Loss: 0.0297 Train R²: -41.4124 Test Loss: 0.0044 Test R²: -6.0579
5 Train Loss: 0.0239 Train R²: -32.8167 Test Loss: 0.0021 Test R²: -2.4265
6 Train Loss: 0.0163 Train R²: -22.2307 Test Loss: 0.0003 Test R²: 0.4441
7 Train Loss: 0.0139 Train R²: -18.6989 Test Loss: 0.0005 Test R²: 0.1459
8 Train Loss: 0.0154 Train R²: -20.9659 Test Loss: 0.0018 Test R²: -1.9295
9 Train Loss: 0.0094 Train R²: -12.3252 Test Loss: 0.0011 Test R²: -0.7659
10 Train Loss: 0.0097 Train R²: -12.8862 Test Loss: 0.0006 Test R²: 0.0249
11 Train Loss: 0.0102 Train R²: -13.4427 Test Loss: 0.0007 Test R²: -0.1511
12 Train Loss: 0.0079 Train R²: -10.2754 Test Loss: 0.0007 Test R²: -0.0983
13 Train Loss: 0