In [5]:
# Install compatible versions
!pip install --no-cache-dir torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install --no-cache-dir sympy==1.12  # Required for schnetpack
!pip install --no-cache-dir schnetpack==2.1.1  # Works with sympy 1.12
!pip install --no-cache-dir pytorch-lightning==2.2.1  # Compatible with torch 2.4.1
!pip install --no-cache-dir numpy  # No known issues
!pip install --no-cache-dir ase==3.23

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting torch==2.4.1
  Downloading https://download.pytorch.org/whl/cu118/torch-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl (857.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m857.6/857.6 MB[0m [31m245.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.19.1
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.19.1%2Bcu118-cp311-cp311-linux_x86_64.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m290.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.4.1
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.4.1%2Bcu118-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m206.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch==2.4.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.8.89-p

Collecting pytorch-lightning==2.2.1
  Downloading pytorch_lightning-2.2.1-py3-none-any.whl.metadata (21 kB)
Downloading pytorch_lightning-2.2.1-py3-none-any.whl (801 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.6/801.6 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch-lightning
  Attempting uninstall: pytorch-lightning
    Found existing installation: pytorch-lightning 2.5.0.post0
    Uninstalling pytorch-lightning-2.5.0.post0:
      Successfully uninstalled pytorch-lightning-2.5.0.post0
Successfully installed pytorch-lightning-2.2.1
Collecting ase==3.23
  Downloading ase-3.23.0-py3-none-any.whl.metadata (3.8 kB)
Downloading ase-3.23.0-py3-none-any.whl (2.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m56.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ase
  Attempting uninstall: ase
    Found existing installation: ase 3.24.0
    Uninstalling ase-3.24.0:
    

In [1]:
import os
import ase
import schnetpack as spk
from schnetpack.datasets import QM9
from schnetpack.transform import ASENeighborList
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl
from schnetpack.data import ASEAtomsData
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
  os.makedirs(qm9tut)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is using {device}")

Device is using cuda


In [4]:
!rm -rf qm9.db split_qm9.npz
qm9data = QM9(
    './qm9.db',
    batch_size =10,
    num_train=110000,
    num_val=10000,
    split_file='./split_qm9.npz',
    transforms=[ASENeighborList(cutoff=5.)]
)
qm9data.prepare_data()
qm9data.setup()

100%|██████████| 133885/133885 [02:15<00:00, 991.55it/s]


In [5]:
print('Number of reference calculations:', len(qm9data.dataset))
print('Number of train data:', len(qm9data.train_dataset))
print('Number of test data:', len(qm9data.test_dataset))
print('Available properties:')

for p in qm9data.dataset.available_properties:
  print('-', p)

Number of reference calculations: 133885
Number of train data: 110000
Number of test data: 13885
Available properties:
- rotational_constant_A
- rotational_constant_B
- rotational_constant_C
- dipole_moment
- isotropic_polarizability
- homo
- lumo
- gap
- electronic_spatial_extent
- zpve
- energy_U0
- energy_U
- enthalpy_H
- free_energy
- heat_capacity


In [7]:
example = qm9data.dataset[0]
print('Properties:')

for k, v in example.items():
  print('-', k, ':', v.shape)

Properties:
- _idx : torch.Size([1])
- rotational_constant_A : torch.Size([1])
- rotational_constant_B : torch.Size([1])
- rotational_constant_C : torch.Size([1])
- dipole_moment : torch.Size([1])
- isotropic_polarizability : torch.Size([1])
- homo : torch.Size([1])
- lumo : torch.Size([1])
- gap : torch.Size([1])
- electronic_spatial_extent : torch.Size([1])
- zpve : torch.Size([1])
- energy_U0 : torch.Size([1])
- energy_U : torch.Size([1])
- enthalpy_H : torch.Size([1])
- free_energy : torch.Size([1])
- heat_capacity : torch.Size([1])
- _n_atoms : torch.Size([1])
- _atomic_numbers : torch.Size([5])
- _positions : torch.Size([5, 3])
- _cell : torch.Size([1, 3, 3])
- _pbc : torch.Size([3])


In [8]:
for batch in qm9data.val_dataloader():
  print(batch.keys())
  print('Isotropic Polarizability:', batch['isotropic_polarizability'])
  break

dict_keys(['_idx', 'rotational_constant_A', 'rotational_constant_B', 'rotational_constant_C', 'dipole_moment', 'isotropic_polarizability', 'homo', 'lumo', 'gap', 'electronic_spatial_extent', 'zpve', 'energy_U0', 'energy_U', 'enthalpy_H', 'free_energy', 'heat_capacity', '_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_j', '_idx_i'])
Isotropic Polarizability: tensor([77.3200, 77.0100, 76.3100, 65.3300, 81.4100, 71.1500, 81.7100, 82.4300,
        82.1700, 68.3200], dtype=torch.float64)


In [9]:
print('system index:', batch['_idx_m'])
print('Center atom index:', batch['_idx_i'])
print('Neighbor atom index:', batch['_idx_j'])

system index: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9])
Center atom index: tensor([  0,   0,   0,  ..., 173, 173, 173])
Neighbor atom index: tensor([  1,   2,   4,  ..., 164, 171, 172])


