In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random, sys, os

In [3]:
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [5]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [6]:
LR = 0.0001
BS = 200

In [7]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)

In [8]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([200, 1, 28, 28])

## Cuda -bmm2x2

In [9]:
import bmm2x2_cuda

In [10]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
#     @torch.jit.ignore
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]
    
    @staticmethod
#     @torch.jit.ignore
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
        del_input, del_weights = bmm2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [11]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        self.bmmfunc = BMM2x2Function()
        
    @torch.jit.ignore
    def bmm(self, x, w):
        return BMM2x2Function.apply(x, w)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2)
        x = self.bmm(x, self.weight)
        x = x.view(bs, -1)
        return x

## Cuda - Bilinear2x2

In [12]:
import bilinear2x2_cuda

In [13]:
class BiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bilinear2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bilinear2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [14]:
class PairBilinear2(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        self.pairW = nn.Parameter(self.pairW)
    
#     @torch.jit.ignore
#     def pairbl2x2(self, x, w):
#         return BiLinear2x2Function.apply(x, w)
    
#     @torch.jit.ignore
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
    
#         x = x.view(bs, -1, 2).transpose(0,1)
#         x = torch.bmm(x, self.pairW)
#         x = x.transpose(1,0)#.reshape(-1, 2)
        
############# OR This block ########################
        x = x.view(bs, -1, 2)
        x = BMM2x2Function.apply(x, self.pairW)
####################################################
#         x = x.view(bs, -1, 2)
        x = BiLinear2x2Function.apply(x, self.Y)
        x = x.view(bs, -1)
        return x

## Fused 2x2 Ops

In [15]:
import fused2x2ops_cuda

In [16]:
class FusedBiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights, grids):
        outputs, input_buffer = fused2x2ops_cuda.bilinear2x2_forward(inputs, weights, grids)
        ctx.save_for_backward(input_buffer, weights, grids)
        return outputs

    @staticmethod
    def backward(ctx, grad_output):
        input_buffer, weights, grids = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights, del_grids = fused2x2ops_cuda.bilinear2x2_backward(
            input_buffer, 
            weights, 
            grids,
            grad_output)
    
        return del_input, del_weights, del_grids

In [17]:
class Fused2x2BiLinear(nn.Module):
    def __init__(self, dim, grid_width, num_layers = None):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_layers = int(np.ceil(np.log2(self.dim)))
        if num_layers is not None:
            self.num_layers = num_layers
        
        
        self.num_pairs = self.dim // 2
        
#         along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
#         along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        
        ### repeat same for num_pairs
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_layers, dim=0)
        
        print(self.Y.shape)
        self.Y = nn.Parameter(self.Y)
        
        ### repeat same for num_pairs
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.pairW = self.pairW.unsqueeze(0).repeat_interleave(self.num_layers, dim=0)
        
        print(self.pairW.shape)
        self.pairW = nn.Parameter(self.pairW)
    
    def forward(self, x):
        x = FusedBiLinear2x2Function.apply(x, self.pairW, self.Y)
        return x

In [18]:
fbl = Fused2x2BiLinear(16, 3, num_layers=1).to(device)

torch.Size([1, 8, 2, 3, 3])
torch.Size([1, 8, 2, 2])


In [19]:
a = torch.randn(2, 16).to(device)
y = fbl(a)

In [20]:
y.shape

torch.Size([2, 16])

In [21]:
y.requires_grad

True

In [22]:
a

tensor([[ 2.1093,  0.7201,  1.9283,  0.7402,  0.1188,  1.2941, -0.0140,  0.6903,
         -0.2115, -2.6138,  0.4018,  0.3808, -0.3816,  0.2621, -0.7745, -0.2235],
        [-1.7910,  0.0609, -1.1728,  0.5823,  0.0888, -0.0838,  0.7788,  1.9868,
          0.1321,  1.8169,  1.0044, -0.9013,  1.3042, -1.5051, -0.0267, -1.0646]],
       device='cuda:0')

In [23]:
y

tensor([[ 2.1093,  0.7201,  1.9283,  0.7402,  0.1188,  1.2941, -0.0140,  0.6903,
         -0.2115, -2.6138,  0.4018,  0.3808, -0.3816,  0.2621, -0.7745, -0.2235],
        [-1.7910,  0.0609, -1.1728,  0.5823,  0.0888, -0.0838,  0.7788,  1.9868,
          0.1321,  1.8169,  1.0044, -0.9013,  1.3042, -1.5051, -0.0267, -1.0646]],
       device='cuda:0', grad_fn=<FusedBiLinear2x2FunctionBackward>)

