# Conditional DCGAN for Prostate Cancer Biopsy Image Synthesis

This notebook implements a Conditional Deep Convolutional GAN (CDCGAN) for synthesizing prostate cancer histopathology images using the PANDA dataset.

## Features
- Conditional generation by ISUP grade (0-5)
- Spectral normalization for training stability
- Data augmentation
- FID score evaluation

In [None]:
# Install dependencies
!pip install torch torchvision tqdm matplotlib numpy pandas pillow scikit-image scipy

In [None]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.nn.utils import spectral_norm
from scipy import linalg
from torchvision.models import inception_v3

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Configuration

In [None]:
class Config:
    image_size = 256
    nc = 3
    nz = 128
    ngf = 64
    ndf = 64
    num_classes = 6
    embed_dim = 128
    batch_size = 16
    num_epochs = 200
    lr_g = 0.0002
    lr_d = 0.0002
    beta1 = 0.5
    beta2 = 0.999
    label_smoothing = 0.1
    checkpoint_dir = './checkpoints'
    samples_dir = './samples'

config = Config()
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.samples_dir, exist_ok=True)

## Data Paths and Download Instructions

In [None]:
DATA_DIR = './panda_data'
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, 'train_images')
TRAIN_CSV = os.path.join(DATA_DIR, 'train.csv')
PATCHES_DIR = os.path.join(DATA_DIR, 'patches_256')

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(PATCHES_DIR, exist_ok=True)

print('Download PANDA dataset from:')
print('https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/data')
print('\nUsing Kaggle API:')
print('!pip install kaggle')
print('!kaggle competitions download -c prostate-cancer-grade-assessment')

## Patch Extraction from WSI

In [None]:
def extract_patches(image_path, patch_size=256, max_patches=20, tissue_thresh=0.5):
    try:
        img = Image.open(image_path).convert('RGB')
    except Exception as e:
        return []
    
    w, h = img.size
    patches = []
    
    for i in range((h - patch_size) // patch_size + 1):
        for j in range((w - patch_size) // patch_size + 1):
            x, y = j * patch_size, i * patch_size
            if x + patch_size > w or y + patch_size > h:
                continue
            
            patch = img.crop((x, y, x + patch_size, y + patch_size))
            arr = np.array(patch)
            tissue_ratio = np.mean(np.mean(arr, axis=2) < 220)
            
            if tissue_ratio >= tissue_thresh and np.var(arr) > 100:
                patches.append((patch, tissue_ratio))
    
    patches.sort(key=lambda x: x[1], reverse=True)
    return [p[0] for p in patches[:max_patches]]


def preprocess_dataset(csv_path, images_dir, output_dir, max_per_class=500):
    df = pd.read_csv(csv_path)
    print(f'Total images: {len(df)}')
    print(df['isup_grade'].value_counts().sort_index())
    
    for g in range(6):
        os.makedirs(os.path.join(output_dir, str(g)), exist_ok=True)
    
    counts = {i: 0 for i in range(6)}
    
    for _, row in tqdm(df.iterrows(), total=len(df)):
        grade = row['isup_grade']
        if counts[grade] >= max_per_class:
            continue
        
        img_path = os.path.join(images_dir, f"{row['image_id']}.tiff")
        if not os.path.exists(img_path):
            continue
        
        for idx, patch in enumerate(extract_patches(img_path)):
            patch.save(os.path.join(output_dir, str(grade), f"{row['image_id']}_p{idx}.png"))
        counts[grade] += 1
    
    print('Done!')

## Dataset Class

In [None]:
class PANDADataset(Dataset):
    def __init__(self, patches_dir, transform=None, balance=True):
        self.transform = transform
        self.samples = []
        
        for grade in range(6):
            gdir = os.path.join(patches_dir, str(grade))
            if os.path.exists(gdir):
                for f in os.listdir(gdir):
                    if f.endswith(('.png', '.jpg')):
                        self.samples.append((os.path.join(gdir, f), grade))
        
        if balance and self.samples:
            counts = {}
            for _, g in self.samples:
                counts[g] = counts.get(g, 0) + 1
            max_c = max(counts.values())
            balanced = []
            for g in range(6):
                gs = [s for s in self.samples if s[1] == g]
                if gs:
                    balanced.extend(gs * (max_c // len(gs) + 1))
            random.shuffle(balanced)
            self.samples = balanced[:max_c * 6]
        
        print(f'Loaded {len(self.samples)} samples')
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label


def get_transforms(size, augment=True):
    if augment:
        return transforms.Compose([
            transforms.Resize((size, size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(90),
            transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])
    return transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

## Generator Network

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc, num_classes, embed_dim):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)
        inp = nz + embed_dim
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(inp, ngf*32, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*32), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*32, ngf*16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*16), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*16, ngf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf), nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, z, labels):
        emb = self.label_emb(labels)
        x = torch.cat([z, emb], 1).view(z.size(0), -1, 1, 1)
        return self.main(x)

## Discriminator Network

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ndf, nc, num_classes, img_size):
        super().__init__()
        self.img_size = img_size
        self.label_emb = nn.Embedding(num_classes, img_size * img_size)
        
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(nc+1, ndf, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf*8, ndf*16, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf*16), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf*16, ndf*32, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(ndf*32), nn.LeakyReLU(0.2, True),
            spectral_norm(nn.Conv2d(ndf*32, 1, 4, 1, 0, bias=False)),
            nn.Sigmoid()
        )
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) and not hasattr(m, 'weight_orig'):
                nn.init.normal_(m.weight, 0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1, 0.02)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, imgs, labels):
        bs = imgs.size(0)
        lmap = self.label_emb(labels).view(bs, 1, self.img_size, self.img_size)
        x = torch.cat([imgs, lmap], 1)
        return self.main(x).view(-1, 1)

