In [11]:
# Refrence example. Want to know how this one works

import torch
import ase.io
from nequip.data import AtomicData, AtomicDataDict
from nequip.scripts.deploy import load_deployed_model, R_MAX_KEY

#device = "cpu"  # "cuda" etc.
#model, metadata = load_deployed_model(
#    "models/NaBr/NaBr-deployed.pth",
#    device=device,
#)

# Load some input structure from an XYZ or other ASE readable file:
data = AtomicData.from_ase(ase.io.read("./datasets/NaBr-data.xyz"), r_max=5.)
#data = data.to(device)

#out = model(AtomicData.to_AtomicDataDict(data))

print(type(data))
print(f"Total energy: {data[AtomicDataDict.TOTAL_ENERGY_KEY]}")
print(f"Force on atom 0: {data[AtomicDataDict.FORCE_KEY][0]}")


<class 'nequip.data.AtomicData.AtomicData'>
Total energy: tensor([-189.3690])
Force on atom 0: tensor([0.0591, 0.8510, 0.0488])


In [12]:
data

AtomicData(atomic_numbers=[64, 1], cell=[3, 3], edge_cell_shift=[1226, 3], edge_index=[2, 1226], forces=[64, 3], free_energy=[1], pbc=[3], pos=[64, 3], stress=[3, 3], total_energy=[1])

## AtomicDataDict

In [16]:
"""Keys for dictionaries/AtomicData objects.

This is a seperate module to compensate for a TorchScript bug that can only recognize constants 
when they are accessed as attributes of an imported module.
"""

import sys
from typing import List

if sys.version_info[1] >= 8:
    from typing import Final
else:
    from typing_extensions import Final

# == Define allowed keys as constants ==
# The positions of the atoms in the system
POSITIONS_KEY: Final[str] = "pos"
# The [2, n_edge] index tensor giving center -> neighbor relations
EDGE_INDEX_KEY: Final[str] = "edge_index"
# A [n_edge, 3] tensor of how many periodic cells each edge crosses in each cell vector
EDGE_CELL_SHIFT_KEY: Final[str] = "edge_cell_shift"
# [n_batch, 3, 3] or [3, 3] tensor where rows are the cell vectors
CELL_KEY: Final[str] = "cell"
# [n_batch, 3] bool tensor
PBC_KEY: Final[str] = "pbc"
# [n_atom, 1] long tensor
ATOMIC_NUMBERS_KEY: Final[str] = "atomic_numbers"
# [n_atom, 1] long tensor
ATOM_TYPE_KEY: Final[str] = "atom_types"

BASIC_STRUCTURE_KEYS: Final[List[str]] = [
    POSITIONS_KEY,
    EDGE_INDEX_KEY,
    EDGE_CELL_SHIFT_KEY,
    CELL_KEY,
    PBC_KEY,
    ATOM_TYPE_KEY,
    ATOMIC_NUMBERS_KEY,
]

# A [n_edge, 3] tensor of displacement vectors associated to edges
EDGE_VECTORS_KEY: Final[str] = "edge_vectors"
# A [n_edge] tensor of the lengths of EDGE_VECTORS
EDGE_LENGTH_KEY: Final[str] = "edge_lengths"
# [n_edge, dim] (possibly equivariant) attributes of each edge
EDGE_ATTRS_KEY: Final[str] = "edge_attrs"
# [n_edge, dim] invariant embedding of the edges
EDGE_EMBEDDING_KEY: Final[str] = "edge_embedding"
EDGE_FEATURES_KEY: Final[str] = "edge_features"
# [n_edge, 1] invariant of the radial cutoff envelope for each edge, allows reuse of cutoff envelopes
EDGE_CUTOFF_KEY: Final[str] = "edge_cutoff"
# edge energy as in Allegro
EDGE_ENERGY_KEY: Final[str] = "edge_energy"

NODE_FEATURES_KEY: Final[str] = "node_features"
NODE_ATTRS_KEY: Final[str] = "node_attrs"

PER_ATOM_ENERGY_KEY: Final[str] = "atomic_energy"
TOTAL_ENERGY_KEY: Final[str] = "total_energy"
FORCE_KEY: Final[str] = "forces"
PARTIAL_FORCE_KEY: Final[str] = "partial_forces"
STRESS_KEY: Final[str] = "stress"
VIRIAL_KEY: Final[str] = "virial"

ALL_ENERGY_KEYS: Final[List[str]] = [
    EDGE_ENERGY_KEY,
    PER_ATOM_ENERGY_KEY,
    TOTAL_ENERGY_KEY,
    FORCE_KEY,
    PARTIAL_FORCE_KEY,
    STRESS_KEY,
    VIRIAL_KEY,
]

BATCH_KEY: Final[str] = "batch"
BATCH_PTR_KEY: Final[str] = "ptr"

# Make a list of allowed keys
ALLOWED_KEYS: List[str] = [
    getattr(sys.modules[__name__], k)
    for k in sys.modules[__name__].__dict__.keys()
    if k.endswith("_KEY")
]

In [18]:
"""nequip.data.jit: TorchScript functions for dealing with AtomicData.

These TorchScript functions operate on ``Dict[str, torch.Tensor]`` representations
of the ``AtomicData`` class which are produced by ``AtomicData.to_AtomicDataDict()``.

Authors: Albert Musaelian
"""

from typing import Dict, Any

import torch
import torch.jit

from e3nn import o3

# Make the keys available in this module
#from ._keys import *  # noqa: F403, F401

# Also import the module to use in TorchScript, this is a hack to avoid bug:
# https://github.com/pytorch/pytorch/issues/52312
#from . import _keys

# Define a type alias
Type = Dict[str, torch.Tensor]


def validate_keys(keys, graph_required=True):
    # Validate combinations
    if graph_required:
        if not (POSITIONS_KEY in keys and EDGE_INDEX_KEY in keys):
            raise KeyError("At least pos and edge_index must be supplied")
    if EDGE_CELL_SHIFT_KEY in keys and "cell" not in keys:
        raise ValueError("If `edge_cell_shift` given, `cell` must be given.")


_SPECIAL_IRREPS = [None]


def _fix_irreps_dict(d: Dict[str, Any]):
    """Just adds a zero irrep to the dict of irreps"""
    return {k: (i if i in _SPECIAL_IRREPS else o3.Irreps(i)) for k, i in d.items()}


def _irreps_compatible(ir1: Dict[str, o3.Irreps], ir2: Dict[str, o3.Irreps]):
    """Checks if two dicts have the same irreps"""
    return all(ir1[k] == ir2[k] for k in ir1 if k in ir2)


#@torch.jit.script
def with_edge_vectors(data: Type, with_lengths: bool = True) -> Type:
    """Compute the edge displacement vectors for a graph.

    If ``data.pos.requires_grad`` and/or ``data.cell.requires_grad``, this
    method will return edge vectors correctly connected in the autograd graph.

    Returns:
        Tensor [n_edges, 3] edge displacement vectors
    """
    # If already computed
    if EDGE_VECTORS_KEY in data:
        if with_lengths and EDGE_LENGTH_KEY not in data:
            data[EDGE_LENGTH_KEY] = torch.linalg.norm(
                data[EDGE_VECTORS_KEY], dim=-1
            )
        return data
    else:
        # Build it dynamically
        # Note that this is
        # (1) backwardable, because everything (pos, cell, shifts)
        #     is Tensors.
        # (2) works on a Batch constructed from AtomicData
        pos = data[POSITIONS_KEY]
        edge_index = data[EDGE_INDEX_KEY]
        # edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
        edge_vec = torch.index_select(pos, 0, edge_index[1]) - torch.index_select(
            pos, 0, edge_index[0]
        )
        if CELL_KEY in data: # just cell is given in data
            # ^ note that to save time we don't check that the edge_cell_shifts 
            # are trivial if no cell is provided; we just assume they are either not present or all zero.
            # -1 gives a batch dim no matter what
            cell = data[CELL_KEY].view(-1, 3, 3)
            edge_cell_shift = data[_keys.EDGE_CELL_SHIFT_KEY]
            if cell.shape[0] > 1:
                batch = data[BATCH_KEY]
                # Cell has a batch dimension
                # note the ASE cell vectors as rows convention
                edge_vec = edge_vec + torch.einsum(
                    "ni,nij->nj", edge_cell_shift, cell[batch[edge_index[0]]]
                )
                # TODO: is there a more efficient way to do the above without
                # creating an [n_edge] and [n_edge, 3, 3] tensor?
            else:
                # Cell has either no batch dimension, or a useless one,
                # so we can avoid creating the large intermediate cell tensor.
                # Note that we do NOT check that the batch array, if it is present,
                # is trivial — but this does need to be consistent.
                edge_vec = edge_vec + torch.einsum(
                    "ni,ij->nj",
                    edge_cell_shift,
                    cell.squeeze(0),  # remove batch dimension
                )
        data[EDGE_VECTORS_KEY] = edge_vec
        if with_lengths:
            data[EDGE_LENGTH_KEY] = torch.linalg.norm(edge_vec, dim=-1)
        return data


#@torch.jit.script
def with_batch(data: Type) -> Type:
    """Get batch Tensor.

    If this AtomicDataPrimitive has no ``batch``, one of all zeros will be
    allocated and returned.
    """
    if BATCH_KEY in data:
        return data
    else:
        pos = data[POSITIONS_KEY]
        batch = torch.zeros(len(pos), dtype=torch.long, device=pos.device)
        data[BATCH_KEY] = batch
        # ugly way to make a tensor of [0, len(pos)], but it avoids transfers or casts
        data[BATCH_PTR_KEY] = torch.arange(
            start=0,
            end=len(pos) + 1,
            step=len(pos),
            dtype=torch.long,
            device=pos.device,
        )
        return data

### Some utils

In [24]:
import torch

# There is no built-in way to check if a Tensor is of an integer type
_TORCH_INTEGER_DTYPES = (torch.int, torch.long)

In [25]:
import re
import copy
import collections

import torch

# from ..utils.num_nodes import maybe_num_nodes

__num_nodes_warn_msg__ = (
    "The number of nodes in your data object can only be inferred by its {} "
    "indices, and hence may result in unexpected batch-wise behavior, e.g., "
    "in case there exists isolated nodes. Please consider explicitly setting "
    "the number of nodes for this data object by assigning it to "
    "data.num_nodes."
)


def size_repr(key, item, indent=0):
    indent_str = " " * indent
    if torch.is_tensor(item) and item.dim() == 0:
        out = item.item()
    elif torch.is_tensor(item):
        out = str(list(item.size()))
    elif isinstance(item, list) or isinstance(item, tuple):
        out = str([len(item)])
    elif isinstance(item, dict):
        lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()]
        out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}"
    elif isinstance(item, str):
        out = f'"{item}"'
    else:
        out = str(item)

    return f"{indent_str}{key}={out}"


