Skip to content

Commit

Permalink
Merge branch 'main' into extra_fields_NNPs
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioMirarchi committed Apr 5, 2024
2 parents 9e546dd + 74702da commit b0d08e1
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 84 deletions.
83 changes: 52 additions & 31 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop:
optimizer.step()
Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


.. _delta-learning:
Training on relative energies
-----------------------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

It might be useful to train the model on relative energies but then make the model produce total energies when running inference.
TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model.
Expand All @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the
.. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`.

Example
~~~~~~~
********

First we train a model with the :code:`remove_ref_energy` option turned on:

Expand All @@ -151,6 +122,56 @@ Then we load the model for inference:
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
Loading a model for inference
-----------------------------

Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example.

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt"
model = load_model(checkpoint, derivative=True)
# An arbitrary set of inputs for the model
n_atoms = 10
zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long)
z = zs[torch.randint(0, len(zs), (n_atoms,))]
pos = torch.randn(len(z), 3)
batch = torch.zeros(len(z), dtype=torch.long)
y, neg_dy = model(z, pos, batch)
.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference.

.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case.

.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case.


Model Ensembles
---------------
It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this:

.. code:: python
import torch
from torchmdnet.models.model import load_model
checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"]
model_ensemble = load_model(checkpoints, return_std=True)
y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)
.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions.



.. autoclass:: torchmdnet.models.model.Ensemble
:noindex:





Expand Down
110 changes: 80 additions & 30 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import lightning as pl
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models.model import create_model, load_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

Expand Down Expand Up @@ -38,7 +38,9 @@ def test_forward(model_name, use_batch, use_extra_args, precision, additional_la
@mark.parametrize("precision", [32, 64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
args = load_example_args(
model_name, remove_prior=True, output_model=output_model, precision=precision
)
model = create_model(args)
model(z, pos, batch=batch)

Expand All @@ -63,18 +65,25 @@ def test_torchscript(model_name, device):
grad_outputs=grad_outputs,
)[0]


def test_torchscript_output_modification():
model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True))
model = create_model(
load_example_args("tensornet", remove_prior=True, derivative=True)
)

class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.model = model

def forward(self, z, pos, batch):
y, neg_dy = self.model(z, pos, batch=batch)
# A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor]
return y, 2*neg_dy
return y, 2 * neg_dy

torch.jit.script(MyModel())


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize("device", ["cpu", "cuda"])
def test_torchscript_dynamic_shapes(model_name, device):
Expand All @@ -86,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device):
model = torch.jit.script(
create_model(load_example_args(model_name, remove_prior=True, derivative=True))
).to(device=device)
#Repeat the input to make it dynamic
# Repeat the input to make it dynamic
for rep in range(0, 5):
print(rep)
zi = z.repeat_interleave(rep+1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep+1, dim=0).to(device=device)
zi = z.repeat_interleave(rep + 1, dim=0).to(device=device)
posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device)
batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device)
y, neg_dy = model(zi, posi, batch=batchi)
grad_outputs = [torch.ones_like(neg_dy)]
Expand All @@ -100,32 +109,34 @@ def test_torchscript_dynamic_shapes(model_name, device):
grad_outputs=grad_outputs,
)[0]

#Currently only tensornet is CUDA graph compatible

# Currently only tensornet is CUDA graph compatible
@mark.parametrize("model_name", ["tensornet"])
def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
args = {
"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand All @@ -144,6 +155,7 @@ def test_cuda_graph_compatible(model_name):
assert torch.allclose(y, y2)
assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5)


@mark.parametrize("model_name", models.__all_models__)
def test_seed(model_name):
args = load_example_args(model_name, remove_prior=True)
Expand All @@ -155,6 +167,7 @@ def test_seed(model_name):
for p1, p2 in zip(m1.parameters(), m2.parameters()):
assert (p1 == p2).all(), "Parameters don't match although using the same seed."


@mark.parametrize("model_name", models.__all_models__)
@mark.parametrize(
"output_model",
Expand Down Expand Up @@ -201,7 +214,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
), f"Set new reference outputs for {model_name} with output model {output_model}."

# compare actual ouput with reference
torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5
)
if derivative:
torch.testing.assert_close(
deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5
Expand All @@ -220,7 +235,7 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
precision=precision
precision=precision,
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
Expand All @@ -229,3 +244,38 @@ def test_gradients(model_name):
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)


def test_ensemble():
ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3
model = load_model(ckpts[0])
ensemble_model = load_model(ckpts, return_std=True)
z, pos, batch = create_example_batch(n_atoms=5)

pred, deriv = model(z, pos, batch)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()

import zipfile
import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
ensemble_zip = join(tmpdir, "ensemble.zip")
with zipfile.ZipFile(ensemble_zip, "w") as zipf:
for i, ckpt in enumerate(ckpts):
zipf.write(ckpt, f"model_{i}.ckpt")
ensemble_model = load_model(ensemble_zip, return_std=True)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()
5 changes: 4 additions & 1 deletion torchmdnet/datasets/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _preload_data(self):
# Watchout for the 1D case, embed can be shared for all samples
tmp = torch.tensor(np.array(data), dtype=dtype)
if tmp.ndim == 1:
tmp = tmp.unsqueeze(0).expand(size, -1)
if len(tmp) == size:
tmp = tmp.unsqueeze(-1)
else:
tmp = tmp.unsqueeze(0).expand(size, -1)
self.stored_data[field].append(tmp)
self.index.extend(list(zip([i] * size, range(size))))
i += 1
Expand Down
Loading

0 comments on commit b0d08e1

Please sign in to comment.