In [1]:
from mdgen.model.gemnet.gemnet import GemNetT


MODELS_PROJECT_ROOT: /home/tuoping/odefed_mdgen/odefed_mdgen/mdgen


In [2]:
from mdgen.model.gemnet.layers.embedding_block import AtomEmbedding

In [3]:
device = "cuda:0"
def get_model(**kwargs) -> GemNetT:
    return GemNetT(
        atom_embedding=AtomEmbedding(emb_size=4),
        num_targets=1,
        latent_dim=4,
        num_radial=4,
        num_blocks=1,
        emb_size_atom=4,
        emb_size_edge=4,
        emb_size_trip=4,
        emb_size_bil_trip=4,
        otf_graph=True,
        scale_file=f"mdgen/model/gemnet/gemnet-dT.json",
        **kwargs,
    ).to(device)

In [4]:
model = get_model(
    max_neighbors=20,
    cutoff=7.0,
    # regress stress in a non-conservative way
    regress_stress=True,
    max_cell_images_per_dim=20,
)

In [5]:
model.eval()

GemNetT(
  (angle_edge_emb): Sequential(
    (0): Linear(in_features=7, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=4, bias=True)
  )
  (radial_basis): RadialBasis(
    (envelope): PolynomialEnvelope()
    (rbf): GaussianSmearing()
  )
  (cbf_basis3): CircularBasisLayer(
    (radial_basis): RadialBasis(
      (envelope): PolynomialEnvelope()
      (rbf): GaussianSmearing()
    )
  )
  (lattice_out_blocks): ModuleList(
    (0-1): 2 x RBFBasedLatticeUpdateBlockFrac(
      (mlp): Sequential(
        (0): Dense(
          (linear): Linear(in_features=4, out_features=4, bias=False)
          (_activation): ScaledSiLU(
            (_activation): SiLU()
          )
        )
        (1): Dense(
          (linear): Linear(in_features=4, out_features=4, bias=False)
          (_activation): Identity()
        )
      )
      (dense_rbf_F): Dense(
        (linear): Linear(in_features=16, out_features=4, bias=False)
        (_activation): Identity()
     

In [6]:
from mdgen.model.tests.testutils import get_mp_20_debug_batch

In [7]:

batch = get_mp_20_debug_batch().to(device)

In [8]:
import torch
with torch.no_grad():
    model_out = model.forward(
        None,
        batch.frac_coords,
        batch.atom_types,
        batch.num_atoms,
        batch.batch,
        batch.lengths,
        batch.angles,
    )

cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0 cuda:0


In [9]:
# for mypy
assert model_out.stress is not None
assert torch.allclose(model_out.stress, model_out.stress.transpose(1, 2), atol=1e-5)

In [10]:
print(model_out)

ModelOutput(energy=tensor([[ 7.3401],
        [19.4495],
        [12.7095],
        [ 1.6860],
        [ 5.2052],
        [23.3484],
        [ 7.0529],
        [17.4456],
        [11.0498],
        [13.5993],
        [14.6290],
        [ 3.7225],
        [ 8.2983],
        [ 7.4768],
        [29.1272],
        [ 4.2671],
        [ 7.3273],
        [12.6814],
        [21.7417],
        [14.6777],
        [ 9.4412],
        [ 1.2599],
        [19.9703],
        [10.7226],
        [ 5.4895],
        [ 5.3953],
        [23.3612],
        [ 1.3512],
        [ 8.2811],
        [30.0697],
        [35.4383],
        [ 9.7728],
        [29.9570],
        [11.9289],
        [ 8.0838],
        [ 8.6024],
        [ 3.8489],
        [ 3.3731],
        [17.0748],
        [ 6.0037],
        [ 0.0480],
        [14.7376],
        [ 5.1421],
        [20.8301],
        [ 6.3134],
        [12.9787],
        [12.0138],
        [39.5103],
        [ 5.8957],
        [ 6.1863],
        [17.3277],
        [19.