In [38]:
import torch
import sys

In [39]:
if 'google.colab' in str(get_ipython()):
    # Drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    sys.path.append('gdrive/MyDrive/mldd')

    #Project dir
    project_dir = 'gdrive/MyDrive/mldd'

    # RDKit
    sys.path.append('/usr/local/lib/python3.7/site-packages/')
    try:
        from rdkit import Chem
        from rdkit.Chem.Draw import IPythonConsole
    except ImportError:
        !add-apt-repository ppa:ubuntu-toolchain-r/test
        !apt-get update --fix-missing
        !apt-get dist-upgrade
        !wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
        !chmod +x Miniconda3-latest-Linux-x86_64.sh
        !./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
        !conda config --set always_yes yes --set changeps1 no
        !conda install -q -y -c conda-forge python=3.7 rdkit
        
        print('Stopping RUNTIME. Colaboratory will restart automatically. Please run cell again.')
        exit()

    #Torch Geometric
    try:
        import torch_geometric
    except ImportError:
        # Add this in a Google Colab cell to install the correct version of Pytorch Geometric.

        def format_pytorch_version(version):
            return version.split('+')[0]

        TORCH_version = torch.__version__
        TORCH = format_pytorch_version(TORCH_version)

        def format_cuda_version(version):
            return 'cu' + version.replace('.', '')

        CUDA_version = torch.version.cuda
        CUDA = format_cuda_version(CUDA_version)

        !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
        !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
        !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
        !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
        !pip install torch-geometric 

else:
    from rdkit import Chem
    from rdkit.Chem.Draw import IPythonConsole
    sys.path.append('.')
    project_dir = '.'

