In [129]:
from e3nn.o3 import spherical_harmonics
from e3nn.o3 import wigner_3j
from torch import Tensor
import torch
from torch import nn

In [130]:
r = torch.randn((10, 100, 3))
r_ang = r/torch.linalg.vector_norm(r, dim = -1).unsqueeze(-1)


Y_1 = spherical_harmonics('0e', r_ang, normalize = 'component')
Y_2 = spherical_harmonics('1o', r_ang, normalize = 'component')
Y_3 = spherical_harmonics('2e', r_ang, normalize = 'component')

features = torch.concatenate([Y_1, Y_2, Y_3], dim = -1)

features.shape

torch.Size([10, 100, 9])

In [131]:
Y_1.shape, Y_2.shape

(torch.Size([10, 100, 1]), torch.Size([10, 100, 3]))

In [132]:
r.shape

torch.Size([10, 100, 3])

In [133]:
l1, l2, l3 = 1, 2, 3
m1, m2, m3 = l1, l2, l3 # m1 = 0, m2 = 0, m3 = 0
C123 = wigner_3j(1, 2, 3)

In [134]:
C123_000 = C123[m1, m2, m3]

C123_000

tensor(0.2928)

In [135]:
C123.shape

torch.Size([3, 5, 7])

In [136]:
class weigner_3j_img_to_real():


    def __init__(self, l):
        self.l = l

    def __call__(self):
        matrix = torch.zeros((2*self.l + 1, 2*self.l + 1), dtype = torch.complex64)

        mult = 1
        for i in range(2*self.l + 1):
            
            if i < self.l:
                matrix[i, i] = 1.0j/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = 1/2**(1/2.)
            elif i == self.l:
                matrix[i, i] = 1.
            else:
                matrix[i, i] = (-1.0)*mult/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = (-1.0j)*mult/2**(1/2.)
                mult *= -1
        
        return matrix

In [137]:
weigner_3j_img_to_real(1)()*2**(1/2.)

tensor([[ 0.0000+1.0000j,  0.0000+0.0000j,  0.0000-1.0000j],
        [ 0.0000+0.0000j,  1.4142+0.0000j,  0.0000+0.0000j],
        [ 1.0000+0.0000j,  0.0000+0.0000j, -1.0000+0.0000j]])

In [138]:
def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        
        pass
    
    
    def __call__(self, l1, l2, l3, n1, n2, n3):
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
        
        weight = torch.zeros([n1, n2, n3])
        if tri_ineq(l1, l2, l3):
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            symbol_3j = wigner_3j(l1, l2, l3)
            return symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*W.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return W.view(*weight.shape)
        else:
            return torch.zeros((n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 1nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1;
        self.n1 = n1;
        
        weight = torch.zeros([n1])
        W = nn.Parameter(nn.init.uniform(weight))
        return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)


In [139]:
print(order_1_equivariant_tensor()(2, 3).shape)
print(order_1_equivariant_tensor()(2, 10).shape)

print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)
print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)


print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).sum())
print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).shape)



torch.Size([5, 3])
torch.Size([5, 10])
torch.Size([3, 4])
torch.Size([3, 4])
tensor(-11.5623, grad_fn=<SumBackward0>)
torch.Size([3, 5, 7, 4, 6, 8])


  W = nn.Parameter(nn.init.uniform(weight))


In [140]:
lmax = 3
w3j_big = [[[wigner_3j(i, j, k) if tri_ineq(i, j, k) else None for i in range (lmax)] for j in range(lmax)] for k in range(lmax)]


#len(w3j_big)/lmax**3

In [141]:
w3j_big

[[[tensor([[[1.]]]), None, None],
  [None, tensor([[[0.5774],
            [0.0000],
            [0.0000]],
   
           [[0.0000],
            [0.5774],
            [0.0000]],
   
           [[0.0000],
            [0.0000],
            [0.5774]]]), None],
  [None,
   None,
   tensor([[[0.4472],
            [0.0000],
            [0.0000],
            [0.0000],
            [0.0000]],
   
           [[0.0000],
            [0.4472],
            [0.0000],
            [0.0000],
            [0.0000]],
   
           [[0.0000],
            [0.0000],
            [0.4472],
            [0.0000],
            [0.0000]],
   
           [[0.0000],
            [0.0000],
            [0.0000],
            [0.4472],
            [0.0000]],
   
           [[0.0000],
            [0.0000],
            [0.0000],
            [0.0000],
            [0.4472]]])]],
 [[None,
   tensor([[[0.5774, 0.0000, 0.0000]],
   
           [[0.0000, 0.5774, 0.0000]],
   
           [[0.0000, 0.0000, 0.5774]]]),
   None],
  [

In [142]:
9261*27

250047

In [143]:
tri_ineq(0, 0, 0)

True

In [144]:
C123.shape

torch.Size([3, 5, 7])

### Loading qm9 dataset

In [145]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
#from nequip.utils.misc import get_default_device_name
#from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config


default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device='cuda',
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
#_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

In [146]:
config = Config.from_file('./configs/example_ETN.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset.type_mapper.num_types

config['num_types'] = dataset.type_mapper.num_types

In [147]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]

In [148]:
len(data)

10

In [149]:
i = 2

num_types = config['num_types']


atom_types_embed = data[i]['atom_types'][data[i]['edge_index'][0]]*num_types + data[i]['atom_types'][data[i]['edge_index'][1]]

atom_types_embed.shape

torch.Size([368, 1])

In [150]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)

In [151]:
from e3nn import o3

In [152]:
o3.Irreps.spherical_harmonics(2)

1x0ee+1x1oe+1x2ee

In [153]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)
from e3nn import o3

L = 2

irreps_edge_sh = o3.Irreps.spherical_harmonics(L)

rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                            'num_basis': 8},
                              cutoff_kwargs={'r_max': config['r_max']},
                              out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                              )

sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

data = [rbe(data[i]) for i in range(len(data))]

data = [sh(data[i]) for i in range(len(data))]


print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)

torch.Size([368, 8]) torch.Size([368, 9])


In [154]:
data[1].keys()

dict_keys(['edge_index', 'pos', 'edge_cell_shift', 'total_energy', 'cell', 'pbc', 'forces', 'atom_types', 'edge_vectors', 'edge_lengths', 'edge_embedding', 'edge_cutoff', 'edge_attrs'])

## Feature Gen

In [155]:
from torch_runstats.scatter import scatter

N_rad = 8

N_spec_rank = 4
N_rad_rank = 4

Q = data[i]['edge_embedding']

A = torch.randn(L + 1, N_spec_rank, num_types**2)
B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


a = A[:, :, atom_types_embed].squeeze(-1)

b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

Y = data[i]['edge_attrs']

#print(data[i]['edge_attrs'][:, slices])
F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                               b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)

print(a.shape, b.shape, F.shape)


species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

center_species = species[edge_center]
neighbor_species = species[edge_neighbor]

F = scatter(F, edge_center, dim=0, dim_size=len(species))

print(F.shape)

torch.Size([3, 4, 368]) torch.Size([368, 3, 4]) torch.Size([368, 9, 4])
torch.Size([21, 9, 4])


### Order 2 tensor

In [156]:
n1 = 4
n2 = 32
L1 = L2 = range(L + 1)

u_in = F.clone()
T_2 = [order_2_equivariant_tensor()(l, l, n1, n2) for l in L1]


u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n2))
for i, slices in enumerate(irreps_edge_sh.slices()):
    u_out[:, slices, :] = torch.einsum('ij,Nmi->Nmj', T_2[i], u_in[:, slices, :])

### Order 3 tensor

In [157]:
n3 = 32
L3 = range(L + 1)

T_3 = [[[order_3_equvariant_tensor()(l3, l1, l2, n3, n1, n2) for l2 in L2] for l1 in L1] for l3 in L3]

In [158]:
u_in = u_out

#T_2_tmp = torch.zeros((u_in.shape[0], u_in.shape[1], n2))

T_2_tmp = [[None for l1 in L1] for l3 in L3]

u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n3))


v = F.clone()
for l2, slices in enumerate(irreps_edge_sh.slices()):
    if l2 == 0:
        for l3 in L3:
            for l1 in L1:
                T_2_tmp[l3][l1] = torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
    else:
        for l3 in L3:
            for l1 in L1:
                T_2_tmp[l3][l1] += torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
    
    
for l3 in L3:    
    for l1, slices in enumerate(irreps_edge_sh.slices()):
        u_out[:, irreps_edge_sh.slices()[l3], :] += torch.einsum('Nabij,Nbj->Nai', T_2_tmp[l3][l1], v[:, slices, :])
        
        
u_out.min()

tensor(-23.5623, grad_fn=<MinBackward1>)

In [159]:
u_out.max()

tensor(19.3212, grad_fn=<MaxBackward1>)

In [160]:
u_out.shape

torch.Size([21, 9, 32])

In [161]:
max([ T_2[i][j].max() for i in [0, 1, 2] for j in [0, 1, 2] ])

tensor(0.4326, grad_fn=<MaxBackward1>)

In [162]:
T_2[1].shape

torch.Size([4, 32])

In [163]:
u_out.shape

torch.Size([21, 9, 32])

### Equivariance testing

In [164]:
x = torch.Tensor([[1, 12, 34]]*10)

edge_vec_plus = x / torch.linalg.norm(x, dim = 1).unsqueeze(-1)
edge_vec_minus = -x / torch.linalg.norm(x, dim = 1).unsqueeze(-1)

irreps_sh = o3.Irreps('1x0e + 1x1o + 1x2e') #o3.Irreps.spherical_harmonics(lmax=2)
irreps_sh_r = o3.Irreps('1x1o')
print(irreps_sh)


alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


sh_plus = o3.spherical_harmonics(irreps_sh, edge_vec_plus, normalize=True)

sh_plus_of_rot =  o3.spherical_harmonics(irreps_sh, edge_vec_plus @ rot_matrix_r, 
                                        normalize=True)

sh_plus_rot = o3.spherical_harmonics(irreps_sh, edge_vec_plus, 
                                        normalize=True) @ rot_matrix

sh_plus_of_rot_rot = o3.spherical_harmonics(irreps_sh, edge_vec_plus @ rot_matrix_r, 
                                        normalize=True) @ rot_matrix.T





sh_minus = o3.spherical_harmonics(irreps_sh, edge_vec_minus, normalize=True)
# normalize=True ensure that x is divided by |x| before computing the sh

sh_plus.pow(2).mean()  # should be close to 1

1x0ee+1x1oe+1x2ee


tensor(0.0796)

In [165]:
sh_plus_rot[0]

tensor([ 0.2821,  0.0076, -0.1719,  0.4573,  0.0159, -0.0060, -0.1983, -0.3597,
         0.4784])

