In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random

In [3]:
import mylibrary.datasets as datasets

In [4]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [5]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

train_size = len(train_label_)

In [6]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)

In [8]:
input_size = 784
output_size = 10

In [9]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [10]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [11]:
learning_rate = 0.0003
batch_size = 50

In [12]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

## Making of matrix factorized layer

In [13]:
class PairBilinear(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)#*0.5
        self.pairW = nn.Parameter(self.pairW)
    
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
        x = x.view(bs, -1, 2).transpose(0,1)
        x = torch.bmm(x, self.pairW)
        x = x.transpose(1,0).reshape(-1, 2)
        
############# OR This block ########################
#         x = x.view(-1, 2)
#         _wi = torch.arange(self.num_pairs).repeat(bs)
#         _W = self.pairW[_wi]
#         x = torch.bmm(x.unsqueeze(1), _W).squeeze(1)
####################################################
        
        x = x*self.grid_width
        index = torch.clamp(x.data, 0, self.grid_width-2)
        index = torch.floor(index)
        x = x-index
        
        index = (index.repeat_interleave(2, dim=0))
        
        _bi = torch.arange(bs).repeat_interleave(self.num_pairs*2)
        _gi = torch.arange(self.num_pairs).repeat_interleave(2).repeat(bs)
        _pi = torch.tensor([0,1], dtype=torch.long).repeat(bs*self.num_pairs)

#         _xc, _yc = tuple(index.type(torch.long).t())
        index_ = index.long().t()
        _xc, _yc = index_[0], index_[1]

#         print("types", _gi.dtype, _pi.dtype, _xc.dtype, _yc.dtype)
        f00 = self.Y[_gi, _pi, _xc, _yc]
        f01 = self.Y[_gi, _pi, _xc, _yc+1]
        f10 = self.Y[_gi, _pi, _xc+1, _yc]
        f11 = self.Y[_gi, _pi, _xc+1, _yc+1]
        #### https://en.wikipedia.org/wiki/Bilinear_interpolation
        a00 = f00
        a01 = f01-f00
        a10 = f10-f00
        a11 = f11-f10-f01+f00
        
        ##### this doubles the multiplication for x,y
#         x = x.repeat_interleave(2, dim=0)
#         y = a00 + x[:, 0]*a10 + x[:, 1]*a01 + x[:, 0]*x[:, 1]*a11

        ##### this repeats in individual way
        _x, _y = x[:, 0].repeat_interleave(2), x[:, 1].repeat_interleave(2)
        y = a00 + _x*a10 + _y*a01 + _x*_y*a11

        ### now 
        y = y.view(bs, -1)
        return y

In [14]:
pairBL = PairBilinear(784, 10).to(device)

# pairBL_s = torch.jit.script(pairBL)
_a = torch.randn(100, 784).to(device)

In [15]:
pairBL(_a) 
# pairBL_s(_a) 

