Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/feat/load data from archive #8

Merged
merged 38 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
fb185f5
use correct type annotation (Path not str)
nmaeder Jun 14, 2024
11b7fad
add git keep for data to be downloaded
nmaeder Jun 14, 2024
bf6a0fa
initial commit
nmaeder Jun 14, 2024
6d82fa5
add function skeletons
nmaeder Jun 14, 2024
8b351dc
more functionality added
nmaeder Jun 14, 2024
631ab23
more exceptions for data setup
nmaeder Jun 14, 2024
c3c2835
formatting
nmaeder Jun 14, 2024
3a2cf04
add functionality and tests
nmaeder Jun 14, 2024
a1eee7d
initial commit
nmaeder Jun 14, 2024
338e70d
added sepearate tests directory with retrieve data tests and testfiles
nmaeder Jun 17, 2024
6489db3
add tree types and tree factory
nmaeder Jun 17, 2024
4b4b9df
formatting
nmaeder Jun 17, 2024
377e12b
open for rist test by @MTLehner
nmaeder Jun 18, 2024
89b4f77
cleanup paths and fix tests
nmaeder Jun 18, 2024
e0a8f8c
remove data fetch from setup
nmaeder Jun 18, 2024
01b7ada
remove unused import
nmaeder Jun 18, 2024
7ae4cab
fixed tree import and tree-file download
MTLehner Jun 18, 2024
9cb902f
fix python 3.9 incompatible syntax
MTLehner Jun 18, 2024
6f9c391
update README
MTLehner Jun 18, 2024
f5b0cfe
fix pre-commit
MTLehner Jun 18, 2024
0b25c77
fix typo
MTLehner Jun 18, 2024
1018336
fix table nameing
MTLehner Jun 18, 2024
13dc907
remove last parts of versioneer
MTLehner Jun 18, 2024
5d78517
remove last-last parts of versioneer
MTLehner Jun 18, 2024
ebceaea
thoroughly test retrieve_data
nmaeder Jun 19, 2024
14e508e
remove pytorch channels
nmaeder Jun 19, 2024
f1e803a
run new tests too in ci
nmaeder Jun 19, 2024
6115a55
change paths
nmaeder Jun 19, 2024
a4d21d9
add utils tests
nmaeder Jun 19, 2024
d7fddc1
add utils tests
nmaeder Jun 19, 2024
6bbde08
run it on ci
nmaeder Jun 19, 2024
47c25c8
formatting
nmaeder Jun 20, 2024
60ab6e6
fist test
nmaeder Jun 20, 2024
735a1e4
fix CI skipif condition
nmaeder Jun 21, 2024
be7b102
add tree_factory tests
nmaeder Jun 21, 2024
9472c36
dont use setup.py for installing
nmaeder Jun 21, 2024
25f8b0c
also add testfiles
nmaeder Jun 21, 2024
4045261
pre-commit run
nmaeder Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
176 changes: 147 additions & 29 deletions serenityff/charge/tree/dash_tree.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,113 @@
import gzip
import io
import os
import pickle
import gzip
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm

# from multiprocessing import Process, Manager
# from numba import njit, objmode, types

# from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor

import io
from serenityff.charge.tree.retrieve_data import (
DATA_URL,
data_is_complete,
get_additional_data,
)
from serenityff.charge.utils.exceptions import DataIncompleteError

try:
import IPython.display
except ImportError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
from collections import defaultdict
from enum import Enum, auto

from PIL import Image
from rdkit.Chem.Draw import rdMolDraw2D
from collections import defaultdict

from serenityff.charge.tree.atom_features import AtomFeatures
from serenityff.charge.data import default_dash_tree_path
from serenityff.charge.utils.rdkit_typing import Molecule
from serenityff.charge.tree.atom_features import AtomFeatures
from serenityff.charge.tree.dash_tools import (
init_neighbor_dict,
new_neighbors,
new_neighbors_atomic,
init_neighbor_dict,
)
from serenityff.charge.utils.rdkit_typing import Molecule


class TreeType(Enum):
DEFAULT = auto()
AM1BCC = auto()
RESP = auto()
MULLIKEN = auto()
CHARGES = auto()
DUALDESCRIPTORS = auto()
C6 = auto()
POLARIZABILITY = auto()
DFTD4 = auto()
DIPOLE = auto()
FULL = auto()