In [166]:
print(sh_plus[0])

print(sh_plus_of_rot_rot[0])

tensor([ 0.2821,  0.0135,  0.1626,  0.4606,  0.0286,  0.0101, -0.2107,  0.3426,
         0.4850])
tensor([ 0.2821,  0.0135,  0.1626,  0.4606,  0.0286,  0.0101, -0.2107,  0.3426,
         0.4850])


In [167]:
def stage_0(data, i):
    

    from nequip.data import AtomicDataDict, AtomicDataset
    from nequip.nn.embedding import (
        OneHotAtomEncoding,
        SphericalHarmonicEdgeAttrs,
        RadialBasisEdgeEncoding,
    )
    from e3nn import o3

    
    #torch.manual_seed(32)
    
    L = 2

    irreps_edge_sh = o3.Irreps.spherical_harmonics(L)
    print(irreps_edge_sh)
    
    rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                                'num_basis': 8},
                                  cutoff_kwargs={'r_max': config['r_max']},
                                  out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                                  )

    sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

    data = [rbe(data[i]) for i in range(len(data))]

    data = [sh(data[i]) for i in range(len(data))]


    #print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)
    
    
    from torch_runstats.scatter import scatter

    N_rad = 8

    N_spec_rank = 4
    N_rad_rank = 4

    Q = data[i]['edge_embedding']

    A = torch.randn(L + 1, N_spec_rank, num_types**2)
    B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


    a = A[:, :, atom_types_embed].squeeze(-1)

    b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

    #print(data[i]['pos'][2])
    Y = data[i]['edge_attrs']
    #print(Y[2])
    #print(data[i]['edge_attrs'][:, slices])
    F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                                   b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)
    
    print(F[2, :, 0])
    #F = Y.unsqueeze(-1)
    #print(a.shape, b.shape, F.shape)


    species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
    edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
    edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

    center_species = species[edge_center]
    neighbor_species = species[edge_neighbor]

    F = scatter(F, edge_center, dim=0, dim_size=len(species))

    #return F
    
    return data[i]['edge_attrs'].unsqueeze(-1)

In [168]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]


import copy

data_new = [{key: torch.clone(el[key]) for key in el} for el in data]

irreps_sh = o3.Irreps('1x0e + 1x1o + 1x2e') #o3.Irreps.spherical_harmonics(lmax=2)
irreps_sh_r = o3.Irreps('1x1o')

alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


for i, el in enumerate(data_new):
    data_new[i]['pos'] = data_new[i]['pos'] @ rot_matrix_r

In [169]:
data[2]

{'edge_index': tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
           1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
           3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,
           4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,
           5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,
           6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
           6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
           7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
           8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,
           9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
          10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
  

In [170]:
F = stage_0(data, 2)

F_new = stage_0(data_new, 2)

#F = data[2]['edge_attrs'].unsqueeze(-1)
#F_new = data_new[2]['edge_attrs'].unsqueeze(-1)

F_new_rot = torch.einsum('Njn,jk->Nkn', F_new, rot_matrix.T)

1x0ee+1x1oe+1x2ee
tensor([-0.4745, -0.2149, -0.1613,  0.2874, -0.4099,  0.2300, -0.1470, -0.3077,
         0.1209], grad_fn=<SelectBackward0>)
1x0ee+1x1oe+1x2ee
tensor([-0.4563,  1.6473,  0.0662,  0.4577,  0.1852,  0.0268, -0.2067,  0.0074,
        -0.3077], grad_fn=<SelectBackward0>)


In [171]:
torch.allclose(F, F_new_rot, atol=1e-03)

True

In [172]:
F[0, :, 0]

tensor([ 1.0000, -1.6557, -0.1353,  0.4903, -1.0481,  0.2892, -1.0976, -0.0857,
        -1.6143])

In [173]:
F_new_rot[0, :, 0]

tensor([ 1.0000, -1.6557, -0.1353,  0.4903, -1.0481,  0.2892, -1.0976, -0.0857,
        -1.6143])

In [174]:
def stage_1(F):
    
    #torch.manual_seed(32)
    
    n1 = 4
    n2 = 32
    L1 = L2 = range(L + 1)

    u_in = F.clone()
    T_2 = [order_2_equivariant_tensor()(l, l, n1, n2) for l in L1]


    u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n2))
    for i, slices in enumerate(irreps_edge_sh.slices()):
        u_out[:, slices, :] = torch.einsum('ij,Nmi->Nmj', T_2[i], u_in[:, slices, :])
        
    return u_out

In [175]:
u_out = stage_1(F)

u_out_new = stage_1(F_new)

u_out_new_rot = torch.einsum('Njn,jk->Nkn', u_out_new, rot_matrix.T)

In [176]:
torch.allclose(u_out, u_out_new_rot, atol=1e-03)

False

In [177]:
list(range(3 - 2, 0, -1))

[1]

