# Tutorial 4: Hydra, OmegaConf, Overrides

In this tutorial, you will learn more about the underlying structure which includes config files, the hydra structure and the OmegaConf magic. We will also cover the topic of overrides which allows us to access the none-default settings.

## 0 Imports

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import numpy as np
from hydra import compose, initialize
from hydra.utils import instantiate

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files. For example, "get_len" is used to get lengths of lists.
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers

# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main
os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

# 1 Hydra

Hydra is used for configuration management. It can be thought of as a tree of configuration files. 
Usually a parent config file is used to set global variables and to specify which other child config files to use. 
The child config files then set specific variables that are usually related to a specific topic (e.g. model architecture, training parameters, data parameters, etc.).

To understand how the child config files are impemented, it is recommended to take a look a the OmegaConf magic in the next chapter.

# 2 OmegaConf Magic

First, we look at and example config and see how the tree structure looks like. 

Taking a closer look at the config file, we can see that a OmegaConf resolver is used to get the length of the list of hidden layers.
In an additional example in the Apendix 1, we show how you can create your own custom resolvers.


We will use this config later to instantiate a model and a dataset.

In [None]:
from omegaconf import OmegaConf

# understanding the omegaconf config magic and syntax:
example_config = {
    "sub_dict": {"a": 1, "b": 2, "l": [1, 2, 3]},
    "len_of_l": "${get_len:${sub_dict.l}}",  # this uses the custom resolver "get_len" to get the length of list l
    # use this structure to cross-reference within a config ${sub_dict.l}
    "mlp": {
        "_target_": "mldft.ml.models.components.mlp.MLP",  # this is used by hydra to instantiate an object of the given class
        "in_channels": 3,
        "hidden_channels": [16, 16, 1],
    },
}

omegaconf_example_config = OmegaConf.create(example_config)
# if you print the config naively, it shows just the strings
print("OmegaConf config:", omegaconf_example_config)
# BUT if you access the value, it resolves the string using the custom resolver
print("Value from accessing len_of_l", omegaconf_example_config.len_of_l)  # prints 3

# instantiate the MLP based on the example  config above:
# when calling instantiate, the _target_ field is used to find the class
# and all other fields are passed as arguments to the class constructor:
# in this case the MLP class from mldft.ml.models.components.mlp
# with in_channels=3 and hidden_channels=[16, 16, 1] as arguments
mlp = instantiate(omegaconf_example_config.mlp)
print("\nInstantiated MLP:", mlp, "\n")

In yaml formatting the above config will look something like:

```
# example_config.yaml

sub_dict:
  a: 1
  b: 2
  l: [1, 2, 3]

# uses the custom resolver "get_len" to get the length of list l
len_of_l: ${get_len:${sub_dict.l}}

# use this structure to cross-reference within a config: ${sub_dict.l}
mlp:
  # used by Hydra to instantiate an object of the given class
  _target_: mldft.ml.models.components.mlp.MLP
  in_channels: 3
  hidden_channels: [16, 16, 1]
```

All our configs (in hierarchical structure are collected in the `configs` folder). The highest level config for model training to start from is the [configs/ml/train.yaml](../../configs/ml/train.yaml).

## 3 Config for model training

Now, we want to load the config for the actual model training. Additionally, we load the data and create batches which can easierly be  handled by the model.


In [None]:
from omegaconf.dictconfig import DictConfig

