In [None]:
!pip install torch torchvision numpy Pillow pytest

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pytest
import pathlib

## One-Dimensional Steerable Kernels (1pt):
In the lecture, we derived a basis for a steerable kernel for the D4 group with its standard action on a 3x3 grid, and the B1 representation. 

Here, we want to generalize this to any 1-dimensional representation of D4.

Your task is to implement the following function, which the cells below will then use to generate visualizations of D4-steerable kernels w.r.t. the different 1-dimensional representations.

In [None]:
def impose_weight_sharing(start, a, b):
    """
    start: initial vector 
    a: representation for 90° rotation
    b: representation for flip around axis 1
    """
    # your code here
    raise NotImplementedError

In [None]:
test = np.zeros((3, 3))
test[0, 0] = 1
expected = np.zeros((3, 3))
expected[0, 0] = 0.25
expected[2, 0] = 0.25
expected[0, 2] = 0.25
expected[2, 2] = 0.25
result = impose_weight_sharing(test, 1, 1)
assert pytest.approx(result) == expected

In [None]:
def visualize_weight_sharing(a, b):
    dim = 5
    input_image = np.zeros((dim, dim))
    fig, ax = plt.subplots(dim, dim)
    for x in range(dim):
        for y in range(dim):
            v = np.copy(input_image)
            v[x, y] = 1
            ax[x, y].imshow(impose_weight_sharing(v, a, b))
    plt.show()

### A1 - Trivial representation

In [None]:
visualize_weight_sharing(1, 1)

### A2

In [None]:
visualize_weight_sharing(1, -1)

### B1

In [None]:
visualize_weight_sharing(-1, 1)

### B2

In [None]:
visualize_weight_sharing(-1, -1)

## Equivariant convolutional network (3pts)
In this part, we are going to develop a K4-invariant convolutional network following the GDL blueprint.
The K4 group consists of 4 elements, horizontal and vertical flip as well as 180° rotation.

###  First layer
As a first task, implement an equivariant convolution layer as described in the lecture (i.e., equivariant w.r.t. the regular representation). As the group has four elements, the output is supposed to be of shape `batch x out_channels x 4 x w x h`, i.e., do not merge the channels together in one dimension. While the ordering of group elements in the 4 sub-channels is in principle arbitrary, to keep things easily comparable (and make the asserts happy), please use the following ordering:
- identity
- flip(dim=2)
- flip(dim=3)
- rot-180°

In [None]:
class K4ConvLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty((out_channels, in_channels, kernel_size, kernel_size)))
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, input):
        # your code here
        raise NotImplementedError

In [None]:
test_layer = K4ConvLayer(1, 1, 3)
test_input = torch.rand((1, 1, 3, 3))
original = test_layer(test_input)
transformed = test_layer(torch.flip(test_input, dims=(2,)))

# check that channels permute according to regular representation
assert pytest.approx(original[0, 0, 0, 0, 0].item()) == transformed[0, 0, 1, 0, 0].item()
assert pytest.approx(original[0, 0, 1, 0, 0].item()) == transformed[0, 0, 0, 0, 0].item()
assert pytest.approx(original[0, 0, 2, 0, 0].item()) == transformed[0, 0, 3, 0, 0].item()
assert pytest.approx(original[0, 0, 3, 0, 0].item()) == transformed[0, 0, 2, 0, 0].item()

### Transformation according to the regular representation
For a regular CNN, we would be done now, and could just start stacking these layers on top of each other.
However, there is a problem now. In the first layer, when we transform the input according to a group element, it is only the position of pixels that changes, but the channels (RGB colors) remain invariant. But the feature map we gain after the first group-convolution transforms in a different way; the pixels change position, and also, there is a permutation of the channel maps. 
Therefore, we need to define a group convolution layer that is equivariant with respect to this new kind of input transformation.

As a first step, implement the function `transform_regular`, which, when given a feature map such as the one produced by our convolution above, and an element of the K4 group (k = 0: identity, 1: flip-2, 2: flip-3, 3: rot-180), produces a version of the feature map as it is transformed by the regular representation of K4. (Strictly speaking, since we have multiple independent `out_channels`, it transforms according to the direct sum of `out_channels` many copies of the regular representation).

In [None]:
def transform_regular(data, k):
    # your code here
    raise NotImplementedError


In [None]:
dummy = torch.zeros(1, 1, 4, 3, 3)
dummy[0, 0, 0, 0, 1] = 1.0

# identity
result = transform_regular(dummy, 0)
assert np.all(result.numpy() == dummy.numpy()), (result, dummy)

