# Notebook on Getting Explanations from a trained model


In [104]:
from bondnet.model.training_utils import (
    get_grapher,
    load_model_lightning,
)
from bondnet.data.utils import mol_graph_to_rxn_graph
import dgl

config = {
    "model": {
        "restore": True,
        "restore_path": "../../../tests/model/test_save_load/test.ckpt",  # path to ckpt
    }
}

model_restart = load_model_lightning(config["model"], load_dir="./test_save_load/")

:::RESTORING MODEL FROM EXISTING FILE:::
NB: using GatedGCNConv
:::MODEL LOADED:::


In [121]:
from bondnet.data.datamodule import BondNetLightningDataModule

dataset_loc = "../../../tests/data/testdata/barrier_100.json"

config = {
    "dataset": {
        "data_dir": dataset_loc,
        "target_var": "dG_barrier",
    },
    "model": {
        "extra_features": ["bond_length"],
        "extra_info": [],
        "debug": False,
        "classifier": False,
        "classif_categories": 3,
        "filter_species": [3, 6],
        "filter_outliers": False,
        "filter_sparse_rxns": False,
        "restore": False,
    },
    "optim": {
        "val_size": 0.0,
        "test_size": 0.85,
        "batch_size": 1,
        "num_workers": 1,
    },
}

dm = BondNetLightningDataModule(config)
feat_size, feat_name = dm.prepare_data()
dm.setup(stage="validate")

reading file from: ../../../tests/data/testdata/barrier_100.json
rxn raw len: 100
Program finished in 0.794138582990854 seconds
.............failures.............
reactions len: 100
valid ind len: 100
bond break fail count: 		0
default fail count: 		0
sdf map fail count: 		0
product bond fail count: 	0
about to group and organize
number of grouped reactions: 100
features: 240
labels: 100
molecules: 240
constructing graphs & features....
number of graphs valid: 240
number of graphs: 240


In [282]:
embedding_list = []

nodes = ["atom", "bond", "global"]
embedding_size = model_restart.hparams.embedding_size
batch_size = len(dm.train_ds)

direct_concat_name = model_restart.hparams.set2set_ntypes_direct
gat_out = model_restart.hparams.gated_hidden_size[-1]
readout_out_size = gat_out * 2 + gat_out * 2
readout_out_size += gat_out * len(direct_concat_name)
targets = []
for it, (batched_graph, label) in enumerate(dm.train_dataloader()):
    feats = {nt: batched_graph.nodes[nt].data["feat"] for nt in nodes}

    target = label["value"].view(-1)

    norm_atom = label["norm_atom"]
    norm_bond = label["norm_bond"]
    stdev = label["scaler_stdev"]
    reactions = label["reaction"]
    targets.append(target)
    embeddings = model_restart.feature_before_fc(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
    )

    graph_rxn, feats_rxn = mol_graph_to_rxn_graph(
        graph=batched_graph,
        feats=feats,
        reactions=reactions,
        reverse=False,
    )
    # print(reactions)
    print(reactions[0].id)

['10527514179105274']
['637705095163771']
['674486061767449']
['12203892800122039']
['13334949506133350']
['139105113297139106']
['638184940563819']
['142070117901142069']
['141512115712141511']
['127302109520127303']
['645255652164524']
['142259118367142260']
['12535392445125355']
['317061459631705']
['668205595666821']


In [283]:
target

tensor([-0.6612])

In [133]:
pred = model_restart(
    graph=batched_graph,
    feats=feats,
    reactions=reactions,
    norm_atom=norm_atom,
    norm_bond=norm_bond,
)

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256])

In [124]:
pred_loss = model_restart.loss(pred.view(-1), target)
pred_loss.backward()

RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([3]) and torch.Size([1]).

In [292]:
import torch

from bondnet.data.utils import mol_graph_to_rxn_graph


