In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random, sys, os

In [3]:
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
LR = 0.0001
BS = 200

In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)

In [8]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([200, 1, 28, 28])

## Cuda -bmm2x2

In [9]:
import bmm2x2_cuda

In [10]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
#     @torch.jit.ignore
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]
    
    @staticmethod
#     @torch.jit.ignore
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
        del_input, del_weights = bmm2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [11]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        self.bmmfunc = BMM2x2Function()
        
    @torch.jit.ignore
    def bmm(self, x, w):
        return BMM2x2Function.apply(x, w)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2)
        x = self.bmm(x, self.weight)
        x = x.view(bs, -1)
        return x

## Cuda - Bilinear2x2

In [12]:
import bilinear2x2_cuda

In [13]:
class BiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bilinear2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bilinear2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [14]:
class PairBilinear2(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        self.pairW = nn.Parameter(self.pairW)
    
#     @torch.jit.ignore
#     def pairbl2x2(self, x, w):
#         return BiLinear2x2Function.apply(x, w)
    
#     @torch.jit.ignore
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
    
#         x = x.view(bs, -1, 2).transpose(0,1)
#         x = torch.bmm(x, self.pairW)
#         x = x.transpose(1,0)#.reshape(-1, 2)
        
############# OR This block ########################
        x = x.view(bs, -1, 2)
#         x = BMM2x2Function.apply(x, self.pairW)
####################################################
#         x = x.view(bs, -1, 2)
        x = BiLinear2x2Function.apply(x, self.Y)
        x = x.view(bs, -1)
        return x

## Fused 2x2 Ops

In [15]:
import fused2x2ops_cuda

In [16]:
class FusedBiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights, grids):
        outputs, input_buffer, dbg = fused2x2ops_cuda.bilinear2x2_forward(inputs, weights, grids)
#         print(dbg)

#         outputs, input_buffer = fused2x2ops_cuda.bilinear2x2_forward(inputs, weights, grids)
        ctx.save_for_backward(input_buffer, weights, grids)
        return outputs, dbg

    @staticmethod
    def backward(ctx, grad_output, a):
        input_buffer, weights, grids = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights, del_grids = fused2x2ops_cuda.bilinear2x2_backward(
            input_buffer, 
            weights, 
            grids,
            grad_output)
    
        return del_input, del_weights, del_grids

In [17]:
class Fused2x2BiLinear(nn.Module):
    def __init__(self, dim, grid_width, num_layers = None):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_layers = int(np.ceil(np.log2(self.dim)))
        if num_layers is not None:
            self.num_layers = num_layers
        
        
        self.num_pairs = self.dim // 2
        
#         along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
#         along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        
        ### repeat same for num_pairs
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_layers, dim=0)
        
        print(self.Y.shape)
        self.Y = nn.Parameter(self.Y)
        
        ### repeat same for num_pairs
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.pairW = self.pairW.unsqueeze(0).repeat_interleave(self.num_layers, dim=0)
        
        print(self.pairW.shape)
        self.pairW = nn.Parameter(self.pairW)
    
    def forward(self, x):
        x, self.dbg = FusedBiLinear2x2Function.apply(x.clone(), self.pairW, self.Y)
        return x

In [18]:
N = 16
fbl = Fused2x2BiLinear(N, 3, num_layers=None).to(device)

torch.Size([4, 8, 2, 3, 3])
torch.Size([4, 8, 2, 2])


In [19]:
a = torch.randn(1, N).to(device)
y = fbl(a)

In [20]:
fbl.dbg

tensor([[[[ 0.,  0.,  1.],
          [ 1.,  2.,  3.],
          [ 2.,  4.,  5.],
          [ 3.,  6.,  7.],
          [ 4.,  8.,  9.],
          [ 5., 10., 11.],
          [ 6., 12., 13.],
          [ 7., 14., 15.]]],


        [[[ 0.,  0.,  2.],
          [ 1.,  1.,  3.],
          [ 2.,  4.,  6.],
          [ 3.,  5.,  7.],
          [ 4.,  8., 10.],
          [ 5.,  9., 11.],
          [ 6., 12., 14.],
          [ 7., 13., 15.]]],


        [[[ 0.,  0.,  4.],
          [ 1.,  1.,  5.],
          [ 2.,  2.,  6.],
          [ 3.,  3.,  7.],
          [ 4.,  8., 12.],
          [ 5.,  9., 13.],
          [ 6., 10., 14.],
          [ 7., 11., 15.]]],


        [[[ 0.,  0.,  8.],
          [ 1.,  1.,  9.],
          [ 2.,  2., 10.],
          [ 3.,  3., 11.],
          [ 4.,  4., 12.],
          [ 5.,  5., 13.],
          [ 6.,  6., 14.],
          [ 7.,  7., 15.]]]], device='cuda:0',
       grad_fn=<FusedBiLinear2x2FunctionBackward>)

