# SchNet

![SchNet](https://user-images.githubusercontent.com/11532812/60562621-3a935c00-9d93-11e9-8af3-59318e61172b.png)

SchNet is a state-of-the-art deep neural architecture that predicts molecular energies:

https://paperswithcode.com/sota/formation-energy-on-qm9

This kernel shows how to train SchNet models using SchNetPack.

The SchNet is constructed from the following conmponents: 

## Embedding
The SchNet uses an embedding depending on the type of the center atom
![embedding](https://user-images.githubusercontent.com/11532812/60563147-8d6e1300-9d95-11e9-9734-3df9b5d457fa.png)
for the initial representation of local chemical environment *i*.

## Atom-wise, Fully-connected Layers
![Atom-wise FC](https://user-images.githubusercontent.com/11532812/60563347-66fca780-9d96-11e9-8b4a-63debdb676ac.png)

## Continuous-filter Convolutional Layers
![cont-filter-conv](https://user-images.githubusercontent.com/11532812/60563492-002bbe00-9d97-11e9-8784-9e7393a18285.png)

## Filter-generating Networks
Distance features are created by expanding the pair-wise distances:
![distance-expansion](https://user-images.githubusercontent.com/11532812/60562913-7e3a9580-9d94-11e9-9067-ba40e70b2213.png)

The feature vector is used as an input for a mulilayer fully-connected neural network.

## Output Layers
The atomic features {**x**_i} are fed to atom-wise multilayer perceptrons, which output atom-wise contributions of a chemical property. The property is calcuated by the sum of those contributions. 

### Potential Energy
![potential-energy](https://user-images.githubusercontent.com/11532812/60563771-20a84800-9d98-11e9-8720-bf8cff96157a.png)
E_i is an atom-wise contribution of the potential energy E.

### Dipole Moment
![dipole-moment](https://user-images.githubusercontent.com/11532812/60563830-53ead700-9d98-11e9-9908-98c06b95b1cf.png)
q_i is an atomic charge produced by the SchNet, and r_i is an atomic position.

## References
* K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller.
SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.
Advances in Neural Information Processing Systems 30, pp. 992-1002 (2017)
[link](http://papers.nips.cc/paper/6700-schnet-a-continuous-filter-convolutional-neural-network-for-modeling-quantum-interactions)
[arXiv](https://arxiv.org/abs/1706.08566v5)
* K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller.
SchNet - a deep learning architecture for molecules and materials.
The Journal of Chemical Physics 148(24), 241722 (2018)
[link](https://doi.org/10.1063/1.5019779)
[arXiv](https://arxiv.org/abs/1712.06113)
* K. T. Schütt, P. Kessel, M. Gastegger, K. A. Nicoli, A. Tkatchenko, and K.-R. Müller.
SchNetPack: A Deep Learning Toolbox For Atomistic Systems.
Journal of Chemical Theory and Computation 15(1), 448-455 (2019)
[link](https://pubs.acs.org/doi/10.1021/acs.jctc.8b00908)
[arXiv](https://arxiv.org/pdf/1809.01072.pdf)
* Kim A. Nicoli, Pan Kessel, Michael Gastegger, Kristof T. Schütt.
Analysis of Atomistic Representations Using Weighted Skip-Connections.
32nd Conference on Neural Information Processing Systems (NIPS 2018)
[arXiv](https://arxiv.org/abs/1810.09751)
* Kristof T. Schütt, Alexandre Tkatchenko, Klaus-Robert Müller.
Learning representations of molecules and materials with atomistic neural networks.
(2018)
[arXiv](https://arxiv.org/abs/1812.04690)

## Implemetations
* PyTorch:
  https://github.com/atomistic-machine-learning/schnetpack
* TensorFlow:
  https://github.com/atomistic-machine-learning/SchNet


In [None]:
!pip install ase==3.17 schnetpack

We need `ASE 3.17` for `SchNetPack 0.2.1`.

The following code is a derivative work from

https://github.com/atomistic-machine-learning/schnetpack/blob/v0.2.1/src/examples/qm9_schnet.py

In [None]:
import pandas as pd

import torch
import torch.nn.functional as F
from torch.optim import Adam

import schnetpack as spk
import schnetpack.atomistic as atm
import schnetpack.representation as rep
from schnetpack.datasets import *

device = torch.device("cuda")

# load qm9 dataset and download if necessary
data = QM9("qm9/", properties=[QM9.U0], remove_uncharacterized=True)

# Statistics
energies = [data[i][QM9.U0].item() for i in range(len(data))]
energies = pd.Series(energies, name=QM9.U0)
display(energies.describe())
ax = energies.hist(bins=50)
_ = ax.set_xlabel(QM9.U0)

In [None]:
#!rm -r output log

# split in train and val
n_val = 10000
train, val, test = data.create_splits(len(data)-n_val*2, n_val)
loader = spk.data.AtomsLoader(train, batch_size=128, num_workers=2)
val_loader = spk.data.AtomsLoader(val, batch_size=256, num_workers=2)

# create model
reps = rep.SchNet(n_interactions=6)
output = atm.Atomwise()
model = atm.AtomisticModel(reps, output)
model = model.to(device)

# create trainer
max_epochs = 100
opt = Adam(model.parameters(), lr=2e-4, weight_decay=1e-6)
loss = lambda b, p: F.mse_loss(p["y"], b[QM9.U0])
metric_list = [
    spk.metrics.MeanAbsoluteError(QM9.U0, "y"),
    spk.metrics.RootMeanSquaredError(QM9.U0, "y"),
]
hooks = [
    spk.train.MaxEpochHook(max_epochs),
    spk.train.CSVHook('log', metric_list, every_n_epochs=1),
]
trainer = spk.train.Trainer("output/", model, loss,
                            opt, loader, val_loader, hooks=hooks)

# start training
trainer.train(device)

In [None]:
df = pd.read_csv('log/log.csv')
display(df.tail())
_ = df[['MAE_energy_U0','RMSE_energy_U0']].plot(ylim=(0,100))

In [None]:
# This function comes from the following script:
# https://github.com/atomistic-machine-learning/schnetpack/blob/v0.2.1/src/scripts/schnetpack_qm9.py
def evaluate_dataset(metrics, model, loader, device):
    for metric in metrics:
        metric.reset()

    with torch.no_grad():
        for batch in loader:
            batch = {
                k: v.to(device)
                for k, v in batch.items()
            }
            result = model(batch)

            for metric in metrics:
                metric.add_batch(batch, result)

    results = [
        metric.aggregate() for metric in metrics
    ]
    return results

In [None]:
model.load_state_dict(torch.load('output/best_model'))
test_loader = spk.data.AtomsLoader(test, batch_size=256, num_workers=2)
model.eval()

df = pd.DataFrame()
df['metric'] = ['MAE', 'RMSE']
df['training'] = evaluate_dataset(metric_list, model, loader, device)
df['validation'] = evaluate_dataset(metric_list, model, val_loader, device)
df['test'] = evaluate_dataset(metric_list, model, test_loader, device)
df

In [None]:
!ls