In [24]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [25]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

## MNIST dataset

In [26]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [27]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [28]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [29]:
input_size = 784
output_size = 10

In [30]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx)
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [31]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [32]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                    num_workers=1, 
                                    batch_size=batch_size, 
                                    shuffle=False)

In [33]:
import dtnnlib as dtnn

In [34]:
dt = dtnn.StereographicTransform(784, 20)

torch.Size([20, 785])


In [35]:
dists = dt(torch.randn(2, 784))

In [36]:
dists.shape

torch.Size([2, 20])

In [37]:
dists.data

tensor([[-0.0002, -0.0383, -0.0249, -0.0016, -0.0291, -0.0347, -0.0024, -0.0139,
         -0.0152, -0.0312, -0.0107, -0.0360,  0.0024,  0.0061, -0.0171, -0.0350,
          0.0039, -0.0420, -0.0239,  0.0109],
        [ 0.0450,  0.0140,  0.0221,  0.0483,  0.0219,  0.0168,  0.0464,  0.0408,
          0.0349,  0.0213,  0.0454,  0.0130,  0.0523,  0.0550,  0.0294,  0.0167,
          0.0467,  0.0095,  0.0257,  0.0580]])

In [38]:
class OneActiv(nn.Module):
    '''
    Mode:
    -softplus
    -relu
    -exp_1.6
    -exp_abs
    '''
    def __init__(self, input_dim, mode='softplus', beta_init=0):
        super().__init__()
        self.input_dim = input_dim
        self.beta = nn.Parameter(torch.ones(1, input_dim)*beta_init)
        self.func_mode = None
        if mode == "softplus":
            self.func_mode = self.func_softplus
        elif mode == "exp_1.6":
            self.func_mode = self.func_exp_16
        elif mode == "exp_abs":
            self.func_mode = self.func_exp_abs
        else:
            self.func_mode = self.func_relu
        pass
        
    def func_softplus(self, x):
        x = torch.exp(self.beta)*(x-1) + 1
        x = nn.functional.softplus(x, beta=6)
        return x
    
    def func_relu(self, x):
        x = torch.exp(self.beta)*(x-1) + 1
        x = torch.relu(x)
        return x
    
    def func_exp_16(self, x):
        x = torch.exp(-torch.exp(2*self.beta)*(torch.abs(x-1)**1.6))
        return x
        
    def func_exp_abs(self, x):
        x = torch.exp(-torch.exp(2*self.beta)*torch.abs(x-1))
        return x
    
    def forward(self, x):
        return self.func_mode(x)

In [39]:
#######################

In [76]:
model = nn.Sequential(
                dtnn.StereographicTransform(784, 785),
#                 nn.LeakyReLU(),
#                 nn.BatchNorm1d(785),
                OneActiv(785),
                nn.Linear(785, 200),
#                 nn.BatchNorm1d(200),
                nn.LeakyReLU(),
#                 OneActiv(200),
                dtnn.StereographicTransform(200, 50),
#                 nn.BatchNorm1d(50),
#                 nn.LeakyReLU(),
                OneActiv(50),
#                 dtnn.StereographicTransform(50, 10),
                nn.Linear(50, 10),
#                 nn.BatchNorm1d(10),
            )
model.to(device)

torch.Size([785, 785])
torch.Size([50, 201])


Sequential(
  (0): StereographicTransform(
    (linear): Linear(in_features=784, out_features=785, bias=True)
  )
  (1): OneActiv()
  (2): Linear(in_features=785, out_features=200, bias=True)
  (3): LeakyReLU(negative_slope=0.01)
  (4): StereographicTransform(
    (linear): Linear(in_features=200, out_features=50, bias=True)
  )
  (5): OneActiv()
  (6): Linear(in_features=50, out_features=10, bias=True)
)

In [77]:
# model[0].centers.requires_grad=False
# model[0].set_centroid_to_data_randomly(train_loader) ## this worked best for preserving locality
# model[0].set_centroid_to_data_maxdist(train_loader)

In [78]:
# center_lbl = model(model[0].centers.data)
# output_cent = torch.softmax(center_lbl, dim=1).argmax(dim=1).data.cpu()
# torch.unique(output_cent, return_counts=True)

In [79]:
# model = nn.Sequential(
#                 nn.Linear(784, 785),
#                 nn.BatchNorm1d(785),
#                 nn.LeakyReLU(),
#                 nn.Linear(785, 200),
#                 nn.BatchNorm1d(200),
#                 nn.LeakyReLU(),
#                 nn.Linear(200, 50),
#                 nn.BatchNorm1d(50),
#                 nn.LeakyReLU(),
#                 nn.Linear(50, 10),
#                 nn.BatchNorm1d(10)
#             )
# model.to(device)

In [80]:
# model[0].weight.data = dt.centers.data.clone().to(device)/85.0

In [81]:
optimizer = optim.Adam(list(model.parameters()), 
                            lr=0.001)
criterion = nn.CrossEntropyLoss()

