In [23]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [24]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random

In [25]:
import mylibrary.datasets as datasets

In [26]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [27]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

train_size = len(train_label_)

In [28]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [29]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)

In [30]:
input_size = 784
output_size = 10

In [31]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [32]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [33]:
learning_rate = 0.0003
batch_size = 50

In [34]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

## Making of matrix factorized layer

In [47]:
class PairFactorizedLinear(nn.Module):
    
    def __init__(self, input_dim, output_dim=None, bias=True):
        super().__init__()
        self.input_dim = input_dim
        if output_dim is None:
            output_dim = input_dim
        self.output_dim = output_dim
        
        self.num_factors = int(np.ceil(np.log2(max(self.input_dim, self.output_dim))))
        self.dim = 2**self.num_factors
        self.pair_indices = [self.get_pair_indices(self.dim, i+1) for i in range(self.num_factors)]
        self.pair_indices = [tuple(ind.t()) for ind in self.pair_indices]
        
        self.weights = []
        for _ in range(self.num_factors):
            w = nn.Parameter(torch.randn(self.dim*2))
            self.weights.append(w)
        self.weights = nn.ParameterList(self.weights)
        self.bias = nn.Parameter(torch.zeros(self.output_dim))
        
        
#         self.zero_mat = torch.zeros(dim, dim)
        self.W = torch.eye(self.dim)
    
        if self.input_dim < self.dim:
            self.add_dim = torch.randint(0, self.input_dim, (self.dim-self.input_dim, ))
        if self.output_dim < self.dim:
            self.select_dim = torch.LongTensor(np.random.permutation(self.dim)[:self.output_dim])
        
    def forward(self, x):
        W = self.W.to(x.device)
        for pi, w in zip(self.pair_indices, self.weights):
            zero_mat = torch.zeros(self.dim, self.dim, device=x.device)
            zero_mat[pi] = w
            W = W@zero_mat
            
        
        if self.input_dim < self.dim:
            x = torch.cat([x, x[:, self.add_dim]], dim=1)
        y = x@W
        if self.output_dim < self.dim:
            y = y[:, self.select_dim]
        if self.bias is not None:
            y = y + self.bias
        return y

    def get_pair_indices(self, dim, step=1):
        assert 2**int(np.log2(dim)) == dim , "The dim must be power of 2"
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)

        block_map = (pairs_+repeat_).reshape(-1)
        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx.reshape(-1, 2)

        map_idx = []
        for idx in indx:
            map_idx.append((idx[0], idx[0]))
            map_idx.append((idx[0], idx[1]))        
            map_idx.append((idx[1], idx[0]))        
            map_idx.append((idx[1], idx[1]))        
        map_idx = torch.LongTensor(map_idx)
        return map_idx

In [48]:
pfL = PairFactorizedLinear(100, 1024)

In [49]:
pfL.num_factors

10

In [50]:
pfL(torch.randn(2, 100))

tensor([[10.2526, -3.1802, -3.2392,  ...,  4.3016, -2.9256, -8.6059],
        [32.1401, -3.1110,  0.8775,  ..., -7.6219,  6.2570, -1.0045]],
       grad_fn=<AddBackward0>)

In [51]:
param_count = sum([torch.numel(p) for p in pfL.parameters()])
param_count

21504

In [52]:
1024*1024/param_count

48.76190476190476

In [56]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = PairFactorizedLinear(784, 1024, bias=False)
        self.bn1 = nn.BatchNorm1d(1024)
        self.la2 = PairFactorizedLinear(1024, bias=False)
        self.bn2 = nn.BatchNorm1d(1024)
        self.la3 = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.bn2(self.la2(x))
        x = torch.relu(x)
        x = self.la3(x)
        return x

