In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

In [2]:
device = torch.device("cuda:1")
# device = torch.device("cpu")

## MNIST dataset

In [3]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [4]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [5]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_)
test_label = torch.LongTensor(test_label_)

In [6]:
input_size = 784
output_size = 10

In [7]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx)
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [8]:
train_dataset = MNIST_Dataset(train_data, train_label)
test_dataset = MNIST_Dataset(test_data, test_label)

In [9]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                    num_workers=1, 
                                    batch_size=batch_size, 
                                    shuffle=False)

In [157]:
class StereographicTransform(nn.Module):
    
    def __init__(self, input_dim, output_dim, bias=True):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.inp_scaler = nn.Parameter(torch.Tensor([1/np.sqrt(self.input_dim)]))
        self.linear = nn.Linear(input_dim+1, output_dim, bias=bias)

    ## https://github.com/pswkiki/SphereGAN/blob/master/model_sphere_gan.py
    def forward(self, x):
#         self.linear.weight.data /= self.linear.weight.data.norm(p=2, dim=1, keepdim=True)
        ### linear has weight -> (outdim, indim) format, so normalizing per output dimension
#         print(self.linear.weight.data.norm(dim=1))
        
        x = x*self.inp_scaler
        sqnorm = (x**2).sum(dim=1, keepdim=True) ## l2 norm squared
        x = x*2/(sqnorm+1)
        new_dim = (sqnorm-1)/(sqnorm+1)
        x = torch.cat((x, new_dim), dim=1)
        x = self.linear(x)
        return x

In [158]:
st = StereographicTransform(784, 20)

In [159]:
dists = st(torch.randn(2, 784))

In [160]:
dists.shape

torch.Size([2, 20])

In [161]:
dists.norm(dim=1)

tensor([0.1274, 0.1317], grad_fn=<CopyBackwards>)

In [162]:
# dists[:, -1]
st.linear.weight.data.norm(dim=1)

tensor([0.5782, 0.5793, 0.5832, 0.5738, 0.5695, 0.5617, 0.5714, 0.5817, 0.5638,
        0.5788, 0.5761, 0.5703, 0.5660, 0.5754, 0.5762, 0.5644, 0.5762, 0.5600,
        0.5714, 0.5678])

In [183]:
model = nn.Sequential(
                StereographicTransform(784, 785),
                nn.BatchNorm1d(785),
                nn.LeakyReLU(),
                StereographicTransform(785, 200),
                nn.BatchNorm1d(200),
                nn.LeakyReLU(),
                StereographicTransform(200, 50),
                nn.BatchNorm1d(50),
                nn.LeakyReLU(),
                StereographicTransform(50, 10),
                nn.BatchNorm1d(10),
#                 nn.Linear(10, 10),
            )
model.to(device)