tensor([[ 0.6556,  0.1478,  1.8715,  ..., -1.1404, -0.5592, -0.8718],
        [-0.7999,  1.8631,  0.4950,  ..., -1.8031, -0.1771,  0.2058],
        [-0.5407,  1.2427,  1.0801,  ..., -0.2061, -0.7866,  0.1892],
        ...,
        [-0.0801,  0.4917, -0.7197,  ...,  0.2450, -0.8682, -1.1732],
        [-0.2825, -2.1012, -1.1710,  ...,  0.8991, -1.9083,  1.4222],
        [-0.5166,  0.1732, -0.6692,  ...,  0.3885,  1.8911,  1.1915]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [16]:
_a

tensor([[ 0.1330,  0.5901, -0.0956,  ..., -0.8563, -0.7847, -0.5033],
        [ 1.6768, -0.7199, -0.3840,  ..., -1.4337,  0.1852, -0.1594],
        [ 1.1184, -0.4866, -1.7942,  ..., -0.5875,  0.1702, -0.7080],
        ...,
        [ 0.4426, -0.0721, -0.4282,  ..., -1.4738, -1.0559, -0.7813],
        [-1.8911, -0.2542, -1.9183,  ..., -1.3509,  1.2799, -1.7175],
        [ 0.1559, -0.4649, -1.0354,  ...,  0.4883,  1.0723,  1.7020]],
       device='cuda:0')

In [16]:
# %timeit pairBL(_a) 

In [17]:
# %timeit pairBL_s(_a) 

In [18]:
pairBL.pairW

Parameter containing:
tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        ...,

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0', requires_grad=True)

In [19]:
import bmm2x2_cuda

## Cuda -bmm2x2

In [20]:
class BMM2x2Function(torch.autograd.Function):
    @staticmethod
#     @torch.jit.ignore
    def forward(ctx, inputs, weights):
        outputs = bmm2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]
    
    @staticmethod
#     @torch.jit.ignore
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
        del_input, del_weights = bmm2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [21]:
class PairWeight2(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        self.bmmfunc = BMM2x2Function()
        
    @torch.jit.ignore
    def bmm(self, x, w):
        return  .apply(x, w)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2)
        x = self.bmm(x, self.weight)
        x = x.view(bs, -1)
        return x

In [22]:
pw = PairWeight2(784).to(device)
pw(torch.randn(2,784).to(device))

tensor([[ 0.4846,  0.2139, -0.5816,  ...,  1.2904, -0.8244,  0.8893],
        [-0.2312, -1.1257, -0.1671,  ...,  0.7801, -1.0979, -1.0750]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [23]:
_a = torch.randn(100, 784).to(device)

In [24]:
%timeit pw(_a)

73.1 µs ± 32 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [25]:
torch.jit.script(pw)

RecursiveScriptModule(original_name=PairWeight2)

In [26]:
%timeit pw(_a)

73.4 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [27]:
import bilinear2x2_cuda

## Cuda - Bilinear2x2

In [28]:
class BiLinear2x2Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs, weights):
        outputs = bilinear2x2_cuda.forward(inputs, weights)
        ctx.save_for_backward(inputs, weights)
        return outputs[0]

    @staticmethod
    def backward(ctx, grad_output):
        inputs, weights = ctx.saved_tensors
#         del_input, del_weights = bmm2x2_cuda.backward(
#             grad_output.contiguous(), 
#             grad_cell.contiguous(), 
#             grad_output.contiguous())
        del_input, del_weights = bilinear2x2_cuda.backward(
            inputs, 
            weights, 
            grad_output)
    
        return del_input, del_weights

In [1]:
class PairBilinear2(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)
        self.pairW = nn.Parameter(self.pairW)
    
#     @torch.jit.ignore
#     def pairbl2x2(self, x, w):
#         return BiLinear2x2Function.apply(x, w)
    
#     @torch.jit.ignore
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
    
#         x = x.view(bs, -1, 2).transpose(0,1)
#         x = torch.bmm(x, self.pairW)
#         x = x.transpose(1,0)#.reshape(-1, 2)
        
############# OR This block ########################
        x = x.view(bs, -1, 2)
        x = BMM2x2Function.apply(x, self.pairW)
####################################################
#         x = x.view(bs, -1, 2)
        x = BiLinear2x2Function.apply(x, self.Y)
        x = x.view(bs, -1)
        return x

NameError: name 'nn' is not defined

In [101]:
pbl2 = PairBilinear2(784, 10).to(device)
# pbl2.pairW.data = pairBL.pairW.data.clone()
# pbl2.Y.data = pairBL.Y.data.clone()

In [102]:
pbl2.pairW.data

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        ...,

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]], device='cuda:0')

In [103]:
_a = torch.randn(100, 784).to(device)

In [104]:
# %timeit pbl2(_a) 

In [105]:
# pbl2 = torch.jit.script(pbl2)

In [106]:
# %timeit -n10000 -r7 pbl2(_a) ### why does scripting produce poorer performance ??

In [107]:
_y = pbl2(_a)
_y.mean().backward()
_y.shape

torch.Size([100, 784])

In [108]:
y = pairBL(_a)
y.mean().backward()
y.shape

torch.Size([100, 784])

In [109]:
(_y-y).abs().max()

tensor(4.7684e-07, device='cuda:0', grad_fn=<MaxBackward1>)

In [110]:
pbl2.Y.grad[0][0]

