In [3]:
from typing import List, Tuple, Union
from itertools import zip_longest
import logging

from rdkit import Chem
import torch
import numpy as np

from chemprop.rdkit import make_mol

class Featurization_parameters:
    """
    A class holding molecule featurization parameters as attributes.
    """
    def __init__(self) -> None:

        # Atom feature sizes
        self.MAX_ATOMIC_NUM = 100
        self.ATOM_FEATURES = {
            'atomic_num': list(range(self.MAX_ATOMIC_NUM)),
            'degree': [0, 1, 2, 3, 4, 5],
            'formal_charge': [-1, -2, 1, 2, 0],
            'chiral_tag': [0, 1, 2, 3],
            'num_Hs': [0, 1, 2, 3, 4],
            'hybridization': [
                Chem.rdchem.HybridizationType.SP,
                Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3,
                Chem.rdchem.HybridizationType.SP3D,
                Chem.rdchem.HybridizationType.SP3D2
            ],
        }

        # Distance feature sizes
        self.PATH_DISTANCE_BINS = list(range(10))
        self.THREE_D_DISTANCE_MAX = 20
        self.THREE_D_DISTANCE_STEP = 1
        self.THREE_D_DISTANCE_BINS = list(range(0, self.THREE_D_DISTANCE_MAX + 1, self.THREE_D_DISTANCE_STEP))

        # len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
        self.ATOM_FDIM = sum(len(choices) + 1 for choices in self.ATOM_FEATURES.values()) + 2
        self.EXTRA_ATOM_FDIM = 0
        self.BOND_FDIM = 14
        self.EXTRA_BOND_FDIM = 0
        self.REACTION_MODE = None
        self.EXPLICIT_H = False
        self.REACTION = False
        self.ADDING_H = False
        self.KEEP_ATOM_MAP = False

# Create a global parameter object for reference throughout this module
PARAMS = Featurization_parameters()


def reset_featurization_parameters(logger: logging.Logger = None) -> None:
    """
    Function resets feature parameter values to defaults by replacing the parameters instance.
    """
    if logger is not None:
        debug = logger.debug
    else:
        debug = print
    debug('Setting molecule featurization parameters to default.')
    global PARAMS
    PARAMS = Featurization_parameters()


def get_atom_fdim(overwrite_default_atom: bool = False, is_reaction: bool = False) -> int:
    """
    Gets the dimensionality of the atom feature vector.

    :param overwrite_default_atom: Whether to overwrite the default atom descriptors.
    :param is_reaction: Whether to add :code:`EXTRA_ATOM_FDIM` for reaction input when :code:`REACTION_MODE` is not None.
    :return: The dimensionality of the atom feature vector.
    """
    if PARAMS.REACTION_MODE:
        return (not overwrite_default_atom) * PARAMS.ATOM_FDIM + is_reaction * PARAMS.EXTRA_ATOM_FDIM
    else:
        return (not overwrite_default_atom) * PARAMS.ATOM_FDIM + PARAMS.EXTRA_ATOM_FDIM


def set_explicit_h(explicit_h: bool) -> None:
    """
    Sets whether RDKit molecules will be constructed with explicit Hs.

    :param explicit_h: Boolean whether to keep explicit Hs from input.
    """
    PARAMS.EXPLICIT_H = explicit_h

def set_adding_hs(adding_hs: bool) -> None:
    """
    Sets whether RDKit molecules will be constructed with adding the Hs to them.

    :param adding_hs: Boolean whether to add Hs to the molecule.
    """
    PARAMS.ADDING_H = adding_hs

def set_keeping_atom_map(keeping_atom_map: bool) -> None:
    """
    Sets whether RDKit molecules keep the original atom mapping.

    :param keeping_atom_map: Boolean whether to keep the original atom mapping.
    """
    PARAMS.KEEP_ATOM_MAP = keeping_atom_map

def set_reaction(reaction: bool, mode: str) -> None:
    """
    Sets whether to use a reaction or molecule as input and adapts feature dimensions.
 
    :param reaction: Boolean whether to except reactions as input.
    :param mode: Reaction mode to construct atom and bond feature vectors.

    """
    PARAMS.REACTION = reaction
    if reaction:
        PARAMS.EXTRA_ATOM_FDIM = PARAMS.ATOM_FDIM - PARAMS.MAX_ATOMIC_NUM - 1
        PARAMS.EXTRA_BOND_FDIM = PARAMS.BOND_FDIM
        PARAMS.REACTION_MODE = mode
        
