## Install packages

In [None]:
# !pip install umap-learn
# !pip install tqdm
# !pip install tensorflowjs

In [None]:
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from IPython import display
from tqdm import tqdm_notebook as tqdm
plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torchvision import models



from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
import re
import math
from time import time
import os, shutil
from glob import glob
from natsort import natsorted
import json


In [None]:
dataset_train = datasets.FashionMNIST('./dataset/', train=True, download=True)
dataset_test = datasets.FashionMNIST('./dataset/', train=False, download=True)

In [None]:
%%time
n = 10000
imgs = dataset_train.data[:n].float().numpy()
imgs /= 255.0
labels = dataset_train.targets[:n].numpy().astype(dtype=np.uint8)

umap = UMAP(metric='l1', min_dist=0.4)
embeddings = umap.fit_transform(imgs.reshape([-1,28*28]))

print(
    imgs.shape, embeddings.shape, labels.shape, imgs.dtype, embeddings.dtype, labels.dtype
)

## save computed embeddings and dataset
!mkdir ./tmp
embeddings.tofile('tmp/embeddings.bin')
labels.tofile('tmp/labels.bin')
imgs.tofile('tmp/imgs.bin')

## load computed embeddings and dataset

In [None]:
# embeddings = np.fromfile('tmp/embeddings.bin', dtype=np.float32).reshape([-1,2])
# imgs = np.fromfile('tmp/imgs.bin', dtype=np.float32).reshape([-1,28,28])
# labels = np.fromfile('tmp/labels.bin', dtype=np.uint8)

## show the embedding

In [None]:
plt.figure(figsize=[7,5])
plt.scatter(embeddings[:,0], embeddings[:,1], s=2, c=labels, cmap='tab10')
plt.colorbar()
plt.axis('square')
plt.title('embeddings')
plt.show()

## Create an embedding dataset

In [None]:
class EmbeddingDataset(Dataset):
    def __init__(self, imgs, embeddings, labels):
        self.imgs = imgs
        self.embeddings = embeddings
        self.labels = labels
    def __len__(self):
        return self.imgs.shape[0]
    
    def __getitem__(self, i):
        return self.imgs[i], self.embeddings[i], self.labels[i]
    
nTrain, nTest = 9000,1000
train_dataset = EmbeddingDataset(imgs[:nTrain], embeddings[:nTrain], labels[:nTrain])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = EmbeddingDataset(imgs[-nTest:], embeddings[-nTest:], labels[-nTest:])
test_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)

## model

In [None]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

def preprocess(x):
    return x.view(-1, 1, 28, 28)

def flatten(ndim):
    def f(x):
        return x.view(-1, ndim)
    return f

def postprocess(x):
    return x.view(-1, 28, 28)

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            Lambda(flatten(28*28)),
#             Lambda(preprocess),
#             nn.Conv2d(1,20,3),
#             nn.LeakyReLU(),
#             nn.MaxPool2d(2),
            
            nn.Linear(784,250),
            nn.LeakyReLU(),
            nn.Linear(250,100),
            nn.LeakyReLU(),
            nn.Linear(100,10),
            nn.LeakyReLU(),
            nn.Linear(10,2),
        )
        self.decoder = nn.Sequential(
            nn.Linear(2,100),
            nn.ReLU(),
            nn.Linear(100,100),
            nn.ReLU(),
            nn.Linear(100,784),
            nn.Sigmoid(),
            Lambda(postprocess),
        )
        
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x




## Util functions for vis

In [None]:
def imshowGrid(imgs):
    im = make_grid(imgs.cpu().detach().unsqueeze(1), padding=2, pad_value=1).permute(1,2,0)
    plt.imshow(im)
    plt.axis('off')

## Training

In [None]:
use_cuda = torch.cuda.is_available()

model = Autoencoder()
if use_cuda:
    model = model.cuda()
mse = nn.MSELoss()
l1 = nn.L1Loss()
kl = nn.KLDivLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
sigmoid = nn.Sigmoid()

## show progress bar here
nepoch = 30
epochBar = tqdm(range(nepoch))

