In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [None]:
class Encoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim):
    super().__init__()
    self.model = torch.nn.Sequential(
        torch.nn.Conv2d(n_input_channels, hidden_size, kernel_size=3, stride=2, padding=1),
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, hidden_size, kernel_size=3),
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, 2*hidden_size, kernel_size=3),
        torch.nn.GELU(),
        torch.nn.Conv2d(2*hidden_size, 2*hidden_size, kernel_size=3, stride=2),
        torch.nn.GELU(),
        torch.nn.Flatten()
    )

    #self.linear_mean = torch.nn.Linear(2*hidden_size*25, latent_dim)
    #self.linear_logvar = torch.nn.Linear(2*hidden_size*25, latent_dim)
    self.linear_mean = torch.nn.Linear(2*hidden_size*16, latent_dim)
    self.linear_logvar = torch.nn.Linear(2*hidden_size*16, latent_dim)

  def forward(self, x):
    x = self.model(x)
    x_mean = self.linear_mean(x)
    x_logvar = self.linear_logvar(x)
    return x_mean, x_logvar

In [None]:
class Decoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim):
    super().__init__()
    self.linear = torch.nn.Sequential(torch.nn.Linear(latent_dim, 2 * 16 * hidden_size), torch.nn.GELU())

    self.model = torch.nn.Sequential(
        torch.nn.ConvTranspose2d(2*hidden_size, 2*hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1), 
        torch.nn.GELU(),
        torch.nn.Conv2d(2*hidden_size, 2*hidden_size, kernel_size=3, padding=1),
        torch.nn.GELU(),
        torch.nn.ConvTranspose2d(2*hidden_size, hidden_size, kernel_size=3, stride=2, output_padding=1, padding=1), 
        torch.nn.GELU(),
        torch.nn.Conv2d(hidden_size, hidden_size, kernel_size=3), # , padding=1
        torch.nn.GELU(),
        torch.nn.ConvTranspose2d(hidden_size, n_input_channels, kernel_size=3, stride=2, output_padding=1, padding=1), 
        torch.nn.Tanh(),
        #torch.nn.Sigmoid(),
    )

  def forward(self, x):
    x = self.linear(x)
    x = x.reshape(x.shape[0], -1, 4, 4)
    x = self.model(x)
    return x

In [None]:
class Autoencoder(torch.nn.Module):
  def __init__(self, n_input_channels, hidden_size, latent_dim):
    super().__init__()

    self.encoder = Encoder(n_input_channels, hidden_size, latent_dim)
    self.decoder = Decoder(n_input_channels, hidden_size, latent_dim)

  def forward(self, x):
    means, logvar = self.encoder(x)
    stds = torch.exp(logvar / 2)
    eps = torch.randn(stds.shape).to(device)
    if self.training:
      x = means + stds * eps
    else:
      x = means
    x = self.decoder(x)
    return x, means, logvar

In [None]:
def visualize_grid(x_batch):
  im_rec = Image.fromarray(torchvision.utils.make_grid((x_batch*0.5+0.5) * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
  #im_rec = Image.fromarray(torchvision.utils.make_grid((x_batch) * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8))
  return im_rec.resize((im_rec.size[0]*4, im_rec.size[1]*4))

In [None]:
n_input_channels, hidden_size, latent_dim = 1, 28, 128
batch_size = 128
learning_rate = 1e-3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Autoencoder(n_input_channels, hidden_size, latent_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
     ])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
for epoch in range(30):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x, labels = data
        x = x.to(device)
        optimizer.zero_grad()

        x_hat, means, logvar = model(x)
        kl_loss = -0.5 * torch.sum(1 + logvar - means.pow(2) - logvar.exp(), dim=1)
        kl_loss = torch.mean(kl_loss)
        mse_loss = torch.nn.functional.mse_loss(x, x_hat, reduction="none")
        mse_loss = mse_loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        #bce_loss = torch.nn.functional.binary_cross_entropy(x_hat.view(-1, 1024), x.view(-1, 1024), reduction='sum')
        #loss = bce_loss + kl_loss 
        loss = mse_loss + kl_loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % batch_size * 5 == batch_size * 5 - 5:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (batch_size * 5):.3f}')
            running_loss = 0.0

    if epoch % 20 == 0:
      for g in optimizer.param_groups:
        learning_rate *= 0.1
        g['lr'] = learning_rate

In [None]:
for data in testloader:
  x, labels = data
  x = x.to(device)

  with torch.no_grad():
    x_hat = model(x)[0]

  break

In [None]:
p = torch.distributions.Normal(torch.zeros(latent_dim).to(device), torch.ones(latent_dim).to(device))
z = p.rsample((batch_size,))
with torch.no_grad():
  x_gen = model.decoder(z)

In [None]:
visualize_grid(x_gen)

In [None]:
visualize_grid(x_hat)

In [None]:
visualize_grid(x)

In [None]:
means, lbls = [], []
for data in testloader:
  x, labels = data
  x = x.to(device)

  with torch.no_grad():
    x_mean = model.encoder(x)[0]
  means.append(x_mean)
  lbls.append(labels)

features = torch.cat(means,0)
print(features.shape, labels.shape)
features = features.detach().cpu().numpy()

labels = torch.cat(lbls).numpy()

print(features.shape, labels.shape)
print(len(means), len(lbls))

tsne = TSNE(n_components=2,learning_rate='auto',init='pca',perplexity=30).fit_transform(features)

In [None]:
colors = np.array(["red","green","blue","yellow","pink","black","orange","purple","beige","brown"])
c = np.array([colors[el] for el in labels])
tsne_sel = tsne#[(labels==1)|(labels==4)]
col_sel = c#[(labels==3)|(labels==5)]
plt.scatter(tsne_sel[:,0], tsne_sel[:,1], c=col_sel)