In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
import itertools
import gc
import pickle

In [None]:
class VectorQuantizerEMA(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = torch.nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = torch.nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = torch.nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = torch.nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = torch.nn.functional.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

In [None]:
class Encoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim):
    super().__init__()
    self.model = torch.nn.Sequential(
        torch.nn.Conv2d(n_input_channels, hidden_size, kernel_size=3, stride=2, padding=1),
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, hidden_size, kernel_size=3),
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, 2*hidden_size, kernel_size=3),
        torch.nn.GELU(),
        torch.nn.Conv2d(2*hidden_size, 2*hidden_size, kernel_size=3, stride=2),
        torch.nn.GELU(),
        #torch.nn.Flatten()
    )

    #self.linear_mean = torch.nn.Linear(2*hidden_size*25, latent_dim)
    #self.linear_logvar = torch.nn.Linear(2*hidden_size*25, latent_dim)
    #self.linear = torch.nn.Linear(2*hidden_size*16, latent_dim)

  def forward(self, x):
    x = self.model(x)
    #x = self.linear(x)
    return x

In [None]:
class Decoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim):
    super().__init__()
    #self.linear = torch.nn.Sequential(torch.nn.Linear(latent_dim, 2 * 16 * hidden_size), torch.nn.GELU())

    self.model = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(2*hidden_size, 2*hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1), 
        torch.nn.GELU(),
        torch.nn.Conv2d(2*hidden_size, 2*hidden_size, kernel_size=3, padding=1),
        torch.nn.GELU(),
        torch.nn.ConvTranspose2d(2*hidden_size, hidden_size, kernel_size=3, stride=2, output_padding=1, padding=1), 
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, hidden_size, kernel_size=3), # , padding=1
        torch.nn.GELU(),
        torch.nn.ConvTranspose2d(hidden_size, n_input_channels, kernel_size=3, stride=2, output_padding=1, padding=1), 
        torch.nn.Tanh(),
        #torch.nn.Sigmoid(),
    )

  def forward(self, x):
    #x = self.linear(x)
    #x = x.reshape(x.shape[0], -1, 4, 4)
    x = self.model(x)
    return x

In [None]:
class Autoencoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
    super().__init__()

    self.encoder = Encoder(n_input_channels, hidden_size, latent_dim)
    self.decoder = Decoder(n_input_channels, hidden_size, latent_dim)
    self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5)

  def forward(self, x):
    x = self.encoder(x)
    loss, quantized, perplexity, enc = self._vq_vae(x)
    x = self.decoder(quantized)
    return x, enc, loss, perplexity

In [None]:
def visualize_grid(x_batch):
  im_rec = Image.fromarray(torchvision.utils.make_grid((x_batch*0.5+0.5) * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
  #im_rec = Image.fromarray(torchvision.utils.make_grid((x_batch) * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
  return im_rec.resize((im_rec.size[0]*4, im_rec.size[1]*4))

In [None]:
batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=8)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
if trainset.data.shape[-1] == 3: 
    n_input_channels, hidden_size, latent_dim = 3, 28, 24
else:
    n_input_channels, hidden_size, latent_dim = 1, 28, 24
    
num_embeddings, embedding_dim, commitment_cost = 128, 56, 0.25 # embedding_dim = num_channels at the output of the encoder
learning_rate = 1e-3
decay = 0.99
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Autoencoder(n_input_channels, hidden_size, latent_dim, num_embeddings, embedding_dim, commitment_cost, decay).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)

trials = {}
for num_embeddings in range(5, 350, 5): 
    for embedding_dem in range(5, 350, 5): 
        trials[num_embeddings,embedding_dim] = []

In [None]:
if type(trainset.data) is np.ndarray:
    data_variance = np.var(trainset.data / 255.0)
else:
    data_variance = torch.var(trainset.data / 255.0)

In [None]:
for epoch in range(7):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x, labels = data
        x = x.to(device)
        optimizer.zero_grad()
        #x, enc, loss, perplexity
        x_hat, enc, vq_loss, perplexity = model(x)
        mse_loss = torch.nn.functional.mse_loss(x, x_hat) / data_variance
        #mse_loss = mse_loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        #bce_loss = torch.nn.functional.binary_cross_entropy(x_hat.view(-1, 1024), x.view(-1, 1024), reduction='sum')
        #loss = bce_loss + kl_loss 
        loss = mse_loss + vq_loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % batch_size * 5 == batch_size * 5 - 5:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (batch_size * 5):.3f}')
            running_loss = 0.0

    if (epoch + 1) % 3 == 0:
      for g in optimizer.param_groups:
        learning_rate *= 0.1
        g['lr'] = learning_rate

In [None]:
# for num_embeddings,embedding_dim in trials.keys():
#     learning_rate = 1e-3
#     device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#     model = Autoencoder(n_input_channels, hidden_size, latent_dim, num_embeddings, embedding_dim, commitment_cost).to(device)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)

#     for epoch in range(7):
#         running_loss = 0.0
#         for i, data in enumerate(trainloader, 0):
#             x, labels = data
#             x = x.to(device)
#             optimizer.zero_grad()
#             #x, enc, loss, perplexity
#             x_hat, enc, vq_loss, perplexity = model(x)
#             mse_loss = torch.nn.functional.mse_loss(x, x_hat) / data_variance
#             #mse_loss = mse_loss.sum(dim=[1, 2, 3]).mean(dim=[0])
#             #bce_loss = torch.nn.functional.binary_cross_entropy(x_hat.view(-1, 1024), x.view(-1, 1024), reduction='sum')
#             #loss = bce_loss + kl_loss 
#             loss = mse_loss + vq_loss
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()
#             if i % batch_size * 5 == batch_size * 5 - 5:
#                 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (batch_size * 5):.3f}')
#                 running_loss = 0.0

#         if (epoch + 1) % 3 == 0:
#             for g in optimizer.param_groups:
#                 learning_rate *= 0.1
#                 g['lr'] = learning_rate
                
#     trials[num_embeddings,embedding_dim].append(running_loss / (batch_size * 5))
#     gc.collect()
# with open('tune_hyperperameters.pkl', 'bw') as f:
#         pickle.dump(trials, f)
# print(trials)

# with open('tune_hyperperameters.txt', 'w') as f:
#     for key, value in trials.items():
#         f.write(f'num_embeddings:{key[0]} - embedding_dim:{key[1]} - loss:{value}\n')

trials.items()

In [None]:
for data in testloader:
  x, labels = data
  x = x.to(device)

  with torch.no_grad():
    x_hat = model(x)[0]

  break

In [None]:
visualize_grid(x_hat)

In [None]:
visualize_grid(x)

In [None]:
'''means, lbls = [], []
for data in testloader:
  x, labels = data
  x = x.to(device)

  with torch.no_grad():
    x_mean = model.encoder(x)[0]
  means.append(x_mean)
  lbls.append(labels)

features = torch.cat(means,0)
features = features.detach().cpu().numpy()
labels = torch.cat(lbls).numpy()

tsne = TSNE(n_components=2).fit_transform(features)'''

In [None]:
'''colors = np.array(["red","green","blue","yellow","pink","black","orange","purple","beige","brown"])
c = np.array([colors[el] for el in labels])
tsne_sel = tsne#[(labels==1)|(labels==4)]
col_sel = c#[(labels==3)|(labels==5)]
plt.scatter(tsne_sel[:,0], tsne_sel[:,1], c=col_sel)'''