class Data(object):
    r"""A plain old python object modeling a single graph with various
    (optional) attributes:

    Args:
        x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
            num_node_features]`. (default: :obj:`None`)
        edge_index (LongTensor, optional): Graph connectivity in COO format
            with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
        edge_attr (Tensor, optional): Edge feature matrix with shape
            :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
        y (Tensor, optional): Graph or node targets with arbitrary shape.
            (default: :obj:`None`)
        pos (Tensor, optional): Node position matrix with shape
            :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
        normal (Tensor, optional): Normal vector matrix with shape
            :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
        face (LongTensor, optional): Face adjacency matrix with shape
            :obj:`[3, num_faces]`. (default: :obj:`None`)

    The data object is not restricted to these attributes and can be extented
    by any other additional data.

    Example::

        data = Data(x=x, edge_index=edge_index)
        data.train_idx = torch.tensor([...], dtype=torch.long)
        data.test_mask = torch.tensor([...], dtype=torch.bool)
    """

    def __init__(
        self,
        x=None,
        edge_index=None,
        edge_attr=None,
        y=None,
        pos=None,
        normal=None,
        face=None,
        **kwargs,
    ):
        self.x = x
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.y = y
        self.pos = pos
        self.normal = normal
        self.face = face
        for key, item in kwargs.items():
            if key == "num_nodes":
                self.__num_nodes__ = item
            else:
                self[key] = item

        if edge_index is not None and edge_index.dtype != torch.long:
            raise ValueError(
                (
                    f"Argument `edge_index` needs to be of type `torch.long` but "
                    f"found type `{edge_index.dtype}`."
                )
            )

        if face is not None and face.dtype != torch.long:
            raise ValueError(
                (
                    f"Argument `face` needs to be of type `torch.long` but found "
                    f"type `{face.dtype}`."
                )
            )

    @classmethod
    def from_dict(cls, dictionary):
        r"""Creates a data object from a python dictionary."""
        data = cls()

        for key, item in dictionary.items():
            data[key] = item

        return data

    def to_dict(self):
        return {key: item for key, item in self}

    def to_namedtuple(self):
        keys = self.keys
        DataTuple = collections.namedtuple("DataTuple", keys)
        return DataTuple(*[self[key] for key in keys])

    def __getitem__(self, key):
        r"""Gets the data of the attribute :obj:`key`."""
        return getattr(self, key, None)

    def __setitem__(self, key, value):
        """Sets the attribute :obj:`key` to :obj:`value`."""
        setattr(self, key, value)

    def __delitem__(self, key):
        r"""Delete the data of the attribute :obj:`key`."""
        return delattr(self, key)

    @property
    def keys(self):
        r"""Returns all names of graph attributes."""
        keys = [key for key in self.__dict__.keys() if self[key] is not None]
        keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
        return keys

    def __len__(self):
        r"""Returns the number of all present attributes."""
        return len(self.keys)

    def __contains__(self, key):
        r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
        data."""
        return key in self.keys

    def __iter__(self):
        r"""Iterates over all present attributes in the data, yielding their
        attribute names and content."""
        for key in sorted(self.keys):
            yield key, self[key]

    def __call__(self, *keys):
        r"""Iterates over all attributes :obj:`*keys` in the data, yielding
        their attribute names and content.
        If :obj:`*keys` is not given this method will iterative over all
        present attributes."""
        for key in sorted(self.keys) if not keys else keys:
            if key in self:
                yield key, self[key]

    def __cat_dim__(self, key, value):
        r"""Returns the dimension for which :obj:`value` of attribute
        :obj:`key` will get concatenated when creating batches.

        .. note::

            This method is for internal use only, and should only be overridden
            if the batch concatenation process is corrupted for a specific data
            attribute.
        """
        if bool(re.search("(index|face)", key)):
            return -1
        return 0

    def __inc__(self, key, value):
        r"""Returns the incremental count to cumulatively increase the value
        of the next attribute of :obj:`key` when creating batches.

        .. note::

            This method is for internal use only, and should only be overridden
            if the batch concatenation process is corrupted for a specific data
            attribute.
        """
        # Only `*index*` and `*face*` attributes should be cumulatively summed
        # up when creating batches.
        return self.num_nodes if bool(re.search("(index|face)", key)) else 0

    @property
    def num_nodes(self):
        r"""Returns or sets the number of nodes in the graph.

        .. note::
            The number of nodes in your data object is typically automatically
            inferred, *e.g.*, when node features :obj:`x` are present.
            In some cases however, a graph may only be given by its edge
            indices :obj:`edge_index`.
            PyTorch Geometric then *guesses* the number of nodes
            according to :obj:`edge_index.max().item() + 1`, but in case there
            exists isolated nodes, this number has not to be correct and can
            therefore result in unexpected batch-wise behavior.
            Thus, we recommend to set the number of nodes in your data object
            explicitly via :obj:`data.num_nodes = ...`.
            You will be given a warning that requests you to do so.
        """
        if hasattr(self, "__num_nodes__"):
            return self.__num_nodes__
        for key, item in self("x", "pos", "normal", "batch"):
            return item.size(self.__cat_dim__(key, item))
        if hasattr(self, "adj"):
            return self.adj.size(0)
        if hasattr(self, "adj_t"):
            return self.adj_t.size(1)
        # if self.face is not None:
        #     logging.warning(__num_nodes_warn_msg__.format("face"))
        #     return maybe_num_nodes(self.face)
        # if self.edge_index is not None:
        #     logging.warning(__num_nodes_warn_msg__.format("edge"))
        #     return maybe_num_nodes(self.edge_index)
        return None

    @num_nodes.setter # I guess python have override of functions for classes
    def num_nodes(self, num_nodes):
        self.__num_nodes__ = num_nodes

    @property
    def num_edges(self):
        """
        Returns the number of edges in the graph.
        For undirected graphs, this will return the number of bi-directional
        edges, which is double the amount of unique edges.
        """
        for key, item in self("edge_index", "edge_attr"):
            return item.size(self.__cat_dim__(key, item))
        for key, item in self("adj", "adj_t"):
            return item.nnz()
        return None

    @property
    def num_faces(self):
        r"""Returns the number of faces in the mesh."""
        if self.face is not None:
            return self.face.size(self.__cat_dim__("face", self.face))
        return None

    @property
    def num_node_features(self):
        r"""Returns the number of features per node in the graph."""
        if self.x is None:
            return 0
        return 1 if self.x.dim() == 1 else self.x.size(1)

    @property
    def num_features(self):
        r"""Alias for :py:attr:`~num_node_features`."""
        return self.num_node_features

    @property
    def num_edge_features(self):
        r"""Returns the number of features per edge in the graph."""
        if self.edge_attr is None:
            return 0
        return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1)

    def __apply__(self, item, func):
        if torch.is_tensor(item):
            return func(item)
        elif isinstance(item, (tuple, list)):
            return [self.__apply__(v, func) for v in item]
        elif isinstance(item, dict):
            return {k: self.__apply__(v, func) for k, v in item.items()}
        else:
            return item

    def apply(self, func, *keys):
        r"""Applies the function :obj:`func` to all tensor attributes
        :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to
        all present attributes.
        """
        for key, item in self(*keys):
            self[key] = self.__apply__(item, func)
        return self

    def contiguous(self, *keys):
        r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
        If :obj:`*keys` is not given, all present attributes are ensured to
        have a contiguous memory layout."""
        return self.apply(lambda x: x.contiguous(), *keys)

    def to(self, device, *keys, **kwargs):
        r"""Performs tensor dtype and/or device conversion to all attributes
        :obj:`*keys`.
        If :obj:`*keys` is not given, the conversion is applied to all present
        attributes."""
        return self.apply(lambda x: x.to(device, **kwargs), *keys)

    def cpu(self, *keys):
        r"""Copies all attributes :obj:`*keys` to CPU memory.
        If :obj:`*keys` is not given, the conversion is applied to all present
        attributes."""
        return self.apply(lambda x: x.cpu(), *keys)

    def cuda(self, device=None, non_blocking=False, *keys):
        r"""Copies all attributes :obj:`*keys` to CUDA memory.
        If :obj:`*keys` is not given, the conversion is applied to all present
        attributes."""
        return self.apply(
            lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys
        )

    def clone(self):
        r"""Performs a deep-copy of the data object."""
        return self.__class__.from_dict(
            {
                k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v)
                for k, v in self.__dict__.items()
            }
        )

    def pin_memory(self, *keys):
        r"""Copies all attributes :obj:`*keys` to pinned memory.
        If :obj:`*keys` is not given, the conversion is applied to all present
        attributes."""
        return self.apply(lambda x: x.pin_memory(), *keys)

    def debug(self):
        if self.edge_index is not None:
            if self.edge_index.dtype != torch.long:
                raise RuntimeError(
                    (
                        "Expected edge indices of dtype {}, but found dtype " " {}"
                    ).format(torch.long, self.edge_index.dtype)
                )

        if self.face is not None:
            if self.face.dtype != torch.long:
                raise RuntimeError(
                    (
                        "Expected face indices of dtype {}, but found dtype " " {}"
                    ).format(torch.long, self.face.dtype)
                )

        if self.edge_index is not None:
            if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2:
                raise RuntimeError(
                    (
                        "Edge indices should have shape [2, num_edges] but found"
                        " shape {}"
                    ).format(self.edge_index.size())
                )

        if self.edge_index is not None and self.num_nodes is not None:
            if self.edge_index.numel() > 0:
                min_index = self.edge_index.min()
                max_index = self.edge_index.max()
            else:
                min_index = max_index = 0
            if min_index < 0 or max_index > self.num_nodes - 1:
                raise RuntimeError(
                    (
                        "Edge indices must lay in the interval [0, {}]"
                        " but found them in the interval [{}, {}]"
                    ).format(self.num_nodes - 1, min_index, max_index)
                )

        if self.face is not None:
            if self.face.dim() != 2 or self.face.size(0) != 3:
                raise RuntimeError(
                    (
                        "Face indices should have shape [3, num_faces] but found"
                        " shape {}"
                    ).format(self.face.size())
                )

        if self.face is not None and self.num_nodes is not None:
            if self.face.numel() > 0:
                min_index = self.face.min()
                max_index = self.face.max()
            else:
                min_index = max_index = 0
            if min_index < 0 or max_index > self.num_nodes - 1:
                raise RuntimeError(
                    (
                        "Face indices must lay in the interval [0, {}]"
                        " but found them in the interval [{}, {}]"
                    ).format(self.num_nodes - 1, min_index, max_index)
                )

        if self.edge_index is not None and self.edge_attr is not None:
            if self.edge_index.size(1) != self.edge_attr.size(0):
                raise RuntimeError(
                    (
                        "Edge indices and edge attributes hold a differing "
                        "number of edges, found {} and {}"
                    ).format(self.edge_index.size(), self.edge_attr.size())
                )

        if self.x is not None and self.num_nodes is not None:
            if self.x.size(0) != self.num_nodes:
                raise RuntimeError(
                    (
                        "Node features should hold {} elements in the first "
                        "dimension but found {}"
                    ).format(self.num_nodes, self.x.size(0))
                )

        if self.pos is not None and self.num_nodes is not None:
            if self.pos.size(0) != self.num_nodes:
                raise RuntimeError(
                    (
                        "Node positions should hold {} elements in the first "
                        "dimension but found {}"
                    ).format(self.num_nodes, self.pos.size(0))
                )

        if self.normal is not None and self.num_nodes is not None:
            if self.normal.size(0) != self.num_nodes:
                raise RuntimeError(
                    (
                        "Node normals should hold {} elements in the first "
                        "dimension but found {}"
                    ).format(self.num_nodes, self.normal.size(0))
                )

    def __repr__(self):
        cls = str(self.__class__.__name__)
        has_dict = any([isinstance(item, dict) for _, item in self])

        if not has_dict:
            info = [size_repr(key, item) for key, item in self]
            return "{}({})".format(cls, ", ".join(info))
        else:
            info = [size_repr(key, item, indent=2) for key, item in self]
            return "{}(\n{}\n)".format(cls, ",\n".join(info))

## Atomic data

In [26]:
"""AtomicData: neighbor graphs in (periodic) real space.

Authors: Albert Musaelian
"""

import warnings
from copy import deepcopy
from typing import Union, Tuple, Dict, Optional, List, Set, Sequence, Final
from collections.abc import Mapping
import os

import numpy as np
import ase
from ase.calculators.singlepoint import SinglePointCalculator, SinglePointDFTCalculator
from ase.calculators.calculator import all_properties as ase_all_properties # maybe problem
from ase.stress import voigt_6_to_full_3x3_stress, full_3x3_to_voigt_6_stress

import torch
import e3nn.o3
from e3nn.io import CartesianTensor# may be problem
#from . import AtomicDataDict
#from nequip.data._util import _TORCH_INTEGER_DTYPES
#from nequip.utils.torch_geometric import Data

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]

# === Key Registration ===

_DEFAULT_LONG_FIELDS: Set[str] = {
    EDGE_INDEX_KEY,
    ATOMIC_NUMBERS_KEY,
    ATOM_TYPE_KEY,
    BATCH_KEY,
}
_DEFAULT_NODE_FIELDS: Set[str] = {
    POSITIONS_KEY,
    NODE_FEATURES_KEY,
    NODE_ATTRS_KEY,
    ATOMIC_NUMBERS_KEY,
    ATOM_TYPE_KEY,
    FORCE_KEY,
    PER_ATOM_ENERGY_KEY,
    BATCH_KEY,
}
_DEFAULT_EDGE_FIELDS: Set[str] = {
    EDGE_CELL_SHIFT_KEY,
    EDGE_VECTORS_KEY,
    EDGE_LENGTH_KEY,
    EDGE_ATTRS_KEY,
    EDGE_EMBEDDING_KEY,
    EDGE_FEATURES_KEY,
    EDGE_CUTOFF_KEY,
    EDGE_ENERGY_KEY,
}
_DEFAULT_GRAPH_FIELDS: Set[str] = {
    TOTAL_ENERGY_KEY,
    STRESS_KEY,
    VIRIAL_KEY,
    PBC_KEY,
    CELL_KEY,
    BATCH_PTR_KEY,
}
_DEFAULT_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = {
    STRESS_KEY: "ij=ji",
    VIRIAL_KEY: "ij=ji",
}
_NODE_FIELDS: Set[str] = set(_DEFAULT_NODE_FIELDS)
_EDGE_FIELDS: Set[str] = set(_DEFAULT_EDGE_FIELDS)
_GRAPH_FIELDS: Set[str] = set(_DEFAULT_GRAPH_FIELDS)
_LONG_FIELDS: Set[str] = set(_DEFAULT_LONG_FIELDS)
_CARTESIAN_TENSOR_FIELDS: Dict[str, str] = dict(_DEFAULT_CARTESIAN_TENSOR_FIELDS)


def register_fields(
    node_fields: Sequence[str] = [],
    edge_fields: Sequence[str] = [],
    graph_fields: Sequence[str] = [],
    long_fields: Sequence[str] = [],
    cartesian_tensor_fields: Dict[str, str] = {},
) -> None:
    r"""Register fields as being per-atom, per-edge, or per-frame.

    Args:
        node_permute_fields: fields that are equivariant to node permutations.
        edge_permute_fields: fields that are equivariant to edge permutations.
    """
    node_fields: set = set(node_fields)
    edge_fields: set = set(edge_fields)
    graph_fields: set = set(graph_fields)
    long_fields: set = set(long_fields)

    # error checking: prevents registering fields as contradictory types
    # potentially unregistered fields
    assert len(node_fields.intersection(edge_fields)) == 0
    assert len(node_fields.intersection(graph_fields)) == 0
    assert len(edge_fields.intersection(graph_fields)) == 0
    # already registered fields
    assert len(_NODE_FIELDS.intersection(edge_fields)) == 0
    assert len(_NODE_FIELDS.intersection(graph_fields)) == 0
    assert len(_EDGE_FIELDS.intersection(node_fields)) == 0
    assert len(_EDGE_FIELDS.intersection(graph_fields)) == 0
    assert len(_GRAPH_FIELDS.intersection(edge_fields)) == 0
    assert len(_GRAPH_FIELDS.intersection(node_fields)) == 0

    # check that Cartesian tensor fields to add are rank-2 (higher ranks not supported)
    for cart_tensor_key in cartesian_tensor_fields:
        cart_tensor_rank = len(
            CartesianTensor(cartesian_tensor_fields[cart_tensor_key]).indices
        )
        if cart_tensor_rank != 2:
            raise NotImplementedError(
                f"Only rank-2 tensor data processing supported, but got {cart_tensor_key} is rank {cart_tensor_rank}. Consider raising a GitHub issue if higher-rank tensor data processing is desired."
            )

    # update fields
    _NODE_FIELDS.update(node_fields)
    _EDGE_FIELDS.update(edge_fields)
    _GRAPH_FIELDS.update(graph_fields)
    _LONG_FIELDS.update(long_fields)
    _CARTESIAN_TENSOR_FIELDS.update(cartesian_tensor_fields)


def deregister_fields(*fields: Sequence[str]) -> None:
    r"""Deregister a field registered with ``register_fields``.

    Silently ignores fields that were never registered to begin with.

    Args:
        *fields: fields to deregister.
    """
    for f in fields:
        assert f not in _DEFAULT_NODE_FIELDS, "Cannot deregister built-in field"
        assert f not in _DEFAULT_EDGE_FIELDS, "Cannot deregister built-in field"
        assert f not in _DEFAULT_GRAPH_FIELDS, "Cannot deregister built-in field"
        assert f not in _DEFAULT_LONG_FIELDS, "Cannot deregister built-in field"
        assert (
            f not in _DEFAULT_CARTESIAN_TENSOR_FIELDS
        ), "Cannot deregister built-in field"

        _NODE_FIELDS.discard(f)
        _EDGE_FIELDS.discard(f)
        _GRAPH_FIELDS.discard(f)
        _LONG_FIELDS.discard(f)
        _CARTESIAN_TENSOR_FIELDS.pop(f, None)


def _register_field_prefix(prefix: str) -> None:
    """Re-register all registered fields as the same type, but with `prefix` added on."""
    assert prefix.endswith("_")
    register_fields(
        node_fields=[prefix + e for e in _NODE_FIELDS],
        edge_fields=[prefix + e for e in _EDGE_FIELDS],
        graph_fields=[prefix + e for e in _GRAPH_FIELDS],
        long_fields=[prefix + e for e in _LONG_FIELDS],
    )


#  === AtomicData ===


def _process_dict(kwargs, ignore_fields=[]):
    """Convert a dict of data into correct dtypes/shapes according to key"""
    # Deal with _some_ dtype issues
    for k, v in kwargs.items():
        if k in ignore_fields:
            continue

        if k in _LONG_FIELDS:
            # Any property used as an index must be long (or byte or bool, but those are not relevant for atomic scale systems)
            # int32 would pass later checks, but is actually disallowed by torch
            kwargs[k] = torch.as_tensor(v, dtype=torch.long)
        elif isinstance(v, bool):
            kwargs[k] = torch.as_tensor(v)
        elif isinstance(v, np.ndarray):
            if np.issubdtype(v.dtype, np.floating):
                kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
            else:
                kwargs[k] = torch.as_tensor(v)
        elif isinstance(v, list):
            ele_dtype = np.array(v).dtype
            if np.issubdtype(ele_dtype, np.floating):
                kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
            else:
                kwargs[k] = torch.as_tensor(v)
        elif np.issubdtype(type(v), np.floating):
            # Force scalars to be tensors with a data dimension
            # This makes them play well with irreps
            kwargs[k] = torch.as_tensor(v, dtype=torch.get_default_dtype())
        elif isinstance(v, torch.Tensor) and len(v.shape) == 0:
            # ^ this tensor is a scalar; we need to give it
            # a data dimension to play nice with irreps
            kwargs[k] = v
        elif isinstance(v, torch.Tensor):
            # This is a tensor, so we just don't do anything except avoid the warning in the `else`
            pass
        else:
            warnings.warn(
                f"Value for field {k} was of unsupported type {type(v)} (value was {v})"
            )

    if BATCH_KEY in kwargs:
        num_frames = kwargs[BATCH_KEY].max() + 1
    else:
        num_frames = 1

    for k, v in kwargs.items():
        if k in ignore_fields:
            continue

        if len(v.shape) == 0:
            kwargs[k] = v.unsqueeze(-1)
            v = kwargs[k]

        if k in set.union(_NODE_FIELDS, _EDGE_FIELDS) and len(v.shape) == 1:
            kwargs[k] = v.unsqueeze(-1)
            v = kwargs[k]

        if (
            k in _NODE_FIELDS
            and POSITIONS_KEY in kwargs
            and v.shape[0] != kwargs[POSITIONS_KEY].shape[0]
        ):
            raise ValueError(
                f"{k} is a node field but has the wrong dimension {v.shape}"
            )
        elif (
            k in _EDGE_FIELDS
            and EDGE_INDEX_KEY in kwargs
            and v.shape[0] != kwargs[EDGE_INDEX_KEY].shape[1]
        ):
            raise ValueError(
                f"{k} is a edge field but has the wrong dimension {v.shape}"
            )
        elif k in _GRAPH_FIELDS:
            if num_frames > 1 and v.shape[0] != num_frames:
                raise ValueError(f"Wrong shape for graph property {k}")


