In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

from torchvision import transforms as T
from torchvision import datasets

import random, os, pathlib, time, sys
from tqdm import tqdm
# from sklearn import datasets

In [None]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

In [None]:
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## MNIST dataset

In [None]:
input_size = 784
output_size = 10

In [None]:
train_dataset = datasets.FashionMNIST(root="data/", train=True, download=True)
test_dataset = datasets.FashionMNIST(root="data/", train=False, download=True)

In [None]:
train_dataset.data = train_dataset.data.reshape(-1, 784)/255.
test_dataset.data = test_dataset.data.reshape(-1, 784)/255.

In [None]:
# train_dataset.targets = train_dataset.targets.numpy()

In [None]:
input_size = 784
output_size = 10

In [None]:
class MNIST_Dataset(data.Dataset):
    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx)
        img, lbl = self.data[idx], self.label[idx]
        return img, lbl

In [None]:
train_dataset = MNIST_Dataset(train_dataset.data, train_dataset.targets)
test_dataset = MNIST_Dataset(test_dataset.data, test_dataset.targets)

In [None]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=True)

test_loader = data.DataLoader(dataset=test_dataset,
                                    num_workers=4, 
                                    batch_size=batch_size, 
                                    shuffle=False)

## Umap pytorch GD

In [None]:
from scipy.optimize import curve_fit