# reflection around first axis; shift by one position
result = transform_regular(dummy, 1)

expected = torch.zeros(1, 1, 4, 3, 3)
expected[0, 0, 1, 2, 1] = 1

assert np.all(result.numpy() == expected.numpy()), (result, expected)

# reflection around second axis; shift by two
result = transform_regular(dummy, 2)

expected = torch.zeros(1, 1, 4, 3, 3)
expected[0, 0, 2, 0, 1] = 1

assert np.all(result.numpy() == expected.numpy()), (result, expected)

# combining both reflections
result = transform_regular(dummy, 3)
expected = transform_regular(transform_regular(dummy, 1), 2)

assert np.all(result.numpy() == expected.numpy()), (result, expected)

### Regular-regular steerable convolution
Now we are ready to implement a second convolution layer, which implements an equivariant convolution for the case when the input transforms according to the regular representation.

In [None]:
class K4GConvLayer(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty((out_channels, in_channels, kernel_size, kernel_size)))
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, input):
        # your code here
        raise NotImplementedError

In [None]:
aux_layer = K4ConvLayer(1, 1, 3)
test_layer = K4GConvLayer(4, 1, 3)
test_raw = torch.rand((1, 1, 5, 5))
test_input = aux_layer(test_raw)
o_out = test_layer(test_input)
t_input = aux_layer(torch.flip(test_raw, dims=(2,)))
t_out = test_layer(t_input)

print(o_out.detach().numpy()[0, 0, :, 0, 0])
print(t_out.detach().numpy()[0, 0, :, 0, 0])

# check that channels permute according to regular representation
assert pytest.approx(o_out[0, 0, 0, 0, 0].item(), rel=1e-3) == t_out[0, 0, 1, 0, 0].item()
assert pytest.approx(o_out[0, 0, 1, 0, 0].item(), rel=1e-3) == t_out[0, 0, 0, 0, 0].item()
assert pytest.approx(o_out[0, 0, 2, 0, 0].item(), rel=1e-3) == t_out[0, 0, 3, 0, 0].item()
assert pytest.approx(o_out[0, 0, 3, 0, 0].item(), rel=1e-3) == t_out[0, 0, 2, 0, 0].item()

### Putting it all together
We are finally ready to build the group-invariant CNN. The cells below define and train both a regular CNN and an invariant CNN based on the layers we constructed above.

Get CIFAR data. We are interested in the data-scarce regime, so we subsample the  training set down to just two percent of the original data.

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
batch_size = 8
trainset = torchvision.datasets.CIFAR10(root='/coursedata', train=True,
                                      download=True, transform=transform)
trainset, valset, _ = torch.utils.data.random_split(trainset, [0.02, 0.08, 0.9], generator=torch.Generator().manual_seed(42))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                        shuffle=False)

Define the reference CNN model.

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 12*2, 3)
        self.conv2 = nn.Conv2d(12*2, 12*2, 3)
        self.conv3 = nn.Conv2d(12*2, 24*2, 3)
        self.conv4 = nn.Conv2d(24*2, 24*2, 3)
        self.conv5 = nn.Conv2d(24*2, 48*2, 3)
        self.conv6 = nn.Conv2d(48*2, 10, 3)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv5(x))
        x = self.conv6(x)[..., 0, 0]
        return x

Defining the GCNN model. You may leave the code as-is, but this cell is deliberately left editable to encourage experimentation :)

