In [1]:
from e3nn.o3 import spherical_harmonics
from e3nn.o3 import wigner_3j
from torch import Tensor
import torch
from torch import nn

In [2]:
r = torch.randn((10, 100, 3))
r_ang = r/torch.linalg.vector_norm(r, dim = -1).unsqueeze(-1)


Y_1 = spherical_harmonics('0e', r_ang, normalize = 'component')
Y_2 = spherical_harmonics('1o', r_ang, normalize = 'component')
Y_3 = spherical_harmonics('2e', r_ang, normalize = 'component')

features = torch.concatenate([Y_1, Y_2, Y_3], dim = -1)

features.shape

torch.Size([10, 100, 9])

In [3]:
Y_1.shape, Y_2.shape

(torch.Size([10, 100, 1]), torch.Size([10, 100, 3]))

In [4]:
r.shape

torch.Size([10, 100, 3])

In [5]:
l1, l2, l3 = 1, 2, 3
m1, m2, m3 = l1, l2, l3 # m1 = 0, m2 = 0, m3 = 0
C123 = wigner_3j(1, 2, 3)

In [6]:
C123_000 = C123[m1, m2, m3]

C123_000

tensor(0.2928)

In [7]:
C123.shape

torch.Size([3, 5, 7])

In [8]:
class weigner_3j_img_to_real():


    def __init__(self, l):
        self.l = l

    def __call__(self):
        matrix = torch.zeros((2*self.l + 1, 2*self.l + 1), dtype = torch.complex64)

        mult = 1
        for i in range(2*self.l + 1):
            
            if i < self.l:
                matrix[i, i] = 1.0j/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = 1/2**(1/2.)
            elif i == self.l:
                matrix[i, i] = 1.
            else:
                matrix[i, i] = (-1.0)*mult/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = (-1.0j)*mult/2**(1/2.)
                mult *= -1
        
        return matrix

In [9]:
weigner_3j_img_to_real(1)()*2**(1/2.)

tensor([[ 0.0000+1.0000j,  0.0000+0.0000j,  0.0000-1.0000j],
        [ 0.0000+0.0000j,  1.4142+0.0000j,  0.0000+0.0000j],
        [ 1.0000+0.0000j,  0.0000+0.0000j, -1.0000+0.0000j]])

In [10]:
def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        
        pass
    
    
    def __call__(self, l1, l2, l3, n1, n2, n3):
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
        
        weight = torch.zeros([n1, n2, n3])
        if tri_ineq(l1, l2, l3):
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            symbol_3j = wigner_3j(l1, l2, l3)
            return symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*W.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return W.view(*weight.shape)
        else:
            return torch.zeros((n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 1nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1;
        self.n1 = n1;
        
        weight = torch.zeros([n1])
        W = nn.Parameter(nn.init.uniform(weight))
        return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)


In [11]:
print(order_1_equivariant_tensor()(2, 3).shape)
print(order_1_equivariant_tensor()(2, 10).shape)

print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)
print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)


print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).sum())
print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).shape)



torch.Size([5, 3])
torch.Size([5, 10])
torch.Size([3, 4])
torch.Size([3, 4])
tensor(-0.3470, grad_fn=<SumBackward0>)
torch.Size([3, 5, 7, 4, 6, 8])


  W = nn.Parameter(nn.init.uniform(weight))


In [12]:
tri_ineq(0, 0, 0)

True

In [13]:
C123.shape

torch.Size([3, 5, 7])

### Loading qm9 dataset

In [14]:
# reference example
from nequip.data import dataset_from_config
from nequip.utils import Config
#from nequip.utils.misc import get_default_device_name
#from nequip.utils.config import _GLOBAL_ALL_ASKED_FOR_KEYS

from nequip.model import model_from_config


