In [31]:
import sys
sys.path.insert(0, '/home/oscarwu/code/chemprop_developing')

import pickle as pkl
import numpy as np
import pandas as pd

In [32]:
# load an example
df_smi = pd.read_csv('esol.csv')
df_smi.iloc[[7]]

Unnamed: 0,smiles,logSolubility
7,CC#N,0.26


In [33]:
smiles = df_smi.iloc[[7]]['smiles'].values.tolist()

In [34]:
smiles

['CC#N']

In [35]:
from chemprop.features.utils import load_features, load_valid_atom_or_bond_features
from chemprop.models.mpn import mask_features_extra_batch, mask_features_extra

In [36]:
bond_features = [load_valid_atom_or_bond_features('esol_wb97xd_bond_RBF_features.pkl', smiles)[7]]

In [37]:
atom_descriptors = [load_valid_atom_or_bond_features('esol_wb97xd_atom_RBF_features.pkl', smiles)[7]]

In [38]:
mol_features = [load_features("esol_wb97xd_molecule_features.csv")[7].reshape(1, -1)]

In [39]:
from chemprop.args import TrainArgs
from chemprop.models.model import MoleculeModel

import torch
from chemprop.utils import load_checkpoint, load_scalers
from chemprop.features import set_extra_bond_fdim

path = "/home/oscarwu/code/chemprop_developing/developing/shap_v1/data/model.pt"
state = torch.load(path, map_location=lambda storage, loc: storage)
train_args = TrainArgs()
train_args.from_dict(vars(state["args"]), skip_unsettable=True)
set_extra_bond_fdim(train_args.bond_features_size)
model = load_checkpoint(path)
scalers = load_scalers(path)
target_scaler, mol_feature_scaler, _, _, _ = scalers

Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.weight".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "readout.7.weight".
Loading pretrained parameter "readout.7.bias".


In [115]:
# Example SHAP Analysis
# Imports
from chemprop.args import TrainArgs
from chemprop.models.model import MoleculeModel

import torch
from chemprop.utils import load_checkpoint, load_scalers
from chemprop.features import set_extra_bond_fdim

from rdkit import Chem
from chemprop.rdkit import make_mol

# Load TrainArgs and MoleculeModel
path = "/home/oscarwu/code/chemprop_developing/developing/shap_v1/data/model.pt"
state = torch.load(path, map_location=lambda storage, loc: storage)
train_args = TrainArgs()
train_args.from_dict(vars(state["args"]), skip_unsettable=True)
set_extra_bond_fdim(train_args.bond_features_size)
model = load_checkpoint(path)
scalers = load_scalers(path)

# Create a MoleculeModel Wrapper
class MoleculeModelWrapper:
    def __init__(self, model, train_args, scalers, smiles, shap, features_batch, 
                 atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch):
        self.model = model
        self.train_args = train_args
        self.target_scaler, self.mol_feature_scaler, _, _, _ = scalers
        self.smiles = smiles
        self.shap = shap
        if features_batch is not None:
            self.features_batch = [self.mol_feature_scaler.transform(feat) for feat in features_batch]
        else:
            self.features_batch = None      
        self.atom_descriptors_batch = atom_descriptors_batch
        self.atom_features_batch = atom_features_batch
        self.bond_descriptors_batch = bond_descriptors_batch
        self.bond_features_batch = bond_features_batch
        
        self.batch = None
        self.extra_keep_features_batch = None
        self.extra_atom_keep_descriptors_batch = None
        self.extra_bond_keep_descriptors_batch = None
        self.extra_atom_keep_features_batch = None
        self.extra_bond_keep_features_batch = None
        self.chemprop_atom_keep_features = None
        self.chemprop_bond_keep_features = None
        

    def __call__(self, feature_choices):
        if isinstance(feature_choices, np.ndarray):
            if len(feature_choices.shape) == 1:
                feature_choices = feature_choices.reshape(1, -1)
                
        result = []
        
        for feature_choice in feature_choices:
        
            self.batch = [[make_mol(s=smi, keep_h=True, add_h=True, keep_atom_map=True)] for smi in self.smiles]
            self.extra_keep_features_batch = [feature_choice[0:20]] # 20 extra molecular features
            self.extra_atom_keep_descriptors_batch = [feature_choice[20:33]] # 13 atom descriptors
            self.extra_bond_keep_descriptors_batch = None # 0 bond descriptors
            self.extra_atom_keep_features_batch = None # 0 atom features
            self.extra_bond_keep_features_batch = [feature_choice[33:37]] # 4 bond features
            self.chemprop_atom_keep_features = feature_choice[37:45] # 8 chemprop atom features
            self.chemprop_bond_keep_features = feature_choice[45:49] # 4 chemprop bond features
            
            # this is from forward in original MoleculeModel in model.py
            

            output = self.model(batch=self.batch, 
                                features_batch=self.features_batch, 
                                atom_descriptors_batch=self.atom_descriptors_batch, 
                                atom_features_batch=self.atom_features_batch, 
                                bond_descriptors_batch=self.bond_descriptors_batch, 
                                bond_features_batch=self.bond_features_batch, 
                                constraints_batch=None,
                                bond_types_batch=None,
                                shap=self.shap, 
                                extra_keep_features_batch=self.extra_keep_features_batch, 
                                extra_atom_keep_descriptors_batch=self.extra_atom_keep_descriptors_batch, 
                                extra_bond_keep_descriptors_batch=self.extra_bond_keep_descriptors_batch, 
                                extra_atom_keep_features_batch=self.extra_atom_keep_features_batch, 
                                extra_bond_keep_features_batch=self.extra_bond_keep_features_batch, 
                                chemprop_atom_keep_features=self.chemprop_atom_keep_features, 
                                chemprop_bond_keep_features=self.chemprop_bond_keep_features)
        
            xform = self.target_scaler.inverse_transform(output.item()).item()
            xform = np.array(xform, ndmin=2)
            result.append(xform)
        
        return np.array(result).reshape(len(feature_choices), -1)
        

Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.weight".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "readout.7.weight".
Loading pretrained parameter "readout.7.bias".


