<a href="https://colab.research.google.com/github/tejaspradhan/Graph-Neural-Networks/blob/main/hiv-project/gnn_explainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installations and Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title
# Install rdkit
import sys
import os
import requests
import subprocess
import shutil
from logging import getLogger, StreamHandler, INFO
 
 
logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel(INFO)
 
 
def install(
        chunk_size=4096,
        file_name="Miniconda3-latest-Linux-x86_64.sh",
        url_base="https://repo.continuum.io/miniconda/",
        conda_path=os.path.expanduser(os.path.join("~", "miniconda")),
        rdkit_version=None,
        add_python_path=True,
        force=False):
    """install rdkit from miniconda
    ```
    import rdkit_installer
    rdkit_installer.install()
    ```
    """
 
    python_path = os.path.join(
        conda_path,
        "lib",
        "python{0}.{1}".format(*sys.version_info),
        "site-packages",
    )
 
    if add_python_path and python_path not in sys.path:
        logger.info("add {} to PYTHONPATH".format(python_path))
        sys.path.append(python_path)
 
    if os.path.isdir(os.path.join(python_path, "rdkit")):
        logger.info("rdkit is already installed")
        if not force:
            return
 
        logger.info("force re-install")
 
    url = url_base + file_name
    python_version = "{0}.{1}.{2}".format(*sys.version_info)
 
    logger.info("python version: {}".format(python_version))
 
    if os.path.isdir(conda_path):
        logger.warning("remove current miniconda")
        shutil.rmtree(conda_path)
    elif os.path.isfile(conda_path):
        logger.warning("remove {}".format(conda_path))
        os.remove(conda_path)
 
    logger.info('fetching installer from {}'.format(url))
    res = requests.get(url, stream=True)
    res.raise_for_status()
    with open(file_name, 'wb') as f:
        for chunk in res.iter_content(chunk_size):
            f.write(chunk)
    logger.info('done')
 
    logger.info('installing miniconda to {}'.format(conda_path))
    subprocess.check_call(["bash", file_name, "-b", "-p", conda_path])
    logger.info('done')
 
    logger.info("installing rdkit")
    subprocess.check_call([
        os.path.join(conda_path, "bin", "conda"),
        "install",
        "--yes",
        "-c", "rdkit",
        "python==3.7.3",
        "rdkit" if rdkit_version is None else "rdkit=={}".format(rdkit_version)])
    logger.info("done")
 
    import rdkit
    logger.info("rdkit-{} installation finished!".format(rdkit.__version__))
 
 
if __name__ == "__main__":
    install()

add /root/miniconda/lib/python3.7/site-packages to PYTHONPATH
python version: 3.7.12
fetching installer from https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
done
installing miniconda to /root/miniconda
done
installing rdkit
done
rdkit-2020.09.1 installation finished!


In [None]:
! pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
! pip install -q torch-geometric

