In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/ML/

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
sns.set()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# VAE Encoder
latent_dim = 100
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc_mu = nn.Linear(4 * 4 * 64 * 4, latent_dim)
        self.fc_logvar = nn.Linear(4 * 4 * 64 * 4, latent_dim)

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# VAE Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 4 * 4 * 64 * 4)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(64 * 4, 64 * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64 * 4, 4, 4)
        x = self.main(x)
        return x

# VAE
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decoder(z)
        return x_reconst, mu, logvar

# MNIST dataset
batch_size = 128
image_size = 28
channels = 1
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

def visualize_latent_space(model, dataloader, num_samples=1000):
    model.eval()
    latent_vectors = []
    labels = []

    with torch.no_grad():
        for i, (images, image_labels) in enumerate(dataloader):
            if i * batch_size >= num_samples:
                break
            images = images.to(device)
            mu, _ = model.encoder(images)
            latent_vectors.append(mu.cpu().numpy())
            labels.extend(image_labels)

    latent_vectors = np.concatenate(latent_vectors, axis=0)
    labels = np.array(labels)
    tsne = TSNE(n_components=2, random_state=0)
    latent_2D = tsne.fit_transform(latent_vectors)

    fig, ax = plt.subplots(figsize=(10, 10))
    scatter = ax.scatter(latent_2D[:, 0], latent_2D[:, 1], c=labels, cmap='viridis', s=20, alpha=0.8)
    legend = ax.legend(*scatter.legend_elements(), title="Digits", loc="upper right", title_fontsize=12)
    ax.add_artist(legend)
    plt.title("Clusters of digits in VAE latent space")
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.savefig('./MNIST_VAE_Clusters.png', format='png', dpi=300)
    plt.show()

# Load the model
vae = VAE()
vae.load_state_dict(torch.load("./Weights/MNIST_VAE.pth"))
vae.to(device)

# Load a dataset with a smaller batch size for visualization
small_batch_size = 500
small_data_loader = DataLoader(mnist_data, batch_size=small_batch_size, shuffle=True)

# Visualize the latent space
visualize_latent_space(vae, small_data_loader)