In [21]:
y.shape

torch.Size([1, 16])

In [22]:
y.requires_grad

True

In [23]:
a

tensor([[ 0.0081, -0.7993,  0.0350,  1.5818,  0.6688, -0.9515,  1.2559,  0.4520,
         -2.5717,  0.0262, -1.1490, -0.1717, -0.7478, -1.1236, -1.2404, -0.7634]],
       device='cuda:0')

In [24]:
y

tensor([[ 0.0081, -0.7993,  0.0350,  1.5818,  0.6688, -0.9515,  1.2559,  0.4520,
         -2.5717,  0.0262, -1.1490, -0.1717, -0.7478, -1.1236, -1.2404, -0.7634]],
       device='cuda:0', grad_fn=<FusedBiLinear2x2FunctionBackward>)

In [25]:
y.mean().backward()

In [26]:
fbl.Y

Parameter containing:
tensor([[[[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0

In [27]:
pbl = PairBilinear2(16, 3).to(device)

In [28]:
pbl.Y

Parameter containing:
tensor([[[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000,

In [29]:
y1 = pbl(a)

In [30]:
y1

tensor([[ 0.0081, -0.7993,  0.0350,  1.5818,  0.6688, -0.9515,  1.2559,  0.4520,
         -2.5717,  0.0262, -1.1490, -0.1717, -0.7478, -1.1236, -1.2404, -0.7634]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [31]:
a

tensor([[ 0.0081, -0.7993,  0.0350,  1.5818,  0.6688, -0.9515,  1.2559,  0.4520,
         -2.5717,  0.0262, -1.1490, -0.1717, -0.7478, -1.1236, -1.2404, -0.7634]],
       device='cuda:0')

In [32]:
y1.mean().backward()

In [33]:
fbl.Y.grad[0][1]

tensor([[[ 0.0000, -0.0676,  0.1257],
         [ 0.0000, -0.0051,  0.0095],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000, -0.0676,  0.1257],
         [ 0.0000, -0.0051,  0.0095],
         [ 0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [34]:
pbl.Y.grad[1]

tensor([[[ 0.0000, -0.0676,  0.1257],
         [ 0.0000, -0.0051,  0.0095],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000, -0.0676,  0.1257],
         [ 0.0000, -0.0051,  0.0095],
         [ 0.0000,  0.0000,  0.0000]]], device='cuda:0')

In [35]:
# fbl.Y.grad[0] - pbl.Y.grad.transpose(-1, -2)
fbl.Y.grad[0] - pbl.Y.grad

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
  

In [36]:
pbl.pairW.grad[1]

TypeError: 'NoneType' object is not subscriptable

In [37]:
fbl.pairW.grad[0][1]

tensor([[0., 0.],
        [0., 0.]], device='cuda:0')

In [38]:
pbl.pairW[1]

tensor([[1., 0.],
        [0., 1.]], device='cuda:0', grad_fn=<SelectBackward>)

In [39]:
fbl.pairW[0][1]

tensor([[1., 0.],
        [0., 1.]], device='cuda:0', grad_fn=<SelectBackward>)

In [40]:
class FusedPairBilinearBlock(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
        
        extra =  2**self.num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(self.input_dim)[:extra]
        
        self.fused_pair_bilinear = Fused2x2BiLinear(2**self.num_layers, grid_width=grid_width, )#num_layers=5)
        
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return self.fused_pair_bilinear(x)

In [41]:
fpbl = FusedPairBilinearBlock(784, 3).to(device)

torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])


In [42]:
y = fpbl(torch.randn(10, 784).to(device))
y.shape

torch.Size([10, 1024])

## Modules and Layers

In [43]:
class BiasLayer(nn.Module):
    def __init__(self, dim, init_val=0):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(dim)*init_val)
        
    def forward(self, x):
        return x+self.bias

In [44]:
class FactorizedPairBilinearSpline(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        self.idx_revidx = []
        for i in range(self.num_layers):
            idrid = self.get_pair(self.input_dim, i+1)
            net = PairBilinear2(self.input_dim, grid_width)
            self.facto_nets.append(net)
            self.idx_revidx.append(idrid)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
#     @torch.jit.ignore
    def get_pair(self, inp_dim, step=1):
        dim = 2**int(np.ceil(np.log2(inp_dim)))
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
        block_map = (pairs_+repeat_).reshape(-1)

        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx[indx<inp_dim]

        rev_indx = torch.argsort(indx)
        return indx, rev_indx
    
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        y = x
#         for i in range(len(self.facto_nets)):
        for i, fn in enumerate(self.facto_nets):
            idx, revidx = self.idx_revidx[i]
            y = y[:, idx]
            y = fn(y) 
            y = y[:, revidx]
#         y = x + y ## this is residual addition... remove if only want feed forward
        return y

In [45]:
class PairBilinearBlock(FactorizedPairBilinearSpline):
    
    def __init__(self, input_dim, grid_width):
        num_layers = int(np.ceil(np.log2(input_dim)))
        extra =  2**num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(input_dim)[:extra]
        
        super().__init__(2**num_layers, grid_width)
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return super().forward(x)

In [46]:
class FactorizedPairBilinearSpline_2(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        for i in range(self.num_layers):
            net = PairBilinear2(self.input_dim, grid_width)
            self.facto_nets.append(net)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        bs = x.shape[0]
        y = x
#         for i in range(len(self.facto_nets)):
        for i, fn in enumerate(self.facto_nets):
#             y = y.view(-1,2,2**(i+1)).permute(0, 2,1).contiguous().view(bs, -1)
            y = y.view(-1,2,2**i).permute(0, 2,1).contiguous().view(bs, -1)
            y = fn(y) 
#             y = y.view(-1,2**(i+1),2).permute(0, 2,1).contiguous()
            y = y.view(-1,2**i,2).permute(0, 2,1).contiguous()
#         y = x + y ## this is residual addition... remove if only want feed forward
        return y.view(bs, -1)

In [47]:
class PairBilinearBlock_2(FactorizedPairBilinearSpline_2):
    
    def __init__(self, input_dim, grid_width):
        num_layers = int(np.ceil(np.log2(input_dim)))
        extra =  2**num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(input_dim)[:extra]
        
        super().__init__(2**num_layers, grid_width)
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return super().forward(x)

In [48]:
pfL = FactorizedPairBilinearSpline(784, 10).to(device)

In [49]:
pfL(torch.randn(100, 784).to(device))

tensor([[ 0.5040,  2.9259, -0.2752,  ...,  0.8611, -1.0041,  0.2877],
        [ 0.1645, -0.7405, -0.9497,  ...,  0.4982,  0.1717,  0.2237],
        [-1.8787, -0.4780, -1.1443,  ...,  1.0046,  0.4900,  0.2338],
        ...,
        [-1.2427,  0.2971,  0.4793,  ..., -1.9474, -0.2935,  1.4595],
        [ 0.4139, -0.8466, -0.0513,  ..., -0.5742, -0.8309,  0.6456],
        [-1.3414,  0.4671, -1.7051,  ...,  0.7226,  0.0492,  0.7175]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [50]:
_a = torch.randn(100, 784).to(device)

# %timeit pfL(_a)

In [51]:
pfL.facto_nets

ModuleList(
  (0): PairBilinear2()
  (1): PairBilinear2()
  (2): PairBilinear2()
  (3): PairBilinear2()
  (4): PairBilinear2()
  (5): PairBilinear2()
  (6): PairBilinear2()
  (7): PairBilinear2()
  (8): PairBilinear2()
  (9): PairBilinear2()
)

In [52]:
param_count = sum([torch.numel(p) for p in pfL.parameters()])
param_count

799680

In [53]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = BiasLayer(784)
        self.la1 = FactorizedPairBilinearSpline(784, grid_width=3)
        self.bn1 = nn.BatchNorm1d(784)
        self.fc = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [74]:
class FactorNet2(nn.Module):
    def __init__(self):
        super().__init__()
        H = 16
        self.bias = nn.Linear(784, H)
        self.bn1 = nn.BatchNorm1d(H)
        self.fc = nn.Linear(H, 10)
        self.la1 = FusedPairBilinearBlock(H, grid_width=3)
#         self.la1 = PairBilinearBlock_2(H, grid_width=3)

#         self.bias = BiasLayer(784)
#         self.la1 = FusedPairBilinearBlock(784, grid_width=3)
# #         self.la1 = PairBilinearBlock(784, grid_width=3)
#         self.bn1 = nn.BatchNorm1d(1024)
#         self.fc = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x
    
    
class FactorNet2_debug(nn.Module):
    def __init__(self):
        super().__init__()
        H = 16
        self.bias = nn.Linear(784, H)
        self.bn1 = nn.BatchNorm1d(H)
        self.fc = nn.Linear(H, 10)
        self.la1 = PairBilinearBlock_2(H, grid_width=3)

#         self.bias = BiasLayer(784)
# #         self.la1 = FusedPairBilinearBlock(784, grid_width=3)
#         self.la1 = PairBilinearBlock_2(784, grid_width=3)
#         self.bn1 = nn.BatchNorm1d(1024)
#         self.fc = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [75]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 784, bias=False)
        self.bn1 = nn.BatchNorm1d(784)
        self.la2 = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.la2(x)
        return x

In [76]:
model = FactorNet2()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

torch.Size([4, 8, 2, 3, 3])
torch.Size([4, 8, 2, 2])


13466

In [77]:
# model.la1.fused_pair_bilinear.Y.shape

In [78]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(624074, 46.34442299123719)

### Model Development

In [79]:
torch.manual_seed(0)
model = FactorNet2().to(device)
# model = OrdinaryNet().to(device)

model

torch.Size([4, 8, 2, 3, 3])
torch.Size([4, 8, 2, 2])


FactorNet2(
  (bias): Linear(in_features=784, out_features=16, bias=True)
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=16, out_features=10, bias=True)
  (la1): FusedPairBilinearBlock(
    (fused_pair_bilinear): Fused2x2BiLinear()
  )
)

In [80]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

In [81]:
torch.manual_seed(0)
model_deb = FactorNet2_debug().to(device)
model_deb

FactorNet2_debug(
  (bias): Linear(in_features=784, out_features=16, bias=True)
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=16, out_features=10, bias=True)
  (la1): PairBilinearBlock_2(
    (facto_nets): ModuleList(
      (0): PairBilinear2()
      (1): PairBilinear2()
      (2): PairBilinear2()
      (3): PairBilinear2()
    )
  )
)

In [82]:
optimizer_deb = torch.optim.Adam(model_deb.parameters(), lr=0.0003)

In [83]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  13466


In [84]:
for p in model.parameters():
    print(p.shape)

torch.Size([16, 784])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([10, 16])
torch.Size([10])
torch.Size([4, 8, 2, 3, 3])
torch.Size([4, 8, 2, 2])


In [85]:
print("number of params: ", sum(p.numel() for p in model_deb.parameters()))

number of params:  13466


In [86]:
for p in model_deb.parameters():
    print(p.shape)

torch.Size([16, 784])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([10, 16])
torch.Size([10])
torch.Size([8, 2, 3, 3])
torch.Size([8, 2, 2])
torch.Size([8, 2, 3, 3])
torch.Size([8, 2, 2])
torch.Size([8, 2, 3, 3])
torch.Size([8, 2, 2])
torch.Size([8, 2, 3, 3])
torch.Size([8, 2, 2])


In [87]:
(model.bias.weight.data - model_deb.bias.weight.data).abs().sum()

tensor(0., device='cuda:0')

In [88]:
(model.bias.bias.data - model_deb.bias.bias.data).abs().sum()

tensor(0., device='cuda:0')

In [89]:
(model.bn1.weight.data - model_deb.bn1.weight.data).abs().sum()

tensor(0., device='cuda:0')

In [90]:
(model.bn1.bias.data - model_deb.bn1.bias.data).abs().sum()

tensor(0., device='cuda:0')

In [91]:
(model.fc.weight.data - model_deb.fc.weight.data).abs().sum()

tensor(0., device='cuda:0')

In [92]:
(model.fc.bias.data - model_deb.fc.bias.data).abs().sum()

tensor(0., device='cuda:0')

In [93]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    i = -1
    for xx, yy in tqdm(train_loader):
        i += 1 
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        yout_ = model_deb(xx)
        loss_ = criterion(yout_, yy)
        optimizer_deb.zero_grad()
        loss_.backward()
        optimizer_deb.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
        
        ########3 Find Similarities and Differences @@@@@@@@@@
        
        with torch.no_grad():
            diff = yout - yout_
            print(f"{i}: Diff: {diff.abs().mean()}")
        
        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y)):
            print(f"{i}, NAN Grid back; loss {float(loss)}", flush=True)
        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.pairW)):
            print(f"{i}, NAN Weight back; loss {float(loss)}", flush=True)
#         sys.stdout.flush()
        
#         break
#     break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
#     i = 1
    for xx, yy in tqdm(test_loader):
#         i = -1
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
            yout_ = model_deb(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  3%|▎         | 8/300 [00:00<00:06, 42.20it/s]

0: Diff: 0.0
1: Diff: 0.0
2: Diff: 1.2451782938072142e-09
3: Diff: 2.392567699516235e-09
4: Diff: 4.7832728888863585e-09
5: Diff: 5.233101951773733e-09
6: Diff: 6.3853806686609005e-09
7: Diff: 6.981892841650961e-09
8: Diff: 8.173286936141722e-09
9: Diff: 8.870848056119485e-09
10: Diff: 8.894130765213504e-09
11: Diff: 1.0334188615956919e-08
12: Diff: 1.1385418829945593e-08


  7%|▋         | 22/300 [00:00<00:04, 55.67it/s]

13: Diff: 1.2611039323928708e-08
14: Diff: 1.292815454689844e-08
15: Diff: 1.3666228149133985e-08
16: Diff: 1.4203600962048313e-08
17: Diff: 1.7596642365447224e-08
18: Diff: 1.652189673961857e-08
19: Diff: 1.871678989573411e-08
20: Diff: 2.0723790683518928e-08
21: Diff: 2.0030887171174072e-08
22: Diff: 2.1864195787202334e-08
23: Diff: 2.4599025039151456e-08
24: Diff: 2.319878156242794e-08
25: Diff: 2.590334169383368e-08


 13%|█▎        | 38/300 [00:00<00:04, 64.29it/s]

26: Diff: 2.742931393129311e-08
27: Diff: 2.8549926867071918e-08
28: Diff: 2.7035364169591958e-08
29: Diff: 3.270851678394138e-08
30: Diff: 3.241282087174113e-08
31: Diff: 3.1503383013387065e-08
32: Diff: 3.513880386663004e-08
33: Diff: 3.33590435275255e-08
34: Diff: 3.792089486864825e-08
35: Diff: 4.3436422458853485e-08
36: Diff: 4.1459223609763285e-08
37: Diff: 4.6123751928917045e-08
38: Diff: 4.665297836936588e-08
39: Diff: 5.143881054436861e-08


 17%|█▋        | 52/300 [00:00<00:03, 65.34it/s]

40: Diff: 5.1418322044582965e-08
41: Diff: 5.696528049270455e-08
42: Diff: 6.182119705044897e-08
43: Diff: 6.904267024765431e-08
44: Diff: 6.956420861570223e-08
45: Diff: 7.118936906636009e-08
46: Diff: 8.006813345673436e-08
47: Diff: 8.258177075504136e-08
48: Diff: 1.0021916097002759e-07
49: Diff: 9.863451566616277e-08
50: Diff: 1.0347552859002462e-07
51: Diff: 1.0620477297607067e-07
52: Diff: 1.3457332670441247e-07
53: Diff: 1.4613476650993107e-07


 22%|██▏       | 66/300 [00:01<00:03, 63.69it/s]

54: Diff: 1.6478171005473996e-07
55: Diff: 1.8965826598105195e-07
56: Diff: 2.2539497024354205e-07
57: Diff: 2.688025233510416e-07
58: Diff: 2.7721515039047517e-07
59: Diff: 3.163497979130625e-07
60: Diff: 3.340910268434527e-07
61: Diff: 3.3704100133036263e-07
62: Diff: 3.367140948284941e-07
63: Diff: 3.7695701848861063e-07
64: Diff: 3.8332353824444e-07
65: Diff: 4.0258188960251573e-07
66: Diff: 4.623546487891872e-07


 27%|██▋       | 80/300 [00:01<00:03, 63.06it/s]

67: Diff: 4.460611648937629e-07
68: Diff: 4.431917659530882e-07
69: Diff: 5.708146204597142e-07
70: Diff: 5.321242042555241e-07
71: Diff: 6.219912620508694e-07
72: Diff: 6.085001018618641e-07
73: Diff: 6.348551551127457e-07
74: Diff: 6.33542924788344e-07
75: Diff: 7.215524533421558e-07
76: Diff: 7.80806885813945e-07
77: Diff: 6.846651103842305e-07
78: Diff: 6.740010007888486e-07
79: Diff: 7.414836886709963e-07
80: Diff: 8.448935773230914e-07


 31%|███▏      | 94/300 [00:01<00:03, 62.42it/s]

81: Diff: 7.063681550789624e-07
82: Diff: 9.134775496022485e-07
83: Diff: 9.107887990467134e-07
84: Diff: 7.783780233694415e-07
85: Diff: 8.459687705908436e-07
86: Diff: 1.0565156571828993e-06
87: Diff: 9.718090723254136e-07
88: Diff: 9.842761983236414e-07
89: Diff: 1.1517340681166388e-06
90: Diff: 1.1747778216886218e-06
91: Diff: 1.241836798726581e-06
92: Diff: 1.4881474044159404e-06
93: Diff: 1.3483172551786993e-06


 34%|███▎      | 101/300 [00:01<00:03, 61.36it/s]

94: Diff: 1.5437240108440164e-06
95: Diff: 1.4906592014085618e-06
96: Diff: 1.5745815744594438e-06
97: Diff: 1.2384556384859025e-06
98: Diff: 1.3709920949622756e-06
99: Diff: 1.2577017969306326e-06
100: Diff: 1.2021577049381449e-06
101: Diff: 1.3562711274062167e-06
102: Diff: 1.1947472557949368e-06
103: Diff: 1.323953370047093e-06
104: Diff: 1.4028270243215957e-06
105: Diff: 1.1558645383047406e-06
106: Diff: 1.3995469316796516e-06


 38%|███▊      | 113/300 [00:01<00:03, 58.33it/s]

107: Diff: 1.5992840189937851e-06
108: Diff: 1.751044351294695e-06
109: Diff: 1.472750795983302e-06
110: Diff: 1.6953973727140692e-06
111: Diff: 1.6320394706781371e-06
112: Diff: 1.7367387954436708e-06





KeyboardInterrupt: 

In [None]:
model.la1.fused_pair_bilinear.Y.shape

In [None]:
model.la1.fused_pair_bilinear.pairW.shape

In [None]:
model.la1.fused_pair_bilinear.Y

In [None]:
model.la1.fused_pair_bilinear.pairW

In [None]:
torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y))

In [None]:
'''
The two models with same computation differ in gradients..
Below code analyses models for the differences in the output.

'''

In [None]:
criterion = nn.CrossEntropyLoss()

torch.manual_seed(0)
model = FactorNet2().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)

torch.manual_seed(0)
model_deb = FactorNet2_debug().to(device)
optimizer_deb = torch.optim.SGD(model_deb.parameters(), lr=0.03)

print("number of params: ", sum(p.numel() for p in model.parameters()))

print("number of params: ", sum(p.numel() for p in model_deb.parameters()))

In [None]:
model.la1.fused_pair_bilinear.pairW[0].shape

In [None]:
model_deb.la1.facto_nets[0].pairW.shape

In [None]:
model.la1.fused_pair_bilinear.Y[0][0]

In [None]:
model_deb.la1.facto_nets[0].Y[0]

In [None]:
for xx, yy in tqdm(train_loader):
    xx = xx.view(xx.shape[0], -1)
    xx, yy = xx.to(device), yy.to(device)

    yout = model(xx)
    loss = criterion(yout.clone(), yy.clone())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    yout_ = model_deb(xx)
    loss_ = criterion(yout_.clone(), yy.clone())
    optimizer_deb.zero_grad()
    loss_.backward()
    optimizer_deb.step()

    break

In [None]:
with torch.no_grad():
    diff = yout - yout_
    print(f"Diff: {diff.abs().mean()}")

In [None]:
model.la1.fused_pair_bilinear.Y.grad[-1][0]

In [None]:
model_deb.la1.facto_nets[-1].Y.grad[0]

In [None]:
for i in range(3):
    diff = model.la1.fused_pair_bilinear.Y.grad[i] - model_deb.la1.facto_nets[i].Y.grad
    print(f"Y grad diff {torch.abs(diff).sum()}")
    
    diff = model.la1.fused_pair_bilinear.pairW.grad[i] - model_deb.la1.facto_nets[i].pairW.grad
    print(f"W grad diff {torch.abs(diff).sum()}")
    
    diff = model.la1.fused_pair_bilinear.Y[i] - model_deb.la1.facto_nets[i].Y
    print(f"Y diff {torch.abs(diff).mean()}")
    
    diff = model.la1.fused_pair_bilinear.pairW[i] - model_deb.la1.facto_nets[i].pairW
    print(f"W diff {torch.abs(diff).mean()}")

In [None]:
model.la1.fused_pair_bilinear.Y[0][0]

In [None]:
model_deb.la1.facto_nets[0].Y[0]

In [None]:
model.bias.bias.grad

In [None]:
model_deb.bias.bias.grad