In [1]:
from rdkit import Chem
from tqdm import tqdm
from serenityff.charge.tree.dash_tree import DASHTree
%load_ext line_profiler

In [2]:
tree = DASHTree(num_processes=1)
# n=1 19.5
# n=2 30.6
# n=3 31.6

Loading DASH tree data


In [3]:
mol_sup = Chem.SDMolSupplier("../data/example.sdf", removeHs=False)

In [4]:
tree.match_new_atom(0, mol_sup[0])

[34,
 0,
 167331,
 167332,
 167333,
 169504,
 170309,
 170536,
 170727,
 170731,
 170732,
 170734]

In [5]:
tree.get_molecules_partial_charges(mol_sup[4])["charges"]

[-0.5931196730441903,
 0.7851668913312818,
 -0.6431413414085979,
 -0.015186619325404951,
 -0.17372044416049043,
 -0.08634567360691903,
 -0.13202160869163443,
 0.14872200355130397,
 -0.5883753153625102,
 0.1707070137686761,
 -0.42315670330341354,
 0.5095730476572767,
 0.19047164368866124,
 0.16725630351830967,
 0.18996089087651563,
 0.49320958451113567]

In [6]:
for mol in tqdm(mol_sup, total=len(mol_sup)):
    tree.get_molecules_partial_charges(mol)

100%|██████████| 20/20 [00:01<00:00, 17.26it/s]


In [15]:
%timeit [tree.get_molecules_partial_charges(mol) for mol in mol_sup]

336 ms ± 9.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
# timing baseline (04.09.2023)                                    891 ms ± 67.2 
# timing with AtomFeature changed to tuples                       874 ms ± 210 ms
# timing with precalculated atom features                         515 ms ± 184 ms
# timing with stored neighbor list                                

In [9]:
num_atoms_per_mol = [mol.GetNumAtoms() for mol in mol_sup]
print(f"Average number of atoms per molecule: {sum(num_atoms_per_mol)/len(num_atoms_per_mol)} (min: {min(num_atoms_per_mol)}, max: {max(num_atoms_per_mol)})")

Average number of atoms per molecule: 21.55 (min: 7, max: 31)


In [10]:
tree.data_storage[1]

Unnamed: 0,level,atom,con,conType,result,stdDeviation,attention,count
0,1,1,-1,-1,0.625488,0.012901,0.596191,10
1,2,70,0,4,0.625488,0.012901,0.379883,10
2,3,65,0,4,0.625488,0.012901,0.170898,10
3,4,70,1,4,0.625488,0.012901,0.003141,10
4,5,55,2,4,0.625488,0.012901,0.002874,10
5,6,33,1,1,0.625488,0.012901,0.00242,10
6,7,1,2,4,0.625488,0.012901,0.002121,10
7,8,34,5,1,0.625488,0.012901,0.002106,10
8,9,33,3,1,0.625488,0.012901,0.002008,10
9,10,34,8,1,0.625488,0.012901,0.001986,10


In [11]:
tree.tree_storage[1]

[(0, 1, -1, -1, 0.5960462093353271, [0, 1]),
 (1, 70, 0, 4, 0.379802405834198, [2]),
 (2, 65, 0, 4, 0.17085446417331696, [3]),
 (3, 70, 1, 4, 0.003141123801469803, [4]),
 (4, 55, 2, 4, 0.002873738994821906, [5]),
 (5, 33, 1, 1, 0.002419679891318083, [6]),
 (6, 1, 2, 4, 0.00212140497751534, [7]),
 (7, 34, 5, 1, 0.0021055617835372686, [8]),
 (8, 33, 3, 1, 0.002008170122280717, [9]),
 (9, 34, 8, 1, 0.0019848933443427086, [10]),
 (10, 31, 6, 1, 0.0019148633582517505, [11]),
 (11, 34, 10, 1, 0.002032126300036907, [12]),
 (12, 34, 10, 1, 0.00198705168440938, [13]),
 (13, 34, 10, 1, 0.0019818239379674196, [14, 15]),
 (14, 27, 4, 4, 0.0018974484410136938, []),
 (15, 26, 4, 4, 0.0018944578478112817, [])]

In [12]:
test_mol = mol_sup[4]

In [13]:
for n in test_mol.GetAtomWithIdx(4).GetNeighbors():
    print(n.GetIdx())

3
5
12


In [14]:
raise Exception("Stop here")

Exception: Stop here

In [None]:
%load_ext line_profiler
%lprun -f tree.get_molecules_partial_charges tree.get_molecules_partial_charges(mol)

In [None]:
%lprun -f tree.match_new_atom tree.match_new_atom(0, mol)