In [10]:
import ase.units

ase.units.a0 = ase.units.Bohr

qm9data = QM9(
    './qm9.db',
    batch_size=100,
    num_train=1000,
    num_val=1000,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.RemoveOffsets(QM9.alpha, remove_mean=True, remove_atomrefs=False),
        trn.CastTo32()
    ],
    property_units={QM9.alpha: 'Bohr'},
    num_workers=4,
    split_file=os.path.join(qm9tut, 'split.npz'),
    pin_memory=True,
    load_properties=[QM9.alpha],
)
qm9data.prepare_data()
qm9data.setup()

100%|██████████| 10/10 [00:01<00:00,  6.32it/s]


In [11]:
means, stddevs = qm9data.get_stats(
    QM9.alpha, divide_by_atoms=True, remove_atomref=False
)
print('Mean atomization energy / atoms:', means.item())
print('Std. dev. atomization energy / atom:', stddevs.item())

Mean atomization energy / atoms: 1.184216322484864
Std. dev. atomization energy / atom: 0.13487847418420648


# Setting up the model

In [12]:
cutoff = 5.
n_atom_basis = 40

pairwise_distance = spk.atomistic.PairwiseDistances().to(device)  # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff).to(device)
schnet = spk.representation.SchNet(
    n_atom_basis=n_atom_basis, n_interactions=6,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
).to(device)
pred_alpha = spk.atomistic.Atomwise(n_in=n_atom_basis, output_key=QM9.alpha).to(device)

nnpot = spk.model.NeuralNetworkPotential(
    representation=schnet,
    input_modules=[pairwise_distance],
    output_modules=[pred_alpha],
    postprocessors=[trn.CastTo64(), trn.AddOffsets(QM9.alpha, add_mean=True, add_atomrefs=False)]
).to(device)

In [13]:
output_alpha = spk.task.ModelOutput(
    name=QM9.alpha,
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1.,
    metrics={
        "MAE" : torchmetrics.MeanAbsoluteError(),
        "RMSE": torchmetrics.MeanSquaredError(squared=False)
    }
)

In [14]:
task = spk.task.AtomisticTask(
    model=nnpot.to(device),
    outputs=[output_alpha],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


## Training the model

The model is now ready for training. Since we already defined all necessary components, the only thing left to do is passing it to the pytorch Lightning Trainer together with the data module.
Additionally, we can provide callbacks that take care of logging, checkpointing etc.

In [15]:
logger = pl.loggers.TensorBoardLogger(save_dir=qm9tut, name="logs")
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(qm9tut, "best_inference_model"),
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    callbacks=callbacks,
    logger=False,
    default_root_dir=qm9tut,
    max_epochs=20,
)
trainer.fit(task, datamodule=qm9data)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 49.0 K
1 | outputs | ModuleList             | 0     
---------------------------------------------------
49.0 K    Trainable params
0         Non-trainable params
49.0 K    Total params
0.196     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


# Inference

Having trained a model for QM9, we are going to use it to obtain some predictions. First, we need to load the model. The Trainer stores the best model in the model directory which can be loaded using PyTorch.

In [16]:
import torch
import numpy as np
from ase import Atoms

best_model = torch.load(os.path.join(qm9tut, 'best_inference_model'), map_location=device)
best_model.to(device)

  best_model = torch.load(os.path.join(qm9tut, 'best_inference_model'), map_location=device)


