In [None]:
# inspired by 
# https://github.com/e3nn/e3nn/blob/main/examples/tetris_polynomial.py
# https://github.com/e3nn/e3nn/blob/main/examples/tetris_gate.py
# Tensor Field Networks https://arxiv.org/abs/1802.08219

In [None]:
import torch
from torch_scatter import scatter
import numpy as np
import math

import matplotlib.pyplot as plt

In [None]:
# Transform from "standard physics" vector basis to basis where rotation matrices are real-valued.
# This could be coded into the coefficients below, but I believe writing it out like this makes things more readable.
# The C-G coefficients below can now be read straight from Wikipedia 
# https://en.wikipedia.org/w/index.php?title=Table_of_Clebsch%E2%80%93Gordan_coefficients&oldid=1083617208
rot3_basis = torch.tensor([  
    [-np.sqrt(0.5), 0., np.sqrt(0.5)],
    [-1.j*np.sqrt(0.5), 0., -1.j*np.sqrt(0.5)],
    [0., 1., 0.],
], dtype=torch.cfloat)

def scalar_scalar_mult(scalar1, scalar2):
    """ 
        Multiply two scalars and output a scalar. 
        input shapes: [b, n], [b, n]
        output shapes: [b, n]
    """
    return scalar1 * scalar2

def scalar_spinor_mult(scalar, spinor):
    """ 
        Multiply a scalar and a spinor and output a spinor. 
        input shapes: [b, n], [b, 2, n]
        output shapes: [b, n, 2]
    """
    return scalar[:, None, :] * spinor

def scalar_vector_mult(scalar, vector):
    """ 
        Multiply a scalar and a vector and output a vector. 
        input shapes: [b, n], [b, 3, n]
        output shapes: [b, 3, n]
    """
    return scalar[:, None, :] * vector

def spinor_spinor_mult(spinor1, spinor2):
    """ 
        Multiply two spinors and output a scalar and a vector. 
        input shapes: [b, 2, n], [b, 2, n]
        output shapes: [b, n], [b, 3, n]
    """
    a1 = spinor1[:, 0]
    b1 = spinor1[:, 1]
    a2 = spinor2[:, 0]
    b2 = spinor2[:, 1]
    
    # vector output (note that the vector is complex valued)
    x = a1 * a2
    y = np.sqrt(0.5) * (a1 * b2 + b1 * a2)
    z = b1 * b2
    v = rot3_basis @ torch.stack([x, y, z], dim=1)  # Transform to basis where rotations are real
    
    # scalar output (complex valued)
    s = np.sqrt(0.5) * (a1 * b2 - b1 * a2)
    
    return s, v

def spinor_spinor_mult_vector(spinor1, spinor2):
    """ 
        Multiply two spinors and output a vector. 
        input shapes: [b, 2, n], [b, 2, n]
        output shapes: [b, 3, n]
    """
    a1 = spinor1[:, 0]
    b1 = spinor1[:, 1]
    a2 = spinor2[:, 0]
    b2 = spinor2[:, 1]
    
    # vector output (note that the vector is complex valued)
    x = a1 * a2
    y = np.sqrt(0.5) * (a1 * b2 + b1 * a2)
    z = b1 * b2
    v = rot3_basis @ torch.stack([x, y, z], dim=1)  # Transform to basis where rotations are real
    
    return v

def spinor_spinor_mult_scalar(spinor1, spinor2):
    """ 
        Multiply two spinors and output a scalar. 
        input shapes: [b, 2, n], [b, 2, n]
        output shapes: [b, n]
    """
    a1 = spinor1[:, 0]
    b1 = spinor1[:, 1]
    a2 = spinor2[:, 0]
    b2 = spinor2[:, 1]
    
    # scalar output (complex valued)
    s = np.sqrt(0.5) * (a1 * b2 - b1 * a2)
    
    return s

def spinor_vector_mult(spinor, vector):
    """
        Multiply a spinor and a vector and output a spinor.
        (A 3/2-spin 4-dimensional output would be possible but we ignore it here.)
        input shapes: [b, 2, n], [b, 3, n]
        output shapes: [b, 2, n]
    """
    vector = torch.conj(rot3_basis).T @ vector.cfloat()   # Transform to basis with easy C-G coeffs
    
    a = spinor[:, 0]
    b = spinor[:, 1]
    
    x = vector[:, 0]
    y = vector[:, 1]
    z = vector[:, 2]
    
    a_new = np.sqrt(2/3) * x * b - np.sqrt(1/3) * y * a
    b_new = -np.sqrt(2/3) * z * a + np.sqrt(1/3) * y * b
    
    return torch.stack([a_new, b_new], dim=1)

def vector_vector_mult(vector1, vector2):
    """
        Multiply two vectors and output a scalar and a vector.
        (A 2-spin 5-dimensional output would be possible but we ignore it here.)
        input shapes: [b, 3, n], [b, 3, n]
        output shapes: [b, n], [b, 3, n]
    """
    
    return torch.linalg.vecdot(vector1, vector2, dim=1), torch.cross(vector1, vector2, dim=1)

def spinor_to_vec(spinor):
    return spinor_spinor_mult(spinor, spinor)[1]

def spinor_to_vec2(spinor):
    a = spinor[:, 0]
    b = spinor[:, 1]
    x = torch.real(a * torch.conj(b) + b * torch.conj(a))
    y = torch.imag(-a * torch.conj(b) + b * torch.conj(a))
    z = torch.real(a * torch.conj(a) - b * torch.conj(b))
    return torch.stack([x, y, z], dim1)

