From 1b67010080cd4026253d5d2b81ed9e5279155f7f Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Mon, 18 Mar 2024 14:12:46 +0200 Subject: [PATCH 01/17] Added support for ensemble models --- torchmdnet/models/model.py | 39 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index a2a80f901..ecc9c44d1 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -151,6 +151,10 @@ def load_model(filepath, args=None, device="cpu", **kwargs): Returns: nn.Module: An instance of the TorchMD_Net model. """ + if isinstance(filepath, (list, tuple)): + return Ensemble( + [load_model(f, args=args, device=device, **kwargs) for f in filepath] + ) ckpt = torch.load(filepath, map_location="cpu") if args is None: @@ -426,3 +430,38 @@ def forward( # Returning an empty tensor allows to decorate this method as always returning two tensors. # This is required to overcome a TorchScript limitation, xref https://github.com/openmm/openmm-torch/issues/135 return y, torch.empty(0) + + +class Ensemble(torch.nn.ModuleList): + """Average predictions over an ensemble of TorchMD-Net models""" + + def __init__(self, modules): + super().__init__(modules) + + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + box: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None, + ): + y = [] + neg_dy = [] + for model in self: + res = model( + z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args + ) + y.append(res[0]) + neg_dy.append(res[1]) + + y = torch.stack(y) + print(y, neg_dy) + neg_dy = torch.stack(neg_dy) + y_mean = torch.mean(y, axis=0) + neg_dy_mean = torch.mean(neg_dy, axis=0) + y_std = torch.std(y, axis=0) + neg_dy_std = torch.std(neg_dy, axis=0) + return y_mean, neg_dy_mean, y_std, neg_dy_std From a5c0a3a842f0a27ef133c7051718523266de633d Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Mon, 18 Mar 2024 14:15:18 +0200 Subject: [PATCH 02/17] remove debug print --- torchmdnet/models/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index ecc9c44d1..c30d894ff 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -458,7 +458,6 @@ def forward( neg_dy.append(res[1]) y = torch.stack(y) - print(y, neg_dy) neg_dy = torch.stack(neg_dy) y_mean = torch.mean(y, axis=0) neg_dy_mean = torch.mean(neg_dy, axis=0) From 2b99e77e7e4003658d7efc9a23ba799263a766b2 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 13:40:04 +0100 Subject: [PATCH 03/17] Update load_model docstring --- torchmdnet/models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c30d894ff..6a0ecf7fa 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -142,8 +142,9 @@ def create_model(args, prior_model=None, mean=None, std=None): def load_model(filepath, args=None, device="cpu", **kwargs): """Load a model from a checkpoint file. + If a list of paths is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str): Path to the checkpoint file. + filepath (str or list): Path to the checkpoint file or a list of paths. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". **kwargs: Extra keyword arguments for the model. From d071cd38324a96a87801098e6a179509efaf2db5 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 13:43:31 +0100 Subject: [PATCH 04/17] Update Ensemble docstring --- torchmdnet/models/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 6a0ecf7fa..618a024a7 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -434,9 +434,17 @@ def forward( class Ensemble(torch.nn.ModuleList): - """Average predictions over an ensemble of TorchMD-Net models""" + """Average predictions over an ensemble of TorchMD-Net models. - def __init__(self, modules): + This module behaves like a single TorchMD-Net model, but its forward method returns the average and standard deviation of the predictions over all models it was initialized with. + + Args: + modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + """ + + def __init__(self, modules: List[nn.Module]): + for module in modules: + assert isinstance(module, TorchMD_Net) super().__init__(modules) def forward( From 94f94d90208b43ca083ce7978ca6e7ddadd9085b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 18 Mar 2024 14:07:14 +0100 Subject: [PATCH 05/17] Add test for ensemble --- tests/test_model.py | 99 ++++++++++++++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index b792595b8..31a09a213 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ import torch import lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, load_model from torchmdnet.models import output_modules from torchmdnet.models.utils import dtype_mapping @@ -23,7 +23,9 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): z, pos, batch = create_example_batch() pos = pos.to(dtype=dtype_mapping[precision]) - model = create_model(load_example_args(model_name, prior_model=None, precision=precision)) + model = create_model( + load_example_args(model_name, prior_model=None, precision=precision) + ) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) @@ -33,10 +35,12 @@ def test_forward(model_name, use_batch, explicit_q_s, precision): @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("output_model", output_modules.__all__) -@mark.parametrize("precision", [32,64]) +@mark.parametrize("precision", [32, 64]) def test_forward_output_modules(model_name, output_model, precision): z, pos, batch = create_example_batch() - args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision) + args = load_example_args( + model_name, remove_prior=True, output_model=output_model, precision=precision + ) model = create_model(args) model(z, pos, batch=batch) @@ -61,18 +65,25 @@ def test_torchscript(model_name, device): grad_outputs=grad_outputs, )[0] + def test_torchscript_output_modification(): - model = create_model(load_example_args("tensornet", remove_prior=True, derivative=True)) + model = create_model( + load_example_args("tensornet", remove_prior=True, derivative=True) + ) + class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.model = model + def forward(self, z, pos, batch): y, neg_dy = self.model(z, pos, batch=batch) # A TorchScript bug is triggered if we modify an output of model marked as Optional[Tensor] - return y, 2*neg_dy + return y, 2 * neg_dy + torch.jit.script(MyModel()) + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize("device", ["cpu", "cuda"]) def test_torchscript_dynamic_shapes(model_name, device): @@ -84,11 +95,11 @@ def test_torchscript_dynamic_shapes(model_name, device): model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ).to(device=device) - #Repeat the input to make it dynamic + # Repeat the input to make it dynamic for rep in range(0, 5): print(rep) - zi = z.repeat_interleave(rep+1, dim=0).to(device=device) - posi = pos.repeat_interleave(rep+1, dim=0).to(device=device) + zi = z.repeat_interleave(rep + 1, dim=0).to(device=device) + posi = pos.repeat_interleave(rep + 1, dim=0).to(device=device) batchi = torch.randint(0, 10, (zi.shape[0],)).sort()[0].to(device=device) y, neg_dy = model(zi, posi, batch=batchi) grad_outputs = [torch.ones_like(neg_dy)] @@ -98,32 +109,35 @@ def test_torchscript_dynamic_shapes(model_name, device): grad_outputs=grad_outputs, )[0] -#Currently only tensornet is CUDA graph compatible + +# Currently only tensornet is CUDA graph compatible @mark.parametrize("model_name", ["tensornet"]) def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): pytest.skip("CUDA not available") z, pos, batch = create_example_batch() - args = {"model": model_name, - "embedding_dimension": 128, - "num_layers": 2, - "num_rbf": 32, - "rbf_type": "expnorm", - "trainable_rbf": False, - "activation": "silu", - "cutoff_lower": 0.0, - "cutoff_upper": 5.0, - "max_z": 100, - "max_num_neighbors": 128, - "equivariance_invariance_group": "O(3)", - "prior_model": None, - "atom_filter": -1, - "derivative": True, - "check_error": False, - "static_shapes": True, - "output_model": "Scalar", - "reduce_op": "sum", - "precision": 32 } + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_error": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") @@ -142,6 +156,7 @@ def test_cuda_graph_compatible(model_name): assert torch.allclose(y, y2) assert torch.allclose(neg_dy, neg_dy2, atol=1e-5, rtol=1e-5) + @mark.parametrize("model_name", models.__all_models__) def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) @@ -153,6 +168,7 @@ def test_seed(model_name): for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed." + @mark.parametrize("model_name", models.__all_models__) @mark.parametrize( "output_model", @@ -199,7 +215,9 @@ def test_forward_output(model_name, output_model, overwrite_reference=False): ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference - torch.testing.assert_close(pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5) + torch.testing.assert_close( + pred, expected[model_name][output_model]["pred"], atol=1e-5, rtol=1e-5 + ) if derivative: torch.testing.assert_close( deriv, expected[model_name][output_model]["deriv"], atol=1e-5, rtol=1e-5 @@ -218,7 +236,7 @@ def test_gradients(model_name): remove_prior=True, output_model=output_model, derivative=derivative, - precision=precision + precision=precision, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) @@ -227,3 +245,20 @@ def test_gradients(model_name): torch.autograd.gradcheck( model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3 ) + + +def test_ensemble(): + ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 + model = load_model(ckpts[0]) + ensemble_model = load_model(ckpts) + z, pos, batch = create_example_batch(n_atoms=5) + + pred, deriv = model(z, pos, batch) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all() From 5ba3db96e0b8670808f3ddc5f75ff324d192660b Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Tue, 19 Mar 2024 10:13:38 +0200 Subject: [PATCH 06/17] returning the standard deviation is now optional --- torchmdnet/models/model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 618a024a7..ad94ad014 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -139,7 +139,7 @@ def create_model(args, prior_model=None, mean=None, std=None): return model -def load_model(filepath, args=None, device="cpu", **kwargs): +def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. If a list of paths is given, an :py:mod:`Ensemble` model is returned. @@ -147,6 +147,7 @@ def load_model(filepath, args=None, device="cpu", **kwargs): filepath (str or list): Path to the checkpoint file or a list of paths. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: @@ -154,7 +155,8 @@ def load_model(filepath, args=None, device="cpu", **kwargs): """ if isinstance(filepath, (list, tuple)): return Ensemble( - [load_model(f, args=args, device=device, **kwargs) for f in filepath] + [load_model(f, args=args, device=device, **kwargs) for f in filepath], + return_std=return_std, ) ckpt = torch.load(filepath, map_location="cpu") @@ -440,12 +442,14 @@ class Ensemble(torch.nn.ModuleList): Args: modules (List[nn.Module]): List of :py:mod:`TorchMD_Net` models to average predictions over. + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. If set to True, the model returns 4 arguments (mean_y, mean_neg_dy, std_y, std_neg_dy) instead of 2 (mean_y, mean_neg_dy). """ - def __init__(self, modules: List[nn.Module]): + def __init__(self, modules: List[nn.Module], return_std: bool = False): for module in modules: assert isinstance(module, TorchMD_Net) super().__init__(modules) + self.return_std = return_std def forward( self, @@ -472,4 +476,8 @@ def forward( neg_dy_mean = torch.mean(neg_dy, axis=0) y_std = torch.std(y, axis=0) neg_dy_std = torch.std(neg_dy, axis=0) - return y_mean, neg_dy_mean, y_std, neg_dy_std + + if self.return_std: + return y_mean, neg_dy_mean, y_std, neg_dy_std + else: + return y_mean, neg_dy_mean From 417955d61617d2338a2c8f23f742162d862f670d Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Tue, 19 Mar 2024 10:42:30 +0200 Subject: [PATCH 07/17] fix test --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 31a09a213..88ac9a1ca 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -250,7 +250,7 @@ def test_gradients(model_name): def test_ensemble(): ckpts = [join(dirname(dirname(__file__)), "tests", "example.ckpt")] * 3 model = load_model(ckpts[0]) - ensemble_model = load_model(ckpts) + ensemble_model = load_model(ckpts, return_std=True) z, pos, batch = create_example_batch(n_atoms=5) pred, deriv = model(z, pos, batch) From 9546e88fba0c505d31612815af7d9bd3e0ceaaec Mon Sep 17 00:00:00 2001 From: Raul Date: Tue, 19 Mar 2024 12:08:00 +0100 Subject: [PATCH 08/17] Update test_model.py Fix typo in CUDA test. --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index b792595b8..00010f890 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -119,7 +119,7 @@ def test_cuda_graph_compatible(model_name): "prior_model": None, "atom_filter": -1, "derivative": True, - "check_error": False, + "check_errors": False, "static_shapes": True, "output_model": "Scalar", "reduce_op": "sum", From 5a60b40ae193dc6a1d7a0f4e8f57d727193be455 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:06:29 +0100 Subject: [PATCH 09/17] Make Ensemble variadic --- torchmdnet/models/model.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index ad94ad014..0422d6de6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -453,23 +453,18 @@ def __init__(self, modules: List[nn.Module], return_std: bool = False): def forward( self, - z: Tensor, - pos: Tensor, - batch: Optional[Tensor] = None, - box: Optional[Tensor] = None, - q: Optional[Tensor] = None, - s: Optional[Tensor] = None, - extra_args: Optional[Dict[str, Tensor]] = None, + *args, + **kwargs, ): + """Average predictions over all models in the ensemble. + The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + """ y = [] neg_dy = [] for model in self: - res = model( - z=z, pos=pos, batch=batch, box=box, q=q, s=s, extra_args=extra_args - ) + res = model(*args, **kwargs) y.append(res[0]) neg_dy.append(res[1]) - y = torch.stack(y) neg_dy = torch.stack(neg_dy) y_mean = torch.mean(y, axis=0) From 013a691da1b213080b6f2921717d4a2e78f2258d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:29:07 +0100 Subject: [PATCH 10/17] Update documentation --- docs/source/models.rst | 83 ++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 244c16928..5ec760e33 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -83,38 +83,9 @@ This is a minimal example of a custom training loop: optimizer.step() - - -Loading a model for inference ------------------------------ - -Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. - -.. code:: python - - import torch - from torchmdnet.models.model import load_model - checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" - model = load_model(checkpoint, derivative=True) - # An arbitrary set of inputs for the model - n_atoms = 10 - zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) - z = zs[torch.randint(0, len(zs), (n_atoms,))] - pos = torch.randn(len(z), 3) - batch = torch.zeros(len(z), dtype=torch.long) - - y, neg_dy = model(z, pos, batch) - -.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. - -.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. - -.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. - - .. _delta-learning: Training on relative energies ------------------------------ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It might be useful to train the model on relative energies but then make the model produce total energies when running inference. TorchMD-Net supports delta training via the :code:`remove_ref_energy` option. Passing this option when training (either via the :ref:`configuration-file` or using the :ref:`torchmd-train` command line interface) will subtract the reference energy from each atom in a sample before passing it to the model. @@ -126,7 +97,7 @@ If :code:`remove_ref_energy` is turned on, the reference energy is stored in the .. note:: The reference energies are stored as an :py:mod:`torchmdnet.priors.Atomref` prior with :code:`enable=False`. Example -~~~~~~~ +******** First we train a model with the :code:`remove_ref_energy` option turned on: @@ -151,6 +122,56 @@ Then we load the model for inference: batch = torch.zeros(len(z), dtype=torch.long) y, neg_dy = model(z, pos, batch) + + +Loading a model for inference +----------------------------- + +Once you have trained a model you should have a checkpoint that you can load for inference using :py:mod:`torchmdnet.models.model.load_model` as in the following example. + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoint = "/path/to/checkpoint/my_checkpoint.ckpt" + model = load_model(checkpoint, derivative=True) + # An arbitrary set of inputs for the model + n_atoms = 10 + zs = torch.tensor([1, 6, 7, 8, 9], dtype=torch.long) + z = zs[torch.randint(0, len(zs), (n_atoms,))] + pos = torch.randn(len(z), 3) + batch = torch.zeros(len(z), dtype=torch.long) + + y, neg_dy = model(z, pos, batch) + +.. note:: You can train a model using only the labels (i.e. energy) by passing :code:`derivative=False` and then set it to :code:`True` to compute its derivative (i.e. forces) only during inference. + +.. note:: Some models take additional inputs such as the charge :code:`q` and the spin :code:`s` of the atoms depending on the chosen priors/outputs. Check the documentation of the model you are using to see if this is the case. + +.. note:: When periodic boundary conditions are required, modules typically offer the possibility of providing the box vectors at construction and/or as an argument to the forward pass. Check the documentation of the class you are using to see if this is the case. + + +Model Ensembles +--------------- +It is possible to create an ensemble of models by loading multiple checkpoints and averaging their predictions. The following example shows how to do this: + +.. code:: python + + import torch + from torchmdnet.models.model import load_model + checkpoints = ["/path/to/checkpoint/my_checkpoint1.ckpt", "/path/to/checkpoint/my_checkpoint2.ckpt"] + model_ensemble = load_model(checkpoints, return_std=True) + y_ensemble, neg_dy_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + +.. note:: :py:mod:`torchmdnet.models.model.load_model` will return an instance of :py:mod:`torchmdnet.models.model.Ensemble` if a list of checkpoints is passed. The :code:`return_std` option can be used to return the standard deviation of the predictions. + + + +.. autoclass:: torchmdnet.models.model.Ensemble + :noindex: + + From 365edea1e121c4bd7e89aec9fef7819a699bd506 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:29:14 +0100 Subject: [PATCH 11/17] Update docstrings --- torchmdnet/models/model.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 0422d6de6..d7ec6d7d3 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -194,29 +194,32 @@ def create_prior_models(args, dataset=None): 1. A single prior model name and its arguments as a dictionary: - ```python - args = { - "prior_model": "Atomref", - "prior_args": {"max_z": 100} - } - ``` + .. code:: python + + args = { + "prior_model": "Atomref", + "prior_args": {"max_z": 100} + } + + 2. A list of prior model names and their arguments as a list of dictionaries: - ```python + .. code:: python + + args = { + "prior_model": ["Atomref", "D2"], + "prior_args": [{"max_z": 100}, {"max_z": 100}] + } - args = { - "prior_model": ["Atomref", "D2"], - "prior_args": [{"max_z": 100}, {"max_z": 100}] - } - ``` 3. A list of prior model names and their arguments as a dictionary: - ```python - args = { - "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] - } - ``` + .. code:: python + + args = { + "prior_model": [{"Atomref": {"max_z": 100}}, {"D2": {"max_z": 100}}] + } + Args: args (dict): Arguments for the model. From f21f3de3b5e12d0781c642579081f3306d28b683 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Tue, 19 Mar 2024 15:31:12 +0100 Subject: [PATCH 12/17] Update docstring --- torchmdnet/models/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index d7ec6d7d3..c090f90f6 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -461,6 +461,12 @@ def forward( ): """Average predictions over all models in the ensemble. The arguments to this function are simply relayed to the forward method of each :py:mod:`TorchMD_Net` model in the ensemble. + Args: + *args: Positional arguments to forward to the models. + **kwargs: Keyword arguments to forward to the models. + Returns: + Tuple[Tensor, Optional[Tensor]] or Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: The average and standard deviation of the predictions over all models in the ensemble. If return_std is False, the output is a tuple (mean_y, mean_neg_dy). If return_std is True, the output is a tuple (mean_y, mean_neg_dy, std_y, std_neg_dy). + """ y = [] neg_dy = [] From c1ac21be22bd9572beeb1ffe4d5d0d6184e3959c Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Wed, 27 Mar 2024 16:51:39 +0200 Subject: [PATCH 13/17] support a zip of ckpt files for ensemble models --- torchmdnet/models/model.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index c090f90f6..eed99024c 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -144,7 +144,7 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): If a list of paths is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str or list): Path to the checkpoint file or a list of paths. + filepath (str or list): Path to the checkpoint file or a list of paths or a zip of checkpoints. args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. @@ -159,6 +159,23 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): return_std=return_std, ) + if filepath.endswith(".zip"): + import zipfile + import tempfile + from glob import glob + import os + + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(filepath, "r") as z: + z.extractall(tmpdir) + + filepath = glob(os.path.join(tmpdir, "*.ckpt")) + + return Ensemble( + [load_model(f, args=args, device=device, **kwargs) for f in filepath], + return_std=return_std, + ) + ckpt = torch.load(filepath, map_location="cpu") if args is None: args = ckpt["hyper_parameters"] From 9ece01b87d04168a5809bf3c660df69bc4b98262 Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Wed, 27 Mar 2024 16:57:36 +0200 Subject: [PATCH 14/17] add test --- tests/test_model.py | 60 ++++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 1dd5e3549..f606559ef 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -116,26 +116,28 @@ def test_cuda_graph_compatible(model_name): if not torch.cuda.is_available(): pytest.skip("CUDA not available") z, pos, batch = create_example_batch() - args = {"model": model_name, - "embedding_dimension": 128, - "num_layers": 2, - "num_rbf": 32, - "rbf_type": "expnorm", - "trainable_rbf": False, - "activation": "silu", - "cutoff_lower": 0.0, - "cutoff_upper": 5.0, - "max_z": 100, - "max_num_neighbors": 128, - "equivariance_invariance_group": "O(3)", - "prior_model": None, - "atom_filter": -1, - "derivative": True, - "check_errors": False, - "static_shapes": True, - "output_model": "Scalar", - "reduce_op": "sum", - "precision": 32 } + args = { + "model": model_name, + "embedding_dimension": 128, + "num_layers": 2, + "num_rbf": 32, + "rbf_type": "expnorm", + "trainable_rbf": False, + "activation": "silu", + "cutoff_lower": 0.0, + "cutoff_upper": 5.0, + "max_z": 100, + "max_num_neighbors": 128, + "equivariance_invariance_group": "O(3)", + "prior_model": None, + "atom_filter": -1, + "derivative": True, + "check_errors": False, + "static_shapes": True, + "output_model": "Scalar", + "reduce_op": "sum", + "precision": 32, + } model = create_model(args).to(device="cuda") model.eval() z = z.to("cuda") @@ -260,3 +262,21 @@ def test_ensemble(): assert neg_dy_std.shape == deriv.shape assert (y_std == 0).all() assert (neg_dy_std == 0).all() + + import zipfile + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + ensemble_zip = join(tmpdir, "ensemble.zip") + with zipfile.ZipFile(ensemble_zip, "w") as zipf: + for i, ckpt in enumerate(ckpts): + zipf.write(ckpt, f"model_{i}.ckpt") + ensemble_model = load_model(ensemble_zip, return_std=True) + pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch) + + torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5) + assert y_std.shape == pred.shape + assert neg_dy_std.shape == deriv.shape + assert (y_std == 0).all() + assert (neg_dy_std == 0).all() From 78482a8a91ebbedff4fa461e2bf95ebde342af4a Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 28 Mar 2024 07:35:33 +0100 Subject: [PATCH 15/17] Update docstring --- torchmdnet/models/model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index eed99024c..9ec852b6c 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -142,16 +142,21 @@ def create_model(args, prior_model=None, mean=None, std=None): def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. - If a list of paths is given, an :py:mod:`Ensemble` model is returned. + If a list of paths or a path to a zip file is given, an :py:mod:`Ensemble` model is returned. Args: - filepath (str or list): Path to the checkpoint file or a list of paths or a zip of checkpoints. + filepath (str or list): Can be any of the following: + + - Path to a checkpoint file. In this case, a :py:mod:`TorchMD_Net` model is returned. + - Path to a zip file containing multiple checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. + - List of paths to checkpoint files. In this case, an :py:mod:`Ensemble` model is returned. + args (dict, optional): Arguments for the model. Defaults to None. device (str, optional): Device on which the model should be loaded. Defaults to "cpu". return_std (bool, optional): Whether to return the standard deviation of an Ensemble model. Defaults to False. **kwargs: Extra keyword arguments for the model. Returns: - nn.Module: An instance of the TorchMD_Net model. + nn.Module: An instance of the TorchMD_Net model or an Ensemble model. """ if isinstance(filepath, (list, tuple)): return Ensemble( From 1a7b2746c48d55124ff5e2fb92fee0f9f85894cb Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 28 Mar 2024 07:58:26 +0100 Subject: [PATCH 16/17] Move Emsemble loading to a different function --- torchmdnet/models/model.py | 73 ++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 9ec852b6c..913693043 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -1,8 +1,10 @@ # Copyright Universitat Pompeu Fabra 2020-2023 https://www.compscience.org # Distributed under the MIT License. # (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT) - +from glob import glob +import os import re +import tempfile from typing import Optional, List, Tuple, Dict import torch from torch.autograd import grad @@ -13,6 +15,7 @@ from torchmdnet import priors from lightning_utilities.core.rank_zero import rank_zero_warn import warnings +import zipfile def create_model(args, prior_model=None, mean=None, std=None): @@ -139,6 +142,47 @@ def create_model(args, prior_model=None, mean=None, std=None): return model +def load_ensemble(filepath, args=None, device="cpu", return_std=False, **kwargs): + """Load an ensemble of models from a list of checkpoint files or a zip file. + + Args: + filepath (str or list): Can be any of the following: + + - Path to a zip file containing multiple checkpoint files. + - List of paths to checkpoint files. + + args (dict, optional): Arguments for the model. Defaults to None. + device (str, optional): Device on which the model should be loaded. Defaults to "cpu". + return_std (bool, optional): Whether to return the standard deviation of the predictions. Defaults to False. + **kwargs: Extra keyword arguments for the model, will be passed to :py:mod:`load_model`. + + Returns: + nn.Module: An instance of :py:mod:`Ensemble`. + """ + if isinstance(filepath, (list, tuple)): + assert all(isinstance(f, str) for f in filepath), "Invalid filepath list." + model_list = [ + load_model(f, args=args, device=device, **kwargs) for f in filepath + ] + elif filepath.endswith(".zip"): + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(filepath, "r") as z: + z.extractall(tmpdir) + ckpt_list = glob(os.path.join(tmpdir, "*.ckpt")) + assert len(ckpt_list) > 0, "No checkpoint files found in zip file." + model_list = [ + load_model(f, args=args, device=device, **kwargs) for f in ckpt_list + ] + else: + raise ValueError( + "Invalid filepath. Must be a list of paths or a path to a zip file." + ) + return Ensemble( + model_list, + return_std=return_std, + ) + + def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): """Load a model from a checkpoint file. @@ -158,29 +202,12 @@ def load_model(filepath, args=None, device="cpu", return_std=False, **kwargs): Returns: nn.Module: An instance of the TorchMD_Net model or an Ensemble model. """ - if isinstance(filepath, (list, tuple)): - return Ensemble( - [load_model(f, args=args, device=device, **kwargs) for f in filepath], - return_std=return_std, + isEnsemble = isinstance(filepath, (list, tuple)) or filepath.endswith(".zip") + if isEnsemble: + return load_ensemble( + filepath, args=args, device=device, return_std=return_std, **kwargs ) - - if filepath.endswith(".zip"): - import zipfile - import tempfile - from glob import glob - import os - - with tempfile.TemporaryDirectory() as tmpdir: - with zipfile.ZipFile(filepath, "r") as z: - z.extractall(tmpdir) - - filepath = glob(os.path.join(tmpdir, "*.ckpt")) - - return Ensemble( - [load_model(f, args=args, device=device, **kwargs) for f in filepath], - return_std=return_std, - ) - + assert isinstance(filepath, str) ckpt = torch.load(filepath, map_location="cpu") if args is None: args = ckpt["hyper_parameters"] From 74702dad9431dc4ea71a3f40deb59b6da9c537b0 Mon Sep 17 00:00:00 2001 From: Raul Date: Thu, 4 Apr 2024 08:34:07 +0200 Subject: [PATCH 17/17] Fix HDF5 not understanding some files (#313) * Fix bug in HDF5 that would cause an error during training when the dataset provides energies with shape (Nsamples,) instead of (Nsamples, 1) * Accommodate for previous behavior --- torchmdnet/datasets/hdf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index 3d817a50c..c647b7d2a 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -89,7 +89,10 @@ def _preload_data(self): # Watchout for the 1D case, embed can be shared for all samples tmp = torch.tensor(np.array(data), dtype=dtype) if tmp.ndim == 1: - tmp = tmp.unsqueeze(0).expand(size, -1) + if len(tmp) == size: + tmp = tmp.unsqueeze(-1) + else: + tmp = tmp.unsqueeze(0).expand(size, -1) self.stored_data[field].append(tmp) self.index.extend(list(zip([i] * size, range(size)))) i += 1