In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm
# from sklearn import datasets

In [2]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [3]:
from tqdm import tqdm
import os, time, sys
import json

In [4]:
import dtnnlib as dtnn

In [5]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

# train_dataset = datasets.MNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [6]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [7]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50])


### Single Layer overfitting

In [8]:
h1 = 200
model = nn.Sequential(
            dtnn.DistanceTransform_MinExp(784, h1),
#             dtnn.DistanceTransform_Exp(784, h1),
#             nn.BatchNorm1d(h1),
#             nn.LeakyReLU(),
            nn.Linear(h1, 10))

In [9]:
model.to(device)

Sequential(
  (0): DistanceTransform_MinExp()
  (1): Linear(in_features=200, out_features=10, bias=True)
)

In [10]:
# model[0].set_centroid_to_data_maxdist(train_loader)
# model[0].set_centroid_to_data(train_loader)
# model[0].set_centroid_to_data_randomly(train_loader)

## Randomly

In [11]:
N = model[0].centers.shape[0]
new_center = []
new_labels = []
count = 0
for i, (xx, yy) in enumerate(train_loader):
    xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
    if count+xx.shape[0] < N:
        new_center.append(xx)
        new_labels.append(yy)
        count += xx.shape[0]
    elif count >= N:
        break
    else:
        new_center.append(xx[:N-count])
        new_labels.append(yy[:N-count])
        count = N
        break
        
new_center = torch.cat(new_center, dim=0)
new_labels = torch.cat(new_labels, dim=0)

## Maxdist

In [12]:
# epoch = 0.2

In [13]:
# N = model[0].centers.shape[0]
# new_center = torch.empty_like(model[0].centers)
# new_labels = torch.empty(model[0].num_centers, dtype=torch.long)

# min_dists = torch.empty(N)
# count = 0
# steps = int(epoch*len(train_loader))
# for i, (xx, yy) in enumerate(tqdm(train_loader)):
#     if i > steps: break

#     xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
#     if count < N:
#         if N-count < train_loader.batch_size:
#             #### final fillup
#             new_center[count:count+N-count] = xx[:N-count]
#             xx = xx[N-count:]
#             yy = yy[N-count:]
#             dists = torch.cdist(new_center, new_center)+torch.eye(N).to(model[0].centers.device)*1e5
#             min_dists = dists.min(dim=0)[0]
#             count = N

#         else:#### fill the center
#             new_center[count:count+len(xx)] = xx
#             new_labels[count:count+len(xx)] = yy
#             count += len(xx)
#             continue

#     ammd = min_dists.argmin()
#     for i, x in enumerate(xx):
#         dists = torch.norm(new_center-x, dim=1)
#         md = dists.min()
#         if md > min_dists[ammd]:
#             min_dists[ammd] = md
#             new_center[ammd] = x
#             new_labels[ammd] = yy[i]
#             ammd = min_dists.argmin()
            
# # self.centers.data = new_center.to(self.centers.device)

In [14]:
new_center.shape, new_labels.shape

(torch.Size([200, 784]), torch.Size([200]))

## Set data as parameters


In [15]:
weights = torch.zeros(len(new_labels), 10)
for i in range(len(new_labels)):
    weights[i, new_labels[i]] = 1.
# weights

In [16]:
weights.shape

torch.Size([200, 10])

In [17]:
model[0].centers.data = new_center.to(model[0].centers.device)
model[-1].weight.data = weights.t().to(model[-1].weight.data)

