In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [3]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from tqdm import tqdm
from sklearn import datasets
import random

In [4]:
import mylibrary.datasets as datasets

In [5]:
device = torch.device("cuda:0")
# device = torch.device("cuda:1")
# device = torch.device("cpu")

In [6]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

train_size = len(train_label_)

In [7]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [8]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)

In [9]:
input_size = 784
output_size = 10

In [10]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
#         self.label = mask.type(torch.float32).reshape(-1,1)
        self._shuffle_data_()
        
    def __len__(self):
        return len(self.data)
    
    def _shuffle_data_(self):
        randidx = random.sample(range(len(self.data)), k=len(self.data))
        self.data = self.data[randidx]
        self.label = self.label[randidx]
    
    def __getitem__(self, idx):
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [11]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [12]:
learning_rate = 0.0003
batch_size = 50

In [13]:
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

## Making of matrix factorized layer

In [14]:
class GroupLinear(nn.Module):
    def __init__(self, input_dim, output_dim, group_size=2, bias=True):
        super().__init__()
        assert input_dim >= group_size, \
                    f"Input dim:{input_dim} must be >= Group size: {group_size}"
        if output_dim < input_dim/group_size:
            print("Some inputs are ignored in the output")
            
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.group_size = group_size
        
        self.weight = torch.randn(group_size, 1).unsqueeze(0).repeat_interleave(output_dim, dim=0)
        self.weight = nn.Parameter(self.weight)
        self.bias = None
        if bias:
            self.bias = nn.Parameter(torch.zeros(output_dim))
            
        self.order = self.get_random_groups()
#         print(self.order)
#         print(np.unique(self.order, return_counts=True))
        
    def get_random_groups(self):
        rand_order = np.random.permutation(self.input_dim) ## all inputs are included
        if self.output_dim*self.group_size < self.input_dim:
            rand_order = rand_order[:self.output_dim*self.group_size]
#         print(rand_order)
        
        ord0 = []
        for i in range(self.output_dim):
            v = np.random.permutation(self.input_dim)[:self.group_size]
            ord0.append(v)
        ord0 = np.array(ord0)
        
#         print(ord0.shape)
#         print(ord0.T)
        
        ord0 = ord0.T.reshape(-1)
        ord0[:len(rand_order)] = rand_order
        ord0 = ord0.reshape(self.group_size, -1).T
#         print(ord0)
        
        return ord0.reshape(-1)

#     def get_random_groups(self):
#         if self.output_dim*self.group_size < self.input_dim:
#             return np.random.permutation(self.input_dim)[:self.output_dim*self.group_size]
        
#         rand_order = np.random.permutation(self.input_dim)
#         _temp = np.random.permutation(self.output_dim*self.group_size-input_dim)%self.input_dim
#         rand_order.
#         print(rand_order)
#         rand_order = rand_order.reshape(self.group_size, -1)
#         print(rand_order)
#         ### change the order if group contains same units
#         pass
        
    def forward(self, x):
        bs, gs = x.shape[0], self.group_size
        x = x[:, self.order]
        
        x = x.view(bs, -1, gs).transpose(0,1)
        x = torch.bmm(x, self.weight)
        x = x.squeeze(2).transpose(1,0)
        if self.bias is not None:
            x = x + self.bias
        return x