COLUMN_DICT = {
0: "level",
1: "atom_type",
2: "con_atom",
3: "con_type",
4: "max_attention",
5: "size",
6: "result",
7: "std",
8: "mulliken",
9: "resp1",
10: "resp2",
11: "AM1BCC",
12: "AM1BCC_std",
13: "dual",
14: "mbis_dipole_strength",
15: "dipole_bond_1",
16: "dipole_bond_2",
17: "dipole_bond_3",
18: "DFTD4:C6",
19: "DFTD4:C6_std",
20: "DFTD4:polarizability",
21: "DFTD4:polarizability_std",
}


COLUMNS = {
TreeType.DEFAULT: [0, 1, 2, 3, 4, 5, 6, 7],
TreeType.MULLIKEN: [0, 1, 2, 3, 4, 5, 7, 8],
TreeType.RESP: [0, 1, 2, 3, 4, 5, 7, 9, 10],
TreeType.AM1BCC: [0, 1, 2, 3, 4, 5, 11, 12],
TreeType.CHARGES: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
TreeType.DUALDESCRIPTORS: [0, 1, 2, 3, 4, 5, 7, 13],
TreeType.C6: [0, 1, 2, 3, 4, 5, 18, 19],
TreeType.POLARIZABILITY: [0, 1, 2, 3, 4, 5, 20, 21],
TreeType.DIPOLE: [0, 1, 2, 3, 4, 5, 7, 4, 15, 16, 17],
TreeType.FULL: list(COLUMN_DICT),
}


class DASHTree:
def __init__(
self,
tree_folder_path: str = default_dash_tree_path,
tree_folder_path: Path = default_dash_tree_path,
preload: bool = True,
verbose: bool = True,
num_processes: int = 1,
default_value_column: str = "result",
default_std_column: str = "std",
tree_type: TreeType = TreeType.DEFAULT,
) -> None:
"""
Class to handle DASH trees and data

Parameters
----------
tree_folder_path : str
tree_folder_path : Path
Path to folder containing DASH trees and data
preload : bool
If True, load all trees and data into memory, if False, load on demand
Expand All @@ -61,8 +123,26 @@ def __init__(
self.tree_storage = {}
self.data_storage = {}
self.atom_feature_type = AtomFeatures
self.default_value_column = default_value_column
self.default_std_column = default_std_column
if preload:
self.load_all_trees_and_data()
self.load_all_trees_and_data(tree_type=tree_type)

@property
def default_value_column(self) -> str:
return self._default_value_column

@property
def default_std_column(self) -> str:
return self._default_std_column

@default_value_column.setter
def default_value_column(self, value: str) -> None:
self._default_value_column = value

@default_std_column.setter
def default_std_column(self, value: str) -> None:
self._default_std_column = value

########################################
# Tree import/export functions
Expand All @@ -71,7 +151,10 @@ def __init__(
# tree file format:
# int(id_counter), int(atom_type), int(con_atom), int(con_type), float(oldNode.attention), []children

def load_all_trees_and_data(self) -> None:
def load_all_trees_and_data(
self,
tree_type: TreeType = TreeType.DEFAULT,
) -> None:
"""
Load all trees and data from the tree_folder_path, expects files named after the atom feature key and
the file extension .gz for the tree and .h5 for the data
Expand All @@ -82,10 +165,25 @@ def load_all_trees_and_data(self) -> None:
print("Loading DASH tree data")
# # import all files
# if True: # self.num_processes <= 1:
success = data_is_complete(self.tree_folder_path)
if not success:
success = get_additional_data()
if not success:
raise DataIncompleteError(
"The Additional Data was not downloaded and extracted Correctly. "
"Please make sure to be connected to the internet for the DASH tree to collect all the "
f"additional data from the ETHZ Research Collection {DATA_URL}."
)

for i in range(self.atom_feature_type.get_number_of_features()):
tree_path = os.path.join(self.tree_folder_path, f"{i}.gz")
df_path = os.path.join(self.tree_folder_path, f"{i}.h5")
self.load_tree_and_data(tree_path, df_path, branch_idx=i)
self.load_tree_and_data(
tree_path,
df_path,
branch_idx=i,
tree_type=tree_type,
)
# else:
# Threads don't seem to work due to HDFstore key error
# with ThreadPoolExecutor(max_workers=self.num_processes) as executor:
Expand All @@ -96,7 +194,14 @@ def load_all_trees_and_data(self) -> None:
if self.verbose:
print(f"Loaded {len(self.tree_storage)} trees and data")

