# Tutorial 3: Demystifying our data transformations -- Mastering the MasterTransformation

The goal of this tutorial is to understand the MasterTransformation class that handles the transformation of data samples before they can be passed to the model for training. The tutorial also includes a visualization of the most important transforms the so-called basis transforms.

In [None]:
# import necessary packages
import os

import matplotlib.pyplot as plt
import numpy as np
import rich
import torch
from hydra import compose, initialize
from hydra.utils import instantiate

# this makes sure that code changes are reflected without restarting the notebook
# this can be helpful if you want to play around with the code in the repo
%load_ext autoreload
%autoreload 2

# omegaconf is used for configuration management
# omegaconf custom resolvers are small functions used in the config files like "get_len" to get lengths of lists
from mldft.utils import omegaconf_resolvers  # this registers omegaconf custom resolvers
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree
from mldft.utils.molecules import build_molecule_ofdata

# download a small dataset from huggingface that contains QM9 and QMugs data (possibly already downloaded)
# and change the DFT_DATA environment variable to the directory where the data is stored

# https://huggingface.co/docs/datasets/cache#cache-directory
# The default cache directory is `~/.cache/huggingface/datasets`
# You can change it by setting this variable to any path you like
CACHE_DIR = None  # e.g. change it to "./hf_cache"

# clone the full repo
# https://huggingface.co/sciai-lab/structures25/tree/main
os.environ[
    "HF_HUB_DISABLE_PROGRESS_BARS"
] = "1"  # to avoid problems with the progress bar in some environments
from huggingface_hub import snapshot_download

data_path = snapshot_download(
    repo_id="sciai-lab/minimal_data_QM9_QMugs", cache_dir=CACHE_DIR, repo_type="dataset"
)

dft_data = os.environ.get("DFT_DATA", None)
os.environ["DFT_DATA"] = data_path
print(
    f"Environment variable DFT_DATA has been changed from {dft_data} to {os.environ['DFT_DATA']}."
)