# load the config as Omegaconf Dict config for training a model
# with the defaut settings for data, optimizer, transforms, basis set, etc.
# this already handles the communication and combination of the different config files, e.g. for data and the model
with initialize(version_base=None, config_path="../../configs/ml"):
    config = compose(
        config_name="train.yaml",
        overrides=[
            "data.dataset_name=QM9_perturbed_fock",  # this will no longer be necessary once the "fixed" is removed from the dataset_name
        ],
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"

datamodule = instantiate(config.data.datamodule)
datamodule.setup(stage="fit")
datamodule.batch_size = 4  # set batch size to 4 (relatively small) for demonstration purposes
train_loader = datamodule.train_dataloader()

## 4 Overriding the default config

 As we not always want to use the default config, here is an examples of how to override settings. 
In more detail, we now wish to use the QMugs dataset instead of the default QM9 dataset.

Below, we also prepare the dataset for the fit stage of training, i.e. we use a smaller subset of the data for training and validation.

In [None]:
with initialize(version_base=None, config_path="../../configs/ml"):
    config_qmugs = compose(
        config_name="train.yaml",
        overrides=[
            # this overrides the data used to the qmugs dataset
            "data.dataset_name=QMUGS_perturbed_fock",  # with the dot we override a nested field
            "data/transforms=no_basis_transforms",  # with the / we override a whole file
        ],
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config_qmugs.paths.output_dir = "example_path"

datamodule_qmugs = instantiate(config_qmugs.data.datamodule)
datamodule_qmugs.setup(stage="fit")  # prepare the datasets
print(f"Length of qmugs train set: {len(datamodule_qmugs.train_set)}")
print(f"Length of qmugs val set: {len(datamodule_qmugs.val_set)}")

To get a better intuition on how the QMugs datset is different from QM9 in terms of complexity, we visualize a QMugs molecule below. For visualizations of example QM9 molecules, please have a look at [Tutorial 2](tutorial_2_visualization.ipynb).

In [None]:
import sys

# keep only the program name so downstream parsers don't see Jupyter's -f=...
sys.argv = sys.argv[:1]

import pyvista

from mldft.utils.molecules import build_molecule_ofdata
from mldft.utils.visualize_3d import get_sticks_mesh_dict

basis_info_qmugs = instantiate(config_qmugs.data.basis_info)

# look at a qmugs molecule to see that they are larger than qm9 molecules
for sample_qmugs in datamodule_qmugs.train_set:
    if sample_qmugs.mol_id.startswith("qmugs"):
        print("Found a qmugs molecule:", sample_qmugs.mol_id)
        break

mol_qmugs = build_molecule_ofdata(sample_qmugs, basis=basis_info_qmugs.basis_dict)


# this give a ball and stick model of the molecule
molecule_mesh = get_sticks_mesh_dict(mol_qmugs)
molecule_mesh["opacity"] = 1

# plot the molecule and the global frame using pyvista:
pyvista.set_jupyter_backend("html")
pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl.camera_position = "zx"
pl.enable_parallel_projection()
pl.add_mesh(**molecule_mesh)
pl.enable_shadows()
pl.reset_camera(
    bounds=0.9
    * np.stack([mol_qmugs.atom_coords().min(0), mol_qmugs.atom_coords().max(0)], axis=1).flatten()
)

img = pl.show(screenshot=True, window_size=(800, 400))
plt.show()

## Appendix 1: Self-created custom resolver

Above, get_len is already defined and registered as custom resolver in mldft.utils.omegaconf_resolvers.
But, we can also define our own omega conf custom resolver:

In [None]:
# check if "sum is already registered"
if OmegaConf.has_resolver("sum"):
    print("sum is already registered")
else:
    print("registering sum")
    # register a custom resolver "sum" that sums up all its arguments
    OmegaConf.register_new_resolver("sum", lambda *args: sum(args))

example_config = {
    "sub_dict": {"a": 17, "b": -5},
    "sum_a_b": "${sum:${sub_dict.a}, ${sub_dict.b}}",  # this uses the custom resolver "get_len" to get the length of list l
    # use this structure to cross-reference within a config ${sub_dict.l}
}

omegaconf_example_config = OmegaConf.create(example_config)
# if you print the config naively, it shows just the strings
print("OmegaConf config:", omegaconf_example_config)
# BUT if you access the value, it resolves the string using the custom resolver
print("Value from accessing len_of_l", omegaconf_example_config.sum_a_b)  # prints 3