## Trainer

In [None]:
""" Train a network."""

import logging
import argparse
import warnings

# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process 
# that has itself initialized PyTorch.
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
import numpy as np  # noqa: F401

from os.path import exists, isdir
from shutil import rmtree
from pathlib import Path

import torch

from nequip.model import model_from_config
from nequip.utils import Config
from nequip.data import dataset_from_config
from nequip.utils import load_file
from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS
from nequip.utils.test import assert_AtomicData_equivariant
from nequip.utils.versions import check_code_version
from nequip.utils.misc import get_default_device_name
from nequip.utils._global_options import _set_global_options
from nequip.scripts._logger import set_up_script_logger

default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device=get_default_device_name(),
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)
# All default_config keys are valid / requested
_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())


def main(args=None, running_as_script: bool = True):
    config = parse_command_line(args)

    if running_as_script:
        set_up_script_logger(config.get("log", None), config.verbose)

    found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
    if found_restart_file and not config.append:
        raise RuntimeError(
            f"Training instance exists at {config.root}/{config.run_name}; "
            "either set append to True or use a different root or runname"
        )
    elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
        # output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
        # first training epoch (usually due to memory):
        warnings.warn(
            f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
            f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
            f"be started."
        )
        rmtree(f"{config.root}/{config.run_name}")

    # for fresh new train
    if not found_restart_file:
        trainer = fresh_start(config)
    else:
        trainer = restart(config)

    # Train
    trainer.save()
    if config.get("gpu_oom_offload", False):
        if not torch.cuda.is_available():
            raise RuntimeError(
                "CUDA is not available; --gpu-oom-offload doesn't make sense."
            )
        warnings.warn(
            "! GPU OOM Offloading is ON:\n"
            "This is meant for training models that would be impossible otherwise due to OOM.\n"
            "Note that this comes at a speed cost and SHOULD NOT be used if your training fits in GPU memory without it.\n"
            "Please also consider whether a smaller model is a more appropriate solution!\n"
            "Also, a warning from PyTorch: 'If you overuse pinned memory, it can cause serious problems when running low on RAM!'"
        )
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            trainer.train()
    else:
        trainer.train()

    return


def parse_command_line(args=None):
    parser = argparse.ArgumentParser(
        description="Train (or restart training of) a NequIP model."
    )
    parser.add_argument(
        "config", help="YAML file configuring the model, dataset, and other options"
    )
    parser.add_argument(
        "--equivariance-test",
        help="test the model's equivariance before training on n (default 1) random frames from the dataset",
        const=1,
        type=int,
        nargs="?",
    )
    parser.add_argument(
        "--model-debug-mode",
        help="enable model debug mode, which can sometimes give much more useful error messages at the cost of some speed. Do not use for production training!",
        action="store_true",
    )
    parser.add_argument(
        "--grad-anomaly-mode",
        help="enable PyTorch autograd anomaly mode to debug NaN gradients. Do not use for production training!",
        action="store_true",
    )
    parser.add_argument(
        "--gpu-oom-offload",
        help="Use `torch.autograd.graph.save_on_cpu` to offload intermediate tensors to CPU (host) memory in order to train models that would be impossible otherwise due to OOM. Note that this comes as at a speed cost and SHOULD NOT be used if your training fits in GPU memory without it. Please also consider whether a smaller model is a more appropriate solution.",
        action="store_true",
    )
    parser.add_argument(
        "--log",
        help="log file to store all the screen logging",
        type=Path,
        default=None,
    )
    parser.add_argument(
        "--warn-unused",
        help="Warn instead of error when the config contains unused keys",
        action="store_true",
    )
    args = parser.parse_args(args=args)

    config = Config.from_file(args.config, defaults=default_config)
    for flag in (
        "model_debug_mode",
        "equivariance_test",
        "grad_anomaly_mode",
        "warn_unused",
        "gpu_oom_offload",
    ):
        config[flag] = getattr(args, flag) or config[flag]

    return config