In [15]:
class PairWeight(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.weight = torch.eye(2).unsqueeze(0).repeat_interleave(input_dim//2, dim=0)
        self.weight = nn.Parameter(self.weight)
        
    def forward(self, x):
        bs, dim = x.shape[0], x.shape[1]
        x = x.view(bs, -1, 2).transpose(0,1)
        x = torch.bmm(x, self.weight)
        x = x.transpose(1,0).reshape(bs, -1)
        return x

In [16]:
class PairFactorizedLinear(nn.Module):
    
    def __init__(self, input_dim, bias=True):
        super().__init__()
        assert input_dim%2 == 0, "Input dim must be even number"
        self.input_dim = input_dim
        
        num_layers = int(np.ceil(np.log2(input_dim)))
            
        self.facto_nets = []
        self.idx_revidx = []
        for i in range(num_layers):
            idrid = self.get_pair(self.input_dim, i+1)
            net = PairWeight(self.input_dim)
            self.facto_nets.append(net)
            self.idx_revidx.append(idrid)
        self.facto_nets = nn.ModuleList(self.facto_nets)
        
        self.bias = None
        if bias: self.bias = nn.Parameter(torch.zeros(self.input_dim))
            
    def get_pair(self, inp_dim, step=1):
        dim = 2**int(np.ceil(np.log2(inp_dim)))
        assert isinstance(step, int), "Step must be integer"

        blocks = (2**step)
        range_ = dim//blocks
        adder_ = torch.arange(0, range_)*blocks

        pairs_ = torch.Tensor([0, blocks//2])
        repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
        block_map = (pairs_+repeat_).reshape(-1)

        reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
        indx = reorder_for_pair.type(torch.long)
        indx = indx[indx<inp_dim]

        rev_indx = torch.argsort(indx)
        return indx, rev_indx
    
    def forward(self, x):
        ## swap first and then forward and reverse-swap
        y = self.facto_nets[0](x)
        for i in range(1, len(self.facto_nets)):
            idx, revidx = self.idx_revidx[i]
            y = self.facto_nets[i](y[:, idx])[:, revidx]
        if self.bias is not None: 
            y = y+self.bias
        return y

In [17]:
class FactorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = GroupLinear(784, 2**15, group_size=10, bias=False)
        self.bn1 = nn.BatchNorm1d(2**15)
        self.la2 = PairFactorizedLinear(2**15)
        self.bn2 = nn.BatchNorm1d(2**15)
        self.la3 = PairFactorizedLinear(2**15)
        self.bn3 = nn.BatchNorm1d(2**15)
        self.la4 = GroupLinear(2**15, 2**9, group_size=2**6, bias=False)
        self.bn4 = nn.BatchNorm1d(2**9)
        self.fc = nn.Linear(2**9, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.bn2(self.la2(x))
        x = torch.relu(x)
        x = self.bn3(self.la3(x))
        x = torch.relu(x)
        x = self.bn4(self.la4(x))
        x = torch.relu(x)
        x = self.fc(x)
        return x

In [18]:
class OrdinaryNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.la1 = nn.Linear(784, 784, bias=False)
        self.bn1 = nn.BatchNorm1d(784)
        self.la2 = nn.Linear(784, 784, bias=False)
        self.bn2 = nn.BatchNorm1d(784)
        self.la3 = nn.Linear(784, 10)
        
    def forward(self, x):
        x = self.bn1(self.la1(x))
        x = torch.relu(x)
        x = self.bn2(self.la2(x))
        x = torch.relu(x)
        x = self.la3(x)
        return x

In [19]:
model = FactorNet()
param_count = sum([torch.numel(p) for p in model.parameters()])
param_count

2594826

In [20]:
model = OrdinaryNet()
param_count1 = sum([torch.numel(p) for p in model.parameters()])
param_count1, param_count1/param_count

(1240298, 0.4779888901991887)

### Model Development

In [21]:
torch.manual_seed(0)
model = FactorNet().to(device)
# model = OrdinaryNet().to(device)
model

FactorNet(
  (la1): GroupLinear()
  (bn1): BatchNorm1d(32768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la2): PairFactorizedLinear(
    (facto_nets): ModuleList(
      (0): PairWeight()
      (1): PairWeight()
      (2): PairWeight()
      (3): PairWeight()
      (4): PairWeight()
      (5): PairWeight()
      (6): PairWeight()
      (7): PairWeight()
      (8): PairWeight()
      (9): PairWeight()
      (10): PairWeight()
      (11): PairWeight()
      (12): PairWeight()
      (13): PairWeight()
      (14): PairWeight()
    )
  )
  (bn2): BatchNorm1d(32768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la3): PairFactorizedLinear(
    (facto_nets): ModuleList(
      (0): PairWeight()
      (1): PairWeight()
      (2): PairWeight()
      (3): PairWeight()
      (4): PairWeight()
      (5): PairWeight()
      (6): PairWeight()
      (7): PairWeight()
      (8): PairWeight()
      (9): PairWeight()
      (10): PairWeight()
      (11): PairWeig

In [28]:
model.la1.order

array([268, 705, 721, ..., 343, 103, 257])

In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

In [24]:
print("number of params: ", sum(p.numel() for p in model.parameters()))

number of params:  2594826


In [25]:
losses = []
train_accs = []
test_accs = []
EPOCHS = 20

for epoch in range(EPOCHS):
    
    train_acc = 0
    train_count = 0
    for xx, yy in tqdm(train_loader):
        xx, yy = xx.to(device), yy.to(device)

        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss))

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0

    print(f'Epoch: {epoch},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    for xx, yy in tqdm(test_loader):
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

100%|██████████| 1200/1200 [07:57<00:00,  2.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0,  Loss:0.25316715240478516


100%|██████████| 200/200 [00:11<00:00, 16.70it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:82.22%, Test Acc:86.19%



100%|██████████| 1200/1200 [07:58<00:00,  2.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1,  Loss:0.274682879447937


100%|██████████| 200/200 [00:11<00:00, 16.68it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:88.70%, Test Acc:87.60%



100%|██████████| 1200/1200 [07:58<00:00,  2.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2,  Loss:0.5031993985176086


100%|██████████| 200/200 [00:11<00:00, 16.69it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:90.91%, Test Acc:88.10%



100%|██████████| 1200/1200 [07:59<00:00,  2.51it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3,  Loss:0.18634822964668274


100%|██████████| 200/200 [00:12<00:00, 16.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:92.64%, Test Acc:88.65%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4,  Loss:0.29451826214790344


100%|██████████| 200/200 [00:12<00:00, 16.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:93.81%, Test Acc:88.91%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5,  Loss:0.12445548176765442


100%|██████████| 200/200 [00:12<00:00, 16.64it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:95.04%, Test Acc:89.02%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6,  Loss:0.15430307388305664


100%|██████████| 200/200 [00:12<00:00, 16.63it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:96.01%, Test Acc:89.13%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7,  Loss:0.08230007439851761


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:96.74%, Test Acc:88.94%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8,  Loss:0.17422203719615936


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:97.39%, Test Acc:89.25%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9,  Loss:0.047530338168144226


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:97.71%, Test Acc:89.40%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10,  Loss:0.03348413109779358


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:98.18%, Test Acc:89.10%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11,  Loss:0.15867704153060913


100%|██████████| 200/200 [00:12<00:00, 16.63it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:98.39%, Test Acc:88.91%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12,  Loss:0.01151068601757288


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:98.66%, Test Acc:89.13%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13,  Loss:0.06058363988995552


100%|██████████| 200/200 [00:12<00:00, 16.64it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:98.86%, Test Acc:89.16%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14,  Loss:0.019555795937776566


100%|██████████| 200/200 [00:12<00:00, 16.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:98.84%, Test Acc:89.31%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15,  Loss:0.01619919016957283


100%|██████████| 200/200 [00:12<00:00, 16.63it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:99.08%, Test Acc:89.18%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16,  Loss:0.03674845024943352


100%|██████████| 200/200 [00:12<00:00, 16.63it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:99.17%, Test Acc:89.16%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17,  Loss:0.014013567939400673


100%|██████████| 200/200 [00:12<00:00, 16.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:99.22%, Test Acc:89.01%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18,  Loss:0.021400120109319687


100%|██████████| 200/200 [00:12<00:00, 16.61it/s]
  0%|          | 0/1200 [00:00<?, ?it/s]

Train Acc:99.36%, Test Acc:89.15%



100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19,  Loss:0.007401998154819012


100%|██████████| 200/200 [00:12<00:00, 16.62it/s]

Train Acc:99.24%, Test Acc:89.27%

	-> Train Acc 99.35666666666667 ; Test Acc 89.4





In [1]:
model

NameError: name 'model' is not defined

In [None]:
### Big Network Accuracy
100%|██████████| 1200/1200 [08:00<00:00,  2.50it/s]
	-> Train Acc 99.35666666666667 ; Test Acc 89.4

    
FactorNet(
  (la1): GroupLinear()
  (bn1): BatchNorm1d(32768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la2): PairFactorizedLinear(
    (facto_nets): ModuleList(
      (0): PairWeight()
      (1): PairWeight()
      (2): PairWeight()
      (3): PairWeight()
      (4): PairWeight()
      (5): PairWeight()
      (6): PairWeight()
      (7): PairWeight()
      (8): PairWeight()
      (9): PairWeight()
      (10): PairWeight()
      (11): PairWeight()
      (12): PairWeight()
      (13): PairWeight()
      (14): PairWeight()
    )
  )
  (bn2): BatchNorm1d(32768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la3): PairFactorizedLinear(
    (facto_nets): ModuleList(
      (0): PairWeight()
      (1): PairWeight()
      (2): PairWeight()
      (3): PairWeight()
      (4): PairWeight()
      (5): PairWeight()
      (6): PairWeight()
      (7): PairWeight()
      (8): PairWeight()
      (9): PairWeight()
      (10): PairWeight()
      (11): PairWeight()
      (12): PairWeight()
      (13): PairWeight()
      (14): PairWeight()
    )
  )
  (bn3): BatchNorm1d(32768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (la4): GroupLinear()
  (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=512, out_features=10, bias=True)
)