In [None]:
class UmapEps(nn.Module):
    
    def __init__(self, input_dim, output_dim, num_data, num_neighbour,
                 min_dist=0.1, spread=1.0, negative_sample_rate=5, num_epsilons=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.min_dist = min_dist
        self.spread = spread
        self.num_neighbour = num_neighbour
        self.num_data = num_data+num_epsilons
        self.num_epsilons = num_epsilons
        self.negative_sample_rate = negative_sample_rate
        
        self.a, self.b = self.find_ab_params(self.spread, self.min_dist)
        
        self.y_centers = nn.Parameter(torch.randn(self.num_data, self.output_dim)/3.)
        self.criterion = nn.BCELoss()
        
        self.sigma = None
        self.cache = None
        pass
    
    def fit_step(self, x, epsilon):
        
#         neg_num = min(self.negative_sample_rate, x.shape[0]-self.num_neighbour)
        neg_num = self.negative_sample_rate*self.num_neighbour

        ### can do only once for same x and y .. 
        if self.cache is None:
            assert x.shape[0] > self.num_neighbour
            
            ### positive sampling only
            dists = torch.cdist(x, x)+torch.eye(x.shape[0]).to(x.device)*1e5
            
            ### add epsilon to all dists
            e = torch.ones(x.shape[0], self.num_epsilons).to(x)*epsilon
            dists = torch.cat([dists, e], dim=1)

            ### add epsilon itself as a node
            e = torch.ones(self.num_epsilons, x.shape[0]+self.num_epsilons).to(x)*epsilon
            e[:,-self.num_epsilons:] = 0.
            
            dists = torch.cat([dists, e], dim=0)
            
            dists, indices = torch.topk(dists, k=self.num_neighbour, dim=1, largest=False, sorted=False)

            dists = (dists-dists.min(dim=1, keepdim=True)[0])

            if self.sigma is None:
                self.sigma = self.get_sigma(dists.data)
                self.sigma[torch.isnan(self.sigma)] = 1.

            dists = dists/self.sigma
            
            
            dists = torch.exp(-dists)
            
            
            probX = torch.zeros(dists.shape[0], dists.shape[0]).to(x.device)
            probX.scatter_(dim=1, index=indices, src=dists)
            probX = probX+probX.t()-probX*probX.t()

            self.cache = (probX, indices)
        else:
            probX, indices = self.cache

        probX_ = torch.gather(probX, dim=1, index=indices)
        dists = torch.cdist(self.y_centers, self.y_centers)
        
        probY = torch.gather(dists, dim=1, index=indices)
        probY = 1/(1+self.a*(probY**(2*self.b)))
        loss_positive = self._bceloss_(probX_, probY)

#         return loss_positive
        
        negative_indices = torch.randint(low=0, high=x.shape[0]+self.num_epsilons, size=(x.shape[0]+self.num_epsilons, neg_num)).to(indices.device)
        ## by default use this (uses 0 as target)
        probX_ = torch.zeros(x.shape[0]+self.num_epsilons, neg_num, device=x.device)
        
        probY = torch.gather(dists, dim=1, index=negative_indices)
        probY = 1/(1+self.a*(probY**(2*self.b)))
#         loss_negative = self.criterion(probY, probX_)
        loss_negative = self._bceloss_(probX_, probY)
        
        ### if mean is used
        loss = loss_positive+loss_negative*self.negative_sample_rate
        ### if sum is used
#         loss = (loss_positive+loss_negative)*1/(x.shape[0]*self.num_neighbour)
    
        return loss
    
    def get_sigma(self, dists, epoch=700, lr=0.03):
        k = dists.shape[1]
        sigma = nn.Parameter(torch.std(dists.data, dim=1, keepdim=True)*0.2)
        optim = torch.optim.Adam([sigma], lr=lr)
        target = torch.log2(torch.ones_like(sigma)*k).to(dists.device)
        for i in range(epoch):
            delta = torch.sum(torch.exp(-dists/sigma), dim=1, keepdim=True)
            delta = delta-target

            optim.zero_grad()
            error = (delta**2).sum()
            error.backward()
            optim.step()
        return sigma.data
    
    def _bceloss_(self, pX, pY):
        logy = torch.clamp(torch.log(pY), min=-100)
        log1_y = torch.clamp(torch.log(1-pY), min=-100)
#         logy = torch.log(pY) ## gets nan loss
#         log1_y = torch.log(1-pY)
        return -torch.mean(pX*logy+(1-pX)*log1_y)
    
        
    def find_ab_params(self, spread, min_dist):

        def curve(x, a, b):
            return 1.0 / (1.0 + a * x ** (2 * b))

        xv = np.linspace(0, spread * 3, 300)
        yv = np.zeros(xv.shape)
        yv[xv < min_dist] = 1.0
        yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
        params, covar = curve_fit(curve, xv, yv)
        return params[0], params[1]

In [None]:
# torch.randn(784, 784)[torch.eye(784).type(torch.bool)]

In [None]:
num_train = 5000
#### use at least 2 epsilon.. so that the eplisons can attract with each other and repel rest.
ump = UmapEps(784, 2, num_data=num_train, num_neighbour=10, negative_sample_rate=2, num_epsilons=10*2).to(device)

In [None]:
ump.a, ump.b

In [None]:
indices = np.random.permutation(len(train_loader.dataset.data))[:num_train]
xx, yy = train_loader.dataset[indices]
xx = xx.to(device)
xx.shape

In [None]:
torch.cdist(xx, xx).mean()

In [None]:
epsilon = 6
ump.fit_step(xx, epsilon=epsilon) ## loss

In [None]:
# adasd

In [None]:
optimizer = torch.optim.Adam(ump.parameters(), lr=0.25)

In [None]:
yy = torch.cat([yy, torch.Tensor([10]*ump.num_epsilons).to(yy)])

In [None]:
yy.shape

In [None]:
EPOCHS = 10000#//2

In [None]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
#### Train with Optimizer

train_error = []
for epoch in tqdm(list(range(EPOCHS))):
    
    loss = ump.fit_step(xx, epsilon=epsilon)
    optimizer.zero_grad()
    loss.backward()
    
    ump.y_centers.grad[torch.isnan(ump.y_centers.grad)] = 0
    
    optimizer.step()
    scheduler.step()
    stdm = ump.y_centers.data.std()

    train_error.append(float(loss))
    
    if epoch%100 == 0:
        print(f'Epoch: {epoch},  Loss:{float(loss)}')
        plt.scatter(*ump.y_centers.cpu().data.numpy().T, c=yy, marker='.', cmap="tab10")
        plt.scatter(*ump.y_centers.cpu().data.numpy()[-ump.num_epsilons:].T, c='k', marker='*', s=100)
        
        plt.show()

In [None]:
plt.scatter(*ump.y_centers.cpu().data.numpy().T, c=yy, marker='.', cmap="tab20", s=1)
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

plt.scatter(*ump.y_centers.cpu().data.numpy()[-ump.num_epsilons:].T, c='k', marker='*', s=100)

In [None]:
adasdasd

### Test data samples

In [None]:
indices = np.random.permutation(len(test_loader.dataset.data))[:100]
test_xx, test_yy = test_loader.dataset[indices]
test_xx = test_xx.to(device)
test_xx.shape

### random uniform samples

In [None]:
randalpha = 0.0

test_xx = test_xx*(1-randalpha)+randalpha*torch.rand(100, 784).to(device)
test_xx.shape

In [None]:
def transform_step(self, train_x, x, testy_centers, epsilon, cache=None):
    
    if cache is None:
        ########################################
        ### Pre computation step
        dists = torch.cdist(x, train_x)

        ## disconnection_distance parameter not used
        
        ### add epsilon to all dists
        e = torch.ones(x.shape[0], self.num_epsilons).to(x)*epsilon
        dists = torch.cat([dists, e], dim=1)
        
        dists, indices = torch.topk(dists, k=self.num_neighbour, dim=1, largest=False, sorted=False)            
            
        dists = (dists-dists.min(dim=1, keepdim=True)[0])

        sigma = self.get_sigma(dists)
        sigma[torch.isnan(sigma)] = 1
        
        dists = dists/sigma
        dists = torch.exp(-dists)

        sz = max(x.shape[0], train_x.shape[0])+self.num_epsilons
        probX = torch.zeros(sz, sz).to(x.device)

        probX.scatter_(dim=1, index=indices, src=dists)

        probX = probX+probX.t()-probX*probX.t()

        ### find non-zero rows
        return (probX, indices)
    
    
    probX, indices = cache
    
    ######################################
    ### positive sampling step
    probX = torch.gather(probX, dim=1, index=indices)

    dists = torch.cdist(testy_centers, self.y_centers)
    
    probY = torch.gather(dists, dim=1, index=indices)
    probY = 1/(1+self.a*(probY**(2*self.b)))
    
    
    loss_positive = self._bceloss_(probX, probY)
    
    #############################################
    ### negative sampling
    neg_num = self.negative_sample_rate*self.num_neighbour
    
    probX = torch.zeros(x.shape[0], neg_num, device=x.device)
    negative_indices = torch.randint(low=0, high=train_x.shape[0]+self.num_epsilons, size=(x.shape[0], neg_num)).to(x.device)
    
    probY = torch.gather(dists, dim=1, index=negative_indices)
    probY = 1/(1+self.a*(probY**(2*self.b)))
    loss_negative = self._bceloss_(probX, probY)

    ### if mean is used
    loss = loss_positive+loss_negative*self.negative_sample_rate
    return loss

In [None]:
y_centers = nn.Parameter(torch.randn(test_xx.shape[0], ump.output_dim).to(device)/3.)
y_centers.requires_grad

In [None]:
# ## Initialize y_centers with nearest sample from training set
# nearest_idx = torch.cdist(test_xx, xx).argmax(dim=1)
# y_centers.data = ump.y_centers.data[nearest_idx]

In [None]:
cache = transform_step(ump, xx, test_xx, y_centers, epsilon, cache=None) ## first get cache
# cache

In [None]:
transform_step(ump, xx, test_xx, y_centers, epsilon, cache)

In [None]:
EPOCHS = 3000
optimizer = torch.optim.Adam([y_centers], lr=0.25)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
#### Train with Optimizer

train_error = []
for epoch in tqdm(list(range(EPOCHS))):
    loss = transform_step(ump, xx, test_xx, y_centers, epsilon, cache)
    optimizer.zero_grad()
    loss.backward()
    
#     print(ump.y_centers.grad)
#     print(torch.count_nonzero(torch.isnan(ump.y_centers.grad)))
#     print(torch.count_nonzero(torch.isinf(ump.y_centers.grad)))
    y_centers.grad[torch.isnan(y_centers.grad)] = 0
    
    optimizer.step()
#     stdm = y_centers.data.std()
#     ump.y_centers.data = ump.y_centers.data + \
#                 (torch.rand_like(ump.y_centers.data)-0.5)*0.01*float(stdm)

    train_error.append(float(loss))
    
    if epoch%100 == 0:
#         print(ump.y_centers.data.std())
        print(f'Epoch: {epoch},  Loss:{float(loss)}')
        plt.scatter(*y_centers.cpu().data.numpy().T, c=test_yy, marker='.', cmap="tab10")
        plt.show()
#     break

In [None]:
plt.scatter(*ump.y_centers.cpu().data.numpy().T, c=yy, marker='.', cmap="tab10", s=3, alpha=0.2)

plt.scatter(*y_centers.cpu().data.numpy().T, c=test_yy, marker='*', edgecolors='k', s=50, cmap='tab10',
            alpha=0.5, zorder=100)

plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))