def fresh_start(config):
    # we use add_to_config cause it's a fresh start and need to record it
    check_code_version(config, add_to_config=True)
    _set_global_options(config)
    if config["default_dtype"] != "float64":
        warnings.warn(
            f"default_dtype={config['default_dtype']} but we strongly recommend float64"
        )

    # = Make the trainer =
    if config.wandb:

        import wandb  # noqa: F401
        from nequip.train.trainer_wandb import TrainerWandB as Trainer

        # download parameters from wandb in case of sweeping
        from nequip.utils.wandb import init_n_update

        config = init_n_update(config)

    elif config.tensorboard:
        from nequip.train.trainer_tensorboard import TrainerTensorBoard as Trainer
    else:
        from nequip.train.trainer import Trainer # !!! Important to look next

    trainer = Trainer(model=None, **Config.as_dict(config))

    # what is this
    # to update wandb data?
    config.update(trainer.params)

    # = Load the dataset =
    dataset = dataset_from_config(config, prefix="dataset")
    logging.info(f"Successfully loaded the data set of type {dataset}...")
    try:
        validation_dataset = dataset_from_config(config, prefix="validation_dataset")
        logging.info(
            f"Successfully loaded the validation data set of type {validation_dataset}..."
        )
    except KeyError:
        # It couldn't be found
        validation_dataset = None

    # = Train/test split =
    trainer.set_dataset(dataset, validation_dataset)

    # = Build model =
    final_model = model_from_config(
        config=config, initialize=True, dataset=trainer.dataset_train
    )
    logging.info("Successfully built the network...")

    # Equivar test
    if config.equivariance_test > 0:
        n_train: int = len(trainer.dataset_train)
        assert config.equivariance_test <= n_train
        final_model.eval()
        indexes = torch.randperm(n_train)[: config.equivariance_test]
        errstr = assert_AtomicData_equivariant(
            final_model, [trainer.dataset_train[i] for i in indexes]
        )
        final_model.train()
        logging.info(
            "Equivariance test passed; equivariance errors:\n"
            "   Errors are in real units, where relevant.\n"
            "   Please note that the large scale of the typical\n"
            "   shifts to the (atomic) energy can cause\n"
            "   catastrophic cancellation and give incorrectly\n"
            "   the equivariance error as zero for those fields.\n"
            f"{errstr}"
        )
        del errstr, indexes, n_train

    # Set the trainer
    trainer.model = final_model

    # Store any updated config information in the trainer
    trainer.update_kwargs(config)

    # Only run the unused check as a callback after the trainer has
    # initialized everything (metrics, early stopping, etc.)
    def _unused_check():
        unused = config._unused_keys()
        if len(unused) > 0:
            message = f"The following keys in the config file were not used, did you make a typo?: {', '.join(unused)}. (If this sounds wrong, please file an issue. You can turn this error into a warning with `--warn-unused`, but please make sure that the key really is correctly spelled and used!.)"
            if config.warn_unused:
                warnings.warn(message)
            else:
                raise KeyError(message)

    trainer._post_init_callback = _unused_check

    return trainer


def restart(config):
    # load the dictionary
    restart_file = f"{config.root}/{config.run_name}/trainer.pth"
    dictionary = load_file(
        supported_formats=dict(torch=["pt", "pth"]),
        filename=restart_file,
        enforced_format="torch",
    )

    # compare dictionary to config and update stop condition related arguments
    for k in config.keys():
        if config[k] != dictionary.get(k, ""):
            if k == "max_epochs":
                dictionary[k] = config[k]
                logging.info(f'Update "{k}" to {dictionary[k]}')
            elif k.startswith("early_stop"):
                dictionary[k] = config[k]
                logging.info(f'Update "{k}" to {dictionary[k]}')
            elif isinstance(config[k], type(dictionary.get(k, ""))):
                raise ValueError(
                    f'Key "{k}" is different in config and the result trainer.pth file. Please double check'
                )

    # note, "trainer.pth"/dictionary also store code versions,
    # which will not be stored in config and thus not checked here
    check_code_version(config)

    # recursive loop, if same type but different value
    # raise error

    config = Config(dictionary, exclude_keys=["state_dict", "progress"])

    # dtype, etc.
    _set_global_options(config)

    # note, the from_dict method will check whether the code version
    # in trainer.pth is consistent and issue warnings
    if config.wandb:
        from nequip.train.trainer_wandb import TrainerWandB
        from nequip.utils.wandb import resume

        resume(config)
        trainer = TrainerWandB.from_dict(dictionary)
    else:
        from nequip.train.trainer import Trainer

        trainer = Trainer.from_dict(dictionary)

    # = Load the dataset =
    dataset = dataset_from_config(config, prefix="dataset")
    logging.info(f"Successfully re-loaded the data set of type {dataset}...")
    try:
        validation_dataset = dataset_from_config(config, prefix="validation_dataset")
        logging.info(
            f"Successfully re-loaded the validation data set of type {validation_dataset}..."
        )
    except KeyError:
        # It couldn't be found
        validation_dataset = None
    trainer.set_dataset(dataset, validation_dataset)

    return trainer


if __name__ == "__main__":
    main(running_as_script=True)

In [1]:
""" Train a network."""

