# Installing the necessary libraries

In [1]:
# Install compatible versions
!pip install --no-cache-dir torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install --no-cache-dir sympy==1.12  # Required for schnetpack
!pip install --no-cache-dir schnetpack==2.1.1  # Works with sympy 1.12
!pip install --no-cache-dir pytorch-lightning==2.2.1  # Compatible with torch 2.4.1
!pip install --no-cache-dir numpy  # No known issues
!pip install --no-cache-dir ase==3.23

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl (857.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m857.6/857.6 MB[0m [31m154.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.19.1%2Bcu118-cp311-cp311-linux_x86_64.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m98.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m185.4 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.4.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.8.89-py

Collecting pytorch-lightning==2.2.1
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl.metadata (21 kB)
Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch-lightning
  Attempting uninstall: pytorch-lightning
    Found existing installation: pytorch-lightning 2.5.0.post0
    Uninstalling pytorch-lightning-2.5.0.post0:
      Successfully uninstalled pytorch-lightning-2.5.0.post0
Successfully installed pytorch-lightning-2.2.1
Collecting ase==3.23
  Downloading ase-3.23.0-py3-none-any.whl.metadata (3.8 kB)
Downloading ase-3.23.0-py3-none-any.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ase
  Attempting uninstall: ase
    Found existing installation: ase 3.24.0
    Uninstalling ase-3.24.0:
    

In [15]:
import os
import ase
import schnetpack as spk
from schnetpack.datasets import QM9
from schnetpack.transform import ASENeighborList
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl
from schnetpack.data import ASEAtomsData
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [33]:
qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
  os.makedirs(qm9tut)

In [34]:
!rm -rf qm9.db split_qm9.npz
qm9data = QM9(
    './qm9.db',
    batch_size =10,
    num_train=110000,
    num_val=10000,
    split_file='./split_qm9.npz',
    transforms=[ASENeighborList(cutoff=5.)]
)
qm9data.prepare_data()
qm9data.setup()

100%|██████████| 133885/133885 [03:25<00:00, 650.52it/s]


In [35]:
print('Number of reference calculations:', len(qm9data.dataset))
print('Number of train data:', len(qm9data.train_dataset))
print('Number of test data:', len(qm9data.test_dataset))
print('Available properties:')

for p in qm9data.dataset.available_properties:
  print('-', p)

Number of reference calculations: 133885
Number of train data: 110000
Number of test data: 13885
Available properties:
- rotational_constant_A
- rotational_constant_B
- rotational_constant_C
- dipole_moment
- isotropic_polarizability
- homo
- lumo
- gap
- electronic_spatial_extent
- zpve
- energy_U0
- energy_U
- enthalpy_H
- free_energy
- heat_capacity


In [36]:
example = qm9data.dataset[0]
print('Properties:')

for k, v in example.items():
  print('-', k, ':', v.shape)

Properties:
- _idx : torch.Size([1])
- rotational_constant_A : torch.Size([1])
- rotational_constant_B : torch.Size([1])
- rotational_constant_C : torch.Size([1])
- dipole_moment : torch.Size([1])
- isotropic_polarizability : torch.Size([1])
- homo : torch.Size([1])
- lumo : torch.Size([1])
- gap : torch.Size([1])
- electronic_spatial_extent : torch.Size([1])
- zpve : torch.Size([1])
- energy_U0 : torch.Size([1])
- energy_U : torch.Size([1])
- enthalpy_H : torch.Size([1])
- free_energy : torch.Size([1])
- heat_capacity : torch.Size([1])
- _n_atoms : torch.Size([1])
- _atomic_numbers : torch.Size([5])
- _positions : torch.Size([5, 3])
- _cell : torch.Size([1, 3, 3])
- _pbc : torch.Size([3])


In [37]:
for batch in qm9data.val_dataloader():
  print(batch.keys())
  print('Isotropic Polarizability:', batch['isotropic_polarizability'])
  break

dict_keys(['_idx', 'rotational_constant_A', 'rotational_constant_B', 'rotational_constant_C', 'dipole_moment', 'isotropic_polarizability', 'homo', 'lumo', 'gap', 'electronic_spatial_extent', 'zpve', 'energy_U0', 'energy_U', 'enthalpy_H', 'free_energy', 'heat_capacity', '_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_j', '_idx_i'])
Isotropic Polarizability: tensor([66.4100, 74.6600, 70.2400, 67.2600, 84.0100, 56.5300, 83.5500, 71.0000,
        78.7400, 55.7300], dtype=torch.float64)


In [38]:
print('system index:', batch['_idx_m'])
print('Center atom index:', batch['_idx_i'])
print('Neighbor atom index:', batch['_idx_j'])

system index: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9])
Center atom index: tensor([  0,   0,   0,  ..., 174, 174, 174])
Neighbor atom index: tensor([ 16,  15,  14,  ..., 171, 172, 173])


In [39]:
import ase.units

ase.units.a0 = ase.units.Bohr

qm9data = QM9(
    './qm9.db',
    batch_size=100,
    num_train=1000,
    num_val=1000,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.RemoveOffsets(QM9.alpha, remove_mean=True, remove_atomrefs=False),
        trn.CastTo32()
    ],
    property_units={QM9.alpha: 'Bohr'},
    num_workers=2,
    split_file=os.path.join(qm9tut, 'split.npz'),
    pin_memory=True,
    load_properties=[QM9.alpha],
)
qm9data.prepare_data()
qm9data.setup()

100%|██████████| 10/10 [00:02<00:00,  4.28it/s]


