# Tutorial 5: MLDFT lit module

In this Section, we will dive into the model structure and explain how pytorch lightning is used.

## 0 Imports

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import rich
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files like "get_len" to get lengths of lists
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers

# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main
os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

## 1 Config and data loading

As a first step, we have to load the "train.yaml" config as a OmegaConf Dict config. For now, we don't use any overwrites, but just use the default setting for data, optimizer, transforms, basis set, etc. As the [hydra "tree" structure](.notebooks/tutorial_4_hydra_omegaconf.ipynb) is used, this already handles the communication and combination of the different config files, e.g. for data and the model.

After the data is loaded, we focus for demonstration purposes on one individual sample molecule.

In [None]:
from omegaconf.dictconfig import DictConfig

from mldft.utils.molecules import build_molecule_ofdata

# the following initialize already handles the communication and combination
# of the different config files, e.g. for data and the model
with initialize(version_base=None, config_path="../../configs/ml"):
    config = compose(
        config_name="train.yaml",
        overrides=[
            # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)
            "data.dataset_name=QM9_perturbed_fock",  # this will no longer be necessary once the "fixed" is removed from the dataset_name
            # Add trainer overrides for demonstration purposes
            "trainer.max_epochs=1",
            "+trainer.limit_train_batches=1",
            "+trainer.limit_val_batches=1",
            "+trainer.enable_checkpointing=False",
            "data.datamodule.num_workers=0",
        ],
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"

datamodule = instantiate(config.data.datamodule)
datamodule.setup(stage="fit")
datamodule.batch_size = 4  # set batch size to 4 (relatively small) for demonstration purposes
train_loader = datamodule.train_dataloader()

sample = datamodule.train_set[0]

# need basis info to build a pySCF molecule object
# see below for more details on basis_info
basis_info = instantiate(config.data.basis_info)

# build a pySCF molecule object from the OFData sample
mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)

Next, we want to take a look at the machine learning model used to predict
the kinetic energy (and possibly other energies) from a given electron density.
The main module which handles the training is the MLDFTLitModule.

For this, let's take a look at the part of the config that is used to configure the model.
It is a very long and nested config, which specifies everything needed for training.

You will find in there amongst other things:
* The optimizer used to update the model parameters during training.
* The learning rate scheduler used to adjust the learning rate after every epoch during training.
* The loss function used to compute the training loss: It is used for backpropagation
to compute the gradients of the model parameters which will be applied to update each parameter.
* The net which is the main neural network architecture that takes the batched sample as input
and outputs a prediction for the energy.
* The basis_info which specifies the basis set used to represent the density.
* The dataset_statistics used to standardize the input densities and the output energy labels to
improve and stabilize training.
* The density_optimizer and denop_settings which specify how density optimization is performed with a trained model.

Question: Can you find the optimizer and the learning rate that we use for training?
Question: Can you also find the optimizer and learning rate that we use during density optimization (denop)?

In [None]:
import functools

from mldft.ml.models.mldft_module import MLDFTLitModule

rich.print(dict_to_tree(config.model, guide_style="dim"))

# The getattribute function prints a message whenever a hook method is called
# Therfore, we can later see in the output which hooks are called during training
# (e.g., on_train_start, training_step, etc.)
# find more information in the output after the trainer is called


def getattribute(self, name):
    attr = object.__getattribute__(self, name)
    hook_prefixes = ("on_", "training_", "validation_", "test_", "predict_")
    if callable(attr) and any(name.startswith(p) for p in hook_prefixes):

        @functools.wraps(attr)
        def wrapper(*args, **kwargs):
            print(f"Our lightning module is now calling: {name}")
            return attr(*args, **kwargs)

        return wrapper
    return attr


mldft_module = instantiate(config.model)
mldft_module.__class__.__getattribute__ = getattribute
# the MLDFTLitModule inherits from pl.LightningModule
# which is a PyTorch Lightning specific class that handles the training loop
print("Successfully instantiated model:", type(mldft_module))

# 2 Forward pass through model

Now, let's do a forward pass through the model with one batch of data. 

The forward output consists of three parts:
* First, the predicted energy for the given input electron density (in our case kinetic energy + XC energy).
* Second, the predicted gradients of the energy with respect to the input density coefficients.These are computed via automatic differentiation (autodiff) in PyTorch (see example below)
* Third, a direct prediction of the ground state density coefficients (or rather the difference between the input density coeffs and the ground state density coeffs).
The latter, we usually don't use during training, see coefficient_loss has weight 0.0 in the config above.