NeuralNetworkPotential(
  (postprocessors): ModuleList(
    (0): CastTo64()
    (1): AddOffsets()
  )
  (representation): SchNet(
    (radial_basis): GaussianRBF()
    (cutoff_fn): CosineCutoff()
    (embedding): Embedding(100, 40)
    (electronic_embeddings): ModuleList()
    (interactions): ModuleList(
      (0-5): 6 x SchNetInteraction(
        (in2f): Dense(
          in_features=40, out_features=40, bias=False
          (activation): Identity()
        )
        (f2out): Sequential(
          (0): Dense(in_features=40, out_features=40, bias=True)
          (1): Dense(
            in_features=40, out_features=40, bias=True
            (activation): Identity()
          )
        )
        (filter_network): Sequential(
          (0): Dense(in_features=20, out_features=40, bias=True)
          (1): Dense(
            in_features=40, out_features=40, bias=True
            (activation): Identity()
          )
        )
      )
    )
  )
  (input_modules): ModuleList(
    (0): PairwiseD

In [17]:
for batch in qm9data.test_dataloader():
  batch = {key: value.to(device) for key, value in batch.items()}

  with torch.no_grad():
    result = best_model(batch)
  print("Result dictionary:", result)
  break

Result dictionary: {'isotropic_polarizability': tensor([24.9568, 23.9596, 26.9670, 21.3839, 20.7384, 20.3058, 21.9999, 19.7879,
        22.8275, 19.2517, 17.8356, 19.3702, 20.3088, 23.6591, 21.9268, 18.1710,
        21.2162, 16.9730, 23.3870, 20.8245, 22.9054, 12.9800, 22.1471, 22.3530,
        22.9300, 17.6501, 18.7496, 23.2286, 20.8358, 19.7797, 16.6239, 20.2062,
        14.4100, 16.8241, 20.9443, 19.7146, 20.4793, 20.9321, 25.2529, 20.7754,
        21.1689, 22.1162, 21.2642, 23.4901, 23.8527, 23.7891, 22.6324, 20.6130,
        24.4188, 19.1816, 21.3923, 18.7406, 11.7230, 22.1732, 17.3631, 21.5483,
        18.0201, 20.7089, 23.7005, 24.5080, 22.4149, 23.7042, 16.1315, 23.1734,
        20.4832, 22.3278, 18.2101, 23.5817, 21.4780, 19.0557, 22.3942, 22.3745,
        22.1325, 16.1626, 15.9431, 17.4626, 17.2725, 19.9944, 22.7604, 21.1472,
        19.4093, 18.9336, 21.5685, 24.7496, 18.1675, 18.8231, 24.3453, 18.2538,
        18.8290, 21.8607, 20.8447, 27.5245, 21.7142, 24.6377, 18.5049, 2

In [18]:
import torch
from schnetpack.interfaces import AtomsConverter
import schnetpack.transform as trn

def predict_polarizability(atoms_obj, model, device):
  """
  predicts the polarizability of a given molecular structure.

  Args:
      atoms_obj: ASE Atoms object representing a molecule.
      model: Trained SchNet model.
      device: The device where the model runs.

  Return:
      float: Predicted polarizability value.
  """
  # Ensure model is in evaluation mode
  model.eval()

  # Convert ASE Atoms  object to SchNetPack input format
  converter = AtomsConverter(
      neighbor_list=trn.ASENeighborList(cutoff=5.), dtype=torch.float32, device=device
  )
  inputs = converter(atoms_obj)

  #Run inference
  with torch.no_grad():
    result = model(inputs)

    # Debugging: Print available keys if an error occurs
    print("Model output keys:", result.keys())

    # Extract the predicted polarizability value
    predicted_polarizability = result[QM9.alpha].item()

    return predicted_polarizability

In [21]:
print("Available keys in the first test sample:", qm9data.test_dataset[0].keys())

Available keys in the first test sample: dict_keys(['_idx', 'isotropic_polarizability', '_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx_i', '_idx_j', '_offsets'])


In [22]:
from ase import Atoms

# Pick a molecule from the dataset
sample_data = qm9data.test_dataset[0]
atomic_numbers = sample_data["_atomic_numbers"].numpy()  # Convert tensors to NumPy
positions = sample_data["_positions"].numpy()   # Convert tensors to NumPy

# Convert an ASE Atoms opbject
sample_molecule = Atoms(numbers=atomic_numbers, positions=positions)

# Call the function to predict polarizability
predicted_value = predict_polarizability(sample_molecule, best_model, device)

# Print the predicted value and the actual value
actual_polarizability = sample_data[QM9.alpha].item()
print(f"Predicted polarizability: {predicted_value}")
print(f"Actual Polarizability: {actual_polarizability}")

Model output keys: dict_keys(['isotropic_polarizability'])
Predicted polarizability: 24.95676283205382
Actual Polarizability: -1.2397361993789673


In [None]:
converter = spk.interfaces.AtomsConverter(neighbor_list=trn.ASENeighborList(cutoff=5.), dtype=torch.float32)

In [None]:
numbers = np.array([6, 1, 1, 1, 1])
positions = np.array([[-0.0126981359, 1.0858041578, 0.0080009958],
                      [0.002150416, -0.0060313176, 0.0019761204],
                      [1.0117308433, 1.4637511618, 0.0002765748],
                      [-0.540815069, 1.4475266138, -0.8766437152],
                      [-0.5238136345, 1.4379326443, 0.9063972942]])
atoms = Atoms(numbers=numbers, positions=positions)

In [None]:
inputs = converter(atoms)

print('Keys:', list(inputs.keys()))

pred = best_model(inputs)

print('Prediction:', pred[QM9.alpha])

Keys: ['_n_atoms', '_atomic_numbers', '_positions', '_cell', '_pbc', '_idx', '_idx_i_local', '_idx_j_local', '_offsets', '_idx_m', '_idx_j', '_idx_i']
Prediction: tensor([6.4133], dtype=torch.float64, grad_fn=<AddBackward0>)


In [None]:
calculator = spk.interfaces.SpkCalculator(
    model_file=os.path.join(qm9tut, "best_inference_model"),  # Path to model
    neighbor_list=trn.ASENeighborList(cutoff=5.),
    polar_key=QM9.alpha,  # Name of polarizability property in model
    energy_unit='Bohr',
    device='cpu'
)
atoms.set_calculator(calculator)
print('Prediction:', )

  model = torch.load(model_path, map_location=device, **kwargs)
  atoms.set_calculator(calculator)


AtomsConverterError: 'energy' is not a property of your model. Please check the model properties!