<a href="https://colab.research.google.com/github/prabhasg5/EPICS/blob/main/fashion_qgan_online.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# QGAN for Fashion Design (FashionGen-compatible)

This notebook implements a Quantum GAN (QGAN) where the Generator is a quantum variational circuit (PennyLane) wrapped into a PyTorch module and the Discriminator is a classical convolutional neural network (PyTorch).

Goal: train on the FashionGen dataset if provided (path above). If FashionGen isn't available locally, the notebook no longer falls back to Fashion-MNIST — the notebook requires the FashionGen dataset to be available locally.

What this notebook includes:
- Setup and dependency checks
- Dataset loader that accepts a FashionGen root path (images) and auto-detects flat or folder layouts
- Quantum Generator implemented with PennyLane's TorchLayer
- Classical Discriminator (PyTorch)
- Training loop with checkpoints and sample visualization

Notes & assumptions:
- FashionGen is large. Point `FASHIONGEN_PATH` (or place files in `./fashion-dataset/images`) to your local FashionGen images arranged as images or subfolders.
- The quantum circuit maps a small latent vector into features; a classical head upsamples to full image size.
- For a research review, run the full training on GPU (if available) and with the full FashionGen dataset. The notebook provides the training scaffolding and checkpointing.

Please run the cells in order from top to bottom.

In [2]:
#!pip install -q kagglehub torchvision


In [3]:
import sys
import subprocess
import importlib

def ensure_package(pkg_name, pip_name=None):
    pip_name = pip_name or pkg_name
    try:
        importlib.import_module(pkg_name)
        print(f"{pkg_name} already installed")
    except Exception:
        print(f"Installing {pip_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name])

# Install packages
ensure_package('pennylane')
ensure_package('torch')
ensure_package('torchvision')
ensure_package('matplotlib')
ensure_package('PIL', 'pillow')

# Imports
import os
import random
import math
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image
import pennylane as qml
from pennylane import numpy as pnp
from pennylane.qnn import TorchLayer

# ===== GPU SETUP =====
# Check CUDA availability and GPU info
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('=' * 60)
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'CUDA Version: {torch.version.cuda}')
    print(f'Available GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
    # Enable TF32 for better performance on Ampere GPUs (T4 doesn't have it, but safe to enable)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    # cuDNN optimization
    torch.backends.cudnn.benchmark = True
else:
    print('⚠️ WARNING: GPU not available! Using CPU.')
print('=' * 60)

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Utility: show image grid
def show_image_tensor(img_tensor, nrow=8, title=None, denorm=True):
    if denorm:
        img_tensor = (img_tensor + 1.0) / 2.0
    grid = make_grid(img_tensor.cpu(), nrow=nrow, normalize=False)
    plt.figure(figsize=(6,6))
    plt.axis('off')
    if title: plt.title(title)
    plt.imshow(grid.permute(1,2,0).clamp(0,1))
    plt.show()

print('✅ Setup complete')

Installing pennylane...
torch already installed
torchvision already installed
matplotlib already installed
PIL already installed
Device: cpu
✅ Setup complete




In [4]:
from glob import glob
from torch.utils.data import Dataset
from torchvision import datasets
import kagglehub

# Download dataset
path = kagglehub.dataset_download("paramaggarwal/fashion-product-images-dataset")
print("✅ Kaggle dataset downloaded to:", path)

data_root = os.path.join(path, "images")
if not os.path.exists(data_root):
    data_root = path
print("📁 Using data root:", data_root)

# ===== GPU-OPTIMIZED SETTINGS =====
image_size = 28  # Keep at 28 for quantum circuit compatibility; increase to 64/128 for better quality
batch_size = 128  # ⬆️ Increased from 64 - T4 can handle larger batches
num_workers = 2   # ⬆️ Increased from 0 - better CPU->GPU pipeline

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Helper: FlatImageDataset
class FlatImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.paths = sorted(
            glob(os.path.join(root, "*.jpg")) +
            glob(os.path.join(root, "*.jpeg")) +
            glob(os.path.join(root, "*.png"))
        )
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

# Detect structure
has_subdirs = any(os.path.isdir(os.path.join(data_root, p)) for p in os.listdir(data_root))
if has_subdirs:
    print(f"📦 Loading dataset using ImageFolder from {data_root}")
    dataset = datasets.ImageFolder(root=data_root, transform=transform)
else:
    print(f"📷 Loading dataset from flat folder: {data_root}")
    dataset = FlatImageDataset(data_root, transform=transform)

if len(dataset) == 0:
    raise RuntimeError(f"No images found in {data_root}")

# ===== GPU-OPTIMIZED DATALOADER =====
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,          # ✅ Speeds up CPU->GPU transfer
    persistent_workers=(num_workers > 0),
    prefetch_factor=2 if num_workers > 0 else None  # ✅ Prefetch batches
)