In [None]:
class ParametrizedTensorProduct(torch.nn.Module):
    def __init__(self, 
                 in_scalars, in_spinors, in_vectors, 
                 in2_scalars, in2_spinors, in2_vectors,
                 out_scalars, out_spinors, out_vectors,
                ):
        super().__init__()
        self.in_scalars = in_scalars
        self.in_spinors = in_spinors
        self.in_vectors = in_vectors
        self.in2_scalars = in2_scalars
        self.in2_spinors = in2_spinors
        self.in2_vectors = in2_vectors
        self.out_scalars = out_scalars
        self.out_spinors = out_spinors
        self.out_vectors = out_vectors
        
        if in_spinors > 0 and in2_spinors > 0 and (
            out_scalars % 2 != 0 or out_vectors % 2 != 0
        ):
            raise ValueError(
                'Implementation requires spinor multiplication to real valued output to have even nbr output channels'
            )
        
        if in_scalars > 0 and in2_scalars > 0 and out_scalars > 0:
            self.scalar_to_scalar = torch.nn.Linear(in_scalars, out_scalars)
            self.scalar2_to_scalar = torch.nn.Linear(in2_scalars, out_scalars)
            
        if in_scalars > 0 and in2_spinors > 0 and out_spinors > 0:
            self.scalar_to_spinor = torch.nn.Linear(in_scalars, out_spinors)
            self.spinor2_to_spinor = torch.nn.Linear(in2_spinors, out_spinors, bias=False).to(torch.cfloat)
        if in_spinors > 0 and in2_scalars > 0 and out_spinors > 0:
            self.spinor_to_spinor = torch.nn.Linear(in_spinors, out_spinors, bias=False).to(torch.cfloat)
            self.scalar2_to_spinor = torch.nn.Linear(in2_scalars, out_spinors)
            
        if in_scalars > 0 and in2_vectors > 0 and out_vectors > 0:
            self.scalar_to_vector = torch.nn.Linear(in_scalars, out_vectors)
            self.vector2_to_vector = torch.nn.Linear(in2_vectors, out_vectors, bias=False)
        if in_vectors > 0 and in2_scalars > 0 and out_vectors > 0:
            self.scalar2_to_vector = torch.nn.Linear(in2_scalars, out_vectors)
            self.vector_to_vector = torch.nn.Linear(in_vectors, out_vectors, bias=False)
            
        if in_spinors > 0 and in2_spinors > 0 and out_scalars > 0:
            self.spinor_to_scalar = torch.nn.Linear(in_spinors, out_scalars // 2, bias=False).to(torch.cfloat)
            self.spinor2_to_scalar = torch.nn.Linear(in2_spinors, out_scalars // 2, bias=False).to(torch.cfloat)    
            
        if in_spinors > 0 and in2_spinors > 0 and out_vectors > 0:
            self.spinor_to_vector = torch.nn.Linear(in_spinors, out_vectors // 2, bias=False).to(torch.cfloat)
            self.spinor2_to_vector = torch.nn.Linear(in2_spinors, out_vectors // 2, bias=False).to(torch.cfloat)
            
        if in_vectors > 0 and in2_spinors > 0 and out_spinors > 0:
            self.vector_to_spinor = torch.nn.Linear(in_vectors, out_spinors, bias=False)
            self.spinor2_to_spinor_2 = torch.nn.Linear(in2_spinors, out_spinors, bias=False).to(torch.cfloat)
        if in2_vectors > 0 and in_spinors > 0 and out_spinors > 0:
            self.vector2_to_spinor = torch.nn.Linear(in2_vectors, out_spinors, bias=False)
            self.spinor_to_spinor_2 = torch.nn.Linear(in_spinors, out_spinors, bias=False).to(torch.cfloat)
            
        if in_vectors > 0 and in2_vectors > 0 and out_scalars > 0:
            self.vector_to_scalar = torch.nn.Linear(in_vectors, out_scalars, bias=False)
            self.vector2_to_scalar = torch.nn.Linear(in2_vectors, out_scalars, bias=False)
            
        if in_vectors > 0 and in2_vectors > 0 and out_vectors > 0:
            self.vector_to_vector_2 = torch.nn.Linear(in_vectors, out_vectors, bias=False)
            self.vector2_to_vector_2 = torch.nn.Linear(in2_vectors, out_vectors, bias=False)
            
            
    def forward(self, scalars, spinors, vectors, scalars2, spinors2, vectors2):
        """
        scalar shapes: [batch, channels]
        spinor shapes: [batch, 2, channels]
        vector shapes: [batch, 3, channels]
        """
        for f in [scalars, spinors, vectors]:
            if f is not None:
                b = f.shape[0]
                break
        
        if self.out_scalars > 0:
            scalars_out = torch.zeros([b, self.out_scalars])
        else:
            scalars_out = None
            
        if self.out_spinors > 0:
            spinors_out = torch.zeros([b, 2, self.out_spinors], dtype=torch.cfloat)
        else:
            spinors_out = None
            
        if self.out_vectors > 0:
            vectors_out = torch.zeros([b, 3, self.out_vectors])
        else:
            vectors_out = None
        
        if self.in_scalars > 0 and self.in2_scalars > 0 and self.out_scalars > 0:
            scalars_out += scalar_scalar_mult(
                self.scalar_to_scalar(scalars),
                self.scalar2_to_scalar(scalars2),
            )
            
        if self.in_scalars > 0 and self.in2_spinors > 0 and self.out_spinors > 0:
            spinors_out += scalar_spinor_mult(
                self.scalar_to_spinor(scalars),
                self.spinor2_to_spinor(spinors2),
            )
        if self.in_spinors > 0 and self.in2_scalars > 0 and self.out_spinors > 0:
            spinors_out += scalar_spinor_mult(
                self.scalar2_to_spinor(scalars2),
                self.spinor_to_spinor(spinors),
            )
            
        if self.in_scalars > 0 and self.in2_vectors > 0 and self.out_vectors > 0:
            vectors_out += scalar_vector_mult(
                self.scalar_to_vector(scalars),
                self.vector2_to_vector(vectors2),
            )
        if self.in_vectors > 0 and self.in2_scalars > 0 and self.out_vectors > 0:
            vectors_out += scalar_vector_mult(
                self.scalar2_to_vector(scalars2),
                self.vector_to_vector(vectors),
            )
            
        if self.in_spinors > 0 and self.in2_spinors > 0 and self.out_scalars > 0:
            value = spinor_spinor_mult_scalar(
                self.spinor_to_scalar(spinors),
                self.spinor2_to_scalar(spinors2),
            )
            scalars_out += torch.cat([value.real, value.imag], dim=-1)
            
        if self.in_spinors > 0 and self.in2_spinors > 0 and self.out_vectors > 0:
            value = spinor_spinor_mult_vector(
                self.spinor_to_vector(spinors),
                self.spinor2_to_vector(spinors2),
            )
            vectors_out += torch.cat([value.real, value.imag], dim=-1)
            
        if self.in_vectors > 0 and self.in2_spinors > 0 and self.out_spinors > 0:
            spinors_out += spinor_vector_mult(
                self.spinor2_to_spinor_2(spinors2),
                self.vector_to_spinor(vectors),
            )
        if self.in2_vectors > 0 and self.in_spinors > 0 and self.out_spinors > 0:
            spinors_out += spinor_vector_mult(
                self.spinor_to_spinor_2(spinors),
                self.vector2_to_spinor(vectors2),
            )
            
        if self.in_vectors > 0 and self.in2_vectors > 0 and self.out_scalars > 0:
            scalars_out += torch.linalg.vecdot(
                self.vector_to_scalar(vectors),
                self.vector2_to_scalar(vectors2),
                dim=1,
            )
            
        if self.in_vectors > 0 and self.in2_vectors > 0 and self.out_vectors > 0:
            vectors_out += torch.cross(
                self.vector_to_vector_2(vectors),
                self.vector2_to_vector_2(vectors2),
                dim=1
            )
            
        return scalars_out, spinors_out, vectors_out
    
    
class GatedParametrizedTensorProduct(ParametrizedTensorProduct):
    def __init__(self, *args):
        super().__init__(*args)
        if (self.out_spinors > 0 or self.out_vectors > 0) and self.out_scalars <= 0:
            raise ValueError('Need output scalars to produce gates')
        if self.out_spinors > 0:
            self.spinor_gate = torch.nn.Linear(self.out_scalars, self.out_spinors, bias=True)
        if self.out_vectors > 0:
            self.vector_gate = torch.nn.Linear(self.out_scalars, self.out_vectors, bias=True)
            
    def forward(self, *args):
        """
        scalar shapes: [batch, channels]
        spinor shapes: [batch, 2, channels]
        vector shapes: [batch, 3, channels]
        """
        scalars_out, spinors_out, vectors_out = super().forward(*args)
        
        scalars_out = torch.nn.functional.gelu(scalars_out)
        
        if self.out_spinors > 0:
            gate_vals = self.spinor_gate(scalars_out)
            norms = spinors_out.norm(dim=1)
            gated_norms = torch.sigmoid(gate_vals)
            spinors_out = spinors_out * gated_norms[:, None]
        if self.out_vectors > 0:
            gate_vals = self.vector_gate(scalars_out)
            norms = vectors_out.norm(dim=1)
            gated_norms = torch.sigmoid(gate_vals)
            vectors_out = vectors_out * gated_norms[:, None]
        
        return scalars_out, spinors_out, vectors_out

In [None]:
def rot_y_2_tensor(a):
    return torch.stack([
        torch.stack([torch.cos(0.5*a), torch.sin(0.5*a)], dim=1),
        torch.stack([-torch.sin(0.5*a), torch.cos(0.5*a)], dim=1),
    ], dim=2).cfloat().mT

def rot_z_2_tensor(a):
    return torch.stack([
        torch.stack([torch.exp(0.5j*a), torch.zeros_like(a).cfloat()], dim=1),
        torch.stack([torch.zeros_like(a).cfloat(), torch.exp(-0.5j*a)], dim=1),
    ], dim=2).mT

def rot_y_3_tensor(a):
    return torch.stack([
        torch.stack([torch.cos(a), torch.zeros_like(a), -torch.sin(a)], dim=1),
        torch.stack([torch.zeros_like(a), torch.ones_like(a), torch.zeros_like(a)], dim=1),
        torch.stack([torch.sin(a), torch.zeros_like(a), torch.cos(a)], dim=1),
    ], dim=2).mT

def rot_z_3_tensor(a):
    return torch.stack([
        torch.stack([torch.cos(a), torch.sin(a), torch.zeros_like(a)], dim=1),
        torch.stack([-torch.sin(a), torch.cos(a), torch.zeros_like(a)], dim=1),
        torch.stack([torch.zeros_like(a), torch.zeros_like(a), torch.ones_like(a)], dim=1),
    ], dim=2).mT

In [None]:
def sample_zyz(n):  # todo: make uniform over SU(2)?
    a = 2 * math.pi * torch.rand(n)
    b = torch.acos(2 * torch.rand(n) - 1)
    c = 2 * math.pi * torch.rand(n)
    return a, b, c

def rot_zyz_2(a, b, c):
    return rot_z_2_tensor(a) @ rot_y_2_tensor(b) @ rot_z_2_tensor(c)

def rot_zyz_3(a, b, c):
    return rot_z_3_tensor(a) @ rot_y_3_tensor(b) @ rot_z_3_tensor(c)

def sample_SU2_2_and_3(n):
    a, b, c = sample_zyz(n)
    R2, R3 = rot_zyz_2(a, b, c), rot_zyz_3(a, b, c)
    if np.random.rand() > .5:
        R2 *= -1
    return R2, R3

In [None]:
def spin_data(random_su2=True, random_translation=True, random_noise=0, random_sign=False, label_type='integer'):
    # classes should neither be determined by only the 3D point positions
    # or only the spin-1/2 features, but only by both jointly.
    pos = torch.tensor([
        [(0., 0, 0), (0, 0, 1), (1, 0, 0)],  # triangle 1
        [(0., 0, 0), (0, 0, 1), (0, 1, 2)],  # triangle 2
        [(0., 0, 0), (0, 0, 1), (1, 0, 0)],  # triangle 3
        [(0., 0, 0), (0, 0, 1), (0, 1, 2)],  # triangle 4
    ])

    spin_feats = torch.tensor([
        [(1., 1.j), (0., 1.), (1.j, 1.)],      # triangle 1
        [(1., 1.j), (0., 1.), (1.j, 1.)],      # triangle 2
        [(1, 0.), (0., 1.), (1.j, 1.j)],       # triangle 3
        [(1, 0.), (0., 1.), (1.j, 1.j)],       # triangle 4
    ])

    if label_type == 'integer':
        labels = torch.tensor(
            [
                0,  # triangle 1
                1,  # triangle 2
                2,  # triangle 3
                3,  # triangle 4
            ], dtype=torch.long,
        )
    elif label_type == 'spinor':
        labels = torch.tensor(
            [
                (1., 0.),  # triangle 1
                (np.sqrt(0.5), np.sqrt(0.5)),  # triangle 2
                (0., 1.),  # triangle 3
                (-np.sqrt(0.5), np.sqrt(0.5)),  # triangle 4
            ], dtype=torch.cfloat,
        )
    else:
        labels = torch.tensor(
            [
                [1, 0, 0, 0],  # triangle 1
                [0, 1, 0, 0],  # triangle 2
                [0, 0, 1, 0],  # triangle 3
                [0, 0, 0, 1],  # triangle 4
            ]
        )
    
    batch = torch.tensor(
        [[j for _ in range(pos.shape[1])] for j in range(pos.shape[0]) ]
    ).flatten()
    
    # apply random SU(2)-element
    if random_su2:
        R2, R3 = sample_SU2_2_and_3(pos.shape[0])
        pos = torch.einsum('bij, bnj -> bni', R3, pos)
        spin_feats = torch.einsum('bij, bnj -> bni', R2, spin_feats)
        
        if label_type == 'spinor':
            labels = torch.einsum('bij, bj -> bi', R2, labels)
            
    # multiply random scalar on input and output (projective spaces)
    if random_sign:
        s = torch.randint(2, size=[spin_feats.shape[0], 1, 1])
        spin_feats *= 2. * (s - 0.5)
        s = torch.randint(2, size=[labels.shape[0], 1])
        labels *= 2. * (s - 0.5)
        
    # apply random translation (should have no effect for conv-nets)
    if random_translation:
        pos += torch.randn([pos.shape[0], 1, 3])
    
    pos = pos.flatten(end_dim=1)
    spin_feats = spin_feats.flatten(end_dim=1)
    
    if random_noise > 0:
        pos += random_noise * torch.randn(pos.shape)
    
    return batch, pos, spin_feats, labels

In [None]:
tp = ParametrizedTensorProduct(0, 1, 1, 0, 1, 1, 2, 0, 0)

batch, pos1, spin_feats1, labels = spin_data(True, False)
batch, pos2, spin_feats2, labels = spin_data(True, False)

print('Test rot invariance')
print(f'Input positions equal: {torch.allclose(pos1, pos2)}')
print(f'Input spinors equal: {torch.allclose(spin_feats1, spin_feats2)}')

sc1, _, _ = tp(None, spin_feats1[..., None], pos1[..., None], None, spin_feats1[..., None], pos1[..., None])
sc2, _, _ = tp(None, spin_feats2[..., None], pos2[..., None], None, spin_feats2[..., None], pos2[..., None])

print(f'''Outputs equal: {
    torch.allclose(sc1, sc2, atol=1e-7)
}''')


In [None]:
gtp = GatedParametrizedTensorProduct(0, 1, 1, 0, 1, 1, 2, 0, 0)

batch, pos1, spin_feats1, labels = spin_data(True, False)
batch, pos2, spin_feats2, labels = spin_data(True, False)

print('Test rot invariance')
print(f'Input positions equal: {torch.allclose(pos1, pos2)}')
print(f'Input spinors equal: {torch.allclose(spin_feats1, spin_feats2)}')

sc1, _, _ = gtp(None, spin_feats1[..., None], pos1[..., None], None, spin_feats1[..., None], pos1[..., None])
sc2, _, _ = gtp(None, spin_feats2[..., None], pos2[..., None], None, spin_feats2[..., None], pos2[..., None])

print(f'''Outputs equal: {
    torch.allclose(sc1, sc2, atol=1e-7)
}''')


In [None]:
batch, pos, spin_feats, labels = spin_data()
print(batch)
print(pos)
print(batch.shape, pos.shape, spin_feats.shape, labels.shape)

In [None]:
def complete_graph(batch):  # TODO: make this better...
    rows = []
    cols = []
    for row_idx, b_row in enumerate(batch):
        for col_idx, b_col in enumerate(batch):
            if b_row == b_col and row_idx != col_idx:
                rows.append(row_idx)
                cols.append(col_idx)
    return torch.stack([torch.tensor(rows), torch.tensor(cols)], dim=0)

In [None]:
def scatter_feats(feats_list, edge_dst):
    return [scatter(feats, edge_dst, dim=0, reduce='mean') for feats in feats_list]


class SO3EquivariantRegressor1(torch.nn.Module):
    # Treats data spinors as input features attached to each point.
    def __init__(self, nbr_scalar_feats=32, nbr_spin_feats=12, nbr_vector_feats=12, nbr_output_spinors=1, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_spin_feats = nbr_spin_feats
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            0, 1, 1,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_spin_feats if gated else 0, nbr_spin_feats, 0,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            0, nbr_spin_feats, 0,
            1, 0, 1,
            0, nbr_output_spinors, 0,
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]

        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vector_feats = scatter(edge_vec, edge_dst, dim=0, reduce='mean')
        spin_feats = spin_feats[..., None]
    
        scalar_feats, spin_feats, vector_feats = self.layer1(
            None, spin_feats[edge_src], vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
   
        # hack to scatter complex valued feats: scatter real and imag parts
        scalar_feats, spin_feats_r, spin_feats_i, vector_feats = scatter_feats(
            [scalar_feats, spin_feats.real, spin_feats.imag, vector_feats],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        _, spin_feats, _ = self.layer2(
            scalar_feats[edge_src], spin_feats[edge_src], vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        _, spin_feats, _ = self.layer3(
            None, spin_feats[edge_src], None,
            edge_scalars, None, edge_vec,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        # Take global mean to get single spinor prediction.
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            batch,
        )
        return spin_feats_r[..., 0] + 1.j * spin_feats_i[..., 0]
    

class SO3EquivariantRegressor2(torch.nn.Module):
    # Treats data spinors by defining spinor spherical harmonics filters
    def __init__(self, nbr_scalar_feats=32, nbr_spin_feats=4, nbr_vector_feats=4, nbr_output_spinors=1, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_spin_feats = nbr_spin_feats
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            0, 0, 1,  # input feat counts (scalars, spinors, vectors)
            1, 1, 1,  # filter feat counts
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 1, 1,  # filter feat counts
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, nbr_spin_feats, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 1, 1,  # filter feat counts
            0, nbr_output_spinors, 0,  # output feat counts
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]
        
        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]
        edge_spin = (spin_feats[edge_src])[..., None]
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vector_feats = scatter(edge_vec, edge_dst, dim=0, reduce='mean')
    
        
        scalar_feats, spin_feats, vector_feats = self.layer1(
            None, None, vector_feats[edge_src],
            edge_scalars, edge_spin, edge_vec,
        ) 
        
        # hack to scatter complex valued feats: scatter real and imag parts
        scalar_feats, spin_feats_r, spin_feats_i, vector_feats = scatter_feats(
            [scalar_feats, spin_feats.real, spin_feats.imag, vector_feats],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i    
        
        scalar_feats, spin_feats, vector_feats = self.layer2(
            scalar_feats[edge_src], spin_feats[edge_src], vector_feats[edge_src],
            edge_scalars, edge_spin, edge_vec,
        ) 
        
        # hack to scatter complex valued feats: scatter real and imag parts
        scalar_feats, spin_feats_r, spin_feats_i, vector_feats = scatter_feats(
            [scalar_feats, spin_feats.real, spin_feats.imag, vector_feats],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        _, spin_feats, _ = self.layer3(
            scalar_feats[edge_src], spin_feats[edge_src], vector_feats[edge_src],
            edge_scalars, edge_spin, edge_vec,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        # Take global mean to get single spinor prediction.
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            batch,
        )
        return spin_feats_r[..., 0] + 1.j * spin_feats_i[..., 0]


class SO3EquivariantRegressor3(torch.nn.Module):
    # Treats data spinors as input vectors (by squaring the spinor) attached to each point.
    # Gets spinors at the end by using the input spinors as filter.
    def __init__(self, nbr_scalar_feats=32, nbr_vector_feats=8, nbr_output_spinors=1, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            0, 0, 3,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, 0, nbr_vector_feats,
            0, 1, 0,
            0, nbr_output_spinors, 0,
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]

        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vec = spinor_spinor_mult_vector(spin_feats[..., None], spin_feats[..., None])
        vector_feats = torch.cat([
            scatter(edge_vec, edge_dst, dim=0, reduce='mean'),
            vec.real,
            vec.imag,
        ], dim=-1)

    
        scalar_feats, _, vector_feats = self.layer1(
            None, None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
   
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
        
        scalar_feats, _, vector_feats = self.layer2(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
   
        _, spin_feats, _ = self.layer3(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            None, spin_feats[edge_src, :, None], None,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        # Take global mean to get single spinor prediction.
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            batch,
        )
        return spin_feats_r[..., 0] + 1.j * spin_feats_i[..., 0]
    

class SO3VectorRegressor3(torch.nn.Module):
    # Treats data spinors as input vectors (by squaring the spinor) attached to each point.
    # Outputs scalars that are reinterpreted as a spinor.
    def __init__(self, nbr_scalar_feats=32, nbr_vector_feats=8, nbr_output_scalars=4, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            0, 0, 3,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, 0, nbr_vector_feats,
            1, 0, 1,
            nbr_output_scalars, 0, 0,
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]

        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vec = spinor_spinor_mult_vector(spin_feats[..., None], spin_feats[..., None])
        vector_feats = torch.cat([
            scatter(edge_vec, edge_dst, dim=0, reduce='mean'),
            vec.real,
            vec.imag,
        ], dim=-1)

    
        scalar_feats, _, vector_feats = self.layer1(
            None, None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
   
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
        
        scalar_feats, _, vector_feats = self.layer2(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
   
        scalar_feats, _, _ = self.layer3(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        scalar_feats = scatter_feats(
            [scalar_feats],
            edge_dst,
        )[0]
        
        spinor_out = scatter(scalar_feats, batch, dim=0, reduce='mean')
        return spinor_out[:, :2] + 1.j * spinor_out[:, 2:]

    
class SO3EquivariantRegressor4(torch.nn.Module):
    # Treats data spinors by remapping them to vectors and defining vector spherical harmonics filters.
    # Gets spinors at the end by using the input spinors as filter.
    def __init__(self, nbr_scalar_feats=32, nbr_vector_feats=8, nbr_output_spinors=1, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats 
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
           
        self.layer1 = TP(
            0, 0, 1,  # input feat counts (scalars, spinors, vectors)
            1, 0, 3,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 3,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            0, 1, 0,  # filter feat counts
            0, nbr_output_spinors, 0,  # output feat counts
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]
        
        edge_src, edge_dst = complete_graph(batch)
        
        vec = spinor_spinor_mult_vector(spin_feats[edge_src, :, None], spin_feats[edge_src, :, None])
        edge_vec = torch.cat([
            (pos[edge_src] - pos[edge_dst])[..., None],
            vec.real,
            vec.imag,
        ], dim=-1)
        
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vector_feats = scatter((pos[edge_src] - pos[edge_dst])[..., None], edge_dst, dim=0, reduce='mean')
           
        scalar_feats, _, vector_feats = self.layer1(
            None, None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        ) 
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )     
        
        
        scalar_feats, _, vector_feats = self.layer2(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        ) 
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )     
        
        _, spin_feats, _ = self.layer3(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            None, spin_feats[edge_src, :, None], None,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        # Take global mean to get single spinor prediction.
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            batch,
        )
        return spin_feats_r[..., 0] + 1.j * spin_feats_i[..., 0]
    

class SO3EquivariantRegressor5(torch.nn.Module):
    # Doesn't use input spinors until last layer.
    # Gets spinors at the end by using the input spinors as filter.
    def __init__(self, nbr_scalar_feats=32, nbr_vector_feats=8, nbr_output_spinors=1, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_vector_feats = nbr_vector_feats
        
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            0, 0, 1,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, 0, nbr_vector_feats,
            0, 1, 0,
            0, nbr_output_spinors, 0,
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]

        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]
        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        vector_feats = torch.cat([
            scatter(edge_vec, edge_dst, dim=0, reduce='mean'),
        ], dim=-1)

    
        scalar_feats, _, vector_feats = self.layer1(
            None, None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
   
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
        
        scalar_feats, _, vector_feats = self.layer2(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
   
        _, spin_feats, _ = self.layer3(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            None, spin_feats[edge_src, :, None], None,
        )
        
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            edge_dst,
        )
        spin_feats = spin_feats_r + 1.j * spin_feats_i
        
        # Take global mean to get single spinor prediction.
        # hack to scatter complex valued feats: scatter real and imag parts
        spin_feats_r, spin_feats_i = scatter_feats(
            [spin_feats.real, spin_feats.imag],
            batch,
        )
        return spin_feats_r[..., 0] + 1.j * spin_feats_i[..., 0]
    
    
    
class NonEquivariantRegressor(torch.nn.Module):
    # Treats data spinors by thinking of them as scalar inputs (i.e. not rot invariant).
    # Also output 4 real scalars that are reinterpreted as a complex spinor.
    def __init__(self, nbr_scalar_feats=32, nbr_vector_feats=8, nbr_output_scalars=4, gated=False) -> None:
        super().__init__()
        self.nbr_scalar_feats = nbr_scalar_feats
        self.nbr_vector_feats = nbr_vector_feats
         
        if gated:
            TP = GatedParametrizedTensorProduct
        else:
            TP = ParametrizedTensorProduct
        
        self.layer1 = TP(
            4, 0, 1,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer2 = TP(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_scalar_feats, 0, nbr_vector_feats,  # output feat counts
        )
        
        self.layer3 = ParametrizedTensorProduct(
            nbr_scalar_feats, 0, nbr_vector_feats,  # input feat counts (scalars, spinors, vectors)
            1, 0, 1,  # filter feat counts
            nbr_output_scalars, 0, 0,  # output feat counts
        )


    def forward(self, batch, pos, spin_feats) -> torch.Tensor:
        # batch shape: [batches*nbr_points]
        # pos shape: [batches*nbr_points, 3]
        # spin_feats shape: [batch*nbr_points, 2]
        
        edge_src, edge_dst = complete_graph(batch)
        edge_vec = (pos[edge_src] - pos[edge_dst])[..., None]

        edge_scalars = torch.ones(edge_vec.shape[0], 1)
        
        scalar_feats = torch.cat(
            [
                spin_feats.real,
                spin_feats.imag,
            ], dim=-1
        )
        
        vector_feats = scatter(edge_vec, edge_dst, dim=0, reduce='mean')
    
        scalar_feats, _, vector_feats = self.layer1(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        ) 
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )
        
        scalar_feats, _, vector_feats = self.layer2(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        ) 
        
        scalar_feats, vector_feats = scatter_feats(
            [scalar_feats, vector_feats],
            edge_dst,
        )

        scalar_feats, _, _ = self.layer3(
            scalar_feats[edge_src], None, vector_feats[edge_src],
            edge_scalars, None, edge_vec,
        )
        
        scalar_feats = scatter_feats(
            [scalar_feats],
            edge_dst,
        )[0]
        
        spin_out = scatter(scalar_feats, batch, dim=0, reduce='mean')
        return spin_out[:, :2] + 1.j * spin_out[:, 2:]




In [None]:
regressor = SO3EquivariantRegressor1(gated=True)

In [None]:
batch, pos1, spin_feats1, labels = spin_data(label_type='spinor')

R2, R3 = sample_SU2_2_and_3(1)
R2 = R2[0]
R3 = R3[0]
pos2 = torch.einsum('ij, bj -> bi', R3, pos1)
spin_feats2 = torch.einsum('ij, bj -> bi', R2, spin_feats1)

print('Test rot equivariance')
print(f'Input positions equal: {torch.allclose(pos1, pos2)}')
print(f'Input spinors equal: {torch.allclose(spin_feats1, spin_feats2)}')

print(f'''Outputs equal: {
    torch.allclose(
        torch.einsum('ij, bj -> bi', R2, regressor(batch, pos1, spin_feats1)), 
        regressor(batch, pos2, spin_feats2),
        atol=1e-6
    )
}''')

print(torch.einsum('ij, bj -> bi', R2, regressor(batch, pos1, spin_feats1)))
print(regressor(batch, pos2, spin_feats2))

In [None]:
def acc(pred, labels):
    accuracy = pred.max(dim=1)[1].eq(labels).double().mean(dim=0).item()
    return accuracy

def spinor_loss(pred, labels):
    # normalize norm
    pred = pred / pred.norm(dim=1, keepdim=True).clamp(min=1e-8)
    labels = labels / labels.norm(dim=1, keepdim=True).clamp(min=1e-8)
    return torch.minimum(torch.norm(pred - labels, dim=1), torch.norm(pred + labels, dim=1)).mean()
    
def train(model, rot_augm=False, noise=0, epochs=400, lr=1e-2, spinor_labels=False, print_out=True):
    model.train()
        
    if spinor_labels:
        loss_fn = spinor_loss
        label_type = 'spinor'
    else:
        loss_fn = torch.nn.CrossEntropyLoss()
        label_type = 'integer'
    
    batch, pos, spin_feats, labels = spin_data(random_su2=False, label_type=label_type)

    optim = torch.optim.Adam(model.parameters(), lr=lr)


    # == Train ==
    for step in range(epochs):
        if noise > 0:
            batch, pos, spin_feats, labels = spin_data(random_su2=rot_augm, random_noise=noise, label_type=label_type)
        elif rot_augm:
            batch, pos, spin_feats, labels = spin_data(random_su2=True, label_type=label_type)
        pred = model(batch, pos, spin_feats)
        loss = loss_fn(pred, labels)

        optim.zero_grad()
        loss.backward()
        optim.step()

        if print_out and step % 10 == 0:
            #print(pred)
            if spinor_labels:
                print(f"epoch {step:5d} | loss {loss:<10.3f}")
            else:
                accuracy = acc(pred, labels)
                print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")

            
def evaluate(model, noise=0, epochs=100, spinor_labels=False):
    accuracy = 0.
    loss = 0.
        
    if spinor_labels:
        loss_fn = spinor_loss
        label_type = 'spinor'
    else:
        loss_fn = torch.nn.CrossEntropyLoss()
        label_type = 'integer'
    
    for step in range(epochs):
        batch, pos, spin_feats, labels = spin_data(random_su2=True, 
                                                   random_sign=True, 
                                                   random_noise=noise, 
                                                   label_type=label_type)
        pred = model(batch, pos, spin_feats)
           
        loss += loss_fn(pred, labels) / epochs
        if not spinor_labels:
            accuracy += acc(pred, labels) / epochs
        
    if spinor_labels:
        return loss
    else:
        return accuracy, loss

In [None]:
# Count complex parameters as 2 real valued if `complex_2_real`
def numel(p, complex_2_real=True):
    if complex_2_real and p.dtype in [torch.cfloat, torch.cdouble]:
        return 2 * p.numel()
    return p.numel()

In [None]:
noise_level = 0.2

In [None]:
model = SO3EquivariantRegressor1(gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-4, spinor_labels=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = SO3EquivariantRegressor2(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-2, spinor_labels=True)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = SO3EquivariantRegressor3(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-2, spinor_labels=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = SO3VectorRegressor3(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-3, spinor_labels=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = SO3EquivariantRegressor4(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-3, spinor_labels=True)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = SO3EquivariantRegressor5(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-3, spinor_labels=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = NonEquivariantRegressor(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-3, spinor_labels=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
model = NonEquivariantRegressor(nbr_scalar_feats=32, gated=True)
#print(model.layer1.spinor_to_spinor.weight)
print(sum([numel(p) for p in model.parameters()]))
print(sum([numel(p, False) for p in model.parameters()]))
train(model, noise=noise_level, lr=1e-3, spinor_labels=True, rot_augm=True, epochs=200)
#print(model.layer1.spinor_to_spinor.weight)

In [None]:
model.eval()
loss = evaluate(model, noise=noise_level, epochs=200, spinor_labels=True)
print(f'EVALUATION --- loss: {loss}')

In [None]:
class Experiment():
    def __init__(self,
                 model_class,
                 gated,
                 nbr_epochs,
                 rot_augm,
                 name):
        self.model_class = model_class
        self.gated = gated
        self.nbr_epochs = nbr_epochs
        self.rot_augm = rot_augm
        self.name = name

In [None]:
# Regression evaluation over several noise levels and multiple runs  
nbr_epochs = 300
        
experiments = [
    Experiment(SO3EquivariantRegressor1, True, nbr_epochs, False, 'Spinors as features'),
    Experiment(SO3EquivariantRegressor2, True, nbr_epochs, False, 'Spinors as filters'),
    Experiment(SO3EquivariantRegressor3, True, nbr_epochs, False, 'Spinors squared\nas vector features'),
    Experiment(SO3EquivariantRegressor4, True, nbr_epochs, False, 'Spinors squared\nas vector filters'),
    Experiment(NonEquivariantRegressor, True, nbr_epochs, False, 'Spinors as scalars'),
    Experiment(NonEquivariantRegressor, True, nbr_epochs, True, 'Spinors as scalars\nwith rot. augm.'),
]

nbr_runs = 30
noise_levels = [0, 0.1, 0.2, 0.3, 0.4]

eval_losses = np.zeros([len(experiments), len(noise_levels), nbr_runs])
for e_idx, exp in enumerate(experiments):
    print(f'Running experiment: {exp.name}.')
    model = exp.model_class(gated=exp.gated)
    for n_idx, noise_level in enumerate(noise_levels):
        for r_idx in range(nbr_runs):
            train(
                model, noise=noise_level, lr=1e-2, spinor_labels=True,
                rot_augm=exp.rot_augm, epochs=exp.nbr_epochs, print_out=False,
            )
            
            model.eval()
            eval_losses[e_idx, n_idx, r_idx] = evaluate(
                model, noise=noise_level, epochs=1000, spinor_labels=True,
            )

In [None]:
plt.style.use('default')

In [None]:
# Visualize experiments

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": "Computer Modern Serif",
    "font.size": 14,
})
plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"
plt.style.use('tableau-colorblind10')

loss_mean = np.mean(eval_losses, axis=2)
loss_std = eval_losses.std(axis=2)
for e_idx, exp in enumerate(experiments):
    if 'skip' in exp.name:
        continue
    plt.errorbar(noise_levels, loss_mean[e_idx], label=exp.name, yerr=loss_std[e_idx])
    
plt.legend()
handles, labels = plt.gca().get_legend_handles_labels()
order = np.argsort(loss_mean[:, -2])[::-1]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])

plt.xlabel('Noise level')
plt.ylabel('Evaluation loss')


In [None]:
# Visualize spin data 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from itertools import product, combinations

from pathlib import Path


def plot_all_reg(pos,
                 spin_feats,
                 labels,
                 save_path=None,
                 colors=['orange', 'forestgreen', 'blueviolet'],
                 r=0.05,
                 plot_limit=1.1,
                ):
    for shape_idx in range(4):
        ax = plt.figure().add_subplot(projection='3d')

        # draw spheres at each location
        mean_pos = torch.mean(pos[3*shape_idx:3*(shape_idx+1)], dim=0)
        for pos_idx, p in enumerate(
            pos[3*shape_idx:3*(shape_idx+1)]
        ):
            u, v = np.mgrid[0:2*np.pi:30j, 0:np.pi:15j]
            x = p[0] - mean_pos[0] + r*np.cos(u)*np.sin(v)
            y = p[1] - mean_pos[1] + r*np.sin(u)*np.sin(v)
            z = p[2] - mean_pos[2] + r*np.cos(v)
            ax.plot_wireframe(x, y, z, color=colors[pos_idx])

        ax.set_aspect("equal")
        ax.axes.set_xlim3d(left=-plot_limit, right=plot_limit) 
        ax.axes.set_ylim3d(bottom=-plot_limit, top=plot_limit) 
        ax.axes.set_zlim3d(bottom=-plot_limit, top=plot_limit) 
        plt.tick_params(top=False, bottom=False, left=False, right=False,
                        labelleft=False, labelbottom=False)
        
        if save_path is not None:
            plt.savefig(save_path / f'3D_{shape_idx}.svg')
        plt.show()

        # spinor plots
        spinors = spin_feats[3*shape_idx:3*(shape_idx+1)]
        fig, axs = plt.subplots(1, 4, figsize=[8, 2])
        for spin_idx, s in enumerate(
            spinors
        ):
            axs[spin_idx].quiver(0., 0., s[0].real, s[1].real, scale=3., width=.03, headwidth=3.4, color='r')
            axs[spin_idx].quiver(0., 0., s[0].imag, s[1].imag, scale=3., width=.03, headwidth=3.4, color='b')

            axs[spin_idx].tick_params(top=False, bottom=False, left=False, right=False,
                                      labelleft=False, labelbottom=False)
            axs[spin_idx].set_aspect("equal")
            
            #axs[spin_idx].tick_params(color='green', labelcolor='green')
            for spine in axs[spin_idx].spines.values():
                spine.set_edgecolor(colors[spin_idx])
                spine.set_linewidth(2.)
                
        # plot label
        s = labels[shape_idx]
        axs[3].quiver(0., 0., s[0].real, s[1].real, scale=3., width=.03, headwidth=3.4, color='r')
        axs[3].quiver(0., 0., s[0].imag, s[1].imag, scale=3., width=.03, headwidth=3.4, color='b')

        axs[3].tick_params(top=False, bottom=False, left=False, right=False,
                                  labelleft=False, labelbottom=False)
        axs[3].set_aspect("equal")

        for spine in axs[3].spines.values():
            spine.set_linewidth(2.)
                
        if save_path is not None:
            fig.savefig(save_path / f'2D_{shape_idx}.svg')


batch, pos, spin_feats, labels = spin_data(random_su2=False, label_type='spinor')
plot_all_reg(pos, spin_feats, labels, save_path=None)