In [178]:
def stage_2(u_in, F):
    
    torch.manual_seed(32)
    
    n1 = 4
    n2 = 32
    L1 = L2 = range(L + 1)
    
    n3 = 32
    L3 = range(L + 1)

    T_3 = [[[order_3_equvariant_tensor()(l3, l1, l2, n3, n1, n2) for l2 in L2] for l1 in L1] for l3 in L3]

    #T_2_tmp = torch.zeros((u_in.shape[0], u_in.shape[1], n2))

    T_2_tmp = [[None for l1 in L1] for l3 in L3]

    u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n3))


    v = F.clone()
    for l2, slices in enumerate(irreps_edge_sh.slices()):
        if l2 == 0:
            for l3 in L3:
                for l1 in L1:
                    T_2_tmp[l3][l1] = torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
        else:
            for l3 in L3:
                for l1 in L1:
                    T_2_tmp[l3][l1] += torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])


    for l3 in L3:    
        for l1, slices in enumerate(irreps_edge_sh.slices()):
            u_out[:, irreps_edge_sh.slices()[l3], :] += torch.einsum('Nabij,Nbj->Nai', T_2_tmp[l3][l1], v[:, slices, :])

            
    return u_out

In [179]:
u_out_2 = stage_2(u_out, F)

u_out_2_new = stage_2(u_out_new, F_new)

u_out_2_new_rot = torch.einsum('Njn,jk->Nkn', u_out_2_new, rot_matrix.T)

In [180]:
torch.allclose(u_out_2, u_out_2_new_rot, atol=1e-03)

False

### Reshaping

In [219]:
def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self, l1, l2, l3, n1, n2, n3):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
    
        weight = torch.zeros([n1, n2, n3])
        self.weight = nn.Parameter(nn.init.kaiming_uniform_(weight))
        
        
        self.symbol_3j = wigner_3j(l1, l2, l3)
        
        
    def __call__(self):
        
        if tri_ineq(l1, l2, l3):
            return self.symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*self.weight.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
def contract_2_tensors(tensor_1, tensor_2):
    weight_new = (tensor_1.weight.flatten(end_dim = -2) @ tensor_2.weight.flatten(start_dim = 1)).view(tensor_1.n1, tensor_1.n2, tensor_2.n2,
                                                                                                       tensor_2.n3)
    
    
    return weight_new, tensor_1.symbol_3j, tensor_2.symbol_3j
    
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return W.view(*weight.shape)
        else:
            return torch.zeros((n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 1nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1;
        self.n1 = n1;
        
        weight = torch.zeros([n1])
        W = nn.Parameter(nn.init.uniform(weight))
        return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)


In [220]:
tensor_1 = order_3_equvariant_tensor(0, 0, 0, 4, 8, 32)
tensor_2 = order_3_equvariant_tensor(0, 0, 0, 32, 8, 8)

tensor_4th_order_weight = contract_2_tensors(tensor_1, tensor_2)[0]

tensor_4th_order_weight.shape

torch.Size([4, 8, 8, 8])

In [221]:
sum([max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3]) for l1 in range(10) for l2 in range(10) for l3 in range(10)])




505

In [782]:
torch.manual_seed(32)

class order_3_equvariant_tensor_free_params():
    
    def __init__(self):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        
        pass
    
    
    def __call__(self, l1, l2, l3, n1, n2, n3):
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
        
        weight = torch.zeros([n1, n2, n3])
        if tri_ineq(l1, l2, l3):
            nn.Parameter(nn.init.normal_(weight))
            #symbol_3j = wigner_3j(l1, l2, l3)
            return weight
        else:
            return torch.zeros((n1, n2, n3))

In [783]:
torch.manual_seed(32)

d = 5
lmax = 2
N_rank_ett = [4 for i in range(d-1)]
Nc = [10]*d
cores = [torch.zeros(lmax+1, N_rank_ett[i], lmax+1, Nc[i], lmax+1, N_rank_ett[i+1]) for i in range(d-2)]

for dd in range(d-2):
    for l1 in range(lmax + 1):
        for l2 in range(lmax + 1):
            for l3 in range(lmax + 1):
                cores[dd][l1, :, l2, :, l3, :] = order_3_equvariant_tensor_free_params()(l1,l2,l3, N_rank_ett[dd], Nc[dd], N_rank_ett[dd+1])
                
cores0 = torch.zeros(lmax+1, Nc[0], lmax+1, N_rank_ett[0])
coresd = torch.zeros(lmax+1, N_rank_ett[0], lmax+1, Nc[-1])

for i in range(lmax + 1):
    weight = torch.empty([Nc[0], N_rank_ett[0]])
    nn.Parameter(nn.init.normal_(weight))
    cores0[i, :, i, :] =  weight
    
    weight = torch.empty([N_rank_ett[-1], Nc[-1]])
    nn.Parameter(nn.init.normal_(weight))
    coresd[i, :, i, :] =  weight

cores = [cores0] + cores + [coresd]

N_rank_ett = [1] + N_rank_ett + [1]

In [784]:
N_rank_ett[d-2]

4

In [785]:
N_rank_ett

[1, 4, 4, 4, 4, 1]

In [786]:
for i in range(d - 2, 0, -1):
    print(cores[i].shape)
    print(N_rank_ett[i], N_rank_ett[i-1])

torch.Size([3, 4, 3, 10, 3, 4])
4 4
torch.Size([3, 4, 3, 10, 3, 4])
4 4
torch.Size([3, 4, 3, 10, 3, 4])
4 1


