In [None]:
import umap

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
%matplotlib inline
# from mpl_toolkits.mplot3d import Axes3D

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, transforms as T

import random, os, pathlib, time
from tqdm import tqdm
# from sklearn import datasets

In [None]:
# device = torch.device("cuda:0")
device = torch.device("cpu")

In [None]:
from tqdm import tqdm
import os, time, sys
import json

In [None]:
import dtnnlib as dtnn

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
#     T.Normalize(
#         mean=[0.5,],
#         std=[0.5,],
#     ),
])

train_dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="./data", train=False, download=True, transform=mnist_transform)

In [None]:
batch_size = 50
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
for xx, yy in train_loader:
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

In [None]:
xx.min(), xx.max()

In [None]:
train_dataset.data.shape, train_dataset.targets.shape

In [None]:
embed = umap.UMAP(n_neighbors=50, n_components=2, min_dist=0.1, spread=1)

In [None]:
selected_points = test_dataset.data.reshape(-1, 28*28).type(torch.float32)/255
embedding = embed.fit_transform(selected_points)

In [None]:
for xx, yy in train_loader:
    xx = xx.reshape(-1, 28*28)
    print(xx.shape, yy.shape)
    break

In [None]:
xx.max()

In [None]:
xtransf = embed.transform(xx.numpy().reshape(-1, 28*28))

In [None]:
xx.max()

In [None]:
def get_sigma(dists, k=50, epoch=700, lr=0.03):
    dists = torch.topk(dists, k=k,dim=1)[0]
    sigma = nn.Parameter(torch.std(dists.data, dim=1, keepdim=True)*0.2)
    optim = torch.optim.Adam([sigma], lr=lr)
    target = torch.log2(torch.ones_like(sigma)*k).to(dists.device)
    for i in range(epoch):
        delta = torch.sum(torch.exp(-dists/sigma), dim=1, keepdim=True)
        delta = delta-target

        optim.zero_grad()
        error = (delta**2).sum()
        error.backward()
        optim.step()
    return sigma.data

In [None]:
# sigma = get_sigma(dists)

In [None]:
!mkdir outputs/07_visualize_actv_umap

## Get scale to top-k points

In [None]:
transform = dtnn.DistanceTransformBase(28*28, len(test_dataset))
transform.centers.data = selected_points

In [None]:
with torch.no_grad():
    dists = transform(xx)

In [None]:
dists = dists-dists.min(dim=1, keepdim=True)[0]
# dists = torch.topk(dists, k=50,dim=1)[0]
sigma = get_sigma(dists, k=10000)
topk_dists = torch.exp(-dists/sigma)

In [None]:
topk_dists.shape

In [None]:
topk_dists.max(), topk_dists.min()

In [None]:
i = -1

In [None]:
i += 1
activ = topk_dists[i].data.cpu().numpy()

print(f"index:{i}/{len(dists)}")
fig = plt.figure(figsize=(10,8))

plt.scatter(embedding[:,0], embedding[:, 1], c=test_dataset.targets, s=np.maximum(activ*80, 0.001), cmap="tab10")

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

ins = plt.gca().inset_axes([0.4,0.75,0.2,0.2]) ## coordinate according to embeddings
ins.imshow(xx[i].numpy().reshape(28, 28), cmap='gray_r')
ins.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

plt.scatter(xtransf[i,0], xtransf[i,1], marker='x', color='k', s=100)

plt.text(.84, .97, f'T={int(yy[i])}', ha='left', va='top', transform=fig.axes[0].transAxes, fontsize="large")
plt.savefig(f"./outputs/07_visualize_actv_umap/umap_scaled_actf_distance_i{i}.pdf", bbox_inches='tight')
plt.show()

### Visualize for Distance based Umap - but Dot product activation

In [None]:
transform2 = nn.Linear(28*28, len(test_dataset))
transform2.weight.data = selected_points

In [None]:
with torch.no_grad():
    dists = -transform2(xx)

In [None]:
dists = dists-dists.min(dim=1, keepdim=True)[0]
# dists = torch.topk(dists, k=50,dim=1)[0]
sigma = get_sigma(dists, k=10000)
topk_dists = torch.exp(-dists/sigma)

In [None]:
topk_dists.max(), topk_dists.min()

In [None]:
i = -1

In [None]:
i += 1
activ = topk_dists[i].data.cpu().numpy()

print(f"index:{i}/{len(dists)}")
fig = plt.figure(figsize=(10,8))

plt.scatter(embedding[:,0], embedding[:, 1], c=test_dataset.targets, s=np.maximum(activ*80, 0.001), cmap="tab10")

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

ins = plt.gca().inset_axes([0.4,0.75,0.2,0.2])
ins.imshow(xx[i].numpy().reshape(28, 28), cmap='gray_r')
ins.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

plt.scatter(xtransf[i,0], xtransf[i,1], marker='x', color='k', s=100)

