In [29]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

import random, os, pathlib, time
from tqdm import tqdm
from sklearn import datasets

from sklearn import tree, cluster

In [30]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

## MNIST dataset

In [31]:
import mylibrary.datasets as datasets
import mylibrary.nnlib as tnn

In [32]:
mnist = datasets.FashionMNIST()
# mnist.download_mnist()
# mnist.save_mnist()
train_data, train_label_, test_data, test_label_ = mnist.load()

train_data = train_data / 255.
test_data = test_data / 255.

# train_label = tnn.Logits.index_to_logit(train_label_)
train_size = len(train_label_)

In [33]:
## converting data to pytorch format
train_data = torch.Tensor(train_data)
test_data = torch.Tensor(test_data)
train_label = torch.LongTensor(train_label_).reshape(-1,1)
test_label = torch.LongTensor(test_label_).reshape(-1,1)

In [34]:
class DistanceTransform(nn.Module):
    
    def __init__(self, input_dim, num_centers, p=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_centers = num_centers
        self.p = p
        
        self.centers = torch.randn(num_centers, input_dim)/3.
#         self.centers = torch.rand(num_centers, input_dim)
        self.centers = nn.Parameter(self.centers)
        self.inv_params = None
    
    def forward(self, x):
        dists = torch.cdist(x, self.centers, p=self.p)
        
        ### normalize similar to UMAP
#         dists = dists-dists.min(dim=1, keepdim=True)[0]
#         dists = dists-dists.mean(dim=1, keepdim=True)
#         dists = dists/dists.std(dim=1, keepdim=True)

        dists = torch.softmax(dists, dim=1)

        return dists
    
    def set_centroid_to_data_randomly(self, data_loader):
        indices = np.random.permutation(len(data_loader.dataset.data))[:self.centers.shape[0]]
        self.centers.data = data_loader.dataset.data[indices].to(self.centers.device)
        self.centers.data += torch.randn_like(self.centers)*0.01
        pass
    
    def set_centroid_to_data_maxdist(self, data_loader):
        ## sample N points
        N = self.centers.shape[0]
        new_center = torch.empty_like(self.centers)
        min_dists = torch.empty(N)
        count = 0
        for i, (xx, _) in enumerate(tqdm(data_loader)):
            if count < N:
                if N-count < batch_size:
                    #### final fillup
                    new_center[count:count+N-count] = xx[:N-count]
                    xx = xx[N-count:]
                    dists = torch.cdist(new_center, new_center)+torch.eye(N)*1e5
                    min_dists = dists.min(dim=0)[0]
                    count = N

                else:#### fill the center
                    new_center[count:count+len(xx)] = xx
                    count += len(xx)
                    continue

            ammd = min_dists.argmin()
            for i, x in enumerate(xx):
                dists = torch.norm(new_center-x, dim=1)
                md = dists.min()
                if md > min_dists[ammd]:
                    min_dists[ammd] = md
                    new_center[ammd] = x
                    ammd = min_dists.argmin()
        self.centers.data = new_center.to(self.centers.device)
        pass
        
    
    def set_centroid_to_data(self, data_loader):
        new_center = self.centers.data.clone()
        min_dists = torch.ones(self.centers.shape[0])*1e9

        for xx, _ in data_loader:

            dists = torch.cdist(xx, self.centers.data)
            ### min dist of each center to the data points
            min_d, arg_md = dists.min(dim=0)

            ### dont allow same point to be assigned as closest to multiple centroid
            occupied = []
            for i in np.random.permutation(len(arg_md)):
        #     for i, ind in enumerate(arg_md):
                ind = arg_md[i]
                if ind in occupied:
                    min_d[i] = min_dists[i]
                    arg_md[i] = -1
                else:
                    occupied.append(ind)

            ### the index of centroids that have new min_dist
            idx = torch.nonzero(min_d<min_dists).reshape(-1)

            ### assign new_center to the nearest data point
            new_center[idx] = xx[arg_md[idx]]
            min_dists[idx] = min_d[idx]
            
        self.centers.data = new_center.to(self.centers.device)
        pass
    
    def compute_inverse_matrix(self):
#         A = 2*(self.centers.data[1:]-self.centers.data[:-1])
        A = 2*(self.centers[1:]-self.centers[:-1])
        
#         c2 = self.centers.data**2
        c2 = self.centers**2
        Z = (c2[:-1]-c2[1:]).sum(dim=1, keepdim=True)
        
        invA = torch.pinverse(A)
#         invA = A.t()
        self.inv_params = (invA, Z)
        
    def inverse(self, dists):
        assert self.inv_params is not None
        
        d2 = dists**2
        D = d2[:, :-1]-d2[:, 1:]

        invA, Z = self.inv_params
        xrec = torch.matmul(invA, D.t()-Z).t()
        return xrec

In [35]:
class EMA(object):

    def __init__(self, momentum=0.1, mu=None):
        self.mu = mu
        self.momentum = momentum

    def __call__(self, x):
        if self.mu is None:
            self.mu = x
        self.mu = self.momentum*self.mu + (1.0 - self.momentum)*x
        return self.mu

## Ordinary Decision Tree

In [36]:
dtr = tree.DecisionTreeClassifier(max_depth=5)

In [37]:
dtr.fit(train_data.data.numpy(), train_label.data.numpy())
yout = dtr.predict(train_data.data.numpy())

In [38]:
yout.shape

(60000,)

In [39]:
acc = yout.reshape(-1,1)==train_label.data.numpy()
acc = np.count_nonzero(acc)/len(acc)
print(acc)

0.7105333333333334


In [40]:
# %matplotlib inline
# plt.figure(figsize=(12,8))
# tree.plot_tree(dtr)
# plt.savefig("./models/tree_mnist_cls_ord.svg")

In [41]:
yout = dtr.predict(test_data.data.numpy())

In [42]:
acc = yout.reshape(-1,1)==test_label.data.numpy()
acc = np.count_nonzero(acc)/len(acc)
print(acc)

0.6938


### Decision tree with distance transform

In [43]:
dt = DistanceTransform(784, num_centers=785)
dt.centers.data = train_data[np.random.permutation(len(train_data))[:dt.num_centers]]
dt.centers.data += torch.randn_like(dt.centers)*0.01

In [44]:
# ### centers using k-means
# kmeans = cluster.KMeans(init="k-means++", n_clusters=785, n_init=1, verbose=1)
# kmeans.fit(train_data)

In [45]:
# kmeans.cluster_centers_.shape

In [46]:
# dt.centers.data *= 0
# dt.centers.data += torch.Tensor(kmeans.cluster_centers_)

In [47]:
xx_ = dt(train_data).data
xx_

tensor([[7.8973e-06, 1.5246e-03, 2.2275e-04,  ..., 2.9124e-05, 3.3851e-03,
         7.8630e-05],
        [1.7634e-03, 6.4385e-04, 3.5041e-03,  ..., 2.3562e-03, 1.8775e-04,
         7.9201e-05],
        [1.3078e-03, 2.5113e-03, 1.5333e-05,  ..., 8.3671e-05, 3.1240e-06,
         2.7901e-05],
        ...,
        [2.1806e-03, 4.9106e-04, 1.6244e-03,  ..., 1.6201e-03, 1.1216e-05,
         2.2745e-04],
        [1.5550e-03, 4.3281e-03, 1.5386e-05,  ..., 1.4530e-04, 4.7621e-05,
         1.6202e-05],
        [2.8211e-04, 2.4850e-03, 2.8136e-07,  ..., 9.5159e-06, 4.3840e-05,
         8.2374e-06]])

In [48]:
xx_.shape, train_label.shape

(torch.Size([60000, 785]), torch.Size([60000, 1]))

In [49]:
dtr_ = tree.DecisionTreeClassifier(max_depth=5)

In [None]:
dtr_.fit(xx_.data.numpy(), train_label.data.numpy())

In [23]:
yout_ = dtr_.predict(xx_.data.numpy())

In [24]:
yout_.shape

(60000,)

In [25]:
acc = yout_.reshape(-1,1)==train_label.data.numpy()
acc = np.count_nonzero(acc)/len(acc)
print(acc)

0.6558666666666667


In [26]:
# %matplotlib inline
# plt.figure(figsize=(12,8))
# tree.plot_tree(dtr)
# plt.savefig("./models/tree_mnist_cls_dist.svg")

In [27]:
yout_ = dtr_.predict(dt(test_data).data.numpy())

In [28]:
acc = yout_.reshape(-1,1)==test_label.data.numpy()
acc = np.count_nonzero(acc)/len(acc)
print(acc)

0.6481