import logging
import argparse
import warnings

# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process 
# that has itself initialized PyTorch.
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
import numpy as np  # noqa: F401

from os.path import exists, isdir
from shutil import rmtree
from pathlib import Path

import torch

In [8]:
from utils import Config, _GLOBAL_ALL_ASKED_FOR_KEYS
from utils import get_default_device_name
from utils import load_callable

default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device=get_default_device_name(),
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)
# All default_config keys are valid / requested
_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

config = Config.from_file('./example.yaml', defaults=default_config)

In [9]:
config

{'_jit_bailout_depth': 2, '_jit_fusion_strategy': [('DYNAMIC', 3)], '_jit_fuser': 'fuser1', 'root': 'results/toluene', 'tensorboard': False, 'wandb': True, 'model_builders': ['SimpleIrrepsConfig', 'EnergyModel', 'PerSpeciesRescale', 'StressForceOutput', 'RescaleEnergyEtc'], 'dataset_statistics_stride': 1, 'device': 'cuda', 'default_dtype': 'float64', 'model_dtype': 'float32', 'allow_tf32': True, 'verbose': 'info', 'model_debug_mode': False, 'equivariance_test': False, 'grad_anomaly_mode': False, 'gpu_oom_offload': False, 'append': True, 'warn_unused': False, 'run_name': 'example-run-toluene', 'seed': 123, 'dataset_seed': 456, 'r_max': 4.0, 'num_layers': 4, 'l_max': 2, 'parity': True, 'num_features': 32, 'nonlinearity_type': 'gate', 'nonlinearity_scalars': {'e': 'silu', 'o': 'tanh'}, 'nonlinearity_gates': {'e': 'silu', 'o': 'tanh'}, 'num_basis': 8, 'BesselBasis_trainable': True, 'PolynomialCutoff_p': 6, 'invariant_layers': 2, 'invariant_neurons': 64, 'avg_num_neighbors': 'auto', 'use_sc

In [6]:
def model_from_config(
    config: Config,
    initialize: bool = False,
    dataset: Optional[AtomicDataset] = None,
    deploy: bool = False,
) -> GraphModuleMixin:
    """Build a model based on `config`.

    Model builders (`model_builders`) can have arguments:
     - ``config``: the config. Always present.
     - ``model``: the model produced by the previous builder. Cannot be requested by the first builder, must be requested by subsequent ones.
     - ``initialize``: whether to initialize the model
     - ``dataset``: if ``initialize`` is True, the dataset
     - ``deploy``: whether the model object is for deployment / inference

    Note that this function temporarily sets ``torch.set_default_dtype()`` and as such is not thread safe.

    Args:
        config
        initialize (bool): whether ``model_builders`` should be instructed to initialize the model
        dataset: dataset for initializers if ``initialize`` is True.
        deploy (bool): whether ``model_builders`` should be told the model is for deployment / inference

    Returns:
        The build model.
    """
    if isinstance(config, dict):
        config = Config.from_dict(config)
    # Pre-process config
    type_mapper = None
    if dataset is not None:
        type_mapper = dataset.type_mapper
    else:
        try:
            type_mapper, _ = instantiate(TypeMapper, all_args=config)
        except RuntimeError:
            pass

    if type_mapper is not None:
        if "num_types" in config:
            assert (
                config["num_types"] == type_mapper.num_types
            ), "inconsistant config & dataset"
        if "type_names" in config:
            assert (
                config["type_names"] == type_mapper.type_names
            ), "inconsistant config & dataset"
        config["num_types"] = type_mapper.num_types
        config["type_names"] = type_mapper.type_names
        config["type_to_chemical_symbol"] = type_mapper.type_to_chemical_symbol
        # We added them, so they are by definition valid:
        _GLOBAL_ALL_ASKED_FOR_KEYS.update(
            {"num_types", "type_names", "type_to_chemical_symbol"}
        )

    default_dtype = torch.get_default_dtype()
    model_dtype: torch.dtype = dtype_from_name(config.get("model_dtype", default_dtype))
    config["model_dtype"] = str(model_dtype).lstrip("torch.")
    # confirm sanity
    assert default_dtype in (torch.float32, torch.float64)
    if default_dtype == torch.float32 and model_dtype == torch.float64:
        raise ValueError(
            "Overall default_dtype=float32, but model_dtype=float64 is a higher precision- change default_dtype to float64"
        )
    # temporarily set the default dtype
    start_graph_model_builders = None
    with torch_default_dtype(model_dtype):

        # Build
        builders = [
            load_callable(b, prefix="nequip.model")
            for b in config.get("model_builders", [])
        ]

        model = None

        for builder_i, builder in enumerate(builders):
            pnames = inspect.signature(builder).parameters
            params = {}
            if "graph_model" in pnames:
                # start graph_model builders, which happen later
                start_graph_model_builders = builder_i
                break
            if "initialize" in pnames:
                params["initialize"] = initialize
            if "deploy" in pnames:
                params["deploy"] = deploy
            if "config" in pnames:
                params["config"] = config
            if "dataset" in pnames:
                if "initialize" not in pnames:
                    raise ValueError(
                        "Cannot request dataset without requesting initialize"
                    )
                if (
                    initialize
                    and pnames["dataset"].default == inspect.Parameter.empty
                    and dataset is None
                ):
                    raise RuntimeError(
                        f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`."
                    )
                params["dataset"] = dataset
            if "model" in pnames:
                if model is None:
                    raise RuntimeError(
                        f"Builder {builder.__name__} asked for the model as an input, but no previous builder has returned a model"
                    )
                params["model"] = model
            else:
                if model is not None:
                    raise RuntimeError(
                        f"All model_builders after the first one that returns a model must take the model as an argument; {builder.__name__} doesn't"
                    )
            model = builder(**params)
            if model is not None and not isinstance(model, GraphModuleMixin):
                raise TypeError(
                    f"Builder {builder.__name__} didn't return a GraphModuleMixin, got {type(model)} instead"
                )
    # reset to default dtype by context manager

    # Wrap the model up
    model = GraphModel(
        model,
        model_dtype=model_dtype,
        model_input_fields=config.get("model_input_fields", {}),
    )

    # Run GraphModel builders
    if start_graph_model_builders is not None:
        for builder in builders[start_graph_model_builders:]:
            pnames = inspect.signature(builder).parameters
            params = {}
            assert "graph_model" in pnames
            params["graph_model"] = model
            if "model" in pnames:
                raise ValueError(
                    f"Once any builder requests `graph_model` (first requested by {builders[start_graph_model_builders].__name__}), no builder can request `model`, but {builder.__name__} did"
                )
            if "initialize" in pnames:
                params["initialize"] = initialize
            if "deploy" in pnames:
                params["deploy"] = deploy
            if "config" in pnames:
                params["config"] = config
            if "dataset" in pnames:
                if "initialize" not in pnames:
                    raise ValueError(
                        "Cannot request dataset without requesting initialize"
                    )
                if (
                    initialize
                    and pnames["dataset"].default == inspect.Parameter.empty
                    and dataset is None
                ):
                    raise RuntimeError(
                        f"Builder {builder.__name__} requires the dataset, initialize is true, but no dataset was provided to `model_from_config`."
                    )
                params["dataset"] = dataset

            model = builder(**params)
            if not isinstance(model, GraphModel):
                raise TypeError(
                    f"Builder {builder.__name__} didn't return a GraphModel, got {type(model)} instead"
                )

    return model

NameError: name 'Optional' is not defined

In [None]:
class GraphModuleMixin:
    r"""Mixin parent class for ``torch.nn.Module``s that act on and return ``AtomicDataDict.Type`` graph data.

    All such classes should call ``_init_irreps`` in their ``__init__`` functions with information on the data fields they expect,
    require, and produce, as well as their corresponding irreps.
    """

    def _init_irreps(
        self,
        irreps_in: Dict[str, Any] = {},
        my_irreps_in: Dict[str, Any] = {},
        required_irreps_in: Sequence[str] = [],
        irreps_out: Dict[str, Any] = {},
    ):
        """Setup the expected data fields and their irreps for this graph module.

        ``None`` is a valid irreps in the context for anything that is invariant but not well described by an ``e3nn.o3.Irreps``.
        An example are edge indexes in a graph, which are invariant but are integers, not ``0e`` scalars.

        Args:
            irreps_in (dict): maps names of all input fields from previous modules or
                data to their corresponding irreps
            my_irreps_in (dict): maps names of fields to the irreps they must have for
                this graph module. Will be checked for consistancy with ``irreps_in``
            required_irreps_in: sequence of names of fields that must be present in
                ``irreps_in``, but that can have any irreps.
            irreps_out (dict): mapping names of fields that are modified/output by
                this graph module to their irreps.
        """
        # Coerce
        irreps_in = {} if irreps_in is None else irreps_in
        irreps_in = AtomicDataDict._fix_irreps_dict(irreps_in)
        # positions are *always* 1o, and always present
        if AtomicDataDict.POSITIONS_KEY in irreps_in:
            if irreps_in[AtomicDataDict.POSITIONS_KEY] != o3.Irreps("1x1o"):
                raise ValueError(
                    f"Positions must have irreps 1o, got instead `{irreps_in[AtomicDataDict.POSITIONS_KEY]}`"
                )
        irreps_in[AtomicDataDict.POSITIONS_KEY] = o3.Irreps("1o")
        # edges are also always present
        if AtomicDataDict.EDGE_INDEX_KEY in irreps_in:
            if irreps_in[AtomicDataDict.EDGE_INDEX_KEY] is not None:
                raise ValueError(
                    f"Edge indexes must have irreps None, got instead `{irreps_in[AtomicDataDict.EDGE_INDEX_KEY]}`"
                )
        irreps_in[AtomicDataDict.EDGE_INDEX_KEY] = None

        my_irreps_in = AtomicDataDict._fix_irreps_dict(my_irreps_in)

        irreps_out = AtomicDataDict._fix_irreps_dict(irreps_out)
        # Confirm compatibility:
        # with my_irreps_in
        for k in my_irreps_in:
            if k in irreps_in and irreps_in[k] != my_irreps_in[k]:
                raise ValueError(
                    f"The given input irreps {irreps_in[k]} for field '{k}' is incompatible with this configuration {type(self)}; should have been {my_irreps_in[k]}"
                )
        # with required_irreps_in
        for k in required_irreps_in:
            if k not in irreps_in:
                raise ValueError(
                    f"This {type(self)} requires field '{k}' to be in irreps_in"
                )
        # Save stuff
        self.irreps_in = irreps_in
        # The output irreps of any graph module are whatever inputs it has, overwritten with whatever outputs it has.
        new_out = irreps_in.copy()
        new_out.update(irreps_out)
        self.irreps_out = new_out

    def _add_independent_irreps(self, irreps: Dict[str, Any]):
        """
        Insert some independent irreps that need to be exposed to the self.irreps_in and self.irreps_out.
        The terms that have already appeared in the irreps_in will be removed.

        Args:
            irreps (dict): maps names of all new fields
        """

        irreps = {
            key: irrep for key, irrep in irreps.items() if key not in self.irreps_in
        }
        irreps_in = AtomicDataDict._fix_irreps_dict(irreps)
        irreps_out = AtomicDataDict._fix_irreps_dict(
            {key: irrep for key, irrep in irreps.items() if key not in self.irreps_out}
        )
        self.irreps_in.update(irreps_in)
        self.irreps_out.update(irreps_out)

    def _make_tracing_inputs(self, n):
        # We impliment this to be able to trace graph modules
        out = []
        for _ in range(n):
            batch = random.randint(1, 4)
            # TODO: handle None case
            # TODO: do only required inputs
            # TODO: dummy input if empty?
            out.append(
                {
                    "forward": (
                        {
                            k: i.randn(batch, -1)
                            for k, i in self.irreps_in.items()
                            if i is not None
                        },
                    )
                }
            )
        return out


class SequentialGraphNetwork(GraphModuleMixin, torch.nn.Sequential):
    r"""A ``torch.nn.Sequential`` of ``GraphModuleMixin``s.

    Args:
        modules (list or dict of ``GraphModuleMixin``s): the sequence of graph modules. 
        If a list, the modules will be named ``"module0", "module1", ...``.
    """

    def __init__(
        self,
        modules: Union[Sequence[GraphModuleMixin], Dict[str, GraphModuleMixin]],
    ):
        if isinstance(modules, dict):
            module_list = list(modules.values())
        else:
            module_list = list(modules)
        # check in/out irreps compatible
        for m1, m2 in zip(module_list, module_list[1:]):
            assert AtomicDataDict._irreps_compatible(
                m1.irreps_out, m2.irreps_in
            ), f"Incompatible irreps_out from {type(m1).__name__} for input to {type(m2).__name__}: {m1.irreps_out} -> {m2.irreps_in}"
        self._init_irreps(
            irreps_in=module_list[0].irreps_in,
            my_irreps_in=module_list[0].irreps_in,
            irreps_out=module_list[-1].irreps_out,
        )
        # torch.nn.Sequential will name children correctly if passed an OrderedDict
        if isinstance(modules, dict):
            modules = OrderedDict(modules)
        else:
            modules = OrderedDict((f"module{i}", m) for i, m in enumerate(module_list))
        super().__init__(modules)

    @classmethod
    def from_parameters(
        cls,
        shared_params: Mapping,
        layers: Dict[str, Union[Callable, Tuple[Callable, Dict[str, Any]]]],
        irreps_in: Optional[dict] = None,
    ):
        r"""Construct a ``SequentialGraphModule`` of modules built from a shared set of parameters.

        For some layer, a parameter with name ``param`` will be taken, in order of priority, from:
          1. The specific value in the parameter dictionary for that layer, if provided
          2. ``name_param`` in ``shared_params`` where ``name`` is the name of the layer
          3. ``param`` in ``shared_params``

        Args:
            shared_params (dict-like): shared parameters from which to pull when instantiating the module
            layers (dict): dictionary mapping unique names of layers to either:
                  1. A callable (such as a class or function) that can be used to ``instantiate`` a module for that layer
                  2. A tuple of such a callable and a dictionary mapping parameter names to values. The given dictionary of parameters will override for this layer values found in ``shared_params``.
                Options 1. and 2. can be mixed.
            irreps_in (optional dict): ``irreps_in`` for the first module in the sequence.

        Returns:
            The constructed SequentialGraphNetwork.
        """
        # note that dictionary ordered gueranteed in >=3.7, so its fine to do an ordered sequential as a dict.
        built_modules = []
        for name, builder in layers.items():
            if not isinstance(name, str):
                raise ValueError(f"`'name'` must be a str; got `{name}`")
            if isinstance(builder, tuple):
                builder, params = builder
            else:
                params = {}
            if not callable(builder):
                raise TypeError(
                    f"The builder has to be a class or a function. got {type(builder)}"
                )

            instance, _ = instantiate(
                builder=builder,
                prefix=name,
                positional_args=(
                    dict(
                        irreps_in=(
                            built_modules[-1].irreps_out
                            if len(built_modules) > 0
                            else irreps_in
                        )
                    )
                ),
                optional_args=params,
                all_args=shared_params,
            )

            if not isinstance(instance, GraphModuleMixin):
                raise TypeError(
                    f"Builder `{builder}` for layer with name `{name}` did not return a GraphModuleMixin, instead got a {type(instance).__name__}"
                )

            built_modules.append(instance)

        return cls(
            OrderedDict(zip(layers.keys(), built_modules)),
        )

    @torch.jit.unused
    def append(self, name: str, module: GraphModuleMixin) -> None:
        r"""Append a module to the SequentialGraphNetwork.

        Args:
            name (str): the name for the module
            module (GraphModuleMixin): the module to append
        """
        assert AtomicDataDict._irreps_compatible(self.irreps_out, module.irreps_in)
        self.add_module(name, module)
        self.irreps_out = dict(module.irreps_out)
        return

    @torch.jit.unused
    def append_from_parameters(
        self,
        shared_params: Mapping,
        name: str,
        builder: Callable,
        params: Dict[str, Any] = {},
    ) -> GraphModuleMixin:
        r"""Build a module from parameters and append it.

        Args:
            shared_params (dict-like): shared parameters from which to pull when instantiating the module
            name (str): the name for the module
            builder (callable): a class or function to build a module
            params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params``

        Returns:
            the build module
        """
        instance, _ = instantiate(
            builder=builder,
            prefix=name,
            positional_args=(dict(irreps_in=self[-1].irreps_out)),
            optional_args=params,
            all_args=shared_params,
        )
        self.append(name, instance)
        return instance

    @torch.jit.unused
    def insert(
        self,
        name: str,
        module: GraphModuleMixin,
        after: Optional[str] = None,
        before: Optional[str] = None,
    ) -> None:
        """Insert a module after the module with name ``after``.

        Args:
            name: the name of the module to insert
            module: the moldule to insert
            after: the module to insert after
            before: the module to insert before
        """

        if (before is None) is (after is None):
            raise ValueError("Only one of before or after argument needs to be defined")
        elif before is None:
            insert_location = after
        else:
            insert_location = before

        # This checks names, etc.
        self.add_module(name, module)
        # Now insert in the right place by overwriting
        names = list(self._modules.keys())
        modules = list(self._modules.values())
        idx = names.index(insert_location)
        if before is None:
            idx += 1
        names.insert(idx, name)
        modules.insert(idx, module)

        self._modules = OrderedDict(zip(names, modules))

        module_list = list(self._modules.values())

        # sanity check the compatibility
        if idx > 0:
            assert AtomicDataDict._irreps_compatible(
                module_list[idx - 1].irreps_out, module.irreps_in
            )
        if len(module_list) > idx:
            assert AtomicDataDict._irreps_compatible(
                module_list[idx + 1].irreps_in, module.irreps_out
            )

        # insert the new irreps_out to the later modules
        for module_id, next_module in enumerate(module_list[idx + 1 :]):
            next_module._add_independent_irreps(module.irreps_out)

        # update the final wrapper irreps_out
        self.irreps_out = dict(module_list[-1].irreps_out)

        return

    @torch.jit.unused
    def insert_from_parameters(
        self,
        shared_params: Mapping,
        name: str,
        builder: Callable,
        params: Dict[str, Any] = {},
        after: Optional[str] = None,
        before: Optional[str] = None,
    ) -> GraphModuleMixin:
        r"""Build a module from parameters and insert it after ``after``.

        Args:
            shared_params (dict-like): shared parameters from which to pull when instantiating the module
            name (str): the name for the module
            builder (callable): a class or function to build a module
            params (dict, optional): extra specific parameters for this module that take priority over those in ``shared_params``
            after: the name of the module to insert after
            before: the name of the module to insert before

        Returns:
            the inserted module
        """
        if (before is None) is (after is None):
            raise ValueError("Only one of before or after argument needs to be defined")
        elif before is None:
            insert_location = after
        else:
            insert_location = before
        idx = list(self._modules.keys()).index(insert_location) - 1
        if before is None:
            idx += 1
        instance, _ = instantiate(
            builder=builder,
            prefix=name,
            positional_args=(dict(irreps_in=self[idx].irreps_out)),
            optional_args=params,
            all_args=shared_params,
        )
        self.insert(after=after, before=before, name=name, module=instance)
        return instance

    # Copied from https://pytorch.org/docs/stable/_modules/torch/nn/modules/container.html#Sequential
    # with type annotations added
    def forward(self, input: AtomicDataDict.Type) -> AtomicDataDict.Type:
        for module in self:
            input = module(input)
        return input

In [None]:
from typing import List, Dict, Any, Optional

import torch

from e3nn.util._argtools import _get_device

from nequip.data import AtomicDataDict


class GraphModel(GraphModuleMixin, torch.nn.Module):
    """Top-level module for any complete `nequip` model.

    Manages top-level rescaling, dtypes, and more.

    Args:

    """

    model_dtype: torch.dtype
    model_input_fields: List[str]

    _num_rescale_layers: int

    def __init__(
        self,
        model: GraphModuleMixin,
        model_dtype: Optional[torch.dtype] = None,
        model_input_fields: Dict[str, Any] = {},
    ) -> None:
        super().__init__()
        irreps_in = {
            # Things that always make sense as inputs:
            AtomicDataDict.POSITIONS_KEY: "1o",
            AtomicDataDict.EDGE_INDEX_KEY: None,
            AtomicDataDict.EDGE_CELL_SHIFT_KEY: None,
            AtomicDataDict.CELL_KEY: "1o",  # 3 of them, but still
            AtomicDataDict.BATCH_KEY: None,
            AtomicDataDict.BATCH_PTR_KEY: None,
            AtomicDataDict.ATOM_TYPE_KEY: None,
        }
        model_input_fields = AtomicDataDict._fix_irreps_dict(model_input_fields)
        assert len(set(irreps_in.keys()).intersection(model_input_fields.keys())) == 0
        irreps_in.update(model_input_fields)
        self._init_irreps(irreps_in=irreps_in, irreps_out=model.irreps_out)
        for k, irreps in model.irreps_in.items():
            if self.irreps_in.get(k, None) != irreps:
                raise RuntimeError(
                    f"Model has `{k}` in its irreps_in with irreps `{irreps}`, but `{k}` is missing from/has inconsistent irreps in model_input_fields of `{self.irreps_in.get(k, 'missing')}`"
                )
        self.model = model
        self.model_dtype = (
            model_dtype if model_dtype is not None else torch.get_default_dtype()
        )
        self.model_input_fields = list(self.irreps_in.keys())

        self._num_rescale_layers = 0
        outer_layer = self.model
        while isinstance(outer_layer, RescaleOutput):
            self._num_rescale_layers += 1
            outer_layer = outer_layer.model

    # == Rescaling ==
    @torch.jit.unused
    def all_RescaleOutputs(self) -> List[RescaleOutput]:
        """All ``RescaleOutput``s wrapping the model, in evaluation order."""
        if self._num_rescale_layers == 0:
            return []
        # we know there's at least one
        out = [self.model]
        for _ in range(self._num_rescale_layers - 1):
            out.append(out[-1].model)
        # we iterated outermost to innermost, which is opposite of evaluation order
        assert len(out) == self._num_rescale_layers
        return out[::-1]

    @torch.jit.unused
    def unscale(
        self, data: AtomicDataDict.Type, force_process: bool = False
    ) -> AtomicDataDict.Type:
        data_unscaled = data.copy()
        # we need to unscale from the outside-in:
        for layer in self.all_RescaleOutputs()[::-1]:
            data_unscaled = layer.unscale(data_unscaled, force_process=force_process)
        return data_unscaled

    @torch.jit.unused
    def scale(
        self, data: AtomicDataDict.Type, force_process: bool = False
    ) -> AtomicDataDict.Type:
        data_scaled = data.copy()
        # we need to scale from the inside out:
        for layer in self.all_RescaleOutputs():
            data_scaled = layer.scale(data_scaled, force_process=force_process)
        return data_scaled

    # == Inference ==

    def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
        # restrict the input data to allowed keys, and cast to model_dtype
        # this also prevents the model from direclty using the dict from the outside,
        # preventing weird pass-by-reference bugs
        new_data: AtomicDataDict.Type = {}
        for k, v in data.items():
            if k in self.model_input_fields:
                if v.is_floating_point():
                    v = v.to(dtype=self.model_dtype)
                new_data[k] = v
        # run the model
        data = self.model(new_data)
        return data

    # == Helpers ==

    @torch.jit.unused
    def get_device(self) -> torch.device:
        return _get_device(self)

In [None]:
def EnergyModel(
    config, initialize: bool, dataset: Optional[AtomicDataset] = None
) -> SequentialGraphNetwork:
    """Base default energy model archetecture.

    For minimal and full configuration option listings, see ``minimal.yaml`` and ``example.yaml``.
    """
    logging.debug("Start building the network model")

    builder_utils.add_avg_num_neighbors(
        config=config, initialize=initialize, dataset=dataset
    )

    num_layers = config.get("num_layers", 3)

    layers = {
        # -- Encode --
        "one_hot": OneHotAtomEncoding,
        "spharm_edges": SphericalHarmonicEdgeAttrs,
        "radial_basis": RadialBasisEdgeEncoding,
        # -- Embed features --
        "chemical_embedding": AtomwiseLinear,
    }

    # add convnet layers
    # insertion preserves order
    for layer_i in range(num_layers):
        layers[f"layer{layer_i}_convnet"] = ConvNetLayer

    # .update also maintains insertion order
    layers.update(
        {
            # TODO: the next linear throws out all L > 0, don't create them in the last layer of convnet
            # -- output block --
            "conv_to_output_hidden": AtomwiseLinear,
            "output_hidden_to_scalar": (
                AtomwiseLinear,
                dict(irreps_out="1x0e", out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY),
            ),
        }
    )

    layers["total_energy_sum"] = (
        AtomwiseReduce,
        dict(
            reduce="sum",
            field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
            out_field=AtomicDataDict.TOTAL_ENERGY_KEY,
        ),
    )

    return SequentialGraphNetwork.from_parameters(
        shared_params=config,
        layers=layers,
    )

In [11]:
# Build
#builders = [
#    load_callable(b, prefix="model")
#    for b in config.get("model_builders", [])
#]

config.get("model_builders", [])

['SimpleIrrepsConfig',
 'EnergyModel',
 'PerSpeciesRescale',
 'StressForceOutput',
 'RescaleEnergyEtc']

In [13]:
layers = {
    # -- Encode --
    "one_hot": 1,
    "spharm_edges": 1,
    "radial_basis": 1,
    # -- Embed features --
    "chemical_embedding": 1,
}

# add convnet layers
# insertion preserves order
for layer_i in range(3):
    layers[f"layer{layer_i}_convnet"] = 1

# .update also maintains insertion order
layers.update(
    {
        # TODO: the next linear throws out all L > 0, don't create them in the last layer of convnet
        # -- output block --
        "conv_to_output_hidden": 1,
        "output_hidden_to_scalar": 1
    }
)

layers["total_energy_sum"] = 1

layers

{'one_hot': 1,
 'spharm_edges': 1,
 'radial_basis': 1,
 'chemical_embedding': 1,
 'layer0_convnet': 1,
 'layer1_convnet': 1,
 'layer2_convnet': 1,
 'conv_to_output_hidden': 1,
 'output_hidden_to_scalar': 1,
 'total_energy_sum': 1}

In [None]:
trainer = Trainer(model=None, **Config.as_dict(config))

# what is this
# to update wandb data?
config.update(trainer.params)

# = Load the dataset =
dataset = dataset_from_config(config, prefix="dataset")

# = Train/test split =
trainer.set_dataset(dataset, validation_dataset)

# = Build model =
final_model = model_from_config(
    config=config, initialize=True, dataset=trainer.dataset_train
)