<a href="https://colab.research.google.com/github/tayssirmoussa66/chemical-reaction-prediction/blob/main/Graph_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from rdkit import Chem
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data, Batch, download_url
import numpy as np
import os.path as osp
from torch_geometric.loader import DataLoader
from tqdm import tqdm

In [None]:
def get_bin_feature(r):
    '''
    This function is used to generate Adjacency Matrix
    '''
    rmol = Chem.MolFromSmiles(r)
    n_atoms = rmol.GetNumAtoms()
    
    
   
    index=[]
    for atom1 in rmol.GetAtoms():
       for atom2 in rmol.GetAtoms():
            i= atom1.GetIntProp('molAtomMapNumber') -1
            j=atom2.GetIntProp('molAtomMapNumber') -1
            
            idx=[]
            idx.append(i)
            idx.append(j)
           
            index.append(idx)

    index = np.asarray(index)
    index = torch.tensor(index)
    index = index.t().to(torch.long).view(2, -1)

    return index


In [None]:
bo_to_index  = {0.0: 0, 1:1, 2:2, 3:3, 1.5:4}
nbos = len(bo_to_index)
INVALID_BOND = -1

def get_bond_label(r, edits):
    '''
    This function is used to generate label vector
    '''
    rmol = Chem.MolFromSmiles(r)
    n_atoms = rmol.GetNumAtoms()
    rmap = np.zeros((n_atoms, n_atoms, nbos))
    
    for s in edits.split(';'):
        a1,a2,bo = s.split('-')
        x = min(int(a1)-1,int(a2)-1)
        y = max(int(a1)-1, int(a2)-1)
        z = bo_to_index[float(bo)]
        rmap[x,y,z] = rmap[y,x,z] = 1

   
    sp_labels = []
    for i in range(n_atoms):
        for j in range(n_atoms):
            for k in range(len(bo_to_index)):
                  if rmap[i,j,k] == 1:
                      sp_labels.append(i * n_atoms * nbos + j * nbos + k)
                        # TODO: check if this is consistent with how TF does flattening
    labels = np.reshape(rmap,(n_atoms*n_atoms,nbos))
    return labels, sp_labels

In [None]:
class MoleculeDataset(Dataset):
    def __init__(self, root, filename, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data).
        """
        self.filename = filename
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):

       return self.filename

    @property
    def processed_file_names(self):

       return 'not_implemented.pt'


    def download(self):
        pass

    def process(self):
        
        self.data = open(self.raw_paths[0], "r")
        idx = 0
       
        for line in self.data:

            r, e = line.strip("\r\n ").split()
            react = r.split('>')[0]
            labels, sp_labels = get_bond_label(react,e)
            edge_index = get_bin_feature(react)

            mol_obj = Chem.MolFromSmiles(react)
            # Get node features
            node_feats = self._get_node_features(mol_obj)
            # Get edg features
            edge_weight = self._get_edge_weights(mol_obj)
            
            # Get labels info
            label = self._get_labels(labels)
            # Get Sp_labels info
            sp_label = self._get_sp_labels(sp_labels)
            # Create data object
            data = Data(x=node_feats,
                        edge_index=edge_index,
                        edge_weight=edge_weight,
                        y=label,
                        z=sp_label,
                        smiles=react
                        )

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            print(f"data n°{idx} saved")
            idx += 1

    def _get_node_features(self, mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        n_atoms = mol.GetNumAtoms()
        all_node_feats = np.zeros((n_atoms, 7))

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Explicit Valence
            node_feats.append(atom.GetExplicitValence())
            # Feature 4: Implicit Valence
            node_feats.append(atom.GetImplicitValence())
            # Feature 5: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 6: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 7: In Ring
            node_feats.append(atom.IsInRing())

            # Append node features to matrix
            all_node_feats[ atom.GetIntProp('molAtomMapNumber') -1] = node_feats

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_weights(self, mol):
        """
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        n_atoms = mol.GetNumAtoms()
        edge_weight = np.zeros((n_atoms, n_atoms))

        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_weight[i,j]=bond.GetBondTypeAsDouble()        

        edge_weight = np.asarray(edge_weight)
        edge_weight = torch.tensor(edge_weight)
        
        return edge_weight


    def _get_labels(self, e):
     
      label = np.asarray(e)
      return torch.tensor(label)

    def _get_sp_labels(self, e):
     
      sp_label = np.asarray(e)
      return torch.tensor(sp_label) 

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """

        data = torch.load(osp.join(self.processed_dir,
                                           f'data_{idx}.pt'))


        return data



In [None]:
Train_Dataset= MoleculeDataset(root="/content/sample_data/Data_train",filename="Train.txt")

In [None]:
Test_Dataset = MoleculeDataset(root="/content/sample_data/Data_test",filename="Test.txt")

In [None]:
Train_loader = DataLoader(Train_Dataset, batch_size=5000, shuffle=True)

In [None]:
Test_loader = DataLoader(Test_Dataset, batch_size=5000, shuffle=True)