In [None]:
# let's do a forward pass through the model with one batch of data
# this is taking some time because the model was not moved to the GPU:
print("mldft_module.device:", mldft_module.device, "\n")
batch = next(iter(train_loader))
forward_out = mldft_module.forward(batch)  # which does the same as mldft_module(batch)

print("Model output:", forward_out)

## 3 Training step

This was a single forward pass through the model, but does not yet look much like training
instead we can make a training_step with the model. 

In more detail, during each training step, the following happens:
1. The training loop calls the `training_step` method of the `MLDFTLitModule`.
2. Inside `training_step`, the model processes the input batch to produce predictions.
3. The loss function computes the loss by comparing the predictions to the true labels.
4. Additional training metrics are computed and logged.
5. The compuatational graph is saved for a backward pass.

Afterwards, the optimizer uses the loss to perform backpropagation and update the model weights.

For more information on the optimizer, the Appendix 3 in this notebook can be recommended.

To do the training step, we will need a trainer attached to the model.
(By the way, the model which we have just loaded is untrained, so the loss will be very large.)

**Command for classical training:**

Usually you would start a training with a command similar to this one:  
```CUDA_VISIBLE_DEVICES=2 python mldft/ml/train.py  experiment=str25/qm9_tf```

Quick note: With```CUDA_VISIBLE_DEVICES```, you select which GPU to run the job on. Please, check after accessing the server which GPU is currently free with the following command: ```gpustat```.

With the rest of the command you call the main training script with experiment specific config options.

In [None]:
# instantiate the trainer
# for that, remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"
config.paths.work_dir = "example_path"
trainer = instantiate(config.trainer)
print("Successfully instantiated trainer:", type(trainer))
mldft_module.trainer = trainer  # add the trainer to the module

# also, let us disable the logging for this tutorial:
mldft_module.log = lambda *args, **kwargs: None
mldft_module.log_dict = lambda *args, **kwargs: None

train_step_out = mldft_module.training_step(batch)
print("Output of training step:", train_step_out.keys())

The training step returns a dictionary containing the following things:
* 'loss': the total loss computed for the batch, which is used for backpropagation
* 'model_outputs': containing the three outputs of the forward pass ('pred_energy', 'pred_gradients', 'pred_diff')
* 'projected_gradient_difference': the difference between the predicted and true energy gradients projected (to preserve the number of electrons)

## 4 Running a full training epoch

The LitModule uses a specific "syntax" to handle these training details under the hood.
For instance, the backwards on the loss and also the optimizer step are performed automatically.
The lightning model uses for that by default the "loss" value
returned in the output dictionary of the training step.

Furthermore, in the lightning module we do not see an explicit training loop
that loops over the batches in the train_loader.
This is automatically handled by the pytorch lightning trainer that combines the
model with the dataloader(s), e.g. in
trainer.fit(model, train_dataloader, val_dataloader)

In addition, the logging to the tensorboard is handled simply via self.log in the lightning module
and uses the logger that is attached to the trainer (via the trainer that is attached to the model).

Additionally, the lightning module handles the validation loop automatically.
It works via the validation_step method similar to the training loop based on the training_step method.
There are even quite a bit more methods that follow a standard syntax and can be used to
achieve certain behavior during training, e.g. on_epoch_start, on_epoch_end, etc.
see https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html

Additional information to the lightning module:

1. The **Lightning module** holds our machine learning model at its core and defines what to train and how to train. There are different subclasses implemented in the code:
* The **model architetcture** is defined under "mldft_module.net". Its forward pass is called in the ```forward``` method of the mldft_module. 
+ *Analogy:* You can think of ```__init__``` as the ingredient list of a recipe and ```forward``` as a simple instruction. However, the magic happens in between...
* The **training logic** lies in the ```training step```.
* The **validation and test logic** is in the ```validation_step``` and ```test_step```.
* The **optimizer(s)** are set in ```configure_optimizers```.
+ *Analogy:* To continue with the analogy, you can think of a the additional functions in the lightning module class as your way to optimize the recipe. With each training step you learn new things and adjuste the recipe (parameters). To make sure your changes are actually good, you also continuously validate it. The optimizatation process happens in this cooking example in your brain as you consider adding more salt etc. In the code, the opimization process happens in the optimizer function.

