In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random, sys, os

In [3]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [4]:
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms

In [5]:
# device = torch.device("cuda:0")
device = torch.device("cuda:1")
# device = torch.device("cpu")

In [6]:
train_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=True, download=True, transform=train_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/FMNIST/", train=False, download=True, transform=test_transform)

In [7]:
LR = 0.0001
BS = 200

In [8]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BS, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BS, shuffle=False, num_workers=2)

In [9]:
## demo of train loader
xx, yy = iter(train_loader).next()
xx.shape

torch.Size([200, 1, 28, 28])

## Cuda -bmm2x2

In [10]:
import bmm2x2_cuda

In [11]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
#     @torch.jit.ignore
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]
    
    @staticmethod
#     @torch.jit.ignore
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
        del_input, del_weights = bmm2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [12]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        self.bmmfunc = BMM2x2Function()
        
    @torch.jit.ignore
    def bmm(self, x, w):
        return BMM2x2Function.apply(x, w)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2)
        x = self.bmm(x, self.weight)
        x = x.view(bs, -1)
        return x

## Cuda - Bilinear2x2

In [13]:
import bilinear2x2_cuda

In [14]:
del_input_list = []

In [15]:
class BiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bilinear2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        global del_input_list
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bilinear2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
#         del_input_list.append(del_input)
        return del_input, del_weights

In [16]:
class PairBilinear2(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        self.pairW = nn.Parameter(self.pairW)
    
#     @torch.jit.ignore
#     def pairbl2x2(self, x, w):
#         return BiLinear2x2Function.apply(x, w)
    
#     @torch.jit.ignore
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
    
#         x = x.view(bs, -1, 2).transpose(0,1)
#         x = torch.bmm(x, self.pairW)
#         x = x.transpose(1,0)#.reshape(-1, 2)
        
############# OR This block ########################
        x = x.contiguous().view(bs, -1, 2)
#         x = BMM2x2Function.apply(x, self.pairW)
####################################################
#         x = x.view(bs, -1, 2)
        x = BiLinear2x2Function.apply(x, self.Y)
        x = x.view(bs, -1)
        return x

In [17]:
class FactorizedPairBilinearSpline_2(nn.Module):
    
    def __init__(self, input_dim, grid_width, num_layers=None):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(self.input_dim)))
        if num_layers is not None:
            self.num_layers = num_layers
            
        self.facto_nets = []
        for i in range(self.num_layers):
            net = PairBilinear2(self.input_dim, grid_width)
            self.facto_nets.append(net)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def forward(self, x):
        
        global del_input_list
        del_input_list = []
        
        ## swap first and then forward and reverse-swap
        bs = x.shape[0]
        y = x
#         for i in range(len(self.facto_nets)):
        for i, fn in enumerate(self.facto_nets):
            y = y.view(-1,2,2**(i)).permute(0, 2,1).contiguous().view(bs, -1)
            y = fn(y) 
            y = y.view(-1,2**(i),2).permute(0, 2,1).contiguous()
#         y = x + y ## this is residual addition... remove if only want feed forward
        return y.view(bs, -1)

In [18]:
# class PairBilinearBlock_2(FactorizedPairBilinearSpline_2):
    
#     def __init__(self, input_dim, grid_width):
#         num_layers = int(np.ceil(np.log2(input_dim)))
#         extra =  2**num_layers - input_dim
#         torch.manual_seed(123)
#         self.selector = torch.randperm(input_dim)[:extra]
        
#         super().__init__(2**num_layers, grid_width)
        
#     def forward(self, x):
#         '''
#         x should have dimension -> bs, M
#         '''
#         x = torch.cat((x, x[:, self.selector]), dim=1)
#         return super().forward(x)

## Fused 2x2 Ops

In [19]:
import fused2x2ops_cuda

In [20]:
recent_del_input = None

In [21]:
class FusedBiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights, grids):
        outputs, input_buffer, dbg = fused2x2ops_cuda.bilinear2x2_forward(inputs, weights, grids)
#         print(dbg)

#         outputs, input_buffer = fused2x2ops_cuda.bilinear2x2_forward(inputs, weights, grids)
        ctx.save_for_backward(input_buffer, weights, grids)
        return outputs, dbg

    @staticmethod
    def backward(ctx, grad_output, a):
        global recent_del_input
        input_buffer, weights, grids = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights, del_grids, del_inputs = fused2x2ops_cuda.bilinear2x2_backward(
            input_buffer, 
            weights, 
            grids,
            grad_output)
    
