In [2]:
import numpy as np
import random
import matplotlib.pyplot as plt
import math
import scipy
from scipy.spatial import distance
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import tqdm



  from tqdm.autonotebook import tqdm


In [4]:
# torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

# TODO: normalise input and weights + use gaussian initialisation

T_DECAY_DEFAULT = 0.0005
W_BOOST_DEFAULT = 0.02

class FpConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros', t_decay=T_DECAY_DEFAULT, w_boost=W_BOOST_DEFAULT, device=None, dtype=None):
        if bias:
            raise Exception("Bias not supported for FrontPropConv2d")
        if dilation != 1:
            raise Exception("Dilation not supported for FrontPropConv2d")
        if groups != 1:
            raise Exception("Groups not supported for FrontPropConv2d")
        if kernel_size[0] != kernel_size[1]:
            raise Exception("Non-square kernel not supported for FrontPropConv2d")
        
        super(FpConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)

        # learning is through front propagation only
        self.weight.requires_grad = False

        # init using Gaussian
        nn.init.normal_(self.weight, mean=0, std=1)

        self.kernel_size = kernel_size[0]

        # hyper params
        self.t_decay = t_decay
        self.w_boost = w_boost

        self.frozen = False

        assert self.t_decay > 0        
        assert self.w_boost > 0

        # see lazy_init_thresholds()
        self.t = None
    

    def lazy_init_thresholds(self, out_h, out_w):
        if self.t is None:
            self.t = torch.ones(self.out_channels, out_h, out_w, device=self.device, dtype=self.weight.dtype)


    def __normalise_unitary(self, data, dim):
        return data / torch.norm(data, dim=dim, keepdim=True)
    

    def __get_weights_boost(self, kernel_idx, data_tensor):
        # FIXME: can i use this addition, or do i need to use the angle?

        kernel_weights = self.weight[kernel_idx]

        assert data_tensor.shape == kernel_weights.shape

        w_boost = self.w_boost * (data_tensor - kernel_weights)

        return w_boost
    

    def get_input_patch(self, input, sample_idx, out_h, out_w):
        in_h = out_h * self.stride - self.padding
        in_w = out_w * self.stride - self.padding
        return input[sample_idx, :, 
                     in_h : in_h + self.kernel_size, 
                     in_w : in_w + self.kernel_size]
    

    def forward(self, input):

        assert input.shape[1] == self.in_channels
        
        self.output = F.conv2d(input, self.weights, self.biases, self.stride, self.padding, self.dilation, self.groups)

        assert self.output.shape == (input.shape[0], self.out_channels, input.shape[-2], input.shape[-1])

        self.lazy_init_thresholds(self.output.shape[-2], self.output.shape[-1])

        assert self.t.shape == self.output.shape

        # ReLU-like non-linear transformation via cutoff threshold
        self.output = torch.where(self.output >= self.t, self.output - self.t, torch.zeros_like(self.output))

        
        # Learning happens below:
        #
        #   For each location where the threshold was exceeded, update the weights.
        #   The weights are updated by a small amount towards the input data tensor
        #   and thresholds are set to the new value where the threshold was exceeded.
        #   Thresholds are also decayed by a small amount on each pass.
        #
        #   Convolution kernels' weights are updated in random order (of samples and locations).
        #
        #   The output is equal to the activation above the threshold (or zero if below threshold).
        #
        #   ---
        #   TODO: Should we output the absolute value or only the diff above threshold ?
        #   (note the threshold changes on each pass)
        #
        #   FIXME: all inputs should be first normalised - this is tricky
        #   TODO: Optimise this
        #   * Try removing loops
        #   * maybe use torch.sparse_coo_tensor() to save memory
        if not self.frozen:
            
            # Update weights:

            excitations_idxs = torch.nonzero(self.output)
            # shuffle indices randomly
            excitations_idxs = excitations_idxs[torch.randperm(excitations_idxs.shape[0])]
            # iterate locations where threshold exceeded and shift weights closer to the input
            for sample_idx, kernel_idx, h, w in excitations_idxs:
                data_tensor = self.get_input_patch(input, sample_idx, h, w)
                assert data_tensor.shape == (self.in_channels, self.kernel_size, self.kernel_size)
                data_tensor = self.__normalise_unitary(data_tensor, dim=(1,2))
                assert data_tensor.shape == (self.in_channels, self.kernel_size, self.kernel_size)
                # update weights
                self.weight[kernel_idx] += self.__get_weights_boost(kernel_idx, data_tensor)
                # normalise weights
                self.weight[kernel_idx] = self.__normalise_unitary(self.weight[kernel_idx], dim=(1,2))
                assert self.weight[kernel_idx].shape == (self.in_channels, self.kernel_size, self.kernel_size)

            # Update thresholds:
                
            self.binary_excitations = self.output >= 0
            # where threshold exceeded set it to the new value
            self.t = self.binary_excitations * self.output + ~self.binary_excitations * self.t
            # decay all thresholds
            self.t = self.t * (1.0 - self.t_decay)


        return self.output

    def backward(self, grad_output):
        raise Exception("Backward pass not implemented for FrontPropConv2d")


    def freeze(self):
        self.frozen = True
            

    def unfreeze(self):
        self.frozen = False


In [20]:
in_ch = 3
out_ch = 5
kernel_size = 2
stride = 1

conv2d = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride)

In [21]:
batch_size = 10
img_w = 16
img_h = 16

# 3 channels, 16x16 image
x = torch.rand((batch_size, in_ch, img_w, img_h))

In [32]:
torch.nonzero(x)

tensor([[ 0,  0,  0,  0],
        [ 0,  0,  0,  1],
        [ 0,  0,  0,  2],
        ...,
        [ 9,  2, 15, 13],
        [ 9,  2, 15, 14],
        [ 9,  2, 15, 15]])

In [38]:
~(torch.nonzero(x) > 0) 

tensor([[ True,  True,  True,  True],
        [ True,  True,  True, False],
        [ True,  True,  True, False],
        ...,
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])

In [22]:
# out_ch x in_ch x kernel_size x kernel_size
conv2d.weight.shape

torch.Size([5, 3, 2, 2])

In [23]:
# out_ch
conv2d.bias.shape

torch.Size([5])

In [24]:
# batch_size x out_ch x img_w x img_h
out = conv2d(x)
out.shape

torch.Size([10, 5, 15, 15])

In [26]:
patches = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
print(patches.shape)

# Get output shape
out_height, out_width = out.shape[-2:]
print(out_height, out_width)

# Reshape and permute patches to get exploded output
exploded_output = patches.contiguous().view(batch_size, in_ch, out_height * kernel_size, out_width * kernel_size)
exploded_output.shape

torch.Size([10, 3, 15, 15, 2, 2])
15 15


torch.Size([10, 3, 30, 30])

In [30]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to C:\Users\karol/.cache\torch\hub\v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\karol/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:04<00:00, 9.98MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  