# Tutorial 6: Density Optimization

In this notebook, you will learn about the density optimization (denop) method.
In more detail, the density will converge towards the groundstate as enough iterations are performed.

## 0 Imports

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import numpy as np
import rich
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

from mldft.ml.models.mldft_module import MLDFTLitModule

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files like "get_len" to get lengths of lists
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree

## 1 (Config) settings for denop

The main denisty optimization block is applied after the model is trained. 
In the command line, it could be execute by the following command: 

```CUDA_VISIBLE_DEVICES=6 python mldft/ofdft/run_density_optimization.py run_path="/export/scratch/ialgroup/dft_str25/models/train/runs/088__from_checkpoint_009__str25\qm9_tf"  n_molecules=10 split=test```

However, it can also be used during training to improve the model's performance. 
For the following notebook, we will use a pretrained model, which can be accessefd by a checkpoint path.

In [None]:
# download a small dataset from huggingface that contains QM9 and QMugs data
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

REPO_ID = "sciai-lab/minimal_data_QM9_QMugs"

print("Using tiny datasets")

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main

os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import hf_hub_download, snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

In [None]:
import contextlib

# load the model from the checkpoint (downloaded from our huggingface model repo):

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# https://huggingface.co/sciai-lab/structures25/tree/main
print("Using QM9 model")
qm9_model_path = hf_hub_download(
    repo_id="sciai-lab/structures25",
    filename="trained-on-qm9/trained-on-qm9.ckpt",
    cache_dir=CACHE_DIR,
)


@contextlib.contextmanager
def _safe_map_location():
    tls = torch._utils._thread_local_state
    had_attr = hasattr(tls, "map_location")
    if not had_attr:
        setattr(tls, "map_location", None)
    try:
        yield
    finally:
        if hasattr(tls, "map_location") and not had_attr:
            delattr(tls, "map_location")


def safe_load_from_ckpt(path):
    for attempt in (1, 2):  # retry once, because the first call can trip the bug
        try:
            with _safe_map_location():
                return MLDFTLitModule.load_from_checkpoint(path, map_location="cpu")
        except AttributeError as e:
            if "map_location" in str(e) and attempt == 1:
                continue
            raise


mldft_module_trained = safe_load_from_ckpt(qm9_model_path)
mldft_module_trained.eval()  # set model to eval mode

print("Successfully loaded trained model from checkpoint:", type(mldft_module_trained))

For this model, we have to use the "local_frames_global_natrep_add_lframe" transformation, which we can access by overwriting the default.

But careful with the dataset statistics! The dataset_statistics used to create the SAD intitial guess is the starting point for denop. In the denop we, find two dataset statistics, one in the the model and one in the data section. They differ by the transformations applied. The initial SAD guess is built completely without any transforms applied and is only later transformed using sample.transformation_matrix shortly before starting the denop. To ensure no transforms are applied we use use the "config_denop.model.dataset_statistics and a specifically called omegaconf custom resolver ("to_no_basis_transforms_dataset_statistics") handels the rest.

With the "datamodule.test_dataloader()", we prepare a test set for density optimization. 

Additionally, we want to batch the data and for demonstartion purposes only use the first batch.

In [None]:
from omegaconf import open_dict

from mldft.ml.data.components.convert_transforms import PrepareForDensityOptimization
from mldft.ofdft.functional_factory import requires_grid
from mldft.utils.omegaconf_resolvers import to_no_basis_transforms_dataset_statistics

# for this model we need to use local frames so we need slightly different transforms
with initialize(version_base=None, config_path="../../configs/ml"):
    config_denop = compose(
        config_name="train.yaml",
        overrides=[
            "data/transforms=local_frames_global_natrep_add_lframes",
            "data.dataset_name=QM9_perturbed_fock",  # this will no longer be necessary once the "fixed" is removed from the dataset_name
            "data.transforms.use_cached_data=False",  # to use untransformed data paths
        ],
    )
    # IMPORTANT: for denop we need to instantiate the model.dataset_statistics, the basis_info and the datamodule
    # use open_dict envrionment to modify the config since it is frozen by default
    with open_dict(config_denop):
        config_denop.model.dataset_statistics = config_denop.data.dataset_statistics.copy()
        config_denop.model.dataset_statistics.path = to_no_basis_transforms_dataset_statistics(
            dataset_statistics_path=config_denop.model.dataset_statistics.path,
            transformation_name=config_denop.data.transforms.name,
        )