## Initialize Models

In [None]:
G = Generator(config.nz, config.ngf, config.nc, config.num_classes, config.embed_dim).to(device)
D = Discriminator(config.ndf, config.nc, config.num_classes, config.image_size).to(device)

print(f'Generator params: {sum(p.numel() for p in G.parameters()):,}')
print(f'Discriminator params: {sum(p.numel() for p in D.parameters()):,}')

# Test
z = torch.randn(2, config.nz, device=device)
y = torch.randint(0, 6, (2,), device=device)
fake = G(z, y)
print(f'G output: {fake.shape}')
print(f'D output: {D(fake, y).shape}')

## Training Setup

In [None]:
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=config.lr_g, betas=(config.beta1, config.beta2))
opt_D = optim.Adam(D.parameters(), lr=config.lr_d, betas=(config.beta1, config.beta2))
sched_G = optim.lr_scheduler.StepLR(opt_G, 50, 0.5)
sched_D = optim.lr_scheduler.StepLR(opt_D, 50, 0.5)

## Training Functions

In [None]:
def train_step(real_imgs, labels):
    bs = real_imgs.size(0)
    real_imgs = real_imgs.to(device)
    labels = labels.to(device)
    
    real_lbl = torch.full((bs, 1), 1 - config.label_smoothing, device=device)
    fake_lbl = torch.zeros(bs, 1, device=device)
    
    # Train D
    D.zero_grad()
    out_real = D(real_imgs, labels)
    loss_real = criterion(out_real, real_lbl)
    
    z = torch.randn(bs, config.nz, device=device)
    fake = G(z, labels)
    out_fake = D(fake.detach(), labels)
    loss_fake = criterion(out_fake, fake_lbl)
    
    loss_D = loss_real + loss_fake
    loss_D.backward()
    opt_D.step()
    
    # Train G
    G.zero_grad()
    z = torch.randn(bs, config.nz, device=device)
    fake = G(z, labels)
    out = D(fake, labels)
    loss_G = criterion(out, real_lbl)
    loss_G.backward()
    opt_G.step()
    
    return loss_D.item(), loss_G.item(), out_real.mean().item(), out_fake.mean().item()


def generate_samples(n_per_class=4):
    G.eval()
    samples = []
    with torch.no_grad():
        for g in range(6):
            z = torch.randn(n_per_class, config.nz, device=device)
            y = torch.full((n_per_class,), g, dtype=torch.long, device=device)
            samples.append(G(z, y))
    G.train()
    return torch.cat(samples)