default_config = dict(
    root="./",
    tensorboard=False,
    wandb=False,
    model_builders=[
        "SimpleIrrepsConfig",
        "EnergyModel",
        "PerSpeciesRescale",
        "StressForceOutput",
        "RescaleEnergyEtc",
    ],
    dataset_statistics_stride=1,
    device='cuda',
    default_dtype="float64",
    model_dtype="float32",
    allow_tf32=True,
    verbose="INFO",
    model_debug_mode=False,
    equivariance_test=False,
    grad_anomaly_mode=False,
    gpu_oom_offload=False,
    append=False,
    warn_unused=False,
    _jit_bailout_depth=2,  # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
    # Quote from eelison in PyTorch slack:
    # https://pytorch.slack.com/archives/CDZD1FANA/p1644259272007529?thread_ts=1644064449.039479&cid=CDZD1FANA
    # > Right now the default behavior is to specialize twice on static shapes and then on dynamic shapes.
    # > To reduce warmup time you can do something like setFusionStrartegy({{FusionBehavior::DYNAMIC, 3}})
    # > ... Although we would wouldn't really expect to recompile a dynamic shape fusion in a model,
    # > provided broadcasting patterns remain fixed
    # We default to DYNAMIC alone because the number of edges is always dynamic,
    # even if the number of atoms is fixed:
    _jit_fusion_strategy=[("DYNAMIC", 3)],
    # Due to what appear to be ongoing bugs with nvFuser, we default to NNC (fuser1) for now:
    # TODO: still default to NNC on CPU regardless even if change this for GPU
    # TODO: default for ROCm?
    _jit_fuser="fuser1",
)

# All default_config keys are valid / requested
#_GLOBAL_ALL_ASKED_FOR_KEYS.update(default_config.keys())

In [15]:
config = Config.from_file('./configs/example_SpinGNNPlus.yaml', defaults=default_config)
    

dataset = dataset_from_config(config, prefix="dataset")

validation_dataset = None

dataset[0]

AtomicData(atom_types=[19, 1], cell=[3, 3], edge_cell_shift=[340, 3], edge_index=[2, 340], pbc=[3], pos=[19, 3])

In [16]:
config

{'_jit_bailout_depth': 2, '_jit_fusion_strategy': [('DYNAMIC', 3)], '_jit_fuser': 'fuser1', 'root': 'results/qm9', 'tensorboard': False, 'wandb': False, 'model_builders': ['allegro.model.SpinGNNPlus', 'PerSpeciesRescale', 'ParaStressForceSpinForceOutput', 'RescaleEnergyEtc'], 'dataset_statistics_stride': 1, 'device': 'cuda', 'default_dtype': 'float32', 'model_dtype': 'float32', 'allow_tf32': True, 'verbose': 'debug', 'model_debug_mode': False, 'equivariance_test': False, 'grad_anomaly_mode': False, 'gpu_oom_offload': False, 'append': True, 'warn_unused': False, 'run_name': 'example', 'seed': 123456, 'dataset_seed': 123456, 'r_max': 6.0, 'avg_num_neighbors': 'auto', 'BesselBasis_trainable': True, 'PolynomialCutoff_p': 6, 'l_max': 2, 'parity': 'o3_full', 'num_layers': 2, 'env_embed_multiplicity': 64, 'embed_initial_edge': True, 'two_body_latent_mlp_latent_dimensions': [128, 256, 512, 1024], 'two_body_latent_mlp_nonlinearity': 'silu', 'two_body_latent_mlp_initialization': 'uniform', 'late

In [17]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]

In [18]:
len(data)

10

In [19]:
i = 2

num_types = len(config['chemical_symbols'])


atom_types_embed = data[i]['atom_types'][data[i]['edge_index'][0]]*num_types + data[i]['atom_types'][data[i]['edge_index'][1]]

atom_types_embed.shape

torch.Size([494, 1])

In [20]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)

In [21]:
from e3nn import o3

In [22]:
o3.Irreps.spherical_harmonics(2)

1x0ee+1x1oe+1x2ee

In [23]:
from nequip.data import AtomicDataDict, AtomicDataset
from nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    RadialBasisEdgeEncoding,
)
from e3nn import o3

L = 2

irreps_edge_sh = o3.Irreps.spherical_harmonics(L)

rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                            'num_basis': 8},
                              cutoff_kwargs={'r_max': config['r_max']},
                              out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                              )

sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

data = [rbe(data[i]) for i in range(len(data))]

data = [sh(data[i]) for i in range(len(data))]


print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)

torch.Size([494, 8]) torch.Size([494, 9])


## Feature Gen

In [24]:
from torch_runstats.scatter import scatter

N_rad = 8

N_spec_rank = 4
N_rad_rank = 4

Q = data[i]['edge_embedding']

A = torch.randn(L + 1, N_spec_rank, num_types**2)
B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


a = A[:, :, atom_types_embed].squeeze(-1)

b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

Y = data[i]['edge_attrs']

#print(data[i]['edge_attrs'][:, slices])
F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                               b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)

print(a.shape, b.shape, F.shape)


species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

center_species = species[edge_center]
neighbor_species = species[edge_neighbor]

F = scatter(F, edge_center, dim=0, dim_size=len(species))

print(F.shape)

torch.Size([3, 4, 494]) torch.Size([494, 3, 4]) torch.Size([494, 9, 4])
torch.Size([23, 9, 4])


### Order 2 tensor

In [25]:
n1 = 4
n2 = 32
L1 = L2 = range(L + 1)

u_in = F.clone()
T_2 = [order_2_equivariant_tensor()(l, l, n1, n2) for l in L1]


u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n2))
for i, slices in enumerate(irreps_edge_sh.slices()):
    u_out[:, slices, :] = torch.einsum('ij,Nmi->Nmj', T_2[i], u_in[:, slices, :])

### Order 3 tensor

In [26]:
n3 = 32
L3 = range(L + 1)

T_3 = [[[order_3_equvariant_tensor()(l3, l1, l2, n3, n1, n2) for l2 in L2] for l1 in L1] for l3 in L3]

In [27]:
u_in = u_out

#T_2_tmp = torch.zeros((u_in.shape[0], u_in.shape[1], n2))

T_2_tmp = [[None for l1 in L1] for l3 in L3]

u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n3))


v = F.clone()
for l2, slices in enumerate(irreps_edge_sh.slices()):
    if l2 == 0:
        for l3 in L3:
            for l1 in L1:
                T_2_tmp[l3][l1] = torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
    else:
        for l3 in L3:
            for l1 in L1:
                T_2_tmp[l3][l1] += torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
    
    
for l3 in L3:    
    for l1, slices in enumerate(irreps_edge_sh.slices()):
        u_out[:, irreps_edge_sh.slices()[l3], :] += torch.einsum('Nabij,Nbj->Nai', T_2_tmp[l3][l1], v[:, slices, :])
        
        
u_out.min()

tensor(-63.2352, grad_fn=<MinBackward1>)

In [28]:
u_out.max()

tensor(51.1680, grad_fn=<MaxBackward1>)

In [29]:
u_out.shape

torch.Size([23, 9, 32])

In [30]:
max([ T_2[i][j].max() for i in [0, 1, 2] for j in [0, 1, 2] ])

tensor(0.4300, grad_fn=<MaxBackward1>)

In [31]:
T_2[1].shape

torch.Size([4, 32])

In [32]:
u_out.shape

torch.Size([23, 9, 32])

### Equivariance testing

In [33]:
x = torch.Tensor([[1, 12, 34]]*10)

edge_vec_plus = x / torch.linalg.norm(x, dim = 1).unsqueeze(-1)
edge_vec_minus = -x / torch.linalg.norm(x, dim = 1).unsqueeze(-1)

irreps_sh = o3.Irreps('1x0e + 1x1o + 1x2e') #o3.Irreps.spherical_harmonics(lmax=2)
irreps_sh_r = o3.Irreps('1x1o')
print(irreps_sh)


alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


sh_plus = o3.spherical_harmonics(irreps_sh, edge_vec_plus, normalize=True)

sh_plus_of_rot =  o3.spherical_harmonics(irreps_sh, edge_vec_plus @ rot_matrix_r, 
                                        normalize=True)

