In [15]:
import itertools
import os
from enum import Enum
from math import ceil
import re

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import scipy.signal
from pandarallel import pandarallel
from scipy.spatial.distance import cdist
from tqdm import tqdm
import matplotlib.markers
from dataclasses import dataclass
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F

pandarallel.initialize(progress_bar=True)

data_path = os.getenv("DATA_PATH", "../data/raw/ANI-1ccx_clean_fullentry.h5")

ATOMIC_PAIRS = [x for x in itertools.combinations_with_replacement([1, 6, 7, 8], 2)]
ATOMIC_NUMBERS = {
    "H": 1,
    "C": 6,
    "N": 7,
    "O": 8,
}

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [24]:
def load_h5_dataset_compact(path, show_progress=True, targets=None):
    mol_counter = defaultdict(defaultdict)
    n_atom_counter = defaultdict(int)

    if targets is None:
        targets = []
    elif isinstance(targets, str):
        targets = [targets]

    # Scout the molecule sizes & number of conformations for each molecule
    with h5py.File(path, "r") as f:
        for empirical_formula, entry in f.items():
            num_conformers, num_atoms, _ = entry["coordinates"].shape
            n_atom_counter[num_atoms] += num_conformers
            mol_counter[num_atoms][empirical_formula] = num_conformers

    # Sort by number of atoms, then alphabetically by formula
    n_atom_counter = dict(sorted(n_atom_counter.items(), key=lambda item: item[0]))
    mol_counter = dict(sorted(mol_counter.items(), key=lambda item: item[0]))
    for key in mol_counter:
        mol_counter[key] = dict(sorted(mol_counter[key].items(), key=lambda item: item[0], reverse=True))

    # Preallocate the coordinate and target arrays
    coordinates = {n: np.empty((count, n, 3), dtype=np.float32) for n, count in n_atom_counter.items()}

    if len(targets):
        target_vals = {n: np.empty((count, len(targets)), dtype=np.float32) for n, count in n_atom_counter.items()}

    with h5py.File(path, "r") as f:
        for n_atoms, counter in mol_counter.items():
            start = 0
            for mol, n_conformers in counter.items():
                coordinates[n_atoms][start:start+n_conformers] = f[mol]["coordinates"]
                for i, target in enumerate(targets):
                    target_vals[n_atoms][start:start+n_conformers, i] = f[mol][target]

                start += n_conformers

    if len(targets):
        return mol_counter, coordinates, target_vals
    else:
        return mol_counter, coordinates, {}


def molecule_to_numbers(molecule: str) -> np.ndarray:
    counts = re.findall(r'\d+', molecule)
    elements = re.findall(r"[A-Za-z]+", molecule)
    numbers = [ATOMIC_NUMBERS[element] for element in elements]
    numbers = [[int(number)] * int(count) for count, number in zip(counts, numbers)]
    numbers = list(itertools.chain.from_iterable(numbers))
    return np.array(numbers, dtype=np.int8)


def get_formula_range(formula, mol_counts):
    atomic_numbers = molecule_to_numbers(formula)
    num_atoms = atomic_numbers.shape[-1]        
    start = mol_counts[num_atoms][formula]
    end = start + mol_counts[num_atoms][formula]
    return start, end
    

def torch_cdist(x, y=None):
    x = torch.from_numpy(x)
    if y is None:
        return torch.cdist(x, x)
    else:
        y = torch.from_numpy(y)
        return torch.cdist(x, y)

In [42]:
@dataclass
class Dataset:
    mol_counts: dict
    coordinates: dict
    targets: dict

    def __init__(self, path, targets=None):
        self.mol_counts, self.coordinates, self.targets = load_h5_dataset_compact(path, targets=targets)

    def __getitem__(self, n):
        return self.mol_counts[n], self.coordinates[n], self.targets[n]

    def atomic_numbers(self, n=None):
        if n is None:
            for n, counts in self.mol_counts.items():
                for mol, count in counts.items():
                    for _ in range(count):
                        yield molecule_to_numbers(mol)
        else:
            for mol, count in self.mol_counts[n].items():
                for _ in range(count):
                    yield molecule_to_numbers(mol)

    def atomic_numbers_from_formula(self, formula):
        return molecule_to_numbers(formula)

In [43]:
dset = Dataset(data_path, targets=["dftb.elec_energy", "ccsd(t)_cbs.energy"])

In [44]:
distances = {n: torch_cdist(c).numpy() for n, c in coordinates.items()}

In [66]:
def group_distances(distance_matrix, atomic_numbers):
    distances = {pair: [] for pair in ATOMIC_PAIRS}
    for i, j in zip(*np.triu_indices_from(distance_matrix, k=1)):
        atom_atom_distance = distance_matrix[i, j]
        atomic_number_pair = tuple(sorted((atomic_numbers[i], atomic_numbers[j])))
        distances[atomic_number_pair].append(atom_atom_distance)
    return distances


def distances_from_coordinates(coordinates, atomic_numbers):
    if len(coordinates.shape) == 2:
        pairwise_distances = [cdist(coordinates, coordinates)]
    else:
        pairwise_distances = torch_cdist(coordinates).numpy()

    if isinstance(atomic_numbers, np.ndarray) and len(atomic_numbers.shape) == 1:
        atomic_numbers = np.array([atomic_numbers] * len(coordinates))

    all_distances = []
    for pairwise_distance, _atomic_numbers in zip(pairwise_distances, atomic_numbers):
        all_distances.append(group_distances(pairwise_distance, _atomic_numbers))

    if len(all_distances) == 1:
        return all_distances[0]
    else:
        return all_distances

In [67]:
import multiprocessing

with multiprocessing.Pool(10) as pool:
    distances_grouped = {n: pool.starmap(distances_from_coordinates, zip(distances[n], dset.atomic_numbers(n))) for n in distances}  

In [68]:
df = load_h5_dataset(data_path, n_formulas=n_formulas)
df = df.set_index("mol", drop=True)
df["num_heavy_atoms"] = df["atomic_numbers"].parallel_apply(lambda x: len(x) - np.sum(np.array(x) == 1))


[defaultdict(list, {(7, 7): [1.5835213303514193]}),
 defaultdict(list, {(7, 7): [1.5904232982635633]}),
 defaultdict(list, {(7, 7): [1.5947991526633585]}),
 defaultdict(list, {(7, 7): [1.5744135649756092]}),
 defaultdict(list, {(7, 7): [1.574230816240466]}),
 defaultdict(list, {(7, 7): [1.5887283206038492]}),
 defaultdict(list, {(7, 7): [1.5762096950717495]}),
 defaultdict(list, {(7, 7): [1.5954217459095614]}),
 defaultdict(list, {(7, 7): [1.590901917875253]}),
 defaultdict(list, {(7, 7): [1.5540571429070296]}),
 defaultdict(list, {(7, 7): [1.5964391708326144]}),
 defaultdict(list, {(7, 7): [1.5795022068774205]}),
 defaultdict(list, {(7, 7): [1.5860607622678977]}),
 defaultdict(list, {(7, 7): [1.5733897337315825]}),
 defaultdict(list, {(7, 7): [1.587566584871495]}),
 defaultdict(list, {(7, 7): [1.5783986337960112]}),
 defaultdict(list, {(7, 7): [1.5876254218720163]}),
 defaultdict(list, {(7, 7): [1.5894925272610487]}),
 defaultdict(list, {(7, 7): [1.5669798724226518]}),
 defaultdict(li