In [1]:
import sys
sys.path.insert(0, '/home/oscarwu/code/chemprop_developing')

import pickle as pkl
import numpy as np
import pandas as pd

In [2]:
# load an example
df_smi = pd.read_csv('esol.csv')
df_smi.iloc[[7]]

Unnamed: 0,smiles,logSolubility
7,CC#N,0.26


In [3]:
smiles = df_smi.iloc[[7]]['smiles'].values.tolist()

In [4]:
smiles

['CC#N']

In [5]:
from chemprop.features.utils import load_features, load_valid_atom_or_bond_features
from chemprop.models.mpn import mask_features_extra_batch, mask_features_extra

In [6]:
bond_features = [load_valid_atom_or_bond_features('esol_wb97xd_bond_RBF_features.pkl', smiles)[7]]

In [7]:
atom_descriptors = [load_valid_atom_or_bond_features('esol_wb97xd_atom_RBF_features.pkl', smiles)[7]]

In [8]:
mol_features = [load_features("esol_wb97xd_molecule_features.csv")[7].reshape(1, -1)]

In [9]:
# Example SHAP Analysis
# Imports
from chemprop.args import TrainArgs
from chemprop.models.model import MoleculeModel

import torch
from chemprop.utils import load_checkpoint, load_scalers
from chemprop.features import set_extra_bond_fdim

from rdkit import Chem
from chemprop.rdkit import make_mol

# Load TrainArgs and MoleculeModel
path = "/home/oscarwu/code/chemprop_developing/developing/shap_v1/data/model.pt"
state = torch.load(path, map_location=lambda storage, loc: storage)
train_args = TrainArgs()
train_args.from_dict(vars(state["args"]), skip_unsettable=True)
set_extra_bond_fdim(train_args.bond_features_size)
model = load_checkpoint(path)
scalers = load_scalers(path)

# Create a MoleculeModel Wrapper
class MoleculeModelWrapper:
    def __init__(self, model, train_args, scalers, smiles, shap, features_batch, 
                 atom_descriptors_batch, atom_features_batch, bond_descriptors_batch, bond_features_batch):
        self.model = model
        self.train_args = train_args
        self.target_scaler, self.mol_feature_scaler, _, _, _ = scalers
        self.smiles = smiles
        self.shap = shap
        if features_batch is not None:
            self.features_batch = [self.mol_feature_scaler.transform(feat) for feat in features_batch]
        else:
            self.features_batch = None      
        self.atom_descriptors_batch = atom_descriptors_batch
        self.atom_features_batch = atom_features_batch
        self.bond_descriptors_batch = bond_descriptors_batch
        self.bond_features_batch = bond_features_batch
        
        self.batch = None
        self.extra_keep_features_batch = None
        self.extra_atom_keep_descriptors_batch = None
        self.extra_bond_keep_descriptors_batch = None
        self.extra_atom_keep_features_batch = None
        self.extra_bond_keep_features_batch = None
        self.chemprop_atom_keep_features = None
        self.chemprop_bond_keep_features = None
        

    def __call__(self, feature_choices):
        if isinstance(feature_choices, np.ndarray):
            if len(feature_choices.shape) == 1:
                feature_choices = feature_choices.reshape(1, -1)
                
        result = []
        
        for feature_choice in feature_choices:
        
            self.batch = [[make_mol(s=smi, keep_h=True, add_h=True, keep_atom_map=True)] for smi in self.smiles]
            self.extra_keep_features_batch = [feature_choice[0:20]] # 20 extra molecular features
            self.extra_atom_keep_descriptors_batch = [feature_choice[20:33]] # 13 atom descriptors
            self.extra_bond_keep_descriptors_batch = None # 0 bond descriptors
            self.extra_atom_keep_features_batch = None # 0 atom features
            self.extra_bond_keep_features_batch = [feature_choice[33:37]] # 4 bond features
            self.chemprop_atom_keep_features = feature_choice[37:45] # 8 chemprop atom features
            self.chemprop_bond_keep_features = feature_choice[45:49] # 4 chemprop bond features
            
            # this is from forward in original MoleculeModel in model.py
            output = self.model(batch=self.batch, 
                                features_batch=self.features_batch, 
                                atom_descriptors_batch=self.atom_descriptors_batch, 
                                atom_features_batch=self.atom_features_batch, 
                                bond_descriptors_batch=self.bond_descriptors_batch, 
                                bond_features_batch=self.bond_features_batch, 
                                constraints_batch=None,
                                bond_types_batch=None,
                                shap=self.shap, 
                                extra_keep_features_batch=self.extra_keep_features_batch, 
                                extra_atom_keep_descriptors_batch=self.extra_atom_keep_descriptors_batch, 
                                extra_bond_keep_descriptors_batch=self.extra_bond_keep_descriptors_batch, 
                                extra_atom_keep_features_batch=self.extra_atom_keep_features_batch, 
                                extra_bond_keep_features_batch=self.extra_bond_keep_features_batch, 
                                chemprop_atom_keep_features=self.chemprop_atom_keep_features, 
                                chemprop_bond_keep_features=self.chemprop_bond_keep_features)
        
            xform = self.target_scaler.inverse_transform(output.item()).item()
            xform = np.array(xform, ndmin=2)
            result.append(xform)
        
        return np.array(result).reshape(len(feature_choices), -1)
        