def save_grid(samples, epoch):
    samples = (samples + 1) / 2
    grid = vutils.make_grid(samples.clamp(0, 1), nrow=4, padding=2)
    plt.figure(figsize=(12, 18))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(f'Epoch {epoch}')
    plt.savefig(f'{config.samples_dir}/epoch_{epoch:04d}.png', dpi=150, bbox_inches='tight')
    plt.close()

## Main Training Loop

In [None]:
def train(dataloader, epochs):
    history = {'d_loss': [], 'g_loss': [], 'd_real': [], 'd_fake': []}
    
    for epoch in range(1, epochs + 1):
        d_losses, g_losses, d_reals, d_fakes = [], [], [], []
        
        for imgs, labels in dataloader:
            ld, lg, dr, df = train_step(imgs, labels)
            d_losses.append(ld)
            g_losses.append(lg)
            d_reals.append(dr)
            d_fakes.append(df)
        
        history['d_loss'].append(np.mean(d_losses))
        history['g_loss'].append(np.mean(g_losses))
        history['d_real'].append(np.mean(d_reals))
        history['d_fake'].append(np.mean(d_fakes))
        
        sched_G.step()
        sched_D.step()
        
        if epoch % 5 == 0 or epoch == 1:
            print(f'[{epoch}/{epochs}] D:{history["d_loss"][-1]:.4f} G:{history["g_loss"][-1]:.4f}')
        
        if epoch % 10 == 0 or epoch == 1:
            save_grid(generate_samples(), epoch)
        
        if epoch % 25 == 0:
            torch.save({'G': G.state_dict(), 'D': D.state_dict(), 'epoch': epoch},
                       f'{config.checkpoint_dir}/ckpt_{epoch}.pt')
    
    torch.save(G.state_dict(), f'{config.checkpoint_dir}/G_final.pt')
    return history

## Visualization

In [None]:
def plot_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    ax1.plot(history['d_loss'], label='D')
    ax1.plot(history['g_loss'], label='G')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.set_title('Losses')
    
    ax2.plot(history['d_real'], label='D(x)')
    ax2.plot(history['d_fake'], label='D(G(z))')
    ax2.axhline(0.5, c='r', ls='--', alpha=0.5)
    ax2.set_xlabel('Epoch')
    ax2.legend()
    ax2.set_title('Discriminator Output')
    plt.tight_layout()
    plt.savefig(f'{config.samples_dir}/history.png', dpi=150)
    plt.show()


def visualize_by_grade(n=4):
    G.eval()
    grades = ['0-Benign', '1-G3+3', '2-G3+4', '3-G4+3', '4-G4+4', '5-High']
    fig, axes = plt.subplots(6, n, figsize=(n*3, 18))
    
    with torch.no_grad():
        for g in range(6):
            z = torch.randn(n, config.nz, device=device)
            y = torch.full((n,), g, dtype=torch.long, device=device)
            imgs = ((G(z, y) + 1) / 2).clamp(0, 1)
            for i in range(n):
                axes[g, i].imshow(imgs[i].permute(1, 2, 0).cpu())
                axes[g, i].axis('off')
            axes[g, 0].set_ylabel(grades[g])
    
    plt.suptitle('Generated by ISUP Grade')
    plt.tight_layout()
    plt.savefig(f'{config.samples_dir}/by_grade.png', dpi=150)
    plt.show()
    G.train()

## FID Score

In [None]:
class FIDCalculator:
    def __init__(self):
        self.inception = inception_v3(pretrained=True, transform_input=False)
        self.inception.fc = nn.Identity()
        self.inception = self.inception.to(device).eval()
        self.resize = transforms.Resize((299, 299))
        self.norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    def get_features(self, imgs):
        imgs = (imgs + 1) / 2
        imgs = self.norm(self.resize(imgs))
        with torch.no_grad():
            return self.inception(imgs)
    
    def compute(self, real_loader, generator, n=1000):
        generator.eval()
        
        real_feats = []
        cnt = 0
        for imgs, _ in real_loader:
            if cnt >= n: break
            real_feats.append(self.get_features(imgs.to(device)))
            cnt += imgs.size(0)
        real_feats = torch.cat(real_feats)[:n].cpu().numpy()
        
        fake_feats = []
        with torch.no_grad():
            while len(fake_feats) * 32 < n:
                z = torch.randn(32, config.nz, device=device)
                y = torch.randint(0, 6, (32,), device=device)
                fake_feats.append(self.get_features(generator(z, y)))
        fake_feats = torch.cat(fake_feats)[:n].cpu().numpy()
        
        mu1, s1 = np.mean(real_feats, 0), np.cov(real_feats, rowvar=False)
        mu2, s2 = np.mean(fake_feats, 0), np.cov(fake_feats, rowvar=False)
        
        diff = mu1 - mu2
        covmean = linalg.sqrtm(s1.dot(s2))
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        generator.train()
        return diff.dot(diff) + np.trace(s1) + np.trace(s2) - 2 * np.trace(covmean)

