Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Mar 27, 2024
1 parent c1ac21b commit 9ece01b
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,28 @@ def test_cuda_graph_compatible(model_name):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
z, pos, batch = create_example_batch()
args = {"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
args = {
"model": model_name,
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"check_errors": False,
"static_shapes": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
model = create_model(args).to(device="cuda")
model.eval()
z = z.to("cuda")
Expand Down Expand Up @@ -260,3 +262,21 @@ def test_ensemble():
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()

import zipfile
import tempfile

with tempfile.TemporaryDirectory() as tmpdir:
ensemble_zip = join(tmpdir, "ensemble.zip")
with zipfile.ZipFile(ensemble_zip, "w") as zipf:
for i, ckpt in enumerate(ckpts):
zipf.write(ckpt, f"model_{i}.ckpt")
ensemble_model = load_model(ensemble_zip, return_std=True)
pred_ensemble, deriv_ensemble, y_std, neg_dy_std = ensemble_model(z, pos, batch)

torch.testing.assert_close(pred, pred_ensemble, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(deriv, deriv_ensemble, atol=1e-5, rtol=1e-5)
assert y_std.shape == pred.shape
assert neg_dy_std.shape == deriv.shape
assert (y_std == 0).all()
assert (neg_dy_std == 0).all()

0 comments on commit 9ece01b

Please sign in to comment.