# Training a neural network on QM9

This tutorial will explain how to use SchNetPack for training a model
on the QM9 dataset and how the trained model can be used for further applications.

First, we import the necessary modules and create a new directory for the data and our model.

In [1]:
import os
import schnetpack as spk
from schnetpack.datasets import QM9
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


## Loading the data

As explained in the [previous tutorial](tutorial_01_preparing_data.ipynb), datasets in SchNetPack are loaded with the `AtomsLoader` class or one of the sub-classes that are specialized for common benchmark datasets. 
The `QM9` dataset class will download and convert the data. We will only use the inner energy at 0K `U0`, so all other properties do not need to be loaded:

In [2]:
qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
    os.makedirs(qm9tut)

%rm qm9tut/split.npz

qm9data = QM9(
    f'{qm9tut}/qm9.db', 
    batch_size=100,
    num_train=110000,
    num_val=10000,
    transforms=[
        trn.SubtractCenterOfMass(),
        #trn.ASENeighborList(cutoff=5.),
        trn.MatScipyNeighborList(cutoff=5.),
        #trn.RemoveOffsets(QM9.mu, remove_mean=True, remove_atomrefs=True),
        trn.CastTo32()
    ],
    property_units={QM9.mu: 'Debye'},
    num_workers=0,
    split_file=os.path.join(qm9tut, "split.npz"),
    pin_memory=False, # set to false, when not using a GPU
    load_properties=[QM9.mu], #only load mu property
)
qm9data.prepare_data()
qm9data.setup()

The dataset is downloaded and partitioned automatically. PyTorch `DataLoader`s can be obtained using `qm9data.train_dataloader()`, `qm9data.val_dataloader()` and `qm9data.test_dataloader()`.

Before building the model, we remove offsets from the energy for good initial conditions. We will get this from the training dataset. Above, this is done automatically by the `RemoveOffsets` transform.
In the following we show what happens under the hood.
For QM9, we also have single-atom reference values stored in the metadata:

These can be used together with the mean and standard deviation of the energy per atom to initialize the model with a good guess of the energy of a molecule. When calculating these statistics, we pass the atomref to take into account, that the model will add these atomrefs to the predicted energy later, so that this part of the energy does not have to be considered in the statistics, i.e.
\begin{equation}
\mu_{U_0} = \frac{1}{n_\text{train}} \sum_{n=1}^{n_\text{train}} \left( U_{0,n} - \sum_{i=1}^{n_{\text{atoms},n}} U_{0,Z_{n,i}} \right)
\end{equation}
for the mean and analogously for the standard deviation. In this case, this corresponds to the mean and std. dev of the *atomization energy* per atom.

means, stddevs = qm9data.get_stats(
    QM9.mu, divide_by_atoms=True, remove_atomref=True
)
print('Mean dipole moment / atom:', means.item())
print('Std. dev. dipole moment / atom:', stddevs.item())

## Setting up the model

Next, we need to build the model and define how it should be trained.

In SchNetPack, a neural network potential usually consists of three parts:

1. A list of input modules that prepare the batched data before the building the representation.
   This includes, e.g., the calculation of pairwise distances between atoms based on neighbor indices or add auxiliary
   inputs for response properties.
2. The representation which either constructs atom-wise features, e.g. with SchNet or PaiNN.
3. One or more output modules for property prediction.

Here, we use the `SchNet` representation with 3 interaction layers, a 5 Angstrom cosine cutoff with pairwise distances
expanded on 20 Gaussians and 50 atomwise features and convolution filters, since we only have a few
training examples. Then, we use an `Atomwise` module to predict the inner energy $U_0$ by summing over atom-wise
energy contributions.

In [3]:
cutoff = 5.
n_atom_basis = 128

pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
painn = spk.representation.PaiNN(
    n_atom_basis=n_atom_basis, n_interactions=3,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
)
#pred_U0 = spk.atomistic.Atomwise(n_in=n_atom_basis, output_key=QM9.U0)
pred_mu = spk.atomistic.DipoleMoment(n_in=n_atom_basis, dipole_key='dipole_moment', predict_magnitude=True, use_vector_representation=False)

nnpot = spk.model.NeuralNetworkPotential(
    representation=painn,
    input_modules=[pairwise_distance],
    output_modules=[pred_mu],
    postprocessors=[trn.CastTo64()]
)

The last argument here is a list of postprocessors that will only be used if `nnpot.inference_mode=True` is set.
It will not be used in training or validation, but only for predictions.
Here, this is used to deal with numerical accuracy and normalization of model outputs:
To make training easier, we have subtracted single atom energies as well as the mean energy per atom
in the preprocessing (see above).
This does not matter for the loss, but for the final prediction we want to get the real energies.
Additionally, we have removed the energy offsets *before* casting to float32 in the preprocessor.
This avoids loss of numerical precision.
Analog to this, we also have to first cast to float64, before re-adding the offsets in the post-processor

The output modules store the prediction in a dictionary under the `output_key` (here: `QM9.U0`), which is connected to
a target property with loss functions and evaluation metrics using the `ModelOutput` class:

In [4]:
output_mu = spk.task.ModelOutput(
    name=QM9.mu,
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1.,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)

By default, the target is assumed to have the same name as the output. Otherwise, a different `target_name`
has to be provided.
Here, we already gave the output the same name as the target in the dataset (`QM9.U0`).
In case of multiple outputs, the full loss is a weighted sum of all output losses.
Therefore, it is possible to provide a `loss_weight`, which we here just set to 1.

All components defined above are then passed to `AtomisticTask`, which is a sublass of
[`LightningModule`](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html).
This connects the model and training process and can then be passed to the PyTorch Lightning `Trainer`.

In [5]:
task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_mu],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)

  rank_zero_warn(


## Training the model

Now, the model is ready for training. Since we already defined all necessary components, the only thing left to do is
passing it to the PyTorch Lightning `Trainer` together with the data module.

Additionally, we can provide callbacks that take care of logging, checkpointing etc.

The `ModelCheckpoint` of SchNetPack is equivalent to that in PyTorch Lightning,
except that we also store the best inference model. We will show how to use this in the next section.

You can have a look at the training log using Tensorboard:
```
tensorboard --logdir=qm9tut/default
```



## Testing

Having trained a model for QM9, we are going to use it to obtain some predictions.
First, we need to load the model. The `Trainer` stores the best model in the model directory which can be loaded using PyTorch:

In [6]:
import torch
import numpy as np
from ase import Atoms

best_model = torch.load(os.path.join(qm9tut, 'best_inference_model'), map_location='cpu')

We can use the test dataloader from the QM( data to obtain a batch of molecules and apply the model:

In [7]:
def infer(dataloader):
    reals = torch.tensor([])
    preds = torch.tensor([])
    for batch in dataloader:
        real = batch[QM9.mu]
        pred = best_model(batch)["dipole_moment"]
        reals = torch.cat([reals, real])
        preds = torch.cat([preds, pred])
    return reals, preds

In [None]:
test_reals, test_preds = infer(qm9data.test_dataloader())
print(f"Test MAE: {torch.mean(torch.abs(test_reals-test_preds))}")
print(f"Test MSE: {torch.mean(torch.pow(test_reals-test_preds,2))}")
val_reals, val_preds = infer(qm9data.val_dataloader())
print(f"Val MAE: {torch.mean(torch.abs(val_reals-val_preds))}")
print(f"Val MSE: {torch.mean(torch.pow(val_reals-val_preds,2))}")
train_reals, train_preds = infer(qm9data.train_dataloader())
print(f"Train MAE: {torch.mean(torch.abs(train_reals-train_preds))}")
print(f"Train MSE: {torch.mean(torch.pow(train_reals-train_preds,2))}")