2. The **LightningDataModule** organizes all the data-related logic:
* With the ```setup``` function in the DataModule class one defines how to load the data.
* Also DataLoaders have to be created with in the ```train_dataloader``` and the ```val_dataloader``` etc.
* Optionally, one can define a preprocessing of the data.

+ Note: The LightningDataModule helps keeping data handling clean and separate from the model logic.

3. Think of the **Trainer** as an orchestrator, which handels:
* Training loops
  * Note: You don't manualy write the training loops in Lightning - the ```Trainer```automates them.
* Validation & testing
* Logging
* Checkpointing
* Device placement (CPU, GPU)
* Distributing training 

Important note: In our file structure, you can find a ["train.py"](../../mldft/ml/train.py) file which is the main entry point for training. It instantiates all relevant components, i.e. Trainer, datamodule, lightning module, etc. 

In a subfolder data, there is a ["datamodule.py"](../../mldft/ml/data/datamodule.py) file associated with the DataModule and lastly in the folder models a ["mldft_module.py"](../../mldft/ml/models/mldft_module.py) file which handels the core of the model. 
[Config files](../../configs/ml/train.yaml) (as discussed in [Tutorial 4](./tutorial_4_hydra_omegaconf.ipynb)) are the place where most the variables are stored for the training.

#### Now it all comes together in our very first training epoch
The lightning module follows a specific syntax of methods which will be executed in a very specific order during training. For instance, the `on_train_epoch_start` method will be executed every time when a training epoch is started (as one might have guessed). Similarly, there is the`on_before_backward` method which is called shortly before the backward or the `on_validation_batch_end` that is called after the processing of each validation batch. 

Below, we will execute a "full training" (but only for one epoch) and log all these methods in the order in which they are executed.

In [None]:
import warnings

# freshly instantiate the trainer for a clean state
trainer = instantiate(config.trainer, enable_progress_bar=False)

# disable all user warnings for the following trainer.fit call
with warnings.catch_warnings():
    warnings.simplefilter("ignore", UserWarning)

    trainer.fit(mldft_module, datamodule=datamodule);

# Appendix 1: Automatic differentiaton

We will have a short intermezzo on understanding how automatic differentiation (via backpropagation) works in PyTorch.
When you have a tensor with requires_grad=True, all operations on that tensor are tracked
and a computation graph is built in the background.
Then when you call backward() on a tensor, the gradients of that tensor with respect to
all tensors that have requires_grad=True and were used to compute that tensor
are computed via backpropagation through the computation graph.
Here is a simple example:

In [None]:
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3 * x + 1
print("y:", y, "\n")
y.backward()  # this computes the gradient of y with respect to x via backpropagation
print("dy/dx:", x.grad, "\n")  # dy/dx = 2*x + 3 = 2*2 + 3 = 7

# small subtlety: if you do multiple operations on a tensor
# the gradients are accumulated in the .grad attribute
x = torch.tensor(2.0, requires_grad=True)
y1 = x**2
y2 = x**3
y1.backward()  # this computes the gradient of y1 with respect to x via backpropagation
print("dy1/dx:", x.grad)  # dy1/dx = 2*x = 2*2 = 4
y2.backward()  # this computes the gradient of y2 with respect to x via backpropagation
print("dy1/dx + dy2/dx:", x.grad, "\n")  # dy1/dx + dy2/dx = 2*x + 3*x**2 = 2*2 + 3*2**2 = 16

# to zero the gradients, you can use the zero_() method
x.grad.zero_()
print("zeroed gradients:", x.grad, "\n")

# detach can be used to stop tracking operations on a tensor
x = torch.tensor(2.0, requires_grad=True)
y = x**2
z = y.detach() + 3 * x  # detach stops tracking operations on y
print("z:", z)
z.backward()  # this computes the gradient of z with respect to x via backpropagation
print("dz/dx:", x.grad, "\n")  # dz/dx = 3, since y was detached

# by default, after one calls backward(), the computation graph is deleted to save memory
# if you want to call backward() multiple times on the same graph, for instance to compute a second derivative
# (as we actually do in our project when we first compute the energy gradient w.r.t. the density
# and then use that energy gradient to compute a loss function that is then used
# for another backward call to update the model parameters)
# in this case you need to specify retain_graph=True
x = torch.tensor(2.0, requires_grad=True)
y = x**6
dy_dx = torch.autograd.grad(y, x, create_graph=True, retain_graph=True)[
    0
]  # this computes dy/dx = 6*x**5
print("dy/dx:", dy_dx)
d2y_dx2 = torch.autograd.grad(dy_dx, x)[0]  # this computes d2y/dx2 = 30*x**4 = 30*2**4 = 480
print("d2y/dx2:", d2y_dx2, "\n")

