In [5]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

from typing import Union, Tuple

In [6]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

## MNIST dataset

In [7]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [8]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [9]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [10]:
input_size = 784
output_size = 10

In [11]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx)
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [12]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [13]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                    num_workers=1, 
                                    batch_size=batch_size, 
                                    shuffle=False)

In [14]:
class ShiftScale(nn.Module):
    
    def __init__(self, input_dim):
        super().__init__()
        self.scaler = nn.Parameter(torch.ones(1, input_dim))
        self.shifter = nn.Parameter(torch.zeros(1, input_dim))
        
    def forward(self, x):
        return (x+self.shifter)*self.scaler

In [15]:
# class DistanceTransform(nn.Module):
    
#     def __init__(self, input_dim, num_centers):
#         super().__init__()
#         self.input_dim = input_dim
#         self.num_centers = num_centers
        
# #         self.centers = torch.randn(num_centers, input_dim)/2.
#         self.centers = torch.rand(num_centers, input_dim)
#         self.centers = nn.Parameter(self.centers)
# #         self.scaler = nn.Parameter(torch.Tensor([1.0]))
# #         self.layernorm = nn.LayerNorm(num_centers, elementwise_affine=False)
        
#     def forward(self, x):
#         x = x[:, :self.input_dim]
#         dists = torch.cdist(x, self.centers)
        
#         ### normalize similar to UMAP
# #         dists = dists-dists.min(dim=1, keepdim=True)[0]
#         dists = dists-dists.mean(dim=1, keepdim=True)
#         dists = dists/dists.std(dim=1, keepdim=True)

# #         dists = self.layernorm(dists)

# #         dists = torch.exp(-(dists**2) * self.scaler)
# #         dists = torch.softmax(-dists*self.scaler, dim=1)
        
#         return dists