def is_explicit_h(is_mol: bool = True) -> bool:
    r"""Returns whether to retain explicit Hs (for reactions only)"""
    if not is_mol:
        return PARAMS.EXPLICIT_H
    return False


def is_adding_hs(is_mol: bool = True) -> bool:
    r"""Returns whether to add explicit Hs to the mol (not for reactions)"""
    if is_mol:
        return PARAMS.ADDING_H
    return False


def is_keeping_atom_map(is_mol: bool = True) -> bool:
    r"""Returns whether to keep the original atom mapping (not for reactions)"""
    if is_mol:
        return PARAMS.KEEP_ATOM_MAP
    return True


def is_reaction(is_mol: bool = True) -> bool:
    r"""Returns whether to use reactions as input"""
    if is_mol:
        return False
    if PARAMS.REACTION: #(and not is_mol, checked above)
        return True
    return False


def reaction_mode() -> str:
    r"""Returns the reaction mode"""
    return PARAMS.REACTION_MODE


def set_extra_atom_fdim(extra):
    """Change the dimensionality of the atom feature vector."""
    PARAMS.EXTRA_ATOM_FDIM = extra


def get_bond_fdim(atom_messages: bool = False,
                  overwrite_default_bond: bool = False,
                  overwrite_default_atom: bool = False,
                  is_reaction: bool = False) -> int:
    """
    Gets the dimensionality of the bond feature vector.

    :param atom_messages: Whether atom messages are being used. If atom messages are used,
                          then the bond feature vector only contains bond features.
                          Otherwise it contains both atom and bond features.
    :param overwrite_default_bond: Whether to overwrite the default bond descriptors.
    :param overwrite_default_atom: Whether to overwrite the default atom descriptors.
    :param is_reaction: Whether to add :code:`EXTRA_BOND_FDIM` for reaction input when :code:`REACTION_MODE:` is not None
    :return: The dimensionality of the bond feature vector.
    """

    if PARAMS.REACTION_MODE:
        return (not overwrite_default_bond) * PARAMS.BOND_FDIM + is_reaction * PARAMS.EXTRA_BOND_FDIM + \
            (not atom_messages) * get_atom_fdim(overwrite_default_atom=overwrite_default_atom, is_reaction=is_reaction)
    else:
        return (not overwrite_default_bond) * PARAMS.BOND_FDIM + PARAMS.EXTRA_BOND_FDIM + \
            (not atom_messages) * get_atom_fdim(overwrite_default_atom=overwrite_default_atom, is_reaction=is_reaction)


def set_extra_bond_fdim(extra):
    """Change the dimensionality of the bond feature vector."""
    PARAMS.EXTRA_BOND_FDIM = extra


def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
    """
    Creates a one-hot encoding with an extra category for uncommon values.

    :param value: The value for which the encoding should be one.
    :param choices: A list of possible values.
    :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`.
             If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1.
    """
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1

    return encoding


def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    if atom is None:
        features = [0] * PARAMS.ATOM_FDIM
    else:
        features = onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']) + \
            onek_encoding_unk(atom.GetTotalDegree(), PARAMS.ATOM_FEATURES['degree']) + \
            onek_encoding_unk(atom.GetFormalCharge(), PARAMS.ATOM_FEATURES['formal_charge']) + \
            onek_encoding_unk(int(atom.GetChiralTag()), PARAMS.ATOM_FEATURES['chiral_tag']) + \
            onek_encoding_unk(int(atom.GetTotalNumHs()), PARAMS.ATOM_FEATURES['num_Hs']) + \
            onek_encoding_unk(int(atom.GetHybridization()), PARAMS.ATOM_FEATURES['hybridization']) + \
            [1 if atom.GetIsAromatic() else 0] + \
            [atom.GetMass() * 0.01]  # scaled to about the same range as other features
        if functional_groups is not None:
            features += functional_groups
    return features

