In [44]:
%%html
<style type='text/css'>
.CodeMirror{
font-size: 11px;
</style>

In [43]:
import torch
import numpy as np 

inputs = torch.FloatTensor(3, 2, 5, 5)
print(len(inputs))
print(inputs.shape)
print(inputs[0].shape)
print(inputs[0][1].shape)

3
torch.Size([3, 2, 5, 5])
torch.Size([2, 5, 5])
torch.Size([5, 5])


In [64]:
def ops_2_to_2(inputs, normalize = False):
    """
    Construct the 15 broadcast tensors for a 2 -> 2 equivariant layer 
    """
    N, D, m, m = inputs.shape
    dim = inputs.shape[-1]
    
    # summation tensors 
    diag_part = torch.diagonal(inputs, dim1=-2, dim2=-1) # N x D x m
    sum_diag_part = diag_part.sum(dim=2, keepdims=True) # N x D x 1
    sum_rows = inputs.sum(dim=3) # N x D x m
    sum_cols = inputs.sum(dim=2) # N x D x m
    sum_all = inputs.sum(dim=(2,3)) # N x D
    
    # broadcast the summation tensors
    ops = [None]*(15+1)
    ops[1] = torch.diag_embed(diag_part) # N x D x m x m 
    # dim = 2
    ops[2] = torch.diag_embed(sum_diag_part.tile(-1, -1, dim))
    ops[3] = torch.diag_embed(sum_rows)
    ops[4] = torch.diag_embed(sum_cols)
    ops[5] = torch.diag_embed(sum_all.unsqueeze(-1).tile(-1,-1,dim ))
    ops[6] = sum_cols.unsqueeze(3).tile(-1, -1, -1, dim)
    ops[7] = sum_rows.unsqueeze(3).tile(-1, -1, -1, dim)
    ops[8] = sum_cols.unsqueeze(2).tile(-1, -1, dim, -1)
    ops[9] = sum_rows.unsqueeze(2).tile(-1, -1, dim, -1)
    ops[10] = inputs
    ops[11] = torch.transpose(inputs, 2, 3)
    ops[12] = diag_part.unsqueeze(3).tile(-1, -1, -1, dim)
    ops[13] = diag_part.unsqueeze(2).tile(-1, -1, dim, -1)
    ops[14] = sum_diag_part.unsqueeze(3).tile(-1, -1, dim, dim)
    ops[15] = sum_all.unsqueeze(-1).unsqueeze(-1).tile(-1, -1, dim, dim)
    
    return torch.stack(ops[1:], dim=2)

In [68]:
def ops_3_to_3(inputs):
    """
    Construct a minimal subset (20) of the 3 -> 3 broadcast tensors
    """
    N, D, m, m, m = inputs.shape
    # Summation tensors
    sum_all = inputs.sum(dim=(-1, -2, -3))
    sum_c1 = inputs.sum(dim=-1)
    sum_c2 = inputs.sum(dim=-2)
    sum_c3 = inputs.sum(dim=-3)
    sum_c12 = inputs.sum(dim=(-1, -2))
    sum_c13 = inputs.sum(dim=(-1, -3))
    sum_c23 = inputs.sum(dim=(-2, -3))
    # Broadcast the summation tensors
    ops = [None] * 20
    ops[1] = sum_all.view(N, D, 1, 1, 1).expand(-1, -1, dim, dim, dim) / (m * m * m)
    ops[2]  = sum_c1.unsqueeze(-1).expand(-1, -1, -1, -1, m) / m
    ops[3]  = sum_c1.unsqueeze(-2).expand(-1, -1, -1, m, -1) / m
    ops[4]  = sum_c1.unsqueeze(-3).expand(-1, -1, m, -1, -1) / m
    ops[5]  = sum_c2.unsqueeze(-1).expand(-1, -1, -1, -1, m) / m
    ops[6]  = sum_c2.unsqueeze(-2).expand(-1, -1, -1, m, -1) / m
    ops[7]  = sum_c2.unsqueeze(-3).expand(-1, -1, m, -1, -1) / m
    ops[8]  = sum_c3.unsqueeze(-1).expand(-1, -1, -1, -1, m) / m
    ops[9]  = sum_c3.unsqueeze(-2).expand(-1, -1, -1, m, -1) / m
    ops[10] = sum_c3.unsqueeze(-3).expand(-1, -1, m, -1, -1) / m
    ops[11] = sum_c12.view(N, D, m, 1, 1).expand(-1, -1, -1, m, m) / (m*m)
    ops[12] = sum_c12.view(N, D, 1, m, 1).expand(-1, -1, m, -1, m) / (m*m)
    ops[13] = sum_c12.view(N, D, 1, 1, m).expand(-1, -1, m, m, -1) / (m*m)
    ops[14] = sum_c13.view(N, D, m, 1, 1).expand(-1, -1, -1, m, m) / (m*m)
    ops[15] = sum_c13.view(N, D, 1, m, 1).expand(-1, -1, m, -1, m) / (m*m)
    ops[16] = sum_c13.view(N, D, 1, 1, m).expand(-1, -1, m, m, -1) / (m*m)
    ops[17] = sum_c23.view(N, D, m, 1, 1).expand(-1, -1, -1, m, m) / (m*m)
    ops[18] = sum_c23.view(N, D, 1, m, 1).expand(-1, -1, m, -1, m) / (m*m)
    ops[19] = sum_c23.view(N, D, 1, 1, m).expand(-1, -1, m, m, -1) / (m*m)
    return torch.stack(ops[1:], dim=2)

In [71]:
import torch.nn as nn

In [72]:
class Eq2to2(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Eq2to2, self).__init__()
        self.basis = 15 # Bell(2+2) = 15
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.coefs = nn.Parameter(torch.zeros(in_dim, out_dim, self.basis))
        self.bias = nn.Parameter(torch.zeros(1, out_dim, 1, 1))
    def forward(self, inputs):
        ops = ops_2_to_2(inputs)
        output = torch.einsum('dsb,ndbij->nsij', self.coefs, ops)
        output = output + self.bias
        return output

In [39]:
inputs.shape

torch.Size([3, 2, 5, 5])

In [40]:
diag_part = torch.diagonal(inputs, dim1=-2, dim2=-1)

In [42]:
diag_part.shape

torch.Size([3, 2, 5])