#         recent_del_input = del_inputs
        return del_input, del_weights, del_grids

In [22]:
class Fused2x2BiLinear(nn.Module):
    def __init__(self, dim, grid_width, num_layers = None):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_layers = int(np.ceil(np.log2(self.dim)))
        if num_layers is not None:
            self.num_layers = num_layers
        
        
        self.num_pairs = self.dim // 2
        
#         along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
#         along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1).t()
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1).t()
        
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        
        ### repeat same for num_pairs
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_layers, dim=0)
        
        print(self.Y.shape)
        self.Y = nn.Parameter(self.Y)
        
        ### repeat same for num_pairs
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        ### repeat same for num_Layers
        self.pairW = self.pairW.unsqueeze(0).repeat_interleave(self.num_layers, dim=0)
        
        print(self.pairW.shape)
        self.pairW = nn.Parameter(self.pairW)
    
    def forward(self, x):
        self.Y.data.clamp_(-10, 10)
        x, self.dbg = FusedBiLinear2x2Function.apply(x.clone(), self.pairW, self.Y)
        return x

In [23]:
N = 16
fbl = Fused2x2BiLinear(N, 3, num_layers=1).to(device)

torch.Size([1, 8, 2, 3, 3])
torch.Size([1, 8, 2, 2])


In [24]:
# pbl = FactorizedPairBilinearSpline_2(N, 3, num_layers=1).to(device)

In [25]:
a = torch.randn(1, N).to(device)
y = fbl(a)

In [26]:
fbl.dbg

RuntimeError: CUDA error: an illegal memory access was encountered

In [27]:
y.shape

torch.Size([1, 16])

In [28]:
a

RuntimeError: CUDA error: an illegal memory access was encountered

In [29]:
y

RuntimeError: CUDA error: an illegal memory access was encountered

In [30]:
y.mean().backward()

RuntimeError: CUDA error: an illegal memory access was encountered

In [None]:
fbl.Y

In [None]:
pbl = FactorizedPairBilinearSpline_2(N, 3, num_layers=1).to(device)

In [None]:
pbl.facto_nets[0].Y

In [None]:
y1 = pbl(a)

In [None]:
y1

In [None]:
a

In [None]:
y1.mean().backward()

In [None]:
fbl.Y.grad[0][1]

In [None]:
pbl.facto_nets[0].Y.grad[1]

In [None]:
# fbl.Y.grad[0] - pbl.Y.grad.transpose(-1, -2)
# fbl.Y.grad[0] - pbl.Y.grad

In [None]:
pbl.facto_nets[0].pairW.grad[1]

In [None]:
fbl.pairW.grad[0][1]

In [None]:
pbl.facto_nets[0].pairW[1]

In [31]:
fbl.pairW[0][1]

RuntimeError: CUDA error: an illegal memory access was encountered

## Optimizing both models

In [32]:
N = 16
L = 2
a = torch.randn(5, N).to(device)
t = torch.randn_like(a)

RuntimeError: CUDA error: an illegal memory access was encountered

In [45]:
def criterion(y, t):
    return ((y-t)**2).sum()

In [46]:
fbl = Fused2x2BiLinear(N, 2, num_layers=L).to(device)

torch.Size([2, 8, 2, 2, 2])
torch.Size([2, 8, 2, 2])


In [47]:
fbl_opt = torch.optim.Adam(fbl.parameters(), lr=0.01)

In [48]:
pbl = FactorizedPairBilinearSpline_2(N, 2, num_layers=L).to(device)

In [49]:
pbl_opt = torch.optim.Adam(pbl.parameters(), lr=0.01)

In [50]:
# fbl.pairW.data = torch.randn_likee(fbl.pairW.data)

In [51]:
# pbl.pairW.data = fbl.pairW.data[0].clone()

In [52]:
fbl.pairW.data

tensor([[[[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]]],


        [[[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]]]], device='cuda:0')

In [53]:
pbl.facto_nets[0].pairW.data

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0')

In [54]:
fbl.Y.data = torch.randn_like(fbl.Y.data)

for i in range(L):
    pbl.facto_nets[i].Y.data = fbl.Y.data[i].clone()#.transpose(-1,-2).contiguous()

In [55]:
fbl.Y.data[0][0]

tensor([[[ 0.3791, -0.8440],
         [-0.1559, -0.6066]],

        [[ 1.6306,  3.1744],
         [ 0.5283, -0.7442]]], device='cuda:0')