def _split_batched_output(graph, value, n_type="bond"):
    """
    Split a tensor into `num_graphs` chunks, the size of each chunk equals the
    number of bonds in the graph.

    Returns:
        list of tensor.

    """
    nbonds = tuple(graph.batch_num_nodes(n_type).tolist())
    print(n_type, nbonds)
    split_tensor = torch.split(value, nbonds)
    # print(len(split_tensor))
    return split_tensor


def feature_at_each_layer(model, graph, feats, reactions, norm_atom, norm_bond):
    """
    Get the features at each layer before the final fully-connected layer.

    This is used for feature visualization to see how the model learns.

    Returns:
        dict: (layer_idx, feats), each feats is a list of
    """

    layer_idx = 0

    bond_feats = {}
    atom_feats = {}
    global_feats = {}

    feats = model.embedding(feats)
    bond_feats[layer_idx] = _split_batched_output(graph, feats["bond"], "bond")
    atom_feats[layer_idx] = _split_batched_output(graph, feats["atom"], "atom")
    global_feats[layer_idx] = _split_batched_output(graph, feats["global"], "global")

    layer_idx += 1

    # gated layer
    for layer in model.gated_layers:
        feats = layer(graph, feats, norm_atom, norm_bond)

        # store bond feature of each molecule
        bond_feats[layer_idx] = _split_batched_output(graph, feats["bond"], "bond")

        atom_feats[layer_idx] = _split_batched_output(graph, feats["atom"], "atom")

        global_feats[layer_idx] = _split_batched_output(
            graph, feats["global"], "global"
        )

        layer_idx += 1

    return bond_feats, atom_feats, global_feats


# outputs are reactant then product feats for each node type

bond_feats, atoms_feats, global_feats = feature_at_each_layer(
    model=model_restart,
    graph=batched_graph,  # for a given set of graph feats
    feats=feats,  # and a given set of feats
    reactions=reactions,
    norm_atom=norm_atom,
    norm_bond=norm_bond,
)
unbatched_graph = dgl.unbatch(batched_graph)
# unbatched_graph = dgl.unbatch(batched_graph)
# print(len(unbatched_graph))

bond (10, 10)
atom (10, 10)
global (1, 1)
bond (10, 10)
atom (10, 10)
global (1, 1)
bond (10, 10)
atom (10, 10)
global (1, 1)
2


In [293]:
import numpy as np


def compute_person(features):
    """
    Compute the Pearson correlation coefficient.

    Args:
        ntype (str): the node type of the graph where the features are stored, e.g.
            `atom` and `bond`.
        exclude (list, optional): indices of features to ignore. This is useful to
            exclude features with 0 stdandard deviation. If `None`, nothing to
            exclude. Defaults to None.
    Returns:
        2D array: correlation between features
    """

    features = features.detach().numpy()
    # remove features with 0 standard deviation
    remove_idx = []
    for i in range(features.shape[1]):
        if np.std(features[:, i]) == 0:
            remove_idx.append(i)
    # remove features on remove_idx
    features = np.delete(features, remove_idx, axis=1)
    corr = np.corrcoef(features)
    return corr


compute_corr = compute_person(atoms_feats[0][0])

In [294]:
compute_corr.shape

(10, 10)

In [295]:
12203892800122039
reactions[0].id

['668205595666821']

In [300]:
type(reactions[0])

bondnet.data.reaction_network.ReactionInNetwork

In [305]:
print(
    reactions[0].init_reactants,
    reactions[0].init_products,
)


def mols_from_reaction(id=0):
    product_wrapper = []
    reactant_wrapper = []

    for i in reactions[id].init_reactants:
        print(i)

        reactant_wrapper.append(
            dm.train_ds.dataset.reaction_network.molecule_wrapper[i]
        )
        print(i, dm.train_ds.dataset.reaction_network.molecule_wrapper[i].id)

    for i in reactions[id].init_products:
        # dm.train_ds.dataset.reaction_network.molecule_wrapper[i]
        # print(i)
        product_wrapper.append(dm.train_ds.dataset.reaction_network.molecule_wrapper[i])

        print(i, dm.train_ds.dataset.reaction_network.molecule_wrapper[i].id)

    pmg_reactants = [i.pymatgen_mol for i in reactant_wrapper]
    pmg_products = [i.pymatgen_mol for i in product_wrapper]
    return pmg_reactants, pmg_products