print(f"✅ Dataset size: {len(dataset)}")

# Detect channels
batch_example = next(iter(dataloader))
example_imgs = batch_example[0]
detected_channels = example_imgs.shape[1]
detected_size = example_imgs.shape[2]
print(f"🔍 Detected: channels={detected_channels}, image_size={detected_size}")

num_channels = detected_channels
image_size = detected_size

Downloading from https://www.kaggle.com/api/v1/datasets/download/paramaggarwal/fashion-product-images-dataset?dataset_version_number=1...


 73%|███████▎  | 16.7G/23.1G [07:51<02:58, 38.1MB/s]


KeyboardInterrupt: 

In [6]:
from glob import glob
from torch.utils.data import Dataset
from torchvision import datasets
import kagglehub

# Download dataset
path = kagglehub.dataset_download("paramaggarwal/fashion-product-images-dataset")
print("✅ Kaggle dataset downloaded to:", path)

data_root = os.path.join(path, "images")
if not os.path.exists(data_root):
    data_root = path
print("📁 Using data root:", data_root)

# ===== GPU-OPTIMIZED SETTINGS =====
image_size = 28  # Keep at 28 for quantum circuit compatibility; increase to 64/128 for better quality
batch_size = 128  # ⬆️ Increased from 64 - T4 can handle larger batches
num_workers = 0   # Set to 0 for Colab stability (multiprocessing issues)

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Helper: FlatImageDataset
class FlatImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.paths = sorted(
            glob(os.path.join(root, "*.jpg")) +
            glob(os.path.join(root, "*.jpeg")) +
            glob(os.path.join(root, "*.png"))
        )
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, 0

# Detect structure
has_subdirs = any(os.path.isdir(os.path.join(data_root, p)) for p in os.listdir(data_root))
if has_subdirs:
    print(f"📦 Loading dataset using ImageFolder from {data_root}")
    dataset = datasets.ImageFolder(root=data_root, transform=transform)
else:
    print(f"📷 Loading dataset from flat folder: {data_root}")
    dataset = FlatImageDataset(data_root, transform=transform)

if len(dataset) == 0:
    raise RuntimeError(f"No images found in {data_root}")

# ===== GPU-OPTIMIZED DATALOADER =====
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,          # ✅ Speeds up CPU->GPU transfer
    persistent_workers=False  # Must be False when num_workers=0
)

print(f"✅ Dataset size: {len(dataset)}")

# Detect channels
batch_example = next(iter(dataloader))
example_imgs = batch_example[0]
detected_channels = example_imgs.shape[1]
detected_size = example_imgs.shape[2]
print(f"🔍 Detected: channels={detected_channels}, image_size={detected_size}")

num_channels = detected_channels
image_size = detected_size


Resuming download from 17975738368 bytes (6795477372 bytes left)...
Resuming download from https://www.kaggle.com/api/v1/datasets/download/paramaggarwal/fashion-product-images-dataset?dataset_version_number=1 (17975738368/24771215740) bytes left.


100%|██████████| 23.1G/23.1G [02:48<00:00, 40.3MB/s]

Extracting files...





