In [1]:
import os
from shutil import rmtree
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG=True
IPythonConsole.drawOptions.addAtomIndices=True
from tqdm import tqdm
import numpy as np
import pandas as pd

from serenityff.torsion.tree.dash_tree import DASHTorsionTree
from serenityff.charge.tree.atom_features import AtomFeatures
from serenityff.torsion.tree_develop.tree_constructor import Torsion_tree_constructor

In [2]:
data_file="../data/example_extraction.csv"
sdf_suply="../data/example.sdf" 
data_split = 0.2
out_folder = "./example_tree_out"
if os.path.exists(out_folder):
    rmtree(out_folder)
os.mkdir(out_folder)

In [3]:
mol_supplier = Chem.SDMolSupplier(sdf_suply)
df_test = pd.read_csv(data_file)

In [4]:
tree_constructor = Torsion_tree_constructor(df_path=data_file,
                sdf_suplier=sdf_suply,
                num_layers_to_build=16,
                data_split=data_split,
                verbose=True,
                sanitize=True,
                sanitize_charges=True)

2024-01-09 21:02:18.878920	Initializing Tree_constructor
2024-01-09 21:02:18.880769	Mols imported, starting df import
2024-01-09 21:02:18.888967	Sanitizing
2024-01-09 21:02:18.904964	Check charge sanity


100%|██████████| 20/20 [00:00<00:00, 4298.54it/s]

Number of wrong charged mols: 0 of 20 mols
2024-01-09 21:02:18.914880	df imported, starting data spliting
2024-01-09 21:02:18.915586	Splitting data
2024-01-09 21:02:18.917216	Data split, delete original
2024-01-09 21:02:18.938560	Starting table filling



100%|██████████| 334/334 [00:00<00:00, 2397.17it/s]

2024-01-09 21:02:19.087558	Table filled, starting adjacency matrix creation





Creating Adjacency matrices:


100%|██████████| 20/20 [00:00<00:00, 1506.25it/s]

Number of train mols: 16
Number of test mols: 4





Found 142 torsions in the dataset
Created a dataframe with (142, 9) torsions


In [5]:
tree_constructor.df

Unnamed: 0,atomtype,idx_in_mol,mol_index,node_attentions,truth,h_connectivity,connected_atoms,total_connected_attention,atom_feature
20,O,0,1,"[0.08731412076250168, 0.07575401684894212, 0.0...",-0.484899,-1,"[2, 1, 3, 4]",0.599979,156553111
21,C,1,1,"[0.07533094601506285, 0.07949683829561092, 0.0...",0.000381,-1,"[3, 8, 7, 6]",0.392914,49436500
26,N,2,1,"[0.0760680088577648, 0.07624203235535834, 0.07...",-0.000414,-1,"[8, 7, 6, 5]",0.614799,100276857
25,C,3,1,"[0.07952590797072513, 0.0767484929016264, 0.07...",-0.000025,-1,"[7, 6, 5, 4]",0.387219,49849837
24,C,4,1,"[0.07947201368734608, 0.07641943049244446, 0.0...",0.000456,-1,"[6, 5, 4, 3]",0.357515,49849836
...,...,...,...,...,...,...,...,...,...
406,C,2,19,"[0.03662447512365832, 0.037163860413569255, 0....",-0.490934,-1,"[2, 3, 4, 5]",0.36673,60418214
407,C,3,19,"[0.035285019837321366, 0.03745373971211091, 0....",-0.189262,-1,"[3, 4, 5, 6]",0.369222,151195731
410,O,4,19,"[0.029042437193256325, 0.040045953766984284, 0...",-0.007612,-1,"[6, 5, 7, 8]",0.60137,151194993
409,C,5,19,"[0.019444394729595877, 0.030420347852354694, 0...",0.498968,-1,"[5, 7, 8, 9]",0.340875,58512325


In [6]:
tree_constructor.create_tree_level_0()

Preparing Dataframe:
Number of unique atom features in df: 92
Creating Tree Level 0:


100%|██████████| 92/92 [00:00<00:00, 1788.57it/s]

2024-01-09 21:02:19.500325	Layer 0 done





In [7]:
tree_constructor.build_tree(num_processes=1)

AF=156553111 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(8, 1, 0, True, 0) (1, 4)]
AF 156553111 done
AF=49436500 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(6, 3, 0, True, 0) (0, 4)]
AF 49436500 done
AF=100276857 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(6, 3, 0, True, 1) (3, 4)]
AF 100276857 done
AF=49849837 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(6, 3, 0, True, 0) (3, 4)]
AF 49849837 done
AF=49849836 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(7, 2, 0, True, 0) (3, 4)]
AF 49849836 done
AF=100276735 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(6, 3, 0, True, 1) (0, 4)]
AF 100276735 done
AF=49849715 - Layer 1 done
children layer 1: [node --- lvl: 2, Num=1, Mean=0.0000, std=0.0000, fp=(6, 3, 0, True, 0) (1, 4)]
AF 49849715 done


In [8]:
print(tree_constructor.root.children[26])
print(tree_constructor.root.children[26].children)
print(tree_constructor.root.children[26].children[1].children)

node --- lvl: 1, Num=1, fp=120326337)
[node --- lvl: 2, Num=1, Mean=-0.0325, std=0.0000, fp=(6, 4, 0, False, 3) (1, 1)]