class AtomicData(Data):
    """A neighbor graph for points in (periodic triclinic) real space.

    For typical cases either ``from_points`` or ``from_ase`` should be used to
    construct a AtomicData; they also standardize and check their input much more
    thoroughly.

    In general, ``node_features`` are features or input information on the nodes 
    that will be fed through and transformed by the network, while ``node_attrs`` are _encodings_ fixed, 
    inherant attributes of the atoms themselves that remain constant through the network.
    For example, a one-hot _encoding_ of atomic species is a node attribute, 
    while some observed instantaneous property of that atom (current partial charge, for example), 
    would be a feature.

    In general, ``torch.Tensor`` arguments should be of consistant dtype. 
    Numpy arrays will be converted to ``torch.Tensor``s; 
    those of floating point dtype will be converted to ``torch.get_current_dtype()`` 
    regardless of their original precision. 
    Scalar values (Python scalars or ``torch.Tensor``s of shape ``()``) a resized to tensors of shape ``[1]``. 
    Per-atom scalar values should be given with shape ``[N_at, 1]``.

    ``AtomicData`` should be used for all data creation and manipulation outside of the model;
    inside of the model ``Type`` is used.

    Args:
        pos (Tensor [n_nodes, 3]): Positions of the nodes.
        edge_index (LongTensor [2, n_edges]): ``edge_index[0]`` is the per-edge
            index of the source node and ``edge_index[1]`` is the target node.
        edge_cell_shift (Tensor [n_edges, 3], optional): which periodic image
            of the target point each edge goes to, relative to the source point.
        cell (Tensor [1, 3, 3], optional): the periodic cell for
            ``edge_cell_shift`` as the three triclinic cell vectors.
        node_features (Tensor [n_atom, ...]): the input features of the nodes, optional
        node_attrs (Tensor [n_atom, ...]): the attributes of the nodes, for instance the atom type, optional
        batch (Tensor [n_atom]): the graph to which the node belongs, optional
        atomic_numbers (Tensor [n_atom]): optional.
        atom_type (Tensor [n_atom]): optional.
        **kwargs: other data, optional.
    """

    def __init__(
        self, irreps: Dict[str, e3nn.o3.Irreps] = {}, _validate: bool = True, **kwargs
    ):
        # empty init needed by get_example
        if len(kwargs) == 0 and len(irreps) == 0:
            super().__init__()
            return

        # Check the keys
        if _validate:
            validate_keys(kwargs)
            _process_dict(kwargs)

        super().__init__(num_nodes=len(kwargs["pos"]), **kwargs)

        if _validate:
            # Validate shapes
            assert self.pos.dim() == 2 and self.pos.shape[1] == 3
            assert self.edge_index.dim() == 2 and self.edge_index.shape[0] == 2
            if "edge_cell_shift" in self and self.edge_cell_shift is not None:
                assert self.edge_cell_shift.shape == (self.num_edges, 3)
                assert self.edge_cell_shift.dtype == self.pos.dtype
            if "cell" in self and self.cell is not None:
                assert (self.cell.shape == (3, 3)) or (
                    self.cell.dim() == 3 and self.cell.shape[1:] == (3, 3)
                )
                assert self.cell.dtype == self.pos.dtype
            if "node_features" in self and self.node_features is not None:
                assert self.node_features.shape[0] == self.num_nodes
                assert self.node_features.dtype == self.pos.dtype
            if "node_attrs" in self and self.node_attrs is not None:
                assert self.node_attrs.shape[0] == self.num_nodes
                assert self.node_attrs.dtype == self.pos.dtype

            if (
                ATOMIC_NUMBERS_KEY in self
                and self.atomic_numbers is not None
            ):
                assert self.atomic_numbers.dtype in _TORCH_INTEGER_DTYPES
            if "batch" in self and self.batch is not None:
                assert self.batch.dim() == 2 and self.batch.shape[0] == self.num_nodes
                # Check that there are the right number of cells
                if "cell" in self and self.cell is not None:
                    cell = self.cell.view(-1, 3, 3)
                    assert cell.shape[0] == self.batch.max() + 1

            # Validate irreps
            # __*__ is the only way to hide from torch_geometric
            self.__irreps__ = _fix_irreps_dict(irreps)
            for field, irreps in self.__irreps__:
                if irreps is not None:
                    assert self[field].shape[-1] == irreps.dim

    @classmethod
    def from_points(
        cls,  # points to the class, so can be used as consturctor as cls(...)
        pos=None,
        r_max: float = None,
        self_interaction: bool = False,
        strict_self_interaction: bool = True,
        cell=None,
        pbc: Optional[PBC] = None,
        **kwargs,
    ):
        """Build neighbor graph from points, optionally with PBC.

        Args:
            pos (np.ndarray/torch.Tensor shape [N, 3]): node positions. If Tensor, must be on the CPU.
            r_max (float): neighbor cutoff radius.
            cell (ase.Cell/ndarray [3,3], optional): periodic cell for the points. Defaults to ``None``.
            pbc (bool or 3-tuple of bool, optional): whether to apply periodic boundary conditions to all
            or each of the three cell vector directions. Defaults to ``False``.
            self_interaction (bool, optional): whether to include self edges for points. Defaults to ``False``. Note
            that edges between the same atom in different periodic images are still included. (See
            ``strict_self_interaction`` to control this behaviour.)
            strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even
            if the two instances of the atom are in different periodic images. 
            Defaults to True, should be True for most applications.
            **kwargs (optional): other fields to add. Keys listed in ``*_KEY` will be treated specially.
        """
        if pos is None or r_max is None:
            raise ValueError("pos and r_max must be given.")

        if pbc is None:
            if cell is not None:
                raise ValueError(
                    "A cell was provided, but pbc weren't. Please explicitly probide PBC."
                )
            # there are no PBC if cell and pbc are not provided
            pbc = False

        if isinstance(pbc, bool):
            pbc = (pbc,) * 3
        else:
            assert len(pbc) == 3

        pos = torch.as_tensor(pos, dtype=torch.get_default_dtype())

        edge_index, edge_cell_shift, cell = neighbor_list_and_relative_vec(
            pos=pos,
            r_max=r_max,
            self_interaction=self_interaction,
            strict_self_interaction=strict_self_interaction,
            cell=cell,
            pbc=pbc,
        )

        # Make torch tensors for data:
        if cell is not None:
            kwargs[CELL_KEY] = cell.view(3, 3)
            kwargs[EDGE_CELL_SHIFT_KEY] = edge_cell_shift
        if pbc is not None:
            kwargs[PBC_KEY] = torch.as_tensor(
                pbc, dtype=torch.bool
            ).view(3)

        return cls(edge_index=edge_index, pos=torch.as_tensor(pos), **kwargs)

    @classmethod
    def from_ase(
        cls,
        atoms,
        r_max,
        key_mapping: Optional[Dict[str, str]] = {},
        include_keys: Optional[list] = [],
        **kwargs,
    ):
        """Build a ``AtomicData`` from an ``ase.Atoms`` object.

        Respects ``atoms``'s ``pbc`` and ``cell``.

        First tries to extract energies and forces from a single-point calculator associated with the ``Atoms`` 
        if one is present and has those fields.
        If either is not found, the method will look for ``energy``/``energies`` 
        and ``force``/``forces`` in ``atoms.arrays``.

        `get_atomic_numbers()` will be stored as the atomic_numbers attribute.

        Args:
            atoms (ase.Atoms): the input.
            r_max (float): neighbor cutoff radius.
            features (torch.Tensor shape [N, M], optional): per-atom M-dimensional feature vectors. If ``None`` (the
             default), uses a one-hot encoding of the species present in ``atoms``.
            include_keys (list): list of additional keys to include in AtomicData aside from the ones defined in
                 ase.calculators.calculator.all_properties. Optional
            key_mapping (dict): rename ase property name to a new string name. Optional
            **kwargs (optional): other arguments for the ``AtomicData`` constructor.

        Returns:
            A ``AtomicData``.
        """
        from nequip.ase import NequIPCalculator

        assert "pos" not in kwargs

        default_args = set(
            [
                "numbers",
                "positions",
            ]  # ase internal names for position and atomic_numbers
            + ["pbc", "cell", "pos", "r_max"]  # arguments for from_points method
            + list(kwargs.keys())
        )
        # the keys that are duplicated in kwargs are removed from the include_keys
        include_keys = list(
            set(include_keys + ase_all_properties + list(key_mapping.keys()))
            - default_args
        )

        km = {
            "forces": FORCE_KEY,
            "energy": TOTAL_ENERGY_KEY,
        }
        km.update(key_mapping)
        key_mapping = km

        add_fields = {}

        # Get info from atoms.arrays; lowest priority. copy first
        add_fields = {
            key_mapping.get(k, k): v
            for k, v in atoms.arrays.items()
            if k in include_keys
        }

        # Get info from atoms.info; second lowest priority.
        add_fields.update(
            {
                key_mapping.get(k, k): v
                for k, v in atoms.info.items()
                if k in include_keys
            }
        )

        if atoms.calc is not None:
            if isinstance(
                atoms.calc, (SinglePointCalculator, SinglePointDFTCalculator)
            ):
                add_fields.update(
                    {
                        key_mapping.get(k, k): deepcopy(v)
                        for k, v in atoms.calc.results.items()
                        if k in include_keys
                    }
                )
            elif isinstance(atoms.calc, NequIPCalculator):
                pass  # otherwise the calculator breaks
            else:
                raise NotImplementedError(
                    f"`from_ase` does not support calculator {atoms.calc}"
                )

        add_fields[ATOMIC_NUMBERS_KEY] = atoms.get_atomic_numbers()

        # cell and pbc in kwargs can override the ones stored in atoms
        cell = kwargs.pop("cell", atoms.get_cell())
        pbc = kwargs.pop("pbc", atoms.pbc)

        # IMPORTANT: the following reshape logic only applies to rank-2 Cartesian tensor fields
        for key in add_fields:
            if key in _CARTESIAN_TENSOR_FIELDS:
                # enforce (3, 3) shape for graph fields, e.g. stress, virial
                if key in _GRAPH_FIELDS:
                    # handle ASE-style 6 element Voigt order stress
                    if key in (STRESS_KEY, VIRIAL_KEY):
                        if add_fields[key].shape == (6,):
                            add_fields[key] = voigt_6_to_full_3x3_stress(
                                add_fields[key]
                            )
                    if add_fields[key].shape == (3, 3):
                        # it's already 3x3, do nothing else
                        pass
                    elif add_fields[key].shape == (9,):
                        add_fields[key] = add_fields[key].reshape((3, 3))
                    else:
                        raise RuntimeError(
                            f"bad shape for {key} registered as a Cartesian tensor graph field---please note that only rank-2 Cartesian tensors are currently supported"
                        )
                # enforce (N_atom, 3, 3) shape for node fields, e.g. Born effective charges
                elif key in _NODE_FIELDS:
                    if add_fields[key].shape[1:] == (3, 3):
                        pass
                    elif add_fields[key].shape[1:] == (9,):
                        add_fields[key] = add_fields[key].reshape((-1, 3, 3))
                    else:
                        raise RuntimeError(
                            f"bad shape for {key} registered as a Cartesian tensor node field---please note that only rank-2 Cartesian tensors are currently supported"
                        )
                else:
                    raise RuntimeError(
                        f"{key} registered as a Cartesian tensor field was not registered as either a graph or node field"
                    )

        return cls.from_points(
            pos=atoms.positions,
            r_max=r_max,
            cell=cell,
            pbc=pbc,
            **kwargs,
            **add_fields,
        )

    def to_ase(
        self,
        type_mapper=None,
        extra_fields: List[str] = [],
    ) -> Union[List[ase.Atoms], ase.Atoms]:
        """Build a (list of) ``ase.Atoms`` object(s) from an ``AtomicData`` object.

        For each unique batch number provided in ``BATCH_KEY``,
        an ``ase.Atoms`` object is created. If ``BATCH_KEY`` does not
        exist in self, a single ``ase.Atoms`` object is created.

        Args:
            type_mapper: if provided, will be used to map ``ATOM_TYPES`` back into
                elements, if the configuration of the ``type_mapper`` allows.
            extra_fields: fields other than those handled explicitly (currently
                those defining the structure as well as energy, per-atom energy,
                and forces) to include in the output object. Per-atom (per-node)
                quantities will be included in ``arrays``; per-graph and per-edge
                quantities will be included in ``info``.

        Returns:
            A list of ``ase.Atoms`` objects if ``BATCH_KEY`` is in self
            and is not None. Otherwise, a single ``ase.Atoms`` object is returned.
        """
        positions = self.pos
        edge_index = self[EDGE_INDEX_KEY]
        if positions.device != torch.device("cpu"):
            raise TypeError(
                "Explicitly move this `AtomicData` to CPU using `.to()` before calling `to_ase()`."
            )
        if ATOMIC_NUMBERS_KEY in self:
            atomic_nums = self.atomic_numbers
        elif type_mapper is not None and type_mapper.has_chemical_symbols:
            atomic_nums = type_mapper.untransform(self[ATOM_TYPE_KEY])
        else:
            warnings.warn(
                "AtomicData.to_ase(): self didn't contain atomic numbers... using atom_type as atomic numbers instead, but this means the chemical symbols in ASE (outputs) will be wrong"
            )
            atomic_nums = self[ATOM_TYPE_KEY]
        pbc = getattr(self, PBC_KEY, None)
        cell = getattr(self, CELL_KEY, None)
        batch = getattr(self, BATCH_KEY, None)
        energy = getattr(self, TOTAL_ENERGY_KEY, None)
        energies = getattr(self, PER_ATOM_ENERGY_KEY, None)
        force = getattr(self, FORCE_KEY, None)
        do_calc = any(
            k in self
            for k in [
                TOTAL_ENERGY_KEY,
                FORCE_KEY,
                PER_ATOM_ENERGY_KEY,
                STRESS_KEY,
            ]
        )

        # exclude those that are special for ASE and that we process seperately
        special_handling_keys = [
            POSITIONS_KEY,
            CELL_KEY,
            PBC_KEY,
            ATOMIC_NUMBERS_KEY,
            TOTAL_ENERGY_KEY,
            FORCE_KEY,
            PER_ATOM_ENERGY_KEY,
            STRESS_KEY,
        ]
        assert (
            len(set(extra_fields).intersection(special_handling_keys)) == 0
        ), f"Cannot specify keys handled in special ways ({special_handling_keys}) as `extra_fields` for atoms output--- they are output by default"

        if cell is not None:
            cell = cell.view(-1, 3, 3)
        if pbc is not None:
            pbc = pbc.view(-1, 3)

        if batch is not None:
            n_batches = batch.max() + 1
            cell = cell.expand(n_batches, 3, 3) if cell is not None else None
            pbc = pbc.expand(n_batches, 3) if pbc is not None else None
        else:
            n_batches = 1

        batch_atoms = []
        for batch_idx in range(n_batches):
            if batch is not None:
                mask = batch == batch_idx
                mask = mask.view(-1)
                # if both ends of the edge are in the batch, the edge is in the batch
                edge_mask = mask[edge_index[0]] & mask[edge_index[1]]
            else:
                mask = slice(None)
                edge_mask = slice(None)

            mol = ase.Atoms(
                numbers=atomic_nums[mask].view(-1),  # must be flat for ASE
                positions=positions[mask],
                cell=cell[batch_idx] if cell is not None else None,
                pbc=pbc[batch_idx] if pbc is not None else None,
            )

            if do_calc:
                fields = {}
                if energies is not None:
                    fields["energies"] = energies[mask].cpu().numpy()
                if energy is not None:
                    fields["energy"] = energy[batch_idx].cpu().numpy()
                if force is not None:
                    fields["forces"] = force[mask].cpu().numpy()
                if STRESS_KEY in self:
                    fields["stress"] = full_3x3_to_voigt_6_stress(
                        self["stress"].view(-1, 3, 3)[batch_idx].cpu().numpy()
                    )
                mol.calc = SinglePointCalculator(mol, **fields)

            # add other information
            for key in extra_fields:
                if key in _NODE_FIELDS:
                    # mask it
                    mol.arrays[key] = (
                        self[key][mask].cpu().numpy().reshape(mask.sum(), -1)
                    )
                elif key in _EDGE_FIELDS:
                    mol.info[key] = (
                        self[key][edge_mask].cpu().numpy().reshape(edge_mask.sum(), -1)
                    )
                elif key == EDGE_INDEX_KEY:
                    mol.info[key] = self[key][:, edge_mask].cpu().numpy()
                elif key in _GRAPH_FIELDS:
                    mol.info[key] = self[key][batch_idx].cpu().numpy().reshape(-1)
                else:
                    raise RuntimeError(
                        f"Extra field `{key}` isn't registered as node/edge/graph"
                    )

            batch_atoms.append(mol)

        if batch is not None:
            return batch_atoms
        else:
            assert len(batch_atoms) == 1
            return batch_atoms[0]

    def get_edge_vectors(data: Data) -> torch.Tensor:
        data = with_edge_vectors(AtomicData.to_AtomicDataDict(data))
        return data[EDGE_VECTORS_KEY]

    @staticmethod
    def to_AtomicDataDict(
        data: Union[Data, Mapping], exclude_keys=tuple()
    ) -> Type:
        if isinstance(data, Data):
            keys = data.keys
        elif isinstance(data, Mapping):
            keys = data.keys()
        else:
            raise ValueError(f"Invalid data `{repr(data)}`")

        return {
            k: data[k]
            for k in keys
            if (
                k not in exclude_keys
                and data[k] is not None
                and isinstance(data[k], torch.Tensor)
            )
        }

    @classmethod
    def from_AtomicDataDict(cls, data: Type):
        # it's an AtomicDataDict, so don't validate-- assume valid:
        return cls(_validate=False, **data)

    @property
    def irreps(self):
        return self.__irreps__

    def __cat_dim__(self, key, value):
        if key == EDGE_INDEX_KEY:
            return 1  # always cat in the edge dimension
        elif key in _GRAPH_FIELDS:
            # graph-level properties and so need a new batch dimension
            return None
        else:
            return 0  # cat along node/edge dimension

    def without_nodes(self, which_nodes):
        """Return a copy of ``self`` with ``which_nodes`` removed.
        The returned object may share references to some underlying data tensors with ``self``.
        Args:
            which_nodes (index tensor or boolean mask)
        Returns:
            A new data object.
        """
        which_nodes = torch.as_tensor(which_nodes)
        if which_nodes.dtype == torch.bool:
            mask = ~which_nodes
        else:
            mask = torch.ones(self.num_nodes, dtype=torch.bool)
            mask[which_nodes] = False
        assert mask.shape == (self.num_nodes,)
        n_keeping = mask.sum()

        # Only keep edges where both from and to are kept
        edge_mask = mask[self.edge_index[0]] & mask[self.edge_index[1]]
        # Create an index mapping:
        new_index = torch.full((self.num_nodes,), -1, dtype=torch.long)
        new_index[mask] = torch.arange(n_keeping, dtype=torch.long)

        new_dict = {}
        for k in self.keys:
            if k == EDGE_INDEX_KEY:
                new_dict[EDGE_INDEX_KEY] = new_index[
                    self.edge_index[:, edge_mask]
                ]
            elif k == EDGE_CELL_SHIFT_KEY:
                new_dict[EDGE_CELL_SHIFT_KEY] = self.edge_cell_shift[
                    edge_mask
                ]
            elif k == CELL_KEY:
                new_dict[k] = self[k]
            else:
                if isinstance(self[k], torch.Tensor) and len(self[k]) == self.num_nodes:
                    new_dict[k] = self[k][mask]
                else:
                    new_dict[k] = self[k]

        new_dict["irreps"] = self.__irreps__

        return type(self)(**new_dict)