✅ Kaggle dataset downloaded to: /root/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1
📁 Using data root: /root/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1
📦 Loading dataset using ImageFolder from /root/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1
✅ Dataset size: 88882




🔍 Detected: channels=3, image_size=28


In [7]:
from google.colab import drive
drive.mount('/content/drive')
# After mounting
checkpoint_path = "/content/drive/MyDrive/qgan_training/checkpoint_latest.pth"
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)


Mounted at /content/drive


In [11]:
latent_dim = 8
n_qubits = 4
n_layers = 3
image_dim = num_channels * image_size * image_size

# ===== QUANTUM DEVICE =====
# NOTE: PennyLane's quantum circuits run on CPU, but we optimize data flow
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch", diff_method="backprop")
def qcircuit(inputs, weights):
    qml.templates.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

weight_shapes = {"weights": (n_layers, n_qubits, 3)}
qlayer = qml.qnn.TorchLayer(qcircuit, weight_shapes)

class QuantumGenerator(nn.Module):
    def __init__(self, latent_dim=latent_dim, n_qubits=n_qubits,
                 image_dim=image_dim, out_channels=num_channels):
        super().__init__()
        self.pre = nn.Linear(latent_dim, n_qubits)
        self.qlayer = qlayer
        post_hidden = 256
        self.post = nn.Sequential(
            nn.Linear(n_qubits, post_hidden),
            nn.ReLU(True),
            nn.Linear(post_hidden, image_dim),
            nn.Tanh(),
        )
        self.out_channels = out_channels

    def forward(self, z):
        x = self.pre(z)
        x = self.qlayer(x)  # Quantum layer (CPU bottleneck)
        out = self.post(x)
        out = out.view(-1, self.out_channels, image_size, image_size)
        return out

# ===== INITIALIZE ON GPU =====
gen = QuantumGenerator().to(device)
print(gen)
print(f"✅ Generator on: {next(gen.parameters()).device}")

# Smoke test
with torch.no_grad():
    sample_z = torch.randn(4, latent_dim, device=device)
    fake = gen(sample_z)
    print(f'✅ Fake image shape: {fake.shape}')

QuantumGenerator(
  (pre): Linear(in_features=8, out_features=4, bias=True)
  (qlayer): <Quantum Torch Layer: func=qcircuit>
  (post): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=2352, bias=True)
    (3): Tanh()
  )
)
✅ Generator on: cpu
✅ Fake image shape: torch.Size([4, 3, 28, 28])


In [12]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=num_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * (image_size // 16) * (image_size // 16), 1),
            # ✅ REMOVED Sigmoid - use BCEWithLogitsLoss instead
        )

    def forward(self, x):
        return self.model(x).view(-1)

# ===== INITIALIZE ON GPU =====
disc = Discriminator().to(device)
print(f"✅ Discriminator on: {next(disc.parameters()).device}")

# ===== GPU-OPTIMIZED OPTIMIZERS =====
# Increased learning rate slightly for faster convergence
optimG = torch.optim.Adam(gen.parameters(), lr=3e-4, betas=(0.5, 0.999))
optimD = torch.optim.Adam(disc.parameters(), lr=3e-4, betas=(0.5, 0.999))

# ✅ USE BCEWithLogitsLoss (combines Sigmoid + BCE, safe for mixed precision)
criterion = nn.BCEWithLogitsLoss()

print("✅ Models and optimizers ready")

✅ Discriminator on: cpu
✅ Models and optimizers ready


In [None]:
# # Delete old checkpoint to start training from scratch
# import os
# if os.path.exists(checkpoint_path):
#     os.remove(checkpoint_path)
#     print("✅ Old checkpoint deleted")

In [13]:
checkpoint_path = "/content/checkpoint_latest.pth"
save_every_batches = 500
atomic_tmp = checkpoint_path + ".tmp"

