In [1]:
import importlib
import moment_kernels as mk
importlib.reload(mk)
import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
%matplotlib notebook
import numpy as np

In [2]:
dataset = CIFAR10('/home/dtward/data',transform=ToTensor())

In [3]:
loader = torch.utils.data.DataLoader(dataset,batch_size=32,shuffle=True)

In [4]:
for x,l in loader:
    break
fig,ax = plt.subplots()
ax.imshow(x[0].permute(1,2,0))
ax.set_title(f'{l[0]}')

<IPython.core.display.Javascript object>

Text(0.5, 1.0, '6')

In [5]:
class Net(torch.nn.Module):
    def __init__(self,):
        super().__init__()
        
        channels0 = 8
        kernel_size = 3
        padding = 1
        out_channels = 10
        self.c0 = mk.ScalarVectorToScalarVector(in_scalars=3,in_vectors=0,out_scalars=channels0,out_vectors=channels0, kernel_size=kernel_size, padding=1)
        self.b0 = mk.ScalarVectorBatchnorm(channels0,channels0)
        self.s0 = mk.ScalarVectorSigmoid(channels0)
        
        self.c1 = mk.ScalarVectorToScalarVector(in_scalars=channels0,in_vectors=channels0,out_scalars=channels0*2,out_vectors=channels0*2,kernel_size=kernel_size,padding=padding)
        self.d1 = mk.Downsample()
        self.b1 = mk.ScalarVectorBatchnorm(channels0*2,channels0*2)
        self.s1 = mk.ScalarVectorSigmoid(channels0*2)
        
        self.c2 = mk.ScalarVectorToScalarVector(in_scalars=channels0*2,in_vectors=channels0*2,out_scalars=channels0*4,out_vectors=channels0*4,kernel_size=kernel_size,padding=padding)
        self.d2 = mk.Downsample()
        self.b2 = mk.ScalarVectorBatchnorm(channels0*4,channels0*4)
        self.s2 = mk.ScalarVectorSigmoid(channels0*4)
        
        self.c3 = mk.ScalarVectorToScalarVector(in_scalars=channels0*4,in_vectors=channels0*4,out_scalars=channels0*8,out_vectors=channels0*8,kernel_size=kernel_size,padding=padding)
        self.d3 = mk.Downsample()
        self.b3 = mk.ScalarVectorBatchnorm(channels0*8,channels0*8)
        self.s3 = mk.ScalarVectorSigmoid(channels0*8)
        
        self.c4 = mk.ScalarVectorToScalarVector(in_scalars=channels0*8,in_vectors=channels0*8,out_scalars=channels0*16,out_vectors=channels0*16,kernel_size=kernel_size,padding=padding)
        self.d4 = mk.Downsample()
        self.b4 = mk.ScalarVectorBatchnorm(channels0*16,channels0*16)
        self.s4 = mk.ScalarVectorSigmoid(channels0*16)
        
        self.c5 = mk.ScalarVectorToScalarVector(in_scalars=channels0*16,in_vectors=channels0*16,out_scalars=out_channels,out_vectors=0,kernel_size=kernel_size,padding=padding)
        self.d5 = mk.Downsample()
        
        
        
        
        
    def forward(self,x):
        x = self.c0(x)
        x = self.b0(x)
        x = self.s0(x)
        
        x = self.c1(x)
        x = self.d1(x)
        x = self.b1(x)
        x = self.s1(x)
        
        x = self.c2(x)
        x = self.d2(x)
        x = self.b2(x)
        x = self.s2(x)
        
        x = self.c3(x)
        x = self.d3(x)
        x = self.b3(x)
        x = self.s3(x)
        
        x = self.c4(x)
        x = self.d4(x)
        x = self.b4(x)
        x = self.s4(x)
        
        # last time no nonlinearity
        x = self.c5(x)
        x = self.d5(x)
        
        # average out any spatial dimensions
        x = torch.mean(x,(-1,-2))
        
        
        return x

In [6]:
net = Net()
count = 0
for p in net.parameters():
    count += p.numel()
print(f'{count} parameters')

138580 parameters


In [7]:
net(x).shape

torch.Size([32, 10])

In [8]:
kernel = net.c1.vv.c

In [9]:
# there are 16 input channels (8 vectors)
# there are 32 output channels (16 vectors)
# for each row/col, we can divide it into 2x2 blocks
# we take a bunch of 2x2 blocks, and stack them together to use pytorch's built in conv function.
kernel.shape

torch.Size([32, 16, 3, 3])

In [10]:
kernel[:2,:2,1,1] # note this is proportional to identity, times -0.2

tensor([[-0.0936, -0.0000],
        [-0.0000, -0.0936]], grad_fn=<SelectBackward0>)

In [11]:
kernel[:2,:2,1,0] # not proportional to identity, because it has the xx^T and id mixed together

