In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# Hyperparameters
latent_dim = 20
batch_size = 256
learning_rate = 5e-4
epochs = 50

# Data Loading
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# VAE Model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 400),
            nn.ReLU(),
        )
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid(),
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu, logvar = self.fc_mu(x), self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

# Loss Function
def loss_function(recon_x, x, mu, logvar):
    reconstruction_loss = F.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_divergence

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_losses = []
for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    avg_loss = train_loss / len(train_loader.dataset)
    train_losses.append(avg_loss)
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

# Plot Loss Curve
plt.plot(range(1, epochs + 1), train_losses)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()

# Reconstruction and Visualization
def visualize_reconstructions(model, data_loader):
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(data_loader))
        data = data.to(device)
        recon_data, _, _ = model(data)
        data, recon_data = data.cpu(), recon_data.cpu()
        fig, axes = plt.subplots(2, 10, figsize=(10, 2))
        for i in range(10):
            axes[0, i].imshow(data[i].reshape(28, 28), cmap='gray')
            axes[1, i].imshow(recon_data[i].reshape(28, 28), cmap='gray')
            axes[0, i].axis('off')
            axes[1, i].axis('off')
        plt.show()

visualize_reconstructions(model, test_loader)

# Latent Space Visualization
def visualize_latent_space(model, data_loader):
    model.eval()
    latents, labels = [], []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            _, mu, _ = model(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
    latents, labels = np.vstack(latents), np.hstack(labels)
    tsne = TSNE(n_components=2)
    latents_2d = tsne.fit_transform(latents)
    plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=labels, cmap='jet', alpha=0.5)
    plt.colorbar()
    plt.title('Latent Space Visualization')
    plt.show()

visualize_latent_space(model, test_loader)

# Generate New Images
def generate_images(model, num_images=10):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)
        samples = model.decoder(z).cpu()
        fig, axes = plt.subplots(1, num_images, figsize=(10, 1))
        for i in range(num_images):
            axes[i].imshow(samples[i].reshape(28, 28), cmap='gray')
            axes[i].axis('off')
        plt.show()

generate_images(model)