In [18]:
best_acc = -1
def test(epoch, model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    acc = 100.*correct/total
    return acc

In [19]:
model.eval()

Sequential(
  (0): DistanceTransform_MinExp()
  (1): Linear(in_features=200, out_features=10, bias=True)
)

In [20]:
criterion = nn.CrossEntropyLoss()

In [21]:
test_acc = test(0, model)
test_acc

[Test] 0 Loss: 1.215 | Acc: 66.930 6693/10000


66.93

# Test MultiLayer data init

In [22]:
h2 = 100
model = nn.Sequential(
            dtnn.DistanceTransform_MinExp(784, h1),
            dtnn.DistanceTransform_MinExp(h1, h2),
            nn.Linear(h2, 10))

In [23]:
model.eval()

Sequential(
  (0): DistanceTransform_MinExp()
  (1): DistanceTransform_MinExp()
  (2): Linear(in_features=100, out_features=10, bias=True)
)

### Randomly select samples

In [24]:
N = h2
new_center2 = []
new_labels2 = []
count = 0
for i, (xx, yy) in enumerate(train_loader):
    xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
    if count+xx.shape[0] < N:
        new_center2.append(xx)
        new_labels2.append(yy)
        count += xx.shape[0]
    elif count >= N:
        break
    else:
        new_center2.append(xx[:N-count])
        new_labels2.append(yy[:N-count])
        count = N
        break
        
new_center2 = torch.cat(new_center2, dim=0)
new_labels2 = torch.cat(new_labels2, dim=0)

In [25]:
new_center.shape, new_center2.shape

(torch.Size([200, 784]), torch.Size([100, 784]))

In [26]:
# if len(new_center) < h2:
#     raise Exception("The function below does not support increasing hidden units ..")

#### Layer 1

In [27]:
model[0].centers.data = new_center.to(model[0].centers.device)

#### Layer 2

In [28]:
xx_ = model[0].centers.data
yy_ = new_labels
if h2>h1:
    xx_ = torch.cat([new_center, new_center2[:h2-h1]], dim=0)
    yy_ = torch.cat([new_labels, new_labels2[:h2-h1]], dim=0)

a1 = model[0](xx_)
a1.shape

torch.Size([200, 200])

In [29]:
model[1].centers.shape

torch.Size([100, 200])

In [30]:
model[1].centers.data = a1.data[:h2].clone()

#### Layer 3 (Final layer)

In [31]:
weights = torch.zeros(h2, 10)
for i in range(h2): ## not all activations(of centers) fit into neurons in layer 2
    weights[i, yy_[i]] = 1.
# weights

In [32]:
weights.shape

torch.Size([100, 10])

In [33]:
model[-1].weight.shape

torch.Size([10, 100])

In [34]:
model[-1].weight.data = weights.t().to(model[-1].weight.data)

In [35]:
model(xx).shape

torch.Size([50, 10])

#### Test performance

In [36]:
test_acc2 = test(0, model)
test_acc2

[Test] 0 Loss: 1.817 | Acc: 59.800 5980/10000


59.8

In [37]:
model[1].centers.shape

torch.Size([100, 200])

## Initialize second layer with different sets of samples

In [38]:
# N = h2
# new_center2 = []
# new_labels2 = []
# count = 0
# for i, (xx, yy) in enumerate(train_loader):
#     xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
#     if count+xx.shape[0] < N:
#         new_center2.append(xx)
#         new_labels2.append(yy)
#         count += xx.shape[0]
#     elif count >= N:
#         break
#     else:
#         new_center2.append(xx[:N-count])
#         new_labels2.append(yy[:N-count])
#         count = N
#         break
        
# new_center2 = torch.cat(new_center2, dim=0)
# new_labels2 = torch.cat(new_labels2, dim=0)

In [39]:
new_center2.shape

torch.Size([100, 784])

#### Layer 1

In [40]:
model[0].centers.data = new_center.to(model[0].centers.device)

In [41]:
a1 = model[0](new_center2)
a1.shape

torch.Size([100, 200])

#### Layer 2

In [42]:
model[1].centers.shape

torch.Size([100, 200])

In [43]:
model[1].centers.data = a1.data[:h2].clone()

#### Layer 3 (Final layer)

In [44]:
weights = torch.zeros(h2, 10)
for i in range(h2): ## not all activations(of centers) fit into neurons in layer 2
    weights[i, new_labels2[i]] = 1.
# weights

In [45]:
model[-1].weight.data = weights.t().to(model[-1].weight.data)

#### Test performance

In [46]:
test_acc3 = test(0, model)
test_acc3

[Test] 0 Loss: 1.582 | Acc: 61.030 6103/10000


61.03

In [47]:
test_acc, test_acc2, test_acc3

(66.93, 59.8, 61.03)

## Benchmark with stats

In [48]:
# HIDDEN_UNITS = [10, 50, 200, 1000, 5000, 20000]

# seed = 2023
# np.random.seed(seed)
# SEEDS = np.random.randint(0, high=9999, size=20)
# SEEDS

In [49]:
# test_accuracy = {h:[] for h in HIDDEN_UNITS}
# test_accuracy

In [50]:
# def get_centers_and_labels(data_loader, N):
#     new_center = []
#     new_labels = []
#     count = 0
#     for i, (xx, yy) in enumerate(data_loader):
#         xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
#         if count+xx.shape[0] < N:
#             new_center.append(xx)
#             new_labels.append(yy)
#             count += xx.shape[0]
#         elif count >= N:
#             break
#         else:
#             new_center.append(xx[:N-count])
#             new_labels.append(yy[:N-count])
#             count = N
#             break

#     new_center = torch.cat(new_center, dim=0)
#     new_labels = torch.cat(new_labels, dim=0)
#     return new_center, new_labels

In [51]:
# batch_size = 50
# for h in HIDDEN_UNITS:
#     print(f"Experiment for Hidden units: {h}")
#     for seed in tqdm(SEEDS):
#         seed = int(seed)
#         torch.manual_seed(seed)
#         np.random.seed(seed)
#         random.seed(seed)
        
#         train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
#         test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)
        
#         model = nn.Sequential(
#                     dtnn.DistanceTransform_MinExp(784, h),
#                     nn.LeakyReLU(),
#                     nn.Linear(h, 10)).to(device)
        
#         new_center, new_labels = get_centers_and_labels(train_loader, h)
#         weights = torch.zeros(len(new_labels), 10)
#         for i in range(len(new_labels)):
#             weights[i, new_labels[i]] = 1.
            
#         model[0].centers.data = new_center.to(model[0].centers.device)
#         model[-1].weight.data = weights.t().to(model[-1].weight.data)
#         model.eval()
        
#         test_acc = test(0, model)
#         test_accuracy[h].append(test_acc)

In [52]:
# print(f"H \tMean \t\tSTD \tMAX")
# for k, v in test_accuracy.items():
# #     print(k, v)
#     print(f"{k} \t{np.mean(v):.4f} \t{np.std(v):.4f} \t{np.max(v)}")

In [None]:
## Simplified init for 