## Generate Synthetic Dataset

In [None]:
def generate_dataset(output_dir, n_per_class=100):
    G.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    for g in range(6):
        gdir = os.path.join(output_dir, f'grade_{g}')
        os.makedirs(gdir, exist_ok=True)
        
        print(f'Grade {g}...')
        with torch.no_grad():
            for i in tqdm(range(n_per_class)):
                z = torch.randn(1, config.nz, device=device)
                y = torch.tensor([g], device=device)
                img = ((G(z, y).squeeze() + 1) / 2).clamp(0, 1)
                img = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                Image.fromarray(img).save(f'{gdir}/syn_{g}_{i:04d}.png')
    
    G.train()
    print(f'Generated {n_per_class * 6} images')

## Demo with Synthetic Data

In [None]:
class DemoDataset(Dataset):
    def __init__(self, n=600, size=256):
        self.n = n
        self.size = size
        self.colors = [[.9,.8,.9],[.85,.7,.85],[.8,.6,.8],[.75,.5,.75],[.7,.4,.7],[.65,.3,.65]]
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, i):
        g = i % 6
        base = np.array(self.colors[g])
        img = np.clip(base + np.random.randn(self.size, self.size, 3) * 0.1, 0, 1)
        
        for _ in range(20 + g * 10):
            cx, cy = np.random.randint(0, self.size, 2)
            r = np.random.randint(3, 10)
            yy, xx = np.ogrid[-cy:self.size-cy, -cx:self.size-cx]
            mask = xx**2 + yy**2 <= r**2
            img[mask] = np.clip(img[mask] - 0.2, 0, 1)
        
        return torch.from_numpy(img).float().permute(2, 0, 1) * 2 - 1, g

In [None]:
# Run demo
RUN_DEMO = True

if RUN_DEMO:
    print('Running demo with synthetic data...')
    demo_loader = DataLoader(DemoDataset(600, config.image_size), 
                             batch_size=config.batch_size, shuffle=True, drop_last=True)
    
    for epoch in range(1, 6):
        losses_d, losses_g = [], []
        for imgs, labels in demo_loader:
            ld, lg, _, _ = train_step(imgs, labels)
            losses_d.append(ld)
            losses_g.append(lg)
        print(f'Epoch {epoch}: D={np.mean(losses_d):.4f} G={np.mean(losses_g):.4f}')
    
    print('\nGenerating samples...')
    visualize_by_grade(4)
    print('Demo complete!')

## Train with Real Data

Uncomment and run after downloading the PANDA dataset.

In [None]:
# Preprocess (run once)
# if os.path.exists(TRAIN_CSV):
#     preprocess_dataset(TRAIN_CSV, TRAIN_IMAGES_DIR, PATCHES_DIR)

# Create dataloader
# transform = get_transforms(config.image_size)
# dataset = PANDADataset(PATCHES_DIR, transform)
# dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, drop_last=True)

# Train
# history = train(dataloader, config.num_epochs)
# plot_history(history)

# Evaluate
# fid_calc = FIDCalculator()
# fid = fid_calc.compute(dataloader, G)
# print(f'FID: {fid:.2f}')

# Generate synthetic data
# generate_dataset('./synthetic_data', n_per_class=100)

## Summary

This CDCGAN implementation includes:
1. **Data Pipeline**: WSI patch extraction, augmentation, class balancing
2. **Generator**: Label embedding + transposed convolutions (256x256)
3. **Discriminator**: Spectral normalization for stability
4. **Training**: BCE loss with label smoothing
5. **Evaluation**: FID score

### Next Steps
1. Download PANDA dataset from Kaggle
2. Run preprocessing
3. Train the model
4. Generate synthetic data