In [40]:
means, stddevs = qm9data.get_stats(
    QM9.alpha, divide_by_atoms=True, remove_atomref=False
)
print('Mean atomization energy / atoms:', means.item())
print('Std. dev. atomization energy / atom:', stddevs.item())

Mean atomization energy / atoms: 1.1944486331826614
Std. dev. atomization energy / atom: 0.1530654943424375


# Setting up the model

In [41]:
cutoff = 5.
n_atom_basis = 40

pairwise_distance = spk.atomistic.PairwiseDistances()  # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
schnet = spk.representation.SchNet(
    n_atom_basis=n_atom_basis, n_interactions=3,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
)
pred_alpha = spk.atomistic.Atomwise(n_in=n_atom_basis, output_key=QM9.alpha)

nnpot = spk.model.NeuralNetworkPotential(
    representation=schnet,
    input_modules=[pairwise_distance],
    output_modules=[pred_alpha],
    postprocessors=[trn.CastTo64(), trn.AddOffsets(QM9.alpha, add_mean=True, add_atomrefs=False)]
)

In [42]:
output_alpha = spk.task.ModelOutput(
    name=QM9.alpha,
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1.,
    metrics={
        "MAE" : torchmetrics.MeanAbsoluteError()
    }
)

In [43]:
task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_alpha],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)

## Training the model

The model is now ready for training. Since we already defined all necessary components, the only thing left to do is passing it to the pytorch Lightning Trainer together with the data module.
Additionally, we can provide callbacks that take care of logging, checkpointing etc.

In [44]:
logger = pl.loggers.TensorBoardLogger(save_dir=qm9tut, name="logs")
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(qm9tut, "best_inference_model"),
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=False,
    default_root_dir=qm9tut,
    max_epochs=5,
)
trainer.fit(task, datamodule=qm9data)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 26.9 K
1 | outputs | ModuleList             | 0     
---------------------------------------------------
26.9 K    Trainable params
0         Non-trainable params
26.9 K    Total params
0.108     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


# Inference

Having trained a model for QM9, we are going to use it to obtain some predictions. First, we need to load the model. The Trainer stores the best model in the model directory which can be loaded using PyTorch.

In [45]:
import torch
import numpy as np
from ase import Atoms

best_model = torch.load(os.path.join(qm9tut, 'best_inference_model'), map_location="cpu")

  best_model = torch.load(os.path.join(qm9tut, 'best_inference_model'), map_location="cpu")


In [46]:
for batch in qm9data.test_dataloader():
  result = best_model(batch)
  print("Result dictionary;", result)
  break

Result dictionary; {'isotropic_polarizability': tensor([20.9852, 19.7717, 18.8356, 16.3683, 20.8592, 19.2635, 10.5133, 19.8983,
        18.6788, 23.7478, 21.5835, 23.5999, 15.8748, 20.6478, 22.0007, 21.0209,
        18.1073, 22.4108, 24.8168, 20.8264, 25.0146, 20.4892, 18.5690, 23.2242,
        22.2541, 12.0947, 24.5473, 18.2514, 23.9130, 22.9744, 18.2944, 22.7574,
        24.9095, 21.8899, 24.5899, 20.4416, 19.9822, 25.7000, 20.7515, 20.3550,
        24.2720, 23.1929, 22.8775, 23.0014, 22.5407, 21.7377, 17.8098, 18.4268,
        19.8969, 21.6913, 10.7409, 18.2805, 22.9792,  9.5826, 18.5075, 18.8369,
        21.4769, 22.4666, 18.8475, 20.0384, 22.3571, 22.9299, 17.6007, 23.6132,
        16.7335, 21.8160, 13.4048, 20.8975, 21.4407, 18.1268, 12.8401, 21.1969,
        24.5541, 20.6162, 23.3412, 15.7490, 19.4706, 22.4822, 25.8929, 17.0330,
        21.4406, 19.1577, 18.5924, 18.0052, 19.4345, 19.2161, 23.8298, 18.9110,
        23.0666, 23.5798, 23.7668, 16.8106, 20.9924, 23.3610, 24.5090, 2

In [47]:
converter = spk.interfaces.AtomsConverter(neighbor_list=trn.ASENeighborList(cutoff=5.), dtype=torch.float32)

In [48]:
numbers = np.array([6, 1, 1, 1, 1])
positions = np.array([[-0.0126981359, 1.0858041578, 0.0080009958],
                      [0.002150416, -0.0060313176, 0.0019761204],
                      [1.0117308433, 1.4637511618, 0.0002765748],
                      [-0.540815069, 1.4475266138, -0.8766437152],
                      [-0.5238136345, 1.4379326443, 0.9063972942]])
atoms = Atoms(numbers=numbers, positions=positions)

In [49]:
inputs = converter(atoms)

print('Keys:', list(inputs.keys()))

pred = best_model(inputs)

print('Prediction:', pred[QM9.alpha])

Keys: ['_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_j', '_idx_i']
Prediction: tensor([6.4133], dtype=torch.float64, grad_fn=<AddBackward0>)


In [50]:
calculator = spk.interfaces.SpkCalculator(
    model_file=os.path.join(qm9tut, "best_inference_model"),  # Path to model
    neighbor_list=trn.ASENeighborList(cutoff=5.),
    polar_key=QM9.alpha,  # Name of polarizability property in model
    energy_unit='Bohr',
    device='cpu'
)
atoms.set_calculator(calculator)
print('Prediction:', )

  model = torch.load(model_path, map_location=device, **kwargs)
  atoms.set_calculator(calculator)


AtomsConverterError: 'energy' is not a property of your model. Please check the model properties!