In [1]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
from nequip.utils.misc import get_default_device_name
from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config




default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device=get_default_device_name(),
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())



In [2]:
config = Config.from_file('./configs/example.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset[0]

AtomicData(atom_types=[21, 1], cell=[3, 3], edge_cell_shift=[364, 3], edge_index=[2, 364], forces=[21, 3], pbc=[3], pos=[21, 3], total_energy=[1])

In [3]:
# Trainer
from nequip.train.trainer import Trainer

trainer = Trainer(model=None, **Config.as_dict(config))

# what is this
# to update wandb data?
config.update(trainer.params)

# = Train/test split =
trainer.set_dataset(dataset, validation_dataset)

# = Build model =
final_model = model_from_config(
    config=config, initialize=True, dataset=trainer.dataset_train
)

DEBUG:root:* Initialize Output
  ...generate file name results/aspirin/example/log
  ...open log file results/aspirin/example/log
  ...generate file name results/aspirin/example/metrics_epoch.csv
  ...open log file results/aspirin/example/metrics_epoch.csv
  ...generate file name results/aspirin/example/metrics_initialization.csv
  ...open log file results/aspirin/example/metrics_initialization.csv
  ...generate file name results/aspirin/example/metrics_batch_train.csv
  ...open log file results/aspirin/example/metrics_batch_train.csv
  ...generate file name results/aspirin/example/metrics_batch_val.csv
  ...open log file results/aspirin/example/metrics_batch_val.csv
  ...generate file name results/aspirin/example/best_model.pth
  ...generate file name results/aspirin/example/last_model.pth
  ...generate file name results/aspirin/example/trainer.pth
  ...generate file name results/aspirin/example/config.yaml
Torch device: cpu
instantiate Loss
...Loss_param = dict(
...   optional_args =

   optional_args :                                           out_field
   optional_args :                                mlp_output_dimension
...ScalarMLP_param = dict(
...   optional_args = {'mlp_nonlinearity': None, 'mlp_initialization': 'uniform', 'mlp_dropout_p': 0.0, 'mlp_batchnorm': False, 'field': 'edge_features', 'out_field': 'edge_energy', 'mlp_latent_dimensions': [128], 'mlp_output_dimension': 1},
...   positional_args = {'irreps_in': {'pos': 1x1o, 'edge_index': None, 'node_attrs': 3x0e, 'node_features': 3x0e, 'edge_embedding': 8x0e, 'edge_cutoff': 1x0e, 'edge_attrs': 1x0e+1x1o+1x2e, 'edge_features': 1024x0e}})
instantiate EdgewiseEnergySum
        all_args :                                           num_types
        all_args :                                   avg_num_neighbors
...EdgewiseEnergySum_param = dict(
...   optional_args = {'avg_num_neighbors': 17.211328506469727, 'normalize_edge_energy_sum': True, 'per_edge_species_scale': False, 'num_types': 3},
...   positiona

In [4]:
final_model

GraphModel(
  (model): RescaleOutput(
    (model): GradientOutput(
      (func): SequentialGraphNetwork(
        (one_hot): OneHotAtomEncoding()
        (radial_basis): RadialBasisEdgeEncoding(
          (basis): NormalizedBasis(
            (basis): BesselBasis()
          )
          (cutoff): PolynomialCutoff()
        )
        (spharm): SphericalHarmonicEdgeAttrs(
          (sh): SphericalHarmonics()
        )
        (allegro): Allegro_Module(
          (latents): ModuleList(
            (0-1): 2 x ScalarMLPFunction(
              (_forward): RecursiveScriptModule(original_name=GraphModule)
            )
          )
          (env_embed_mlps): ModuleList(
            (0-1): 2 x ScalarMLPFunction(
              (_forward): RecursiveScriptModule(original_name=GraphModule)
            )
          )
          (tps): ModuleList(
            (0-1): 2 x RecursiveScriptModule(original_name=GraphModule)
          )
          (linears): ModuleList(
            (0-1): 2 x RecursiveScriptMod

In [5]:
from e3nn import o3

l_max = 2

irreps_edge_sh = repr(
            o3.Irreps.spherical_harmonics(
                l_max, p=(-1)
            )
        )

In [6]:
irreps_edge_sh

'1x0e+1x1o+1x2e'

In [14]:
import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
    
from torch import nn
import math


data = AtomicData.to_AtomicDataDict(dataset[0])



# edge length embedding
torch.manual_seed(32)

num_basis = 8
r_max = 5


data_my = {key: torch.clone(data[key]) for key in data}
data_my = AtomicDataDict.with_edge_vectors(data_my, with_lengths=True)

edge_length = data_my['edge_lengths']

bessel_weights = (torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi)
bessel_weights = nn.Parameter(bessel_weights)

edge_length_embedding = 2/r_max*torch.sin(bessel_weights * edge_length.unsqueeze(-1) / r_max)/edge_length.unsqueeze(-1)

# cutoff
factor = 1/r_max
p = 6
    
x = edge_length * factor

cutoff = 1.0
cutoff = cutoff - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
cutoff = cutoff + (p * (p + 2.0) * torch.pow(x, p + 1.0))
cutoff = cutoff - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))
cutoff *= (x < 1.0)

cutoff = cutoff.unsqueeze(-1)

data_my['edge_embedding'] = edge_length_embedding * cutoff

# types embedding
num_classes = 3


edge_ind = data_my['edge_index']

types_embed = one_hot(dataset[0]['atom_types'], num_classes)
types_src = types_embed[edge_ind[0]].squeeze(1)
types_dst = types_embed[edge_ind[1]].squeeze(1)



# latent vector
latent_vector = torch.concatenate([types_src, types_dst, edge_length_embedding], dim = 1)

# MLP
invariant_layers = 2
invariant_neurons = 64
out_neurons = 32

fc = FullyConnectedNet(
    [latent_vector.shape[1]]
    + invariant_layers * [invariant_neurons]
    + [out_neurons],
    torch.nn.functional.silu)

latent_vector_out = fc(latent_vector)


data_my['scalar'] = latent_vector_out

In [15]:
latent_vector.shape

torch.Size([364, 14])

In [16]:
latent_vector_out.shape

torch.Size([364, 32])

In [25]:
from nequip.nn import AtomwiseLinear
from e3nn.o3 import Irreps


#data2_my = {key: torch.clone(data_my[key]) for key in data_my}

linear1 = o3.Linear('32x0e', '32x0e')

weight1 = linear1(latent_vector_out)
weight1.shape

torch.Size([364, 32])

In [34]:
from torch import nn
import math


torch.manual_seed(32)


l_max = 2
irreps_edge_sh = o3.Irreps.spherical_harmonics(2)

data2_my = {key: torch.clone(data2_my[key]) for key in data_my}
data2_my = AtomicDataDict.with_edge_vectors(data2_my, with_lengths=False)


harm_gen = o3.SphericalHarmonics(irreps_edge_sh, True, 'component')

edge_vec = data_my['edge_vectors']

harm_edge = harm_gen(edge_vec)
harm_edge.shape


data2_my['edge_features'] = harm_edge

In [39]:
from e3nn.o3 import TensorProduct, Linear, FullyConnectedTensorProduct
from torch_runstats.scatter import scatter

x = data2_my['edge_features']
edge_src = data2_my['edge_index'][1]
edge_dst = data2_my['edge_index'][0]


term_1 = harm_edge
edge_features = torch.einsum('ij,ib->ijb', weight1[edge_src], harm_edge[edge_src])

# TODO: Check if it really right result
edge_features = scatter(edge_features, edge_dst, dim=0, dim_size=len(x))
edge_features.shape

torch.Size([364, 32, 9])

In [45]:
from e3nn.o3 import TensorProduct, Linear, FullyConnectedTensorProduct

hidden_layer_irrep = o3.Irreps("32x0e + 32x0o + 32x1e + 32x1o + 32x2e + 32x2o")

irrep_in = o3.Irreps("1x0e + 1x1o + 1x2e")
irreps_edge = o3.Irreps("32x0e + 32x1o + 32x2o")


irreps_mid = []
instructions = []

# instructions means stuff for multiplicities
for i, (_, ir_in) in enumerate(irrep_in):
    for j, (mul, ir_edge) in enumerate(irreps_edge):
        for ir_out in ir_in * ir_edge:
            if ir_out in hidden_layer_irrep:
                k = len(irreps_mid)
                irreps_mid.append((mul, ir_out))
                instructions.append((i, j, k, "uvu", True))

# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_mid = o3.Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort()

fctp = FullyConnectedTensorProduct(
            irrep_in,
            irreps_edge,
            hidden_layer_irrep
        )

fctp

FullyConnectedTensorProduct(1x0e+1x1o+1x2e x 32x0e+32x1o+32x2o -> 32x0e+32x0o+32x1e+32x1o+32x2e+32x2o | 15360 paths | 15360 weights)

In [82]:
irrep_in.dim

9

In [83]:
irreps_edge.dim

288

In [89]:
torch.flatten(edge_features.transpose(2, 1), -2).shape

torch.Size([364, 288])

In [94]:
out = fctp(harm_edge, edge_features.transpose(2, 1).flatten(-2))

In [95]:
out.shape

torch.Size([364, 576])

In [70]:
harm_edge.unsqueeze(-1).shape

torch.Size([364, 9, 1])

In [76]:
(edge_features.transpose(2, 1)).shape

torch.Size([364, 9, 32])

In [73]:
edge_features.shape

torch.Size([364, 32, 9])