def atom_features_new(atom: Chem.rdchem.Atom, keep_features: List[bool] = [True]*8, functional_groups: List[int] = None) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for an atom.

    :param atom: An RDKit atom.
    :param keep_features: A boolean vector indicating which features to keep.
    :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
    :return: A list containing the atom features.
    """
    if atom is None:
        return [0] * PARAMS.ATOM_FDIM

    # Define features and their corresponding keep flags
    feature_constructors = [
        (lambda atom: onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']), 0),
        (lambda atom: onek_encoding_unk(atom.GetTotalDegree(), PARAMS.ATOM_FEATURES['degree']), 1),
        (lambda atom: onek_encoding_unk(atom.GetFormalCharge(), PARAMS.ATOM_FEATURES['formal_charge']), 2),
        (lambda atom: onek_encoding_unk(int(atom.GetChiralTag()), PARAMS.ATOM_FEATURES['chiral_tag']), 3),
        (lambda atom: onek_encoding_unk(int(atom.GetTotalNumHs()), PARAMS.ATOM_FEATURES['num_Hs']), 4),
        (lambda atom: onek_encoding_unk(int(atom.GetHybridization()), PARAMS.ATOM_FEATURES['hybridization']), 5),
        (lambda atom: [1 if atom.GetIsAromatic() else 0], 6),
        (lambda atom: [atom.GetMass() * 0.01], 7)
    ]

    # Build features based on keep_features
    features = []
    for constructor, flag_index in feature_constructors:
        feature = constructor(atom)
        if not keep_features[flag_index]:
            feature = [0] * len(feature)
        features.extend(feature)

    if functional_groups is not None:
        features.extend(functional_groups)

    return features

def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for a bond.

    :param bond: An RDKit bond.
    :return: A list containing the bond features.
    """
    if bond is None:
        fbond = [1] + [0] * (PARAMS.BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [
            0,  # bond is not None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
        fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
    return fbond

def bond_features_new(bond: Chem.rdchem.Bond, keep_features: List[bool] = [True]*4) -> List[Union[bool, int, float]]:
    """
    Builds a feature vector for a bond.

    :param bond: An RDKit bond.
    :param keep_features: A boolean vector indicating which features to keep.
    :return: A list containing the bond features.
    """
    if bond is None:
        fbond = [1] + [0] * (PARAMS.BOND_FDIM - 1)
    else:
        bt = bond.GetBondType()
        fbond = [0] # bond is not None
        bond_features_list = [
            (lambda bond: [bt == Chem.rdchem.BondType.SINGLE], 0),
            (lambda bond: [bt == Chem.rdchem.BondType.DOUBLE], 0),
            (lambda bond: [bt == Chem.rdchem.BondType.TRIPLE], 0),
            (lambda bond: [bt == Chem.rdchem.BondType.AROMATIC], 0),
            (lambda bond: [(bond.GetIsConjugated() if bt is not None else 0)], 1),
            (lambda bond: [(bond.IsInRing() if bt is not None else 0)], 2),
            (lambda bond: onek_encoding_unk(int(bond.GetStereo()), list(range(6))), 3)
        ]
        for constructor, flag_index in bond_features_list:
            feature = constructor(bond)
            if not keep_features[flag_index]:
                feature = [0] * len(feature)
            fbond.extend(feature)

    return fbond

In [4]:
from pprint import pprint

In [5]:
# test bond feature fuction

In [6]:
smiles = 'CCO'
mol = Chem.MolFromSmiles(smiles)
bond = mol.GetBondWithIdx(0)

In [7]:
bond_feat_1 = bond_features(bond)

In [8]:
bond_feat_1

[0, True, False, False, False, False, False, 1, 0, 0, 0, 0, 0, 0]

In [9]:
bond_feat_2 = bond_features_new(bond)

In [10]:
bond_feat_2

[0, True, False, False, False, False, False, 1, 0, 0, 0, 0, 0, 0]

In [11]:
bond_feat_1 == bond_feat_2

True

In [12]:
# test atom feature function

In [13]:
smiles = 'CCO'
mol = Chem.MolFromSmiles(smiles)
atom = mol.GetAtomWithIdx(0)

In [14]:
new_features = atom_features_new(atom)

In [15]:
len(new_features)

133

In [16]:
new_features == features

NameError: name 'features' is not defined

In [96]:
new_features_reduced = atom_features_new(atom, keep_features=[1]*8)

In [97]:
new_features_reduced

[0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0.12011]

In [3]:
features = atom_features(atom)

In [7]:
len(features)

133

In [6]:
pprint(features)

[0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0.12011]


In [17]:
import pandas as pd

In [12]:
atom_features= pd.read_csv("atom_types_hc.csv").values[:, 1:]

In [13]:
int(atom_features.shape[0]-1)

9

In [15]:
[9] * 20

[9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]

In [21]:
bond_list = []
for i in range(5):
    for k in range(3):
        if i == k:
            bond_list.append(k)

In [22]:
bond_list

[0, 1, 2]

In [20]:
np.array(bond_list).reshape(1, -1)

array([[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]])

In [1]:
import numpy as np
from shap import links, PermutationExplainer
from copy import deepcopy

# Simple binary masker function
def binary_masker(binary_mask, x):
    masked_x = deepcopy(x)
    masked_x[binary_mask == 0] = 0
    return np.array([masked_x])

# Dummy model for illustration purposes
class DummyModel:
    def __call__(self, data):
        # Simple model that sums the input data and times the fisrt element by 2
        # return np.sum(data, axis=1, keepdims=True) + data[:, 0:1] * 2
        return data[:, 0:1] * 5 + data[:, 1:2] * 3

# Create the dummy model
model = DummyModel()

# Create the PermutationExplainer
explainer = PermutationExplainer(model=model, masker=binary_masker)

# Generate an example input to explain
example_input = np.array([[2, -5, 0, 1, 1], [1, -1, 0, 3, 4]])

# Explain the example input
# explanation = explainer(np.array([example_input]), max_evals=200)
explanation = explainer(example_input, max_evals=200)


print("SHAP values:", explanation.values)


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


SHAP values: [[ 10. -15.   0.   0.   0.]
 [  5.  -3.   0.   0.   0.]]


In [7]:
DummyModel()(np.array([[1, -1, 0, 3, 4]))

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [5]:
DummyModel()(np.array([[1, -1, 0, 3, 4]])).shape

(1, 1)

In [2]:
DummyModel()(np.array([[2, -5, 0, 1, 1], [1, -1, 0, 3, 4]]))

array([[-5],
       [ 2]])

In [6]:
DummyModel()(np.array([[2, -5, 0, 1, 1], [1, -1, 0, 3, 4]])).shape

(2, 1)

In [81]:
np.array([example_input])

array([[[ 2, -5,  0,  1,  1],
        [ 1, -1,  0,  3,  4]]])

In [63]:
np.array([example_input])

array([[ 2, -5,  0,  1,  1]])

In [69]:
example_input

array([[ 2, -5,  0,  1,  1],
       [ 1, -1,  0,  3,  4]])

In [72]:
example_input.shape

(2, 5)

In [74]:
np.array([2, -5, 0, 1, 1]).shape

(5,)

In [76]:
np.array([np.array([2, -5, 0, 1, 1])]).shape

(1, 5)

In [20]:
# test mask out extra features
def mask_atom_features_extra(atom_features_extra: np.ndarray, keep_features: List[bool]) -> np.ndarray:
    """
    Masks certain features in atom_features_extra based on the keep_features boolean vector.

    :param atom_features_extra: A 2D numpy array with shape (num_atoms, feature_length) where feature_length is a multiple of 50.
    :param keep_features: A boolean vector indicating which features to keep. Each feature has a length of 50.
    :return: A masked 2D numpy array with the same shape as atom_features_extra.
    """
    num_atoms, feature_length = atom_features_extra.shape
    num_features = feature_length // 50  # Each feature has length 50

    if len(keep_features) != num_features:
        raise ValueError("Length of keep_features does not match the number of features in atom_features_extra.")

    # Create a mask to apply to atom_features_extra
    mask = np.ones_like(atom_features_extra)

    for i, keep in enumerate(keep_features):
        if not keep:
            mask[:, i*50:(i+1)*50] = 0

    # Apply the mask to atom_features_extra
    masked_features = atom_features_extra * mask

    return masked_features

In [47]:
# Dummy TrainArgs class for illustration
import sys
sys.path.insert(0, '/home/oscarwu/code/chemprop_developing')
from chemprop.args import TrainArgs

# Example SHAP Analysis with MPN

# Define the TrainArgs
args = TrainArgs()

# Create the MPN model
from chemprop.models.mpn import MPN
mpn_model = MPN(args)

# Example molecule (SMILES string)
smiles_list = ['CCO']

# Define dummy features batch, atom descriptors batch, and bond descriptors batch
# features_batch = [np.random.rand(1, 100) for _ in range(len(smiles_list))]  # Example feature length of 100
# atom_descriptors_batch = [np.random.rand(1, 100) for _ in range(len(smiles_list))]
# bond_descriptors_batch = [np.random.rand(1, 50) for _ in range(len(smiles_list))]

# Define keep_features batches
# extra_keep_features_batch = [[True] * 100 for _ in range(len(smiles_list))]  # Keep all features
# extra_atom_keep_descriptors_batch = [[True] * 100 for _ in range(len(smiles_list))]
# extra_bond_keep_descriptors_batch = [[True] * 50 for _ in range(len(smiles_list))]

# Create a SHAP explainer
def binary_masker(binary_mask, x):
    masked_x = deepcopy(x)
    masked_x[binary_mask == 0] = 0
    return np.array([masked_x])

# Define a wrapper for the model to match SHAP's expected input format
class MPNWrapper:
    def __init__(self, model):
        self.model = model

    def __call__(self, data):
        # Convert data to the expected input format for the MPN model
        batch = [[Chem.MolFromSmiles(smiles) for smiles in data]]
        features_batch = [np.random.rand(1, 100) for _ in data]
        extra_keep_features_batch = [[True] * 100 for _ in data]
        output = self.model(batch, features_batch=features_batch, shap=True,
                    extra_keep_features_batch=extra_keep_features_batch,)
        # output = self.model(batch, features_batch=features_batch, shap=True,
        #                     extra_keep_features_batch=extra_keep_features_batch,
        #                     extra_atom_keep_descriptors_batch=extra_atom_keep_descriptors_batch,
        #                     extra_bond_keep_descriptors_batch=extra_bond_keep_descriptors_batch)
        return output.detach().cpu().numpy()

# Create the MPN wrapper
model_wrapper = MPNWrapper(mpn_model)

# Create the PermutationExplainer
explainer = PermutationExplainer(model=model_wrapper, masker=binary_masker)

# Generate example input to explain
example_input = np.array(smiles_list)

# Explain the example input
explanation = explainer(example_input, max_evals=200)

print("SHAP values:", explanation.values)


TypeError: 'numpy.float64' object cannot be interpreted as an integer

In [48]:
smiles_list

['CCO']

In [49]:
model_wrapper(smiles_list)

array([[0.02869339, 0.        , 0.        , 0.08539114, 0.01695965,
        0.17654537, 0.08226229, 0.01316568, 0.01456605, 0.04406174,
        0.01382583, 0.07931754, 0.        , 0.00845726, 0.03411956,
        0.        , 0.        , 0.06085997, 0.04044246, 0.14077239,
        0.        , 0.        , 0.        , 0.05221221, 0.03106297,
        0.0462199 , 0.01961821, 0.06123255, 0.03978665, 0.00291387,
        0.        , 0.0524957 , 0.11667792, 0.        , 0.        ,
        0.05406984, 0.14364138, 0.04035055, 0.00610798, 0.        ,
        0.05862929, 0.        , 0.07890157, 0.02106198, 0.03567052,
        0.0988366 , 0.136213  , 0.0193952 , 0.0699313 , 0.04332287,
        0.        , 0.15324824, 0.01307643, 0.02672162, 0.        ,
        0.0221462 , 0.08094481, 0.00088978, 0.06145951, 0.00661707,
        0.00606289, 0.11338267, 0.        , 0.        , 0.00782273,
        0.        , 0.03817214, 0.02425379, 0.        , 0.13113678,
        0.14146711, 0.01410575, 0.        , 0.01

In [50]:
model_wrapper(smiles_list).shape

(1, 300)

In [38]:
smiles_list = ["CCO"]

In [39]:
smiles_list

['CCO']

In [40]:
MPNWrapper(mpn_model)(smiles_list)

array([[0.00000000e+00, 5.63246310e-02, 8.63388181e-04, 0.00000000e+00,
        6.29064888e-02, 1.17112815e-01, 0.00000000e+00, 9.38053895e-03,
        1.38919607e-01, 0.00000000e+00, 1.44390529e-02, 1.25650868e-01,
        4.69912998e-02, 2.02104468e-02, 6.12197332e-02, 1.72461465e-01,
        7.48388693e-02, 6.07380383e-02, 3.43471766e-04, 8.01870897e-02,
        0.00000000e+00, 0.00000000e+00, 8.22735280e-02, 0.00000000e+00,
        6.35521039e-02, 9.88672953e-03, 4.74283705e-03, 2.70644072e-02,
        0.00000000e+00, 0.00000000e+00, 3.14490497e-02, 1.26274765e-01,
        6.74669677e-03, 8.63672420e-02, 1.04836263e-01, 2.83235144e-02,
        3.67188677e-02, 1.96339726e-01, 3.72810960e-02, 1.80077422e-02,
        5.52019989e-03, 8.28011855e-02, 0.00000000e+00, 1.19934343e-01,
        1.74677949e-02, 0.00000000e+00, 2.29149610e-02, 2.11068001e-02,
        0.00000000e+00, 2.58211941e-02, 0.00000000e+00, 1.70654431e-03,
        5.64640649e-02, 9.87240486e-03, 3.78186479e-02, 0.000000

In [None]:
smiles_list

In [41]:
MPNWrapper(mpn_model)(smiles_list).shape

(1, 300)

In [43]:
smiles_list = ['CCO', "CCC"]


In [44]:
testout = MPNWrapper(mpn_model)(smiles_list)

In [45]:
len(testout)

1

In [46]:
testout.shape

(1, 300)

In [52]:
class MPNWrapper:
    def __init__(self, model, smiles, features_batch):
        self.model = model
        self.smiles = smiles
        self.features_batch = features_batch

    def __call__(self, data):
        # Convert data to the expected input format for the MPN model
        batch = [[Chem.MolFromSmiles(smiles) for smiles in self.smiles]]
        extra_keep_features_batch = data.tolist()  # Convert numpy array to list of lists
        output = self.model(batch, features_batch=self.features_batch, shap=True,
                            extra_keep_features_batch=extra_keep_features_batch)
        return output.detach().cpu().numpy()

# Define dummy features batch
features_batch = [np.random.rand(1, 100) for _ in range(len(smiles_list))]  # Example feature length of 100

# Create the MPN wrapper
model_wrapper = MPNWrapper(mpn_model, smiles_list, features_batch)

# Create the PermutationExplainer
explainer = PermutationExplainer(model=model_wrapper, masker=binary_masker)

# Generate example input to explain
example_input = np.array([[True] * 100])  # Example input for keep_features

# Explain the example input
explanation = explainer(example_input, max_evals=400)

print("SHAP values:", explanation.values)

AssertionError: The model produced 1 output rows when given 10 input rows! Check the implementation of the model you provided for errors.

In [56]:
model_wrapper(example_input)

array([[3.56964059e-02, 0.00000000e+00, 0.00000000e+00, 5.23746498e-02,
        0.00000000e+00, 3.79013494e-02, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 4.74891551e-02, 3.12048331e-04, 1.65173598e-03,
        1.96289569e-02, 2.56350115e-02, 5.89957135e-03, 0.00000000e+00,
        1.11671025e-02, 0.00000000e+00, 1.67639051e-02, 1.62893891e-01,
        1.66621469e-02, 9.76122320e-02, 0.00000000e+00, 6.98851841e-03,
        0.00000000e+00, 0.00000000e+00, 3.45444120e-03, 0.00000000e+00,
        9.08087194e-02, 0.00000000e+00, 0.00000000e+00, 6.29136851e-03,
        1.56819134e-03, 0.00000000e+00, 0.00000000e+00, 2.52928287e-02,
        0.00000000e+00, 0.00000000e+00, 1.36682823e-01, 3.29912417e-02,
        4.22598124e-02, 2.80275326e-02, 1.65742654e-02, 5.25233895e-03,
        6.13102615e-02, 2.78039668e-02, 0.00000000e+00, 0.00000000e+00,
        1.15126707e-02, 3.95655073e-02, 1.18175847e-02, 0.00000000e+00,
        2.29185060e-01, 0.00000000e+00, 1.06022917e-01, 1.527071

In [57]:
model_wrapper(example_input).shape

(1, 300)

In [58]:
model_wrapper(np.array([[False] * 100]))

array([[3.56964059e-02, 0.00000000e+00, 0.00000000e+00, 5.23746498e-02,
        0.00000000e+00, 3.79013494e-02, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 4.74891551e-02, 3.12048331e-04, 1.65173598e-03,
        1.96289569e-02, 2.56350115e-02, 5.89957135e-03, 0.00000000e+00,
        1.11671025e-02, 0.00000000e+00, 1.67639051e-02, 1.62893891e-01,
        1.66621469e-02, 9.76122320e-02, 0.00000000e+00, 6.98851841e-03,
        0.00000000e+00, 0.00000000e+00, 3.45444120e-03, 0.00000000e+00,
        9.08087194e-02, 0.00000000e+00, 0.00000000e+00, 6.29136851e-03,
        1.56819134e-03, 0.00000000e+00, 0.00000000e+00, 2.52928287e-02,
        0.00000000e+00, 0.00000000e+00, 1.36682823e-01, 3.29912417e-02,
        4.22598124e-02, 2.80275326e-02, 1.65742654e-02, 5.25233895e-03,
        6.13102615e-02, 2.78039668e-02, 0.00000000e+00, 0.00000000e+00,
        1.15126707e-02, 3.95655073e-02, 1.18175847e-02, 0.00000000e+00,
        2.29185060e-01, 0.00000000e+00, 1.06022917e-01, 1.527071

In [59]:
model_wrapper(np.array([[False] * 100])) == model_wrapper(np.array([[True] * 100]))

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,
      

In [61]:
# Dummy TrainArgs class for illustration
import sys
sys.path.insert(0, '/home/oscarwu/code/chemprop_developing')
from chemprop.args import TrainArgs

# Example SHAP Analysis with MPN

# Define the TrainArgs
args = TrainArgs()

# Create the MPN model
from chemprop.models.mpn import MPN
mpn_model = MPN(args)

class MPNWrapper:
    def __init__(self, model, smiles, features_batch):
        self.model = model
        self.smiles = smiles
        self.features_batch = features_batch
        self.mol_feature = mol_feature_scaler.transform(mol_feature)
        

    def __call__(self, data):
        # Convert data to the expected input format for the MPN model
        batch = [[Chem.MolFromSmiles(smiles) for smiles in self.smiles]]
        extra_keep_features_batch = data.tolist()  # Convert numpy array to list of lists
        output = self.model(batch, features_batch=self.features_batch, shap=True,
                            extra_keep_features_batch=extra_keep_features_batch)
        
        output = target_scaler.inverse_transform(output)
        return output.detach().cpu().numpy()

# Define dummy features batch
features_batch = [np.random.rand(1, 100) for _ in range(len(smiles_list))]  # Example feature length of 100

# Create the MPN wrapper
model_wrapper = MPNWrapper(mpn_model, smiles_list, features_batch)

# Create the PermutationExplainer
explainer = PermutationExplainer(model=model_wrapper, masker=binary_masker)

# Generate example input to explain
example_input = np.array([[True] * 200])  # Example input for keep_features


In [66]:
model_wrapper(np.array([[False] * 500]))

array([[2.28082836e-02, 0.00000000e+00, 0.00000000e+00, 1.38110006e-02,
        3.98440808e-02, 3.11706942e-02, 6.13108762e-02, 0.00000000e+00,
        7.57831261e-02, 0.00000000e+00, 5.32255461e-03, 8.37425813e-02,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.06118761e-01, 2.77969390e-02, 0.00000000e+00, 2.00244952e-02,
        1.69296507e-02, 7.81310871e-02, 0.00000000e+00, 0.00000000e+00,
        3.00733801e-02, 7.16054291e-02, 9.83800460e-03, 1.77742854e-01,
        6.53687790e-02, 1.58211438e-03, 4.07660007e-02, 0.00000000e+00,
        1.63837913e-02, 0.00000000e+00, 8.82048830e-02, 4.45606895e-02,
        3.31034325e-02, 2.16282159e-03, 0.00000000e+00, 9.47360974e-03,
        0.00000000e+00, 7.56702498e-02, 3.38848084e-02, 9.77005158e-03,
        1.25909507e-01, 7.63905272e-02, 8.34982377e-03, 6.94973534e-03,
        6.48947954e-02, 4.34623808e-02, 1.75422907e-01, 0.00000000e+00,
        0.00000000e+00, 1.56332050e-02, 1.66949257e-02, 0.000000

In [None]:
# need to change model.py moleculemodel to adapt with an wrapper for shap
# atom/bond feature/descriptor directly read from npz
# mol feature pre-generate from chemprop 
# applying to scaling to model prediction outcome in call 