IndexError: list index out of range

In [None]:
tree_constructor.convert_tree_to_node(tree_folder_path="./example_tree_out")

In [None]:
example_tree = DASHTorsionTree(tree_folder_path="./example_tree_out")

In [None]:
for branch in example_tree.tree_storage:
    if len(example_tree.tree_storage[branch]) > 1:
        print(f"Branch {branch} has {len(example_tree.tree_storage[branch])} nodes")

In [None]:
mol_idx_test = tree_constructor.test_df.mol_index.unique().tolist()

In [None]:
example_tree.tree_storage[6]

In [None]:
matched_node_path = example_tree.match_new_atom(atom=0, mol=mol)

In [None]:
matched_node_path[0]

In [None]:
example_tree.get_atom_properties(matched_node_path=matched_node_path)

In [None]:
example_tree_charges = []
example_ref_charges = []
for mol_idx in mol_idx_test:
    try:
        mol = Chem.SDMolSupplier(sdf_suply, removeHs=False)[mol_idx]
        example_tree_charges.extend(example_tree.get_molecules_partial_charges(mol)["charges"])
        example_ref_charges.extend(tree_constructor.test_df[tree_constructor.test_df.mol_index == mol_idx].truth.values)  
    except:
        print(f"Failed mol with index {mol_idx}")

In [None]:
df = pd.DataFrame({"tree":example_tree_charges, "ref":example_ref_charges})

In [None]:
df

In [None]:
ax=df.plot.scatter(x="ref", y="tree", figsize=(6,6), xlim=(-1,1), ylim=(-1,1))
ax.set_aspect('equal')
ax.plot([-1,1],[-1,1], color="grey", linestyle="--")
ax.set_xlabel("Reference charges [e]")
ax.set_ylabel("Tree charges [e]")
ax.set_title(f"Example tree charge correlation\n RMSE: {np.sqrt(np.mean((df.tree-df.ref)**2)):.3f} e")

It can be seen, that even this very simple example tree can already assign partial charges with a decent accuracy. However, the small tree can not yet assign all molecules and with atoms with atom features not present in the training set the assigned charges will be NaN. In addition to the assigned charges we can also retrive the assignment path, the total attention values and the statistical errors.

## Test an example molecule with the example tree

In [None]:
example_mol = Chem.SDMolSupplier(sdf_suply, removeHs=False)[1]

In [None]:
example_tree.match_new_atom(1, example_mol)

In [None]:
node_path = example_tree.match_new_atom(1, example_mol)
branch_idx = node_path[0]
print(f"Branch index: {branch_idx}")
print(f"Node path: {node_path}")

In [None]:
charges_in_path = []
counts_in_path = []
attention_in_path = []
stdDev_in_path = []
df = example_tree.data_storage[branch_idx]
for atom in node_path[1:]:
    charges_in_path.append(df.iloc[atom]["result"])
    counts_in_path.append(df.iloc[atom]["size"])
    attention_in_path.append(df.iloc[atom]["max_attention"])
    stdDev_in_path.append(df.iloc[atom]["stdDeviation"])

We can also observe, how the assigned partial charge slowly converges to the final value along the node path trough the tree. 

In [None]:
ax = pd.Series(charges_in_path).fillna(0).plot.line(label="Charge [e]")
ax2 = pd.Series(counts_in_path).plot.line(secondary_y=True, ax=ax, color="red", label="Counts in node")
ax.set_xlabel("Tree depth along path")
ax.set_ylabel("Charge [e]")
ax.set_title("Example path in tree")
ax.right_ax.set_ylabel("Counts in node")
ax.legend(handles=[ax.lines[0], ax2.lines[0]])

The above figure shows the number of counts decresing along the path through the tree, and going to 1, since our example dataset is so small, that we basically recover the full querry molecule. With a larger number of molecules in the dataset the averaging would be better, and we could use a certain attention value, collected along the path as stop marker for the matching to avoid overfitting.

In [None]:
ax = pd.Series(attention_in_path).cumsum().fillna(0).plot.line(label="Attention")
ax.set_xlabel("Tree depth along path")
ax.set_ylabel("Attention")
ax.set_title("Cumulative attention along path in the example tree")

The threshold for the total attention can be tuned as meta parameter of the matching process in the tree.  

We can also plot the standard deviation in the nodes along the path, al;though it should be mentioned, that in a small example they are not very meaningfull.

In [None]:
ax = pd.Series(stdDev_in_path).plot.line(label="StdDev [e]")
ax.set_xlabel("Tree depth along path")
ax.set_ylabel("StdDev [e]")
ax.set_title("StdDev along path in the example tree")

In [None]:
example_mol

In [None]:
example_tree.explain_property(mol=example_mol, atom=1)

In this case we selected the C atom of a carboxylic acid group and it can be seen, that this functional group is recovered in the path of the tree. The highest attention is on the double bond to the oxygen atom (atom 1 in the fragment), followed by the OH group (atom 2), the C atom (atom 3) and then the aromatic ring. This roughly corresponds to the chemical knowldge of a carboxylic acid group, but it is quantified by the attention values.

Clean up all the files created by this notebook

In [None]:
rmtree(out_folder)