In [60]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 1024, bias=False)
        self.bn1 = nn.BatchNorm1d(1024)
        self.la2 = nn.Linear(1024, 1024, bias=False)
        self.bn2 = nn.BatchNorm1d(1024)
        self.la3 = nn.Linear(1024, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.bn2(self.la2(x))
        x = torch.relu(x)
        x = self.la3(x)
        return x

In [61]:
model = FactorNet()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

57354

In [62]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(1865738, 32.530215852425286)

### Model Development

In [64]:
torch.manual_seed(0)
model = FactorNet().to(device)
# model = OrdinaryNet().to(device)
model

FactorNet(
  (la1): PairFactorizedLinear(
    (weights): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (4): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (5): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (6): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (7): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (8): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
        (9): Parameter containing: [torch.cuda.FloatTensor of size 2048 (GPU 0)]
    )
  )
  (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la2): PairFact

In [65]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [66]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  57354


In [67]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [00:33<00:00, 35.58it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0,  Loss:0.6693528890609741


100%|██████████| 200/200 [00:03<00:00, 65.66it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:74.52%, Test Acc:79.94%



100%|██████████| 1200/1200 [00:33<00:00, 35.52it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1,  Loss:0.5355532169342041


100%|██████████| 200/200 [00:03<00:00, 65.23it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:81.64%, Test Acc:82.11%



100%|██████████| 1200/1200 [00:33<00:00, 35.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2,  Loss:0.2605569660663605


100%|██████████| 200/200 [00:03<00:00, 65.54it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:83.16%, Test Acc:82.90%



100%|██████████| 1200/1200 [00:33<00:00, 35.47it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3,  Loss:0.4016371965408325


100%|██████████| 200/200 [00:03<00:00, 65.51it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:84.42%, Test Acc:83.53%



100%|██████████| 1200/1200 [00:33<00:00, 35.52it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4,  Loss:0.5834341049194336


100%|██████████| 200/200 [00:03<00:00, 65.43it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.03%, Test Acc:84.40%



100%|██████████| 1200/1200 [00:33<00:00, 35.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5,  Loss:0.541366696357727


100%|██████████| 200/200 [00:03<00:00, 65.52it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:85.70%, Test Acc:84.34%



100%|██████████| 1200/1200 [00:33<00:00, 35.43it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6,  Loss:0.4174977242946625


100%|██████████| 200/200 [00:03<00:00, 65.56it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.14%, Test Acc:84.96%



100%|██████████| 1200/1200 [00:33<00:00, 35.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7,  Loss:0.51841139793396


100%|██████████| 200/200 [00:03<00:00, 65.39it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.37%, Test Acc:85.04%



100%|██████████| 1200/1200 [00:33<00:00, 35.39it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8,  Loss:0.43383219838142395


100%|██████████| 200/200 [00:03<00:00, 65.51it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.91%, Test Acc:85.47%



100%|██████████| 1200/1200 [00:33<00:00, 35.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9,  Loss:0.43389445543289185


100%|██████████| 200/200 [00:03<00:00, 65.45it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:86.95%, Test Acc:85.70%



100%|██████████| 1200/1200 [00:33<00:00, 35.44it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10,  Loss:0.27512410283088684


100%|██████████| 200/200 [00:03<00:00, 65.56it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.37%, Test Acc:86.05%



100%|██████████| 1200/1200 [00:33<00:00, 35.44it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11,  Loss:0.4482766389846802


100%|██████████| 200/200 [00:03<00:00, 65.34it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.63%, Test Acc:86.08%



100%|██████████| 1200/1200 [00:33<00:00, 35.54it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12,  Loss:0.3700610399246216


100%|██████████| 200/200 [00:03<00:00, 65.02it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.83%, Test Acc:86.15%



100%|██████████| 1200/1200 [00:33<00:00, 35.45it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13,  Loss:0.18712036311626434


100%|██████████| 200/200 [00:03<00:00, 65.46it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:87.90%, Test Acc:86.32%



100%|██████████| 1200/1200 [00:33<00:00, 35.46it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14,  Loss:0.3768487274646759


100%|██████████| 200/200 [00:03<00:00, 65.46it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.08%, Test Acc:86.48%



100%|██████████| 1200/1200 [00:33<00:00, 35.43it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15,  Loss:0.26308077573776245


100%|██████████| 200/200 [00:03<00:00, 65.01it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.39%, Test Acc:86.72%



100%|██████████| 1200/1200 [00:33<00:00, 35.41it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16,  Loss:0.287915974855423


100%|██████████| 200/200 [00:03<00:00, 65.46it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.45%, Test Acc:86.54%



100%|██████████| 1200/1200 [00:33<00:00, 35.48it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17,  Loss:0.6232171058654785


100%|██████████| 200/200 [00:03<00:00, 65.33it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.55%, Test Acc:86.75%



100%|██████████| 1200/1200 [00:33<00:00, 35.49it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18,  Loss:0.21898196637630463


100%|██████████| 200/200 [00:03<00:00, 65.30it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.77%, Test Acc:86.81%



100%|██████████| 1200/1200 [00:33<00:00, 35.53it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19,  Loss:0.2985161542892456


100%|██████████| 200/200 [00:03<00:00, 65.29it/s]

Train Acc:88.89%, Test Acc:86.79%

	-> Train Acc 88.89333333333333 ; Test Acc 86.81





In [None]:
## stats: 20 epochs
### for factor-net -> 57354 -> Train Acc 88.83666666666666 ; Test Acc 86.7
### for ordinary-net -> 1867786 -> Train Acc 97.95833333333334 ; Test Acc 89.61