_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"

# use "ase" as default
# TODO: eventually, choose fastest as default
# NOTE:
# - vesin and matscipy do not support self-interaction
# - vesin does not allow for mixed pbcs
_NEQUIP_NL: Final[str] = os.environ.get("NEQUIP_NL", "ase").lower()

if _NEQUIP_NL == "vesin":
    from vesin import NeighborList as vesin_nl
elif _NEQUIP_NL == "matscipy":
    import matscipy.neighbours
elif _NEQUIP_NL == "ase":
    import ase.neighborlist
else:
    raise NotImplementedError(f"Unknown neighborlist NEQUIP_NL = {_NEQUIP_NL}")


def neighbor_list_and_relative_vec(
    pos,
    r_max,
    self_interaction=False,
    strict_self_interaction=True,
    cell=None,
    pbc=False,
):
    """Create neighbor list and neighbor vectors based on radial cutoff.

    Create neighbor list (``edge_index``) and relative vectors
    (``edge_attr``) based on radial cutoff.

    Edges are given by the following convention:
    - ``edge_index[0]`` is the *source* (convolution center).
    - ``edge_index[1]`` is the *target* (neighbor).

    Thus, ``edge_index`` has the same convention as the relative vectors:
    :math:`\\vec{r}_{source, target}`

    If the input positions are a tensor with ``requires_grad == True``,
    the output displacement vectors will be correctly attached to the inputs
    for autograd.

    All outputs are Tensors on the same device as ``pos``; this allows future
    optimization of the neighbor list on the GPU.

    Args:
        pos (shape [N, 3]): Positional coordinate; Tensor or numpy array. If Tensor, must be on CPU.
        r_max (float): Radial cutoff distance for neighbor finding.
        cell (numpy shape [3, 3]): Cell for periodic boundary conditions. Ignored if ``pbc == False``.
        pbc (bool or 3-tuple of bool): Whether the system is periodic in each of the three cell dimensions.
        self_interaction (bool): Whether or not to include same periodic image self-edges in the neighbor list.
        strict_self_interaction (bool): Whether to include *any* self interaction edges in the graph, even if the two
            instances of the atom are in different periodic images. Defaults to True, should be True for most applications.

    Returns:
        edge_index (torch.tensor shape [2, num_edges]): List of edges.
        edge_cell_shift (torch.tensor shape [num_edges, 3]): Relative cell shift
            vectors. Returned only if cell is not None.
        cell (torch.Tensor [3, 3]): the cell as a tensor on the correct device.
            Returned only if cell is not None.
    """
    if isinstance(pbc, bool):
        pbc = (pbc,) * 3

    # Either the position or the cell may be on the GPU as tensors
    if isinstance(pos, torch.Tensor):
        temp_pos = pos.detach().cpu().numpy()
        out_device = pos.device
        out_dtype = pos.dtype
    else:
        temp_pos = np.asarray(pos)
        out_device = torch.device("cpu")
        out_dtype = torch.get_default_dtype()

    # Right now, GPU tensors require a round trip
    if out_device.type != "cpu":
        warnings.warn(
            "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible."
        )

    # Get a cell on the CPU no matter what
    if isinstance(cell, torch.Tensor):
        temp_cell = cell.detach().cpu().numpy()
        cell_tensor = cell.to(device=out_device, dtype=out_dtype)
    elif cell is not None:
        temp_cell = np.asarray(cell)
        cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype)
    else:
        # ASE will "complete" this correctly.
        temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype)
        cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype)

    # ASE dependent part
    temp_cell = ase.geometry.complete_cell(temp_cell)

    if _NEQUIP_NL == "vesin":
        assert strict_self_interaction and not self_interaction
        # use same mixed pbc logic as
        # https://github.com/Luthaf/vesin/blob/main/python/vesin/src/vesin/_ase.py
        if pbc[0] and pbc[1] and pbc[2]:
            periodic = True
        elif not pbc[0] and not pbc[1] and not pbc[2]:
            periodic = False
        else:
            raise ValueError(
                "different periodic boundary conditions on different axes are not supported by vesin neighborlist, use ASE or matscipy"
            )

        first_idex, second_idex, shifts = vesin_nl(
            cutoff=float(r_max), full_list=True
        ).compute(points=temp_pos, box=temp_cell, periodic=periodic, quantities="ijS")

    elif _NEQUIP_NL == "matscipy":
        assert strict_self_interaction and not self_interaction
        first_idex, second_idex, shifts = matscipy.neighbours.neighbour_list(
            "ijS",
            pbc=pbc,
            cell=temp_cell,
            positions=temp_pos,
            cutoff=float(r_max),
        )
    elif _NEQUIP_NL == "ase":
        first_idex, second_idex, shifts = ase.neighborlist.primitive_neighbor_list(
            "ijS",
            pbc,
            temp_cell,
            temp_pos,
            cutoff=float(r_max),
            self_interaction=strict_self_interaction,  # we want edges from atom to itself in different periodic images!
            use_scaled_positions=False,
        )

    # Eliminate true self-edges that don't cross periodic boundaries
    if not self_interaction:
        bad_edge = first_idex == second_idex
        bad_edge &= np.all(shifts == 0, axis=1)
        keep_edge = ~bad_edge
        if _ERROR_ON_NO_EDGES and (not np.any(keep_edge)):
            raise ValueError(
                f"Every single atom has no neighbors within the cutoff r_max={r_max} (after eliminating self edges, no edges remain in this system)"
            )
        first_idex = first_idex[keep_edge]
        second_idex = second_idex[keep_edge]
        shifts = shifts[keep_edge]

    # Build output:
    edge_index = torch.vstack(
        (torch.LongTensor(first_idex), torch.LongTensor(second_idex))
    ).to(device=out_device)

    shifts = torch.as_tensor(
        shifts,
        dtype=out_dtype,
        device=out_device,
    )
    return edge_index, shifts, cell_tensor

In [27]:
# Load some input structure from an XYZ or other ASE readable file:
data = AtomicData.from_ase(ase.io.read("./datasets/NaBr-data.xyz"), r_max=5.)
#data = data.to(device)

#out = model(AtomicData.to_AtomicDataDict(data))

print(type(data))
print(f"Total energy: {data[AtomicDataDict.TOTAL_ENERGY_KEY]}")
print(f"Force on atom 0: {data[AtomicDataDict.FORCE_KEY][0]}")

<class '__main__.AtomicData'>
Total energy: tensor([-189.3690])
Force on atom 0: tensor([0.0591, 0.8510, 0.0488])


### Going to implement the collections of atomic data.

In [32]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
from nequip.utils.misc import get_default_device_name
from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device=get_default_device_name(),
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

In [37]:
config

