In [4]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [5]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random

In [6]:
import mylibrary.datasets as datasets

In [7]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [8]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

train_size = len(train_label_)

In [9]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [10]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)

In [11]:
input_size = 784
output_size = 10

In [12]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [13]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [14]:
learning_rate = 0.0003
batch_size = 50

In [15]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

## Making of matrix factorized layer

In [97]:
class PairBilinear(nn.Module):
    def __init__(self, dim, grid_width):
        super().__init__()
        self.dim = dim
        self.grid_width = grid_width
        
        self.num_pairs = self.dim // 2
        along_row = torch.linspace(0, 1, self.grid_width).reshape(1, -1)
        along_col = torch.linspace(0, 1, self.grid_width).reshape(-1, 1)
        self.Y = torch.stack([along_row+along_col*0, along_row*0+along_col])
        self.Y = torch.repeat_interleave(self.Y.unsqueeze(0), self.num_pairs, dim=0)
        self.Y = nn.Parameter(self.Y)
        
        self.pairW = torch.eye(2).unsqueeze(0).repeat_interleave(self.num_pairs, dim=0)*0.5
        self.pairW = nn.Parameter(self.pairW)
    
    def forward(self, x):
        bs = x.shape[0]
        
############# This block ########################
        ### this block is significantly faster
        x = x.view(bs, -1, 2).transpose(0,1)
        x = torch.bmm(x, self.pairW)
        x = x.transpose(1,0).reshape(-1, 2)
        
############# OR This block ########################
#         x = x.view(-1, 2)
#         _wi = torch.arange(self.num_pairs).repeat(bs)
#         _W = self.pairW[_wi]
#         x = torch.bmm(x.unsqueeze(1), _W).squeeze(1)
####################################################
        
        x = x*self.grid_width
        index = torch.clamp(x.data, 0, self.grid_width-2)
        index = torch.floor(index)
        x = x-index
        
        index = (index.repeat_interleave(2, dim=0))
        
        _bi = torch.arange(bs).repeat_interleave(self.num_pairs*2)
        _gi = torch.arange(self.num_pairs).repeat_interleave(2).repeat(bs)
        _pi = torch.tensor([0,1], dtype=torch.long).repeat(bs*self.num_pairs)

#         _xc, _yc = tuple(index.type(torch.long).t())
        index_ = index.long().t()
        _xc, _yc = index_[0], index_[1]

#         print("types", _gi.dtype, _pi.dtype, _xc.dtype, _yc.dtype)
        f00 = self.Y[_gi, _pi, _xc, _yc]
        f01 = self.Y[_gi, _pi, _xc, _yc+1]
        f10 = self.Y[_gi, _pi, _xc+1, _yc]
        f11 = self.Y[_gi, _pi, _xc+1, _yc+1]
        
        a00 = f00
        a10 = f10-f00
        a01 = f01-f00
        a11 = f11-f10-f01+f00
        
        ##### this doubles the multiplication for x,y
#         x = x.repeat_interleave(2, dim=0)
#         y = a00 + x[:, 0]*a10 + x[:, 1]*a01 + x[:, 0]*x[:, 1]*a11

        ##### this repeats in individual way
        _x, _y = x[:, 0].repeat_interleave(2), x[:, 1].repeat_interleave(2)
        y = a00 + _x*a10 + _y*a01 + _x*_y*a11

        ### now 
        y = y.view(bs, -1)
        return y

In [98]:
pairBL = PairBilinear(784, 10).to(device)

pairBL_s = torch.jit.script(pairBL)
_a = torch.randn(100, 784).to(device)

In [99]:
# pairBL(_a) 
pairBL_s(_a) 