In [None]:
gs = GridSpec(2,3,width_ratios=[1.5,1.5,1])
lossHistory = []
for epoch in epochBar:
    for imgs, embeddings, labels in train_loader:
        if use_cuda:
            imgs = imgs.cuda()
            embeddings = embeddings.cuda()
        
        code = model.encoder(imgs)
        recon1 = model.decoder(code) #autoencoder
        recon2 = model.decoder(embeddings) #decoder from ground truth
        

        loss = 1*mse(code, embeddings) + 1*l1(code, embeddings)
        loss += 100*mse(recon1, imgs) + 50*l1(recon1, imgs)
#         loss += 4*mse(recon2, imgs) + 2*l1(recon2, imgs)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epochBar.set_postfix({'loss': loss.item()})
        lossHistory.append(np.log(loss.item()))
        
    if epoch==0 or epoch % 10 == 9:
        with torch.no_grad():
            display.clear_output(wait=True)
            plt.figure(figsize=[20,6])
            
            plt.subplot(gs[:,0])
            plt.plot(lossHistory)
            plt.title('loss')
            
            plt.subplot(gs[:,1])
            for imgs, embeddings, labels in test_loader:
                if use_cuda:
                    code = model.encoder(imgs.cuda()) 
                else:
                    code = model.encoder(imgs)
                    
                recon = model.decoder(code)
                code = code.cpu().numpy()
                plt.scatter(embeddings[:,0], embeddings[:,1], s=1, c='#333333', zorder=1)
                plt.scatter(code[:,0], code[:,1], s=1, c=labels, cmap='tab10', zorder=2)
                
            plt.colorbar()
            plt.axis('equal')
            plt.title('learned embedding (epoch {}/{})'.format(epoch, nepoch))
#             plt.xlim([-8,6])
#             plt.ylim([-4,6])
            
            plt.subplot(gs[0,2])
            imshowGrid(imgs)
            plt.title('original')
            plt.subplot(gs[1,2])
            imshowGrid(recon)
            plt.title('learned')
            plt.show()
    


## Reload embeddings, get parameters for frontend

In [None]:
embeddings = np.fromfile('tmp/embeddings.bin', dtype=np.float32).reshape([-1,2])
imgs = np.fromfile('tmp/imgs.bin', dtype=np.float32).reshape([-1,28,28])
labels = np.fromfile('tmp/labels.bin', dtype=np.uint8)

[xmin, ymin] = embeddings.min(axis=0)
[xmax, ymax] = embeddings.max(axis=0)
vmin = torch.tensor([xmin, ymin])
vmax = torch.tensor([xmax, ymax])

print(xmin, xmax, ymin, ymax)

## Dump data to js files 

In [None]:
data_dir = './reconstruct-fashion/'
!rm -r $data_dir
!mkdir $data_dir

In [None]:
with open(data_dir+'constants.js', 'w') as f:
    f.write('let constants = {};\n')
    
    f.write('constants.xrange = \n')
    json.dump([float(xmin), float(xmax)], f)
    f.write(';\n')
    
    f.write('constants.yrange = \n')
    json.dump([float(ymin), float(ymax)], f)
    f.write(';\n')

with open(data_dir+'data.js', 'w') as f:
    f.write('let data = {}; \ndata.embeddings = \n')
    json.dump(embeddings.tolist(), f)
    f.write('\n\n')
    f.write('data.labels = \n')
    json.dump(labels.tolist(), f)

## Convert to a tfjs model

In [None]:
from tensorflow import keras
import tensorflowjs as tfjs

## define a kera model that is equicalent to the pytorch decoder
model_keras = keras.models.Sequential([
    keras.layers.Dense(100, activation='relu', input_shape=(2,)),
    keras.layers.Dense(100, activation='relu'),
    keras.layers.Dense(784, activation='sigmoid'),
])

## copy the weights
print([w.shape for w in model_keras.get_weights()])
weights = [[model.decoder[i].weight, model.decoder[i].bias]  for i in [0,2,4]]
weights = sum(weights, [])
weights = [w.data.cpu().numpy().T for w in weights]
model_keras.set_weights(weights)

# model_keras.save('model_keras.h5')
tfjs.converters.save_keras_model(model_keras, data_dir)

## Compress into a zip (optional)

In [None]:
zip_filename = data_dir[:-1] + '.zip'
!rm $zip_filename
!zip -r $zip_filename $data_dir