<a href="https://colab.research.google.com/github/tahirabatool123/Task/blob/main/dlgm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch torchvision --quiet


In [None]:
# VAE for MNIST - PyTorch implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import numpy as np
import random

# Reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------
# Model Class
# ------------------------
class VAEModel(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAEModel, self).__init__()
        self.latent_dim = latent_dim
        # Encoder: Conv layers -> flatten -> linear -> mu/logvar
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),  # 28->14
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), # 14->7
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1), # 7->4 (approx)
            nn.ReLU(),
        )
        # compute flattened size dynamically (we know approx 128*4*4=2048)
        self.fc_mu = nn.Linear(128*4*4, latent_dim)
        self.fc_logvar = nn.Linear(128*4*4, latent_dim)

        # Decoder: linear -> reshape -> ConvTranspose layers
        self.fc_dec = nn.Linear(latent_dim, 128*4*4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4->8
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # 8->16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),  # 16->32
            nn.ReLU(),
            # final conv to get back to 28x28 (we will crop/resize to 28)
            nn.Conv2d(16, 1, 3, 1, 1),
            nn.Sigmoid()  # output pixels in [0,1]
        )

    def encode(self, x):
        h = self.encoder(x)                       # [B,128,4,4]
        h = h.view(h.size(0), -1)                 # [B, 128*4*4]
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # logvar is log(sigma^2)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + std * eps
        return z

    def decode(self, z):
        h = self.fc_dec(z)
        h = h.view(-1, 128, 4, 4)                 # reshape
        x_rec = self.decoder(h)                   # may be 32x32 -> we'll center-crop to 28
        # center-crop to 28x28 if needed
        if x_rec.shape[-1] != 28:
            # crop centered
            _, _, H, W = x_rec.shape
            start_h = (H - 28) // 2
            start_w = (W - 28) // 2
            x_rec = x_rec[:, :, start_h:start_h+28, start_w:start_w+28]
        return x_rec

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

# ------------------------
# Loss Function
# ------------------------
def vae_loss_function(reconstruction, mu, log_var, original):
    """
    reconstruction, original: tensors in [0,1], shape [B,1,28,28]
    mu, log_var: [B, latent_dim]
    """
    # Reconstruction loss: BCE summed over pixels then mean over batch
    bce = F.binary_cross_entropy(reconstruction, original, reduction='sum')  # sum over all pixels and batch
    # KL divergence per batch summed
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    # For optimization we return total loss
    total_loss = bce + kl
    # To return per-image averages for tracking:
    batch_size = original.size(0)
    recon_loss_per_image = bce / batch_size
    kl_per_image = kl / batch_size
    return total_loss, recon_loss_per_image.item(), kl_per_image.item()

# ------------------------
# Training Function
# ------------------------
def train_vae(model, dataloader, epochs=30, lr=1e-3, print_every=1):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    history = {'recon_per_image':[], 'kl_per_image':[]}
    for epoch in range(1, epochs+1):
        running_recon = 0.0
        running_kl = 0.0
        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            loss, recon_per_image, kl_per_image = vae_loss_function(recon, mu, logvar, data)
            loss.backward()
            optimizer.step()
            running_recon += recon_per_image
            running_kl += kl_per_image
        avg_recon = running_recon / len(dataloader)
        avg_kl = running_kl / len(dataloader)
        history['recon_per_image'].append(avg_recon)
        history['kl_per_image'].append(avg_kl)
        if epoch % print_every == 0:
            print(f"Epoch {epoch}/{epochs}  Recon_per_image: {avg_recon:.4f}  KL_per_image: {avg_kl:.4f}")
    return model, history

# ------------------------
# Sample Generation Function
# ------------------------
def generate_samples(model, num_samples=10):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, model.latent_dim).to(device)
        samples = model.decode(z)  # [num,1,28,28]
        samples = samples.cpu()
    # ensure shape (10,1,28,28)
    return samples

# ------------------------
# Run Experiment
# ------------------------
def run_vae_experiment(batch_size=128, epochs=30, latent_dim=20, lr=1e-3, samples_to_generate=10):
    # Data
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Model
    model = VAEModel(latent_dim=latent_dim).to(device)

    # Train
    model, history = train_vae(model, train_loader, epochs=epochs, lr=lr)

    # Calculate final average KL on test set (per image)
    model.eval()
    total_kl = 0.0
    batches = 0
    originals_list = []
    reconstructions_list = []
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            # compute kl per image for this batch
            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            kl_per_image = kl / data.size(0)
            total_kl += kl_per_image.item()
            batches += 1
            # store first batch images for reconstruction examples
            if len(originals_list) < 5:
                originals_list.append(data.cpu())
                reconstructions_list.append(recon.cpu())
    avg_kl_loss = total_kl / batches if batches > 0 else 0.0

    # Reconstruction: take first 5 test images (flatten lists)
    originals_tensor = torch.cat(originals_list, dim=0)[:5]      # shape (5,1,28,28)
    recon_tensor = torch.cat(reconstructions_list, dim=0)[:5]    # shape (5,1,28,28)

    # MSE for reconstructions (mean over pixels & images)
    mse_value = torch.mean((originals_tensor - recon_tensor) ** 2).item()

    # Generate samples
    generated_samples = generate_samples(model, num_samples=samples_to_generate)  # (10,1,28,28)

    # Pixel std across all generated samples
    pixel_std = torch.std(generated_samples).item()

    # Ensure shapes as expected
    assert generated_samples.shape in [(samples_to_generate,1,28,28)], "Generated samples shape mismatch"

    return model, generated_samples, recon_tensor, avg_kl_loss, mse_value, pixel_std, history

# ------------------------
# Example runner (call this)
# ------------------------
if __name__ == "__main__":
    # Recommended: run in Colab or machine with GPU
    vae_model, generated_samples, reconstructed_samples, avg_kl, mse_val, pix_std, history = run_vae_experiment(
        batch_size=128, epochs=30, latent_dim=20, lr=1e-3, samples_to_generate=10
    )
    print("Final results:")
    print("Generated samples shape:", generated_samples.shape)
    print(f"Avg KL per image: {avg_kl:.4f}")
    print(f"MSE (first 5 reconstructions): {mse_val:.6f}")
    print(f"Pixel std of generated samples: {pix_std:.6f}")

    # Save some sample grids if you like:
    utils.save_image(generated_samples, "generated_samples.png", nrow=5)
    # Save original vs recon (side by side)
    side_by_side = torch.cat([reconstructed_samples, reconstructed_samples], dim=0) # placeholder
    utils.save_image(reconstructed_samples, "reconstructions.png", nrow=5)


In [None]:
# --- Save trained model ---
model_path = "vae_model.pth"
torch.save(vae_model.state_dict(), model_path)
print("Model saved successfully as", model_path)


In [None]:
# --- Load model ---
loaded_model = VAEModel(latent_dim=20)
loaded_model.load_state_dict(torch.load("vae_model.pth", map_location=device))
loaded_model.to(device)
loaded_model.eval()
print("Model loaded successfully!")


In [None]:
from google.colab import drive
drive.mount('/content/drive')