sh_plus_rot = o3.spherical_harmonics(irreps_sh, edge_vec_plus, 
                                        normalize=True) @ rot_matrix

sh_plus_of_rot_rot = o3.spherical_harmonics(irreps_sh, edge_vec_plus @ rot_matrix_r, 
                                        normalize=True) @ rot_matrix.T





sh_minus = o3.spherical_harmonics(irreps_sh, edge_vec_minus, normalize=True)
# normalize=True ensure that x is divided by |x| before computing the sh

sh_plus.pow(2).mean()  # should be close to 1

1x0ee+1x1oe+1x2ee


tensor(0.0796)

In [34]:
sh_plus_rot[0]

tensor([ 0.2821, -0.4627,  0.1386, -0.0735,  0.1557, -0.2935, -0.2393, -0.0466,
        -0.4776])

In [35]:
print(sh_plus[0])

print(sh_plus_of_rot_rot[0])

tensor([ 0.2821,  0.0135,  0.1626,  0.4606,  0.0286,  0.0101, -0.2107,  0.3426,
         0.4850])
tensor([ 0.2821,  0.0135,  0.1626,  0.4606,  0.0286,  0.0101, -0.2107,  0.3426,
         0.4850])


In [36]:
def stage_0(data, i):
    

    from nequip.data import AtomicDataDict, AtomicDataset
    from nequip.nn.embedding import (
        OneHotAtomEncoding,
        SphericalHarmonicEdgeAttrs,
        RadialBasisEdgeEncoding,
    )
    from e3nn import o3

    
    torch.manual_seed(32)
    
    L = 2

    irreps_edge_sh = o3.Irreps.spherical_harmonics(L)

    rbe = RadialBasisEdgeEncoding(basis_kwargs={'r_max': config['r_max'], 
                                                'num_basis': 8},
                                  cutoff_kwargs={'r_max': config['r_max']},
                                  out_field=AtomicDataDict.EDGE_EMBEDDING_KEY,
                                  )

    sh = SphericalHarmonicEdgeAttrs(irreps_edge_sh=irreps_edge_sh)

    data = [rbe(data[i]) for i in range(len(data))]

    data = [sh(data[i]) for i in range(len(data))]


    #print(data[i]['edge_embedding'].shape, data[i]['edge_attrs'].shape)
    
    
    from torch_runstats.scatter import scatter

    N_rad = 8

    N_spec_rank = 4
    N_rad_rank = 4

    Q = data[i]['edge_embedding']

    A = torch.randn(L + 1, N_spec_rank, num_types**2)
    B = torch.randn(L + 1, N_rad_rank, N_rad, N_spec_rank)


    a = A[:, :, atom_types_embed].squeeze(-1)

    b = torch.einsum('Lrnk,LkE,En->ELr', B, a, Q)

    #print(data[i]['pos'][2])
    Y = data[i]['edge_attrs']
    #print(Y[2])
    #print(data[i]['edge_attrs'][:, slices])
    F = torch.concat([torch.einsum('Em,En->Emn', Y[:, slices],
                                   b[:, l]) for l, slices in enumerate(irreps_edge_sh.slices())], dim = -2)
    
    print(F[2, :, 0])
    #F = Y.unsqueeze(-1)
    #print(a.shape, b.shape, F.shape)


    species = data[i][AtomicDataDict.ATOM_TYPE_KEY].squeeze(-1)
    edge_center = data[i][AtomicDataDict.EDGE_INDEX_KEY][0]
    edge_neighbor = data[i][AtomicDataDict.EDGE_INDEX_KEY][1]

    center_species = species[edge_center]
    neighbor_species = species[edge_neighbor]

    F = scatter(F, edge_center, dim=0, dim_size=len(species))

    return F

In [37]:

import torch
from torch.nn.functional import one_hot
from nequip.data import AtomicData, AtomicDataDict
from torch.nn.functional import one_hot
from e3nn.nn import FullyConnectedNet
from allegro import with_edge_spin_length
from allegro import _keys
from torch import nn
import math