In [None]:
# first, we load our large config, instantiate the datamodule and obtain a single sample
with initialize(version_base=None, config_path="../../configs/ml"):
    config = compose(
        config_name="train.yaml",
        overrides=[],
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config.paths.output_dir = "example_path"

datamodule = instantiate(config.data.datamodule)
datamodule.setup(stage="fit")
sample = datamodule.train_set[0]

# need basis info to build a pySCF molecule object
# see below for more details on basis_info
basis_info = instantiate(config.data.basis_info)

# build a pySCF molecule object from the OFData sample
mol = build_molecule_ofdata(sample, basis=basis_info.basis_dict)

In [None]:
# one important but (slightly) tricky topic are the transforms that are applied
#  to a data sample when loaded from the dataset
rich.print(dict_to_tree(config.data.transforms, guide_style="dim"))

## 1 The MasterTransformation class

Some of our data transformations are quite expensive and should therefore not be performed on the fly during training.
As a solution, we have precomputed several different transformed versions of our datasets and saved them to the file servers (we call this cached data).
In this tutorial, two transforms have been applied previously to the data and are loaded as "cached":
* the transformation into local frames (local reference frames at every atom)
* and the global symmetric natural reparametrization (natrep),
that is, an orthonormalization of the basis functions

### Cached data, basis transforms, pre- and post transforms

In [None]:
# all transforms are combined into one class called the MasterTransformation
from mldft.ml.data.components.basis_transforms import MasterTransformation

datamodule.transforms.__dict__
print("Name of (cached) transforms:", datamodule.transforms.name)
print("Whether to use cached data:", datamodule.transforms.use_cached_data, "\n")
print("cached_basis_transforms:", datamodule.transforms.cached_basis_transforms, "\n")
# these transforms must therefore not be applied on the fly during training if cached data is used
# however if use_cached_data=FALSE, these transforms are actually still applied on the fly

Global symmetric natrep is the reason that the basis function integrals are no longer zero for a majority of the basis functions (l>0 prior to natrep, cf. [Tutorial 1](./tutorial_1_datamodule.ipynb)).
In fact, natrep performs a change of basis to a new set of basis functions that are **orthonormal** on a global level.

In [None]:
# basis function integrals are no longer zero:
# side note: just ignore the word "dual" in the "dual_basis_integrals" attribute name
print("Basis function integrals after natrep (first 10):", sample.dual_basis_integrals[:10])

#### Global symmetric natrep
By diagonalizing the overlap matrix $O_{\mu\nu}$ of the basis functions $\omega_\mu$,
$$
O_{\mu\nu} = \int \mathrm d^3r \ \omega_\mu \omega_\nu 
$$,
we find a change of basis that can be used to make all basis functions mutually orthogonal. Furthermore, we can normalize the resulting basis functions so that the overlap matrix in the transformed basis becomes the identity matrix.

All other transforms are applied on the fly during training
but there are several different types of such transforms:

First, the **pre_transforms** are directly applied to the OFData sample when it is loaded from the disk.
All molecules (transformed or not) are saved as individual zarr files on the disk.

In the default case the pre_transforms are
* ToTorch: to convert numpy arrays to torch tensors
* ProjectGradient: to project the gradient label orthogonally to the direction in which the number of electrons changes
(this is important since we want to keep the number of electrons constant during density optimization)
* AddFullEdgeIndex: to add a list of edges of a fully connected graph to the sample data
(used for message passing neural networks)

#### Gradient projection
The gradient projection ensures that a step in the direction of the gradient will not change the number of electrons.

Let $\mathrm w_\mu$ be the integral of the basis functions $\omega(\vec r)$:
$$
\mathrm w_\mu = \int \mathrm d^3 \ \vec r \omega_\mu(\vec r)
$$ 
The number of electrons for a given density is then:
$$
N_e = \int \mathrm d^3 \ \vec r \sum_\mu p_\mu \omega_\mu(\vec r) = \sum_\mu p_\mu \mathrm w_\mu = \mathbf p^T \mathbf w
$$.
If we collect all $\mathrm w_\mu$ in a vector $\mathbf w$, then the projection operator that acts on the gradients is given by
$$
\Pi = I - \frac{\mathbf w \mathbf w^T}{\mathbf w^T\mathbf w}
$$. One can easily check that indeed $\Pi \Pi = \Pi$. If we now consider an arbitrary change to our density $p \to p' = p + \Delta p$ the number of electrons of the density will change. But for $p \to p' = p + \Pi \Delta p$ is stays constant:
$$
N_e' = (\mathbf p + \Pi \Delta \mathbf p)^T \mathbf w = (\mathbf p^T + \Delta \mathbf p^T \Pi^T)  \mathbf w = N_e + \Delta \mathbf p^T \Big(I - \frac{\mathbf w \mathbf w^T}{\mathbf w^T\mathbf w} \Big) \mathbf w = N_e + \Delta \mathbf p^T \Big(\mathbf w - \mathbf w \frac{\mathbf w^T\mathbf w}{\mathbf w^T\mathbf w}\Big) = N_e
$$.
Indeed, $\Pi$ is a projection operator that when applied to the gradient step ($\Delta \mathbf p = \text{learning rate} \times \nabla_p E$) preserves the number of electrons $N_e$ of the corresponding electron density.

In [None]:
print("pre_transforms:", datamodule.transforms.pre_transforms)

Second, **additional_pre_transforms**: In contrast to pre_transforms,
additional_pre_transforms are only used if NOT cached data is used
therefore, in our case even though it is specified in the config the
AddOverlapMatrix transform is not applied.


In [None]:
print("additional_pre_transforms:", datamodule.transforms.additional_pre_transforms)

Third, the **basis_transforms** are always applied.

In [None]:
# since we use cached data, we do not apply any basis transforms on the fly
print("basis_transforms:", datamodule.transforms.basis_transforms, "\n")

Fourth, the **post_transforms**: These are also always applied
and typically prepare the data for the model,
e.g. one can change the dtype here between float32 and float64.
In the default case, we use ToTorch to make sure that all attributes in OFData are converted to torch tensors.

In [None]:
print("post_transforms:", datamodule.transforms.post_transforms)

**The reason for our complicated transform structure is the following:**  
Basis transforms, such as the local frames transforms or the natrep transformation, affect the basis functions (see below). Therefore, for consistency, basis transforms transform *all* fields in the sample according to their geometric representation. Thus, the pre_transforms are important to potentially add attributes to the data samples which should then be affected by the basis transforms, e.g. the AddOverlapMatrix transform.

Next, let us manually change the `use_cached_data` to False such that
AddOverlapMatrix transform will actually be applied.
In that case, the data will be loaded as untransformed data
and then the LocalFrames and SymmetricGlobalNatrep basis transforms will be applied on the fly.

In [None]:
import time

# let us initialize a second datamodule but with use_cached_data=False
with initialize(version_base=None, config_path="../../configs/ml"):
    config_no_cache = compose(
        config_name="train.yaml",
        overrides=[
            # we need one simple override here but otherwise we just use the default setting (see tutorial 4 for more information)
            "data.transforms.use_cached_data=False",  # override to not use cached data
        ],
    )

# remove the hydra specific stuff that only works in @hydra.main decorated functions
config_no_cache.paths.output_dir = "example_path"

datamodule_no_cache = instantiate(config_no_cache.data.datamodule)
datamodule_no_cache.setup(stage="fit")

t0 = time.time()
sample_cache = datamodule.train_set[0]
print(
    f"Sample has overlap_matrix: {hasattr(sample_cache, 'overlap_matrix')}, since pre_transforms are not used."
)
print(f"Loading sample with cached transforms took: {time.time()-t0:.2f} seconds")
t0 = time.time()
sample_no_cache = datamodule_no_cache.train_set[0]
print(
    f"Sample has overlap_matrix: {hasattr(sample_no_cache, 'overlap_matrix')}, since pre_transforms are used."
)
print(f"Loading sample without cached transforms took: {time.time()-t0:.2f} seconds")

In [None]:
# We can confirm that after symmetric global natrep,
# the overlap matrix of the basis functions is indeed close to the identity matrix:
print(
    "Overlap matrix close to identity",
    torch.allclose(sample_no_cache.overlap_matrix, torch.eye(sample_no_cache.coeffs.shape[0])),
)

# let's confirm that otherwise the two samples are identical
for key in ["pos", "coeffs", "ground_state_coeffs", "dual_basis_integrals"]:
    print(
        f"Checking that {key} are close:", torch.allclose(sample_cache[key], sample_no_cache[key])
    )

## 2 Visualization of local frames transformation
In our project (when using the Graphformer architecture as in STRUCTURES25), we use local frames to canonicalize the geometric input data to achieve rotational equivariance.

Therefore, let us visualize the local frames (computed base on nearest neighbor positions):
Note that the x-axis (the green arrow) of the local frames is not visible as it always points towards the nearest neighbor atom and is therefore "swallowed" by the bond between these two atoms.

In [None]:
import sys

# keep only the program name so downstream parsers don't see Jupyter's -f=...
sys.argv = sys.argv[:1]

import pyvista

# let use explictily calculate local frames for the given sample:
from mldft.ml.models.components.local_frames_module import (
    LocalFramesTransformMatrixDense,
)
from mldft.utils.visualize_3d import (
    get_local_frames_mesh_dict,
    get_sticks_mesh_dict,
    visualize_orbital,
)

# predict the local frames from the atomic positions and atom types:
local_frames_module = LocalFramesTransformMatrixDense()
transformation_matrix, lframes = local_frames_module.sample_forward(sample, return_lframes=True)

local_frames_mesh = get_local_frames_mesh_dict(
    origins=sample.pos,
    bases=lframes,
    scale=2,
    # axes_radius_scale=0.06
)

# this gives a ball and stick model of the molecule
molecule_mesh = get_sticks_mesh_dict(mol)
molecule_mesh["opacity"] = 1

# plot the molecule and the global frame using pyvista:
pyvista.set_jupyter_backend("html")
pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl.add_mesh(**local_frames_mesh)
pl.add_mesh(**molecule_mesh)
pl.enable_shadows()
pl.reset_camera(
    bounds=0.9 * np.stack([mol.atom_coords().min(0), mol.atom_coords().max(0)], axis=1).flatten()
)

print("3d visualization of local coordinate frames at every atom:")
img = pl.show(screenshot=True, window_size=(800, 400))

Let us illustrate the effect of the local frames transformation at a single basis function.

In [None]:
# let us visualize a single basis function:
basis_function_idx = 370
node_idx = sample.coeff_ind_to_node_ind[basis_function_idx]
coeffs = np.zeros(sample.coeffs.shape)
coeffs[basis_function_idx] = 1.0  # set one coefficient to one, all others to zero

# this can be used to visualize local frames
# (in this case the global coordinate frame at the position at the atom)
global_frame_mesh = get_local_frames_mesh_dict(
    origins=sample.pos[node_idx].view(1, 3),
    # origins=torch.zeros(1, 3),
    bases=torch.eye(3)[None],
    scale=2,
)

pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl = visualize_orbital(
    mol=mol,
    coeff=coeffs,
    plotter=pl,
    mode="isosurface",
    resolution=0.15,
    isosurface_quantile=0.6,
)

pl.add_mesh(**global_frame_mesh)
print("P-Orbital like basis function without transforms (and global frame):")
img = pl.show(screenshot=True, window_size=(800, 400))
plt.show()

#### How to visualize transformed basis functions (some cool theory)

The basis functions are taken from the BasisInfo object (see above where it is initialized, and [Tutorial1](./tutorial_1_datamodule.ipynb) for more details). Short recap: The basis functions $\omega_\mu(\vec r)$ are used to represent the electron density $\rho(\vec r)$ via linear combination:
$$
\rho(\vec r) = \sum_\mu p_\mu \omega_\mu(\vec r) = \mathbf p^T \boldsymbol \omega(\vec r),
$$
where we have grouped the coefficients $p_\mu$ and basis functions $\omega_\mu$ in $d$-dimensional vectors, i.e. $\mathbf p, \boldsymbol \omega(\vec r) \in \mathbb R^d$.

A basis transformation changes the basis function $\omega_\mu$ into new $\omega'_\mu$ that are a linear combination of the $\omega_\mu$. Similarly, the coefficients are transformed into $p'_\mu$ that are a linear combination of the $p_\mu$. Let $A \in \mathrm{GL}(d)$ be a real $d \times d$ basis transformation matrix. Under this transformation, we *demand* that the coefficients transform like vectors, i.e. $p'_\mu = \sum_\nu A_{\mu \nu}  p_\nu$ or in short $\mathbf p' = A \mathbf p$. 

Now, we ask the following question: How must the transformed $\omega'_\mu$ look like so that the density function stays the same, i.e. $\rho'(\vec r) =  \mathbf p'^T \boldsymbol \omega'(\vec r) =  \mathbf p^T \boldsymbol \omega(\vec r) = \rho$. The anser is the following
$$
\boldsymbol \omega' = \big(A^{-1}\big)^T \boldsymbol \omega \ \text{ , since then: } \ \rho' =  \mathbf p'^T \boldsymbol \omega' = (A \mathbf p)^T  \big(A^{-1}\big)^T \boldsymbol \omega = \mathbf p^T \big(A^{-1} A\big)^T \boldsymbol \omega = \rho .
$$

In components, this transformation behavior reads $\omega'_\mu = \big[ \big(A^{-1}\big)^T\big]_{\mu \nu} \omega_\nu = \omega_\nu \big(A^{-1}\big)_{\nu \mu}$
Thus, when interpreting $\boldsymbol \omega$ as row vector, it transforms like $\boldsymbol \omega'^T = \boldsymbol \omega^T A^{-1}$, that is, $\boldsymbol \omega$ transforms as dual vector (with the inverse of $A$ from the right), as can be seen in the `transform_tensor` function in [basis_transforms.py](../../mldft/ml/data/components/basis_transforms.py).  


Lastly, we want to answer the following question: How can we look at the transformed basis functions $\boldsymbol \omega'$ without actually chaning the basis function but by changing the coefficients? For that, we consider the special "density" defined by
$$
\omega'_\sigma(\vec r) = \big (p^{(\sigma)} \big )^T \boldsymbol \omega'(\vec r) , \ \text{ with }  \ p^{(\sigma)}_\mu = \begin{cases} 1 \ \text{ if }  \ \mu = \sigma \\
0  \ \text{ else } \end{cases} .
$$
Now, based on the above considerations, we know how find the appropriate coefficients to visualize this density in the original untransformed basis, namely:
$$
\omega'_\sigma(\vec r) = \big ( \mathbf p^{(\sigma)} \big )^T \boldsymbol \omega'(\vec r) = \big (\mathbf p^{(\sigma)} \big )^T \Big( \big(A^{-1}\big)^T \boldsymbol \omega \Big) = \big (A^{-1}   \mathbf p^{(\sigma)} \big )^T \boldsymbol \omega
$$
So, we conclude that we can effectively visualize the transformed basis function $\omega'_\sigma(\vec r)$ in untransformed basis by using the following coefficients:
$$
\mathbf p'^{(\sigma)} = A^{-1} \mathbf p^{(\sigma)} .
$$ 
This is exactly what we will do below.

In [None]:
# now we apply a transform to the coeffs to see how the basis function changes:
from mldft.ml.data.components.basis_transforms import transform_tensor
from mldft.ml.data.components.of_data import Representation

# actually we transform the coeffs with the inverse to see how the basis function will change:
# (see explanation above)
transformed_coeffs = transform_tensor(
    tensor=torch.from_numpy(coeffs).float(),
    transformation_matrix=transformation_matrix.T,  # the transpose is the inverse for Wigner-D matrices
    inv_transformation_matrix=transformation_matrix,  # the inverse of the inverse is the original matrix
    representation=Representation.VECTOR,  # ensures multiplication with A^{-1} from the left (see above)
)

# this can be used to visualize local frames
global_frame_mesh = get_local_frames_mesh_dict(
    origins=sample.pos[node_idx].view(1, 3),  # at the position of the atom
    bases=lframes[node_idx].view(1, 3, 3),  # use local frame instead of global frame now
    scale=2.5,
    axes_radius_scale=0.06,
)

pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl = visualize_orbital(
    mol=mol,
    coeff=transformed_coeffs,
    plotter=pl,
    mode="isosurface",
    resolution=0.15,
    isosurface_quantile=0.6,
    bond_radius=0.1,
)

pl.add_mesh(**global_frame_mesh)
print("P-Orbital like basis function after local frames transformation (and local frame):")
img = pl.show(screenshot=True, window_size=(800, 400))

For more information, on irreducible representations, equivariance with respect to rotations and the Wigner-D matrices consider watching the following lecture video [part 1](https://www.youtube.com/watch?v=gbEaHqJA9vI) and [part 2](https://www.youtube.com/watch?v=1-Z50VmIf9s).

Indeed, we can see that the basis function is transformed. First the blue part of the handles points in the direction of the red axis of the **global** reference frame (first visulalization), and after the transform the blue part of the handles points in the direction of the red axis of the **local** reference frame.

## 3 Visualization of NatRep transformation


As a next step, let us visualize the effect of global symmetric natrep
together with the local frames transform on the same basis function that we visualized above.


The effect of global symmetric natrep is that the basis function is now a linear combination of all basis functions where the coefficients of that linear combination are obtained from the basis change that diagonalizes the overlap matrix. 


Therefore the basis function will be more delocalized but in particular the *symmetric* version of global natrep ensures that the overlap of the old and the new basis functions is maximized under the following metric:
$$
\left \| u - v \right \|^2 = \int \vert u(\vec r) - v(\vec r) \vert^2 \mathrm \ \mathrm d^3\vec r 
$$
Thus, the old and the new basis function are still fairly similar.
In doing so, the *symmetric* natrep ensures that the new basis functions are still fairly localized.

In [None]:
# the following will ensure that the basis transforms transformation matrix is added to the sample:
datamodule_no_cache.transforms.add_transformation_matrix = True
sample_with_trafo = datamodule_no_cache.train_set[0]

# now we apply a transform to the coeffs to see how the basis function changes:

# actually we transform the coeffs with the inverse to see how the basis function will change:
# (see explanation above)
transformed_coeffs2 = transform_tensor(
    tensor=torch.from_numpy(coeffs).float(),
    transformation_matrix=sample_with_trafo.inv_transformation_matrix,  # use the inverse transformation matrix
    inv_transformation_matrix=sample_with_trafo.transformation_matrix,  # the inverse of the inverse is the original matrix
    representation=Representation.VECTOR,  # ensures multiplication with A^{-1} from the left (see above)
)

# this can be used to visualize local frames
# (in this case just the global coordinate frame at the origin)
global_frame_mesh = get_local_frames_mesh_dict(
    origins=sample.pos[node_idx].view(1, 3),
    bases=lframes[node_idx].view(1, 3, 3),  # use local frame instead of global frame now
    scale=2.5,
    axes_radius_scale=0.06,
)

isosurface_quantile = 0.9
pl = pyvista.Plotter(off_screen=True, notebook=True, image_scale=1)
pl = visualize_orbital(
    mol=mol,
    coeff=transformed_coeffs2,
    plotter=pl,
    mode="isosurface",
    resolution=0.15,
    isosurface_quantile=isosurface_quantile,
    bond_radius=0.1,
)

pl.add_mesh(**global_frame_mesh)

print(
    "P-Orbital like basis function after local frames transformation and global symmetric natrep (and local frame):"
)
print(
    f"Visualized with isosurface_quantile={isosurface_quantile} (play around to see different iso surfaces)."
)
img = pl.show(screenshot=True, window_size=(800, 400))