In [24]:
y.mean().backward()

In [25]:
fbl.Y

Parameter containing:
tensor([[[[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0.5000],
           [1.0000, 1.0000, 1.0000]],

          [[0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000],
           [0.0000, 0.5000, 1.0000]]],


         [[[0.0000, 0.0000, 0.0000],
           [0.5000, 0.5000, 0

In [26]:
pbl = PairBilinear2(16, 3).to(device)

In [27]:
pbl.Y

Parameter containing:
tensor([[[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000, 1.0000, 1.0000]],

         [[0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000],
          [0.0000, 0.5000, 1.0000]]],


        [[[0.0000, 0.0000, 0.0000],
          [0.5000, 0.5000, 0.5000],
          [1.0000,

In [28]:
y1 = pbl(a)

In [29]:
y1

tensor([[ 2.1093,  0.7201,  1.9283,  0.7402,  0.1188,  1.2941, -0.0140,  0.6903,
         -0.2115, -2.6138,  0.4018,  0.3808, -0.3816,  0.2621, -0.7745, -0.2235],
        [-1.7910,  0.0609, -1.1728,  0.5823,  0.0888, -0.0838,  0.7788,  1.9868,
          0.1321,  1.8169,  1.0044, -0.9013,  1.3042, -1.5051, -0.0267, -1.0646]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [30]:
a

tensor([[ 2.1093,  0.7201,  1.9283,  0.7402,  0.1188,  1.2941, -0.0140,  0.6903,
         -0.2115, -2.6138,  0.4018,  0.3808, -0.3816,  0.2621, -0.7745, -0.2235],
        [-1.7910,  0.0609, -1.1728,  0.5823,  0.0888, -0.0838,  0.7788,  1.9868,
          0.1321,  1.8169,  1.0044, -0.9013,  1.3042, -1.5051, -0.0267, -1.0646]],
       device='cuda:0')

In [31]:
y1.mean().backward()

In [32]:
fbl.Y.grad[0][1]

tensor([[[ 0.0000,  0.0873,  0.0172],
         [ 0.0000, -0.0914, -0.0399],
         [ 0.0000,  0.0464,  0.0429]],

        [[ 0.0000,  0.0873,  0.0172],
         [ 0.0000, -0.0914, -0.0399],
         [ 0.0000,  0.0464,  0.0429]]], device='cuda:0')

In [33]:
pbl.Y.grad[1]

tensor([[[ 0.0000,  0.0873,  0.0172],
         [ 0.0000, -0.0914, -0.0399],
         [ 0.0000,  0.0464,  0.0429]],

        [[ 0.0000,  0.0873,  0.0172],
         [ 0.0000, -0.0914, -0.0399],
         [ 0.0000,  0.0464,  0.0429]]], device='cuda:0')

In [34]:
# fbl.Y.grad[0] - pbl.Y.grad.transpose(-1, -2)
fbl.Y.grad[0] - pbl.Y.grad

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
  

In [35]:
pbl.pairW.grad[1]

tensor([[0.0236, 0.0236],
        [0.0413, 0.0413]], device='cuda:0')

In [36]:
fbl.pairW.grad[0][1]

tensor([[0.0236, 0.0236],
        [0.0413, 0.0413]], device='cuda:0')

In [37]:
pbl.pairW[1]

tensor([[1., 0.],
        [0., 1.]], device='cuda:0', grad_fn=<SelectBackward>)

In [38]:
fbl.pairW[0][1]

tensor([[1., 0.],
        [0., 1.]], device='cuda:0', grad_fn=<SelectBackward>)

In [39]:
class FusedPairBilinearBlock(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
        
        extra =  2**self.num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(self.input_dim)[:extra]
        
        self.fused_pair_bilinear = Fused2x2BiLinear(2**self.num_layers, grid_width=grid_width, )#num_layers=5)
        
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return self.fused_pair_bilinear(x)

In [40]:
fpbl = FusedPairBilinearBlock(784, 3).to(device)

torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])


In [41]:
y = fpbl(torch.randn(10, 784).to(device))
y.shape

torch.Size([10, 1024])

## Modules and Layers

In [42]:
class BiasLayer(nn.Module):
    def __init__(self, dim, init_val=0):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(dim)*init_val)
        
    def forward(self, x):
        return x+self.bias

In [43]:
class FactorizedPairBilinearSpline(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        self.idx_revidx = []
        for i in range(self.num_layers):
            idrid = self.get_pair(self.input_dim, i+1)
            net = PairBilinear2(self.input_dim, grid_width)
            self.facto_nets.append(net)
            self.idx_revidx.append(idrid)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
#     @torch.jit.ignore
    def get_pair(self, inp_dim, step=1):
        dim = 2**int(np.ceil(np.log2(inp_dim)))
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
        block_map = (pairs_+repeat_).reshape(-1)

        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx[indx<inp_dim]

        rev_indx = torch.argsort(indx)
        return indx, rev_indx
    
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        y = x
#         for i in range(len(self.facto_nets)):
        for i, fn in enumerate(self.facto_nets):
            idx, revidx = self.idx_revidx[i]
            y = y[:, idx]
            y = fn(y) 
            y = y[:, revidx]
#         y = x + y ## this is residual addition... remove if only want feed forward
        return y

In [44]:
class PairBilinearBlock(FactorizedPairBilinearSpline):
    
    def __init__(self, input_dim, grid_width):
        num_layers = int(np.ceil(np.log2(input_dim)))
        extra =  2**num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(input_dim)[:extra]
        
        super().__init__(2**num_layers, grid_width)
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return super().forward(x)

In [45]:
pfL = FactorizedPairBilinearSpline(784, 10).to(device)

In [46]:
pfL(torch.randn(100, 784).to(device))

tensor([[ 0.5040,  2.9259, -0.2752,  ...,  0.8611, -1.0041,  0.2877],
        [ 0.1645, -0.7405, -0.9497,  ...,  0.4982,  0.1717,  0.2237],
        [-1.8787, -0.4780, -1.1443,  ...,  1.0046,  0.4900,  0.2338],
        ...,
        [-1.2427,  0.2971,  0.4793,  ..., -1.9474, -0.2935,  1.4595],
        [ 0.4139, -0.8466, -0.0513,  ..., -0.5742, -0.8309,  0.6456],
        [-1.3414,  0.4671, -1.7051,  ...,  0.7226,  0.0492,  0.7175]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [47]:
_a = torch.randn(100, 784).to(device)

# %timeit pfL(_a)

In [48]:
pfL.facto_nets

ModuleList(
  (0): PairBilinear2()
  (1): PairBilinear2()
  (2): PairBilinear2()
  (3): PairBilinear2()
  (4): PairBilinear2()
  (5): PairBilinear2()
  (6): PairBilinear2()
  (7): PairBilinear2()
  (8): PairBilinear2()
  (9): PairBilinear2()
)

In [49]:
param_count = sum([torch.numel(p) for p in pfL.parameters()])
param_count

799680

In [50]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = BiasLayer(784)
        self.la1 = FactorizedPairBilinearSpline(784, grid_width=3)
        self.bn1 = nn.BatchNorm1d(784)
        self.fc = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [51]:
class FactorNet2(nn.Module):
    def __init__(self):
        super().__init__()
#         self.bias = nn.Linear(784, 512)
#         self.la1 = FusedPairBilinearBlock(512, grid_width=3)
#         self.bn1 = nn.BatchNorm1d(512)
#         self.fc = nn.Linear(512, 10)

        self.bias = BiasLayer(784)
        self.la1 = FusedPairBilinearBlock(784, grid_width=3)
#         self.la1 = PairBilinearBlock(784, grid_width=3)
        self.bn1 = nn.BatchNorm1d(1024)
        self.fc = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x
    
    
class FactorNet2_debug(nn.Module):
    def __init__(self):
        super().__init__()
#         self.bias = nn.Linear(784, 512)
#         self.la1 = FusedPairBilinearBlock(512, grid_width=3)
#         self.bn1 = nn.BatchNorm1d(512)
#         self.fc = nn.Linear(512, 10)

        self.bias = BiasLayer(784)
#         self.la1 = FusedPairBilinearBlock(784, grid_width=3)
        self.la1 = PairBilinearBlock(784, grid_width=3)
        self.bn1 = nn.BatchNorm1d(1024)
        self.fc = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [52]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 784, bias=False)
        self.bn1 = nn.BatchNorm1d(784)
        self.la2 = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.la2(x)
        return x

In [53]:
model = FactorNet2()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])


125722

In [54]:
# model.la1.fused_pair_bilinear.Y.shape

In [55]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(624074, 4.963920395793894)

### Model Development

In [56]:
torch.manual_seed(0)
model = FactorNet2().to(device)
# model = OrdinaryNet().to(device)

model

torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])


FactorNet2(
  (bias): BiasLayer()
  (la1): FusedPairBilinearBlock(
    (fused_pair_bilinear): Fused2x2BiLinear()
  )
  (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

In [57]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

In [58]:
torch.manual_seed(0)
model_deb = FactorNet2_debug().to(device)
model_deb

FactorNet2_debug(
  (bias): BiasLayer()
  (la1): PairBilinearBlock(
    (facto_nets): ModuleList(
      (0): PairBilinear2()
      (1): PairBilinear2()
      (2): PairBilinear2()
      (3): PairBilinear2()
      (4): PairBilinear2()
      (5): PairBilinear2()
      (6): PairBilinear2()
      (7): PairBilinear2()
      (8): PairBilinear2()
      (9): PairBilinear2()
    )
  )
  (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=1024, out_features=10, bias=True)
)

In [59]:
optimizer_deb = torch.optim.Adam(model_deb.parameters(), lr=0.0003)

In [60]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  125722


In [61]:
for p in model.parameters():
    print(p.shape)

torch.Size([784])
torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])
torch.Size([1024])
torch.Size([1024])
torch.Size([10, 1024])
torch.Size([10])


In [62]:
print("number of params: ", sum(p.numel() for p in model_deb.parameters()))

number of params:  125722


In [63]:
for p in model_deb.parameters():
    print(p.shape)

torch.Size([784])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([512, 2, 3, 3])
torch.Size([512, 2, 2])
torch.Size([1024])
torch.Size([1024])
torch.Size([10, 1024])
torch.Size([10])


In [64]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        yout_ = model_deb(xx)
        loss_ = criterion(yout_, yy)
        optimizer_deb.zero_grad()
        loss_.backward()
        optimizer_deb.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
        
        ########3 Find Similarities and Differences @@@@@@@@@@
        
        with torch.no_grad():
            diff = yout - yout_
            print(f"Diff: {diff.abs().mean()}")
        
#         if torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y)):
#             print(f"NAN Grid back; loss {float(loss)}", flush=True)
#         if torch.any(torch.isnan(model.la1.fused_pair_bilinear.pairW)):
#             print(f"NAN Weight back; loss {float(loss)}", flush=True)
#         sys.stdout.flush()
        

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
            yout_ = model_deb(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  0%|          | 1/300 [00:00<01:31,  3.27it/s]

Diff: 0.0


  1%|          | 2/300 [00:00<01:15,  3.93it/s]

Diff: 1.5234649026751867e-06


  1%|          | 3/300 [00:00<01:11,  4.15it/s]

Diff: 2.619203041831497e-05


  1%|▏         | 4/300 [00:00<01:08,  4.33it/s]

Diff: 4.2597308492986485e-05


  2%|▏         | 5/300 [00:01<01:07,  4.37it/s]

Diff: 7.820215978426859e-05


  2%|▏         | 6/300 [00:01<01:06,  4.42it/s]

Diff: 0.00010962785017909482


  2%|▏         | 7/300 [00:01<01:05,  4.47it/s]

Diff: 0.00013302225852385163


  3%|▎         | 8/300 [00:01<01:05,  4.49it/s]

Diff: 0.0001575818459969014


  3%|▎         | 9/300 [00:02<01:04,  4.51it/s]

Diff: 0.000198329784325324


  3%|▎         | 10/300 [00:02<01:04,  4.50it/s]

Diff: 0.00020950505859218538


  4%|▎         | 11/300 [00:02<01:04,  4.50it/s]

Diff: 0.0002351947478018701


  4%|▍         | 12/300 [00:02<01:03,  4.54it/s]

Diff: 0.000258908374235034


  4%|▍         | 13/300 [00:02<01:02,  4.57it/s]

Diff: 0.0002811381418723613


  5%|▍         | 14/300 [00:03<01:02,  4.58it/s]

Diff: 0.0003019824798684567


  5%|▌         | 15/300 [00:03<01:02,  4.59it/s]

Diff: 0.0003342173877172172


  5%|▌         | 16/300 [00:03<01:01,  4.60it/s]

Diff: 0.0003759414830710739


  6%|▌         | 17/300 [00:03<01:01,  4.60it/s]

Diff: 0.0004014784062746912


  6%|▌         | 17/300 [00:04<01:07,  4.20it/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/tsuman/Program_Files/Python/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-64-6e0d593283de>", line 23, in <module>
    loss_.backward()
  File "/home/tsuman/Program_Files/Python/miniconda3/lib/python3.7/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/tsuman/Program_Files/Python/miniconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/tsuman/Program_Files/Python/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 1828, in showtraceback
    s

KeyboardInterrupt: 

In [65]:
model.la1.fused_pair_bilinear.Y.shape

torch.Size([10, 512, 2, 3, 3])

In [66]:
model.la1.fused_pair_bilinear.pairW.shape

torch.Size([10, 512, 2, 2])

In [67]:
model.la1.fused_pair_bilinear.Y

Parameter containing:
tensor([[[[[-4.8426e-03,  3.4685e-03,  0.0000e+00],
           [ 5.0352e-01,  4.9969e-01,  5.0000e-01],
           [ 1.0000e+00,  1.0000e+00,  1.0000e+00]],

          [[-4.3923e-03,  5.0313e-01,  1.0000e+00],
           [ 3.1356e-03,  4.9971e-01,  1.0000e+00],
           [ 0.0000e+00,  5.0000e-01,  1.0000e+00]]],


         [[[-2.8686e-03,  1.5614e-03,  1.5437e-03],
           [ 4.9823e-01,  5.0178e-01,  5.0149e-01],
           [ 1.0000e+00,  1.0000e+00,  1.0000e+00]],

          [[ 6.8547e-04,  4.9926e-01,  9.9846e-01],
           [ 3.2321e-03,  4.9949e-01,  9.9851e-01],
           [ 0.0000e+00,  5.0000e-01,  1.0000e+00]]],


         [[[-2.6297e-03, -8.4663e-04, -1.7677e-03],
           [ 4.9717e-01,  4.9811e-01,  4.9824e-01],
           [ 9.9889e-01,  9.9809e-01,  9.9864e-01]],

          [[-3.3490e-03,  4.9816e-01,  9.9825e-01],
           [-1.9572e-03,  4.9808e-01,  9.9826e-01],
           [-1.1171e-03,  4.9805e-01,  9.9863e-01]]],


         ...,


        

In [68]:
model.la1.fused_pair_bilinear.pairW

Parameter containing:
tensor([[[[ 1.0032e+00,  2.8594e-03],
          [ 3.1364e-03,  1.0028e+00]],

         [[ 9.9887e-01,  1.4349e-03],
          [ 1.5738e-03,  9.9914e-01]],

         [[ 9.9724e-01, -2.1724e-03],
          [-2.2809e-03,  9.9759e-01]],

         ...,

         [[ 9.9701e-01, -4.8337e-03],
          [-3.4174e-03,  1.0047e+00]],

         [[ 1.0052e+00, -3.2165e-03],
          [-4.8172e-03,  9.9543e-01]],

         [[ 1.0055e+00, -2.7365e-03],
          [ 4.5138e-03,  9.9565e-01]]],


        [[[ 1.0035e+00,  3.5056e-03],
          [ 3.1649e-03,  9.9912e-01]],

         [[ 1.0031e+00, -2.1072e-03],
          [-2.2906e-03,  9.9911e-01]],

         [[ 9.9725e-01, -3.8963e-04],
          [-1.4932e-03,  9.9862e-01]],

         ...,

         [[ 1.0037e+00,  2.3437e-03],
          [ 4.7492e-03,  1.0047e+00]],

         [[ 1.0052e+00,  2.2877e-03],
          [-4.9567e-03,  1.0055e+00]],

         [[ 9.9543e-01, -1.2543e-03],
          [-2.2327e-03,  9.9568e-01]]],


        

In [69]:
torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y))

tensor(False, device='cuda:0')

In [70]:
'''
The two models with same computation differ in gradients..
Below code analyses models for the differences in the output.

'''

'\nThe two models with same computation differ in gradients..\nBelow code analyses models for the differences in the output.\n\n'

In [105]:
criterion = nn.CrossEntropyLoss()

torch.manual_seed(0)
model = FactorNet2().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)

torch.manual_seed(0)
model_deb = FactorNet2_debug().to(device)
optimizer_deb = torch.optim.Adam(model_deb.parameters(), lr=0.03)

print("number of params: ", sum(p.numel() for p in model.parameters()))

print("number of params: ", sum(p.numel() for p in model_deb.parameters()))

torch.Size([10, 512, 2, 3, 3])
torch.Size([10, 512, 2, 2])
number of params:  125722
number of params:  125722


In [106]:
model.la1.fused_pair_bilinear.pairW[0].shape

torch.Size([512, 2, 2])

In [107]:
model_deb.la1.facto_nets[0].pairW.shape

torch.Size([512, 2, 2])

In [108]:
model.la1.fused_pair_bilinear.Y[0][0]

tensor([[[0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.5000],
         [1.0000, 1.0000, 1.0000]],

        [[0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000]]], device='cuda:0', grad_fn=<SelectBackward>)

In [109]:
model_deb.la1.facto_nets[0].Y[0]

tensor([[[0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.5000],
         [1.0000, 1.0000, 1.0000]],

        [[0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000]]], device='cuda:0', grad_fn=<SelectBackward>)

In [115]:
for xx, yy in tqdm(train_loader):
    xx = xx.view(xx.shape[0], -1)
    xx, yy = xx.to(device), yy.to(device)

    yout = model(xx)
    loss = criterion(yout, yy)
    optimizer.zero_grad()
    loss.backward()
#     optimizer.step()

    yout_ = model_deb(xx)
    loss_ = criterion(yout_, yy)
    optimizer_deb.zero_grad()
    loss_.backward()
#     optimizer_deb.step()

    break

  0%|          | 0/300 [00:00<?, ?it/s]


In [116]:
with torch.no_grad():
    diff = yout - yout_
    print(f"Diff: {diff.abs().mean()}")

Diff: 0.06269387900829315


In [117]:
model.la1.fused_pair_bilinear.Y.grad[7][0]

tensor([[[-8.0895e-04, -3.5383e-04, -3.9412e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 2.0205e-05,  2.1356e-03,  2.1439e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]], device='cuda:0')

In [118]:
model_deb.la1.facto_nets[7].Y.grad[0]

tensor([[[-1.9456e-03, -1.0282e-04, -1.9194e-07],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-6.9390e-05,  1.4888e-03, -3.9113e-06],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]], device='cuda:0')

In [119]:
for i in range(10):
    diff = model.la1.fused_pair_bilinear.Y.grad[i] - model_deb.la1.facto_nets[i].Y.grad
    print(f"Y grad diff {torch.abs(diff).mean()}")
    
    diff = model.la1.fused_pair_bilinear.pairW.grad[i] - model_deb.la1.facto_nets[i].pairW.grad
    print(f"W grad diff {torch.abs(diff).mean()}")
    
    diff = model.la1.fused_pair_bilinear.Y[i] - model_deb.la1.facto_nets[i].Y
    print(f"Y diff {torch.abs(diff).mean()}")
    
    diff = model.la1.fused_pair_bilinear.pairW[i] - model_deb.la1.facto_nets[i].pairW
    print(f"W diff {torch.abs(diff).mean()}")

Y grad diff 0.001012712367810309
W grad diff 0.0014948141761124134
Y diff 0.0
W diff 0.0
Y grad diff 0.001003076322376728
W grad diff 0.00146401091478765
Y diff 0.0
W diff 0.0
Y grad diff 0.0010161908576264977
W grad diff 0.0014503607526421547
Y diff 0.0
W diff 0.0
Y grad diff 0.001030481536872685
W grad diff 0.0014273985289037228
Y diff 0.0
W diff 0.0
Y grad diff 0.0010648692259564996
W grad diff 0.0014739538310095668
Y diff 0.0
W diff 0.0
Y grad diff 0.0012608177494257689
W grad diff 0.0019115135073661804
Y diff 0.0
W diff 0.0
Y grad diff 0.001483541214838624
W grad diff 0.0022428082302212715
Y diff 0.0
W diff 0.0
Y grad diff 0.0017222966998815536
W grad diff 0.00259769125841558
Y diff 0.0
W diff 0.0
Y grad diff 0.0021347692236304283
W grad diff 0.0028894953429698944
Y diff 0.0
W diff 0.0
Y grad diff 0.002208049176260829
W grad diff 0.0030084566678851843
Y diff 0.0
W diff 0.0


In [120]:
model.la1.fused_pair_bilinear.Y[0][0]

tensor([[[0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.5000],
         [1.0000, 1.0000, 1.0000]],

        [[0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000]]], device='cuda:0', grad_fn=<SelectBackward>)

In [121]:
model_deb.la1.facto_nets[0].Y[0]

tensor([[[0.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.5000],
         [1.0000, 1.0000, 1.0000]],

        [[0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 1.0000]]], device='cuda:0', grad_fn=<SelectBackward>)