In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm
# from sklearn import datasets

In [2]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [3]:
from tqdm import tqdm
import os, time, sys
import json

In [4]:
import dtnnlib as dtnn

In [164]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.5,],
        std=[0.5,],
    ),
])

# train_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.FashionMNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

train_dataset = datasets.MNIST(root="../../../../_Datasets/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.MNIST(root="../../../../_Datasets/", train=False, download=True, transform=mnist_transform)

In [165]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [166]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

torch.Size([50, 1, 28, 28]) torch.Size([50])


In [188]:
h = 1000
model = nn.Sequential(
            dtnn.DistanceTransform_MinExp(784, h),
#             dtnn.DistanceTransform_Exp(784, h),
#             nn.BatchNorm1d(10),
            nn.LeakyReLU(),
            nn.Linear(h, 10))

In [189]:
model.to(device)

Sequential(
  (0): DistanceTransform_MinExp()
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=1000, out_features=10, bias=True)
)

In [179]:
# model[0].set_centroid_to_data_maxdist(train_loader)
# model[0].set_centroid_to_data(train_loader)
# model[0].set_centroid_to_data_randomly(train_loader)

## Randomly

In [180]:
N = model[0].centers.shape[0]
new_center = []
new_labels = []
count = 0
for i, (xx, yy) in enumerate(train_loader):
    xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
    if count+xx.shape[0] < N:
        new_center.append(xx)
        new_labels.append(yy)
        count += xx.shape[0]
    elif count >= N:
        break
    else:
        new_center.append(xx[:N-count])
        new_labels.append(yy[:N-count])
        count = N
        break
        
new_center = torch.cat(new_center, dim=0)
new_labels = torch.cat(new_labels, dim=0)

## Maxdist

In [112]:
epoch = 0.2

In [113]:
N = model[0].centers.shape[0]
new_center = torch.empty_like(model[0].centers)
new_labels = torch.empty(model[0].num_centers, dtype=torch.long)

min_dists = torch.empty(N)
count = 0
steps = int(epoch*len(train_loader))
for i, (xx, yy) in enumerate(tqdm(train_loader)):
    if i > steps: break

    xx = xx.reshape(-1, model[0].input_dim).to(model[0].centers.device)
    if count < N:
        if N-count < train_loader.batch_size:
            #### final fillup
            new_center[count:count+N-count] = xx[:N-count]
            xx = xx[N-count:]
            yy = yy[N-count:]
            dists = torch.cdist(new_center, new_center)+torch.eye(N).to(model[0].centers.device)*1e5
            min_dists = dists.min(dim=0)[0]
            count = N

        else:#### fill the center
            new_center[count:count+len(xx)] = xx
            new_labels[count:count+len(xx)] = yy
            count += len(xx)
            continue

    ammd = min_dists.argmin()
    for i, x in enumerate(xx):
        dists = torch.norm(new_center-x, dim=1)
        md = dists.min()
        if md > min_dists[ammd]:
            min_dists[ammd] = md
            new_center[ammd] = x
            new_labels[ammd] = yy[i]
            ammd = min_dists.argmin()
            
# self.centers.data = new_center.to(self.centers.device)

 20%|██████████                                        | 241/1200 [00:18<01:15, 12.73it/s]


In [114]:
new_center.shape, new_labels.shape

(torch.Size([100, 784]), torch.Size([100]))

## Set


In [190]:
weights = torch.zeros(len(new_labels), 10)
for i in range(len(new_labels)):
    weights[i, new_labels[i]] = 1.
# weights

In [191]:
weights.shape

torch.Size([1000, 10])

In [192]:
model[0].centers.data = new_center.to(model[0].centers.device)
model[-1].weight.data = weights.t().to(model[-1].weight.data)

In [193]:
best_acc = -1
def test(epoch, model):
    global best_acc
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(test_loader)):
            inputs, targets = inputs.to(device).view(-1, 28*28), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(f"[Test] {epoch} Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f} {correct}/{total}")
    
    acc = 100.*correct/total
    return acc

In [194]:
model.eval()

Sequential(
  (0): DistanceTransform_MinExp()
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=1000, out_features=10, bias=True)
)

In [195]:
criterion = nn.CrossEntropyLoss()

In [196]:
test_acc = test(0, model)
test_acc

100%|██████████████████████████████████████████████████| 200/200 [00:01<00:00, 146.76it/s]

[Test] 0 Loss: 1.090 | Acc: 86.130 8613/10000





86.13

In [None]:
65.09