{'_jit_bailout_depth': 2, '_jit_fusion_strategy': [('DYNAMIC', 3)], '_jit_fuser': 'fuser1', 'root': 'results/NaBr-tutorial', 'tensorboard': False, 'wandb': True, 'model_builders': ['allegro.model.Allegro', 'PerSpeciesRescale', 'ForceOutput', 'RescaleEnergyEtc'], 'dataset_statistics_stride': 1, 'device': 'cpu', 'default_dtype': 'float32', 'model_dtype': 'float32', 'allow_tf32': True, 'verbose': 'info', 'model_debug_mode': False, 'equivariance_test': False, 'grad_anomaly_mode': False, 'gpu_oom_offload': False, 'append': True, 'warn_unused': False, 'run_name': 'NaBr', 'seed': 123456, 'dataset_seed': 123456, 'r_max': 5.0, 'avg_num_neighbors': 'auto', 'BesselBasis_trainable': True, 'PolynomialCutoff_p': 6, 'l_max': 1, 'parity': 'o3_full', 'num_layers': 1, 'env_embed_multiplicity': 8, 'embed_initial_edge': True, 'two_body_latent_mlp_latent_dimensions': [32, 64, 128], 'two_body_latent_mlp_nonlinearity': 'silu', 'two_body_latent_mlp_initialization': 'uniform', 'latent_mlp_latent_dimensions': [

In [40]:
config = Config.from_file('./configs/NaBr.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")


dataset[0]

AtomicData(atom_types=[64, 1], cell=[3, 3], edge_cell_shift=[1146, 3], edge_index=[2, 1146], forces=[64, 3], free_energy=[1], pbc=[3], pos=[64, 3], stress=[3, 3], total_energy=[1])

### Transformer

In [41]:
from typing import Dict, Optional, Union, List
import warnings

import torch

import ase.data

#from nequip.data import AtomicData, AtomicDataDict


class TypeMapper:
    """Based on a configuration, map atomic numbers to types."""

    num_types: int
    chemical_symbol_to_type: Optional[Dict[str, int]]
    type_to_chemical_symbol: Optional[Dict[int, str]]
    type_names: List[str]
    _min_Z: int

    def __init__(
        self,
        type_names: Optional[List[str]] = None,
        chemical_symbol_to_type: Optional[Dict[str, int]] = None,
        type_to_chemical_symbol: Optional[Dict[int, str]] = None,
        chemical_symbols: Optional[List[str]] = None,
    ):
        if chemical_symbols is not None:
            if chemical_symbol_to_type is not None:
                raise ValueError(
                    "Cannot provide both `chemical_symbols` and `chemical_symbol_to_type`"
                )
            # repro old, sane NequIP behaviour
            # checks also for validity of keys
            atomic_nums = [ase.data.atomic_numbers[sym] for sym in chemical_symbols]
            # https://stackoverflow.com/questions/29876580/how-to-sort-a-list-according-to-another-list-python
            chemical_symbols = [
                e[1] for e in sorted(zip(atomic_nums, chemical_symbols))
            ]
            chemical_symbol_to_type = {k: i for i, k in enumerate(chemical_symbols)}
            del chemical_symbols

        if type_to_chemical_symbol is not None:
            type_to_chemical_symbol = {
                int(k): v for k, v in type_to_chemical_symbol.items()
            }
            assert all(
                v in ase.data.chemical_symbols for v in type_to_chemical_symbol.values()
            )

        # Build from chem->type mapping, if provided
        self.chemical_symbol_to_type = chemical_symbol_to_type
        if self.chemical_symbol_to_type is not None:
            # Validate
            for sym, type in self.chemical_symbol_to_type.items():
                assert sym in ase.data.atomic_numbers, f"Invalid chemical symbol {sym}"
                assert 0 <= type, f"Invalid type number {type}"
            assert set(self.chemical_symbol_to_type.values()) == set(
                range(len(self.chemical_symbol_to_type))
            )
            if type_names is None:
                # Make type_names
                type_names = [None] * len(self.chemical_symbol_to_type)
                for sym, type in self.chemical_symbol_to_type.items():
                    type_names[type] = sym
            else:
                # Make sure they agree on types
                # We already checked that chem->type is contiguous,
                # so enough to check length since type_names is a list
                assert len(type_names) == len(self.chemical_symbol_to_type)
            # Make mapper array
            valid_atomic_numbers = [
                ase.data.atomic_numbers[sym] for sym in self.chemical_symbol_to_type
            ]
            self._min_Z = min(valid_atomic_numbers)
            self._max_Z = max(valid_atomic_numbers)
            Z_to_index = torch.full(
                size=(1 + self._max_Z - self._min_Z,), fill_value=-1, dtype=torch.long
            )
            for sym, type in self.chemical_symbol_to_type.items():
                Z_to_index[ase.data.atomic_numbers[sym] - self._min_Z] = type
            self._Z_to_index = Z_to_index
            self._index_to_Z = torch.zeros(
                size=(len(self.chemical_symbol_to_type),), dtype=torch.long
            )
            for sym, type_idx in self.chemical_symbol_to_type.items():
                self._index_to_Z[type_idx] = ase.data.atomic_numbers[sym]
            self._valid_set = set(valid_atomic_numbers)
            true_type_to_chemical_symbol = {
                type_id: sym for sym, type_id in self.chemical_symbol_to_type.items()
            }
            if type_to_chemical_symbol is not None:
                assert type_to_chemical_symbol == true_type_to_chemical_symbol
            else:
                type_to_chemical_symbol = true_type_to_chemical_symbol

        # check
        if type_names is None:
            raise ValueError(
                "None of chemical_symbols, chemical_symbol_to_type, nor type_names was provided; exactly one is required"
            )
        # validate type names
        assert all(
            n.isalnum() for n in type_names
        ), "Type names must contain only alphanumeric characters"
        # Set to however many maps specified -- we already checked contiguous
        self.num_types = len(type_names)
        # Check type_names
        self.type_names = type_names
        self.type_to_chemical_symbol = type_to_chemical_symbol
        if self.type_to_chemical_symbol is not None:
            assert set(type_to_chemical_symbol.keys()) == set(range(self.num_types))

    def __call__(
        self, data: Union[AtomicDataDict.Type, AtomicData], types_required: bool = True
    ) -> Union[AtomicDataDict.Type, AtomicData]:
        if AtomicDataDict.ATOM_TYPE_KEY in data:
            if AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
                warnings.warn(
                    "Data contained both ATOM_TYPE_KEY and ATOMIC_NUMBERS_KEY; ignoring ATOMIC_NUMBERS_KEY"
                )
        elif AtomicDataDict.ATOMIC_NUMBERS_KEY in data:
            assert (
                self.chemical_symbol_to_type is not None
            ), "Atomic numbers provided but there is no chemical_symbols/chemical_symbol_to_type mapping!"
            atomic_numbers = data[AtomicDataDict.ATOMIC_NUMBERS_KEY]
            del data[AtomicDataDict.ATOMIC_NUMBERS_KEY]

            data[AtomicDataDict.ATOM_TYPE_KEY] = self.transform(atomic_numbers)
        else:
            if types_required:
                raise KeyError(
                    "Data doesn't contain any atom type information (ATOM_TYPE_KEY or ATOMIC_NUMBERS_KEY)"
                )
        return data

    def transform(self, atomic_numbers):
        """core function to transform an array to specie index list"""

        if atomic_numbers.min() < self._min_Z or atomic_numbers.max() > self._max_Z:
            bad_set = set(torch.unique(atomic_numbers).cpu().tolist()) - self._valid_set
            raise ValueError(
                f"Data included atomic numbers {bad_set} that are not part of the atomic number -> type mapping!"
            )

        return self._Z_to_index.to(device=atomic_numbers.device)[
            atomic_numbers - self._min_Z
        ]

    def untransform(self, atom_types):
        """Transform atom types back into atomic numbers"""
        return self._index_to_Z[atom_types].to(device=atom_types.device)

    @property
    def has_chemical_symbols(self) -> bool:
        return self.chemical_symbol_to_type is not None

    @staticmethod
    def format(
        data: list, type_names: List[str], element_formatter: str = ".6f"
    ) -> str:
        data = torch.as_tensor(data) if data is not None else None
        if data is None:
            return f"[{', '.join(type_names)}: None]"
        elif data.ndim == 0:
            return (f"[{', '.join(type_names)}: {{:{element_formatter}}}]").format(data)
        elif data.ndim == 1 and len(data) == len(type_names):
            return (
                "["
                + ", ".join(
                    f"{{{i}[0]}}: {{{i}[1]:{element_formatter}}}"
                    for i in range(len(data))
                )
                + "]"
            ).format(*zip(type_names, data))
        else:
            raise ValueError(
                f"Don't know how to format data=`{data}` for types {type_names} with element_formatter=`{element_formatter}`"
            )

### Utils

In [43]:
from typing import List, Optional, Callable, Union, Any, Tuple

import re
import copy
import warnings
import numpy as np
import os.path as osp
import sys
from collections.abc import Sequence

import torch.utils.data
from torch import Tensor

#from .data import Data
from nequip.utils.torch_geometric.utils import makedirs

IndexType = Union[slice, Tensor, np.ndarray, Sequence]


class Dataset(torch.utils.data.Dataset):
    r"""Dataset base class for creating graph datasets.
    See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
    create_dataset.html>`__ for the accompanying tutorial.

    Args:
        root (string, optional): Root directory where the dataset should be
            saved. (optional: :obj:`None`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files to find in the :obj:`self.raw_dir` folder in
        order to skip the download."""
        raise NotImplementedError

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        r"""The name of the files to find in the :obj:`self.processed_dir`
        folder in order to skip the processing."""
        raise NotImplementedError

    def download(self):
        r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
        raise NotImplementedError

    def process(self):
        r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
        raise NotImplementedError

    def len(self) -> int:
        raise NotImplementedError

    def get(self, idx: int) -> Data:
        r"""Gets the data object at index :obj:`idx`."""
        raise NotImplementedError

    def __init__(
        self,
        root: Optional[str] = None,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
    ):
        super().__init__()

        if isinstance(root, str):
            root = osp.expanduser(osp.normpath(root))

        self.root = root
        self.transform = transform
        self.pre_transform = pre_transform
        self.pre_filter = pre_filter
        self._indices: Optional[Sequence] = None

        if "download" in self.__class__.__dict__.keys():
            self._download()

        if "process" in self.__class__.__dict__.keys():
            self._process()

    def indices(self) -> Sequence:
        return range(self.len()) if self._indices is None else self._indices

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, "raw")

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, "processed")

    @property
    def num_node_features(self) -> int:
        r"""Returns the number of features per node in the dataset."""
        data = self[0]
        if hasattr(data, "num_node_features"):
            return data.num_node_features
        raise AttributeError(
            f"'{data.__class__.__name__}' object has no "
            f"attribute 'num_node_features'"
        )

    @property
    def num_features(self) -> int:
        r"""Alias for :py:attr:`~num_node_features`."""
        return self.num_node_features

    @property
    def num_edge_features(self) -> int:
        r"""Returns the number of features per edge in the dataset."""
        data = self[0]
        if hasattr(data, "num_edge_features"):
            return data.num_edge_features
        raise AttributeError(
            f"'{data.__class__.__name__}' object has no "
            f"attribute 'num_edge_features'"
        )

    @property
    def raw_paths(self) -> List[str]:
        r"""The filepaths to find in order to skip the download."""
        files = to_list(self.raw_file_names)
        return [osp.join(self.raw_dir, f) for f in files]

    @property
    def processed_paths(self) -> List[str]:
        r"""The filepaths to find in the :obj:`self.processed_dir`
        folder in order to skip the processing."""
        files = to_list(self.processed_file_names)
        return [osp.join(self.processed_dir, f) for f in files]

    def _download(self):
        if files_exist(self.raw_paths):  # pragma: no cover
            return

        makedirs(self.raw_dir)
        self.download()

    def _process(self):
        f = osp.join(self.processed_dir, "pre_transform.pt")
        if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
            warnings.warn(
                f"The `pre_transform` argument differs from the one used in "
                f"the pre-processed version of this dataset. If you want to "
                f"make use of another pre-processing technique, make sure to "
                f"sure to delete '{self.processed_dir}' first"
            )

        f = osp.join(self.processed_dir, "pre_filter.pt")
        if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
            warnings.warn(
                "The `pre_filter` argument differs from the one used in the "
                "pre-processed version of this dataset. If you want to make "
                "use of another pre-fitering technique, make sure to delete "
                "'{self.processed_dir}' first"
            )

        if files_exist(self.processed_paths):  # pragma: no cover
            return

        print("Processing dataset...", file=sys.stderr)

        makedirs(self.processed_dir)
        self.process()

        path = osp.join(self.processed_dir, "pre_transform.pt")
        torch.save(_repr(self.pre_transform), path)
        path = osp.join(self.processed_dir, "pre_filter.pt")
        torch.save(_repr(self.pre_filter), path)

        print("Done!", file=sys.stderr)

    def __len__(self) -> int:
        r"""The number of examples in the dataset."""
        return len(self.indices())

    def __getitem__(
        self,
        idx: Union[int, np.integer, IndexType],
    ) -> Union["Dataset", Data]:
        r"""In case :obj:`idx` is of type integer, will return the data object
        at index :obj:`idx` (and transforms it in case :obj:`transform` is
        present).
        In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
        tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
        :obj:`np.array`, will return a subset of the dataset at the specified
        indices."""
        if (
            isinstance(idx, (int, np.integer))
            or (isinstance(idx, Tensor) and idx.dim() == 0)
            or (isinstance(idx, np.ndarray) and np.isscalar(idx))
        ):

            data = self.get(self.indices()[idx])
            data = data if self.transform is None else self.transform(data)
            return data

        else:
            return self.index_select(idx)

    def index_select(self, idx: IndexType) -> "Dataset":
        indices = self.indices()

        if isinstance(idx, slice):
            indices = indices[idx]

        elif isinstance(idx, Tensor) and idx.dtype == torch.long:
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
            idx = idx.flatten().nonzero(as_tuple=False)
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
            idx = idx.flatten().nonzero()[0]
            return self.index_select(idx.flatten().tolist())

        elif isinstance(idx, Sequence) and not isinstance(idx, str):
            indices = [indices[i] for i in idx]

        else:
            raise IndexError(
                f"Only integers, slices (':'), list, tuples, torch.tensor and "
                f"np.ndarray of dtype long or bool are valid indices (got "
                f"'{type(idx).__name__}')"
            )

        dataset = copy.copy(self)
        dataset._indices = indices
        return dataset

    def shuffle(
        self,
        return_perm: bool = False,
    ) -> Union["Dataset", Tuple["Dataset", Tensor]]:
        r"""Randomly shuffles the examples in the dataset.

        Args:
            return_perm (bool, optional): If set to :obj:`True`, will return
                the random permutation used to shuffle the dataset in addition.
                (default: :obj:`False`)
        """
        perm = torch.randperm(len(self))
        dataset = self.index_select(perm)
        return (dataset, perm) if return_perm is True else dataset

    def __repr__(self) -> str:
        arg_repr = str(len(self)) if len(self) > 1 else ""
        return f"{self.__class__.__name__}({arg_repr})"


def to_list(value: Any) -> Sequence:
    if isinstance(value, Sequence) and not isinstance(value, str):
        return value
    else:
        return [value]


def files_exist(files: List[str]) -> bool:
    # NOTE: We return `False` in case `files` is empty, leading to a
    # re-processing of files on every instantiation.
    return len(files) != 0 and all([osp.exists(f) for f in files])


def _repr(obj: Any) -> str:
    if obj is None:
        return "None"
    return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__())

In [44]:
from typing import List

from collections.abc import Sequence

import torch
import numpy as np
from torch import Tensor

#from .data import Data
#from .dataset import IndexType