In [56]:
pbl.facto_nets[0].Y.data[0]

tensor([[[ 0.3791, -0.8440],
         [-0.1559, -0.6066]],

        [[ 1.6306,  3.1744],
         [ 0.5283, -0.7442]]], device='cuda:0')

### Repeat from here

In [57]:
y = fbl(a)

In [58]:
fbl_opt.zero_grad()
criterion(y,t).backward()

In [59]:
y1 = pbl(a)

In [60]:
pbl_opt.zero_grad()
criterion(y1,t).backward()

In [61]:
y1, y

(tensor([[-1.5037e+00, -3.7595e+00, -1.8056e+00,  5.1659e-01, -1.0865e+00,
           1.2358e+01, -7.4260e-01,  2.4485e+01,  5.0571e+00,  2.1378e+00,
           2.5427e+00,  2.5374e+00,  6.9556e+00, -1.8317e+00, -2.6482e+00,
           1.1066e+00],
         [ 9.7731e+00,  1.6325e+00, -3.8314e+00,  9.4500e-01,  2.3571e-01,
           7.2137e-02, -4.0172e+00,  4.7148e-02, -6.4294e+00,  2.6724e+00,
           8.9784e+00,  1.4810e+00, -6.1784e+00, -1.0216e+00, -2.0686e+00,
          -2.8565e+00],
         [-1.6147e+00, -2.9150e+00,  1.1380e+00, -1.5460e+00, -9.7056e+00,
          -2.8597e+01,  2.3266e+01, -4.0729e+01,  7.2298e+00, -8.5021e-01,
          -2.1144e+00, -3.2863e+00, -2.5194e-01,  2.2901e-01,  4.0116e-01,
          -3.4700e+00],
         [-1.4724e+00, -9.0094e-01, -1.0729e+00, -6.8572e-01, -3.9127e+00,
           4.6269e+01,  1.8246e+00,  9.7919e+01,  4.7538e+00,  6.3082e-01,
          -3.1558e-01, -1.1043e+00,  4.2851e-01, -1.1215e+00,  1.3154e+00,
          -2.9661e+00],
    

In [62]:
# a

In [63]:
'''
OBSERVATIONS:
1. The kernel for copy operation works.; Now try matrix multiplication.. also works
2. 
'''

'\nOBSERVATIONS:\n1. The kernel for copy operation works.; Now try matrix multiplication.. also works\n2. \n'

In [64]:
y1.data-y.data

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')

In [65]:
# fbl.Y.grad

In [66]:
# pbl.Y.grad

In [67]:
fbl_opt.step()
pbl_opt.step()

In [68]:
# fbl.Y.grad[0] - pbl.Y.grad.transpose(-1, -2)
# torch.abs(fbl.Y.grad[0] - pbl.Y.grad).abs().sum()

In [69]:
torch.abs(fbl.Y.data[0] - pbl.facto_nets[0].Y.data).abs().sum()

tensor(9.3132e-10, device='cuda:0')

In [70]:
torch.abs(fbl.Y.grad[0] - pbl.facto_nets[0].Y.grad).abs().sum()

tensor(0.0022, device='cuda:0')

In [71]:
torch.abs(fbl.pairW.data[0] - pbl.facto_nets[0].pairW.data).abs().sum()

tensor(0., device='cuda:0')

In [72]:
torch.abs(fbl.pairW.grad[0] - pbl.facto_nets[0].pairW.grad).abs().sum()

TypeError: unsupported operand type(s) for -: 'Tensor' and 'NoneType'

In [73]:
# pbl.pairW.grad[1]

In [74]:
# fbl.pairW.grad[0][1]

In [75]:
pbl.facto_nets[0].pairW

Parameter containing:
tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0', requires_grad=True)

In [76]:
fbl.pairW[0]

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0', grad_fn=<SelectBackward0>)

In [77]:
torch.abs(fbl.pairW.data[0] - pbl.facto_nets[0].pairW.data)

tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]], device='cuda:0')

## dissecting the difference in first gradient

In [78]:
fbl.Y.grad[0] - pbl.facto_nets[0].Y.grad

