# Tutorial 1: Understanding our OFDatamodule that handles all data samples during training

The goal of this tutorial is to understand the behaviour and interplay of following classes that handle the loading and processing of our data: OFDataset, OFData, OFBatch, OFLoader, OFDataModule.

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import rich
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files like "get_len" to get lengths of lists
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree

# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main
os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

## 1 Loading the datamodule from our gigantic config

First, we load a large config as Omegaconf Dict config for training a model
with the defaut settings for data, optimizer, transforms, basis set, etc.
For now you can think of the config as a large nested dictionary that contains all settings
and hyperparameters used for training our OF-DFT model.
Later in the tutorial ([tutorial_4_hydra_omega_conf](./tutorial_4_hydra_omegaconf.ipynb)), we will go into more detail about how this works.

In [None]:
from omegaconf.dictconfig import DictConfig

# the following initialize already handles the communication and combination
# of the different config files, e.g. for data and the model
with initialize(version_base=None, config_path="../../configs/ml"):
    config = compose(
        config_name="train.yaml",
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"

# let us take a look at the part of the config data is used specifically for configuring the data module
rich.print(dict_to_tree(config.data.datamodule, guide_style="dim"))

In [None]:
from mldft.ml.data.datamodule import OFDataModule

# we will now use this part of the config to instantiate important individual parts
# of the full training pipeline e.g. the datamodule
datamodule = instantiate(config.data.datamodule)
datamodule.batch_size = 4  # set batch size to 4 (relatively small) for demonstration purposes
print("Successfully instantiated datamodule:", type(datamodule))
datamodule.setup(stage="fit")  # prepare the data, e.g. split into train, val, test
# with stage="fit" no test set is prepared, only the train and validation set used during training

In [None]:
# to get a quick look of what is combined in the datamodule, we can look at its __dict__
datamodule.__dict__

## 2 A first look at the dataset and a single sample

In [None]:
# In some place like this import the respective class so that you can click to definition
from mldft.ml.data.components.dataset import OFDataset

# let's look at the dataset(s):
# print the length ot the train and validation set used during training:
# so-called "split files" are handling the split into disjoint train, val and test set
print(f"Length of train set: {len(datamodule.train_set)}")
print(f"Length of val set: {len(datamodule.val_set)}")
print(f"type of train set: {type(datamodule.train_set)}")
print(f"isinstance of OFDataset {isinstance(datamodule.train_set, OFDataset)}")

In [None]:
from mldft.ml.data.components.of_data import OFData

# get a single sample for the train set
sample = datamodule.train_set[0]
print("Atom positions:", sample.pos)
print("Atom types:", sample.atomic_numbers)
print("number of coefficients:", sample.coeffs.shape)
print(
    "Integrals of basis functions used to describe the density:", sample.dual_basis_integrals.shape
)
print("scf_iteration:", sample.scf_iteration)
print("Energy label (kinetic energy + XC energy):", sample.energy_label)
print("Energy key:", datamodule.train_set.of_data_kwargs["energy_key"])
print("Is sample an instance of OFData?", isinstance(sample, OFData))

In [None]:
from mldft.utils.molecules import build_molecule_ofdata

# need basis info to build a pySCF molecule object
# see below for more details on basis_info
basis_info = instantiate(config.data.basis_info)

# build a pySCF molecule object from the OFData sample
mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)
print(f"type : {type(mol)}, xyz string of that molecule:\n")
print(mol.tostring("xyz"))

In [None]:
from pyscf.lib import param

bohr2ang = param.BOHR  # approx 0.529177 , i.e. 1 bohr = 0.529177 Angstrom
# our dataset works in the "distance" unit Bohr but others (like RDKit in this case) work in Angstrom
# to see how both are consistent we can convert the positions
print("Positions in Angstrom:\n", sample.pos * bohr2ang)

# note that from the pyscf.Mole object we can also get the atom positions in different units via:
print("Positions in Angstrom from pyscf.Mole:\n", mol.atom_coords(unit="Angstrom"))
print("Positions in Bohr from pyscf.Mole:\n", mol.atom_coords(unit="Bohr"))

In [None]:
from rdkit.Chem import Draw

from mldft.utils.conversions import pyscf_to_rdkit

# please, note that the transformation from a set of atom positions (e.g. xyzfile) to an rdkit molecule
# with bonds (and nice pictures/structure as below) is not necessarily well defined,
# since it is non-trivial to infer chemical bonds from just positions and atom types
# (though this should not be an issue for classic QM9 and QMUGS molecules)