data = [AtomicData.to_AtomicDataDict(dataset[i]) for i in range(10)]


import copy

data_new = [{key: torch.clone(el[key]) for key in el} for el in data]

irreps_sh = o3.Irreps('1x0e + 1x1o + 1x2e') #o3.Irreps.spherical_harmonics(lmax=2)
irreps_sh_r = o3.Irreps('1x1o')

alpha, beta, gamma = o3.rand_angles(100)

rot_matrix = irreps_sh.D_from_angles(alpha[0], beta[0], gamma[0])
rot_matrix_r = irreps_sh_r.D_from_angles(alpha[0], beta[0], gamma[0])


for i, el in enumerate(data_new):
    data_new[i]['pos'] = data_new[i]['pos'] @ rot_matrix_r

In [38]:
F = stage_0(data, 2)

F_new = stage_0(data_new, 2)


F_new_rot = torch.einsum('Njn,jk->Nkn', F_new, rot_matrix.T)

tensor([-0.1360,  0.6565, -0.1762,  0.0175, -0.0983,  0.9870,  0.9094,  0.0264,
         1.8374], grad_fn=<SelectBackward0>)
tensor([-0.1360, -0.0921, -0.0290,  0.6730,  0.5290, -0.0228,  1.1326,  0.1664,
        -1.8966], grad_fn=<SelectBackward0>)


In [39]:
torch.allclose(F, F_new_rot, atol=1e-03)

True

In [40]:
def stage_1(F):
    
    torch.manual_seed(32)
    
    n1 = 4
    n2 = 32
    L1 = L2 = range(L + 1)

    u_in = F.clone()
    T_2 = [order_2_equivariant_tensor()(l, l, n1, n2) for l in L1]


    u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n2))
    for i, slices in enumerate(irreps_edge_sh.slices()):
        u_out[:, slices, :] = torch.einsum('ij,Nmi->Nmj', T_2[i], u_in[:, slices, :])
        
    return u_out

In [41]:
u_out = stage_1(F)

u_out_new = stage_1(F_new)

u_out_new_rot = torch.einsum('Njn,jk->Nkn', u_out_new, rot_matrix.T)

In [42]:
torch.allclose(u_out, u_out_new_rot, atol=1e-03)

True

In [43]:
def stage_2(u_in, F):
    
    torch.manual_seed(32)
    
    n1 = 4
    n2 = 32
    L1 = L2 = range(L + 1)
    
    n3 = 32
    L3 = range(L + 1)

    T_3 = [[[order_3_equvariant_tensor()(l3, l1, l2, n3, n1, n2) for l2 in L2] for l1 in L1] for l3 in L3]

    #T_2_tmp = torch.zeros((u_in.shape[0], u_in.shape[1], n2))

    T_2_tmp = [[None for l1 in L1] for l3 in L3]

    u_out = torch.zeros((u_in.shape[0], u_in.shape[1], n3))


    v = F.clone()
    for l2, slices in enumerate(irreps_edge_sh.slices()):
        if l2 == 0:
            for l3 in L3:
                for l1 in L1:
                    T_2_tmp[l3][l1] = torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])
        else:
            for l3 in L3:
                for l1 in L1:
                    T_2_tmp[l3][l1] += torch.einsum('abcijk,Nck->Nabij', T_3[l3][l1][l2], u_in[:, slices, :])


    for l3 in L3:    
        for l1, slices in enumerate(irreps_edge_sh.slices()):
            u_out[:, irreps_edge_sh.slices()[l3], :] += torch.einsum('Nabij,Nbj->Nai', T_2_tmp[l3][l1], v[:, slices, :])

            
    return u_out

In [44]:
u_out_2 = stage_2(u_out, F)

u_out_2_new = stage_2(u_out_new, F_new)

u_out_2_new_rot = torch.einsum('Njn,jk->Nkn', u_out_2_new, rot_matrix.T)

In [45]:
torch.allclose(u_out_2, u_out_2_new_rot, atol=1e-03)