tensor([[-0.1880,  0.0000],
        [ 0.0000, -0.4479]], grad_fn=<SelectBackward0>)

In [12]:
kernel[:2,:2,0,0] # as we look at different rows and columns we see different numbers

tensor([[0.0104, 0.0674],
        [0.0674, 0.0104]], grad_fn=<SelectBackward0>)

In [13]:
kernel[:2,:2,1,2] # note that this is evaluated at x=(0,1), the previous was y=(0,-1)= -x

tensor([[-0.1880, -0.0000],
        [-0.0000, -0.4479]], grad_fn=<SelectBackward0>)

In [14]:
# there is a matrix valued function for each row and column (row/column = "location x")
# the components are not independent of one another, 
# they depend on one another in a predictable way as a function of space

In [15]:
# check the equivariance
with torch.no_grad():
    err = net(x) - net(x.rot90(k=1,dims=(-1,-2)))
    print(torch.sqrt(torch.mean(err**2)).item())
    
    err = net(x) - net(x.flip(dims=(-1,)))
    print(torch.sqrt(torch.mean(err**2)).item())

3.871007265843218e-07
2.905766791627684e-07


In [16]:
optimizer = torch.optim.Adam(net.parameters())
loss = torch.nn.CrossEntropyLoss()

In [17]:
class Net1(torch.nn.Module):
    def __init__(self,channels0):
        super().__init__()
        
        #channels0 = 16
        kernel_size = 3
        padding = 1
        out_channels = 10
        self.c0 = torch.nn.Conv2d(3,channels0,kernel_size=kernel_size,padding=1)
        self.b0 = torch.nn.BatchNorm2d(channels0)
        self.s0 = torch.nn.ReLU()
        
        self.c1 = torch.nn.Conv2d(channels0,channels0*2,kernel_size=kernel_size,padding=padding,stride=2)
        self.b1 = torch.nn.BatchNorm2d(channels0*2)
        self.s1 = torch.nn.ReLU()
        
        self.c2 = torch.nn.Conv2d(channels0*2,channels0*4,kernel_size=kernel_size,padding=padding,stride=2)
        self.b2 = torch.nn.BatchNorm2d(channels0*4)
        self.s2 = torch.nn.ReLU()
        
        self.c3 = torch.nn.Conv2d(channels0*4,channels0*8,kernel_size=kernel_size,padding=padding,stride=2)
        self.b3 = torch.nn.BatchNorm2d(channels0*8)
        self.s3 = torch.nn.ReLU()
        
        self.c4 = torch.nn.Conv2d(channels0*8,channels0*16,kernel_size=kernel_size,padding=padding,stride=2)
        self.b4 = torch.nn.BatchNorm2d(channels0*16)
        self.s4 = torch.nn.ReLU()
        
        self.c5 = torch.nn.Conv2d(channels0*16,out_channels,kernel_size=kernel_size,padding=padding,stride=2)
        
        
        
        
        
    def forward(self,x):
        x = self.c0(x)
        x = self.b0(x)
        x = self.s0(x)
        
        x = self.c1(x)        
        x = self.b1(x)
        x = self.s1(x)
        
        x = self.c2(x)        
        x = self.b2(x)
        x = self.s2(x)
        
        x = self.c3(x)        
        x = self.b3(x)
        x = self.s3(x)
        
        x = self.c4(x)        
        x = self.b4(x)
        x = self.s4(x)
        
        # last time no nonlinearity
        x = self.c5(x)        
        
        # average out any spatial dimensions
        x = torch.mean(x,(-1,-2))
        
        
        return x

In [18]:
# about the same number of parameters
net1 = Net1(9)
count = 0
for p in net1.parameters():
    count += p.numel()
print(f'{count} parameters')

137980 parameters


In [19]:
optimizer1 = torch.optim.Adam(net1.parameters())

In [None]:

nepochs = 10
Esave = []
E1save = []
fig,ax = plt.subplots()
for e in range(nepochs):
    E_ = []
    E1_ = []
    count = 0
    for x,l in loader:
        optimizer.zero_grad()
        lhat = net(x)
        E = loss(lhat,l)
        E.backward()
        optimizer.step()
        E_.append(E.item())
        
        optimizer1.zero_grad()
        lhat = net1(x)
        E = loss(lhat,l)
        E.backward()
        optimizer1.step()
        E1_.append(E.item())
        
        
        if count > 100: 
            break # stop early for this example
        count += 1
    Esave.append(np.mean(E_))
    E1save.append(np.mean(E1_))
    ax.cla()
    ax.plot(Esave,label='Rot')
    ax.plot(E1save,label='Standard')
    ax.legend()
    fig.canvas.draw()
    
    

<IPython.core.display.Javascript object>

  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