In [82]:
index = 0
train_accs, test_accs = [], []
for epoch in tqdm(list(range(40))):
    model.train()
    train_acc = 0
    train_count = 0
    for xx, yy in train_loader:
        xx, yy = xx.to(device), yy.to(device)
        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0
    
#     if epoch%5 == 0:
#         print(f"Shifting the centroids to the nearest data point")
#         model[0].set_centroid_to_data(train_loader)

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    model.eval()
    for xx, yy in test_loader:
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:0.7326363325119019


  2%|▎         | 1/40 [00:03<02:27,  3.78s/it]

Train Acc:43.71%, Test Acc:79.11%

Epoch: 1:0,  Loss:0.43709030747413635


  5%|▌         | 2/40 [00:07<02:29,  3.95s/it]

Train Acc:83.03%, Test Acc:83.85%

Epoch: 2:0,  Loss:0.264369934797287


  8%|▊         | 3/40 [00:11<02:24,  3.91s/it]

Train Acc:85.99%, Test Acc:84.86%

Epoch: 3:0,  Loss:0.3002038896083832


 10%|█         | 4/40 [00:15<02:21,  3.92s/it]

Train Acc:87.44%, Test Acc:85.64%

Epoch: 4:0,  Loss:0.24726207554340363


 12%|█▎        | 5/40 [00:19<02:17,  3.92s/it]

Train Acc:88.21%, Test Acc:85.93%

Epoch: 5:0,  Loss:0.26405957341194153


 15%|█▌        | 6/40 [00:23<02:12,  3.90s/it]

Train Acc:88.70%, Test Acc:86.77%

Epoch: 6:0,  Loss:0.24867944419384003


 18%|█▊        | 7/40 [00:27<02:09,  3.91s/it]

Train Acc:89.31%, Test Acc:87.28%

Epoch: 7:0,  Loss:0.2721996009349823


 20%|██        | 8/40 [00:31<02:04,  3.90s/it]

Train Acc:89.68%, Test Acc:86.99%

Epoch: 8:0,  Loss:0.32429009675979614


 22%|██▎       | 9/40 [00:35<02:00,  3.89s/it]

Train Acc:90.20%, Test Acc:88.05%

Epoch: 9:0,  Loss:0.35445114970207214


 25%|██▌       | 10/40 [00:38<01:56,  3.88s/it]

Train Acc:90.63%, Test Acc:88.09%

Epoch: 10:0,  Loss:0.2983971834182739


 28%|██▊       | 11/40 [00:42<01:52,  3.89s/it]

Train Acc:90.77%, Test Acc:88.22%

Epoch: 11:0,  Loss:0.061349235475063324


 30%|███       | 12/40 [00:46<01:49,  3.92s/it]

Train Acc:90.99%, Test Acc:88.38%

Epoch: 12:0,  Loss:0.23819446563720703


 32%|███▎      | 13/40 [00:50<01:45,  3.91s/it]

Train Acc:91.48%, Test Acc:88.70%

Epoch: 13:0,  Loss:0.3847692608833313


 35%|███▌      | 14/40 [00:54<01:41,  3.91s/it]

Train Acc:91.66%, Test Acc:88.99%

Epoch: 14:0,  Loss:0.22984769940376282


 38%|███▊      | 15/40 [00:58<01:37,  3.89s/it]

Train Acc:92.09%, Test Acc:89.17%

Epoch: 15:0,  Loss:0.19259808957576752


 40%|████      | 16/40 [01:02<01:33,  3.88s/it]

Train Acc:92.41%, Test Acc:88.45%

Epoch: 16:0,  Loss:0.281254380941391


 42%|████▎     | 17/40 [01:06<01:29,  3.90s/it]

Train Acc:92.62%, Test Acc:88.53%

Epoch: 17:0,  Loss:0.12399900704622269


 45%|████▌     | 18/40 [01:10<01:25,  3.89s/it]

Train Acc:92.93%, Test Acc:89.00%

Epoch: 18:0,  Loss:0.1841464787721634


 48%|████▊     | 19/40 [01:14<01:21,  3.87s/it]

Train Acc:93.30%, Test Acc:89.15%

Epoch: 19:0,  Loss:0.1622452288866043


 50%|█████     | 20/40 [01:17<01:17,  3.88s/it]

Train Acc:93.39%, Test Acc:88.58%

Epoch: 20:0,  Loss:0.2504052519798279


 52%|█████▎    | 21/40 [01:21<01:13,  3.88s/it]

Train Acc:93.73%, Test Acc:89.21%

Epoch: 21:0,  Loss:0.19857482612133026


 55%|█████▌    | 22/40 [01:25<01:08,  3.83s/it]

Train Acc:93.96%, Test Acc:88.77%



 55%|█████▌    | 22/40 [01:27<01:11,  4.00s/it]


KeyboardInterrupt: 

In [79]:
model[0].s.shape, model[0].t.shape

(torch.Size([785, 1]), torch.Size([785, 1]))

In [1190]:
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

	-> MAX Train Acc 97.31666666666666 ; Test Acc 89.82
