In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

## MNIST dataset

In [3]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [4]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [5]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [6]:
input_size = 784
output_size = 10

In [7]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx)
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [8]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [9]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                    num_workers=1, 
                                    batch_size=batch_size, 
                                    shuffle=False)

In [10]:
import dtnnlib as dtnn

In [11]:
dt = dtnn.StereographicTransform(784, 20)

torch.Size([20, 785])


In [12]:
dists = dt(torch.randn(2, 784))

In [13]:
dists.shape

torch.Size([2, 20])

In [14]:
dists.data

tensor([[-0.0094,  0.0283,  0.0198,  0.0252, -0.0282, -0.0212, -0.0039, -0.0099,
          0.0283,  0.0258, -0.0128, -0.0144, -0.0130, -0.0125,  0.0005, -0.0326,
          0.0244,  0.0099, -0.0314, -0.0164],
        [-0.0309,  0.0080, -0.0004,  0.0054, -0.0459, -0.0396, -0.0232, -0.0287,
          0.0085,  0.0072, -0.0328, -0.0306, -0.0321, -0.0333, -0.0180, -0.0510,
          0.0090, -0.0080, -0.0520, -0.0330]])

In [15]:
class OneActiv(nn.Module):
    '''
    Mode:
    -softplus
    -relu
    -exp_1.6
    -exp_abs
    '''
    def __init__(self, input_dim, mode='softplus', beta_init=0):
        super().__init__()
        self.input_dim = input_dim
        self.beta = nn.Parameter(torch.ones(1, input_dim)*beta_init)
        self.func_mode = None
        if mode == "softplus":
            self.func_mode = self.func_softplus
        elif mode == "exp_1.6":
            self.func_mode = self.func_exp_16
        elif mode == "exp_abs":
            self.func_mode = self.func_exp_abs
        else:
            self.func_mode = self.func_relu
        pass
        
    def func_softplus(self, x):
        x = torch.exp(self.beta)*(x-1) + 1
        x = nn.functional.softplus(x, beta=6)
        return x
    
    def func_relu(self, x):
        x = torch.exp(self.beta)*(x-1) + 1
        x = torch.relu(x)
        return x
    
    def func_exp_16(self, x):
        x = torch.exp(-torch.exp(2*self.beta)*(torch.abs(x-1)**1.6))
        return x
        
    def func_exp_abs(self, x):
        x = torch.exp(-torch.exp(2*self.beta)*torch.abs(x-1))
        return x
    
    def forward(self, x):
        return self.func_mode(x)

In [16]:
#######################

In [17]:
model = nn.Sequential(
                dtnn.StereographicTransform(784, 785, normalize=True, bias=False),
#                 nn.LeakyReLU(),
#                 nn.BatchNorm1d(785),
                OneActiv(785),
                nn.Linear(785, 200),
                nn.BatchNorm1d(200),
                nn.LeakyReLU(),
#                 OneActiv(200),
                dtnn.StereographicTransform(200, 50),
#                 nn.BatchNorm1d(50),
#                 nn.LeakyReLU(),
                OneActiv(50),
                dtnn.StereographicTransform(50, 10),
                nn.BatchNorm1d(10),
#                 nn.Linear(10, 10),
            )
model.to(device)

torch.Size([785, 785])
torch.Size([50, 201])
torch.Size([10, 51])


Sequential(
  (0): StereographicTransform(
    (linear): Linear(in_features=784, out_features=785, bias=False)
  )
  (1): OneActiv()
  (2): Linear(in_features=785, out_features=200, bias=True)
  (3): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): LeakyReLU(negative_slope=0.01)
  (5): StereographicTransform(
    (linear): Linear(in_features=200, out_features=50, bias=True)
  )
  (6): OneActiv()
  (7): StereographicTransform(
    (linear): Linear(in_features=50, out_features=10, bias=True)
  )
  (8): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [18]:
# model[0].centers.requires_grad=False
# model[0].set_centroid_to_data_randomly(train_loader) ## this worked best for preserving locality
# model[0].set_centroid_to_data_maxdist(train_loader)

In [19]:
# center_lbl = model(model[0].centers.data)
# output_cent = torch.softmax(center_lbl, dim=1).argmax(dim=1).data.cpu()
# torch.unique(output_cent, return_counts=True)