def save_checkpoint(path, epoch, batch_idx, gen, disc, optimG, optimD):
    ckpt = {
        "epoch": epoch,
        "batch_idx": batch_idx,
        "gen_state": gen.state_dict(),
        "disc_state": disc.state_dict(),
        "optimG_state": optimG.state_dict(),
        "optimD_state": optimD.state_dict(),
        "torch_rng": torch.get_rng_state(),
        "cuda_rng_all": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
        "np_rng": np.random.get_state(),
        "py_random": random.getstate(),
    }
    torch.save(ckpt, atomic_tmp)
    os.replace(atomic_tmp, path)
    print(f"💾 Saved checkpoint -> {path} (epoch={epoch}, batch={batch_idx})")

def load_checkpoint(path, gen, disc, optimG, optimD, device):
    if not os.path.exists(path):
        return 1, 0
    print("📥 Loading checkpoint:", path)

    # Fix for PyTorch 2.6+ weights_only default change
    # Safe to set False since we trust our own checkpoints
    ckpt = torch.load(path, map_location=device, weights_only=False)

    gen.load_state_dict(ckpt["gen_state"])
    disc.load_state_dict(ckpt["disc_state"])
    optimG.load_state_dict(ckpt["optimG_state"])
    optimD.load_state_dict(ckpt["optimD_state"])

    # Fix for RNG state type compatibility
    try:
        torch_rng = ckpt["torch_rng"]
        if isinstance(torch_rng, np.ndarray):
            torch_rng = torch.from_numpy(torch_rng)
        torch.set_rng_state(torch_rng)
    except Exception as e:
        print(f"⚠️ Could not restore torch RNG state: {e}")

    try:
        if torch.cuda.is_available() and ckpt.get("cuda_rng_all") is not None:
            cuda_rng = ckpt["cuda_rng_all"]
            if isinstance(cuda_rng, list):
                cuda_rng = [torch.from_numpy(s) if isinstance(s, np.ndarray) else s for s in cuda_rng]
            torch.cuda.set_rng_state_all(cuda_rng)
    except Exception as e:
        print(f"⚠️ Could not restore CUDA RNG state: {e}")

    try:
        np.random.set_state(ckpt["np_rng"])
    except Exception as e:
        print(f"⚠️ Could not restore numpy RNG state: {e}")

    try:
        random.setstate(ckpt["py_random"])
    except Exception as e:
        print(f"⚠️ Could not restore Python RNG state: {e}")

    start_epoch = int(ckpt.get("epoch", 1))
    start_batch = int(ckpt.get("batch_idx", 0))
    print(f"✅ Resuming from epoch={start_epoch}, batch={start_batch}")
    return start_epoch, start_batch

start_epoch, start_batch = load_checkpoint(checkpoint_path, gen, disc, optimG, optimD, device)



In [None]:
from torch.cuda.amp import autocast, GradScaler

# ===== MIXED PRECISION TRAINING =====
# Use updated API for PyTorch 2.0+
use_amp = torch.cuda.is_available()
try:
    # New API (PyTorch 2.0+)
    scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
except AttributeError:
    # Fallback for older PyTorch versions
    scaler = GradScaler(enabled=use_amp)

# Use new autocast API
try:
    from torch.amp import autocast
    autocast_context = lambda: autocast('cuda', enabled=use_amp)
except ImportError:
    from torch.cuda.amp import autocast
    autocast_context = lambda: autocast(enabled=use_amp)

print(f"🚀 Mixed Precision Training: {'ENABLED' if use_amp else 'DISABLED'}")

# Training hyperparams
num_epochs = 10
print_every = 50

# Track timing
epoch_start_time = time.time()
batch_count = 0