In [40]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
from mldd.data import Featurizer
import numpy as np

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise ValueError("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


class GraphFeaturizer(Featurizer):
    def __call__(self, df):
        graphs = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row.smiles
            mol = Chem.MolFromSmiles(smiles)
            graph = self.process_mol(mol)
            graphs.append(graph)
            labels.append(y)
        labels = np.array(labels)
        return graphs, labels

    def process_mol(self, mol):
        edges = []
        for bond in mol.GetBonds():
            edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
            edges.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
        edges = np.array(edges)
            
        nodes = []
        for atom in mol.GetAtoms():
            results = one_of_k_encoding_unk(
                atom.GetSymbol(),
                [
                    'Br', 'C', 'Cl', 'F', 'H', 'I', 'N', 'O', 'P', 'S', 'Unknown'
                ]
            ) + one_of_k_encoding(
                atom.GetDegree(),
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            ) + one_of_k_encoding_unk(
                atom.GetImplicitValence(),
                [0, 1, 2, 3, 4, 5, 6]
            ) + [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + one_of_k_encoding_unk(
                atom.GetHybridization(),
                [
                    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                    Chem.rdchem.HybridizationType.SP3D2
                ]
            ) + [atom.GetIsAromatic()] + one_of_k_encoding_unk(
                atom.GetTotalNumHs(),
                [0, 1, 2, 3, 4]
            )
            nodes.append(results)
        nodes = np.array(nodes)

        return (nodes, edges.T)

In [42]:
from torch_geometric.data import InMemoryDataset, Data

class GraphDataset(InMemoryDataset):
    def __init__(self, X, y, root, transform=None, pre_transform=None):
        self.dataset = (X, y)
        super().__init__(root, transform, pre_transform)
        
        self.download()
        self.process()

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['data.pt']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        data = []

        for (nodes, edges), target in zip(*self.dataset):
            nodes = torch.tensor(nodes, dtype=torch.float)
            edges = torch.tensor(edges, dtype=torch.long)
            target = torch.tensor(target, dtype=torch.float)
            
            graph = Data(x=nodes, edge_index=edges, y=target)
            graph.num_nodes = nodes.shape[0]
            data.append(graph)

        torch.save(data, self.raw_paths[0])
        

    def process(self):
        data_list = torch.load(self.raw_paths[0])
        
        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [43]:
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

from torch.nn import ReLU
from torch_geometric.nn import GCNConv, Linear, global_mean_pool

class GCNRegressor(torch.nn.Module):
    def __init__(self, input, hidden, output):
        super().__init__()

        self.input = None
        self.final_conv_acts = None
        self.final_conv_grads = None

        self.conv1 = GCNConv(input, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.conv3 = GCNConv(hidden, hidden)
        self.conv4 = GCNConv(hidden, output)
        self.linear = Linear(output, 1)
        self.relu = ReLU(inplace=True)
    
    def activations_hook(self, grad):
        self.final_conv_grads = grad

    def forward(self, x, edge_index, batch):
        x.requires_grad = True
        self.input = x

        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        with torch.enable_grad():
            self.final_conv_acts = self.conv4(x, edge_index)
        self.final_conv_acts.register_hook(self.activations_hook)

        x = global_mean_pool(self.final_conv_acts, batch=batch)
        x = self.linear(x)

        return x

In [44]:
from mldd.data import load_esol

def train(X_train, y_train, X_valid, y_valid):
    # hyperparameters definition
    hidden_size = 512
    epochs = 50
    batch_size = 64
    learning_rate = 0.0001
    
    # model preparation
    model = GCNRegressor(input=len(X_train[0][0][0]), hidden=hidden_size, output=64)
    model.train()
    
    # data preparation
    dataset = GraphDataset(X_train, y_train.reshape(-1, 1), root='esol-train')
    loader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=True, )
    
    # training loop
    optimizer = torch.optim.Adam(params=model.parameters())
    loss_fn = torch.nn.MSELoss()
    for epoch in trange(1, epochs + 1, leave=False):
        for data in tqdm(loader, leave=False):
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            
            model.zero_grad()
            preds = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            loss.backward()
            optimizer.step()
    return model


def predict(model, X_test, y_test):
    # hyperparameters definition
    # (but this doesn't change the training results, it's only to optimize the eval speed)
    batch_size = 64

    # data preparation
    dataset = GraphDataset(X_test, y_test.reshape(-1, 1), root='esol-test')
    loader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # evaluation loop
    preds_batches = []
    with torch.no_grad():
        for data in tqdm(loader):
            x, edge_index, batch = data.x, data.edge_index, data.batch
            
            preds = model(x, edge_index, batch)
            preds_batches.append(preds.cpu().detach().numpy())
    preds = np.concatenate(preds_batches)
    return preds

In [45]:
df, fold_indices = load_esol(split_path= project_dir + '/data/esol/split.npz')
featurizer = GraphFeaturizer(y_column='measured log solubility in mols per litre')
scores = []

In [46]:
import pandas as pd
import numpy as np
import torch
from tqdm.notebook import tqdm, trange

from mldd.metrics import mae, rmse, rocauc, r_squared
from mldd.data import *

# for train_data, valid_data, test_data in cross_validate(df, fold_indices, preprocessing_fn=featurizer):
#     X_train, y_train = train_data
#     X_valid, y_valid = valid_data
#     X_test, y_test = test_data
            
#     # training
#     modelG = train(X_train, y_train, X_valid, y_valid)
    
#     # # evaluation
#     predictions = predict(modelG, X_test, y_test)
    
#     rmse_score = rmse(y_test, predictions.flatten())
#     mae_score = mae(y_test, predictions.flatten())
#     r2_score = r_squared(y_test, predictions.flatten())
#     scores.append([rmse_score, mae_score, r2_score])
    
#     break  # can be removed to get results on all folds
# scores = np.array(scores)
# print('RMSE, MAE, R2 = ' + \
#       ', '.join(f'{mean:.2f}±{std:.3f}' for mean, std in zip(scores.mean(axis=0), scores.std(axis=0))))

In [47]:
# torch.save(modelG, project_dir + '/GCNRegressor.pth')

In [48]:
modelG = torch.load(project_dir + '/GCNRegressor.pth')

# Interpretability

In [49]:
from GNNInterpeter import GNNInterpreter
from ipywidgets import interact

mols = df.smiles.to_numpy()

@interact(mol = mols, 
          method=['substitution', 'gradcam', 'saliency'], 
          replace_atoms_with=['all', 'zero', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'N', 'O', 'P', 'S'], 
          replace_atom_alg=['number', 'atom'], 
          calculate_atom_weight_alg=['signed', 'absolute'])
def interactive(mol=mols[13], 
                method='substitution', 
                replace_atoms_with='all', 
                replace_atom_alg='number', 
                calculate_atom_weight_alg='signed'):
    interpreter = GNNInterpreter(modelG, featurizer)

    def b():
        out = interpreter.get_original_pred(return_tensor=True)
        y = df[df.smiles == mol]['measured log solubility in mols per litre'].iloc[0]
        y = torch.tensor(y, dtype=torch.float32).reshape((1, 1))
        torch.nn.MSELoss()(out, y).backward()

    return interpreter.get_importance_map_svg(mol, method, replace_atoms_with, 
                                              replace_atom_alg, calculate_atom_weight_alg, b)[0]

interactive(children=(Dropdown(description='mol', index=13, options=('OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2…

In [50]:
import svgutils.transform as sg
import os
import shutil

mol_idx = [10, 100, 400,  700, 800]
interpreter = GNNInterpreter(modelG, featurizer)

def b() :
    out = interpreter.get_original_pred(return_tensor=True)
    y = df[df.smiles == mol]['measured log solubility in mols per litre'].iloc[0]
    y = torch.tensor(y, dtype=torch.float32).reshape((1, 1))
    torch.nn.MSELoss()(out, y).backward()

for mol_idx in mol_idx:
    mol = df.iloc[mol_idx].smiles

    figs = {}

    path = f'./experiment/esol/{mol_idx}'
    try:
        if os.path. exists(path) and os.path.isdir(path):
            shutil.rmtree(path)
        os.makedirs(path)
    except OSError:
        pass

    figs['gradcam'] = interpreter.get_importance_map_svg(mol=mol, method='gradcam', backward_func=b)[1]
    figs['saliency'] = interpreter.get_importance_map_svg(mol=mol, method='saliency', backward_func=b)[1]

    figs['substitution all signed'] = interpreter.get_importance_map_svg(mol=mol, method='substitution')[1]
    # figs['fig_subst_zero_signed'] = interpreter.get_importance_map_svg(mol=mol, method='substitution', replace_atoms_with='zero')[1]
    # figs['fig_subst_C_signed'] = interpreter.get_importance_map_svg(mol=mol, method='substitution', replace_atoms_with='C')[1]

    figs['substitution all absolute'] = interpreter.get_importance_map_svg(mol=mol, method='substitution', calculate_atom_weight_alg='absolute')[1]
    figs['substitution zero absolute'] = interpreter.get_importance_map_svg(mol=mol, method='substitution', replace_atoms_with='zero', calculate_atom_weight_alg='absolute')[1]
    # figs['fig_subst_C_absolute'] = interpreter.get_importance_map_svg(mol=mol, method='substitution', replace_atoms_with='C', calculate_atom_weight_alg='absolute')[1]

    for method_idx, method_name in enumerate(figs):
        text = sg.TextElement(0, 470, f'Method: {method_name}', size=20)
        mol_text = sg.TextElement(0, 490, f'Molecule: {df.iloc[mol_idx]["Compound ID"]}', size=20)
        figs[method_name].append(text)
        figs[method_name].append(mol_text)
        im_path = f'{path}/{method_idx + 1}_{method_name}.svg'
        figs[method_name].save(im_path)
        os.system(f'inkscape "{im_path}" --export-png="{im_path.replace("svg", "png")}"')
        

Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 500 pixels (96 dpi)
Background RRGGBBAA: ffffff00
Area 0:0:500:500 exported to 500 x 