In [18]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
#from nequip.utils.misc import get_default_device_name
#from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config


default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device='cpu',
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
#_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

In [20]:
config = Config.from_file('./configs/example_ETN.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset[0]

search for AtomicData_options with prefix dataset
search for r_max with prefix dataset
          0_args :                                               r_max
instantiate TypeMapper
   optional_args :                             chemical_symbol_to_type
...TypeMapper_param = dict(
...   optional_args = {'type_names': None, 'chemical_symbol_to_type': {'H': 0, 'C': 1, 'O': 2}, 'type_to_chemical_symbol': None, 'chemical_symbols': None},
...   positional_args = {})
instantiate register_fields
...register_fields_param = dict(
...   optional_args = {'node_fields': [], 'edge_fields': [], 'graph_fields': [], 'long_fields': []},
...   positional_args = {})
instantiate NpzDataset
   optional_args :                                                root
   optional_args :                                         key_mapping
   optional_args :                                npz_fixed_field_keys
   optional_args :                                  AtomicData_options <-                         dataset_Atom

AtomicData(atom_types=[21, 1], cell=[3, 3], edge_cell_shift=[364, 3], edge_index=[2, 364], forces=[21, 3], pbc=[3], pos=[21, 3], total_energy=[1])

In [21]:
# Trainer
from nequip.train.trainer import Trainer
from e3nn import o3

trainer = Trainer(model=None, **Config.as_dict(config))

# what is this
# to update wandb data?
config.update(trainer.params)

# = Train/test split =
trainer.set_dataset(dataset, validation_dataset)

#config['model_input_fields'] = {'node_spin': o3.Irreps('1x1e')}
Nc = 10
N_rank_spec = 4
config['Nc'] = Nc
config['N_rank_spec'] = N_rank_spec
config['N_rank_ett'] = [4, 4, 4]
config['d'] = 4


# = Build model =
final_model = model_from_config(
    config=config, initialize=True, dataset=trainer.dataset_train
)

* Initialize Trainer
* Initialize Output
  ...generate file name results/aspirin/example/log
  ...open log file results/aspirin/example/log
  ...generate file name results/aspirin/example/metrics_epoch.csv
  ...open log file results/aspirin/example/metrics_epoch.csv
  ...generate file name results/aspirin/example/metrics_initialization.csv
  ...open log file results/aspirin/example/metrics_initialization.csv
  ...generate file name results/aspirin/example/metrics_batch_train.csv
  ...open log file results/aspirin/example/metrics_batch_train.csv
  ...generate file name results/aspirin/example/metrics_batch_val.csv
  ...open log file results/aspirin/example/metrics_batch_val.csv
  ...generate file name results/aspirin/example/best_model.pth
  ...generate file name results/aspirin/example/last_model.pth
  ...generate file name results/aspirin/example/trainer.pth
  ...generate file name results/aspirin/example/config.yaml
Torch device: cpu
instantiate Loss
...Loss_param = dict(
...   optio

Replace string dataset_forces_rms to 31.252248764038086
Initially outputs are globally scaled by: 31.252248764038086, total_energy are globally shifted by None.
PerSpeciesScaleShift's arguments were in dataset units; rescaling:
  Original scales: [H: 31.252249, C: 31.252249, O: 31.252249] shifts: [H: -19318.355469, C: -19318.355469, O: -19318.355469]
  New scales: [H: 1.000000, C: 1.000000, O: 1.000000] shifts: [H: -618.142883, C: -618.142883, O: -618.142883]


In [22]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

trainer.model = final_model
data0 = AtomicData.to_AtomicDataDict(dataset[0])

In [23]:
trainer.train()

Number of weights: 1316
Number of trainable weights: 1316
instantiate Adam
        all_args :                                             amsgrad <-                           optimizer_params.amsgrad
        all_args :                                                 eps <-                               optimizer_params.eps
        all_args :                                        weight_decay <-                      optimizer_params.weight_decay
        all_args :                                               betas <-                             optimizer_params.betas
...Adam_param = dict(
...   optional_args = {'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None},
...   positional_args = {'params': <generator object Module.parameters at 0x7fbece1b7060>, 'lr': 0.001})
instantiate ReduceLROnPlateau
        all_args :                                              factor 

      4   100        0.147        0.147     0.000185         8.48           12         6.93         8.93
      4   190         0.14        0.139     0.000492          8.6         11.7         11.9         14.6

validation
# Epoch batch         loss       loss_f       loss_e        f_mae       f_rmse        e_mae       e_rmse
      4    10        0.148        0.148     0.000491         8.56           12           12         14.5


  Train      #    Epoch      wal       LR       loss_f       loss_e         loss        f_mae       f_rmse        e_mae       e_rmse
! Train               4   66.744    0.001        0.144     0.000252        0.144         8.56         11.9         8.13         10.4
! Validation          4   66.744    0.001         0.13     0.000227         0.13         8.14         11.3         7.68         9.88
Wall time: 66.74657141198986
! Best model        4    0.130
Saved trainer to results/aspirin/example/trainer.pth
Saved last model to to results/aspirin/example/last_mo



  Train      #    Epoch      wal       LR       loss_f       loss_e         loss        f_mae       f_rmse        e_mae       e_rmse
! Train              11  179.485    0.001       0.0601     0.000231       0.0604         5.61         7.66         7.91         9.97
! Validation         11  179.485    0.001       0.0567     0.000207        0.057         5.44         7.44         7.47         9.44
Wall time: 179.48706013397896
! Best model       11    0.057
Saved trainer to results/aspirin/example/trainer.pth
Saved last model to to results/aspirin/example/last_model.pth

training
# Epoch batch         loss       loss_f       loss_e        f_mae       f_rmse        e_mae       e_rmse
     12   100       0.0526       0.0524     0.000182         5.11         7.16         7.34         8.84
     12   190       0.0526       0.0522     0.000383         5.17         7.14         10.1         12.8

validation
# Epoch batch         loss       loss_f       loss_e        f_mae       f_rmse        

In [6]:
import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
    
from torch import nn
import math

data_new = final_model(data0)

In [7]:
data_new['node_features_ETN'].shape

torch.Size([21, 9, 10])

In [8]:
data_new['node_features_F'].shape

torch.Size([21, 9, 10])

In [9]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = data0


import copy

data_rot = {key: torch.clone(data0[key]) for key in data0}

irreps_sh = o3.Irreps('1x0e + 1x1o + 1x2e') #o3.Irreps.spherical_harmonics(lmax=2)
irreps_sh_r = o3.Irreps('1x1o')

alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


data_rot['pos'] = data_rot['pos'] @ rot_matrix_r

In [10]:
torch.manual_seed(32)

F = final_model(data)['node_features_F']

torch.manual_seed(32)

F_rot =final_model(data_rot)['node_features_F']


F_rot_rot = torch.einsum('Njn,jk->Nkn', F_rot, rot_matrix.T)

if torch.allclose(F, F_rot_rot, atol=1e-05):
    print('F is equivariant')

F is equivariant


In [11]:
torch.manual_seed(32)

F = final_model(data)['node_features_ETN']


torch.manual_seed(32)

F_rot =final_model(data_rot)['node_features_ETN']


F_rot_rot = torch.einsum('Njn,jk->Nkn', F_rot, rot_matrix.T)

if torch.allclose(F, F_rot_rot, atol=1e-05):
    print('F is equivariant')

F is equivariant


In [12]:
torch.manual_seed(32)

F = final_model(data)['atomic_energy']


torch.manual_seed(32)

F_rot =final_model(data_rot)['atomic_energy']


if torch.allclose(F, F_rot, atol=1e-05):
    print('atomic energy is invariant')

atomic energy is invariant


In [13]:
torch.manual_seed(32)

F = final_model(data)['total_energy']


torch.manual_seed(32)

F_rot =final_model(data_rot)['total_energy']


if torch.allclose(F, F_rot, atol=1e-05):
    print('atomic energy is invariant')

atomic energy is invariant