[K     |████████████████████████████████| 7.9 MB 12.3 MB/s 
[K     |████████████████████████████████| 3.5 MB 10.8 MB/s 
[K     |████████████████████████████████| 2.3 MB 12.1 MB/s 
[K     |████████████████████████████████| 370 kB 13.2 MB/s 
[K     |████████████████████████████████| 482 kB 48.7 MB/s 
[K     |████████████████████████████████| 41 kB 548 kB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch 
import torch.nn.functional as F 
import torch_geometric
from torch_geometric.data import Dataset, Data
from torch_geometric.nn import GATConv, Linear, TopKPooling, GNNExplainer
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.loader import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm

## Creating the Graph Dataset 

In [None]:
class HIVDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        super(HIVDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        return 'not_implemented.pt'

    def download(self):
        pass

    def process(self):
        self.data = pd.read_csv(self.raw_paths[0])
        print(self.raw_paths)
        for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            mol_obj = Chem.MolFromSmiles(mol["smiles"])
            # Get node features
            node_feats = self._get_node_features(mol_obj)
            # Get edge features
            edge_feats = self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol_obj)
            # Get labels info
            label = self._get_labels(mol["HIV_active"])

            # Create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        smiles=mol["smiles"]
                        ) 
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))

    def _get_node_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data

In [None]:
train_dataset = HIVDataset(root='/content/drive/MyDrive/GNN_HIV/data/',filename='HIV_train_oversampled.csv')
test_dataset = HIVDataset(root='/content/drive/MyDrive/GNN_HIV/data/',filename='HIV_test.csv')

Processing...


['/content/drive/MyDrive/GNN_HIV/data/raw/HIV_train_oversampled.csv']


100%|██████████| 71634/71634 [12:16<00:00, 97.28it/s]
Done!
Processing...


['/content/drive/MyDrive/GNN_HIV/data/raw/HIV_test.csv']


100%|██████████| 3999/3999 [00:41<00:00, 97.08it/s]
Done!


## Loading the Model 

In [None]:
class GNN(torch.nn.Module):
  def __init__(self,feature_size):
    super(GNN, self).__init__()
    num_classes=1 
    embedding_size=1024
    self.conv1 = GATConv(feature_size,embedding_size, heads=3, dropout =0.3)
    self.head_transform1 = Linear(embedding_size*3, embedding_size)
    self.pool1 = TopKPooling(embedding_size,ratio=0.8)

    self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout =0.3)
    self.head_transform2 = Linear(embedding_size*3, embedding_size)
    self.pool2 = TopKPooling(embedding_size,ratio=0.5)

    self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout =0.3)
    self.head_transform3 = Linear(embedding_size*3, embedding_size)
    self.pool3 = TopKPooling(embedding_size,ratio=0.2)

    self.linear1 = Linear(embedding_size*2,1024)
    self.linear2 = Linear(1024,num_classes)

  def forward(self, x, edge_attr,edge_index,batch_index):
    x = self.conv1(x,edge_index)
    x = self.head_transform1(x)

    x,edge_index,edge_attr,batch_index, _ , _ = self.pool1(x,edge_index,None, batch_index)

    x1 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)

    x = self.conv2(x,edge_index)
    x = self.head_transform2(x)

    x,edge_index,edge_attr,batch_index, _ , _ = self.pool2(x,edge_index,None, batch_index)

    x2 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)

    x = self.conv3(x,edge_index)
    x = self.head_transform3(x)

    x,edge_index,edge_attr,batch_index, _ , _ = self.pool3(x,edge_index,None, batch_index)

    x3 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)

    x=x1+x2+x3

    x = self.linear1(x).relu()
    x = F.dropout(x,p=0.5, training=self.training) 
    x = self.linear2(x)

    return x

In [None]:
model = GNN(train_dataset[0].x.shape[1])

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/GNN_HIV/models/model-hiv-15epochs.pt'))

<All keys matched successfully>

## Generating Explainations 

In [None]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=True)

In [None]:
for _, batch in enumerate(train_loader):
  node_idx=1
  explainer = GNNExplainer(model, epochs=200)
  node_feat_mask, edge_mask = explainer.explain_node(node_idx,batch.x.float(),batch.edge_index,edge_attr= batch.edge_attr.float(),batch_index =batch.batch)


  0%|          | 0/200 [00:00<?, ?it/s][A
Explain node 1:   0%|          | 0/200 [00:00<?, ?it/s][A

AssertionError: ignored

In [None]:
pip install graphlime

Collecting graphlime
  Downloading graphlime-1.2.0.tar.gz (3.3 kB)
Building wheels for collected packages: graphlime
  Building wheel for graphlime (setup.py) ... [?25l[?25hdone
  Created wheel for graphlime: filename=graphlime-1.2.0-py3-none-any.whl size=2617 sha256=a81cad2e73fc13db93ed123dabfabe3008394c3afa2d6bb9d05178cdc5b59232
  Stored in directory: /root/.cache/pip/wheels/33/29/94/9835c557e2def18b58369cda0032935a3263acfa9266aaeb5d
Successfully built graphlime
Installing collected packages: graphlime
Successfully installed graphlime-1.2.0