model_dataset_statistics_for_denop = instantiate(config_denop.model.dataset_statistics)

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config_denop.paths.output_dir = "example_path"

datamodule = instantiate(config_denop.data.datamodule)
datamodule.setup(stage="test")

Next, some transformations to the right data type have to be done.

Additionally, the denop settings and the density optimizer settings in the config need to be instantiated.

In [None]:
# instantiate the ofdft config:
with initialize(version_base=None, config_path="../../configs/ofdft"):
    config_ofdft = compose(
        config_name="ofdft.yaml",
    )


# print configs for denop:
print("\nConfig for density_optimizer:")
rich.print(dict_to_tree(config_ofdft.optimizer, guide_style="dim"))

# for denop our model needs the following additional things:
# denop_settings = instantiate(config_denop.model.denop_settings)
density_optimizer = instantiate(config_ofdft.optimizer)

In [None]:
from mldft.ml.data.components.dataset import OFDataset

# set pytorch dtype default to float64 for better numerical accuracy in denop
from mldft.ml.data.components.of_data import Representation

torch.set_default_dtype(torch.float64)

# customise the transformations for density optimization:
basis_info = instantiate(config_denop.data.basis_info)
transforms = instantiate(config_denop.data.transforms)
add_grid = requires_grid(
    config_denop.data.target_key, config_ofdft.negative_integrated_density_penalty_weight
)
transforms.pre_transforms.insert(0, PrepareForDensityOptimization(basis_info, add_grid=add_grid))
transforms.add_transformation_matrix = True
transforms.use_cached_data = False

dataset_kwargs = instantiate(config_denop.data.datamodule.dataset_kwargs)
dataset_kwargs.update(
    {
        "limit_scf_iterations": -1,
        "additional_keys_at_ground_state": {
            "of_labels/energies/e_electron": Representation.SCALAR,
            "of_labels/energies/e_ext": Representation.SCALAR,
            "of_labels/energies/e_hartree": Representation.SCALAR,
            "of_labels/energies/e_kin": Representation.SCALAR,
            "of_labels/energies/e_kin_plus_xc": Representation.SCALAR,
            "of_labels/energies/e_kin_minus_apbe": Representation.SCALAR,
            "of_labels/energies/e_kinapbe": Representation.SCALAR,
            "of_labels/energies/e_xc": Representation.SCALAR,
            "of_labels/energies/e_tot": Representation.SCALAR,
        },
    }
)

denop_dataset = OFDataset(
    paths=datamodule.test_set.paths,
    num_scf_iterations_per_path=None,
    basis_info=basis_info,
    transforms=transforms,
    **dataset_kwargs,
)

sample_double = denop_dataset[0]

In [None]:
from mldft.ml.data.components.convert_transforms import ToTorch

# by default for denop we use double precision
print("sample.pos.dtype before ToTorch:", sample_double.pos.dtype)
to_float_32 = ToTorch(float_dtype=torch.float32)
sample_float = sample_double.clone()
sample_float = to_float_32(sample_float)
print("sample.pos.dtype after ToTorch:", sample_float.pos.dtype)

# after converting to float we can do a simple forward pass through the trained model
forward_out_trained = mldft_module_trained.forward(sample_float)

## 2 Functional Factory and SAD guess

Now, the config settings are prepared and instantiated and we can take a look at the "FunctionalFactory" and the Sad guess, which will be calles during the actual denop procedure.

The "FunctionalFactory" is used to create an energy functional from the trained model that can be used for density optimization.

The contributions returned are therefore: ```contributions = [mldft_module_trained, "hartree", "nuclear_attraction"]```
* With our model, we predict T_s + E_xc, i.e. the non-interacting kinetic energy plus the exchange-correlation energy.
* The hartree energy is the classical electron-electron repulsion energy based on the current density.
* The nuclear attraction energy is the attraction of the electrons to the nuclei based on the current density.