In [787]:
def stage_0(data, i):
    

    from nequip.data import AtomicDataDict, AtomicDataset
    from nequip.nn.embedding import (
        OneHotAtomEncoding,
        SphericalHarmonicEdgeAttrs,
        RadialBasisEdgeEncoding,
    )
    from e3nn import o3

    
    torch.manual_seed(32)
    
    L = 2

    irreps_edge_sh = o3.Irreps.spherical_harmonics(L, p = 1)
    print(irreps_edge_sh)
    
    rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                                'num_basis': 8},
                                  cutoff_kwargs={'r_max': config['r_max']},
                                  out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                                  )

    sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

    data = [rbe(data[i]) for i in range(len(data))]

    data = [sh(data[i]) for i in range(len(data))]


    #print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)
    
    
    from torch_runstats.scatter import scatter

    N_rad = 8

    N_spec_rank = 4
    N_rad_rank = 4

    Q = data[i]['edge_embedding']

    A = torch.randn(L + 1, N_spec_rank, num_types**2)
    B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


    a = A[:, :, atom_types_embed].squeeze(-1)

    b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

    #print(data[i]['pos'][2])
    Y = data[i]['edge_attrs']
    #print(Y[2])
    #print(data[i]['edge_attrs'][:, slices])
    F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                                   b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)
    
    #print(F[2, :, 0])
    #F = Y.unsqueeze(-1)
    #print(a.shape, b.shape, F.shape)


    species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
    edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
    edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

    center_species = species[edge_center]
    neighbor_species = species[edge_neighbor]

    F = scatter(F, edge_center, dim=0, dim_size=len(species))

    #return F
    
    return data[i]['edge_attrs'].unsqueeze(-1)

In [788]:
from e3nn.o3 import wigner_3j

def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])


def ETN(data, i, cores, ranks, Nc, lmax = 2):

    torch.manual_seed(32)
    
    F = stage_0(data, i)
    
    # Init irreps
    irreps_edge_sh = o3.Irreps.spherical_harmonics(lmax, p = 1)
    slices = irreps_edge_sh.slices()
    #Init w3j
    w3j_big = [[[wigner_3j(l1, l2, l3) if tri_ineq(l1, l2, l3) else None for l3 in range (lmax+1)] for l2 in range(lmax+1)] for l1 in range(lmax+1)]
    
    
    u_out = torch.zeros((F.shape[0], F.shape[1], ranks[-2]))
    for i, slice in enumerate(slices):
        u_out[:, slice, :] = torch.einsum('ij,Nmj->Nmi', cores[-1][i, :, i, :], F[:, slice, :])    
    
    
    
    # Series third order tensors
    for i in range(d - 2, 0, -1):

        # TODO: now define localy, mb define for all, to ensure computational graph
        T_2_tmp = [[torch.zeros(F.shape[0], 2*l1+1, ranks[i], 2*l2+1, Nc[i], dtype=F.dtype, device=F.device) for l2 in range(lmax + 1)] for l1 in range(lmax + 1)] # result of first reduction of order 3 tensor

        #print(i)
        # First contraction with previous feature vector
        for l1 in range(lmax + 1):
            for l2 in range(lmax + 1):
                for l3, slice in enumerate(slices):
                    if tri_ineq(l1, l2, l3):
                        #T_3 = self.w3j_big[l1][l2][l3][..., None, None, None] * self.cores3[i][(l1, l2, l3)][None, None, None, ...]
                        T_2_tmp[l1][l2] += torch.einsum('abc,ijk,Nck->Naibj', w3j_big[l1][l2][l3], cores[i][l1, :, l2, :, l3, :], u_out[:, slice, :])


       # Second contraction with F vector
        u_out_new = torch.zeros((F.shape[0], F.shape[1], ranks[i]), dtype=F.dtype,
                        device=F.device) # temporary verctor output of etn

        for l1 in range(lmax + 1):    
            for l2, slice in enumerate(slices):
                u_out_new[:, slices[l1], :] += torch.einsum('Naibj,Nbj->Nai', T_2_tmp[l1][l2], F[:, slice, :])


        u_out = u_out_new

    
    u_final = torch.zeros((F.shape[0], F.shape[1], Nc[0]))
    for i, slices in enumerate(slices):
        u_final[:, slices, :] = torch.einsum('ij,Nmj->Nmi', cores[0][i, :, i, :], u_out[:, slices, :])
        
    
    
    out = (( u_final * F ).sum(dim = (-2, -1) )).unsqueeze(-1)
    
    return u_final, out

In [789]:
u_final, out = ETN(data, 2, cores, N_rank_ett, Nc)

1x0ee+1x1ee+1x2ee


In [790]:
out