pmg_reactants, pmg_products = mols_from_reaction(id=0)

[23] [24]
23
23 66821
24 66820


In [308]:
pmg_reactants, pmg_products = mols_from_reaction(id=0)

23
23 66821
24 66820


In [307]:
pmg_reactants

[Molecule Summary
 Site: N (1.7704, -0.3109, 0.0006)
 Site: C (0.4056, -0.1228, 0.0414)
 Site: N (-0.4637, -1.1234, -0.0685)
 Site: N (-1.6480, -0.5259, -0.0411)
 Site: N (-1.6232, 0.7754, 0.0793)
 Site: C (-0.3303, 1.0821, 0.1339)
 Site: H (2.3172, 0.4137, 0.4460)
 Site: H (2.0791, -1.2401, 0.2566)
 Site: H (-2.5130, -1.0481, -0.1018)
 Site: H (0.0178, 2.1058, 0.2341)]

In [313]:
# plot pymatgen mol
from ase.visualize.plot import plot_atoms
from ase.visualize import view
from io import StringIO
from ase.io import read

xyz_string = pmg_products[0].to("xyz")
sdf_string = pmg_products[0].to("sdf")
molblock = pmg_products[0].to("mol")
f = StringIO(xyz_string)
atoms = read(f, format="xyz")

# view(atoms, rotation='10x,20y,30z')
view(atoms)

<subprocess.Popen at 0x7f82700fa0d0>

In [314]:
molblock

'\n OpenBabel09072312213D\n\n 10 10  0  0  0  0  0  0  0  0999 V2000\n    1.7693   -0.3176    0.0764 N   0  0  0  0  0  0  0  0  0  0  0  0\n    0.4058   -0.1216    0.0285 C   0  0  0  0  0  0  0  0  0  0  0  0\n   -0.4637   -1.1236   -0.0667 N   0  0  0  0  0  0  0  0  0  0  0  0\n   -1.6482   -0.5274   -0.0237 N   0  0  0  0  0  0  0  0  0  0  0  0\n   -1.6234    0.7745    0.0892 N   0  0  0  0  0  0  0  0  0  0  0  0\n   -0.3302    1.0825    0.1286 C   0  0  0  0  0  0  0  0  0  0  0  0\n    2.3275    0.4727   -0.2170 H   0  0  0  0  0  0  0  0  0  0  0  0\n    2.0881   -1.1880   -0.3295 H   0  0  0  0  0  0  0  0  0  0  0  0\n   -2.5133   -1.0500   -0.0801 H   0  0  0  0  0  0  0  0  0  0  0  0\n    0.0181    2.1070    0.2213 H   0  0  0  0  0  0  0  0  0  0  0  0\n  2  1  1  0  0  0  0\n  2  6  1  0  0  0  0\n  3  4  1  0  0  0  0\n  3  2  2  0  0  0  0\n  4  5  1  0  0  0  0\n  5  6  2  0  0  0  0\n  6 10  1  0  0  0  0\n  7  1  1  0  0  0  0\n  8  1  1  0  0  0  0\n  9  4  1  0 

In [326]:
# convert xyz to mol in rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D


mol = Chem.MolFromMolBlock(molblock)

In [340]:
for ind, atom in enumerate(mol.GetAtoms()):
    corr_i = compute_corr[0][ind]
    atom.SetProp("atomNote", "corr: {:.2f}".format(corr_i))

# show hhydrogens in plot
# mol = Chem.AddHs(mol)
size = (240, 240)
fig = Draw.ShowMol(
    mol, 
    size=size, 
    kekulize=False, 
    wedgeBonds=False, 
    showAtomNumbers=False,
    options=

)

In [244]:
# Gradient Analysis