## Appendix 2: Partial

Since the optimizer in the config is only partially ("_partial_") initialized, we want to take a look at what this actually means in the example below.

As you might know the standard normal distribution is a special case of a classical Gaussian distribution. To include this knowledge but simplify futher calling, we could use the partial function and with it specify the necessary mean and standard deviation properties that make a Gaussian a standard normal distribution.

In [None]:
from functools import partial


def gaussian(x, mean, std):
    return torch.exp(-0.5 * ((x - mean) / std) ** 2) / (std * (2 * torch.pi) ** 0.5)


standard_normal = partial(gaussian, mean=0.0, std=1.0)
print(standard_normal)
# standard_normal is now a function that only takes x as argument
# and mean and std are fixed to 0.0 and 1.0 respectively
x = torch.linspace(-5, 5, steps=100)
y = standard_normal(x=x)
plt.plot(x.numpy(), y.numpy())
plt.title("Standard normal distribution")
plt.xlabel("x")
plt.ylabel("Probability density")
plt.show()

## Appendix 3: Optimizer

Now, we want to examplify how the updating of the model parameters works during training
for that we need to attach an optimizer to the model. 

Next, if you look carefully you will find that the optimizer in the config is only partially ("_partial_") initialized. This means that some of the arguments are missing and will be filled in later (more info see Appendix 1). In particular the model parameters that should be optimized are missing, because the model parameters are not known before the model is instantiated.


In [None]:
optimizer_partially_initialized = instantiate(config.model.optimizer)
optimizer = optimizer_partially_initialized(params=mldft_module.parameters())

# an alternative more compact option would be the following:
# optimizer = instantiate(config.model.optimizer, params=mldft_module.parameters())

mldft_module.optimizer = optimizer  # add the optimizer to the module
print("Successfully instantiated optimizer and linked it with model parameters:", type(optimizer))

for name, model_param in mldft_module.named_parameters():
    print(name, model_param.shape)
    break  # just the first parameter

# first we zero the gradients of the model parameters
mldft_module.optimizer.zero_grad()
# print the gradient of the first parameter (should be None after zeroing the grads):
print("Gradient of first parameter before backward:", model_param.grad)
# then we call backward on the loss to compute the gradients of the model parameters
try:
    train_step_out["loss"].backward(
        retain_graph=False
    )  # so that this cell can in principle be run multiple times
    # now the gradients of the model parameters are stored in the .grad attribute of each parameter
    print("Gradient of first parameter after backward:", model_param.grad, model_param.grad.shape)
    old_model_param = (
        model_param.clone().detach()
    )  # clone and detach to keep a copy of the old parameters

    # now, we can update the model parameters with one step of the optimizer
    mldft_module.optimizer.step()
    print(
        "Maximum relative change in first parameter after one optimizer step:",
        ((model_param - old_model_param) / old_model_param).abs().max(),
    )
except RuntimeError as e:
    print("Caught expected RuntimeError due to multiple backward calls on the same graph.")

# Appendix 4: Dataset statistics

One small but not to be underestimated detail of our training are the dataset statistics.
These are used to standardize the input densities and the output energy labels
to improve and stabilize training.
As such, the dataset statistics are specific to which dataset (QM9 or QMUGS) and energy label is used (E_kin, E_xc, E_kin + E_xc, etc.),
as well as to which transforms are applied to the input densities (e.g. local_frames_global_symmetric_natrep).

The dataset_statistics are essentially a .zarr folder, which can be seen in the config path. After instantiating it, we see for each relevant quantity some additonal statistical values, like the mean and std, as well as the abs_max value.

In [None]:
# let's take a look at the respective part in the config to verify that:
rich.print(dict_to_tree(config.data.dataset_statistics, guide_style="dim"))

from mldft.ml.preprocess.dataset_statistics import DatasetStatistics

dataset_statistics = instantiate(config.data.dataset_statistics)
print("Successfully instantiated dataset_statistics:", type(dataset_statistics))
dataset_statistics