In [1]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
from nequip.utils.misc import get_default_device_name
from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config




default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device=get_default_device_name(),
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())



In [2]:
config = Config.from_file('./configs/example.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset[0]

AtomicData(atom_types=[15, 1], cell=[3, 3], edge_cell_shift=[152, 3], edge_index=[2, 152], forces=[15, 3], pbc=[3], pos=[15, 3], total_energy=[1])

In [17]:
# Trainer
from nequip.train.trainer import Trainer

trainer = Trainer(model=None, **Config.as_dict(config))

# what is this
# to update wandb data?
config.update(trainer.params)

# = Train/test split =
trainer.set_dataset(dataset, validation_dataset)

Torch device: cpu


In [24]:
import torch

from vladimir import GraphModuleMixin
from vladimir import OneHotAtomEncoding

one_hot = OneHotAtomEncoding(2)

from vladimir import AtomwiseLinear
from e3nn import o3

torch.manual_seed(32)


# Linear
al = AtomwiseLinear(irreps_in=one_hot.irreps_out, irreps_out=o3.Irreps([(32, (0, 1))]))


# Rbe
from vladimir import RadialBasisEdgeEncoding
from e3nn import o3

rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': 4}, cutoff_kwargs={'r_max': 4}, irreps_in=al.irreps_out)

# SH
torch.manual_seed(32)

from vladimir import SphericalHarmonicEdgeAttrs

sh = SphericalHarmonicEdgeAttrs(2, irreps_in=rbe.irreps_out)


# Convolution
from vladimir import conv
from vladimir import InteractionBlock, ConvNetLayer
from e3nn import o3

torch.manual_seed(32)


avg_num_neighbors = None


# 3 conv layers
conv1 = ConvNetLayer(irreps_in = sh.irreps_out, 
                    feature_irreps_hidden = '32x0e+32x1e+32x2e+32x1o+32x2o',
                    convolution = InteractionBlock,
                    convolution_kwargs={'invariant_layers': 2, 
                                        'invariant_neurons': 64,
                                        'avg_num_neighbors': avg_num_neighbors}
                   )


conv2 = ConvNetLayer(irreps_in = conv1.irreps_out, 
                    feature_irreps_hidden = '32x0e+32x1e+32x2e+32x1o+32x2o',
                    convolution = InteractionBlock,
                    convolution_kwargs={'invariant_layers': 2, 
                                        'invariant_neurons': 64,
                                        'avg_num_neighbors': avg_num_neighbors}
                   )

conv3 = ConvNetLayer(irreps_in = conv2.irreps_out, 
                    feature_irreps_hidden = '32x0e+32x0o+32x1e+32x2e+32x1o+32x2o',
                    convolution = InteractionBlock,
                    convolution_kwargs={'invariant_layers': 2, 
                                        'invariant_neurons': 64,
                                        'avg_num_neighbors': avg_num_neighbors}
                   )

# Last linear
from vladimir import conv
from vladimir import AtomwiseLinear
from e3nn import o3

torch.manual_seed(32)


al2 = AtomwiseLinear(irreps_in=conv3.irreps_out, irreps_out=o3.Irreps([(16, (0, 1))]))
al3 = AtomwiseLinear(irreps_in=al2.irreps_out, irreps_out=o3.Irreps([(1, (0, 1))]))


# Shift and scale
from vladimir import PerSpeciesScaleShift

torch.manual_seed(32)


num_types = 2

scales = [-11319.556641, -11319.556641]
shifts = [1/30.621034622192383, 1/30.621034622192383]

psss = PerSpeciesScaleShift(
    field = 'node_features',
    type_names = ['H', 'C'],
    num_types = num_types,
    shifts = shifts,
    scales = scales,
    arguments_in_dataset_units = True,
    scales_trainable = True,
    shifts_trainable = True,
    irreps_in = al3.irreps_out)

# Reduce
from vladimir import AtomwiseReduce

torch.manual_seed(32)


num_types = 2

scales = [1., 1.]
shifts = [0., 0.]

ar = AtomwiseReduce(
    field = 'shifted_node_features',
    out_field = 'total_energy',
    reduce = 'sum',
    irreps_in = psss.irreps_out)


from vladimir import SequentialGraphNetwork

module_list = [one_hot, al,
               rbe, sh,
               conv1, conv2, conv3,
               al2, al3,
               psss, ar]

graph_func = SequentialGraphNetwork(module_list)

from vladimir import StressOutput

torch.manual_seed(32)

so = StressOutput(graph_func)

from vladimir import RescaleOutput

torch.manual_seed(32)

rescale = RescaleOutput(model = so,
        scale_keys = ['pos', 'total_energy', 'forces', 'stress', 'virial'],
        shift_keys = ['total_energy'],
        scale_by=None,
        shift_by=None,
        shift_trainable = False,
        scale_trainable = False,
        irreps_in = so.irreps_in)

from vladimir import GraphModel

torch.manual_seed(32)

<torch._C.Generator at 0x7fb0f5eb8250>

In [25]:
from nequip.model._gmm import GraphModel

gm = GraphModel(rescale)

In [26]:
print(type(gm))

<class 'nequip.nn._graph_model.GraphModel'>


In [27]:
# Trainer
from nequip.train.trainer import Trainer

trainer = Trainer(model = gm, **Config.as_dict(config))

# what is this
# to update wandb data?
config.update(trainer.params)

# = Train/test split =
trainer.set_dataset(dataset, validation_dataset)

trainer.train()

Torch device: cpu
Number of weights: 227420
Number of trainable weights: 227420
! Starting training ...

validation
# Epoch batch         loss       loss_f       loss_e        f_mae       f_rmse      H_f_mae      C_f_mae  psavg_f_mae     H_f_rmse     C_f_rmse psavg_f_rmse        e_mae      e/N_mae
      0     5      1.6e+08     1.67e+07     1.44e+08     3.02e+03     4.09e+03     2.21e+03     3.94e+03     3.07e+03     2.95e+03     5.09e+03     4.02e+03      1.8e+05      1.2e+04


  Initialization     #    Epoch      wal       LR       loss_f       loss_e         loss        f_mae       f_rmse      H_f_mae      C_f_mae  psavg_f_mae     H_f_rmse     C_f_rmse psavg_f_rmse        e_mae      e/N_mae
! Initial Validation          0    2.061    0.005     1.47e+07     1.44e+08     1.59e+08     2.85e+03     3.84e+03     1.98e+03     3.85e+03     2.91e+03     2.59e+03     4.88e+03     3.74e+03      1.8e+05      1.2e+04
Wall time: 2.0631081250030547
! Best model        0 159085408.000
! Stop train