In [16]:
### shift normalized dists towards 0 for sparse activation with exponential
class DistanceTransform(nn.Module):
    
    def __init__(self, input_dim, num_centers, p=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.p = p
        
#         self.centers = torch.randn(num_centers, input_dim)/2.
        self.centers = torch.rand(num_centers, input_dim)
        self.centers = nn.Parameter(self.centers)
        
        self.scaler = nn.Parameter(torch.ones(1, num_centers)*2/3)
        self.bias = nn.Parameter(torch.ones(1, num_centers)*-0.1)# if bias else None
        
        self.layernorm = nn.LayerNorm(num_centers, elementwise_affine=False)
        
    def forward(self, x):
#         x = x[:, :self.input_dim]
        dists = torch.cdist(x, self.centers)
        
        ### normalize similar to UMAP
#         dists = dists-dists.min(dim=1, keepdim=True)[0]
        dists = dists-dists.mean(dim=1, keepdim=True)
        dists = dists/dists.std(dim=1, keepdim=True)

#         dists = self.layernorm(dists)

#         dists = torch.exp(-dists*self.scaler)+self.bias
        dists = torch.exp((-dists-3)*self.scaler)+self.bias
    
#         dists = torch.softmax(-dists*self.scaler, dim=1)

        return dists

In [17]:
# ## bias to basic dist
# class DistanceTransform(nn.Module):
    
#     def __init__(self, input_dim, num_centers, p=2, bias=True):
#         super().__init__()
#         self.input_dim = input_dim
#         self.num_centers = num_centers
#         self.p = p
#         self.bias = nn.Parameter(torch.zeros(1, num_centers)) if bias else None
        
# #         self.centers = torch.randn(num_centers, input_dim)/2.
#         self.centers = torch.rand(num_centers, input_dim)
#         self.centers = nn.Parameter(self.centers)
# #         self.scaler = nn.Parameter(torch.Tensor([1.0]))
# #         self.layernorm = nn.LayerNorm(num_centers, elementwise_affine=False)
        
#     def forward(self, x):
#         x = x[:, :self.input_dim]
#         dists = torch.cdist(x, self.centers)
        
#         ### normalize similar to UMAP
# #         dists = dists-dists.min(dim=1, keepdim=True)[0]
# #         dists = dists-dists.max(dim=1, keepdim=True)[0]
#         dists = dists-dists.mean(dim=1, keepdim=True)
# #         dists = dists/dists.std(dim=1, keepdim=True)*(2/3)

# #         dists = self.layernorm(dists)

# #         dists = torch.exp(-dists*(2/3)-2)
# #         dists = torch.softmax(-dists*self.scaler, dim=1)

#         if self.bias is not None: dists = dists+self.bias
#         return dists

In [18]:
dt = DistanceTransform(784, 785)

In [19]:
dists = dt(torch.randn(2, 784))
dists.shape

torch.Size([2, 785])

In [194]:
###### Convolution to distance transform

In [195]:
class Conv2D_DT(nn.Module):
    
    def __init__(self, in_channels, out_channels,
                 kernel_size: Union[int, Tuple[int, ...]],
                 dilation: Union[int, Tuple[int, ...]] = 1,
                 padding: Union[int, Tuple[int, ...]] = 0,
                 stride: Union[int, Tuple[int, ...]] = 1,):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = padding  # format for padding -> l, r, t, b
        self.stride = stride
        self._preprocess_()
        
        self.unfold = nn.Unfold(self.kernel_size, self.dilation, self.padding, self.stride)
        self.dt = DistanceTransform(self.kernel_size[0]*self.kernel_size[1]*in_channels, out_channels)
    
    def _preprocess_(self):
        if not isinstance(self.kernel_size, (tuple, list)):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        if not isinstance(self.dilation, (tuple, list)):
            self.dilation = (self.dilation, self.dilation)
        if not isinstance(self.stride, (tuple, list)):
            self.stride = (self.stride, self.stride)
#         if not isinstance(self.padding, (tuple, list)):
#             self.padding = (self.padding, self.padding, self.padding, self.padding)
#         assert len(self.padding) == 4, 'padding must be specified for all sides of image'
        if not isinstance(self.padding, (tuple, list)):
            self.padding = (self.padding, self.padding)
        assert len(self.padding) == 2, 'padding must be specified for TB, LR'
        return
        
    def _get_output_size_(self, inputH, inputW):
        ### input change due to padding
        inputH, inputW = (inputH+2*self.padding[0], inputW+2*self.padding[1])
        
        ### kernel change due to dilation
        kH = (self.kernel_size[0]-1)*self.dilation[0]+1
        kW = (self.kernel_size[1]-1)*self.dilation[1]+1
        
        oH = (inputH-kH)/self.stride[0]+1
        oW = (inputW-kW)/self.stride[1]+1
        return (int(oH), int(oW))
        
    def forward(self, x):
        c = x.shape
        x = self.unfold(x).transpose(1,2).reshape(-1, self.dt.input_dim)
        x = self.dt(x).view(c[0], -1, self.dt.num_centers).transpose(1,2)\
                    .view(c[0], -1, *self._get_output_size_(c[2], c[3]))
        return x

In [196]:
convdt = Conv2D_DT(3, 4, kernel_size=3, padding=(0,1))

In [197]:
x = torch.randn(1, 3, 10, 10)
# x = torch.arange(torch.numel(x), dtype=x.dtype).reshape(x.shape)

In [198]:
y = convdt(x)
y.shape ## B, K1*K2*C, N_window

torch.Size([1, 4, 8, 10])

In [199]:
y

tensor([[[[-1.5820e-02, -2.9336e-01, -1.1420e-01, -2.8184e-01, -9.9238e-02,
            1.1778e-01, -7.6962e-02, -5.0676e-01, -1.6824e-01, -3.3333e-01],
          [ 1.2480e-01, -1.8395e-01, -1.3378e-01, -4.8187e-02, -1.7141e-01,
           -4.0958e-01, -9.9412e-02, -1.7116e-01, -2.1677e-02, -1.9155e-01],
          [-4.1571e-01, -2.2958e-01, -1.0402e-01, -1.7689e-01, -1.5786e-01,
           -6.3832e-02, -1.9401e-01, -1.0064e-01, -5.3067e-01, -2.8397e-01],
          [-2.2334e-01, -5.1131e-01, -2.5324e-01, -1.8139e-01, -3.8438e-01,
           -2.0437e-01, -1.9531e-01, -2.1935e-01, -1.3157e-01, -3.7372e-01],
          [-5.1013e-02,  3.1500e-03, -6.6307e-02, -2.5629e-01,  8.9358e-02,
           -1.0644e-01, -1.7593e-01, -2.9563e-01,  5.9885e-02, -4.7546e-02],
          [-3.9073e-01, -3.7214e-01, -3.6599e-01, -4.8006e-01, -2.2589e-01,
           -4.6470e-01, -3.0997e-01, -2.1423e-01, -2.2089e-01, -3.3721e-01],
          [-1.5872e-01,  9.2511e-03,  7.4882e-02,  2.8109e-02,  1.0479e-01,
      

In [200]:
#######################

In [201]:
class MyModel(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.cnn = nn.Sequential(
                Conv2D_DT(1, 20, 5),
                nn.MaxPool2d(2),
                nn.LeakyReLU(),
                Conv2D_DT(20, 50, 5),
                nn.MaxPool2d(2),
                nn.LeakyReLU(),
                )
        self.fc = nn.Sequential(
                DistanceTransform(4*4*50, 500),
                nn.LeakyReLU(),
                DistanceTransform(500, 10),
                ShiftScale(10),
            )
    def forward(self,x):
        x = self.cnn(x)
        x = x.reshape(-1, 4*4*50)
        x = self.fc(x)
        return x

In [202]:
# class MyModel(nn.Module):
    
#     def __init__(self):
#         super().__init__()
        
#         self.cnn = nn.Sequential(
#                 nn.Conv2d(1, 20, 5),
#                 nn.MaxPool2d(2),
#                 nn.LeakyReLU(),
#                 nn.Conv2d(20, 50, 5),
#                 nn.MaxPool2d(2),
#                 nn.LeakyReLU(),
#                 )
#         self.fc = nn.Sequential(
#                 nn.Linear(4*4*50, 500),
#                 nn.LeakyReLU(),
#                 nn.Linear(500, 10),
#                 nn.BatchNorm1d(10),
#             )
#     def forward(self,x):
#         x = self.cnn(x)
#         x = x.reshape(-1, 4*4*50)
#         x = self.fc(x)
#         return x

In [203]:
model = MyModel()
model.to(device)

MyModel(
  (cnn): Sequential(
    (0): Conv2D_DT(
      (unfold): Unfold(kernel_size=(5, 5), dilation=(1, 1), padding=(0, 0), stride=(1, 1))
      (dt): DistanceTransform()
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv2D_DT(
      (unfold): Unfold(kernel_size=(5, 5), dilation=(1, 1), padding=(0, 0), stride=(1, 1))
      (dt): DistanceTransform()
    )
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): LeakyReLU(negative_slope=0.01)
  )
  (fc): Sequential(
    (0): DistanceTransform()
    (1): LeakyReLU(negative_slope=0.01)
    (2): DistanceTransform()
    (3): ShiftScale()
  )
)

In [204]:
xx, yy = iter(train_loader).next()
model(xx.reshape(-1, 1, 28, 28).to(device)).shape

torch.Size([50, 10])

In [205]:
optimizer = optim.Adam(list(model.parameters()), 
                            lr=0.001)
criterion = nn.CrossEntropyLoss()

In [206]:
index = 0
train_accs, test_accs = [], []
for epoch in tqdm(list(range(40))):
    model.train()
    train_acc = 0
    train_count = 0
    for xx, yy in train_loader:
        xx, yy = xx.reshape(-1, 1, 28, 28).to(device), yy.to(device)
        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0
    
#     if epoch%5 == 0:
#         print(f"Shifting the centroids to the nearest data point")
#         model[0].set_centroid_to_data(train_loader)

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    model.eval()
    for xx, yy in test_loader:
        xx, yy = xx.reshape(-1, 1, 28, 28).to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:0.4128994047641754


  2%|▎         | 1/40 [00:08<05:26,  8.37s/it]

Train Acc:70.59%, Test Acc:80.02%

Epoch: 1:0,  Loss:0.32194313406944275


  5%|▌         | 2/40 [00:16<05:19,  8.42s/it]

Train Acc:83.30%, Test Acc:84.16%

Epoch: 2:0,  Loss:0.2974202036857605


  8%|▊         | 3/40 [00:25<05:11,  8.42s/it]

Train Acc:86.00%, Test Acc:85.90%

Epoch: 3:0,  Loss:0.3440536558628082


 10%|█         | 4/40 [00:33<05:03,  8.42s/it]

Train Acc:87.26%, Test Acc:86.17%

Epoch: 4:0,  Loss:0.3156658113002777


 12%|█▎        | 5/40 [00:42<04:54,  8.42s/it]

Train Acc:88.13%, Test Acc:87.27%

Epoch: 5:0,  Loss:0.22195637226104736


 15%|█▌        | 6/40 [00:50<04:45,  8.41s/it]

Train Acc:88.74%, Test Acc:88.03%

Epoch: 6:0,  Loss:0.29990965127944946


 18%|█▊        | 7/40 [00:58<04:37,  8.40s/it]

Train Acc:89.18%, Test Acc:88.49%

Epoch: 7:0,  Loss:0.23408469557762146


 20%|██        | 8/40 [01:07<04:29,  8.41s/it]

Train Acc:89.68%, Test Acc:89.08%

Epoch: 8:0,  Loss:0.13937540352344513


 22%|██▎       | 9/40 [01:15<04:20,  8.41s/it]

Train Acc:90.22%, Test Acc:89.48%

Epoch: 9:0,  Loss:0.2537733018398285


 25%|██▌       | 10/40 [01:24<04:12,  8.41s/it]

Train Acc:90.40%, Test Acc:89.31%

Epoch: 10:0,  Loss:0.2230251133441925


 28%|██▊       | 11/40 [01:32<04:03,  8.41s/it]

Train Acc:90.70%, Test Acc:89.41%

Epoch: 11:0,  Loss:0.43614470958709717


 30%|███       | 12/40 [01:40<03:55,  8.40s/it]

Train Acc:90.95%, Test Acc:89.52%

Epoch: 12:0,  Loss:0.18129204213619232


 32%|███▎      | 13/40 [01:49<03:46,  8.40s/it]

Train Acc:91.23%, Test Acc:89.49%

Epoch: 13:0,  Loss:0.14957067370414734


 35%|███▌      | 14/40 [01:57<03:38,  8.41s/it]

Train Acc:91.50%, Test Acc:89.21%

Epoch: 14:0,  Loss:0.30811789631843567


 38%|███▊      | 15/40 [02:06<03:30,  8.41s/it]

Train Acc:91.73%, Test Acc:89.79%

Epoch: 15:0,  Loss:0.30762791633605957


 40%|████      | 16/40 [02:14<03:21,  8.40s/it]

Train Acc:91.95%, Test Acc:90.37%

Epoch: 16:0,  Loss:0.3574450612068176


 42%|████▎     | 17/40 [02:22<03:13,  8.40s/it]

Train Acc:92.11%, Test Acc:89.94%

Epoch: 17:0,  Loss:0.20711122453212738


 45%|████▌     | 18/40 [02:31<03:04,  8.39s/it]

Train Acc:92.34%, Test Acc:90.04%

Epoch: 18:0,  Loss:0.283222496509552


 48%|████▊     | 19/40 [02:39<02:56,  8.40s/it]

Train Acc:92.52%, Test Acc:90.31%

Epoch: 19:0,  Loss:0.2070496678352356


 50%|█████     | 20/40 [02:48<02:48,  8.40s/it]

Train Acc:92.71%, Test Acc:90.33%

Epoch: 20:0,  Loss:0.12107142806053162


 52%|█████▎    | 21/40 [02:56<02:39,  8.41s/it]

Train Acc:92.86%, Test Acc:90.61%

Epoch: 21:0,  Loss:0.1419389247894287


 55%|█████▌    | 22/40 [03:04<02:31,  8.42s/it]

Train Acc:93.01%, Test Acc:90.19%

Epoch: 22:0,  Loss:0.11178340017795563


 57%|█████▊    | 23/40 [03:13<02:22,  8.41s/it]

Train Acc:93.17%, Test Acc:90.59%

Epoch: 23:0,  Loss:0.1981443464756012


 60%|██████    | 24/40 [03:21<02:14,  8.41s/it]

Train Acc:93.41%, Test Acc:90.45%

Epoch: 24:0,  Loss:0.09476640820503235


 62%|██████▎   | 25/40 [03:30<02:06,  8.41s/it]

Train Acc:93.50%, Test Acc:90.16%

Epoch: 25:0,  Loss:0.10581621527671814


 65%|██████▌   | 26/40 [03:38<01:57,  8.40s/it]

Train Acc:93.67%, Test Acc:90.63%

Epoch: 26:0,  Loss:0.16052870452404022


 68%|██████▊   | 27/40 [03:46<01:49,  8.39s/it]

Train Acc:93.58%, Test Acc:90.90%

Epoch: 27:0,  Loss:0.35116758942604065


 70%|███████   | 28/40 [03:55<01:40,  8.40s/it]

Train Acc:93.85%, Test Acc:90.84%

Epoch: 28:0,  Loss:0.25284063816070557


 72%|███████▎  | 29/40 [04:03<01:32,  8.41s/it]

Train Acc:94.03%, Test Acc:90.60%

Epoch: 29:0,  Loss:0.11799734830856323


 75%|███████▌  | 30/40 [04:12<01:24,  8.41s/it]

Train Acc:94.17%, Test Acc:90.89%

Epoch: 30:0,  Loss:0.10534902662038803


 78%|███████▊  | 31/40 [04:20<01:15,  8.42s/it]

Train Acc:94.20%, Test Acc:91.04%

Epoch: 31:0,  Loss:0.16730950772762299


 80%|████████  | 32/40 [04:29<01:07,  8.41s/it]

Train Acc:94.36%, Test Acc:90.34%

Epoch: 32:0,  Loss:0.1351667046546936


 82%|████████▎ | 33/40 [04:37<00:58,  8.41s/it]

Train Acc:94.50%, Test Acc:91.43%

Epoch: 33:0,  Loss:0.2073395848274231


 85%|████████▌ | 34/40 [04:45<00:50,  8.41s/it]

Train Acc:94.65%, Test Acc:90.86%

Epoch: 34:0,  Loss:0.16478420794010162


 88%|████████▊ | 35/40 [04:54<00:42,  8.40s/it]

Train Acc:94.77%, Test Acc:91.06%

Epoch: 35:0,  Loss:0.13498421013355255


 90%|█████████ | 36/40 [05:02<00:33,  8.40s/it]

Train Acc:94.79%, Test Acc:91.11%

Epoch: 36:0,  Loss:0.19256603717803955


 92%|█████████▎| 37/40 [05:11<00:25,  8.40s/it]

Train Acc:94.91%, Test Acc:91.25%

Epoch: 37:0,  Loss:0.29411786794662476


 95%|█████████▌| 38/40 [05:19<00:16,  8.41s/it]

Train Acc:94.96%, Test Acc:91.19%

Epoch: 38:0,  Loss:0.19522039592266083


 98%|█████████▊| 39/40 [05:27<00:08,  8.42s/it]

Train Acc:95.24%, Test Acc:90.40%

Epoch: 39:0,  Loss:0.08623470366001129


100%|██████████| 40/40 [05:36<00:00,  8.41s/it]

Train Acc:95.24%, Test Acc:91.29%

	-> MAX Train Acc 95.24333333333334 ; Test Acc 91.43





In [None]:
#### Report
# -mean works better than -min
# -min and -max work the same
# non exp works == exps ?? (need expansion)
# layernorm works best for activation functions, not needed if batchnorm used
# 

In [164]:
model.cnn[0]

Conv2D_DT(
  (unfold): Unfold(kernel_size=(5, 5), dilation=(1, 1), padding=(0, 0), stride=(1, 1))
  (dt): DistanceTransform(
    (layernorm): LayerNorm((20,), eps=1e-05, elementwise_affine=False)
  )
)

In [165]:
model.eval()
xx = train_dataset[np.random.randint(0, len(train_dataset), 50)][0].to(device)
hh = model.cnn(xx.reshape(-1, 1, 28, 28))
hh.shape

torch.Size([50, 50, 4, 4])

In [174]:
hh[0,:,:,:]

tensor([[[0.3639, 0.4713, 0.4519, 0.5083],
         [0.3742, 0.4260, 0.3797, 0.5161],
         [0.6105, 0.7115, 0.8506, 0.3109],
         [0.4332, 0.3778, 0.4804, 0.4227]],

        [[0.4507, 0.4439, 0.4287, 0.3602],
         [0.4195, 0.3900, 0.4831, 0.5990],
         [0.4930, 0.7527, 0.5622, 0.3342],
         [0.5466, 0.4771, 0.4981, 0.4912]],

        [[0.5052, 0.3938, 0.3243, 0.3371],
         [0.4597, 0.3492, 0.2942, 0.3626],
         [0.5308, 0.3488, 0.5289, 0.6275],
         [0.7375, 0.6787, 0.5943, 0.5946]],

        [[0.3175, 0.3861, 0.3877, 0.3782],
         [0.3409, 0.2672, 0.2723, 0.3228],
         [0.2752, 0.3433, 0.4133, 0.2556],
         [0.6312, 0.4924, 0.6295, 0.6080]],

        [[0.7979, 0.8010, 0.7175, 0.6030],
         [0.8726, 0.6133, 0.7082, 0.6839],
         [0.5143, 0.7394, 0.5239, 0.5849],
         [0.7327, 0.6134, 0.5752, 0.6405]],

        [[0.4284, 0.3650, 0.5119, 0.4893],
         [0.3136, 0.5471, 0.4954, 0.3413],
         [0.4130, 0.4061, 0.4837, 0.3165],
 

In [75]:
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

	-> MAX Train Acc 99.41499999999999 ; Test Acc 90.93


In [None]:
torch.save(model.state_dict(), "./models/temp_03_1_model_dec1_v0.pth")

# model.load_state_dict(torch.load("./temp_01_2_model_nov26.pth", map_location=device))

In [None]:
model.eval()
xx = train_dataset[np.random.randint(0, len(train_dataset), 50)][0].to(device)
dists = model[0](xx.reshape(-1, 1, 28, 28))
# dists = xx@model[0].weight.data.t()

# dists = model[1](dists)
model.train()
dists.shape

In [None]:
# xx[0], model[0].weight.data[0]

In [None]:
dists.mean(), dists.std(), dists.max(dim=1)

In [None]:
# model[0].scaler

## UMAP - from library

In [None]:
import umap

In [None]:
embed = umap.UMAP(n_neighbors=784, n_components=2, min_dist=0.1, spread=1, metric="euclidean")
# embed = umap.UMAP(n_neighbors=784, n_components=2, min_dist=0.1, spread=1,
# #                   target_metric='euclidean',
#                   target_metric='categorical',
#                   target_weight=0.1
#                  )

In [None]:
center_lbl = model(model[0].centers.data)
# center_lbl = model(model[0].weight.data)
# output_cent = torch.softmax(center_lbl, dim=1).argmax(dim=1).data.cpu()
output_cent = center_lbl.argmax(dim=1).data.cpu()

# output_cent = center_lbl.data.cpu().numpy()

torch.unique(output_cent, return_counts=True)

In [None]:
centers = model[0].centers.data.cpu().numpy()
# centers = model[0].weight.data.cpu().numpy()

embedding = embed.fit_transform(centers)
# embedding = embed.fit_transform(centers, output_cent)

In [None]:
dists.shape

In [None]:
i = 0

In [None]:
activ = dists.data.cpu()[i]
# activ = activ - activ.min()
# activ = torch.exp(activ)

i += 1
print(f"{i}/{len(dists)}")
plt.figure(figsize=(8,8))
plt.scatter(embedding[:,0], embedding[:, 1], c=output_cent, s=np.maximum(activ*50, 0.5), cmap="tab10")
# plt.scatter(embedding[:,0], embedding[:, 1], c=output_cent, s=activ, cmap="tab10")

In [None]:
aa = dists.data.cpu()[i]
aa.mean(), aa.min(), aa.max(), aa.std()

In [None]:
j = 0

In [None]:
# plt.imshow(xx.cpu()[j].reshape(28,28))
plt.imshow(centers[j].reshape(28,28))
j += 1

In [None]:
class ScaleShift(nn.Module):
    
    def __init__(self, input_dim):
        super().__init__()
        self.scaler = nn.Parameter(torch.ones(1, input_dim))
        self.shifter = nn.Parameter(torch.zeros(1, input_dim))
        
    def forward(self, x):
        return x*self.scaler+self.shifter