In [1]:
from e3nn.o3 import spherical_harmonics
from e3nn.o3 import wigner_3j
from torch import Tensor
import torch
from torch import nn

In [2]:
r = torch.randn((10, 100, 3))
r_ang = r/torch.linalg.vector_norm(r, dim = -1).unsqueeze(-1)


Y_1 = spherical_harmonics('1e', r_ang, normalize = 'component')
Y_2 = spherical_harmonics('2e', r_ang, normalize = 'component')
Y_3 = spherical_harmonics('2e', r_ang, normalize = 'component')

features = torch.concatenate([Y_1, Y_2, Y_3], dim = -1)

features.shape

torch.Size([10, 100, 13])

In [3]:
Y_1.shape, Y_2.shape

(torch.Size([10, 100, 3]), torch.Size([10, 100, 5]))

In [4]:
r.shape

torch.Size([10, 100, 3])

In [5]:
l1, l2, l3 = 1, 2, 3
m1, m2, m3 = l1, l2, l3 # m1 = 0, m2 = 0, m3 = 0
C123 = wigner_3j(1, 2, 3)

In [6]:
C123_000 = C123[m1, m2, m3]

C123_000

tensor(0.2928)

In [7]:
C123.shape

torch.Size([3, 5, 7])

In [8]:
class weigner_3j_img_to_real():


    def __init__(self, l):
        self.l = l

    def __call__(self):
        matrix = torch.zeros((2*self.l + 1, 2*self.l + 1), dtype = torch.complex64)

        mult = 1
        for i in range(2*self.l + 1):
            
            if i < self.l:
                matrix[i, i] = 1.0j/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = 1/2**(1/2.)
            elif i == self.l:
                matrix[i, i] = 1.
            else:
                matrix[i, i] = (-1.0)*mult/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = (-1.0j)*mult/2**(1/2.)
                mult *= -1
        
        return matrix

In [9]:
weigner_3j_img_to_real(1)()*2**(1/2.)

tensor([[ 0.0000+1.0000j,  0.0000+0.0000j,  0.0000-1.0000j],
        [ 0.0000+0.0000j,  1.4142+0.0000j,  0.0000+0.0000j],
        [ 1.0000+0.0000j,  0.0000+0.0000j, -1.0000+0.0000j]])

In [10]:


def tri_ineq(l1, l2, l3):
    print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        
        pass
    
    
    def __call__(self, l1, l2, l3, n1, n2, n3):
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
        
        weight = torch.zeros([n1, n2, n3])
        if tri_ineq(l1, l2, l3):
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            symbol_3j = wigner_3j(1, 2, 3)
            return symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*W.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return torch.ones((2*l1 + 1, 2*l2 + 1, 1, 1))*W.view(1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, n1, n2))

        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return torch.ones((2*l1 + 1, 2*l2 + 1, 1, 1))*W.view(1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 1nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1;
        self.n1 = n1;
        
        weight = torch.zeros([n1])
        W = nn.Parameter(nn.init.uniform(weight))
        return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)


In [11]:
print(order_1_equivariant_tensor()(2, 3).shape)
print(order_1_equivariant_tensor()(2, 10).shape)

print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)
print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)


print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).sum())
print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).shape)



torch.Size([5, 3])
torch.Size([5, 10])
torch.Size([5, 5, 3, 4])
torch.Size([5, 5, 3, 4])
3 3
tensor(7.1833, grad_fn=<SumBackward0>)
3 3
torch.Size([3, 5, 7, 4, 6, 8])


  W = nn.Parameter(nn.init.uniform(weight))


In [12]:
tri_ineq(0, 0, 0)

0 0


True

In [13]:
C123.shape

torch.Size([3, 5, 7])

In [14]:
torch.einsum('iav,ibt,vtr->iabr', Y_1, Y_2, C123).shape

torch.Size([10, 100, 100, 7])

### Loading qm9 dataset

In [18]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
#from nequip.utils.misc import get_default_device_name
#from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config


default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device='cuda',
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
#_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

In [20]:
config = Config.from_file('./configs/example_SpinGNNPlus.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset[0]

Processing dataset...
Done!


AtomicData(atom_types=[19, 1], cell=[3, 3], edge_cell_shift=[340, 3], edge_index=[2, 340], pbc=[3], pos=[19, 3])

In [45]:
config

{'_jit_bailout_depth': 2, '_jit_fusion_strategy': [('DYNAMIC', 3)], '_jit_fuser': 'fuser1', 'root': 'results/qm9', 'tensorboard': False, 'wandb': False, 'model_builders': ['allegro.model.SpinGNNPlus', 'PerSpeciesRescale', 'ParaStressForceSpinForceOutput', 'RescaleEnergyEtc'], 'dataset_statistics_stride': 1, 'device': 'cuda', 'default_dtype': 'float32', 'model_dtype': 'float32', 'allow_tf32': True, 'verbose': 'debug', 'model_debug_mode': False, 'equivariance_test': False, 'grad_anomaly_mode': False, 'gpu_oom_offload': False, 'append': True, 'warn_unused': False, 'run_name': 'example', 'seed': 123456, 'dataset_seed': 123456, 'r_max': 6.0, 'avg_num_neighbors': 'auto', 'BesselBasis_trainable': True, 'PolynomialCutoff_p': 6, 'l_max': 2, 'parity': 'o3_full', 'num_layers': 2, 'env_embed_multiplicity': 64, 'embed_initial_edge': True, 'two_body_latent_mlp_latent_dimensions': [128, 256, 512, 1024], 'two_body_latent_mlp_nonlinearity': 'silu', 'two_body_latent_mlp_initialization': 'uniform', 'late

In [136]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]

In [137]:
len(data)

10

In [138]:
i = 2

num_types = len(config['chemical_symbols'])


atom_types_embed = data[i]['atom_types'][data[i]['edge_index'][0]]*num_types + data[i]['atom_types'][data[i]['edge_index'][1]]

atom_types_embed.shape

torch.Size([494, 1])

In [139]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)

In [140]:
from e3nn import o3

In [141]:
o3.Irreps.spherical_harmonics(2)

1x0ee+1x1oe+1x2ee

In [158]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)
from e3nn import o3

L = 2

irreps_edge_sh = o3.Irreps.spherical_harmonics(L)

rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                            'num_basis': 8},
                              cutoff_kwargs={'r_max': config['r_max']},
                              out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                              )

sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

data = [rbe(data[i]) for i in range(len(data))]

data = [sh(data[i]) for i in range(len(data))]


print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)

torch.Size([494, 8]) torch.Size([494, 9])


In [186]:
from torch_runstats.scatter import scatter

N_rad = 8

N_spec_rank = 4
N_rad_rank = 4

Q = data[i]['edge_embedding']

A = torch.randn(L + 1, N_spec_rank, num_types**2)
B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


a = A[:, :, atom_types_embed].squeeze(-1)

b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

Y = data[i]['edge_attrs']

#print(data[i]['edge_attrs'][:, slices])
F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                               b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)

print(a.shape, b.shape, F.shape)


species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

center_species = species[edge_center]
neighbor_species = species[edge_neighbor]

F = scatter(F, edge_center, dim=0, dim_size=len(species))

print(F.shape)

torch.Size([3, 4, 494]) torch.Size([494, 3, 4]) torch.Size([494, 9, 4])
torch.Size([23, 9, 4])
