In [21]:
from rdkit import Chem
from tqdm import tqdm
from serenityff.charge.tree.dash_tree import DASHTree
from serenityff.charge.tree.atom_features import AtomFeatures, get_connection_info_bond_type
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [2]:
tree = DASHTree(num_processes=1)
# n=1 19.5
# n=2 30.6
# n=3 31.6

Loading DASH tree data


In [3]:
mol_sup = Chem.SDMolSupplier("../data/example.sdf", removeHs=False)

In [4]:
tree.match_new_atom(0, mol_sup[0])

[34, 0, 167331, 167332, 167333, 169504]

In [5]:
tree.get_molecules_partial_charges(mol_sup[4])["charges"]

[-0.743588904134849,
 1.1309775445707058,
 -0.6724881843055797,
 0.2940594235400035,
 -0.23680232625932834,
 -0.16523398145988805,
 -0.23680232625932834,
 0.33431095702447067,
 -0.6338138301045644,
 -0.004846247250981643,
 -0.48499433858143876,
 0.4956579075462079,
 0.1489313589438346,
 0.1489313589438346,
 0.1489313589438346,
 0.47677022884306663]

In [6]:
for mol in tqdm(mol_sup, total=len(mol_sup)):
    tree.get_molecules_partial_charges(mol)

 60%|██████    | 12/20 [00:00<00:00, 30.51it/s]

100%|██████████| 20/20 [00:00<00:00, 26.11it/s]


In [7]:
%timeit [tree.get_molecules_partial_charges(mol) for mol in mol_sup]

374 ms ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
# timing baseline (04.09.2023)                                    991 ms ± 67.2 
# timing with AtomFeature changed to tuples                       874 ms ± 210 ms
# timing with precalculated atom features                         515 ms ± 184 ms
# timing with stored neighbor list                                359 ms ± 55.3 ms

In [9]:
tree.get_molecules_partial_charges(mol_sup[4])

{'charges': [-0.743588904134849,
  1.1309775445707058,
  -0.6724881843055797,
  0.2940594235400035,
  -0.23680232625932834,
  -0.16523398145988805,
  -0.23680232625932834,
  0.33431095702447067,
  -0.6338138301045644,
  -0.004846247250981643,
  -0.48499433858143876,
  0.4956579075462079,
  0.1489313589438346,
  0.1489313589438346,
  0.1489313589438346,
  0.47677022884306663],
 'std': [0.01751708984375,
  0.0034885406494140625,
  0.0262298583984375,
  0.09539794921875,
  0.083984375,
  0.043304443359375,
  0.083984375,
  0.046722412109375,
  0.10308837890625,
  0.0889892578125,
  0.022064208984375,
  0.0082550048828125,
  0.01739501953125,
  0.01739501953125,
  0.01739501953125,
  0.0116729736328125],
 'match_depth': [4, 4, 7, 4, 3, 3, 3, 4, 2, 4, 2, 8, 4, 4, 4, 3]}

In [10]:
num_atoms_per_mol = [mol.GetNumAtoms() for mol in mol_sup]
print(f"Average number of atoms per molecule: {sum(num_atoms_per_mol)/len(num_atoms_per_mol)} (min: {min(num_atoms_per_mol)}, max: {max(num_atoms_per_mol)})")

Average number of atoms per molecule: 21.55 (min: 7, max: 31)


In [11]:
tree.data_storage[1]

Unnamed: 0,level,atom,con,conType,result,stdDeviation,attention,count
0,1,1,-1,-1,0.625488,0.012901,0.596191,10
1,2,70,0,4,0.625488,0.012901,0.379883,10
2,3,65,0,4,0.625488,0.012901,0.170898,10
3,4,70,1,4,0.625488,0.012901,0.003141,10
4,5,55,2,4,0.625488,0.012901,0.002874,10
5,6,33,1,1,0.625488,0.012901,0.00242,10
6,7,1,2,4,0.625488,0.012901,0.002121,10
7,8,34,5,1,0.625488,0.012901,0.002106,10
8,9,33,3,1,0.625488,0.012901,0.002008,10
9,10,34,8,1,0.625488,0.012901,0.001986,10


In [12]:
tree.tree_storage[1]

[(0, 1, -1, -1, 0.5960462093353271, [0, 1]),
 (1, 70, 0, 4, 0.379802405834198, [2]),
 (2, 65, 0, 4, 0.17085446417331696, [3]),
 (3, 70, 1, 4, 0.003141123801469803, [4]),
 (4, 55, 2, 4, 0.002873738994821906, [5]),
 (5, 33, 1, 1, 0.002419679891318083, [6]),
 (6, 1, 2, 4, 0.00212140497751534, [7]),
 (7, 34, 5, 1, 0.0021055617835372686, [8]),
 (8, 33, 3, 1, 0.002008170122280717, [9]),
 (9, 34, 8, 1, 0.0019848933443427086, [10]),
 (10, 31, 6, 1, 0.0019148633582517505, [11]),
 (11, 34, 10, 1, 0.002032126300036907, [12]),
 (12, 34, 10, 1, 0.00198705168440938, [13]),
 (13, 34, 10, 1, 0.0019818239379674196, [14, 15]),
 (14, 27, 4, 4, 0.0018974484410136938, []),
 (15, 26, 4, 4, 0.0018944578478112817, [])]