plt.scatter(*ump.y_centers.cpu().data.numpy()[-ump.num_epsilons:].T, marker='o', edgecolors='k', facecolors='None', s=100,
            alpha=0.3, zorder=-100)

#### Plot with interpolation

In [None]:
indices = np.random.permutation(len(test_loader.dataset.data))[:100]
_test_xx, test_yy = test_loader.dataset[indices]
_test_xx = _test_xx.to(device)

In [None]:
randalphas = np.linspace(0, 1, 11)
alp_idx = -1
randalphas

In [None]:
randval = torch.rand(len(test_xx), 784).to(device)

## ==>> rerun below code from here

In [None]:
alp_idx += 1
randalpha = randalphas[alp_idx]
print(randalpha)

In [None]:
test_xx = _test_xx*(1-randalpha)+randalpha*randval
test_xx.shape

In [None]:
y_centers = nn.Parameter(torch.randn(test_xx.shape[0], ump.output_dim).to(device)/3.)
y_centers.requires_grad

In [None]:
cache = transform_step(ump, xx, test_xx, y_centers, epsilon, cache=None) ## first get cache

In [None]:
transform_step(ump, xx, test_xx, y_centers, epsilon, cache)

In [None]:
EPOCHS = 3000
optimizer = torch.optim.Adam([y_centers], lr=0.25)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
#### Train with Optimizer

train_error = []
for epoch in tqdm(list(range(EPOCHS))):
    loss = transform_step(ump, xx, test_xx, y_centers, epsilon, cache)
    optimizer.zero_grad()
    loss.backward()
    
    y_centers.grad[torch.isnan(y_centers.grad)] = 0
    
    optimizer.step()

    train_error.append(float(loss))


In [None]:
!mkdir ./outputs/16_epsilon_umap/

In [None]:
plt.scatter(*ump.y_centers.cpu().data.numpy().T, c=yy, marker='.', cmap="tab10", s=3, alpha=0.2)

plt.scatter(*y_centers.cpu().data.numpy().T, c=test_yy, marker='*', edgecolors='k', s=50, cmap='tab10',
            alpha=0.5, zorder=100)
plt.colorbar(boundaries=np.arange(11)-0.5).set_ticks(np.arange(10))


plt.scatter(*ump.y_centers.cpu().data.numpy()[-ump.num_epsilons:].mean(axis=0, keepdims=True).T, marker='o', edgecolors='k', facecolors='None', s=100,
            alpha=1.0, lw=2, zorder=-100)

plt.savefig(f"./outputs/16_epsilon_umap/embed_alpha{np.round(randalpha, decimals=1)}.pdf", bbox_inches="tight")