In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
import IPython.display
from collections import defaultdict
from tqdm import tqdm

from serenityff.charge.tree.dash_tree import DASHTree

In [None]:
dftd4_sdf_path = "./mols_comb_dftd4.sdf"
mol_sup = Chem.SDMolSupplier(dftd4_sdf_path, removeHs=False)

## combine trees

In [None]:
base_tree = DASHTree(tree_folder_path="./props_tree/")

In [None]:
new_tree = DASHTree()
new_tree.tree_folder_path = "./dashProps_tree/"

In [None]:
for key in tqdm(base_tree.data_storage.keys()):
    df = base_tree.data_storage[key]
    df_c6 = pd.read_hdf(f'./test_143_c6/{key}.h5', key='df', mode='r')[["DFTD4:C6",	"DFTD4:C6_std",	"DFTD4:polarizability",	"DFTD4:polarizability_std"]]
    df_am1bcc = pd.read_hdf(f'./test_145_am1bcc/{key}.h5', key='df', mode='r')[["AM1BCC",	"AM1BCC_std"]]
    df_all = pd.concat([df, df_am1bcc,df_c6], axis=1)
    new_tree.data_storage[key] = df_all.copy(deep=True)

In [None]:
rename_columns = {"atom": "atom_type", "con": "con_atom", "conType": "con_type", "stdDeviation":"std", "attention": "max_attention", "count": "size"}

In [None]:
for key in tqdm(new_tree.data_storage.keys()):
    df = new_tree.data_storage[key]
    df.rename(columns=rename_columns, inplace=True)
    new_tree.data_storage[key] = df

In [None]:
new_tree.save_all_trees_and_data()

## plots

In [None]:
mol = Chem.MolFromSmiles("c1ccccc1O")
mol = Chem.AddHs(mol)

In [None]:
mol

In [None]:
npath, match_indices = new_tree.match_new_atom(mol=mol, atom=1, return_atom_indices=True)
print(npath)
print(match_indices)

In [None]:
new_tree.get_atom_properties(matched_node_path=new_tree.match_new_atom(mol=mol, atom=1))

In [None]:
new_tree.get_property_noNAN(mol=mol, atom=1, property_name="DFTD4:C6")

In [None]:
mol = mol_sup[2857]

In [None]:
mol

In [None]:
mol.GetPropsAsDict()

In [None]:
new_tree.get_molecules_partial_charges(mol=mol, chg_std_key="std")['charges']

In [None]:
mol_c6 = []
for atom_idx in range(mol.GetNumAtoms()):
    mol_c6.append(new_tree.get_property_noNAN(mol=mol, atom=atom_idx, property_name="DFTD4:C6"))
mol_c6