In [None]:
class GCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = K4ConvLayer(3, 12, 3)
        self.conv2 = K4GConvLayer(12*4, 12, 3)
        self.conv3 = K4GConvLayer(12*4, 24, 3)
        self.conv4 = K4GConvLayer(24*4, 24, 3)
        self.conv5 = K4GConvLayer(24*4, 48, 3)
        self.conv6 = K4GConvLayer(48*4, 10, 3)
    
    def forward_equivariant_features(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(torch.reshape(x, (x.shape[0], 4*12, 28, 28)), 2)
        x = torch.reshape(x, (x.shape[0], 12, 4, 14, 14))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(torch.reshape(x, (x.shape[0], 4*24, 10, 10)), 2)
        x = torch.reshape(x, (x.shape[0], 24, 4, 5, 5))
        x = F.relu(self.conv5(x))
        return x
    
    def forward_equivariant_cls(self, x):
        x = self.forward_equivariant_features(x)
        return self.conv6(x)[..., 0, 0]

    def forward(self, x):
        x = self.forward_equivariant_cls(x)
        x = torch.mean(x, dim=2)
        return x

In [None]:
test_model = GCNN()
test_input = torch.rand((1, 3, 32, 32))
o_out = test_model(test_input)
t_input = torch.flip(test_input, dims=(2,))
t_out = test_model(t_input)

# check that channels permute according to regular representation
assert pytest.approx(o_out.detach().numpy(), rel=1e-3) == t_out.detach().numpy()

Define a training loop with early stopping based on the validation set

In [None]:
def evaluate(net, dataloader):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct, total

def train_and_evaluate(net, checkpoint_name: str):
    if pathlib.Path(checkpoint_name).exists():
        print(f"Loading existing checkpoint {checkpoint_name}")
        net.load_state_dict(torch.load(checkpoint_name, weights_only=True))
        return

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    best = 0
    patience = 0

    for epoch in range(100):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        correct, total = evaluate(net, valloader)        
        print(f'[{epoch + 1}] loss: {running_loss / len(trainset):.3f} eval: {100 * correct / total:.2f}')
        running_loss = 0.0
        
        if correct < best:
            patience += 1
            if patience > 10:
                # early stop
                break
        else:
            torch.save(net.state_dict(), checkpoint_name)
            best = correct
            patience = 0
    
    # reload best weights
    net.load_state_dict(torch.load(checkpoint_name, weights_only=True))

Train the standard CNN

In [None]:
cnn = CNN()
print(f"Parameters: {sum(p.numel() for p in cnn.parameters())}")
train_and_evaluate(cnn, "ckp-cnn")

Train the group-invariant CNN

In [None]:
gcnn = GCNN()
print(f"Parameters: {sum(p.numel() for p in gcnn.parameters())}")
train_and_evaluate(gcnn, "ckp-gcnn")

### Evaluate both the original model, and its group-smoothed version
#### First on original CIFAR data

In [None]:
test_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
testset = torchvision.datasets.CIFAR10(root='/coursedata', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

In [None]:
c, t = evaluate(cnn, testloader)
c4, t4 = evaluate(gcnn, testloader)
print(f"Accuracy: {100 * c / t:.2f}% vs {100 * c4 / t4:.2f}%")

#### Then on augmented CIFAR

In [None]:
test_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.RandomHorizontalFlip(),
                                                 torchvision.transforms.RandomVerticalFlip()])
testset = torchvision.datasets.CIFAR10(root='/coursedata', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

In [None]:
c, t = evaluate(cnn, testloader)
c4, t4 = evaluate(gcnn, testloader)
print(f"Accuracy: {100 * c / t:.2f}% vs {100 * c4 / t4:.2f}%")

### Visualization of equivariant features.
Pick a feature map (`midx`) and a training sample (`sidx`), then run the cells below to see how the feature map transformes when the input transforms.

In [None]:
# pick one feature map
midx = 4
sidx = 1

In [None]:
sample = trainset[sidx][0]
label = trainset[sidx][1]
plt.imshow(sample.permute(1, 2, 0).detach().numpy())
plt.show()

In [None]:
s = sample[None, ...]
fmap = gcnn.forward_equivariant_features(s)[0, midx, :, ...]
lmap = gcnn.forward_equivariant_cls(s)[0, label]
print(lmap.detach().numpy())

fig, ax = plt.subplots(4, 4)

for i in range(4):
    ax[0, i].imshow(fmap[i].detach().numpy())

s = torch.flip(sample[None, ...], dims=(2,))
fmap = gcnn.forward_equivariant_features(s)[0, midx, :, ...]
lmap = gcnn.forward_equivariant_cls(s)[0, label]
print(lmap.detach().numpy())
for i in range(4):
    ax[1, i].imshow(fmap[i].detach().numpy())

s = torch.flip(sample[None, ...], dims=(3,))
fmap = gcnn.forward_equivariant_features(s)[0, midx, :, ...]
lmap = gcnn.forward_equivariant_cls(s)[0, label]
print(lmap.detach().numpy())
for i in range(4):
    ax[2, i].imshow(fmap[i].detach().numpy())

s = torch.rot90(sample[None, ...], k=2, dims=(2,3))
fmap = gcnn.forward_equivariant_features(s)[0, midx, :, ...]
lmap = gcnn.forward_equivariant_cls(s)[0, label]
print(lmap.detach().numpy())
for i in range(4):
    ax[3, i].imshow(fmap[i].detach().numpy())

for i in range(4):
    for j in range(4):
        ax[i, j].set_xticks([])
        ax[i, j].set_yticks([])