In [150]:
model_wrapper = MoleculeModelWrapper(model=model.eval(), 
                                     train_args=train_args, 
                                     scalers=scalers, 
                                     smiles=smiles, 
                                     shap=True, 
                                     features_batch=mol_features, 
                                     atom_descriptors_batch=atom_descriptors, 
                                     atom_features_batch=None, 
                                     bond_descriptors_batch=None, 
                                     bond_features_batch=bond_features)

In [151]:
feature_choice = np.array([[1]*49])

In [152]:
output = model_wrapper(feature_choice)

In [153]:
output

array([[0.12985924]])

In [154]:
feature_choice = np.array([[1]*49, [0]*49])

In [155]:
output = model_wrapper(feature_choice)

In [156]:
# why this one is not determinstic?
output

array([[ 0.12985924],
       [-2.13154999]])

In [157]:
# SHAP
from shap import PermutationExplainer
from copy import deepcopy

def binary_masker(binary_mask, x):
    masked_x = deepcopy(x)
    masked_x[binary_mask == 0] = 0
    return np.array([masked_x])

In [163]:
explainer = PermutationExplainer(model=model_wrapper, masker=binary_masker)

In [164]:
feature_choice = np.array([[1]*49])

In [165]:
# feature_choice = np.array([[1]*49, [0]*49])

In [166]:
explanation = explainer(feature_choice, max_evals=100)

In [167]:
explanation

.values =
array([[-8.73498138e-02, -2.82137423e-02,  5.66453497e-02,
        -6.20505195e-02,  2.94889429e-02,  4.41313191e-02,
         1.10207176e-01,  9.38828885e-02,  2.98273422e-02,
         2.09957323e-02,  9.64797584e-02,  1.65479839e-01,
         1.28751694e-01,  1.08275224e-01,  1.07254705e-01,
         9.16567172e-02,  1.76093135e-01,  7.43381462e-02,
         1.03414439e-01, -4.23458279e-02,  3.02746431e-02,
         5.32220653e-02,  2.61540791e-02,  2.88479889e-01,
         2.15035788e-01,  2.55079709e-02,  1.02111332e-02,
         1.05198507e-01, -8.75544065e-05,  6.48916754e-02,
         2.85638981e-02,  3.66373425e-03,  2.90563684e-02,
        -1.34587706e-01,  1.11285239e-01,  7.54855111e-03,
         2.96121997e-02,  6.76332731e-02,  5.99750160e-02,
         1.83033260e-02,  3.80254047e-02, -2.40196078e-03,
         5.47026319e-03,  0.00000000e+00, -1.05436543e-03,
        -3.03397985e-02,  0.00000000e+00,  0.00000000e+00,
        -5.19491541e-03]])

.base_values =
arr