In [None]:
def grad_at_each_layer(model, graph, feats, reactions, norm_atom, norm_bond, target):
    """
    Get the features at each layer before the final fully-connected layer.

    This is used for feature visualization to see how the model learns.

    Returns:
        dict: (layer_idx, feats), each feats is a list of
    """

    layer_idx = 0
    bond_grads = {}
    atom_grads = {}
    global_grads = {}

    pred = model(
        graph=graph,
        feats=feats,
        reactions=reactions,
        norm_atom=norm_atom,
        norm_bond=norm_bond,
    )
    dict_grads = {"embedding": {}, "gated": {}}

    model.loss(pred.view(-1), target).backward()
    linear_index = 0
    dict_grads["embedding"]["atom"] = model.embedding.linears["atom"].weight.grad
    dict_grads["embedding"]["bond"] = model.embedding.linears["bond"].weight.grad
    dict_grads["embedding"]["global"] = model.embedding.linears["global"].weight.grad

    # gated layer
    for layer in model.gated_layers:
        for name, param in layer.A.named_parameters():
            dict_grads["gated"]["A_" + name] = param.grad
        for name, param in layer.B.named_parameters():
            dict_grads["gated"]["B_" + name] = param.grad
        for name, param in layer.C.named_parameters():
            dict_grads["gated"]["C_" + name] = param.grad
        for name, param in layer.D.named_parameters():
            dict_grads["gated"]["D_" + name] = param.grad
        for name, param in layer.E.named_parameters():
            dict_grads["gated"]["E_" + name] = param.grad
        for name, param in layer.F.named_parameters():
            dict_grads["gated"]["F_" + name] = param.grad
        for name, param in layer.G.named_parameters():
            dict_grads["gated"]["G_" + name] = param.grad
        for name, param in layer.H.named_parameters():
            dict_grads["gated"]["H_" + name] = param.grad
        for name, param in layer.I.named_parameters():
            dict_grads["gated"]["I_" + name] = param.grad

    return dict_grads


dict_grads = grad_at_each_layer(
    model=model_restart,
    graph=batched_graph,
    feats=feats,
    reactions=reactions,
    norm_atom=norm_atom,
    norm_bond=norm_bond,
    target=target,
)

In [None]:
import numpy as np
from sklearn.preprocessing import MinMaxScaler


# this operates at the node level???
# this is just what nodes are activated
def saliency_map(input_grads):
    node_saliency_map = []
    for n in range(input_grads.shape[0]):  # nth node
        node_grads = input_grads[n, :]
        node_saliency = torch.norm(F.relu(node_grads)).item()
        node_saliency_map.append(node_saliency)
    return node_saliency_map


def grad_cam(final_conv_acts, final_conv_grads):
    # print('grad_cam')
    node_heat_map = []
    alphas = torch.mean(
        final_conv_grads, axis=0
    )  # mean gradient for each feature (512x1)
    for n in range(final_conv_acts.shape[0]):  # nth node
        node_heat = F.relu(alphas @ final_conv_acts[n]).item()
        node_heat_map.append(node_heat)
    return node_heat_map


def ugrad_cam(n_atoms, final_conv_acts, final_conv_grads):
    # print('new_grad_cam')
    node_heat_map = []
    alphas = torch.mean(
        final_conv_grads, axis=0
    )  # mean gradient for each feature (512x1)
    for n in range(final_conv_acts.shape[0]):  # nth node
        node_heat = (alphas @ final_conv_acts[n]).item()
        node_heat_map.append(node_heat)

    node_heat_map = np.array(node_heat_map[:n_atoms]).reshape(-1, 1)
    pos_node_heat_map = (
        MinMaxScaler(feature_range=(0, 1))
        .fit_transform(node_heat_map * (node_heat_map >= 0))
        .reshape(
            -1,
        )
    )
    neg_node_heat_map = (
        MinMaxScaler(feature_range=(-1, 0))
        .fit_transform(node_heat_map * (node_heat_map < 0))
        .reshape(
            -1,
        )
    )
    return pos_node_heat_map + neg_node_heat_map