In [1]:
import os
import schnetpack as spk
from schnetpack.atomistic import Atomwise
from schnetpack.representation import SchNet
import schnetpack.transform as trn

import matplotlib.pyplot as plt

import torch
import torchmetrics
import pytorch_lightning as pl

In [2]:
go_mol_data = spk.data.AtomsDataModule(
    './processed/coronene_schnet.db',
    batch_size=1,
    distance_unit='Ang',
    property_units={'Spectrum':'eV'},
    num_train=250,
    num_val=39,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.RemoveOffsets('Spectrum', remove_mean=True, remove_atomrefs=False),
        trn.CastTo32()
    ],
    num_workers=1,
    pin_memory=True
)
go_mol_data.prepare_data()
go_mol_data.setup()

100%|██████████| 250/250 [00:00<00:00, 292.35it/s]


In [3]:
properties = go_mol_data.dataset[0]
print('Loaded properties:\n', *['{:s}\n'.format(i) for i in properties.keys()])

Loaded properties:
 _idx
 Spectrum
 _n_atoms
 _atomic_numbers
 _positions
 _cell
 _pbc



In [4]:
print('Spectrum:\n', properties['Spectrum'])
print('Shape:\n', properties['Spectrum'].shape)

Spectrum:
 tensor([1.4555e-14, 8.8380e-14, 5.0741e-13, 2.7545e-12, 1.4138e-11, 6.8616e-11,
        3.1489e-10, 1.3664e-09, 5.6073e-09, 2.1759e-08, 7.9854e-08, 2.7716e-07,
        9.0987e-07, 2.8253e-06, 8.2991e-06, 2.3063e-05, 6.0645e-05, 1.5091e-04,
        3.5542e-04, 7.9244e-04, 1.6729e-03, 3.3450e-03, 6.3363e-03, 1.1375e-02,
        1.9358e-02, 3.1245e-02, 4.7850e-02, 6.9561e-02, 9.6037e-02, 1.2599e-01,
        1.5713e-01, 1.8640e-01, 2.1045e-01, 2.2630e-01, 2.3199e-01, 2.2712e-01,
        2.1312e-01, 1.9311e-01, 1.7162e-01, 1.5402e-01, 1.4597e-01, 1.5280e-01,
        1.7896e-01, 2.2732e-01, 2.9850e-01, 3.9033e-01, 4.9752e-01, 6.1191e-01,
        7.2338e-01, 8.2145e-01, 8.9708e-01, 9.4430e-01, 9.6115e-01, 9.4954e-01,
        9.1423e-01, 8.6124e-01, 7.9631e-01, 7.2399e-01, 6.4750e-01, 5.6930e-01,
        4.9193e-01, 4.1858e-01, 3.5322e-01, 3.0002e-01, 2.6248e-01, 2.4242e-01,
        2.3935e-01, 2.5033e-01, 2.7050e-01, 2.9409e-01, 3.1572e-01, 3.3159e-01,
        3.4029e-01, 3.4297e-0

In [5]:
cutoff = 5.
n_atom_basis = 2

pairwise_distance = spk.atomistic.PairwiseDistances()
radial_basis = spk.nn.GaussianRBF(n_rbf=10, cutoff=cutoff)
schnet = SchNet(
    n_atom_basis=n_atom_basis, n_interactions=5,
    radial_basis=radial_basis, cutoff_fn=spk.nn.CosineCutoff(cutoff)
)

pred_spectrum = Atomwise(n_in=n_atom_basis, n_out=200, n_hidden=[64, 128, 128],
                         n_layers=3, output_key='Spectrum')

nnpot = spk.model.NeuralNetworkPotential(
    representation=schnet,
    input_modules=[pairwise_distance],
    output_modules=[pred_spectrum],
    postprocessors=[trn.CastTo64(), trn.AddOffsets('Spectrum', add_mean=True, add_atomrefs=False)]
)

In [6]:
output_spec = spk.task.ModelOutput(
    name = 'Spectrum',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=0.01,
    metrics={'MAE': torchmetrics.MeanAbsoluteError()}
)

In [7]:
task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_spec],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={'lr': 1e-4}
)
print(task)

AtomisticTask(
  (model): NeuralNetworkPotential(
    (postprocessors): ModuleList(
      (0): CastTo64()
      (1): AddOffsets()
    )
    (representation): SchNet(
      (radial_basis): GaussianRBF()
      (cutoff_fn): CosineCutoff()
      (embedding): Embedding(100, 2, padding_idx=0)
      (interactions): ModuleList(
        (0-4): 5 x SchNetInteraction(
          (in2f): Dense(
            in_features=2, out_features=2, bias=False
            (activation): Identity()
          )
          (f2out): Sequential(
            (0): Dense(in_features=2, out_features=2, bias=True)
            (1): Dense(
              in_features=2, out_features=2, bias=True
              (activation): Identity()
            )
          )
          (filter_network): Sequential(
            (0): Dense(in_features=10, out_features=2, bias=True)
            (1): Dense(
              in_features=2, out_features=2, bias=True
              (activation): Identity()
            )
          )
        )
      )
    

/home/samjhall/anaconda3/envs/pyg-schnet/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [8]:
path = './schnet'
logger = pl.loggers.TensorBoardLogger(save_dir=path)
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(path, 'best_inference_model'),
        save_top_k=1,
        monitor='val_loss'
    )
]

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    max_epochs=10
)

trainer.fit(task, datamodule=go_mol_data)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: ./schnet/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 34.7 K
1 | outputs | ModuleList             | 0     
---------------------------------------------------
34.7 K    Trainable params
0         Non-trainable params
34.7 K    Total params
0.139     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/samjhall/anaconda3/envs/pyg-schnet/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)
/home/samjhall/anaconda3/envs/pyg-schnet/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


RuntimeError: Predictions and targets are expected to have the same shape, but got torch.Size([1, 200]) and torch.Size([200]).

In [None]:
best_model = torch.load('.schnet/best_inference_model', map_location='cpu')

In [None]:
for batch in go_mol_data.test_dataloader():
    result = best_model(batch)
    break

In [None]:
spec_pred = torch.flatten(result['Spectrum'])

plt.plot(spec_pred.detach().numpy())
plt.plot(go_mol_data.test_dataset[0]['Spectrum'])
plt.show()