In [13]:
test_mol = mol_sup[4]

In [14]:
for n in test_mol.GetAtomWithIdx(4).GetNeighbors():
    print(n.GetIdx())

3
5
12


In [15]:
%load_ext line_profiler
%lprun -f tree.get_molecules_partial_charges tree.get_molecules_partial_charges(mol)

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


Timer unit: 1e-09 s

Total time: 0.049371 s
File: /home/mlehner/serenityff-charge/serenityff/charge/tree/dash_tree.py
Function: get_molecules_partial_charges at line 248

Line #      Hits         Time  Per Hit   % Time  Line Contents
   248                                               def get_molecules_partial_charges(
   249                                                   self,
   250                                                   mol: Molecule,
   251                                                   norm_method: str = "std_weighted",
   252                                                   max_depth: int = 16,
   253                                                   attention_threshold: float = 10,
   254                                                   attention_incremet_threshold: float = 0,
   255                                                   verbose: bool = False,
   256                                                   default_std_value: float = 0.1,
   257          

In [16]:
%lprun -f tree.match_new_atom tree.match_new_atom(0, mol)

Timer unit: 1e-09 s

Total time: 0.00109272 s
File: /home/mlehner/serenityff-charge/serenityff/charge/tree/dash_tree.py
Function: match_new_atom at line 155

Line #      Hits         Time  Per Hit   % Time  Line Contents
   155                                               def match_new_atom(
   156                                                   self,
   157                                                   atom: int,
   158                                                   mol: Molecule,
   159                                                   max_depth: int = 16,
   160                                                   attention_threshold: float = 10,
   161                                                   attention_incremet_threshold: float = 0,
   162                                               ):
   163                                                   """
   164                                                   Match a atom in a molecule to a DASH tree subgraph. The matchin

In [17]:
%lprun -f tree._init_neighbor_dict tree._init_neighbor_dict(mol)

Timer unit: 1e-09 s

Total time: 0.00326334 s
File: /home/mlehner/serenityff-charge/serenityff/charge/tree/dash_tree.py
Function: _init_neighbor_dict at line 119

Line #      Hits         Time  Per Hit   % Time  Line Contents
   119                                               def _init_neighbor_dict(self, mol: Molecule):
   120         1        960.0    960.0      0.0          neighbor_dict = {}
   121        27     160888.0   5958.8      4.9          for atom_idx, atom in enumerate(mol.GetAtoms()):
   122        27      38958.0   1442.9      1.2              neighbor_dict[atom_idx] = []
   123        52     107340.0   2064.2      3.3              for neighbor in atom.GetNeighbors():
   124        52    1653932.0  31806.4     50.7                  af_with_connection_info = AtomFeatures.atom_features_from_molecule_w_connection_info(
   125        52    1183422.0  22758.1     36.3                      mol, neighbor.GetIdx(), (0, atom_idx)
   126                                         

In [19]:
%lprun -f AtomFeatures.atom_features_from_molecule_w_connection_info AtomFeatures.atom_features_from_molecule_w_connection_info(mol, 1, (0, 0))

Timer unit: 1e-09 s

Total time: 7.6661e-05 s
File: /home/mlehner/serenityff-charge/serenityff/charge/tree/atom_features.py
Function: atom_features_from_molecule_w_connection_info at line 217

Line #      Hits         Time  Per Hit   % Time  Line Contents
   217                                               @staticmethod
   218                                               def atom_features_from_molecule_w_connection_info(
   219                                                   molecule: Molecule, index: int, connected_to: Tuple[Any] = (-1, -1)
   220                                               ) -> int:
   221         1        178.0    178.0      0.2          connected_bond_type = (
   222         1      44666.0  44666.0     58.3              -1 if connected_to[1] == -1 else get_connection_info_bond_type(molecule, int(index), int(connected_to[1]))
   223                                                   )
   224         1      31292.0  31292.0     40.8          key = AtomFeatures.a

In [23]:
%lprun -f get_connection_info_bond_type get_connection_info_bond_type(mol, 1, 0)


Timer unit: 1e-09 s

Total time: 4.021e-05 s
File: /home/mlehner/serenityff-charge/serenityff/charge/tree/atom_features.py
Function: get_connection_info_bond_type at line 15

Line #      Hits         Time  Per Hit   % Time  Line Contents
    15                                           def get_connection_info_bond_type(molecule: Molecule, index: int, connected_to: int) -> int:
    16         1      27801.0  27801.0     69.1      bond = molecule.GetBondBetweenAtoms(int(index), int(connected_to))
    17         1        325.0    325.0      0.8      if bond is None:
    18                                                   return -1
    19         1       3979.0   3979.0      9.9      elif bond.GetIsConjugated():
    20                                                   return 4
    21                                               else:
    22         1       8105.0   8105.0     20.2          return int(bond.GetBondType())