tensor([[-17220.9043],
        [-17220.9219],
        [-17220.9277],
        [-17220.9023],
        [-17220.9336],
        [-17220.9043],
        [-17220.8926],
        [-17220.9023],
        [-17220.9219],
        [-17220.9082],
        [-17220.9219],
        [-17220.8945],
        [-17220.9238],
        [-17220.9219],
        [-17220.9258],
        [-17220.8926],
        [-17220.9023],
        [-17220.9414],
        [-17220.9180],
        [-17220.9023],
        [-17220.9004],
        [-17220.9336],
        [-17220.9160],
        [-17220.9375],
        [-17220.8984],
        [-17220.9160],
        [-17220.9434],
        [-17220.9102],
        [-17220.9023],
        [-17220.9062],
        [-17220.9355],
        [-17220.9238],
        [-17220.9102],
        [-17220.9180],
        [-17220.9258],
        [-17220.9297],
        [-17220.9180],
        [-17220.9004],
        [-17220.9141],
        [-17220.9219],
        [-17220.9180],
        [-17220.9238],
        [-17220.9062],
        [-1

In [791]:
import torch as tn

def QR(mat):
    """
    Compute the QR decomposition. Backend can be changed.

    Parameters
    ----------
    mat : tn array
        DESCRIPTION.

    Returns
    -------
    Q : the Q matrix
    R : the R matrix

    """
    Q,R = tn.linalg.qr(mat)
    return Q, R

In [792]:
torch.manual_seed(32)


d = 5
lmax = 2
N_rank_ett = [4 for i in range(d-1)]
Nc = [10]*d
cores = [torch.zeros(lmax+1, N_rank_ett[i], lmax+1, Nc[i], lmax+1, N_rank_ett[i+1]) for i in range(d-2)]

for dd in range(d-2):
    for l1 in range(lmax + 1):
        for l2 in range(lmax + 1):
            for l3 in range(lmax + 1):
                cores[dd][l1, :, l2, :, l3, :] = order_3_equvariant_tensor_free_params()(l1,l2,l3, N_rank_ett[dd], Nc[dd], N_rank_ett[dd+1])

                
#for dd in range(d-3):
#    cores_cur_tmp = cores[dd].flatten(0, -3)
#    cores_next_tmp = cores[dd+1].flatten(2)
    
#    for l in range(lmax + 1):
#        Q, R = QR(cores_cur_tmp[:, l, :])
#        cores[dd][..., l, :] = Q.reshape(cores[dd][..., l, :].shape)#/(2*l + 1)
#        cores[dd+1][l] = (R @ cores_next_tmp[l]).reshape(cores[dd + 1][l].shape)#/(2*l + 1)
                
cores0 = torch.zeros(lmax+1, Nc[0], lmax+1, N_rank_ett[0])
coresd = torch.zeros(lmax+1, N_rank_ett[0], lmax+1, Nc[-1])


for i in range(lmax + 1):
    weight = torch.empty([Nc[0], N_rank_ett[0]])
    nn.Parameter(nn.init.normal_(weight))
    cores0[i, :, i, :] =  weight
    
    weight = torch.empty([N_rank_ett[-1], Nc[-1]])
    nn.Parameter(nn.init.normal_(weight))
    coresd[i, :, i, :] =  weight

cores = [cores0] + cores + [coresd]

N_rank_ett = [1] + N_rank_ett + [1]

In [793]:
def ortho_simple(cores):
    """Simple orthogonalization for cores
       works like in the paper"""
    
    d = len(cores)
    lmax = cores[0].shape[0] - 1
    
    # Core 0
    cores_cur_tmp = cores[0]
    cores_next_tmp = cores[1].flatten(2)
    
    for l in range(lmax + 1):
        Q, R = QR(cores_cur_tmp[l, :, l, :])
        cores[0][l, :, l, :] = Q.reshape(cores[0][l, :, l, :].shape)#/(2*l + 1)
        cores[1][l] = (R @ cores_next_tmp[l]).reshape(cores[1][l].shape)#/(2*l + 1)
        
    
    
    # Cores triple
    for dd in range(1, d-2):
        cores_cur_tmp = cores[dd].flatten(0, -3)
        cores_next_tmp = cores[dd+1].flatten(2)

        for l in range(lmax + 1):
            Q, R = QR(cores_cur_tmp[:, l, :])
            cores[dd][..., l, :] = Q.reshape(cores[dd][..., l, :].shape)#/(2*l + 1)
            cores[dd+1][l] = (R @ cores_next_tmp[l]).reshape(cores[dd + 1][l].shape)#/(2*l + 1)
   

    # Core d
    cores_cur_tmp = cores[-2].flatten(0, -3)
    cores_next_tmp = cores[-1]
    
    for l in range(lmax + 1):
        Q, R = QR(cores_cur_tmp[:, l, :])
        cores[-2][..., l, :] = Q.reshape(cores[-2][..., l, :].shape)#/(2*l + 1)
        cores[-1][l, :, l, :] = (R @ cores_next_tmp[l, :, l, :]).reshape(cores[-1][l, :, l, :].shape)#/(2*l + 1)
    
    return cores

In [794]:
cores = ortho_simple(cores)

In [795]:
cores[3].flatten(0, -3)[:, 2].T @ cores[3].flatten(0, -3)[:, 2]

tensor([[ 1.0000e+00,  3.7253e-09,  1.4901e-08, -7.4506e-09],
        [ 3.7253e-09,  1.0000e+00, -1.8626e-09,  2.2352e-08],
        [ 1.4901e-08, -1.8626e-09,  1.0000e+00, -3.7253e-09],
        [-7.4506e-09,  2.2352e-08, -3.7253e-09,  1.0000e+00]])

In [796]:
u_final_new, out_new = ETN(data, 2, cores, N_rank_ett, Nc)

1x0ee+1x1ee+1x2ee


In [797]:
torch.allclose(out, out_new)

True

In [803]:
u_final[1, :, 0] / u_final_new[1, :, 0]

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [798]:
torch.allclose(u_final, u_final_new)

False

In [805]:
out.sum()

tensor(-6337296.5000)

In [806]:
out_new.sum()

tensor(-6337292.5000)

In [807]:
out_new[0], out[0]

(tensor([-17220.9023]), tensor([-17220.9043]))

In [579]:
torch.manual_seed(32)

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]


import copy

data_new = [{key: torch.clone(el[key]) for key in el} for el in data]

irreps_sh = o3.Irreps.spherical_harmonics(lmax=2, p = 1)
irreps_sh_r = o3.Irreps('1x1o')

alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


for i, el in enumerate(data_new):
    data_new[i]['pos'] = data_new[i]['pos'] @ rot_matrix_r

In [580]:
u_final_rot, out_new_rot = ETN(data_new, 2, cores, N_rank_ett, Nc)

1x0ee+1x1ee+1x2ee


In [581]:
out_new_rot 

tensor([[-17220.9258],
        [-17220.9023],
        [-17220.9062],
        [-17220.9160],
        [-17220.8984],
        [-17220.8984],
        [-17220.9023],
        [-17220.8965],
        [-17220.9043],
        [-17220.9297],
        [-17220.9141],
        [-17220.9297],
        [-17220.9258],
        [-17220.9023],
        [-17220.9277],
        [-17220.8867],
        [-17220.9238],
        [-17220.9004],
        [-17220.9082],
        [-17220.9258],
        [-17220.9102],
        [-17220.9004],
        [-17220.9082],
        [-17220.9023],
        [-17220.9336],
        [-17220.8867],
        [-17220.9258],
        [-17220.9023],
        [-17220.8984],
        [-17220.9102],
        [-17220.9453],
        [-17220.8984],
        [-17220.9277],
        [-17220.9121],
        [-17220.9336],
        [-17220.9316],
        [-17220.9023],
        [-17220.9336],
        [-17220.9219],
        [-17220.9043],
        [-17220.8984],
        [-17220.9062],
        [-17220.8867],
        [-1

In [582]:
out_new.sum(), out_new_rot.sum()

(tensor(-6337295.), tensor(-6337295.))

In [583]:
torch.allclose(out_new_rot, out_new)

True

In [584]:
u_final[1, :, 0]

tensor([  227.6372,  3350.7109, -7468.5576,  4818.2109,    34.0287,   -52.7466,
           46.9218,   -75.8481,    12.6338])

In [585]:
u_final_rot[0, :, 0]

tensor([  227.6365, -1593.0637, -8149.4727,  4611.9434,   -15.4862,    27.3645,
           66.3346,   -79.2211,    19.7419])

In [586]:
ETN_out_rot_rot = torch.einsum('Njn,jk->Nkn', u_final_rot, rot_matrix.T) # rotated ETN out features from rotated positions

In [587]:
torch.allclose(ETN_out_rot_rot, u_final)

False

In [588]:
ETN_out_rot_rot[1, :, 0]

tensor([  227.6384,  3350.7002, -7468.5659,  4818.1924,    34.0282,   -52.7467,
           46.9226,   -75.8486,    12.6338])

In [589]:
from itertools import product

L1 = range(3)
L2 = range(3)
L3 = range(3)

instr = []
for l3, l1, l2 in product(L3, L1, L2):
    if tri_ineq(l1, l2, l3):
        instr.append((l1, l2, l3))

In [590]:
len(instr)

15

In [591]:
irreps_edge_sh = o3.Irreps.spherical_harmonics(2, p=1)

In [592]:
from typing import Optional, List, Tuple

# Assume irreps does not change 
print(irreps_edge_sh)
base_in1 = o3.Irreps([el[1] for el in irreps_edge_sh])
base_in2 = o3.Irreps([el[1] for el in irreps_edge_sh])
base_out = o3.Irreps([el[1] for el in irreps_edge_sh])


# Building instructions
instructions: List[Tuple[int, int, int]] = []
tmp_i_out: int = 0
for i_out, (_, ir_out) in enumerate(base_out):
    for i_1, (_, ir_in1) in enumerate(base_in1):
        for i_2, (_, ir_in2) in enumerate(base_in2):
            if ir_out in ir_in1 * ir_in2:
                instructions.append((i_1, i_2, i_out))

                tmp_i_out += 1
                
                
instructions_1 = [(0, l, l) for l in range(irreps_edge_sh.lmax + 1)]
instructions_d = [(l, l, 0) for l in range(irreps_edge_sh.lmax + 1)]

1x0ee+1x1ee+1x2ee


In [593]:
instructions == instr

True

In [717]:
torch.manual_seed(40)

from torch.nn import Parameter, ParameterList

d = 5
lmax = 2
N_rank_ett = [4 for i in range(d-1)]
Nc = [10]*d

instruction_list = [instructions_1] + [instructions for i in range(d - 2)] + [instructions_d]

num_paths = len(instr)

cores = [torch.zeros(lmax+1, N_rank_ett[i], lmax+1, Nc[i], lmax+1, N_rank_ett[i+1]) for i in range(d-2)]

cores0 = Parameter(torch.empty(lmax+1, 1, Nc[0], N_rank_ett[0]).normal_())
coresd = Parameter(torch.empty(lmax+1, N_rank_ett[-1], Nc[-1], 1).normal_())

cores = ParameterList([cores0] 
                      + [Parameter(torch.empty(num_paths, N_rank_ett[r], Nc[r+1], N_rank_ett[r+1]).normal_()) for r in range(d - 2)] 
                      + [coresd]) 

N_rank_ett = [1] + N_rank_ett + [1]

In [718]:
def convert_to_dense(cores_old, instruction_list):
    """Converts the path matrix to dense format 
       indexed as l1, :,l2, :, l3, :"""
    
    lmax = cores_old[0].shape[0] - 1
    d = len(cores_old)
    
    cores0 = torch.zeros(lmax + 1, cores_old[0].shape[2], 
                         lmax + 1, cores_old[0].shape[-1])
    
    coresd = torch.zeros(lmax + 1, cores_old[-1].shape[1],
                         lmax + 1, cores_old[-1].shape[2])

    for i in range(lmax + 1):
        cores0[i, :, i, :] =  cores_old[0][i, 0]

        coresd[i, :, i, :] =  cores_old[-1][i, :, :, 0]
    
    cores = [torch.zeros(lmax+1, cores_old[i+1].shape[1], 
                         lmax+1, cores_old[i+1].shape[2],
                         lmax+1, cores_old[i+1].shape[3]) for i in range(d-2)]

    for dd in range(d-2):
        for i, (l1, l2, l3) in enumerate(instruction_list[dd + 1]):
            cores[dd][l1, :, l2, :, l3, :] = cores_old[dd + 1][i]
    
    cores = [cores0] + cores + [coresd] 
    return cores

In [719]:
from e3nn.o3 import wigner_3j

def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])


def ETN(data, i, cores, ranks, Nc, instruction_list, lmax = 2):

    torch.manual_seed(40)
    
    F = stage_0(data, i)
    
    print(F[10, :, 0])
    # Init irreps
    irreps_edge_sh = o3.Irreps.spherical_harmonics(lmax, p = 1)
    slices = irreps_edge_sh.slices()
    #Init w3j
    w3j_big = [[[wigner_3j(l1, l2, l3) if tri_ineq(l1, l2, l3) else None for l3 in range (lmax+1)] for l2 in range(lmax+1)] for l1 in range(lmax+1)]
    
    cores = convert_to_dense(cores, instruction_list)
    
    u_out = torch.zeros((F.shape[0], F.shape[1], ranks[-2]))
    for i, slice in enumerate(slices):
        u_out[:, slice, :] = torch.einsum('ij,Nmj->Nmi', cores[-1][i, :, i, :], F[:, slice, :])    
    
    
    # Series third order tensors
    for i in range(d - 2, 0, -1):

        # TODO: now define localy, mb define for all, to ensure computational graph
        T_2_tmp = [[torch.zeros(F.shape[0], 2*l1+1, ranks[i], 2*l2+1, Nc[i], dtype=F.dtype, device=F.device) for l2 in range(lmax + 1)] for l1 in range(lmax + 1)] # result of first reduction of order 3 tensor

        #print(i)
        # First contraction with previous feature vector
        for l1 in range(lmax + 1):
            for l2 in range(lmax + 1):
                for l3, slice in enumerate(slices):
                    if tri_ineq(l1, l2, l3):
                        #T_3 = self.w3j_big[l1][l2][l3][..., None, None, None] * self.cores3[i][(l1, l2, l3)][None, None, None, ...]
                        T_2_tmp[l1][l2] += torch.einsum('abc,ijk,Nck->Naibj', w3j_big[l1][l2][l3], cores[i][l1, :, l2, :, l3, :], u_out[:, slice, :])


       # Second contraction with F vector
        u_out_new = torch.zeros((F.shape[0], F.shape[1], ranks[i]), dtype=F.dtype,
                        device=F.device) # temporary verctor output of etn

        for l1 in range(lmax + 1):    
            for l2, slice in enumerate(slices):
                u_out_new[:, slices[l1], :] += torch.einsum('Naibj,Nbj->Nai', T_2_tmp[l1][l2], F[:, slice, :])


        u_out = u_out_new

    
    u_final = torch.zeros((F.shape[0], F.shape[1], Nc[0]))
    for i, slices in enumerate(slices):
        u_final[:, slices, :] = torch.einsum('ij,Nmj->Nmi', cores[0][i, :, i, :], u_out[:, slices, :])
        
    
    
    out = (( u_final * F ).sum(dim = (-2, -1) )).unsqueeze(-1)
    
    return u_final, out

In [720]:
u_final, out = ETN(data, 2, cores, N_rank_ett, Nc, instruction_list)

1x0ee+1x1ee+1x2ee
tensor([ 1.0000, -1.2899,  1.0922,  0.3787, -0.6306, -1.8187,  0.2156,  0.5339,
        -0.9814])


In [721]:
out.sum()

tensor(-548798.9375, grad_fn=<SumBackward0>)

In [722]:
from allegro import lr_orthogonal

In [723]:
cores, _ = lr_orthogonal(cores, N_rank_ett, instruction_list)

In [724]:
all(list(map(lambda a, b: torch.allclose(a, b), cores_new_new, cores_new)))

True

In [725]:
u_final_ortho, out_ortho = ETN(data, 2, cores, N_rank_ett, Nc, instruction_list)

1x0ee+1x1ee+1x2ee
tensor([ 1.0000, -1.2899,  1.0922,  0.3787, -0.6306, -1.8187,  0.2156,  0.5339,
        -0.9814])


In [728]:
out_ortho.sum()

tensor(-548793., grad_fn=<SumBackward0>)

In [727]:
out[0], out_ortho[0]

(tensor([-1491.3076], grad_fn=<SelectBackward0>),
 tensor([-1491.2832], grad_fn=<SelectBackward0>))

In [729]:
u_final_ortho[0, 0]

tensor([ 5111.2969,  4467.6646, -3891.8503,  5520.1807,  6226.9316,  1672.3468,
        -4797.0845,  5222.7334, -4501.8105, -5394.5918],
       grad_fn=<SelectBackward0>)

In [730]:
u_final[0, 0]


tensor([ 5111.2969,  4467.6636, -3891.8525,  5520.1831,  6226.9321,  1672.3439,
        -4797.0864,  5222.7334, -4501.8081, -5394.5894],
       grad_fn=<SelectBackward0>)