rdkit_mol = pyscf_to_rdkit(mol)
print("type", type(rdkit_mol))
# show the molecule with rdkit
img = Draw.MolToImage(rdkit_mol)
plt.imshow(img)
plt.axis("off")
plt.show()

## 3 Our density representation and the BasisInfo class

We represent the electron density $\rho(\vec r)$, which is a function of 3D space, as a linear combination of so-called atom-cendered basis functions (each is a function of 3D space localized around a different atoms in the molecule).
$$\rho(\vec r) = \sum_\mu p_\mu \omega_\mu(\vec r)$$
$p_\mu$ are the density coefficients and $\omega_\mu(\vec r)$ are the different basis functions. We use Gaussian type orbitals (GTOs) as basis functions which combine a Gaussian-like radial part with a spherical harmonic angular part. Please take a look at the [STRUCTURES25 paper](https://pubs.acs.org/doi/10.1021/jacs.5c06219) for more details.  

Above, we have seen a visulization of a electron density using the coefficients and the basis functions.
let us now look at the basis info object in more detail:

In [None]:
from mldft.ml.data.components.basis_info import BasisInfo

# the essential info about all basis functions for the different atom types is stored in
# basis_info.basis_dict, a dictionary with the following structure:
# key: atom type val: list of (angular momentum, [exponent, weighting coeffs for contractions])
# see https://pyscf.org/user/gto.html#basis-format for details
basis_info.basis_dict

In [None]:
print("atomic numbers in the dataset:", basis_info.atomic_numbers)
print("Number of basis functions/coeffs per atom type:", basis_info.basis_dim_per_atom)

# for instance, we can take a look at the integrals of the basis functions for Hydrogen:
# all basis functions that have l>0 integrate to zero:
basis_info.integrals[0]

## 4 Our dataloader converts OFData into OFBatch objects
We are gradually moving towards training a model. For that, we take a look at the dataloaders that combine multiple molecules into batches, which are then passed to the model for training.

In [None]:
from mldft.ml.data.components.loader import OFLoader
from mldft.ml.data.components.of_batch import OFBatch

datamodule.batch_size = 4  # set batch size to 4 (relatively small) for demonstration purposes
train_loader = datamodule.train_dataloader()
for batch in train_loader:
    # get the first batch in the train loader as the model would
    batch
    break

# an alternative to get the first batch from the train_loader is the following:
# batch = next(iter(train_loader))

# one special thing about geometric graph data:
# different molecules have different number of atoms, therefore combining them into
# one batch is not as simple as stacking them into a tensor
# but it is more an appending into one large graph with all atoms and the
# batch.batch tensor indicating which atom belongs to which molecule in the large graph
print("number of molecules in the batch:", batch.num_graphs)
print("Number of atoms in the batch:", batch.num_nodes)
print("batch.batch:", batch.batch, "len(batch.batch):", len(batch.batch))
print("Length of 'concatenated' atom positions:", batch.pos.shape)
print("Length of 'concatenated' atomic numbers:", batch.atomic_numbers.shape)

# find out how many atoms are in each molecule in the batch
num_atoms_per_mol = torch.bincount(batch.batch)
print("Number of atoms per molecule in the batch:", num_atoms_per_mol)
# average, max, min, number of atoms in the molecules in the batch
print(
    "average number of atoms per molecule in the batch:", num_atoms_per_mol.float().mean().item()
)
print("max number of atoms per molecule in the batch:", num_atoms_per_mol.max().item())
print("min number of atoms per molecule in the batch:", num_atoms_per_mol.min().item())

In [None]:
# a batch can be separated into individual data samples (molecules) again via:
list_of_molecules = batch.to_data_list()
print(
    "Length of list_of_molecules:",
    len(list_of_molecules),
    "first mol in list is:",
    list_of_molecules[0],
)

It is also possible to create batches manually from a list of OFData samples:

In [None]:
from mldft.ml.data.components.of_data import Representation

# let us add a new property to each sample
# (this is a bit special since we always specify the representation of items, see sample.representations)
for molecule in list_of_molecules:
    molecule.add_item(
        key="example_property", value=torch.tensor(42.0), representation=Representation.SCALAR
    )

batch = OFBatch.from_data_list(list_of_molecules)
batch.example_property