--
* Also note: Since the nuclear repulsion energy does not depend on the density, it is not part of the functional
but is computed later on directly from the mol object via nuclear_repulsion = mol.energy_nuc()

In [None]:
# functional factory (which is a slighly more complicated thing)
from mldft.ofdft.functional_factory import FunctionalFactory

func_factory = FunctionalFactory.from_module(
    module=mldft_module_trained,
    xc_functional=config_ofdft.xc_functional,  # not used in our case since we predict T_s + E_xc
    negative_integrated_density_penalty_weight=config_ofdft.negative_integrated_density_penalty_weight,
    # the latter is zero by default (no penalty for regions with negative electron densities)
)

Now, on to the SAD (Sum of Atomic Denisties):

The SAD guess is a sum of independent atom-type specific densities that are based on dataset statistics
and for which the total number of electrons matches the total number of electrons in the molecule.
Even though there are also other first guess methods like MINAO or HÃœCKEL (well established initial guesses already implemented in the `pyscf` package) or the option to learn the initial guess, we use the simple SAD guess as a default in STRUCTURES25 since it is cheapest of the ones listed.

In [None]:
from mldft.ofdft.callbacks import ConvergenceCallback
from mldft.ofdft.density_optimization import density_optimization_with_label
from mldft.utils.sad_guesser import SADNormalizationMode

# since we do use SAD (Sum of Atomic Densities) as initial density guess, we have to specify
# the following keyword arguments that are passed to the SAD guesser,
# see SADGuesser class for details:

sad_guess_kwargs = dict(
    dataset_statistics=model_dataset_statistics_for_denop,
    normalization_mode=SADNormalizationMode.PER_ATOM_WEIGHTED,
    basis_info=basis_info,
    weigher_key="ground_state_only",
    spherical_average=True,
)

## 3 Denop process and results

In [None]:
# change max number of interations:
density_optimizer.max_cycle = 10  # You migth want to change the number of iterations to speed things up, but this migth not guarantee convergence

metric_dict, callback, energies_label, energy_functional = density_optimization_with_label(
    sample=sample_double,  # OFData sample object containing required tensors for the functional.
    mol=sample_double.mol,  # Molecule object used for the initial guess and building the grid (used for eval of XC functional).
    optimizer=density_optimizer,  # Optimizer used for the density optimization process.
    func_factory=func_factory,  # see above
    callback=ConvergenceCallback(),  # specifies which iteration to report as the converged result
    # (in our case of "last_iter" as convergence criterion, this is simply the last iteration.)
    initial_guess_str=config_ofdft.initialization,  # in our case SAD is used as initial guess (see above)
    max_xc_memory=config_ofdft.ofdft_kwargs.max_xc_memory,  # XC is computed on the grid this defines an upper limit for the grid size
    # not relevant when using e_kin_plus_xc as training target
    # best doc string explanation is the following:
    #  Guess of the maximum memory that should be taken by the aos in MB. Total usage might be higher.
    #       Defaults to the pyscf default of 4000MB
    normalize_initial_guess=config_ofdft.ofdft_kwargs.normalize_initial_guess,  # Whether to normalize the initial guess to the correct number of electrons.
    proj_minao_module=None,  # Lightning module used to improve the initial guess from SAD to some learned initial guess.
    sad_guess_kwargs=sad_guess_kwargs,  # see above
    convergence_criterion=config_ofdft.convergence_criterion,  # in our case "last_iter", i.e. we simply take the last iteration and stop iterating if the gradient norm is below the convergence_tolerance
    disable_printing=False,  # Whether to disable printing of the optimization progress.
)

From the "density_optimization_with_label", the following properties are returned:
* metric_dict: Dictionary containing various metrics collected during the optimization process.
* callback: ConvergenceCallback object used to determine convergence.
* energies_label: Dictionary containing the energies computed during the optimization process, including the label energies.
* energy_functional: The energy functional used for the optimization.

If you wnat to you can also look athe results individually:

In [None]:
# contains metrics evaluating the final density and energy after denop (gs=ground_state)
metric_dict

In [None]:
from mldft.ofdft.energies import Energies