In [20]:
# model = nn.Sequential(
#                 nn.Linear(784, 785),
#                 nn.BatchNorm1d(785),
#                 nn.LeakyReLU(),
#                 nn.Linear(785, 200),
#                 nn.BatchNorm1d(200),
#                 nn.LeakyReLU(),
#                 nn.Linear(200, 50),
#                 nn.BatchNorm1d(50),
#                 nn.LeakyReLU(),
#                 nn.Linear(50, 10),
#                 nn.BatchNorm1d(10)
#             )
# model.to(device)

In [21]:
# model[0].weight.data = dt.centers.data.clone().to(device)/85.0

In [22]:
optimizer = optim.Adam(list(model.parameters()), 
                            lr=0.001)
criterion = nn.CrossEntropyLoss()

In [23]:
index = 0
train_accs, test_accs = [], []
for epoch in tqdm(list(range(40))):
    model.train()
    train_acc = 0
    train_count = 0
    for xx, yy in train_loader:
        xx, yy = xx.to(device), yy.to(device)
        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0
    
#     if epoch%5 == 0:
#         print(f"Shifting the centroids to the nearest data point")
#         model[0].set_centroid_to_data(train_loader)

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    model.eval()
    for xx, yy in test_loader:
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:0.47725823521614075


  2%|▎         | 1/40 [00:04<03:01,  4.66s/it]

Train Acc:74.03%, Test Acc:82.15%

Epoch: 1:0,  Loss:0.41984522342681885


  5%|▌         | 2/40 [00:09<02:56,  4.65s/it]

Train Acc:84.27%, Test Acc:84.17%

Epoch: 2:0,  Loss:0.40634337067604065


  8%|▊         | 3/40 [00:13<02:51,  4.63s/it]

Train Acc:85.60%, Test Acc:84.01%

Epoch: 3:0,  Loss:0.26421597599983215


 10%|█         | 4/40 [00:18<02:46,  4.62s/it]

Train Acc:86.36%, Test Acc:83.67%

Epoch: 4:0,  Loss:0.34115272760391235


 12%|█▎        | 5/40 [00:23<02:42,  4.64s/it]

Train Acc:87.00%, Test Acc:86.71%

Epoch: 5:0,  Loss:0.36509743332862854


 15%|█▌        | 6/40 [00:27<02:36,  4.60s/it]

Train Acc:87.36%, Test Acc:85.14%

Epoch: 6:0,  Loss:0.3546651601791382


 18%|█▊        | 7/40 [00:32<02:32,  4.63s/it]

Train Acc:87.58%, Test Acc:86.14%

Epoch: 7:0,  Loss:0.22431057691574097


 20%|██        | 8/40 [00:37<02:28,  4.63s/it]

Train Acc:87.98%, Test Acc:86.01%

Epoch: 8:0,  Loss:0.20444272458553314


 22%|██▎       | 9/40 [00:41<02:22,  4.60s/it]

Train Acc:88.02%, Test Acc:86.93%

Epoch: 9:0,  Loss:0.33615151047706604


 25%|██▌       | 10/40 [00:46<02:17,  4.59s/it]

Train Acc:88.40%, Test Acc:87.48%

Epoch: 10:0,  Loss:0.20713095366954803


 28%|██▊       | 11/40 [00:50<02:12,  4.58s/it]

Train Acc:88.59%, Test Acc:86.70%

Epoch: 11:0,  Loss:0.3200035095214844


 30%|███       | 12/40 [00:55<02:08,  4.59s/it]

Train Acc:88.72%, Test Acc:85.99%

Epoch: 12:0,  Loss:0.3161141872406006


 32%|███▎      | 13/40 [00:59<02:03,  4.59s/it]

Train Acc:88.78%, Test Acc:86.57%



 32%|███▎      | 13/40 [01:02<02:09,  4.78s/it]


KeyboardInterrupt: 

In [79]:
model[0].s.shape, model[0].t.shape

(torch.Size([785, 1]), torch.Size([785, 1]))

In [1190]:
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

	-> MAX Train Acc 97.31666666666666 ; Test Acc 89.82