True

### Reshaping

In [63]:
def tri_ineq(l1, l2, l3):
    #print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self, l1, l2, l3, n1, n2, n3):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
    
        weight = torch.zeros([n1, n2, n3])
        self.weight = nn.Parameter(nn.init.kaiming_uniform_(weight))
        
        
        self.symbol_3j = wigner_3j(l1, l2, l3)
        
        
    def __call__(self):
        
        if tri_ineq(l1, l2, l3):
            return self.symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*self.weight.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
def contract_2_tensors(tensor_1, tensor_2):
    weight_new = (tensor_1.weight.flatten(end_dim = -2) @ tensor_2.weight.flatten(start_dim = 1)).view(tensor_1.n1, tensor_1.n2, tensor_2.n2,
                                                                                                       tensor_2.n3)
    
    
    return weight_new, tensor_1.symbol_3j, tensor_2.symbol_3j
    
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return W.view(*weight.shape)
        else:
            return torch.zeros((n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 1nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1;
        self.n1 = n1;
        
        weight = torch.zeros([n1])
        W = nn.Parameter(nn.init.uniform(weight))
        return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)


In [65]:
tensor_1 = order_3_equvariant_tensor(0, 0, 0, 4, 8, 32)
tensor_2 = order_3_equvariant_tensor(0, 0, 0, 32, 8, 8)

tensor_4th_order_weight = contract_2_tensors(tensor_1, tensor_2)[0]

tensor_4th_order_weight.shape

torch.Size([4, 8, 8, 8])

In [67]:
sum([max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3]) for l1 in range(10) for l2 in range(10) for l3 in range(10)])




505

In [1]:
class InterpolativeLiftingKernel(LiftingKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        # Create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels.
        self.weight = torch.nn.Parameter(torch.zeros((
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        ), device=self.group.identity.device))

        # Initialize weights using kaiming uniform intialisation.
        torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))

    def sample(self):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # torch grid_sample function.

        ## YOUR CODE STARTS HERE ##
        weight = self.weight.view(
            self.out_channels * self.in_channels,
            self.kernel_size,
            self.kernel_size
        )
        ## AND ENDS HERE ##

        # Sample the transformed kernels.
        transformed_weight = []
        for spatial_grid_idx in range(self.group.elements().numel()):
            transformed_weight.append(
                bilinear_interpolation(weight, self.transformed_grid_R2[:, spatial_grid_idx, :, :])
            )
        transformed_weight = torch.stack(transformed_weight)

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        )

        # Put out channel dimension before group dimension. We do this
        # to be able to use pytorched Conv2D. Details below!
        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight


class LiftingConvolution(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size, padding):
        super().__init__()

        self.kernel = InterpolativeLiftingKernel(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )

        self.padding = padding

    def forward(self, x):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """

        # Obtain convolution kernels transformed under the group.

        ## YOUR CODE STARTS HERE ##
        conv_kernels = self.kernel.sample()
        ## AND ENDS HERE ##

        # Apply lifting convolution. Note that using a reshape we can fold the
        # group dimension of the kernel into the output channel dimension. We
        # treat every transformed kernel as an additional output channel. This
        # way we can use pytorch's conv2d function!

        # Question: Do you see why we (can) do this?

        ## YOUR CODE STARTS HERE ##
        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                self.kernel.out_channels * self.kernel.group.elements().numel(),
                self.kernel.in_channels,
                self.kernel.kernel_size,
                self.kernel.kernel_size
            ),
            padding=self.padding
        )
        ## AND ENDS HERE ##

        # Reshape [batch_dim, in_channels * num_group_elements, spatial_dim_1,
        # spatial_dim_2] into [batch_dim, in_channels, num_group_elements,
        # spatial_dim_1, spatial_dim_2], separating channel and group
        # dimensions.
        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2]
        )

        return x


NameError: name 'torch' is not defined

In [None]:
class myLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, input):
        x, y = input.shape
        if y != self.in_features:
            print(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
            return 0
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
        return ret
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )