<a href="https://colab.research.google.com/github/prashamsa0512/NLP/blob/main/assignment%2010.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # For MNIST images (values between 0 and 1)
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

def train_vae(model, train_loader, num_epochs=50, learning_rate=1e-3, device='cuda'):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        recon_loss_total = 0
        kl_loss_total = 0

        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()

            # Forward pass
            recon_batch, mu, log_var = model(data)

            # Reconstruction loss (binary cross entropy)
            recon_loss = nn.functional.binary_cross_entropy(
                recon_batch, data.view(-1, 784), reduction='sum')

            # KL divergence loss
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

            # Total loss
            loss = recon_loss + kl_loss

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            recon_loss_total += recon_loss.item()
            kl_loss_total += kl_loss.item()

        # Print epoch statistics
        avg_loss = total_loss / len(train_loader.dataset)
        avg_recon_loss = recon_loss_total / len(train_loader.dataset)
        avg_kl_loss = kl_loss_total / len(train_loader.dataset)

        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Average Loss: {avg_loss:.4f}')
            print(f'Reconstruction Loss: {avg_recon_loss:.4f}')
            print(f'KL Loss: {avg_kl_loss:.4f}\n')

def visualize_results(model, test_loader, device='cuda'):
    model.eval()
    with torch.no_grad():
        # Get a batch of test data
        data, _ = next(iter(test_loader))
        data = data.to(device)

        # Reconstruct images
        recon_batch, _, _ = model(data)

        # Generate new images from random latent vectors
        z = torch.randn(8, 20).to(device)
        generated = model.decode(z)

        # Plot results
        plt.figure(figsize=(12, 6))

        # Original images
        for i in range(8):
            plt.subplot(3, 8, i + 1)
            plt.imshow(data[i].cpu().numpy().reshape(28, 28), cmap='gray')
            plt.axis('off')
            if i == 0:
                plt.title('Original')

        # Reconstructed images
        for i in range(8):
            plt.subplot(3, 8, i + 9)
            plt.imshow(recon_batch[i].cpu().numpy().reshape(28, 28), cmap='gray')
            plt.axis('off')
            if i == 0:
                plt.title('Reconstructed')

        # Generated images
        for i in range(8):
            plt.subplot(3, 8, i + 17)
            plt.imshow(generated[i].cpu().numpy().reshape(28, 28), cmap='gray')
            plt.axis('off')
            if i == 0:
                plt.title('Generated')

        plt.tight_layout()
        plt.show()

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, transform=transform, download=True)
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, transform=transform, download=True)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

    # Initialize model
    model = VAE().to(device)

    # Train the model
    train_vae(model, train_loader, device=device)

    # Visualize results
    visualize_results(model, test_loader, device=device)

if __name__ == "__main__":
    main()