tensor([[[[ 3.0518e-05, -1.5259e-05],
          [-1.5259e-05,  1.1444e-05]],

         [[-9.5367e-07,  1.6689e-06],
          [ 4.7684e-07, -8.3447e-07]]],


        [[[ 0.0000e+00,  0.0000e+00],
          [-1.5259e-05,  3.8147e-06]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -1.9073e-06]]],


        [[[ 2.4414e-04, -6.1035e-05],
          [ 0.0000e+00,  1.2207e-04]],

         [[ 0.0000e+00,  2.4414e-04],
          [-9.7656e-04, -1.2207e-04]]],


        [[[-1.2207e-04,  6.1035e-05],
          [ 3.0518e-05, -7.6294e-06]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]]],


        [[[-6.1035e-05,  0.0000e+00],
          [-1.5259e-05,  3.0518e-05]],

         [[ 0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00],
          [ 0.

In [79]:
fbl.Y.grad[0][0], pbl.facto_nets[0].Y.grad[0]

(tensor([[[ 327.9010, -143.4838],
          [-111.4384,   48.9418]],
 
         [[  15.8838,    2.4484],
          [  -7.6424,   -1.4596]]], device='cuda:0'),
 tensor([[[ 327.9010, -143.4838],
          [-111.4384,   48.9418]],
 
         [[  15.8838,    2.4484],
          [  -7.6424,   -1.4596]]], device='cuda:0'))

In [80]:
fbl.Y.grad[1] - pbl.facto_nets[1].Y.grad

tensor([[[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]],


        [[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]]], device='cuda:0')

In [81]:
recent_del_input[1]#.sum()

tensor([[-3.3865e+00,  2.1475e+00, -5.0813e+01, -6.7917e+01, -7.7847e+00,
         -3.0167e+02, -9.2780e-01,  8.8429e+02, -5.5581e+00,  1.1693e+01,
          1.5529e+01,  4.6645e-01, -5.0991e+01,  7.8248e+00, -8.1608e+01,
          1.0434e+01],
        [ 1.3335e+02,  4.8941e+00, -6.6769e+01,  1.1105e+01,  2.9499e+01,
          1.1546e+00, -1.5340e+01, -2.4701e-01, -1.1079e+02, -1.2389e+01,
         -1.6451e+02, -5.8750e+00,  4.7702e+01, -1.1734e+01, -4.7073e+01,
          1.6948e+01],
        [ 2.7652e+00,  1.6842e+00,  3.8038e+01, -1.7971e+01, -7.3647e+02,
          1.5797e+03, -2.0910e+02,  1.0404e+03, -2.1347e+01, -2.2273e+01,
          1.0161e+02,  5.0339e+01,  4.0611e-01, -1.8502e+01,  4.3898e+00,
          1.2963e+01],
        [-3.4984e+00, -3.5313e-01, -6.3028e+00,  4.8524e+00, -7.9466e+00,
         -5.8966e+03, -2.5341e+00,  2.8729e+03, -1.5595e+01, -4.1621e-01,
          3.5140e+01, -1.9838e-01,  6.5884e-01, -1.2603e+01, -1.5375e+00,
          1.0425e+01],
        [-7.3082e+00

In [82]:
del_input_list[0]#.sum()

tensor([[[-3.3865e+00, -5.0813e+01],
         [ 2.1475e+00, -6.7917e+01],
         [-7.7847e+00, -9.2780e-01],
         [-3.0167e+02,  8.8429e+02],
         [-5.5581e+00,  1.5529e+01],
         [ 1.1693e+01,  4.6645e-01],
         [-5.0991e+01, -8.1608e+01],
         [ 7.8248e+00,  1.0434e+01]],

        [[ 1.3335e+02, -6.6769e+01],
         [ 4.8941e+00,  1.1105e+01],
         [ 2.9499e+01, -1.5340e+01],
         [ 1.1546e+00, -2.4701e-01],
         [-1.1079e+02, -1.6451e+02],
         [-1.2389e+01, -5.8750e+00],
         [ 4.7702e+01, -4.7073e+01],
         [-1.1734e+01,  1.6948e+01]],

        [[ 2.7652e+00,  3.8038e+01],
         [ 1.6842e+00, -1.7971e+01],
         [-7.3647e+02, -2.0910e+02],
         [ 1.5797e+03,  1.0404e+03],
         [-2.1347e+01,  1.0161e+02],
         [-2.2273e+01,  5.0339e+01],
         [ 4.0611e-01,  4.3898e+00],
         [-1.8502e+01,  1.2963e+01]],

        [[-3.4984e+00, -6.3028e+00],
         [-3.5313e-01,  4.8524e+00],
         [-7.9466e+00, -2.5341e+

In [83]:
del_input_list[0].reshape(5, -1, 2, 2).transpose(-1, -2).reshape(5, -1) - recent_del_input[1]

tensor([[ 2.3842e-07, -2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -5.9605e-08,  0.0000e+00, -4.7684e-07,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -7.6294e-06,
          0.0000e+00],
        [-1.5259e-05,  4.7684e-07,  7.6294e-06,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  4.4703e-08,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00,  1.1921e-07,  0.0000e+00,  1.9073e-06,  6.1035e-05,
          0.0000e+00,  1.5259e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          7.6294e-06,  0.0000e+00,  0.0000e+00, -1.9073e-06,  4.7684e-07,
          0.0000e+00],
        [ 0.0000e+00,  0.0000e+00, -4.7684e-07,  0.0000e+00,  9.5367e-07,
          0.0000e+00, -2.3842e-07,  0.0000e+00,  0.0000e+00,  2.9802e-08,
          3.8147e-06, -4.4703e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 0.0000e+00

## Model

In [84]:
class FusedPairBilinearBlock(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
        
        extra =  2**self.num_layers - input_dim
        torch.manual_seed(123)
        self.selector = torch.randperm(self.input_dim)[:extra]
        
        self.fused_pair_bilinear = Fused2x2BiLinear(2**self.num_layers, grid_width=grid_width, )#num_layers=5)
        
        
    def forward(self, x):
        '''
        x should have dimension -> bs, M
        '''
        x = torch.cat((x, x[:, self.selector]), dim=1)
        return self.fused_pair_bilinear(x)

In [85]:
class BiasLayer(nn.Module):
    def __init__(self, dim, init_val=0):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(dim)*init_val)
        
    def forward(self, x):
        return x+self.bias

In [86]:
class FactorNet2(nn.Module):
    def __init__(self):
        super().__init__()
        H = 512
        self.bias = nn.Linear(784, H)
        self.bn1 = nn.BatchNorm1d(H)
        self.fc = nn.Linear(H, 10)
        self.la1 = FusedPairBilinearBlock(H, grid_width=3)
#         self.la1 = PairBilinearBlock_2(H, grid_width=3)

#         self.bias = BiasLayer(784)
#         self.la1 = FusedPairBilinearBlock(784, grid_width=3)
# #         self.la1 = PairBilinearBlock(784, grid_width=3)
#         self.bn1 = nn.BatchNorm1d(1024)
#         self.fc = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [87]:
torch.manual_seed(0)
model = FactorNet2().to(device)

model

torch.Size([9, 256, 2, 3, 3])
torch.Size([9, 256, 2, 2])


FactorNet2(
  (bias): Linear(in_features=784, out_features=512, bias=True)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=512, out_features=10, bias=True)
  (la1): FusedPairBilinearBlock(
    (fused_pair_bilinear): Fused2x2BiLinear()
  )
)

In [88]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)

In [89]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    i = -1
    for xx, yy in tqdm(train_loader):
        i += 1 
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        
        if torch.any(torch.isnan(yout.data)):
            print(f"{i},Yout; NAN", flush=True)
            break
            
        loss = criterion(yout, yy)
        
        if torch.any(torch.isnan(loss.data)):
            print(f"{i},loss; NAN", flush=True)
            break
        
        optimizer.zero_grad()
        loss.backward()

        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y.grad)):
            print(f"{i}, NAN Grid back grad; loss {float(loss)}", flush=True)
            break
        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.pairW.grad)):
            print(f"{i}, NAN Weight back grad; loss {float(loss)}", flush=True)
            break

        
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)
        
        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y)):
            print(f"{i}, NAN Grid update; loss {float(loss)}", flush=True)
            break
        if torch.any(torch.isnan(model.la1.fused_pair_bilinear.pairW)):
            print(f"{i}, NAN Weight update; loss {float(loss)}", flush=True)
            break

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx = xx.view(xx.shape[0], -1)
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████████████████████████████████████████████| 300/300 [00:02<00:00, 126.59it/s]


Epoch: 0,  Loss:0.9972912669181824


100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 105.07it/s]


Train Acc:63.73%, Test Acc:72.11%



 89%|████████████████████████████████████████████▎     | 266/300 [00:02<00:00, 132.67it/s]

276,Yout; NAN


 92%|██████████████████████████████████████████████    | 276/300 [00:02<00:00, 124.19it/s]


Epoch: 1,  Loss:58913986707456.0


100%|████████████████████████████████████████████████████| 50/50 [00:00<00:00, 108.25it/s]


Train Acc:76.44%, Test Acc:62.26%



  0%|                                                             | 0/300 [00:00<?, ?it/s]

0, NAN Grid back grad; loss 6.467122680123337e+30


  0%|                                                             | 0/300 [00:00<?, ?it/s]


ZeroDivisionError: float division by zero

In [187]:
model.la1.fused_pair_bilinear.Y.data.max()

tensor(1.0163, device='cuda:0')

In [191]:
torch.any(torch.isnan(model.la1.fused_pair_bilinear.Y.data))

tensor(False, device='cuda:0')

In [192]:
yout = model(xx)

In [195]:
torch.any(torch.isnan(yout))

tensor(False, device='cuda:0')

In [196]:
yy

tensor([7, 1, 4, 5, 8, 3, 5, 7, 6, 4, 5, 9, 0, 0, 6, 9, 4, 1, 7, 0, 4, 1, 6, 8,
        6, 6, 5, 0, 2, 1, 8, 1, 9, 6, 5, 6, 6, 4, 0, 8, 6, 3, 4, 2, 4, 0, 7, 4,
        8, 9, 7, 9, 1, 1, 5, 7, 7, 3, 7, 1, 8, 4, 0, 2, 7, 3, 5, 4, 4, 9, 9, 8,
        4, 1, 9, 0, 7, 9, 0, 3, 6, 0, 2, 5, 0, 9, 5, 9, 0, 9, 1, 3, 4, 5, 4, 9,
        4, 3, 0, 5, 8, 2, 6, 2, 5, 4, 5, 7, 7, 4, 1, 8, 0, 8, 8, 9, 3, 1, 2, 7,
        7, 3, 1, 4, 4, 0, 6, 3, 8, 4, 5, 4, 0, 3, 2, 0, 5, 3, 4, 0, 3, 5, 5, 8,
        9, 1, 3, 0, 1, 2, 2, 2, 6, 7, 3, 0, 7, 6, 8, 7, 1, 2, 1, 2, 4, 4, 5, 9,
        0, 0, 8, 8, 9, 3, 6, 5, 3, 4, 8, 3, 9, 8, 0, 9, 3, 4, 5, 5, 1, 5, 9, 7,
        7, 7, 6, 8, 3, 5, 6, 6], device='cuda:0')

In [197]:
loss = criterion(yout, yy)

In [205]:
yout.argmax()//10

tensor(134, device='cuda:0')

In [206]:
yout[134]

tensor([ 1.1605e+29, -7.1041e+28, -1.0751e+29,  9.6640e+27, -1.3566e+29,
         1.1653e+29,  1.0879e+29,  1.7212e+29, -1.9855e+28,  1.7179e+29],
       device='cuda:0', grad_fn=<SelectBackward>)

In [218]:
yout[134]

tensor([ 1.1605e+29, -7.1041e+28, -1.0751e+29,  9.6640e+27, -1.3566e+29,
         1.1653e+29,  1.0879e+29,  1.7212e+29, -1.9855e+28,  1.7179e+29],
       device='cuda:0', grad_fn=<SelectBackward>)

In [219]:
model.la1.fused_pair_bilinear.Y.data

tensor([[[[[ 3.8739e-03, -7.1850e-03, -9.2713e-03],
           [ 5.0880e-01,  4.9659e-01,  5.0152e-01],
           [ 1.0121e+00,  1.0031e+00,  1.0011e+00]],

          [[-2.2405e-03,  5.0718e-01,  1.0115e+00],
           [-5.2883e-03,  4.9287e-01,  9.8861e-01],
           [-8.6327e-03,  4.9361e-01,  9.9927e-01]]],


         [[[ 2.1228e-03, -7.1292e-04, -8.8196e-03],
           [ 5.0610e-01,  5.0579e-01,  4.9202e-01],
           [ 1.0066e+00,  1.0050e+00,  9.9415e-01]],

          [[-1.5808e-03,  5.0555e-01,  1.0099e+00],
           [-4.6353e-03,  5.0876e-01,  1.0075e+00],
           [-4.6236e-03,  5.0164e-01,  1.0042e+00]]],


         [[[-4.0497e-03,  3.1962e-03, -1.1414e-02],
           [ 5.0873e-01,  4.9291e-01,  5.1071e-01],
           [ 1.0103e+00,  9.9000e-01,  1.0000e+00]],

          [[-1.8174e-03,  5.0718e-01,  1.0113e+00],
           [-1.7545e-03,  4.9468e-01,  9.8872e-01],
           [-9.1505e-03,  5.0697e-01,  1.0000e+00]]],


         ...,


         [[[ 5.7747e-03,  1.67