In [8]:
from e3nn.o3 import spherical_harmonics
from e3nn.o3 import wigner_3j
from torch import Tensor
import torch
from torch import nn

In [9]:
r = torch.Tensor([1., 10., 20.])
r_ang = r/torch.linalg.vector_norm(r)


Y_1 = spherical_harmonics('1e', r_ang, normalize = 'component')
Y_2 = spherical_harmonics('2e', r_ang, normalize = 'component')
Y_3 = spherical_harmonics('2e', r_ang, normalize = 'component')

features = torch.concatenate([Y_1, Y_2, Y_3])

features

tensor([ 0.0218,  0.2183,  0.4366,  0.0436,  0.0218, -0.1265,  0.4361,  0.4351,
         0.0436,  0.0218, -0.1265,  0.4361,  0.4351])

In [10]:
l1, l2, l3 = 1, 2, 3
m1, m2, m3 = l1, l2, l3 # m1 = 0, m2 = 0, m3 = 0
C123 = wigner_3j(1, 2, 3)

In [11]:
C123_000 = C123[m1, m2, m3]

C123_000

tensor(0.2928)

In [12]:
C123.shape

torch.Size([3, 5, 7])

In [13]:
class weigner_3j_img_to_real():


    def __init__(self, l):
        self.l = l

    def __call__(self):
        matrix = torch.zeros((2*self.l + 1, 2*self.l + 1), dtype = torch.complex64)

        mult = 1
        for i in range(2*self.l + 1):
            
            if i < self.l:
                matrix[i, i] = 1.0j/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = 1/2**(1/2.)
            elif i == self.l:
                matrix[i, i] = 1.
            else:
                matrix[i, i] = (-1.0)*mult/2**(1/2.)
                matrix[2*self.l + 1 - i - 1, i] = (-1.0j)*mult/2**(1/2.)
                mult *= -1
        
        return matrix

In [14]:
weigner_3j_img_to_real(1)()*2**(1/2.)

tensor([[ 0.0000+1.0000j,  0.0000+0.0000j,  0.0000-1.0000j],
        [ 0.0000+0.0000j,  1.4142+0.0000j,  0.0000+0.0000j],
        [ 1.0000+0.0000j,  0.0000+0.0000j, -1.0000+0.0000j]])

In [62]:


def tri_ineq(l1, l2, l3):
    print(max([l1, l2, l3]), min([l1 + l2, l2 + l3, l1 + l3]))
    return max([l1, l2, l3]) <= min([l1 + l2, l2 + l3, l1 + l3])

class order_3_equvariant_tensor():
    
    def __init__(self):
        """ Basically transformation using
            C (l1 l2 l3)
              (m1 m2 m3)"""
        
        pass
    
    
    def __call__(self, l1, l2, l3, n1, n2, n3):
        self.l1 = l1; self.l2 = l2; self.l3 = l3
        self.n1 = n1; self.n2 = n2; self.n3 = n3
        
        weight = torch.zeros([n1, n2, n3])
        if tri_ineq(l1, l2, l3):
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            symbol_3j = wigner_3j(1, 2, 3)
            return symbol_3j.view(*symbol_3j.shape, 1, 1, 1)*W.view(1, 1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, 2*l3 + 1, n1, n2, n3))
        
        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return torch.ones((2*l1 + 1, 2*l2 + 1, 1, 1))*W.view(1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, n1, n2))

        
class order_2_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, l2, n1, n2):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1, n2])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return torch.ones((2*l1 + 1, 2*l2 + 1, 1, 1))*W.view(1, 1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, 2*l2 + 1, n1, n2))

        
class order_1_equivariant_tensor():
    
    def __init__(self):
        """Makes the 2nd order tensor in a way that
           each lm is multiplied by coefficient c, no angular momentum mixing
        """
        pass
    
    def __call__(self, l1, n1):
        self.l1 = l1; self.l2 = l2
        self.n1 = n1; self.n2 = n2
        
        weight = torch.zeros([n1])
        if l1 == l2: # same as tri_ineq(l1, l2, 0)
            W = nn.Parameter(nn.init.kaiming_uniform_(weight))
            return torch.ones((2*l1 + 1, 1))*W.view(1, *weight.shape)
        else:
            return torch.zeros((2*l1 + 1, n1))


In [63]:
print(order_1_equivariant_tensor()(2, 2, 3).shape)
print(order_1_equivariant_tensor()(2, 2, 3, 4).shape)

print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)
print(order_2_equivariant_tensor()(2, 2, 3, 4).shape)


print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).sum())
print(order_3_equvariant_tensor()(1, 2, 3, 4, 6, 8).shape)



torch.Size([5, 5, 3, 4])
torch.Size([5, 5, 3, 4])
3 3
tensor(-11.6505, grad_fn=<SumBackward0>)
3 3
torch.Size([3, 5, 7, 4, 6, 8])


In [50]:
tri_ineq(0, 0, 0)

0 0


True