In [1]:
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import BRICS
import plotly.express as px
import plotly.io as pio
import XAIChem

pio.templates.default = "seaborn"

In [None]:
data = pd.read_csv("../../data/ESOL/ESOL.csv")
attributions_brics = pd.read_json("../../data/ESOL/attribution_brics_no_mean.json")
attributions_brics["atom_ids"] = attributions_brics.atom_ids.apply(tuple)

In [2]:
def brics_decompose(smiles: str) -> list:
    """
    Break all BRICS bond present in the molecule and return a pandas 
    dataframe containing the molecule smiles, substructure atom ids and 
    substructure rdkit molecule object
    """

    rdmol = Chem.MolFromSmiles(smiles)

    # Break BRICS bonds
    broken_mol = BRICS.BreakBRICSBonds(rdmol)
    substructures = Chem.GetMolFrags(broken_mol, asMols=True)
    substructures_atom_ids = Chem.GetMolFrags(broken_mol)

    out = []

    # Remove the atom id of the linking atom
    for i, substruct_atom_ids in enumerate(substructures_atom_ids):
        substruct_atom_ids = tuple(
            atom_id
            for atom_id in substruct_atom_ids
            if broken_mol.GetAtomWithIdx(atom_id).GetSymbol() != "*"
        )

        out.append([smiles, substruct_atom_ids, substructures[i]])
    
    return pd.DataFrame(data=out, columns=["molecule_smiles", "atom_ids", "brics_rdmol"])

In [None]:
# Get all BRICS rdkit molecule objects with there respective paranet r e s p e c t i v e atom ids 
brics_fragments = pd.concat(attributions_brics.molecule_smiles.drop_duplicates().apply(brics_decompose).to_list())
brics_fragments.head()

In [None]:
# Join brics molecules with attribution dataframe
brics_fragments = brics_fragments.join(
    attributions_brics.set_index(
        ["molecule_smiles", "atom_ids"]
    ), 
    on=["molecule_smiles", "atom_ids"]
)[[
    "molecule_smiles",
    "atom_ids",
    "brics_rdmol",
    "substruct_smiles",
    "SME",
    "Shapley_value",
    "HN_value"
]]

In [None]:
print(f"Positive SME attribution: {len(sme_brics_pos := brics_fragments.query('SME >= 0').drop_duplicates('substruct_smiles'))}")
print(f"Negative SME attribution: {len(sme_brics_neg := brics_fragments.query('SME < 0').drop_duplicates('substruct_smiles'))}")
print()
print(f"Positive Shapley attribution: {len(shapley_brics_pos := brics_fragments.query('Shapley_value >= 0').drop_duplicates('substruct_smiles'))}")
print(f"Negative Shapley attribution: {len(shapley_brics_neg := brics_fragments.query('Shapley_value < 0').drop_duplicates('substruct_smiles'))}")
print()
print(f"Positive HN attribution: {len(hn_brics_pos := brics_fragments.query('HN_value >= 0').drop_duplicates('substruct_smiles'))}")
print(f"Negative HN attribution: {len(hn_brics_neg := brics_fragments.query('HN_value < 0').drop_duplicates('substruct_smiles'))}")

In [None]:
builder = BRICS.BRICSBuild(sme_brics_pos.brics_rdmol.to_list())

for _ in range(1000):
    display(next(builder))