for epoch in range(start_epoch, num_epochs + 1):
    for batch_idx, (real_imgs, _) in enumerate(dataloader):
        # Skip batches if resuming
        if epoch == start_epoch and batch_idx < start_batch:
            continue

        batch_count += 1
        current_batch_size = real_imgs.size(0)

        # ===== MOVE DATA TO GPU =====
        real_imgs = real_imgs.to(device, non_blocking=True)  # ✅ non_blocking for speed

        # Labels
        real_labels = torch.ones(current_batch_size, device=device)
        fake_labels = torch.zeros(current_batch_size, device=device)

        # ========== Train Discriminator ==========
        optimD.zero_grad(set_to_none=True)  # ✅ Faster than zero_grad()

        with autocast_context():  # ✅ Mixed precision
            # Real images
            real_out = disc(real_imgs)
            loss_real = criterion(real_out, real_labels)

            # Fake images
            z = torch.randn(current_batch_size, latent_dim, device=device)
            fake_imgs = gen(z).detach()
            fake_out = disc(fake_imgs)
            loss_fake = criterion(fake_out, fake_labels)

            loss_D = loss_real + loss_fake

        scaler.scale(loss_D).backward()
        scaler.step(optimD)
        scaler.update()

        # ========== Train Generator ==========
        optimG.zero_grad(set_to_none=True)

        with autocast_context():
            z = torch.randn(current_batch_size, latent_dim, device=device)
            fake_imgs = gen(z)
            fake_out = disc(fake_imgs)
            loss_G = criterion(fake_out, real_labels)

        scaler.scale(loss_G).backward()
        scaler.step(optimG)
        scaler.update()

        # ========== Logging ==========
        if batch_count % print_every == 0:
            elapsed = time.time() - epoch_start_time
            batches_per_sec = batch_count / elapsed
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] "
                  f"D_loss: {loss_D.item():.4f} | G_loss: {loss_G.item():.4f} "
                  f"| Speed: {batches_per_sec:.2f} batches/s")

        # ========== Checkpointing ==========
        if batch_count % save_every_batches == 0:
            save_checkpoint(checkpoint_path, epoch, batch_idx, gen, disc, optimG, optimD)

    # End of epoch
    save_checkpoint(checkpoint_path, epoch, len(dataloader), gen, disc, optimG, optimD)
    print(f"✅ Epoch {epoch} complete")

    # Show generated samples
    with torch.no_grad():
        sample_z = torch.randn(16, latent_dim, device=device)
        samples = gen(sample_z)
        show_image_tensor(samples, nrow=4, title=f'Generated Samples - Epoch {epoch}')

print("🎉 Training complete!")


In [None]:
def load_checkpoint_for_inference(path, gen, disc, optG=None, optD=None, map_location=None):
    """Load checkpoint for inference or resuming training"""
    ckpt = torch.load(path, map_location=map_location)
    gen.load_state_dict(ckpt['gen_state'])
    disc.load_state_dict(ckpt['disc_state'])
    if optG and 'optimG_state' in ckpt:
        optG.load_state_dict(ckpt['optimG_state'])
    if optD and 'optimD_state' in ckpt:
        optD.load_state_dict(ckpt['optimD_state'])
    print(f'✅ Loaded checkpoint: {path}')
    return ckpt.get('epoch', None)

# Create fixed noise for consistent visualization
fixed_noise = torch.randn(64, latent_dim, device=device)

# Try to load the latest checkpoint
if os.path.exists(checkpoint_path):
    print("📥 Loading checkpoint for visualization...")
    load_checkpoint_for_inference(checkpoint_path, gen, disc, optimG, optimD, map_location=device)

    # Generate samples
    gen.eval()  # Set to evaluation mode
    with torch.no_grad():
        samples = gen(fixed_noise)
        show_image_tensor(samples, nrow=8, title='Generated Samples from Checkpoint')
    gen.train()  # Back to training mode
else:
    print('⚠️ No checkpoint found at:', checkpoint_path)
    print('💡 Generate some samples with the untrained model:')
    with torch.no_grad():
        samples = gen(fixed_noise[:16])
        show_image_tensor(samples, nrow=4, title='Samples (Untrained Model)')

print('\n' + '='*60)
print('✅ Notebook ready!')
print('='*60)
print('📝 Training tips:')
print('  • Current batch size: 128 (optimal for T4 GPU)')
print('  • Mixed precision: ENABLED for 2-3x speed boost')
print('  • Checkpoints save every 500 batches')
print('  • To train longer: increase num_epochs and rerun training cell')
print('='*60)