plt.text(.84, .97, f'T={int(yy[i])}', ha='left', va='top', transform=fig.axes[0].transAxes, fontsize="large")
plt.savefig(f"./outputs/07_visualize_actv_umap/umap_scaled_actf_linear_i{i}.pdf", bbox_inches='tight')
plt.show()


### Visualize for Distance based Umap - and Cosine angle activation

In [None]:
transform3 = nn.Linear(28*28, len(test_dataset))
transform3.weight.data = selected_points/torch.norm(selected_points, dim=1, keepdim=True)

In [None]:
with torch.no_grad():
    xx_ = xx/torch.norm(xx, dim=1, keepdim=True)
    dists = -transform3(xx_)

In [None]:
dists = dists-dists.min(dim=1, keepdim=True)[0]
sigma = get_sigma(dists, k=10000, lr=0.002)
topk_dists = torch.exp(-dists/sigma)

In [None]:
topk_dists.max(), topk_dists.min()

In [None]:
i = -1

In [None]:
i += 1
activ = topk_dists[i].data.cpu().numpy()

print(f"index:{i}/{len(dists)}")
fig = plt.figure(figsize=(10,8))

plt.scatter(embedding[:,0], embedding[:, 1], c=test_dataset.targets, s=np.maximum(activ*80, 0.001), cmap="tab10")

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

ins = plt.gca().inset_axes([0.4,0.75,0.2,0.2])
ins.imshow(xx[i].numpy().reshape(28, 28), cmap='gray_r')
ins.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

plt.scatter(xtransf[i,0], xtransf[i,1], marker='x', color='k', s=100)

plt.text(.84, .97, f'T={int(yy[i])}', ha='left', va='top', transform=fig.axes[0].transAxes, fontsize="large")
plt.savefig(f"./outputs/07_visualize_actv_umap/umap_scaled_actf_angle_i{i}.pdf", bbox_inches='tight')
plt.show()


# Other Vizz

### Visualize for Cosine based Umap - and Cosine angle activation

In [None]:
embed2 = umap.UMAP(n_neighbors=50, n_components=2, min_dist=0.1, spread=1, metric='cosine')

In [None]:
embedding2 = embed2.fit_transform(selected_points)

In [None]:
xtransf2 = embed2.transform(xx.numpy().reshape(-1, 28*28))

In [None]:
transform3 = nn.Linear(28*28, len(test_dataset))
transform3.weight.data = selected_points/torch.norm(selected_points, dim=1, keepdim=True)

In [None]:
selected_points.shape

In [None]:
i = -1

In [None]:
i += 1
with torch.no_grad():
    xx_ = xx[i:i+1]
    xx_ /= torch.norm(xx_, dim=1, keepdim=True)
    dists = -transform3(xx_)
dists.shape
print(yy[i])

In [None]:
dists = dists-dists.min(dim=1, keepdim=True)[0]
sigma = get_sigma(dists, k=10000, lr=0.002)
topk_dists = torch.exp(-dists/sigma)

In [None]:
topk_dists.max(), topk_dists.min()

In [None]:
activ = topk_dists.data.cpu().numpy()[0]

print(f"index:{i}/{len(dists)}")
fig = plt.figure(figsize=(10,8))

plt.scatter(embedding[:,0], embedding[:, 1], c=test_dataset.targets, s=np.maximum(activ*80, 0.001), cmap="tab10")

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

ins = plt.gca().inset_axes([0.6,0.75,0.2,0.2])
ins.imshow(xx[i].numpy().reshape(28, 28), cmap='gray_r')
ins.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

plt.scatter(xtransf[i,0], xtransf[i,1], marker='x', color='k', s=100)

plt.text(.84, .97, f'T={int(yy[i])}', ha='left', va='top', transform=fig.axes[0].transAxes, fontsize="large")

plt.show()


In [None]:
activ = topk_dists.data.cpu().numpy()[0]

print(f"index:{i}/{len(dists)}")
fig = plt.figure(figsize=(10,8))

plt.scatter(embedding2[:,0], embedding2[:, 1], c=test_dataset.targets, s=np.maximum(activ*80, 0.001), cmap="tab10")

plt.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

cbar = plt.colorbar(ticks=range(10), #label='classes', 
                    boundaries=np.arange(11)-0.5)
cbar.set_ticks(np.arange(10))
cbar.set_ticklabels(list(range(10)))

ins = plt.gca().inset_axes([0.6,0.75,0.2,0.2])
ins.imshow(xx[i].numpy().reshape(28, 28), cmap='gray_r')
ins.tick_params(left = False, right = False , labelleft = False ,
                labelbottom = False, bottom = False)

plt.scatter(xtransf2[i,0], xtransf2[i,1], marker='x', color='k', s=100)

plt.text(.84, .97, f'T={int(yy[i])}', ha='left', va='top', transform=fig.axes[0].transAxes, fontsize="large")

plt.show()
