In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
import math
import scipy
from scipy.spatial import distance
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import tqdm



  from tqdm.autonotebook import tqdm


In [5]:

T_DECAY_DEFAULT = 0.0005
W_BOOST_DEFAULT = 0.02

class FpLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True, t_decay=T_DECAY_DEFAULT, w_boost=W_BOOST_DEFAULT):
        if bias:
            raise Exception("Bias not supported for FrontPropLinear")
        
        super(FpLinear, self).__init__(in_features, out_features, bias=False)

        # learning is through front propagation only
        self.weight.requires_grad = False

        # init using Gaussian
        nn.init.normal_(self.weight, mean=0, std=1)

        # hyper params
        self.t_decay = t_decay
        self.w_boost = w_boost

        self.frozen = False

        assert self.t_decay > 0        
        assert self.w_boost > 0

        # init thresholds
        self.t = torch.ones(self.out_features, device=self.device, dtype=self.weight.dtype)
    

    def __normalise_unitary(self, data):
        return data / torch.norm(data, dim=1, keepdim=True)
    

    def __get_weights_boost(self, data_vector):
        assert data_vector.shape == self.weight.shape

        w_boost = self.w_boost * (data_vector - self.weight)
        excited_filter = self.excitations.unsqueeze(1).expand_as(self.weight)
        w_boost = w_boost * excited_filter

        assert w_boost.shape == self.weight.shape
        return w_boost
    

    def forward_single_sample(self, data):
        
        data_vector = data.expand(self.out_features, -1)
        data_vector = self.__normalise_unitary(data_vector)

        output = torch.sum(data_vector * self.weight, dim=1)

        assert torch.all(output > -1.01) and torch.all(output < 1.01)

        excitations = (output >= self.t).float()

        if not self.frozen:
            self.weight = self.weight + self.__get_weights_boost(data_vector)
            self.weight = self.__normalise_unitary(self.weight)
            self.t = excitations * output + (1.0 - excitations) * self.t
            self.t = self.t * (1.0 - self.t_decay)

        # output = self.__normalise_unitary(output)
        # Output should already be normalised
        assert torch.all_close(torch.norm(output, dim=1), torch.ones(self.out_features), atol=1e-3) 
        
        self.__assert()
        
        return output
    

    def forward(self, input):
        # FIXME:
        # As of now, samples are just processed sequentially, for simplicity.

        assert input.shape == (input.shape[0], self.in_features)

        for i, sample in enumerate(input):
            sample_out = self.forward_single_sample(sample)
            self.output[i] = sample_out

        assert self.output.shape == (input.shape[0], self.out_features)
    
        # ReLU-like non-linear transformation via cutoff threshold
        #
        #   TODO: Should we output the absolute value or only the diff above threshold ?
        #   (note the threshold changes on each pass)
        self.output = torch.where(self.output >= self.t, self.output - self.t, torch.zeros_like(self.output))

        return self.output

    def backward(self, grad_output):
        raise Exception("Backward pass not implemented for FrontPropConv2d")


    def freeze(self):
        self.frozen = True
            

    def unfreeze(self):
        self.frozen = False


In [20]:
in_ch = 3
out_ch = 5
kernel_size = 2
stride = 1

conv2d = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride)

In [21]:
batch_size = 10
img_w = 16
img_h = 16

# 3 channels, 16x16 image
x = torch.rand((batch_size, in_ch, img_w, img_h))

In [32]:
torch.nonzero(x)

tensor([[ 0,  0,  0,  0],
        [ 0,  0,  0,  1],
        [ 0,  0,  0,  2],
        ...,
        [ 9,  2, 15, 13],
        [ 9,  2, 15, 14],
        [ 9,  2, 15, 15]])

In [38]:
~(torch.nonzero(x) > 0) 

tensor([[ True,  True,  True,  True],
        [ True,  True,  True, False],
        [ True,  True,  True, False],
        ...,
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])

In [22]:
# out_ch x in_ch x kernel_size x kernel_size
conv2d.weight.shape

torch.Size([5, 3, 2, 2])

In [23]:
# out_ch
conv2d.bias.shape

torch.Size([5])

In [24]:
# batch_size x out_ch x img_w x img_h
out = conv2d(x)
out.shape

torch.Size([10, 5, 15, 15])

In [26]:
patches = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
print(patches.shape)

# Get output shape
out_height, out_width = out.shape[-2:]
print(out_height, out_width)

# Reshape and permute patches to get exploded output
exploded_output = patches.contiguous().view(batch_size, in_ch, out_height * kernel_size, out_width * kernel_size)
exploded_output.shape

torch.Size([10, 3, 15, 15, 2, 2])
15 15


torch.Size([10, 3, 30, 30])

In [30]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to C:\Users\karol/.cache\torch\hub\v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\karol/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:04<00:00, 9.98MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  