tensor([[-0.0071,  0.7113, -0.7153,  ..., -0.3795,  0.4411, -0.3208],
        [-1.1016,  0.3758,  0.8971,  ..., -0.8615,  0.5358, -0.6306],
        [-0.1761,  0.2188, -0.0178,  ..., -0.5722,  0.6258, -0.3648],
        ...,
        [ 0.5361,  0.0208,  0.4173,  ..., -0.2884,  0.2662, -0.4721],
        [-0.4196, -0.7848, -0.2637,  ...,  0.2242, -0.6709, -0.7388],
        [ 0.2409,  0.2897,  0.1081,  ...,  0.0153,  1.4901, -0.2373]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [104]:
%timeit pairBL(_a) 

10.8 ms ± 92.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [105]:
%timeit pairBL_s(_a) 

9.76 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [90]:
class BiasLayer(nn.Module):
    def __init__(self, dim, init_val=0):
        super().__init__()
        self.bias = nn.Parameter(torch.ones(dim)*init_val)
        
    def forward(self, x):
        return x+self.bias

In [91]:
class FactorizedPairBilinearSpline(nn.Module):
    
    def __init__(self, input_dim, grid_width):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        self.idx_revidx = []
        for i in range(num_layers):
            idrid = self.get_pair(self.input_dim, i+1)
            net = PairBilinear(self.input_dim, grid_width)
            self.facto_nets.append(net)
            self.idx_revidx.append(idrid)
        self.facto_nets = nn.ModuleList(self.facto_nets)
            
    def get_pair(self, inp_dim, step=1):
        dim = 2**int(np.ceil(np.log2(inp_dim)))
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
        block_map = (pairs_+repeat_).reshape(-1)

        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx[indx<inp_dim]

        rev_indx = torch.argsort(indx)
        return indx, rev_indx
    
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        y = x
        for i in range(len(self.facto_nets)):
            idx, revidx = self.idx_revidx[i]
            y = y[:, idx]
            y = self.facto_nets[i](y) 
            y = y[:, revidx]
        y = x + y ## this is residual addition... remove if only want feed forward
        return y

In [92]:
pfL = FactorizedPairBilinearSpline(100, 5)

In [93]:
pfL(torch.randn(2, 100))

tensor([[ 0.9614,  0.6828, -0.9391,  2.1458, -0.0869,  0.6399,  1.3007, -0.6229,
         -0.1580,  0.0300, -0.9818,  0.7819, -1.7387, -0.7296, -1.2122,  0.9295,
         -0.0530,  1.2357, -0.4474, -2.5562,  0.7614, -0.4462, -0.3371, -1.1049,
          1.2296,  0.7923,  0.7100,  0.7280, -1.7893, -1.6331, -0.6923,  1.6057,
         -2.0934, -1.9059, -0.2090,  1.0632, -1.3126,  0.3149,  0.5082,  1.1400,
         -0.1142,  1.9484,  0.3631,  2.1911,  0.1444,  1.3763, -1.4698,  1.4223,
         -2.3476, -0.1818, -1.6571,  0.9826,  0.1942,  1.4795,  0.1006, -0.3802,
         -0.3359,  0.2492, -1.5641, -0.7567,  0.0277,  0.0558,  0.6531,  0.2979,
          1.8045,  0.5515, -0.7329, -0.9733, -0.9201, -0.6423, -0.6129,  0.3823,
          0.4918,  0.6434, -0.1310,  0.1267,  1.2961, -1.0392, -2.4240,  1.4502,
         -0.9543, -0.7714,  0.9922,  0.0579,  1.7345, -1.2354,  0.2363,  0.1950,
         -0.0647, -1.3091,  1.1971,  1.0509,  0.2714, -0.8513, -0.2608, -0.1523,
         -0.3489, -0.2296, -

In [94]:
pfL.facto_nets

ModuleList(
  (0): PairBilinear()
  (1): PairBilinear()
  (2): PairBilinear()
  (3): PairBilinear()
  (4): PairBilinear()
  (5): PairBilinear()
  (6): PairBilinear()
)

In [95]:
param_count = sum([torch.numel(p) for p in pfL.parameters()])
param_count

18900

In [96]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = BiasLayer(784)
        self.la1 = FactorizedPairBilinearSpline(784, grid_width=5)
        self.bn1 = nn.BatchNorm1d(784)
        self.fc = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bias(x)
        x = self.la1(x)
#         x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [97]:
# class FactorNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.bias = BiasLayer(784)
#         self.la1 = FactorizedPairBilinearSpline(784, grid_width=5)
#         self.bn1 = nn.BatchNorm1d(784)
#         self.la2 = FactorizedPairBilinearSpline(784, grid_width=5)
#         self.bn2 = nn.BatchNorm1d(784)
#         self.fc = nn.Linear(784, 10)
        
#     def forward(self, x):
#         x = self.bias(x)
#         x = self.bn1(self.la1(x))
#         x = torch.relu(x)
#         x = self.bn2(self.la2(x))
#         x = torch.relu(x)
#         x = self.fc(x)
#         return x

In [98]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 784, bias=False)
        self.bn1 = nn.BatchNorm1d(784)
        self.la2 = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.la2(x)
        return x

In [99]:
model = FactorNet()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

221882

In [100]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(624074, 2.812639150539476)

### Model Development

In [105]:
torch.manual_seed(0)
# model = FactorNet().to(device)
model = OrdinaryNet().to(device)
model

OrdinaryNet(
  (la1): Linear(in_features=784, out_features=784, bias=False)
  (bn1): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la2): Linear(in_features=784, out_features=10, bias=True)
)

In [106]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

In [107]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  624074


In [108]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:02<00:00, 592.29it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0,  Loss:0.5296090841293335


100%|██████████| 200/200 [00:00<00:00, 719.11it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.43%, Test Acc:85.73%



100%|██████████| 1200/1200 [00:02<00:00, 595.07it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1,  Loss:0.37414395809173584


100%|██████████| 200/200 [00:00<00:00, 760.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.76%, Test Acc:86.73%



100%|██████████| 1200/1200 [00:02<00:00, 592.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2,  Loss:0.3412091135978699


100%|██████████| 200/200 [00:00<00:00, 725.45it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.92%, Test Acc:87.05%



100%|██████████| 1200/1200 [00:02<00:00, 590.42it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3,  Loss:0.2261812388896942


100%|██████████| 200/200 [00:00<00:00, 680.74it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:89.99%, Test Acc:87.35%



100%|██████████| 1200/1200 [00:02<00:00, 595.18it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4,  Loss:0.4199630618095398


100%|██████████| 200/200 [00:00<00:00, 680.95it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.59%, Test Acc:87.85%



100%|██████████| 1200/1200 [00:02<00:00, 594.17it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5,  Loss:0.16107900440692902


100%|██████████| 200/200 [00:00<00:00, 693.33it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.43%, Test Acc:88.01%



100%|██████████| 1200/1200 [00:02<00:00, 582.23it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6,  Loss:0.14165063202381134


100%|██████████| 200/200 [00:00<00:00, 723.12it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:91.78%, Test Acc:88.29%



100%|██████████| 1200/1200 [00:02<00:00, 581.19it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7,  Loss:0.15212902426719666


100%|██████████| 200/200 [00:00<00:00, 735.15it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.31%, Test Acc:88.70%



100%|██████████| 1200/1200 [00:02<00:00, 589.26it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8,  Loss:0.4977302849292755


100%|██████████| 200/200 [00:00<00:00, 719.03it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.80%, Test Acc:88.67%



100%|██████████| 1200/1200 [00:02<00:00, 588.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9,  Loss:0.29144924879074097


100%|██████████| 200/200 [00:00<00:00, 684.53it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.18%, Test Acc:88.78%



100%|██████████| 1200/1200 [00:02<00:00, 590.76it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10,  Loss:0.1441831886768341


100%|██████████| 200/200 [00:00<00:00, 725.58it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.62%, Test Acc:88.68%



100%|██████████| 1200/1200 [00:02<00:00, 592.36it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11,  Loss:0.14703786373138428


100%|██████████| 200/200 [00:00<00:00, 677.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.95%, Test Acc:88.82%



100%|██████████| 1200/1200 [00:02<00:00, 598.21it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12,  Loss:0.23298802971839905


100%|██████████| 200/200 [00:00<00:00, 765.67it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.33%, Test Acc:88.82%



100%|██████████| 1200/1200 [00:02<00:00, 593.57it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13,  Loss:0.05440201610326767


100%|██████████| 200/200 [00:00<00:00, 643.07it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.47%, Test Acc:88.56%



100%|██████████| 1200/1200 [00:02<00:00, 596.80it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14,  Loss:0.20857298374176025


100%|██████████| 200/200 [00:00<00:00, 706.05it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:94.93%, Test Acc:88.88%



100%|██████████| 1200/1200 [00:02<00:00, 596.68it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15,  Loss:0.11299629509449005


100%|██████████| 200/200 [00:00<00:00, 689.84it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.10%, Test Acc:88.79%



100%|██████████| 1200/1200 [00:01<00:00, 603.19it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16,  Loss:0.08547037839889526


100%|██████████| 200/200 [00:00<00:00, 705.07it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.52%, Test Acc:88.72%



100%|██████████| 1200/1200 [00:02<00:00, 591.74it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17,  Loss:0.20652280747890472


100%|██████████| 200/200 [00:00<00:00, 756.41it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.47%, Test Acc:89.14%



100%|██████████| 1200/1200 [00:02<00:00, 579.83it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18,  Loss:0.07988645136356354


100%|██████████| 200/200 [00:00<00:00, 730.51it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.81%, Test Acc:89.33%



100%|██████████| 1200/1200 [00:02<00:00, 590.30it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19,  Loss:0.10500206798315048


100%|██████████| 200/200 [00:00<00:00, 669.59it/s]

Train Acc:95.96%, Test Acc:89.01%

	-> Train Acc 95.96166666666667 ; Test Acc 89.33





In [None]:
## stats: 20 epochs || Fact+BN+Linear ; lr0.0001 ##_with 3 bilinear layers
### for factor-net: 5grid : 73706-> 100%|██████████| 1200/1200 [00:24<00:00, 48.44it/s]
########### -> Train Acc 90.3367 ; Test Acc 88.06

### for factor-net: 50grid : 5894906-> 100%|██████████| 1200/1200 [00:28<00:00, 42.74it/s]
########### -> Train Acc 99.985 ; Test Acc 85.85

### for factor-net: 10grid : 250106-> 100%|██████████| 1200/1200 [00:24<00:00, 48.11it/s]
########### -> Train Acc 92.17167 ; Test Acc 88.36

In [None]:
### for factor-net: fact+bn+relu+linear : 5grid : lr 0.0003
####### -> Train Acc 92.42833333333334 ; Test Acc 88.42

### same : factor-net had default of 3 bilinear layers.. changed to log2(input dim)=10 to properly mix all.
#######  -> 100%|██████████| 1200/1200 [00:50<00:00, 23.88it/s]
### facto-net: fact+bn+relu+linear : 5grid  -> params=221882
######## -> Train Acc 95.165 ; Test Acc 89.45

### ordinary net || linear+BN+Linear : lr=0.0003 : params=624074
######## -> Train Acc 95.96166666666667 ; Test Acc 89.33