tensor([[ 1.8960e-02, -1.5911e-02,  5.3492e-04,  1.9585e-04,  3.1085e-04,
          5.8913e-05,  1.1378e-04,  1.4477e-04, -1.1113e-02,  1.2932e-02],
        [-1.6555e-02,  1.3890e-02, -4.9747e-04, -1.8149e-04, -2.8556e-04,
         -4.1738e-05, -9.7319e-05, -1.2630e-04,  9.6358e-03, -1.1210e-02],
        [ 2.5066e-04, -2.1867e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -2.9182e-05,  3.9054e-05],
        [ 1.3725e-04, -1.2182e-04,  9.4361e-06,  2.2923e-06,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 5.9988e-04, -5.1897e-04,  2.6828e-06,  2.0066e-07,  0.0000e+00,
          0.0000e+00,  4.5391e-06,  1.1132e-05, -1.0785e-06,  1.9395e-06],
        [ 4.9923e-04, -4.6348e-04,  6.6691e-07,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  3.2840e-06,  6.5553e-06, -2.1993e-05,  4.6497e-05],
        [ 9.3554e-05, -8.4381e-05,  0.0000e+00,  0.0000e+00,  2.1386e-06,
          1.9833e-06,  0.0000e+0

In [111]:
pairBL.Y.grad[0][0]

tensor([[ 3.0136e-02, -2.5613e-02,  9.7418e-04,  5.0602e-04,  4.7834e-04,
          2.0383e-04,  1.2791e-04,  5.7154e-04, -2.4811e-02,  2.8339e-02],
        [-2.4848e-02,  2.1005e-02, -9.0548e-04, -4.6497e-04, -4.2285e-04,
         -1.7482e-04, -1.0873e-04, -5.1074e-04,  2.2353e-02, -2.5456e-02],
        [ 4.9635e-04, -4.4098e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  3.2764e-06, -2.8335e-05,  3.9054e-05],
        [ 4.8818e-04, -4.3903e-04,  2.1276e-05,  1.5214e-05,  2.6411e-07,
          9.1284e-07,  0.0000e+00,  0.0000e+00, -1.3769e-04,  1.6480e-04],
        [ 1.2331e-03, -1.0884e-03,  9.8100e-06,  5.3056e-06,  2.5982e-06,
          8.9800e-06,  4.5391e-06,  1.1132e-05, -7.5191e-05,  9.9957e-05],
        [ 6.5121e-04, -5.7521e-04,  6.6691e-07,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  3.2840e-06,  6.5553e-06, -2.1993e-05,  4.6497e-05],
        [ 2.1518e-04, -1.7046e-04,  0.0000e+00,  0.0000e+00,  2.1386e-06,
          1.9833e-06,  0.0000e+0

In [112]:
(pbl2.Y.grad - pairBL.Y.grad).abs().max()

tensor(0.0485, device='cuda:0')

In [113]:
%timeit pbl2(_a) 
## without script -> 130 µs ± 6.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
## with script -> 492 µs ± 17.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

KeyboardInterrupt: 

In [None]:
%timeit pairBL(_a)

In [114]:
class BiasLayer(nn.Module):
    def __init__(self, dim, init_val=0):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(dim)*init_val)
        
    def forward(self, x):
        return x+self.bias

In [137]:
class FactorizedPairBilinearSpline(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        self.num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        self.idx_revidx = []
        for i in range(self.num_layers):
            idrid = self.get_pair(self.input_dim, i+1)
            net = PairBilinear2(self.input_dim, grid_width)
            self.facto_nets.append(net)
            self.idx_revidx.append(idrid)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
#     @torch.jit.ignore
    def get_pair(self, inp_dim, step=1):
        dim = 2**int(np.ceil(np.log2(inp_dim)))
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
        block_map = (pairs_+repeat_).reshape(-1)

        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx[indx<inp_dim]

        rev_indx = torch.argsort(indx)
        return indx, rev_indx
    
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        y = x
#         for i in range(len(self.facto_nets)):
        for i, fn in enumerate(self.facto_nets):
            idx, revidx = self.idx_revidx[i]
            y = y[:, idx]
            y = fn(y) 
            y = y[:, revidx]
#         y = x + y ## this is residual addition... remove if only want feed forward
        return y

In [138]:
pfL = FactorizedPairBilinearSpline(784, 10).to(device)

In [139]:
# pfL = torch.jit.script(pfL)

In [140]:
pfL(torch.randn(100, 784).to(device))

tensor([[ 1.1644, -1.2263, -5.6556,  ..., -4.3227, -2.2853, -2.1946],
        [-0.5468, -1.1605, -0.7539,  ..., -2.3264,  0.3838, -0.2109],
        [-1.7189,  4.3969,  3.4433,  ...,  0.0726,  3.0803,  1.7158],
        ...,
        [ 1.4555, -3.7405,  1.5897,  ...,  2.3172,  0.2741,  1.2542],
        [ 3.3057,  3.0363, -4.1321,  ..., -2.6341, -1.8085,  4.3901],
        [-3.1164,  1.5452, -0.7791,  ..., -3.7420,  0.3990, -0.3774]],
       device='cuda:0', grad_fn=<IndexBackward>)

In [141]:
_a = torch.randn(100, 784).to(device)

%timeit pfL(_a)

3.49 ms ± 53.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [142]:
pfL.facto_nets

ModuleList(
  (0): PairBilinear2()
  (1): PairBilinear2()
  (2): PairBilinear2()
  (3): PairBilinear2()
  (4): PairBilinear2()
  (5): PairBilinear2()
  (6): PairBilinear2()
  (7): PairBilinear2()
  (8): PairBilinear2()
  (9): PairBilinear2()
)

In [143]:
param_count = sum([torch.numel(p) for p in pfL.parameters()])
param_count

799680

In [144]:
784*784

614656

In [169]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = BiasLayer(784)
        self.la1 = FactorizedPairBilinearSpline(784, grid_width=2)
        self.bn1 = nn.BatchNorm1d(784)
        self.fc = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [170]:
# class FactorNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.bias = BiasLayer(784)
#         self.la1 = FactorizedPairBilinearSpline(784, grid_width=5)
#         self.bn1 = nn.BatchNorm1d(784)
#         self.la2 = FactorizedPairBilinearSpline(784, grid_width=5)
#         self.bn2 = nn.BatchNorm1d(784)
#         self.fc = nn.Linear(784, 10)
        
#     def forward(self, x):
#         x = self.bias(x)
#         x = self.bn1(self.la1(x))
#         x = torch.relu(x)
#         x = self.bn2(self.la2(x))
#         x = torch.relu(x)
#         x = self.fc(x)
#         return x

In [171]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 784, bias=False)
        self.bn1 = nn.BatchNorm1d(784)
        self.la2 = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.la2(x)
        return x

In [172]:
model = FactorNet()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

57242

In [173]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(624074, 10.902379371789944)

### Model Development

In [174]:
torch.manual_seed(0)
model = FactorNet().to(device)
# model = OrdinaryNet().to(device)
model

FactorNet(
  (bias): BiasLayer()
  (la1): FactorizedPairBilinearSpline(
    (facto_nets): ModuleList(
      (0): PairBilinear2()
      (1): PairBilinear2()
      (2): PairBilinear2()
      (3): PairBilinear2()
      (4): PairBilinear2()
      (5): PairBilinear2()
      (6): PairBilinear2()
      (7): PairBilinear2()
      (8): PairBilinear2()
      (9): PairBilinear2()
    )
  )
  (bn1): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=784, out_features=10, bias=True)
)

In [175]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

In [176]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  57242


In [177]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:40<00:00, 29.38it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0,  Loss:3.5326058864593506


100%|██████████| 200/200 [00:01<00:00, 157.48it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:76.29%, Test Acc:80.03%



100%|██████████| 1200/1200 [00:41<00:00, 29.16it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1,  Loss:1.3521384000778198


100%|██████████| 200/200 [00:01<00:00, 151.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.04%, Test Acc:81.78%



100%|██████████| 1200/1200 [00:40<00:00, 29.82it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2,  Loss:1.3875417709350586


100%|██████████| 200/200 [00:01<00:00, 170.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:82.56%, Test Acc:82.69%



100%|██████████| 1200/1200 [00:39<00:00, 30.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3,  Loss:0.9956820607185364


100%|██████████| 200/200 [00:01<00:00, 171.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.12%, Test Acc:81.69%



100%|██████████| 1200/1200 [00:39<00:00, 30.21it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4,  Loss:0.584894061088562


100%|██████████| 200/200 [00:01<00:00, 172.46it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.77%, Test Acc:81.18%



100%|██████████| 1200/1200 [00:40<00:00, 29.93it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5,  Loss:1.064032793045044


100%|██████████| 200/200 [00:01<00:00, 168.60it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.52%, Test Acc:82.07%



100%|██████████| 1200/1200 [00:39<00:00, 30.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6,  Loss:0.47845005989074707


100%|██████████| 200/200 [00:01<00:00, 173.02it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.78%, Test Acc:82.27%



100%|██████████| 1200/1200 [00:39<00:00, 30.13it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7,  Loss:0.5694294571876526


100%|██████████| 200/200 [00:01<00:00, 175.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.91%, Test Acc:82.70%



100%|██████████| 1200/1200 [00:39<00:00, 30.08it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8,  Loss:0.48077356815338135


100%|██████████| 200/200 [00:01<00:00, 170.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.77%, Test Acc:82.03%



100%|██████████| 1200/1200 [00:39<00:00, 30.11it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9,  Loss:0.25931403040885925


100%|██████████| 200/200 [00:01<00:00, 169.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.97%, Test Acc:83.25%



100%|██████████| 1200/1200 [00:39<00:00, 30.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10,  Loss:0.7134172916412354


100%|██████████| 200/200 [00:01<00:00, 166.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.66%, Test Acc:84.94%



100%|██████████| 1200/1200 [00:39<00:00, 30.00it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11,  Loss:0.3500666916370392


100%|██████████| 200/200 [00:01<00:00, 169.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.93%, Test Acc:81.61%



100%|██████████| 1200/1200 [00:39<00:00, 30.13it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12,  Loss:0.45489415526390076


100%|██████████| 200/200 [00:01<00:00, 171.88it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.18%, Test Acc:81.38%



100%|██████████| 1200/1200 [00:39<00:00, 30.09it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13,  Loss:0.3033628463745117


100%|██████████| 200/200 [00:01<00:00, 171.08it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.56%, Test Acc:85.18%



100%|██████████| 1200/1200 [00:40<00:00, 29.90it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14,  Loss:1.0426937341690063


100%|██████████| 200/200 [00:01<00:00, 173.00it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.53%, Test Acc:84.34%



100%|██████████| 1200/1200 [00:39<00:00, 30.10it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15,  Loss:0.2512720823287964


100%|██████████| 200/200 [00:01<00:00, 173.27it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.89%, Test Acc:85.40%



100%|██████████| 1200/1200 [00:39<00:00, 30.04it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16,  Loss:0.6238232851028442


100%|██████████| 200/200 [00:01<00:00, 170.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.55%, Test Acc:85.18%



100%|██████████| 1200/1200 [00:40<00:00, 29.97it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17,  Loss:0.38584384322166443


100%|██████████| 200/200 [00:01<00:00, 175.51it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.70%, Test Acc:83.79%



100%|██████████| 1200/1200 [00:39<00:00, 30.03it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18,  Loss:0.41940611600875854


100%|██████████| 200/200 [00:01<00:00, 173.75it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.87%, Test Acc:85.05%



100%|██████████| 1200/1200 [00:39<00:00, 30.06it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19,  Loss:0.2131788432598114


100%|██████████| 200/200 [00:01<00:00, 171.78it/s]

Train Acc:88.13%, Test Acc:85.17%

	-> Train Acc 88.13 ; Test Acc 85.39999999999999





In [None]:
# 100%|██████████| 1200/1200 [00:52<00:00, 22.72it/s] using called pairlinear
# 100%|██████████| 1200/1200 [00:10<00:00, 118.42it/s] using Ordinary

In [None]:
## stats: 20 epochs || Fact+BN+Linear ; lr0.0001 ##_with 3 bilinear layers
### for factor-net: 5grid : 73706-> 100%|██████████| 1200/1200 [00:24<00:00, 48.44it/s]
########### -> Train Acc 90.3367 ; Test Acc 88.06

### for factor-net: 50grid : 5894906-> 100%|██████████| 1200/1200 [00:28<00:00, 42.74it/s]
########### -> Train Acc 99.985 ; Test Acc 85.85

### for factor-net: 10grid : 250106-> 100%|██████████| 1200/1200 [00:24<00:00, 48.11it/s]
########### -> Train Acc 92.17167 ; Test Acc 88.36

In [None]:
### for factor-net: fact+bn+relu+linear : 5grid : lr 0.0003
####### -> Train Acc 92.42833333333334 ; Test Acc 88.42

### same : factor-net had default of 3 bilinear layers.. changed to log2(input dim)=10 to properly mix all.
#######  -> 100%|██████████| 1200/1200 [00:50<00:00, 23.88it/s]
### facto-net: fact+bn+relu+linear : 5grid  -> params=221882
######## -> Train Acc 95.165 ; Test Acc 89.45

### ordinary net || linear+BN+Linear : lr=0.0003 : params=624074  -> [579.83it/s]
######## -> Train Acc 95.96166666666667 ; Test Acc 89.33


### Sparse Dataset