In [30]:
from collections import defaultdict
from collections.abc import Set

import rdkit
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG=True  
IPythonConsole.drawOptions.addAtomIndices = True  
IPythonConsole.molSize = 600, 600

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import Embedding, Module, ModuleList

In [31]:
# data types 
float_type = torch.float32  
categorical_type = torch.long
mask_type = torch.float32 
labels_type = torch.float32

In [32]:
# define categorical and continous variable class 

class ContinuousVariable:
    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return f'<ContinuousVariable: {self.name}>'

    def __eq__(self, other):
        return self.name == other.name

    def __hash__(self):
        return hash(self.name)

class CategoricalVariable:
    def __init__(self, name, values, add_null_value=True):
        self.name = name
        self.has_null_value = add_null_value
        if self.has_null_value:
            self.null_value = None
            values = (None,) + tuple(values)
        self.values = tuple(values)
        self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
        self.inv_value_to_idx_mapping = {i: v for v, i in self.value_to_idx_mapping.items()}

        if self.has_null_value:
            self.null_value_idx = self.value_to_idx_mapping[self.null_value]

    def get_null_idx(self):
        if self.has_null_value:
            return self.null_value_idx
        else:
            raise RuntimeError(f"Categorical variable {self.name} has no null value")

    def value_to_idx(self, value):
        return self.value_to_idx_mapping[value]

    def idx_to_value(self, idx):
        return self.inv_value_to_idx_mapping[idx]

    def __len__(self):
        return len(self.values)

    def __repr__(self):
        return f'<CategoricalVariable: {self.name}>'

    def __eq__(self, other):
        return self.name == other.name and self.values == other.values

    def __hash__(self):
        return hash((self.name, self.values))

In [33]:
# atomic features we are going work with (or node features in general)

# 1. atom symbols 
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 
                'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 
                'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 
                'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 
                'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 
                'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 
                'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg', 
                'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce', 'Pr', 
                'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 
                'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 
                'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']

ATOM_SYMBOLS_FEATURE = CategoricalVariable('atom_symbol', ATOM_SYMBOLS)

# 2. atom aromatic value 
ATOM_AROMATIC_VALUES = [True, False]
ATOM_AROMATIC_FEATURE = CategoricalVariable('is_aromatic', ATOM_AROMATIC_VALUES)

# 3. atom explicit valence 
ATOM_EXPLICIT_VALENCE_FEATURE = ContinuousVariable('explicit_valence')

# 4. atom implicit valence 
ATOM_IMPLICIT_VALENCE_FEATURE = ContinuousVariable('implicit_valence')

# atom features (ex. here we have 4 )
ATOM_FEATURES = [ATOM_SYMBOLS_FEATURE, ATOM_AROMATIC_FEATURE, ATOM_EXPLICIT_VALENCE_FEATURE, ATOM_IMPLICIT_VALENCE_FEATURE]

In [34]:
# for a given atom this creates a dictionary of the atomic features [listed in above ATOM_FEATURES] and corresponding values 
def get_atom_features(rd_atom):
    atom_symbol = rd_atom.GetSymbol()
    is_aromatic = rd_atom.GetIsAromatic()
    implicit_valence = float(rd_atom.GetImplicitValence())
    explicit_valence = float(rd_atom.GetExplicitValence())
    
    return {ATOM_SYMBOLS_FEATURE: atom_symbol,
            ATOM_AROMATIC_FEATURE: is_aromatic,
            ATOM_EXPLICIT_VALENCE_FEATURE: explicit_valence,
            ATOM_IMPLICIT_VALENCE_FEATURE: implicit_valence}

In [35]:
# bond features we are going work with (or edge features in general)

# 1. bond type 
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE', 
                'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
                'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC', 
                'IONIC', 'HYDROGEN', 'THREECENTER',	'DATIVEONE', 'DATIVE',
                'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']

TYPE_FEATURE = CategoricalVariable('bond_type', BOND_TYPES)

# 2. bond direction 
BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT', 'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalVariable('bond_direction', BOND_DIRECTIONS)

# 3. bond sterio feature 
BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE', 
                'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalVariable('bond_stereo', BOND_STEREO)

# 4. bond aromatic value 
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalVariable('is_aromatic', AROMATIC_VALUES)

# bond features 
BOND_FEATURES = [TYPE_FEATURE, DIRECTION_FEATURE, AROMATIC_FEATURE, STEREO_FEATURE]

In [36]:
# get the bond features 
def get_bond_features(rd_bond):
    bond_type = str(rd_bond.GetBondType())
    bond_stereo_info = str(rd_bond.GetStereo())
    bond_direction = str(rd_bond.GetBondDir())
    is_aromatic = rd_bond.GetIsAromatic()
    return {TYPE_FEATURE: bond_type,
            DIRECTION_FEATURE: bond_direction,
            AROMATIC_FEATURE: is_aromatic,
            STEREO_FEATURE: bond_stereo_info}