Loading pretrained parameter "encoder.encoder.0.cached_zero_vector".
Loading pretrained parameter "encoder.encoder.0.W_i.weight".
Loading pretrained parameter "encoder.encoder.0.W_h.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.weight".
Loading pretrained parameter "encoder.encoder.0.W_o.bias".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.weight".
Loading pretrained parameter "encoder.encoder.0.atom_descriptors_layer.bias".
Loading pretrained parameter "readout.1.weight".
Loading pretrained parameter "readout.1.bias".
Loading pretrained parameter "readout.4.weight".
Loading pretrained parameter "readout.4.bias".
Loading pretrained parameter "readout.7.weight".
Loading pretrained parameter "readout.7.bias".


In [10]:
model_wrapper = MoleculeModelWrapper(model=model.eval(), 
                                     train_args=train_args, 
                                     scalers=scalers, 
                                     smiles=smiles, 
                                     shap=True, 
                                     features_batch=mol_features, 
                                     atom_descriptors_batch=atom_descriptors, 
                                     atom_features_batch=None, 
                                     bond_descriptors_batch=None, 
                                     bond_features_batch=bond_features)

In [11]:
# SHAP
from shap import PermutationExplainer
from copy import deepcopy

def binary_masker(binary_mask, x):
    masked_x = deepcopy(x)
    masked_x[binary_mask == 0] = 0
    return np.array([masked_x])

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [12]:
explainer = PermutationExplainer(model=model_wrapper, masker=binary_masker)

In [13]:
feature_choice = np.array([[1]*49])

In [14]:
explanation = explainer(feature_choice, max_evals=100)

In [15]:
explanation

.values =
array([[-1.43640646e-01, -6.84858612e-02,  2.85901953e-02,
         1.96509833e-02,  8.47005351e-02,  9.23541205e-02,
         5.97737337e-02,  7.33205969e-02,  9.52663719e-02,
         1.73071858e-02,  1.01696825e-01,  1.56419227e-01,
         7.58900558e-02,  1.34595440e-01,  1.37718379e-01,
         1.20293876e-01,  2.19930482e-02,  1.14377117e-01,
         1.43023928e-01, -6.24533316e-02,  4.34926978e-02,
         5.31633450e-02,  2.57224946e-02,  2.97803444e-01,
         2.05844927e-01,  2.52376350e-02,  4.09241053e-03,
         1.12925137e-01,  2.10068700e-04,  6.32235629e-02,
         2.59969761e-02,  8.41512317e-03,  2.53351515e-02,
        -1.21231080e-01,  1.36803110e-01, -7.95359079e-03,
         4.99466642e-02,  6.72860874e-02,  6.18939734e-02,
         2.60776004e-02,  5.55812698e-03,  1.10736833e-02,
        -7.45870730e-03,  0.00000000e+00, -8.50917455e-04,
        -6.96938026e-02,  0.00000000e+00,  0.00000000e+00,
         1.61033282e-02]])

.base_values =
arr