def load_tree_and_data(self, tree_path: str, df_path: str, hdf_key: str = "df", branch_idx: int = None) -> None:
def load_tree_and_data(
self,
tree_path: str,
df_path: str,
tree_type: TreeType = TreeType.DEFAULT,
hdf_key: str = "df",
branch_idx: int = None,
) -> None:
"""
Load a tree and data from the tree_folder_path, expects files named after the atom feature key and
the file extension .gz for the tree and .h5 for the data
Expand All @@ -118,7 +223,8 @@ def load_tree_and_data(self, tree_path: str, df_path: str, hdf_key: str = "df",
branch_idx = int(os.path.basename(tree_path).split(".")[0])
with gzip.open(tree_path, "rb") as f:
tree = pickle.load(f)
df = pd.read_hdf(df_path, key=hdf_key, mode="r")
columns = [COLUMN_DICT[v] for v in COLUMNS[tree_type]]
df = pd.read_hdf(df_path, key=hdf_key, mode="r", columns=columns)
self.tree_storage[branch_idx] = tree
self.data_storage[branch_idx] = df

Expand Down Expand Up @@ -219,9 +325,12 @@ def match_new_atom(
neighbor_dict = init_neighbor_dict(mol, af=self.atom_feature_type)

# get layer 0, and init all relevant containers
branch_idx, matched_node_path, atom_indices_in_subgraph, max_depth = self._get_init_layer(
mol=mol, atom=atom, max_depth=max_depth
)
(
branch_idx,
matched_node_path,
atom_indices_in_subgraph,
max_depth,
) = self._get_init_layer(mol=mol, atom=atom, max_depth=max_depth)

# if data for branch is not preloaded, load it now
if branch_idx not in self.tree_storage:
Expand Down Expand Up @@ -251,9 +360,10 @@ def match_new_atom(
atom_indices_in_subgraph.append(atom)
node_attention = self.tree_storage[branch_idx][child][4]
cummulative_attention += node_attention
possible_new_atom_features_toAdd, possible_new_atom_idxs_toAdd = new_neighbors_atomic(
neighbor_dict, atom_indices_in_subgraph, atom
)
(
possible_new_atom_features_toAdd,
possible_new_atom_idxs_toAdd,
) = new_neighbors_atomic(neighbor_dict, atom_indices_in_subgraph, atom)
possible_new_atom_features.extend(possible_new_atom_features_toAdd)
possible_new_atom_idxs.extend(possible_new_atom_idxs_toAdd)
if cummulative_attention > attention_threshold:
Expand Down Expand Up @@ -399,8 +509,8 @@ def get_molecules_partial_charges(
attention_incremet_threshold: float = 0,
verbose: bool = False,
default_std_value: float = 0.1,
chg_key: str = "result",
chg_std_key: str = "stdDeviation",
chg_key: str | None = None,
chg_std_key: str | None = None,
nodePathList=None,
):
"""
Expand Down Expand Up @@ -433,6 +543,10 @@ def get_molecules_partial_charges(
dict
Dictionary containing the partial charges, standard deviations and match depths of all atoms
"""
if chg_key is None:
chg_key = self.default_value_column
if chg_std_key is None:
chg_std_key = self.default_std_column
return_list = []
tree_raw_charges = []
tree_charge_std = []
Expand Down Expand Up @@ -588,8 +702,8 @@ def get_molecular_dipole_moment(
self,
mol: Molecule,
inDebye: bool = True,
chg_key: str = "result",
chg_std_key: str = "std",
chg_key: str | None = None,
chg_std_key: str | None = None,
sngl_cnf=True,
nconfs=10,
pruneRmsThresh=0.5,
Expand All @@ -602,6 +716,10 @@ def get_molecular_dipole_moment(
Get the dipole moment of a molecule by matching all atoms to DASH tree subgraphs and
summing the dipole moments of the matched atoms
"""
if chg_key is None:
chg_key = self.default_value_column
if chg_std_key is None:
chg_std_key = self.default_std_column
chgs = self.get_molecules_partial_charges(
mol=mol,
norm_method="std_weighted",
Expand Down
Loading
Loading