In [37]:
# rdmol to atom and bond features 
def rdmol_to_graph(mol):
    atoms = {rd_atom.GetIdx:get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
    bonds = {frozenset((rd_bond.GetBeginAtomIdx(),rd_bond.GetEndAtomIdx())): get_bond_features(rd_bond)  for rd_bond in mol.GetBonds()}
    return atoms,bonds

In [38]:
# smiles to atom and bond features
def smiles_to_graph(smiles):
    rd_mol = MolFromSmiles(smiles)
    graph = rdmol_to_graph(rd_mol)
    return graph

In [39]:
g = smiles_to_graph('c1ccccc1')

In [40]:
# graph dataset: Given a graph (or graph list), this outputs the graph data
# grapg data like: node_idx, adjacency matrix (and list), node and edge features (both cat and cts)
class GraphDataset(Dataset):
    def __init__(self, *, graphs, labels, node_variables, edge_variables, metadata=None):
        '''
        Create a new graph dataset, 
        '''
        # graphs 
        self.graphs = graphs
        self.labels = labels
        assert len(self.graphs) == len(self.labels), "The graphs and labels lists must be the same length"
        self.metadata = metadata
        if self.metadata is not None:
            assert len(self.metadata) == len(self.graphs), "The metadata list needs to be as long as the graphs"
            
        # node and edge variables 
        self.node_variables = node_variables
        self.edge_variables = edge_variables
        self.categorical_node_variables = [var for var in self.node_variables if isinstance(var, CategoricalVariable)]
        self.continuous_node_variables = [var for var in self.node_variables if isinstance(var, ContinuousVariable)]
        self.categorical_edge_variables = [var for var in self.edge_variables if isinstance(var, CategoricalVariable)]
        self.continuous_edge_variables = [var for var in self.edge_variables if isinstance(var, ContinuousVariable)]
        
    # length of the dataset 
    def __len__(self):
        return len(self.graphs)
    
    # cts node variables [from graph nodes]
    def make_continuous_node_features(self, nodes):
        if len(self.continuous_node_variables) == 0:
            return None
        n_nodes = len(nodes)
        n_features = len(self.continuous_node_variables)
        continuous_node_features = torch.zeros((n_nodes, n_features), dtype=float_type)
        for node_idx, features in nodes.items():
            node_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_node_variables], dtype=float_type)
            continuous_node_features[node_idx] = node_features
        return continuous_node_features
    
    # cat node features [from graph nodes]
    def make_categorical_node_features(self, nodes):
        if len(self.categorical_node_variables) == 0:
            return None
        n_nodes = len(nodes)
        n_features = len(self.categorical_node_variables)
        categorical_node_features = torch.zeros((n_nodes, n_features), dtype=categorical_type)

        for node_idx, features in nodes.items():
            for i, categorical_variable in enumerate(self.categorical_node_variables):
                value = features[categorical_variable]
                value_index = categorical_variable.value_to_idx(value)
                categorical_node_features[node_idx, i] = value_index
        return categorical_node_features
    
    # cts edge features [from graph edges]
    def make_continuous_edge_features(self, n_nodes, edges):
        if len(self.continuous_edge_variables) == 0:
            return None
        n_features = len(self.continuous_edge_variables)
        continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=float_type)
        for edge, features in edges.items():
            edge_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_edge_variables], dtype=float_type)
            u,v = edge
            continuous_edge_features[u, v] = edge_features
            if isinstance(edge, Set):
                continuous_edge_features[v, u] = edge_features
        return continuous_edge_features
    
    # cat edge features  [from graph edges]
    def make_categorical_edge_features(self, n_nodes, edges):
        if len(self.categorical_edge_variables) == 0:
            return None
        n_features = len(self.categorical_edge_variables)
        categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=categorical_type)

        for edge, features in edges.items():
            u,v = edge
            for i, categorical_variable in enumerate(self.categorical_edge_variables):
                value = features[categorical_variable]
                value_index = categorical_variable.value_to_idx(value)
                categorical_edge_features[u, v, i] = value_index
                if isinstance(edge, Set):
                    categorical_edge_features[v, u, i] = value_index
        return categorical_edge_features
    
    # get node and edge features from a graph 
    def __getitem__(self, index):
        graph = self.graphs[index]
        nodes, edges = graph 
        n_nodes = len(nodes)
        # get the node and edge features for the graph 
        continuous_node_features = self.make_continuous_node_features(nodes)
        categorical_node_features = self.make_categorical_node_features(nodes)
        continuous_edge_features = self.make_continuous_edge_features(n_nodes, edges)
        categorical_edge_features = self.make_categorical_edge_features(n_nodes, edges)
        # graph label 
        label = self.labels[index]
        # node idx and edge list 
        nodes_idx = sorted(nodes.keys())
        edge_list = sorted(edges.keys())
        # adjacancy matrix 
        adjacency_matrix = torch.zeros((n_nodes,n_nodes),dtype=float_type)
        for edge in edge_list:
            u,v = edge 
            adjacency_matrix[u,v] = 1
            adjacency_matrix[v,u] = 1 # assuming G is undirected 
        # adjacancy matrix     
        adjacency_list = defaultdict(list)
        adjacency_list[u].append(v)
        adjacency_list[v].append(u)  # assuming undirected 
        
        data_record = {'nodes': nodes_idx,
                        'adjacency_matrix': adjacency_matrix,
                        'adjacency_list': adjacency_list,
                        'categorical_node_features': categorical_node_features,
                        'continuous_node_features': continuous_node_features,
                        'categorical_edge_features': categorical_edge_features,
                        'continuous_edge_features': continuous_edge_features,
                        'label': label}
        # If you need to add extra information (metadata about this graph)
        if self.metadata is not None:
            data_record['metadata'] = self.metadata[index]
            
        return data_record
    
    def get_node_variables(self):
        return {'continuous': self.continuous_node_variables,
        'categorical': self.categorical_node_variables}

    def get_edge_variables(self):
        return {'continuous': self.continuous_edge_variables,
        'categorical': self.categorical_edge_variables}