# contains the different energies at the "converged" state
energies_label.energies_dict

In [None]:
# The callback contains the states of all iterations
print("Length of callback.states:", len(callback.energy))
# that converged result can be obtained via:
print("Converged result:")
callback.get_convergence_result()

## 4 Visualizations of results

As a last step, let's visualize the density optimization process and see how it converges toward the tolance set.

In [None]:
# plot how the gradient norm evolved during denop:
plt.plot(callback.gradient_norm, label="Gradient norm")
plt.yscale("log")
plt.xlabel("Iteration")
plt.ylabel("Gradient norm")
plt.title("Density optimization convergence")
plt.axhline(
    density_optimizer.convergence_tolerance, color="red", linestyle="--", label="Tolerance"
)
plt.legend()
plt.show()

Furthermore, we want to look at the Energy evolvance per iteration (remember we predict T_s + E_xc, and "hartee" and "nuclear_attractio"n energy  are ajusted acording to the density or independent of it).

In [None]:
from mldft.ofdft.ofstate import OFState

# OFState returned by the ConvergenceCallback,
# amongst other things, it contains the predicted coefficients and energies
# at every iteration
energies_dict = {key: [] for key in energies_label.energies_dict.keys()}
for key in energies_label.energies_dict.keys():
    for energy in callback.energy:
        energies_dict[key].append(energy.energies_dict[key])
    energies_dict[key] = np.array(energies_dict[key])

# build the total energy from the different contributions
total_energy = np.zeros_like(next(iter(energies_dict.values())))
for key in energies_dict.keys():
    total_energy += energies_dict[key]
energies_dict["total_energy"] = total_energy

# plot a curve of how the energy (as predicted by our model) evolved during denop
# kin_plus_xc is the energy that our model directly predicts
# all other energy contributions are computed from the density (with existing functionals)
# the total energy is the sum of all contributions
# important note: the total energy is therefore an approximation to the true DFT energy
# based on our learned functional (model)
for key in energies_dict.keys():
    plt.plot(energies_dict[key], label=key)
plt.xlabel("Denop iteration")
plt.ylabel("Predicted energy (mHa)")
plt.title("Energy during density optimization")
plt.legend(loc=(1.01, 0))
plt.show()

Lastly, a full overview to the denop process:

In [None]:
# okay in the above plot we don't see much change in the energies during denop
# more interesting is to look at the difference between the predicted energies
# and the energy ground state labels
# for that, we can use the plot from plot_density_optimization
# that is also shown in the pdf of denop plots:
from mldft.ml.data.components.basis_transforms import transform_tensor_with_sample
from mldft.ml.data.components.of_data import Representation
from mldft.utils.plotting.density_optimization import plot_density_optimization

# for comparison, transform the ground state coeffs back to the untransformed representation,
# since the trajectory coeffs are transformed back before in the callback
gs_coeffs = transform_tensor_with_sample(
    sample_double, sample_double.ground_state_coeffs, Representation.VECTOR, invert=True
)

fig = plot_density_optimization(
    callback=callback,
    energies_label=energies_label,
    coeffs_label=gs_coeffs,  # ground state density coefficients as label used for computing the density error
    sample=sample_double,
)

In this plot, we first see the energy differences to the ground state energy in mHa. In more detail, 
it oscialltes around the groundstate energy until it converges to a final energy that is slightly above the ground state energy.

The second panel shows the error between the predicted and the target density as well as the gradient norm.
The density error is computed as the L2 norm of the difference between the predicted and target density on the grid.
The gradient norm is the norm of the gradient of the energy with respect to the density coefficients. 

The third panel shows the change in the density coefficients during denop.

Finally, the last panel shows the dipole moment differences to the ground state dipole moment in au (atomic units).


In this tutorial, we have illustrated some of the inner workings behind density optimization. For many of the small steps like getting a data sample one which we can run density optimization or 
there exists some high level functionality in our code base to do them (in the [run_density_optimization.py](../../mldft/ofdft/run_density_optimization.py)):  

`SampleGenerator`: a class to obain individual data samples from a full model/data config  
`run_singlepoint_ofdft`: a function that runs a full density opitmization for the given molecule  