# SOTA Projected GAN for Prostate Cancer Biopsy Image Synthesis

This notebook implements a high-fidelity **Projected GAN** for synthesizing prostate biopsy patches conditioned on ISUP grade (0-5).

## Key Innovations
- **Style-Modulated Generator**: Uses AdaIN layers to inject class styles into a constant starting tensor.
- **Projected Discriminator**: Leverages a pre-trained EfficientNet-B0 backbone to extract multi-resolution features.
- **Projection Discrimination**: Directly matches feature vectors with class label embeddings for accurate conditioning.
- **Diversity Loss (LZ)**: Maximizes visual variety between different noise vectors to prevent mode collapse.

In [None]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.models import efficientnet_b0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Model Architectures

In [None]:
class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features)
        self.fc = nn.Linear(style_dim, num_features * 2)
    def forward(self, x, style):
        style = self.fc(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1)
        return self.norm(x) * (1 + gamma) + beta

class SynthesisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim, upsample=True):
        super().__init__()
        self.upsample = upsample
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
        self.adain1 = AdaIN(style_dim, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.adain2 = AdaIN(style_dim, out_channels)
    def forward(self, x, style):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = F.leaky_relu(self.adain1(self.conv1(x), style), 0.2)
        x = F.leaky_relu(self.adain2(self.conv2(x), style), 0.2)
        return x

class Generator(nn.Module):
    def __init__(self, nz=512, style_dim=512, n_classes=6, ngf=64):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, 128)
        self.mapping = nn.Sequential(
            nn.Linear(nz + 128, style_dim), nn.LeakyReLU(0.2),
            nn.Linear(style_dim, style_dim)
        )
        self.const = nn.Parameter(torch.randn(1, ngf*16, 4, 4))
        self.blocks = nn.ModuleList([
            SynthesisBlock(ngf*16, ngf*8, style_dim), # 8x8
            SynthesisBlock(ngf*8, ngf*4, style_dim),  # 16x16
            SynthesisBlock(ngf*4, ngf*2, style_dim),  # 32x32
            SynthesisBlock(ngf*2, ngf*2, style_dim),  # 64x64
            SynthesisBlock(ngf*2, ngf, style_dim),    # 128x128
            SynthesisBlock(ngf, ngf, style_dim)       # 256x256
        ])
        self.to_rgb = nn.Sequential(nn.Conv2d(ngf, 3, 1), nn.Tanh())
    def forward(self, z, labels):
        styles = self.mapping(torch.cat([z, self.label_emb(labels)], 1))
        x = self.const.repeat(z.size(0), 1, 1, 1)
        for block in self.blocks: x = block(x, styles)
        return self.to_rgb(x)

class Discriminator(nn.Module):
    def __init__(self, n_classes=6):
        super().__init__()
        backbone = efficientnet_b0(pretrained=True).features
        self.features = nn.ModuleList([backbone[i] for i in range(len(backbone))])
        for p in self.parameters(): p.requires_grad = False
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Conv2d(24, 1, 3, padding=1)),
            nn.Sequential(nn.Conv2d(40, 1, 3, padding=1)),
            nn.Sequential(nn.Conv2d(112, 1, 3, padding=1)),
            nn.Sequential(nn.Conv2d(320, 1, 3, padding=1))
        ])
        self.proj = nn.Conv2d(320, 512, 1)
        self.cls_emb = nn.Embedding(n_classes, 512)
    def forward(self, x, labels):
        feats = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in [2, 3, 5, 7]: feats.append(x)
        score = sum([h(f).mean(dim=[2,3]) for h, f in zip(self.heads, feats)]) / 4
        proj_score = (self.proj(feats[-1]).mean(dim=[2,3]) * self.cls_emb(labels)).sum(dim=1, keepdim=True)
        return score + proj_score

## Training Configuration

In [None]:
class Config:
    batch_size = 16
    lr = 0.0001
    nz = 512
    epochs = 300
    data_dir = './panda_data/patches_256'

config = Config()
G = Generator().to(device)
D = Discriminator().to(device)
opt_G = optim.Adam(G.parameters(), lr=config.lr, betas=(0.0, 0.99))
opt_D = optim.Adam(D.parameters(), lr=config.lr*4, betas=(0.0, 0.99))

## Training Step with Diversity Loss

In [None]:
def train_step(real_imgs, labels):
    real_imgs, labels = real_imgs.to(device), labels.to(device)
    bs = real_imgs.size(0)

    # Train D
    opt_D.zero_grad()
    d_real = D(real_imgs, labels)
    z = torch.randn(bs, config.nz, device=device)
    fake = G(z, labels)
    d_fake = D(fake.detach(), labels)
    loss_D = torch.mean(F.relu(1.0 - d_real)) + torch.mean(F.relu(1.0 + d_fake))
    loss_D.backward()
    opt_D.step()

    # Train G
    opt_G.zero_grad()
    z1, z2 = torch.randn(bs, config.nz, device=device), torch.randn(bs, config.nz, device=device)
    f1, f2 = G(z1, labels), G(z2, labels)
    loss_G_adv = -torch.mean(D(f1, labels))
    # Diversity Loss (LZ)
    loss_G_lz = torch.mean(torch.abs(z1 - z2)) / (torch.mean(torch.abs(f1 - f2)) + 1e-8)
    loss_G = loss_G_adv + (loss_G_lz * 0.1)
    loss_G.backward()
    opt_G.step()
    return loss_D.item(), loss_G.item()