Sequential(
  (0): StereographicTransform(
    (linear): Linear(in_features=785, out_features=785, bias=True)
  )
  (1): StereographicTransform(
    (linear): Linear(in_features=786, out_features=200, bias=True)
  )
  (2): StereographicTransform(
    (linear): Linear(in_features=201, out_features=50, bias=True)
  )
  (3): StereographicTransform(
    (linear): Linear(in_features=51, out_features=10, bias=True)
  )
  (4): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [184]:
# model = nn.Sequential(
#                 nn.Linear(784, 785),
#                 nn.BatchNorm1d(785),
#                 nn.LeakyReLU(),
#                 nn.Linear(785, 200),
#                 nn.BatchNorm1d(200),
#                 nn.LeakyReLU(),
#                 nn.Linear(200, 50),
#                 nn.BatchNorm1d(50),
#                 nn.LeakyReLU(),
#                 nn.Linear(50, 10),
#                 nn.BatchNorm1d(10)
#             )
# model.to(device)

In [185]:
# model[0].weight.data = dt.centers.data.clone().to(device)/85.0

In [186]:
optimizer = optim.Adam(list(model.parameters()), 
                            lr=0.001)
criterion = nn.CrossEntropyLoss()

In [187]:
index = 0
train_accs, test_accs = [], []
for epoch in tqdm(list(range(40))):
    model.train()
    train_acc = 0
    train_count = 0
    for xx, yy in train_loader:
        xx, yy = xx.to(device), yy.to(device)
        yout = model(xx)
        loss = criterion(yout, yy)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        train_acc += correct
        train_count += len(outputs)

    train_accs.append(float(train_acc)/train_count*100)
    train_acc = 0
    train_count = 0
    
#     if epoch%5 == 0:
#         print(f"Shifting the centroids to the nearest data point")
#         model[0].set_centroid_to_data(train_loader)

    print(f'Epoch: {epoch}:{index},  Loss:{float(loss)}')
    test_count = 0
    test_acc = 0
    model.eval()
    for xx, yy in test_loader:
        xx, yy = xx.to(device), yy.to(device)
        with torch.no_grad():
            yout = model(xx)
        outputs = torch.argmax(yout, dim=1).data.cpu().numpy()
        correct = (outputs == yy.data.cpu().numpy()).astype(float).sum()
        test_acc += correct
        test_count += len(xx)
    test_accs.append(float(test_acc)/test_count*100)
    print(f'Train Acc:{train_accs[-1]:.2f}%, Test Acc:{test_accs[-1]:.2f}%')
    print()

### after each class index is finished training
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 0:0,  Loss:0.5743033289909363


  2%|▎         | 1/40 [00:04<02:55,  4.49s/it]

Train Acc:80.33%, Test Acc:82.86%

Epoch: 1:0,  Loss:0.4889901876449585


  5%|▌         | 2/40 [00:08<02:50,  4.50s/it]

Train Acc:83.56%, Test Acc:84.00%

Epoch: 2:0,  Loss:0.5416829586029053


  8%|▊         | 3/40 [00:13<02:46,  4.50s/it]

Train Acc:84.13%, Test Acc:83.74%

Epoch: 3:0,  Loss:0.3761681020259857


 10%|█         | 4/40 [00:17<02:41,  4.48s/it]

Train Acc:84.62%, Test Acc:83.11%

Epoch: 4:0,  Loss:0.5676597356796265


 12%|█▎        | 5/40 [00:22<02:36,  4.48s/it]

Train Acc:84.96%, Test Acc:84.25%

Epoch: 5:0,  Loss:0.5185732841491699


 15%|█▌        | 6/40 [00:26<02:31,  4.46s/it]

Train Acc:84.89%, Test Acc:84.36%

Epoch: 6:0,  Loss:0.31887975335121155


 18%|█▊        | 7/40 [00:31<02:27,  4.46s/it]

Train Acc:85.00%, Test Acc:84.05%

Epoch: 7:0,  Loss:0.4555157423019409


 20%|██        | 8/40 [00:35<02:23,  4.47s/it]

Train Acc:85.53%, Test Acc:85.05%

Epoch: 8:0,  Loss:0.5366398692131042


 22%|██▎       | 9/40 [00:40<02:18,  4.48s/it]

Train Acc:85.72%, Test Acc:84.76%

Epoch: 9:0,  Loss:0.41289252042770386


 25%|██▌       | 10/40 [00:44<02:13,  4.45s/it]

Train Acc:85.97%, Test Acc:85.06%

Epoch: 10:0,  Loss:0.30593663454055786


 28%|██▊       | 11/40 [00:49<02:09,  4.46s/it]

Train Acc:86.15%, Test Acc:84.46%

Epoch: 11:0,  Loss:0.3024251163005829


 30%|███       | 12/40 [00:53<02:04,  4.46s/it]

Train Acc:86.36%, Test Acc:85.45%

Epoch: 12:0,  Loss:0.4340760111808777


 32%|███▎      | 13/40 [00:58<01:59,  4.44s/it]

Train Acc:86.71%, Test Acc:85.29%

Epoch: 13:0,  Loss:0.3820713460445404


 35%|███▌      | 14/40 [01:02<01:55,  4.44s/it]

Train Acc:86.77%, Test Acc:85.86%

Epoch: 14:0,  Loss:0.5808930993080139


 38%|███▊      | 15/40 [01:06<01:50,  4.44s/it]

Train Acc:86.89%, Test Acc:85.52%

Epoch: 15:0,  Loss:0.35580500960350037


 40%|████      | 16/40 [01:11<01:46,  4.44s/it]

Train Acc:87.24%, Test Acc:85.96%



 40%|████      | 16/40 [01:14<01:52,  4.68s/it]


KeyboardInterrupt: 

In [None]:
print(f'\t-> MAX Train Acc {max(train_accs)} ; Test Acc {max(test_accs)}')

In [30]:
torch.save(model.state_dict(), "./models/temp_03_1_model_dec1_v0.pth")

# model.load_state_dict(torch.load("./temp_01_2_model_nov26.pth", map_location=device))