class Batch(Data):
    r"""A plain old python object modeling a batch of graphs as one big
    (disconnected) graph. With :class:`torch_geometric.data.Data` being the
    base class, all its methods can also be used here.
    In addition, single graphs can be reconstructed via the assignment vector
    :obj:`batch`, which maps each node to its respective graph identifier.
    """

    def __init__(self, batch=None, ptr=None, **kwargs):
        super(Batch, self).__init__(**kwargs)

        for key, item in kwargs.items():
            if key == "num_nodes":
                self.__num_nodes__ = item
            else:
                self[key] = item

        self.batch = batch
        self.ptr = ptr
        self.__data_class__ = Data
        self.__slices__ = None
        self.__cumsum__ = None
        self.__cat_dims__ = None
        self.__num_nodes_list__ = None
        self.__num_graphs__ = None

    @classmethod
    def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]):
        r"""Constructs a batch object from a python list holding
        :class:`torch_geometric.data.Data` objects.
        The assignment vector :obj:`batch` is created on the fly.
        Additionally, creates assignment batch vectors for each key in
        :obj:`follow_batch`.
        Will exclude any keys given in :obj:`exclude_keys`."""

        keys = list(set(data_list[0].keys) - set(exclude_keys))
        assert "batch" not in keys and "ptr" not in keys

        batch = cls()
        for key in data_list[0].__dict__.keys():
            if key[:2] != "__" and key[-2:] != "__":
                batch[key] = None

        batch.__num_graphs__ = len(data_list)
        batch.__data_class__ = data_list[0].__class__
        for key in keys + ["batch"]:
            batch[key] = []
        batch["ptr"] = [0]

        device = None
        slices = {key: [0] for key in keys}
        cumsum = {key: [0] for key in keys}
        cat_dims = {}
        num_nodes_list = []
        for i, data in enumerate(data_list):
            for key in keys:
                item = data[key]

                # Increase values by `cumsum` value.
                cum = cumsum[key][-1]
                if isinstance(item, Tensor) and item.dtype != torch.bool:
                    if not isinstance(cum, int) or cum != 0:
                        item = item + cum
                elif isinstance(item, (int, float)):
                    item = item + cum

                # Gather the size of the `cat` dimension.
                size = 1
                cat_dim = data.__cat_dim__(key, data[key])
                # 0-dimensional tensors have no dimension along which to
                # concatenate, so we set `cat_dim` to `None`.
                if isinstance(item, Tensor) and item.dim() == 0:
                    cat_dim = None
                cat_dims[key] = cat_dim

                # Add a batch dimension to items whose `cat_dim` is `None`:
                if isinstance(item, Tensor) and cat_dim is None:
                    cat_dim = 0  # Concatenate along this new batch dimension.
                    item = item.unsqueeze(0)
                    device = item.device
                elif isinstance(item, Tensor):
                    size = item.size(cat_dim)
                    device = item.device

                batch[key].append(item)  # Append item to the attribute list.

                slices[key].append(size + slices[key][-1])
                inc = data.__inc__(key, item)
                if isinstance(inc, (tuple, list)):
                    inc = torch.tensor(inc)
                cumsum[key].append(inc + cumsum[key][-1])

                if key in follow_batch:
                    if isinstance(size, Tensor):
                        for j, size in enumerate(size.tolist()):
                            tmp = f"{key}_{j}_batch"
                            batch[tmp] = [] if i == 0 else batch[tmp]
                            batch[tmp].append(
                                torch.full((size,), i, dtype=torch.long, device=device)
                            )
                    else:
                        tmp = f"{key}_batch"
                        batch[tmp] = [] if i == 0 else batch[tmp]
                        batch[tmp].append(
                            torch.full((size,), i, dtype=torch.long, device=device)
                        )

            if hasattr(data, "__num_nodes__"):
                num_nodes_list.append(data.__num_nodes__)
            else:
                num_nodes_list.append(None)

            num_nodes = data.num_nodes
            if num_nodes is not None:
                item = torch.full((num_nodes,), i, dtype=torch.long, device=device)
                batch.batch.append(item)
                batch.ptr.append(batch.ptr[-1] + num_nodes)

        batch.batch = None if len(batch.batch) == 0 else batch.batch
        batch.ptr = None if len(batch.ptr) == 1 else batch.ptr
        batch.__slices__ = slices
        batch.__cumsum__ = cumsum
        batch.__cat_dims__ = cat_dims
        batch.__num_nodes_list__ = num_nodes_list

        ref_data = data_list[0]
        for key in batch.keys:
            items = batch[key]
            if None in items:
                raise ValueError(
                    f"Found a `None` in the provided data objects for batching in key `{key}`"
                )
            item = items[0]
            cat_dim = ref_data.__cat_dim__(key, item)
            cat_dim = 0 if cat_dim is None else cat_dim
            if isinstance(item, Tensor):
                batch[key] = torch.cat(items, cat_dim)
            elif isinstance(item, (int, float)):
                batch[key] = torch.tensor(items)

        # if torch_geometric.is_debug_enabled():
        #     batch.debug()

        return batch.contiguous()

    def get_example(self, idx: int) -> Data:
        r"""Reconstructs the :class:`torch_geometric.data.Data` object at index
        :obj:`idx` from the batch object.
        The batch object must have been created via :meth:`from_data_list` in
        order to be able to reconstruct the initial objects."""

        if self.__slices__ is None:
            raise RuntimeError(
                (
                    "Cannot reconstruct data list from batch because the batch "
                    "object was not created using `Batch.from_data_list()`."
                )
            )

        data = self.__data_class__()
        idx = self.num_graphs + idx if idx < 0 else idx

        for key in self.__slices__.keys():
            item = self[key]
            if self.__cat_dims__[key] is None:
                # The item was concatenated along a new batch dimension,
                # so just index in that dimension:
                item = item[idx]
            else:
                # Narrow the item based on the values in `__slices__`.
                if isinstance(item, Tensor):
                    dim = self.__cat_dims__[key]
                    start = self.__slices__[key][idx]
                    end = self.__slices__[key][idx + 1]
                    item = item.narrow(dim, start, end - start)
                else:
                    start = self.__slices__[key][idx]
                    end = self.__slices__[key][idx + 1]
                    item = item[start:end]
                    item = item[0] if len(item) == 1 else item

            # Decrease its value by `cumsum` value:
            cum = self.__cumsum__[key][idx]
            if isinstance(item, Tensor):
                if not isinstance(cum, int) or cum != 0:
                    item = item - cum
            elif isinstance(item, (int, float)):
                item = item - cum

            data[key] = item

        if self.__num_nodes_list__[idx] is not None:
            data.num_nodes = self.__num_nodes_list__[idx]

        return data

    def index_select(self, idx: IndexType) -> List[Data]:
        if isinstance(idx, slice):
            idx = list(range(self.num_graphs)[idx])

        elif isinstance(idx, Tensor) and idx.dtype == torch.long:
            idx = idx.flatten().tolist()

        elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
            idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist()

        elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
            idx = idx.flatten().tolist()

        elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
            idx = idx.flatten().nonzero()[0].flatten().tolist()

        elif isinstance(idx, Sequence) and not isinstance(idx, str):
            pass

        else:
            raise IndexError(
                f"Only integers, slices (':'), list, tuples, torch.tensor and "
                f"np.ndarray of dtype long or bool are valid indices (got "
                f"'{type(idx).__name__}')"
            )

        return [self.get_example(i) for i in idx]

    def __getitem__(self, idx):
        if isinstance(idx, str):
            return super(Batch, self).__getitem__(idx)
        elif isinstance(idx, (int, np.integer)):
            return self.get_example(idx)
        else:
            return self.index_select(idx)

    def to_data_list(self) -> List[Data]:
        r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
        from the batch object.
        The batch object must have been created via :meth:`from_data_list` in
        order to be able to reconstruct the initial objects."""
        return [self.get_example(i) for i in range(self.num_graphs)]

    @property
    def num_graphs(self) -> int:
        """Returns the number of graphs in the batch."""
        if self.__num_graphs__ is not None:
            return self.__num_graphs__
        elif self.ptr is not None:
            return self.ptr.numel() - 1
        elif self.batch is not None:
            return int(self.batch.max()) + 1
        else:
            raise ValueError

### Dataset

In [45]:
import numpy as np
import logging
import inspect
import itertools
import yaml
import hashlib
import math
from typing import Tuple, Dict, Any, List, Callable, Union, Optional

import torch

from torch_runstats.scatter import scatter_std, scatter_mean

#from nequip.utils.torch_geometric import Batch, Dataset
from nequip.utils.torch_geometric.utils import download_url, extract_zip

#import nequip
#from nequip.data import (
#    AtomicData,
#    AtomicDataDict,
#    _NODE_FIELDS,
#    _EDGE_FIELDS,
#    _GRAPH_FIELDS,
#)
from nequip.utils.batch_ops import bincount
from nequip.utils.regressor import solver
from nequip.utils.savenload import atomic_write
#from ..transforms import TypeMapper


class AtomicDataset(Dataset):
    """The base class for all NequIP datasets."""

    root: str
    dtype: torch.dtype

    def __init__(
        self,
        root: str,
        type_mapper: Optional[TypeMapper] = None,
    ):
        self.dtype = torch.get_default_dtype()
        super().__init__(root=root, transform=type_mapper)

    def statistics(
        self,
        fields: List[Union[str, Callable]],
        modes: List[str],
        stride: int = 1,
        unbiased: bool = True,
        kwargs: Optional[Dict[str, dict]] = {},
    ) -> List[tuple]:
        # TODO: If needed, this can eventually be implimented for general AtomicDataset 
        # by computing an online running mean and using Welford's method 
        # for a stable running standard deviation: 
        # https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/
        # That would be needed if we have lazy loading datasets.
        # TODO: When lazy-loading datasets are implimented, how to deal with statistics, sampling, and subsets?
        raise NotImplementedError("not implimented for general AtomicDataset yet")

    @property
    def type_mapper(self) -> Optional[TypeMapper]:
        # self.transform is always a TypeMapper
        return self.transform

    def _get_parameters(self) -> Dict[str, Any]:
        """Get a dict of the parameters used to build this dataset."""
        pnames = list(inspect.signature(self.__init__).parameters)
        IGNORE_KEYS = {
            # the type mapper is applied after saving, not before, so doesn't matter to cache validity
            "type_mapper"
        }
        params = {
            k: getattr(self, k)
            for k in pnames
            if k not in IGNORE_KEYS and hasattr(self, k)
        }
        # Add other relevant metadata:
        params["dtype"] = str(self.dtype)
        params["nequip_version"] = nequip.__version__
        return params

    @property
    def processed_dir(self) -> str:
        # We want the file name to change when the parameters change
        # So, first we get all parameters:
        params = self._get_parameters()
        # Make some kind of string of them:
        # we don't care about this possibly changing between python versions,
        # since a change in python version almost certainly means a change in
        # versions of other things too, and is a good reason to recompute
        buffer = yaml.dump(params).encode("ascii")
        # And hash it:
        param_hash = hashlib.sha1(buffer).hexdigest()
        return f"{self.root}/processed_dataset_{param_hash}"


class AtomicInMemoryDataset(AtomicDataset):
    r"""Base class for all datasets that fit in memory.

    Please note that, as a ``pytorch_geometric`` dataset, it must be backed by some kind of disk storage.
    By default, the raw file will be stored at root/raw and the processed torch
    file will be at root/process.

    Subclasses must implement:
     - ``raw_file_names``
     - ``get_data()``

    Subclasses may implement:
     - ``download()`` or ``self.url`` or ``ClassName.URL``

    Args:
        root (str, optional): Root directory where the dataset should be saved. 
        Defaults to current working directory.
        file_name (str, optional): file name of data source. only used in children class
        url (str, optional): url to download data source
        AtomicData_options (dict, optional): extra key that are not stored in data 
        but needed for AtomicData initialization
        include_frames (list, optional): the frames to process with the constructor.
        type_mapper (TypeMapper): the transformation to map atomic information to species index. Optional
    """

    def __init__(
        self,
        root: str,
        file_name: Optional[str] = None,
        url: Optional[str] = None,
        AtomicData_options: Dict[str, Any] = {},
        include_frames: Optional[List[int]] = None,
        type_mapper: Optional[TypeMapper] = None,
    ):
        # TO DO, this may be simplified
        # See if a subclass defines some inputs
        self.file_name = (
            getattr(type(self), "FILE_NAME", None) if file_name is None else file_name
        )
        self.url = getattr(type(self), "URL", url)

        self.AtomicData_options = AtomicData_options
        self.include_frames = include_frames

        self.data = None

        # !!! don't delete this block.
        # otherwise the inherent children class
        # will ignore the download function here
        class_type = type(self)
        if class_type != AtomicInMemoryDataset:
            if "download" not in self.__class__.__dict__:
                class_type.download = AtomicInMemoryDataset.download
            if "process" not in self.__class__.__dict__:
                class_type.process = AtomicInMemoryDataset.process

        # Initialize the InMemoryDataset, which runs download and process
        # See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets
        # Then pre-process the data if disk files are not found
        super().__init__(root=root, type_mapper=type_mapper)
        if self.data is None:
            self.data, include_frames = torch.load(self.processed_paths[0])
            if not np.all(include_frames == self.include_frames):
                raise ValueError(
                    f"the include_frames is changed. "
                    f"please delete the processed folder and rerun {self.processed_paths[0]}"
                )

    def len(self):
        if self.data is None:
            return 0
        return self.data.num_graphs

    @property
    def raw_file_names(self):
        raise NotImplementedError()

    @property
    def processed_file_names(self) -> List[str]:
        return ["data.pth", "params.yaml"]

    def get_data(
        self,
    ) -> Union[Tuple[Dict[str, Any], Dict[str, Any]], List[AtomicData]]:
        """Get the data --- called from ``process()``, can assume that ``raw_file_names()`` exist.

        Note that parameters for graph construction such as ``pbc`` and ``r_max`` should be included here as (likely, but not necessarily, fixed) fields.

        Returns:
        A dict:
            fields: dict
                mapping a field name ('pos', 'cell') to a list-like sequence of tensor-like objects giving that field's value for each example.
        Or:
            data_list: List[AtomicData]
        """
        raise NotImplementedError

    def download(self):
        if (not hasattr(self, "url")) or (self.url is None):
            # Don't download, assume present. 
            # Later could have FileNotFound if the files don't actually exist
            pass
        else:
            download_path = download_url(self.url, self.raw_dir)
            if download_path.endswith(".zip"):
                extract_zip(download_path, self.raw_dir)

    def process(self):
        data = self.get_data()
        if isinstance(data, list):

            # It's a data list
            data_list = data
            if not (self.include_frames is None or data_list is None):
                data_list = [data_list[i] for i in self.include_frames]
            assert all(isinstance(e, AtomicData) for e in data_list)
            assert all(AtomicDataDict.BATCH_KEY not in e for e in data_list)

            fields = {}

        elif isinstance(data, dict):
            # It's fields
            # Get our data
            fields = data

            # check keys
            all_keys = set(fields.keys())
            assert AtomicDataDict.BATCH_KEY not in all_keys
            # Check bad key combinations, but don't require that this be a graph yet.
            AtomicDataDict.validate_keys(all_keys, graph_required=False)

            # check dimensionality
            num_examples = set([len(a) for a in fields.values()])
            if not len(num_examples) == 1:
                shape_dict = {f: v.shape for f, v in fields.items()}
                raise ValueError(
                    f"This dataset is invalid: expected all fields to have same length (same number of examples), but they had shapes {shape_dict}"
                )
            num_examples = next(iter(num_examples))

            include_frames = self.include_frames
            if include_frames is None:
                include_frames = range(num_examples)

            # Make AtomicData from it:
            if AtomicDataDict.EDGE_INDEX_KEY in all_keys:
                # This is already a graph, just build it
                constructor = AtomicData
            else:
                # do neighborlist from points
                constructor = AtomicData.from_points
                assert "r_max" in self.AtomicData_options
                assert AtomicDataDict.POSITIONS_KEY in all_keys

            data_list = [
                constructor(
                    **{
                        **{f: v[i] for f, v in fields.items()},
                        **self.AtomicData_options,
                    }
                )
                for i in include_frames
            ]

        else:
            raise ValueError("Invalid return from `self.get_data()`")

        # Batch it for efficient saving
        # This limits an AtomicInMemoryDataset to a maximum of LONG_MAX atoms _overall_, but that is a very big number and any dataset that large is probably not "InMemory" anyway
        data = Batch.from_data_list(data_list)
        del data_list
        del fields

        total_MBs = sum(item.numel() * item.element_size() for _, item in data) / (
            1024 * 1024
        )
        logging.info(
            f"Loaded data: {data}\n    processed data size: ~{total_MBs:.2f} MB"
        )
        del total_MBs

        # use atomic writes to avoid race conditions between
        # different trainings that use the same dataset
        # since those separate trainings should all produce the same results,
        # it doesn't matter if they overwrite each others cached'
        # datasets. It only matters that they don't simultaneously try
        # to write the _same_ file, corrupting it.
        with atomic_write(self.processed_paths[0], binary=True) as f:
            torch.save((data, self.include_frames), f)
        with atomic_write(self.processed_paths[1], binary=False) as f:
            yaml.dump(self._get_parameters(), f)

        logging.info("Cached processed data to disk")

        self.data = data

    def get(self, idx):
        return self.data.get_example(idx)

    def _selectors(
        self,
        stride: int = 1,
    ):
        if self._indices is not None:
            graph_selector = torch.as_tensor(self._indices)[::stride]
            # note that self._indices is _not_ necessarily in order,
            # while self.data --- which we take our arrays from ---
            # is always in the original order.
            # In particular, the values of `self.data.batch`
            # are indexes in the ORIGINAL order
            # thus we need graph level properties to also be in the original order
            # so that batch values index into them correctly
            # since self.data.batch is always sorted & contiguous
            # (because of Batch.from_data_list)
            # we sort it:
            graph_selector, _ = torch.sort(graph_selector)
        else:
            graph_selector = torch.arange(0, self.len(), stride)

        node_selector = torch.as_tensor(
            np.in1d(self.data.batch.numpy(), graph_selector.numpy())
        )

        edge_index = self.data[AtomicDataDict.EDGE_INDEX_KEY]
        edge_selector = node_selector[edge_index[0]] & node_selector[edge_index[1]]

        return (graph_selector, node_selector, edge_selector)

    def statistics(
        self,
        fields: List[Union[str, Callable]],
        modes: List[str],
        stride: int = 1,
        unbiased: bool = True,
        kwargs: Optional[Dict[str, dict]] = {},
    ) -> List[tuple]:
        """Compute the statistics of ``fields`` in the dataset.

        If the values at the fields are vectors/multidimensional, they must be of fixed shape and elementwise
        statistics will be computed.

        Args:
            fields: the names of the fields to compute statistics for.
                Instead of a field name, a callable can also be given that reuturns a quantity 
                to compute the statisics for.

                If a callable is given, it will be called with a (possibly batched) ``Data``-like object 
                and must return a sequence of points to add to the set over which the statistics 
                will be computed.
                The callable must also return a string, one of ``"node"`` or ``"graph"``, 
                indicating whether the value it returns is a per-node or per-graph quantity.
                PLEASE NOTE: the argument to the callable may be "batched", 
                and it may not be batched "contiguously": ``batch`` and ``edge_index`` 
                may have "gaps" in their values.

                For example, to compute the overall statistics of the x,y, and z components 
                of a per-node vector ``force`` field:

                    data.statistics([lambda data: (data.force.flatten(), "node")])

                The above computes the statistics over a set of size 3N, where N is the total number of nodes
                in the dataset.

            modes: the statistic to compute for each field. Valid options are:
                 - ``count``
                 - ``rms``
                 - ``mean_std``
                 - ``per_atom_*``
                 - ``per_species_*``

            stride: the stride over the dataset while computing statistcs.

            unbiased: whether to use unbiased for standard deviations.

            kwargs: other options for individual statistics modes.

        Returns:
            List of statistics. For fields of floating dtype the statistics are the two-tuple (mean, std); for fields of integer dtype the statistics are a one-tuple (bincounts,)
        """

        # Short circut:
        assert len(modes) == len(fields)
        if len(fields) == 0:
            return []

        graph_selector, node_selector, edge_selector = self._selectors(stride=stride)

        num_graphs = len(graph_selector)
        num_nodes = node_selector.sum()
        num_edges = edge_selector.sum()

        if self.transform is not None:
            # pre-transform the data so that statistics process transformed data
            data_transformed = self.transform(self.data.to_dict(), types_required=False)
        else:
            data_transformed = self.data.to_dict()
        # pre-select arrays
        # this ensures that all following computations use the right data
        all_keys = set()
        selectors = {}
        for k in data_transformed.keys():
            all_keys.add(k)
            if k in _NODE_FIELDS:
                selectors[k] = node_selector
            elif k in _GRAPH_FIELDS:
                selectors[k] = graph_selector
            elif k == AtomicDataDict.EDGE_INDEX_KEY:
                selectors[k] = (slice(None, None, None), edge_selector)
            elif k in _EDGE_FIELDS:
                selectors[k] = edge_selector
        # TODO: do the batch indexes, edge_indexes, etc. after selection need to be
        # "compacted" to subtract out their offsets? For now, we just punt this
        # onto the writer of the callable field.
        # apply selector to actual data
        data_transformed = {
            k: data_transformed[k][selectors[k]]
            for k in data_transformed.keys()
            if k in selectors
        }

        atom_types: Optional[torch.Tensor] = None
        out: list = []
        for ifield, field in enumerate(fields):
            if callable(field):
                # make a joined thing? so it includes fixed fields
                arr, arr_is_per = field(data_transformed)
                arr = arr.to(self.dtype)  # all statistics must be on floating
                assert arr_is_per in ("node", "graph", "edge")
            else:
                if field not in all_keys:
                    raise RuntimeError(
                        f"The field key `{field}` is not present in this dataset"
                    )
                if field not in selectors:
                    # this means field is not selected and so not available
                    raise RuntimeError(
                        f"Only per-node and per-graph fields can have statistics computed; `{field}` has not been registered as either. If it is per-node or per-graph, please register it as such"
                    )
                arr = data_transformed[field]
                if field in _NODE_FIELDS:
                    arr_is_per = "node"
                elif field in _GRAPH_FIELDS:
                    arr_is_per = "graph"
                elif field in _EDGE_FIELDS:
                    arr_is_per = "edge"
                else:
                    raise RuntimeError

            # Check arr
            if arr is None:
                raise ValueError(
                    f"Cannot compute statistics over field `{field}` whose value is None!"
                )
            if not isinstance(arr, torch.Tensor):
                if np.issubdtype(arr.dtype, np.floating):
                    arr = torch.as_tensor(arr, dtype=self.dtype)
                else:
                    arr = torch.as_tensor(arr)
            if arr_is_per == "node":
                arr = arr.view(num_nodes, -1)
            elif arr_is_per == "graph":
                arr = arr.view(num_graphs, -1)
            elif arr_is_per == "edge":
                arr = arr.view(num_edges, -1)

            ana_mode = modes[ifield]
            # compute statistics
            if ana_mode == "count":
                # count integers
                uniq, counts = torch.unique(
                    torch.flatten(arr), return_counts=True, sorted=True
                )
                out.append((uniq, counts))
            elif ana_mode == "rms":
                # root-mean-square
                out.append((torch.sqrt(torch.mean(arr * arr)),))

            elif ana_mode == "mean_std":
                # mean and std
                if len(arr) < 2:
                    raise ValueError(
                        "Can't do per species standard deviation without at least two samples"
                    )
                mean = torch.mean(arr, dim=0)
                std = torch.std(arr, dim=0, unbiased=unbiased)
                out.append((mean, std))

            elif ana_mode == "absmax":
                out.append((arr.abs().max(),))

            elif ana_mode.startswith("per_species_"):
                # per-species
                algorithm_kwargs = kwargs.pop(field + ana_mode, {})

                ana_mode = ana_mode[len("per_species_") :]

                if atom_types is None:
                    atom_types = data_transformed[AtomicDataDict.ATOM_TYPE_KEY]

                results = self._per_species_statistics(
                    ana_mode,
                    arr,
                    arr_is_per=arr_is_per,
                    batch=data_transformed[AtomicDataDict.BATCH_KEY],
                    atom_types=atom_types,
                    unbiased=unbiased,
                    algorithm_kwargs=algorithm_kwargs,
                )
                out.append(results)

            elif ana_mode.startswith("per_atom_"):
                # per-atom
                # only makes sense for a per-graph quantity
                if arr_is_per != "graph":
                    raise ValueError(
                        f"It doesn't make sense to ask for `{ana_mode}` since `{field}` is not per-graph"
                    )
                ana_mode = ana_mode[len("per_atom_") :]
                results = self._per_atom_statistics(
                    ana_mode=ana_mode,
                    arr=arr,
                    batch=data_transformed[AtomicDataDict.BATCH_KEY],
                    unbiased=unbiased,
                )
                out.append(results)

            else:
                raise NotImplementedError(f"Cannot handle statistics mode {ana_mode}")

        return out

    @staticmethod
    def _per_atom_statistics(
        ana_mode: str,
        arr: torch.Tensor,
        batch: torch.Tensor,
        unbiased: bool = True,
    ):
        """Compute "per-atom" statistics that are normalized by the number of atoms in the system.

        Only makes sense for a graph-level quantity (checked by .statistics).
        """
        # using unique_consecutive handles the non-contiguous selected batch index
        _, N = torch.unique_consecutive(batch, return_counts=True)
        N = N.unsqueeze(-1)
        assert N.ndim == 2
        assert N.shape == (len(arr), 1)
        assert arr.ndim >= 2
        data_dim = arr.shape[1:]
        arr = arr / N
        assert arr.shape == (len(N),) + data_dim
        if ana_mode == "mean_std":
            if len(arr) < 2:
                raise ValueError(
                    "Can't do standard deviation without at least two samples"
                )
            mean = torch.mean(arr, dim=0)
            std = torch.std(arr, unbiased=unbiased, dim=0)
            return mean, std
        elif ana_mode == "rms":
            return (torch.sqrt(torch.mean(arr.square())),)
        elif ana_mode == "absmax":
            return (torch.max(arr.abs()),)
        else:
            raise NotImplementedError(
                f"{ana_mode} for per-atom analysis is not implemented"
            )

    def _per_species_statistics(
        self,
        ana_mode: str,
        arr: torch.Tensor,
        arr_is_per: str,
        atom_types: torch.Tensor,
        batch: torch.Tensor,
        unbiased: bool = True,
        algorithm_kwargs: Optional[dict] = {},
    ):
        """Compute "per-species" statistics.

        For a graph-level quantity, models it as a linear combintation of the number of atoms 
        of different types in the graph.

        For a per-node quantity, computes the expected statistic but for each type instead of over all nodes.
        """
        N = bincount(atom_types.squeeze(-1), batch)
        assert N.ndim == 2  # [batch, n_type]
        N = N[(N > 0).any(dim=1)]  # deal with non-contiguous batch indexes
        assert arr.ndim >= 2
        if arr_is_per == "graph":

            if ana_mode != "mean_std":
                raise NotImplementedError(
                    f"{ana_mode} for per species analysis is not implemented for shape {arr.shape}"
                )

            N = N.type(self.dtype)

            return solver(N, arr, **algorithm_kwargs)

        elif arr_is_per == "node":
            arr = arr.type(self.dtype)

            if ana_mode == "mean_std":
                # There need to be at least two occurances of each atom type in the
                # WHOLE dataset, not in any given frame:
                if torch.any(N.sum(dim=0) < 2):
                    raise ValueError(
                        "Can't do per species standard deviation without at least two samples per species"
                    )
                mean = scatter_mean(arr, atom_types, dim=0)
                assert mean.shape[1:] == arr.shape[1:]  # [N, dims] -> [type, dims]
                assert len(mean) == N.shape[1]
                std = scatter_std(arr, atom_types, dim=0, unbiased=unbiased)
                assert std.shape == mean.shape
                return mean, std
            elif ana_mode == "rms":
                square = scatter_mean(arr.square(), atom_types, dim=0)
                assert square.shape[1:] == arr.shape[1:]  # [N, dims] -> [type, dims]
                assert len(square) == N.shape[1]
                dims = len(square.shape) - 1
                for i in range(dims):
                    square = square.mean(axis=-1)
                return (torch.sqrt(square),)
            else:
                raise NotImplementedError(
                    f"Statistics mode {ana_mode} isn't yet implemented for per_species_"
                )

        else:
            raise NotImplementedError

    def rdf(
        self, bin_width: float, stride: int = 1
    ) -> Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]:
        """Compute the pairwise RDFs of the dataset.

        Args:
            bin_width: width of the histogram bin in distance units
            stride: stride of data to include

        Returns:
            dictionary mapping `(type1, type2)` to tuples of `(hist, bin_edges)` in the style of `np.histogram`.
        """
        graph_selector, node_selector, edge_selector = self._selectors(stride=stride)

        data = AtomicData.to_AtomicDataDict(self.data)
        data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)

        results = {}

        types = self.type_mapper(data)[AtomicDataDict.ATOM_TYPE_KEY]

        edge_types = torch.index_select(
            types, 0, data[AtomicDataDict.EDGE_INDEX_KEY].reshape(-1)
        ).view(2, -1)
        types_center = edge_types[0].numpy()
        types_neigh = edge_types[1].numpy()

        r_max: float = self.AtomicData_options["r_max"]
        # + 1 to always have a zero bin at the end
        n_bins: int = int(math.ceil(r_max / bin_width)) + 1
        # +1 since these are bin_edges including rightmost
        bins = bin_width * np.arange(n_bins + 1)

        for type1, type2 in itertools.combinations_with_replacement(
            range(self.type_mapper.num_types), 2
        ):
            # Try to do as much of this as possible in-place
            mask = types_center == type1
            np.logical_and(mask, types_neigh == type2, out=mask)
            np.logical_and(mask, edge_selector, out=mask)
            mask = mask.astype(np.int32)
            results[(type1, type2)] = np.histogram(
                data[AtomicDataDict.EDGE_LENGTH_KEY],
                weights=mask,
                bins=bins,
                density=True,
            )
            # RDF is symmetric
            results[(type2, type1)] = results[(type1, type2)]

        return results

In [None]:
import tempfile
import functools
import itertools
from os.path import dirname, basename, abspath
from typing import Dict, Any, List, Union, Optional, Sequence

import ase
import ase.io

import torch
import torch.multiprocessing as mp


from nequip.utils.multiprocessing import num_tasks
#from .. import AtomicData
#from ..transforms import TypeMapper
#from ._base_datasets import AtomicInMemoryDataset


def _ase_dataset_reader(
    rank: int,
    world_size: int,
    tmpdir: str,
    ase_kwargs: dict,
    atomicdata_kwargs: dict,
    include_frames,
    global_options: dict,
) -> Union[str, List[AtomicData]]:
    """Parallel reader for all frames in file."""
    if world_size > 1:
        from nequip.utils._global_options import _set_global_options

        # ^ avoid import loop
        # we only `multiprocessing` if world_size > 1
        _set_global_options(global_options)
    # interleave--- in theory it is better for performance for the ranks
    # to read consecutive blocks, but the way ASE is written the whole
    # file gets streamed through all ranks anyway, so just trust the OS
    # to cache things sanely, which it will.
    # ASE handles correctly the case where there are no frames in index
    # and just gives an empty list, so that will succeed:
    index = slice(rank, None, world_size)
    if include_frames is None:
        # count includes 0, 1, ..., inf
        include_frames = itertools.count()

    datas = []
    # stream them from ase too using iread
    for i, atoms in enumerate(ase.io.iread(**ase_kwargs, index=index, parallel=False)):
        global_index = rank + (world_size * i)
        datas.append(
            (
                global_index,
                (
                    AtomicData.from_ase(atoms=atoms, **atomicdata_kwargs)
                    if global_index in include_frames
                    # in-memory dataset will ignore this later, but needed for indexing to work out
                    else None
                ),
            )
        )
    # Save to a tempfile---
    # there can be a _lot_ of tensors here, and rather than dealing with
    # the complications of running out of file descriptors and setting
    # sharing methods, since this is a one time thing, just make it simple
    # and avoid shared memory entirely.
    if world_size > 1:
        path = f"{tmpdir}/rank{rank}.pth"
        torch.save(datas, path)
        return path
    else:
        return datas


class ASEDataset(AtomicInMemoryDataset):
    """

    Args:
        ase_args (dict): arguments for ase.io.read
        include_keys (list): in addition to forces and energy, the keys that needs to
             be parsed into dataset
             The data stored in ase.atoms.Atoms.array has the lowest priority,
             and it will be overrided by data in ase.atoms.Atoms.info
             and ase.atoms.Atoms.calc.results. Optional
        key_mapping (dict): rename some of the keys to the value str. Optional

    Example: Given an atomic data stored in "H2.extxyz" that looks like below:

    ```H2.extxyz
    2
    Properties=species:S:1:pos:R:3 energy=-10 user_label=2.0 pbc="F F F"
     H       0.00000000       0.00000000       0.00000000
     H       0.00000000       0.00000000       1.02000000
    ```

    The yaml input should be

    ```
    dataset: ase
    dataset_file_name: H2.extxyz
    ase_args:
      format: extxyz
    include_keys:
      - user_label
    key_mapping:
      user_label: label0
    chemical_symbols:
      - H
    ```

    for VASP parser, the yaml input should be
    ```
    dataset: ase
    dataset_file_name: OUTCAR
    ase_args:
      format: vasp-out
    key_mapping:
      free_energy: total_energy
    chemical_symbols:
      - H
    ```

    """

    def __init__(
        self,
        root: str,
        ase_args: dict = {},
        file_name: Optional[str] = None,
        url: Optional[str] = None,
        AtomicData_options: Dict[str, Any] = {},
        include_frames: Optional[List[int]] = None,
        type_mapper: TypeMapper = None,
        key_mapping: Optional[dict] = None,
        include_keys: Optional[List[str]] = None,
    ):
        self.ase_args = {}
        self.ase_args.update(getattr(type(self), "ASE_ARGS", dict()))
        self.ase_args.update(ase_args)
        assert "index" not in self.ase_args
        assert "filename" not in self.ase_args

        self.include_keys = include_keys
        self.key_mapping = key_mapping

        super().__init__(
            file_name=file_name,
            url=url,
            root=root,
            AtomicData_options=AtomicData_options,
            include_frames=include_frames,
            type_mapper=type_mapper,
        )

    @classmethod
    def from_atoms_list(cls, atoms: Sequence[ase.Atoms], **kwargs):
        """Make an ``ASEDataset`` from a list of ``ase.Atoms`` objects.

        If `root` is not provided, a temporary directory will be used.

        Please note that this is a convinience method that does NOT avoid a round-trip to disk; the provided ``atoms`` will be written out to a file.

        Ignores ``kwargs["file_name"]`` if it is provided.

        Args:
            atoms
            **kwargs: passed through to the constructor
        Returns:
            The constructed ``ASEDataset``.
        """
        if "root" not in kwargs:
            tmpdir = tempfile.TemporaryDirectory()
            kwargs["root"] = tmpdir.name
        else:
            tmpdir = None
        kwargs["file_name"] = tmpdir.name + "/atoms.xyz"
        atoms = list(atoms)
        # Write them out
        ase.io.write(kwargs["file_name"], atoms, format="extxyz")
        # Read them in
        obj = cls(**kwargs)
        if tmpdir is not None:
            # Make it keep a reference to the tmpdir to keep it alive
            # When the dataset is garbage collected, the tmpdir will
            # be too, and will (hopefully) get deleted eventually.
            # Or at least by end of program...
            obj._tmpdir_ref = tmpdir
        return obj

    @property
    def raw_file_names(self):
        return [basename(self.file_name)]

    @property
    def raw_dir(self):
        return dirname(abspath(self.file_name))

    def get_data(self):
        ase_args = {"filename": self.raw_dir + "/" + self.raw_file_names[0]}
        ase_args.update(self.ase_args)

        # skip the None arguments
        kwargs = dict(
            include_keys=self.include_keys,
            key_mapping=self.key_mapping,
        )
        kwargs = {k: v for k, v in kwargs.items() if v is not None}
        kwargs.update(self.AtomicData_options)
        n_proc = num_tasks()
        with tempfile.TemporaryDirectory() as tmpdir:
            from nequip.utils._global_options import _get_latest_global_options

            # ^ avoid import loop
            reader = functools.partial(
                _ase_dataset_reader,
                world_size=n_proc,
                tmpdir=tmpdir,
                ase_kwargs=ase_args,
                atomicdata_kwargs=kwargs,
                include_frames=self.include_frames,
                # get the global options of the parent to initialize the worker correctly
                global_options=_get_latest_global_options(),
            )
            if n_proc > 1:
                # things hang for some obscure OpenMP reason on some systems when using `fork` method
                ctx = mp.get_context("forkserver")
                with ctx.Pool(processes=n_proc) as p:
                    # map it over the `rank` argument
                    datas = p.map(reader, list(range(n_proc)))
                # clean up the pool before loading the data
                datas = [torch.load(d) for d in datas]
                datas = sum(datas, [])
                # un-interleave the datas
                datas = sorted(datas, key=lambda e: e[0])
            else:
                datas = reader(rank=0)
                # datas here is already in order, stride 1 start 0
                # no need to un-interleave
        # return list of AtomicData:
        return [e[1] for e in datas]
    


In [46]:
import numpy as np
from os.path import dirname, basename, abspath
from typing import Dict, Any, List, Optional


#from .. import AtomicDataDict, _LONG_FIELDS, _NODE_FIELDS, _GRAPH_FIELDS
#from ..transforms import TypeMapper
#from ._base_datasets import AtomicInMemoryDataset


class NpzDataset(AtomicInMemoryDataset):
    """Load data from an npz file.

    To avoid loading unneeded data, keys are ignored by default unless they are in ``key_mapping``, ``include_keys``,
    or ``npz_fixed_fields_keys``.

    Args:
        key_mapping (Dict[str, str]): mapping of npz keys to ``AtomicData`` keys. Optional
        include_keys (list): the attributes to be processed and stored. Optional
        npz_fixed_field_keys: the attributes that only have one instance but apply to all frames. Optional
            Note that the mapped keys (as determined by the _values_ in ``key_mapping``) should be used in
            ``npz_fixed_field_keys``, not the original npz keys from before mapping. If an npz key is not
            present in ``key_mapping``, it is mapped to itself, and this point is not relevant.

    Example: Given a npz file with 10 configurations, each with 14 atoms.

        position: (10, 14, 3)
        force: (10, 14, 3)
        energy: (10,)
        Z: (14)
        user_label1: (10)        # per config
        user_label2: (10, 14, 3) # per atom

    The input yaml should be

    ```yaml
    dataset: npz
    dataset_file_name: example.npz
    include_keys:
      - user_label1
      - user_label2
    npz_fixed_field_keys:
      - cell
      - atomic_numbers
    key_mapping:
      position: pos
      force: forces
      energy: total_energy
      Z: atomic_numbers
    graph_fields:
      - user_label1
    node_fields:
      - user_label2
    ```

    """

    def __init__(
        self,
        root: str,
        key_mapping: Dict[str, str] = {
            "positions": AtomicDataDict.POSITIONS_KEY,
            "energy": AtomicDataDict.TOTAL_ENERGY_KEY,
            "force": AtomicDataDict.FORCE_KEY,
            "forces": AtomicDataDict.FORCE_KEY,
            "Z": AtomicDataDict.ATOMIC_NUMBERS_KEY,
            "atomic_number": AtomicDataDict.ATOMIC_NUMBERS_KEY,
        },
        include_keys: List[str] = [],
        npz_fixed_field_keys: List[str] = [],
        file_name: Optional[str] = None,
        url: Optional[str] = None,
        AtomicData_options: Dict[str, Any] = {},
        include_frames: Optional[List[int]] = None,
        type_mapper: TypeMapper = None,
    ):
        self.key_mapping = key_mapping
        self.npz_fixed_field_keys = npz_fixed_field_keys
        self.include_keys = include_keys

        super().__init__(
            file_name=file_name,
            url=url,
            root=root,
            AtomicData_options=AtomicData_options,
            include_frames=include_frames,
            type_mapper=type_mapper,
        )

    @property
    def raw_file_names(self):
        return [basename(self.file_name)]

    @property
    def raw_dir(self):
        return dirname(abspath(self.file_name))

    def get_data(self):

        data = np.load(self.raw_dir + "/" + self.raw_file_names[0], allow_pickle=True)

        # only the keys explicitly mentioned in the yaml file will be parsed
        keys = set(list(self.key_mapping.keys()))
        keys.update(self.npz_fixed_field_keys)
        keys.update(self.include_keys)
        keys = keys.intersection(set(list(data.keys())))

        mapped = {self.key_mapping.get(k, k): data[k] for k in keys}

        for intkey in _LONG_FIELDS:
            if intkey in mapped:
                mapped[intkey] = mapped[intkey].astype(np.int64)

        fields = {k: v for k, v in mapped.items() if k not in self.npz_fixed_field_keys}
        num_examples, num_atoms, n_dim = fields[AtomicDataDict.POSITIONS_KEY].shape
        assert n_dim == 3

        # now we replicate and add the fixed fields:
        for fixed_field in self.npz_fixed_field_keys:
            orig = mapped[fixed_field]
            if fixed_field in _NODE_FIELDS:
                assert orig.ndim >= 1  # [n_atom, feature_dims]
                assert orig.shape[0] == num_atoms
                replicated = np.expand_dims(orig, 0)
                replicated = np.tile(
                    replicated,
                    (num_examples,) + (1,) * len(replicated.shape[1:]),
                )  # [n_example, n_atom, feature_dims]
            elif fixed_field in _GRAPH_FIELDS:
                # orig is [feature_dims]
                replicated = np.expand_dims(orig, 0)
                replicated = np.tile(
                    replicated,
                    (num_examples,) + (1,) * len(replicated.shape[1:]),
                )  # [n_example, feature_dims]
            else:
                raise KeyError(
                    f"npz_fixed_field_keys contains `{fixed_field}`, but it isn't registered as a node or graph field"
                )
            fields[fixed_field] = replicated
        return fields

### Dataset_from_config

In [48]:
import inspect
from importlib import import_module

from nequip import data
#from nequip.data.transforms import TypeMapper
#from nequip.data import AtomicDataset
from nequip.utils import instantiate, get_w_prefix


def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
    """initialize database based on a config instance

    It needs dataset type name (case insensitive),
    and all the parameters needed in the constructor.

    Examples see tests/data/test_dataset.py TestFromConfig
    and tests/datasets/test_simplest.py

    Args:

    config (dict, nequip.utils.Config): dict/object that store all the parameters
    prefix (str): Optional. The prefix of all dataset parameters

    Return:

    dataset (nequip.data.AtomicDataset)
    """

    config_dataset = config.get(prefix, None)
    if config_dataset is None:
        raise KeyError(f"Dataset with prefix `{prefix}` isn't present in this config!")

    if inspect.isclass(config_dataset):
        # user define class
        class_name = config_dataset
    else:
        try:
            module_name = ".".join(config_dataset.split(".")[:-1])
            class_name = ".".join(config_dataset.split(".")[-1:])
            class_name = getattr(import_module(module_name), class_name)
        except Exception:
            # ^ TODO: don't catch all Exception
            # default class defined in nequip.data or nequip.dataset
            dataset_name = config_dataset.lower()

            class_name = None
            for k, v in inspect.getmembers(data, inspect.isclass):
                if k.endswith("Dataset"):
                    if k.lower() == dataset_name:
                        class_name = v
                    if k[:-7].lower() == dataset_name:
                        class_name = v
                elif k.lower() == dataset_name:
                    class_name = v

    if class_name is None:
        raise NameError(f"dataset type {dataset_name} does not exists")

    # if dataset r_max is not found, use the universal r_max
    atomicdata_options_key = "AtomicData_options"
    prefixed_eff_key = f"{prefix}_{atomicdata_options_key}"
    config[prefixed_eff_key] = get_w_prefix(
        atomicdata_options_key, {}, prefix=prefix, arg_dicts=config
    )
    config[prefixed_eff_key]["r_max"] = get_w_prefix(
        "r_max",
        prefix=prefix,
        arg_dicts=[config[prefixed_eff_key], config],
    )

    # Build a TypeMapper from the config
    type_mapper, _ = instantiate(TypeMapper, prefix=prefix, optional_args=config)

    instance, _ = instantiate(
        class_name,
        prefix=prefix,
        positional_args={"type_mapper": type_mapper},
        optional_args=config,
    )

    return instance

In [53]:
instantiate(TypeMapper, prefix='dataset', optional_args=config)

(<__main__.TypeMapper at 0x7fae969c07c0>,
 {'type_names': None,
  'chemical_symbol_to_type': None,
  'type_to_chemical_symbol': None,
  'chemical_symbols': ['H', 'C']})

In [49]:
config = Config.from_file('./configs/NaBr.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")


dataset[0]

AtomicData(atom_types=[64, 1], cell=[3, 3], edge_cell_shift=[1146, 3], edge_index=[2, 1146], forces=[64, 3], free_energy=[1], pbc=[3], pos=[64, 3], stress=[3, 3], total_energy=[1])

In [50]:
config = Config.from_file('./configs/example.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")


dataset[0]

Downloading http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip
Processing dataset...
Done!


AtomicData(atom_types=[15, 1], cell=[3, 3], edge_cell_shift=[152, 3], edge_index=[2, 152], forces=[15, 3], pbc=[3], pos=[15, 3], total_energy=[1])

### Dataloader

In [54]:
from typing import List, Optional, Iterator

import torch
from torch.utils.data import Sampler

from nequip.utils.torch_geometric import Batch, Data, Dataset


class Collater(object):
    """Collate a list of ``AtomicData``.

    Args:
        exclude_keys: keys to ignore in the input, not copying to the output
    """

    def __init__(
        self,
        exclude_keys: List[str] = [],
    ):
        self._exclude_keys = set(exclude_keys)

    @classmethod
    def for_dataset(
        cls,
        dataset,
        exclude_keys: List[str] = [],
    ):
        """Construct a collater appropriate to ``dataset``."""
        return cls(
            exclude_keys=exclude_keys,
        )

    def collate(self, batch: List[Data]) -> Batch:
        """Collate a list of data"""
        out = Batch.from_data_list(batch, exclude_keys=self._exclude_keys)
        return out

    def __call__(self, batch: List[Data]) -> Batch:
        """Collate a list of data"""
        return self.collate(batch)

    @property
    def exclude_keys(self):
        return list(self._exclude_keys)


class DataLoader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size: int = 1,
        shuffle: bool = False,
        exclude_keys: List[str] = [],
        **kwargs,
    ):
        if "collate_fn" in kwargs:
            del kwargs["collate_fn"]

        super(DataLoader, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=Collater.for_dataset(dataset, exclude_keys=exclude_keys),
            **kwargs,
        )


class PartialSampler(Sampler[int]):
    r"""Samples elements without replacement, but divided across a number of calls to `__iter__`.

    To ensure deterministic reproducibility and restartability, dataset permutations are generated
    from a combination of the overall seed and the epoch number. As a result, the caller must
    tell this sampler the epoch number before each time `__iter__` is called by calling
    `my_partial_sampler.step_epoch(epoch_number_about_to_run)` each time.

    This sampler decouples epochs from the dataset size and cycles through the dataset over as
    many (partial) epochs as it may take. As a result, the _dataset_ epoch can change partway
    through a training epoch.

    Args:
        data_source (Dataset): dataset to sample from
        shuffle (bool): whether to shuffle the dataset each time the _entire_ dataset is consumed
        num_samples_per_epoch (int): number of samples to draw in each call to `__iter__`.
            If `None`, defaults to `len(data_source)`.
        generator (Generator): Generator used in sampling.
    """

    data_source: Dataset
    num_samples_per_epoch: int
    shuffle: bool
    _epoch: int
    _prev_epoch: int

    def __init__(
        self,
        data_source: Dataset,
        shuffle: bool = True,
        num_samples_per_epoch: Optional[int] = None,
        generator=None,
    ) -> None:
        self.data_source = data_source
        self.shuffle = shuffle
        if num_samples_per_epoch is None:
            num_samples_per_epoch = self.num_samples_total
        self.num_samples_per_epoch = num_samples_per_epoch
        assert self.num_samples_per_epoch <= self.num_samples_total
        assert self.num_samples_per_epoch >= 1
        self.generator = generator
        self._epoch = None
        self._prev_epoch = None

    @property
    def num_samples_total(self) -> int:
        # dataset size might change at runtime
        return len(self.data_source)

    def step_epoch(self, epoch: int) -> None:
        self._epoch = epoch

    def __iter__(self) -> Iterator[int]:
        assert self._epoch is not None
        assert (self._prev_epoch is None) or (self._epoch == self._prev_epoch + 1)
        assert self._epoch >= 0

        full_epoch_i, start_sample_i = divmod(
            # how much data we've already consumed:
            self._epoch * self.num_samples_per_epoch,
            # how much data there is the dataset:
            self.num_samples_total,
        )

        if self.shuffle:
            temp_rng = torch.Generator()
            # Get new randomness for each _full_ time through the dataset
            # This is deterministic w.r.t. the combination of dataset seed and epoch number
            # Both of which persist across restarts
            # (initial_seed() is restored by set_state())
            temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i)
            full_order_this = torch.randperm(self.num_samples_total, generator=temp_rng)
            # reseed the generator for the _next_ epoch to get the shuffled order of the
            # _next_ dataset epoch to pad out this one for completing any partial batches
            # at the end:
            temp_rng.manual_seed(self.generator.initial_seed() + full_epoch_i + 1)
            full_order_next = torch.randperm(self.num_samples_total, generator=temp_rng)
            del temp_rng
        else:
            full_order_this = torch.arange(self.num_samples_total)
            # without shuffling, the next epoch has the same sampling order as this one:
            full_order_next = full_order_this

        full_order = torch.cat((full_order_this, full_order_next), dim=0)
        del full_order_next, full_order_this

        this_segment_indexes = full_order[
            start_sample_i : start_sample_i + self.num_samples_per_epoch
        ]
        # because we cycle into indexes from the next dataset epoch,
        # we should _always_ be able to get num_samples_per_epoch
        assert len(this_segment_indexes) == self.num_samples_per_epoch
        yield from this_segment_indexes

        self._prev_epoch = self._epoch

    def __len__(self) -> int:
        return self.num_samples_per_epochb
    