In [41]:
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES, bond_features=BOND_FEATURES):
    '''
    Create a new GraphDataset from a list of smiles_records dictionaries.
    These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
    '''
    graphs = []
    labels = []
    metadata = []
    for smiles_record in smiles_records:
        smiles = smiles_record['smiles']
        label = smiles_record['label']
        graph = smiles_to_graph(smiles)
        graphs.append(graph)
        labels.append(label)
        metadata.append(smiles_record)
    return GraphDataset(graphs=graphs, 
                        labels=labels, 
                        node_variables=atom_features, 
                        edge_variables=bond_features, 
                        metadata=metadata)

In [42]:
dataset = make_molecular_graph_dataset([{'smiles': 'c1ccccc1', 'label':1},{'smiles':'OS(=O)(=O)O', 'label': 0}])
# dataset[1]

In [43]:
d_model = 8  # The dimensionality of all vectors in the model.
ffn_dim = 12 # hidden layer size of mlp 

class Embedder(Module):
    def __init__(self,categorical_variables,embedding_dim):
        super().__init__()
        self.categorical_variables = categorical_variables
        embeddings = []
        for var in self.categorical_variables:
            num_embedding = len(var)
            if var.has_null_value:
                embedding = Embedding(num_embedding,embedding_dim,padding_idx=var.get_null_idx())
            else:
                embedding = Embedding(num_embedding,embedding_dim)
            embeddings.append(embedding)
        self.embeddings = ModuleList(embeddings)
    
    def forward(self,categorical_features):
        all_embedded_vars = []
        for i,embedding in enumerate(self.embeddings):
            var_indices = categorical_features[...,i]
            embedded_vars = embedding(var_indices)
            all_embedded_vars.append(embedded_vars)
        stacked_embedded_vars = torch.stack(all_embedded_vars,dim=0)
        embedded_vars = torch.sum(stacked_embedded_vars,dim=0)
        return embedded_vars

In [44]:
num_categorical_node_variables = len(dataset.categorical_node_variables)
num_continuous_node_variables = len(dataset.continuous_node_variables)
node_embedding_dim = d_model - num_continuous_node_variables # Note d_model is set to some value 

In [45]:
# categorical_node_features = dataset[1]['categorical_node_features']
# node_embedder = Embedder(dataset.categorical_node_variables, node_embedding_dim)
# node_embedder(categorical_node_features)

In [46]:
# combining embedded features with cts features 

def FeatureCombiner(Module):
    def __init__(self,categorical_variables,embedding_dim):
        super().__init__()
        self.categorical_variables = categorical_variables
        self.embedder = Embedder(self.categorical_variables,embedding_dim)
        
    def forward(self,continuous_features, categorical_features):
        features = []
        if categorical_features is not None:
            embedded_features = self.embedder(categorical_features)
            features.append(embedded_features)
            
        if continuous_features is not None:
            features.append(continuous_features)

        if len(features) == 0:
            raise RuntimeError('No features to combine')
        
        full_features = torch.cat(features,dim=-1) # concatanate along feature dimension 
        return full_features