<a href="https://colab.research.google.com/github/saisubash1013/MSc-Project/blob/main/H%26E_to_IHC_Project_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pix2Pix

In [None]:
# Pix2Pix — imports & basic setup
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# pick GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# dataset root in Drive (expects TrainValAB/{trainA,trainB,valA,valB})
DATA_PATH = "/content/drive/MyDrive/HER2/TrainValAB"


In [None]:
# Cell 2 — mount Drive and point to dataset
from google.colab import drive
drive.mount('/content/drive')

# expected folders: TrainValAB/{trainA, trainB, valA, valB}
TRAIN_A_DIR = os.path.join(DATA_PATH, "trainA")
TRAIN_B_DIR = os.path.join(DATA_PATH, "trainB")
VALA_DIR    = os.path.join(DATA_PATH, "valA")
VALB_DIR    = os.path.join(DATA_PATH, "valB")
CKPT_DIR    = "/content/drive/MyDrive/checkpointsPix2Pix"

# quick sanity checks
assert os.path.exists(DATA_PATH), f"Dataset root not found: {DATA_PATH}"
for d in [TRAIN_A_DIR, TRAIN_B_DIR, VALA_DIR, VALB_DIR]:
    print(f"{os.path.basename(d):7s}:", "OK" if os.path.exists(d) else "MISSING")

# count images (jpg/png/jpeg)
def count_imgs(p):
    if not os.path.exists(p): return 0
    exts = (".jpg", ".jpeg", ".png")
    return sum(f.lower().endswith(exts) for f in os.listdir(p))

print(f"trainA: {count_imgs(TRAIN_A_DIR)}  |  trainB: {count_imgs(TRAIN_B_DIR)}")
print(f"valA  : {count_imgs(VALA_DIR)}  |  valB  : {count_imgs(VALB_DIR)}")

os.makedirs(CKPT_DIR, exist_ok=True)

In [None]:
# Cell 3 — U-Net generator (Pix2Pix)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # downsampling (H/2,W/2 each step)
        self.enc1 = self._down(3,   64, norm=False)
        self.enc2 = self._down(64,  128)
        self.enc3 = self._down(128, 256)
        self.enc4 = self._down(256, 512)
        self.enc5 = self._down(512, 512)

        # upsampling + skip connections
        self.dec1 = self._up(512,   512)
        self.dec2 = self._up(1024,  256)
        self.dec3 = self._up(512,   128)
        self.dec4 = self._up(256,    64)
        self.dec5 = nn.ConvTranspose2d(128, 3, 4, 2, 1)  # back to 3 channels

    @staticmethod
    def _down(i, o, norm=True):
        layers = [nn.Conv2d(i, o, 4, 2, 1)]
        if norm: layers.append(nn.BatchNorm2d(o))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    @staticmethod
    def _up(i, o):
        return nn.Sequential(
            nn.ConvTranspose2d(i, o, 4, 2, 1),
            nn.BatchNorm2d(o),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)

        d1 = self.dec1(e5); d1 = torch.cat([d1, e4], dim=1)
        d2 = self.dec2(d1); d2 = torch.cat([d2, e3], dim=1)
        d3 = self.dec3(d2); d3 = torch.cat([d3, e2], dim=1)
        d4 = self.dec4(d3); d4 = torch.cat([d4, e1], dim=1)

        return torch.tanh(self.dec5(d4))  # outputs in [-1, 1]

generator = Generator().to(device)
print("Generator ready.")


In [None]:
# Cell 4 — PatchGAN discriminator (Pix2Pix)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),  nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 1, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 1),  nn.Sigmoid()
        )

    def forward(self, x_in, x_tgt):
        # condition on (input, target) pair
        x = torch.cat([x_in, x_tgt], dim=1)
        return self.net(x)  # patch-wise real/fake map

discriminator = Discriminator().to(device)
print("Discriminator ready.")


In [None]:
# Cell 5 — training setup (losses, optimizers, params)
criterion_GAN = nn.BCELoss()   # real/fake
criterion_L1  = nn.L1Loss()    # pixel L1

optimizer_G = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

num_epochs  = 10
lambda_L1   = 100
start_epoch = 0

print("Training setup ready → epochs:", num_epochs, "| λ_L1:", lambda_L1)


In [None]:
# Cell 6 — resume training from a saved checkpoint
def load_checkpoint(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(ckpt['generator_state_dict'])
    discriminator.load_state_dict(ckpt['discriminator_state_dict'])
    optimizer_G.load_state_dict(ckpt['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(ckpt['optimizer_D_state_dict'])
    epoch = ckpt.get('epoch', 0)
    print(f"Resumed from epoch {epoch}")
    return epoch

# To resume, uncomment and point to a file:
# start_epoch = load_checkpoint('/content/drive/MyDrive/checkpointsPix2Pix/checkpoint_epoch_5.pth')


In [None]:
# Cell 6.5 — dataset + dataloader (creates `train_loader`)
class PairedFolderDataset(Dataset):
    def __init__(self, root, sub_a='trainA', sub_b='trainB', size=256):
        self.dir_a = os.path.join(root, sub_a)
        self.dir_b = os.path.join(root, sub_b)

        exts = ('.jpg', '.png', '.jpeg')
        self.files_a = sorted([f for f in os.listdir(self.dir_a) if f.lower().endswith(exts)])
        self.files_b = sorted([f for f in os.listdir(self.dir_b) if f.lower().endswith(exts)])

        # pair by index (assumes aligned ordering/filenames)
        self.length = min(len(self.files_a), len(self.files_b))

        # to 256×256, tensors, normalize to [-1, 1] (matches generator tanh)
        self.tf = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img_a = Image.open(os.path.join(self.dir_a, self.files_a[idx])).convert('RGB')
        img_b = Image.open(os.path.join(self.dir_b, self.files_b[idx])).convert('RGB')
        return self.tf(img_a), self.tf(img_b)

# build loader
train_dataset = PairedFolderDataset(DATA_PATH, 'trainA', 'trainB', size=256)
train_loader  = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)

print(f"Train pairs: {len(train_dataset)}  |  Batch size: {train_loader.batch_size}")


In [None]:
# Cell 7 — training loop + per-epoch checkpoints
generator.train()
discriminator.train()

training_losses = {'g': [], 'd': []}

for epoch in range(start_epoch, num_epochs):
    g_sum = 0.0
    d_sum = 0.0

    for i, (he_images, ihc_images) in enumerate(train_loader):
        he_images  = he_images.to(device)
        ihc_images = ihc_images.to(device)

        # PatchGAN targets (30×30 map)
        b = he_images.size(0)
        real_labels = torch.ones(b, 1, 30, 30, device=device)
        fake_labels = torch.zeros(b, 1, 30, 30, device=device)

        # --- Discriminator ---
        optimizer_D.zero_grad()
        out_real   = discriminator(he_images, ihc_images)
        d_real     = criterion_GAN(out_real, real_labels)

        fake_ihc   = generator(he_images)
        out_fake   = discriminator(he_images, fake_ihc.detach())
        d_fake     = criterion_GAN(out_fake, fake_labels)

        d_loss = 0.5 * (d_real + d_fake)
        d_loss.backward()
        optimizer_D.step()

        # --- Generator ---
        optimizer_G.zero_grad()
        out_fake = discriminator(he_images, fake_ihc)
        g_gan    = criterion_GAN(out_fake, real_labels)
        g_l1     = criterion_L1(fake_ihc, ihc_images)
        g_loss   = g_gan + lambda_L1 * g_l1
        g_loss.backward()
        optimizer_G.step()

        g_sum += g_loss.item()
        d_sum += d_loss.item()

        if i % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs} | Step {i:04d}/{len(train_loader)} | D {d_loss.item():.4f} | G {g_loss.item():.4f}")

    # epoch averages
    g_avg = g_sum / max(1, len(train_loader))
    d_avg = d_sum / max(1, len(train_loader))
    training_losses['g'].append(g_avg)
    training_losses['d'].append(d_avg)

    # save checkpoint
    ckpt = {
        'epoch': epoch + 1,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'g_loss': g_avg,
        'd_loss': d_avg,
        'training_losses': training_losses
    }
    torch.save(ckpt, os.path.join(CKPT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
    print(f"Epoch {epoch+1}/{num_epochs} | G {g_avg:.4f} | D {d_avg:.4f} — saved")

print("Training complete.")


In [None]:
# Cell 8 — quick validation preview (save a few H&E | Real IHC | Fake IHC triptychs)
from torchvision.utils import save_image

# small helper: [-1,1] → [0,1]
def denorm(x):
    return (x * 0.5 + 0.5).clamp(0, 1)

# build a val loader (same pairing logic)
val_dataset = PairedFolderDataset(DATA_PATH, 'valA', 'valB', size=256)
val_loader  = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2)

SAMPLE_DIR = os.path.join(CKPT_DIR, "samples_val")
os.makedirs(SAMPLE_DIR, exist_ok=True)

# generate N samples and save triptychs
N = min(12, len(val_dataset))
was_training = generator.training
generator.eval()

with torch.no_grad():
    for i, (he, ihc) in enumerate(val_loader):
        if i >= N: break
        he  = he.to(device)
        ihc = ihc.to(device)

        fake = generator(he)
        trip = torch.cat([denorm(he), denorm(ihc), denorm(fake)], dim=3)  # concat width-wise
        save_image(trip.cpu(), os.path.join(SAMPLE_DIR, f"val_{i:03d}_triptych.png"))

if was_training:
    generator.train()

print(f"Saved {N} triptychs to: {SAMPLE_DIR}")


In [None]:
# === Generate 10 triptychs from the given epoch-200 checkpoint ===
import os, glob
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CHECKPOINT_PATH = "/content/drive/MyDrive/HER2/pix2pix_checkpoints/generator_epoch_200.pth"
VALA_DIR = "/content/drive/MyDrive/HER2/TrainValAB/valA"  # H&E
VALB_DIR = "/content/drive/MyDrive/HER2/TrainValAB/valB"  # IHC
OUT_DIR  = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"
os.makedirs(OUT_DIR, exist_ok=True)

assert os.path.exists(CHECKPOINT_PATH), f"Checkpoint not found: {CHECKPOINT_PATH}"
assert os.path.exists(VALA_DIR) and os.path.exists(VALB_DIR), "valA/valB paths not found."

# ---- Model defs ----
class Pix2PixUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1=self._c(3,64,False); self.enc2=self._c(64,128); self.enc3=self._c(128,256)
        self.enc4=self._c(256,512); self.enc5=self._c(512,512)
        self.dec1=self._u(512,512); self.dec2=self._u(1024,256)
        self.dec3=self._u(512,128); self.dec4=self._u(256,64)
        self.dec5=nn.ConvTranspose2d(128,3,4,2,1)
    def _c(self,i,o,norm=True):
        layers=[nn.Conv2d(i,o,4,2,1)]
        if norm: layers.append(nn.BatchNorm2d(o))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)
    def _u(self,i,o):
        return nn.Sequential(nn.ConvTranspose2d(i,o,4,2,1), nn.BatchNorm2d(o), nn.ReLU())
    def forward(self,x):
        e1=self.enc1(x); e2=self.enc2(e1); e3=self.enc3(e2); e4=self.enc4(e3); e5=self.enc5(e4)
        d1=self.dec1(e5); d1=torch.cat([d1,e4],1)
        d2=self.dec2(d1); d2=torch.cat([d2,e3],1)
        d3=self.dec3(d2); d3=torch.cat([d3,e2],1)
        d4=self.dec4(d3); d4=torch.cat([d4,e1],1)
        return torch.tanh(self.dec5(d4))

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, down=True, act="relu", use_bn=True):
        super().__init__()
        norm = nn.InstanceNorm2d(out_ch) if use_bn else nn.Identity()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False, padding_mode="reflect") if down
            else nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            norm, nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
    def forward(self,x): return self.conv(x)

class CycleUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, feat=64):
        super().__init__()
        self.initial_down=ConvBlock(in_ch,feat,True,"leaky",False)
        self.down1=ConvBlock(feat,feat*2,True,"leaky"); self.down2=ConvBlock(feat*2,feat*4,True,"leaky")
        self.down3=ConvBlock(feat*4,feat*8,True,"leaky"); self.down4=ConvBlock(feat*8,feat*8,True,"leaky")
        self.down5=ConvBlock(feat*8,feat*8,True,"leaky"); self.down6=ConvBlock(feat*8,feat*8,True,"leaky")
        self.bottleneck=nn.Sequential(nn.Conv2d(feat*8,feat*8,4,2,1,padding_mode="reflect"), nn.ReLU())
        self.up1=ConvBlock(feat*8,feat*8,False,"relu",True)
        self.up2=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up3=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up4=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up5=ConvBlock(feat*16,feat*4,False,"relu",True)
        self.up6=ConvBlock(feat*8,feat*2,False,"relu",True)
        self.up7=ConvBlock(feat*4,feat,False,"relu",True)
        self.final_up=nn.Sequential(nn.ConvTranspose2d(feat*2,out_ch,4,2,1), nn.Tanh())
    def forward(self,x):
        d1=self.initial_down(x); d2=self.down1(d1); d3=self.down2(d2)
        d4=self.down3(d3); d5=self.down4(d4); d6=self.down5(d5); d7=self.down6(d6)
        b=self.bottleneck(d7)
        u1=self.up1(b); u2=self.up2(torch.cat([u1,d7],1))
        u3=self.up3(torch.cat([u2,d6],1)); u4=self.up4(torch.cat([u3,d5],1))
        u5=self.up5(torch.cat([u4,d4],1)); u6=self.up6(torch.cat([u5,d3],1))
        u7=self.up7(torch.cat([u6,d2],1))
        return self.final_up(torch.cat([u7,d1],1))

# ---- load checkpoint and detect arch ----
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
sd = None
if isinstance(ckpt, dict):
    for k in ["generator_state_dict","gen_A_state_dict","state_dict","model","netG","G"]:
        if k in ckpt and isinstance(ckpt[k], dict):
            sd = ckpt[k]; break
    if sd is None and all(torch.is_tensor(v) for v in ckpt.values()):
        sd = ckpt
else:
    raise RuntimeError("Unexpected checkpoint format.")
first = next(iter(sd))
if first.startswith("module."):
    sd = {k.replace("module.","",1): v for k,v in sd.items()}

def detect_arch(keys):
    ks = list(keys)
    if any(k.startswith("enc1") or ".enc1." in k for k in ks): return "pix2pix"
    if any(k.startswith("initial_down") or ".initial_down." in k for k in ks): return "cyc_unet"
    if any(".down1." in k for k in ks): return "cyc_unet"
    return "unknown"

arch = detect_arch(sd.keys())
print("Detected arch:", arch)

G = Pix2PixUNet().to(DEVICE).eval() if arch=="pix2pix" else CycleUNet().to(DEVICE).eval()
G.load_state_dict(sd, strict=False)

# ---- pair 10 filenames by basename ----
def index_by_base(folder):
    idx = {}
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp","*.webp"):
        for p in glob.glob(os.path.join(folder, ext)):
            idx[os.path.splitext(os.path.basename(p))[0]] = p
    return idx

A = index_by_base(VALA_DIR); B = index_by_base(VALB_DIR)
common = sorted(set(A) & set(B))[:10]
assert common, "No matching basenames between valA and valB."

# ---- transforms ----
tx_in = transforms.Compose([
    transforms.Resize(256), transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
tx_01 = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
to01 = lambda t: (t*0.5 + 0.5).clamp(0,1)

# ---- generate ----
print("Saving 10 triptychs to:", OUT_DIR)
with torch.no_grad():
    for base in common:
        he_img  = Image.open(A[base]).convert("RGB")
        ihc_img = Image.open(B[base]).convert("RGB")
        x = tx_in(he_img).unsqueeze(0).to(DEVICE)
        fake = G(x)
        fake01 = to01(fake)[0].cpu()
        he01   = tx_01(he_img)
        real01 = tx_01(ihc_img)
        trip = torch.cat([he01, real01, fake01], dim=2)
        save_image(trip,   os.path.join(OUT_DIR, f"{base}_TRIPTYCH_e200.png"))
        save_image(fake01, os.path.join(OUT_DIR, f"{base}_fakeIHC_e200.png"))
print("Done.")


In [None]:
# View the 10 saved triptychs (H&E | Real IHC | Generated IHC)
import os, glob, math
from PIL import Image
import matplotlib.pyplot as plt

OUT_DIR = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"

trip_paths = sorted(glob.glob(os.path.join(OUT_DIR, "*_TRIPTYCH_e200.png")))
print(f"Found {len(trip_paths)} triptychs in:", OUT_DIR)
assert trip_paths, f"No triptych images found in {OUT_DIR}. Check that the previous generation step ran."

cols = 5
rows = math.ceil(len(trip_paths) / cols)
plt.figure(figsize=(cols*4, rows*4))
for i, p in enumerate(trip_paths[:rows*cols]):
    img = Image.open(p).convert("RGB")
    ax = plt.subplot(rows, cols, i+1)
    ax.imshow(img)
    ax.set_title(os.path.basename(p), fontsize=8)
    ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
import random, matplotlib.pyplot as plt
from PIL import Image

OUT_DIR = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"
trip_paths = sorted(glob.glob(os.path.join(OUT_DIR, "*_TRIPTYCH_e200.png")))
p = random.choice(trip_paths)  # or set p to a specific file path
plt.figure(figsize=(12,4))
plt.imshow(Image.open(p).convert("RGB"))
plt.title(os.path.basename(p))
plt.axis('off')
plt.show()


# **CycleGAN hybrid**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_PATH = "/content/drive/MyDrive/HER2/TrainValAB"

# folders (unpaired A↔B)
TRAIN_A_DIR = os.path.join(DATA_PATH, "trainA")   # domain A: H&E
TRAIN_B_DIR = os.path.join(DATA_PATH, "trainB")   # domain B: IHC
VALA_DIR    = os.path.join(DATA_PATH, "valA")
VALB_DIR    = os.path.join(DATA_PATH, "valB")

# checkpoints
CKPT_CYC_DIR = "/content/drive/MyDrive/checkpointsCycleHybrid"
os.makedirs(CKPT_CYC_DIR, exist_ok=True)

# hyperparams
IMG_SIZE    = 256
BATCH_SIZE  = 4
EPOCHS_CYC  = 10
LR_CYC      = 2e-4
BETAS_CYC   = (0.5, 0.999)
LAMBDA_CYC  = 10     # cycle-consistency weight
LAMBDA_ID   = 5      # identity weight (often 0.5 * LAMBDA_CYC)

# quick sanity print
for d in [TRAIN_A_DIR, TRAIN_B_DIR, VALA_DIR, VALB_DIR]:
    print(f"{os.path.basename(d):6s}:", "OK" if os.path.exists(d) else "MISSING")
print(f"Checkpoints → {CKPT_CYC_DIR}")

In [None]:
# CycleGAN (hybrid) — Cell 3: models (U-Net generators + PatchGAN discriminators)

# --- small conv block (down/upsample with optional InstanceNorm) ---
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_bn=True):
        super().__init__()
        norm = nn.InstanceNorm2d(out_channels) if use_bn else nn.Identity()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down else
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            norm,
            nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True),
        )
    def forward(self, x):
        return self.conv(x)

# --- U-Net-like generator (used for A→B and B→A) ---
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        # down (7 steps to 1×1 bottleneck)
        self.initial_down = ConvBlock(in_channels, features, down=True,  act="leaky", use_bn=False)
        self.down1 = ConvBlock(features,     features*2, down=True,  act="leaky")
        self.down2 = ConvBlock(features*2,   features*4, down=True,  act="leaky")
        self.down3 = ConvBlock(features*4,   features*8, down=True,  act="leaky")
        self.down4 = ConvBlock(features*8,   features*8, down=True,  act="leaky")
        self.down5 = ConvBlock(features*8,   features*8, down=True,  act="leaky")
        self.down6 = ConvBlock(features*8,   features*8, down=True,  act="leaky")
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"),
            nn.ReLU(inplace=True),
        )
        # up + skips (mirror)
        self.up1 = ConvBlock(features*8,     features*8,   down=False, act="relu", use_bn=True)
        self.up2 = ConvBlock(features*16,    features*8,   down=False, act="relu", use_bn=True)
        self.up3 = ConvBlock(features*16,    features*8,   down=False, act="relu", use_bn=True)
        self.up4 = ConvBlock(features*16,    features*8,   down=False, act="relu", use_bn=True)
        self.up5 = ConvBlock(features*16,    features*4,   down=False, act="relu", use_bn=True)
        self.up6 = ConvBlock(features*8,     features*2,   down=False, act="relu", use_bn=True)
        self.up7 = ConvBlock(features*4,     features,     down=False, act="relu", use_bn=True)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, 4, 2, 1),
            nn.Tanh(),  # [-1, 1]
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        b  = self.bottleneck(d7)

        u1 = self.up1(b)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return self.final_up(torch.cat([u7, d1], 1))

# --- 70×70 PatchGAN discriminator (spectral norm, logits output) ---
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=(64, 128, 256, 512)):
        super().__init__()
        self.initial = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, features[0], 4, 2, 1, padding_mode="reflect")),
            nn.LeakyReLU(0.2, inplace=True),
        )
        layers = []
        in_ch = features[0]
        for f in features[1:]:
            layers += [
                nn.utils.spectral_norm(nn.Conv2d(in_ch, f, 4, 2, 1, bias=False, padding_mode="reflect")),
                nn.InstanceNorm2d(f),
                nn.LeakyReLU(0.2, inplace=True),
            ]
            in_ch = f
        layers += [nn.utils.spectral_norm(nn.Conv2d(in_ch, 1, 4, 1, 1, padding_mode="reflect"))]  # no Sigmoid
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return self.model(x)  # logits map

# --- weight init (DCGAN-style) ---
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            if m.weight is not None: nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:   nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            if m.weight is not None: nn.init.normal_(m.weight, 1.0, 0.02)
            if m.bias is not None:   nn.init.zeros_(m.bias)

# --- optional: VGG19 perceptual loss (feature L1) ---
from torchvision.models import vgg19, VGG19_Weights

class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:36].eval().to(device)  # up to relu5_1
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg
        self.l1  = nn.L1Loss()

    def forward(self, x, y):
        # [-1,1] → [0,1] for VGG
        x01 = (x * 0.5 + 0.5).clamp(0, 1)
        y01 = (y * 0.5 + 0.5).clamp(0, 1)
        return self.l1(self.vgg(x01), self.vgg(y01))

print("CycleGAN (hybrid) models ready.")


In [None]:
# CycleGAN — build unpaired loaders (A=H&E, B=IHC)
class UnpairedFolderDataset(Dataset):
    def __init__(self, root_a, root_b, size=IMG_SIZE, augment=True):
        self.dir_a, self.dir_b = root_a, root_b
        exts = ('.jpg', '.jpeg', '.png')
        self.a_files = sorted([f for f in os.listdir(self.dir_a) if f.lower().endswith(exts)])
        self.b_files = sorted([f for f in os.listdir(self.dir_b) if f.lower().endswith(exts)])
        self.n_a, self.n_b = len(self.a_files), len(self.b_files)
        self.n = max(self.n_a, self.n_b)  # unpaired → cycle the shorter side

        ops = [transforms.Resize((IMG_SIZE, IMG_SIZE))]
        if augment: ops.insert(0, transforms.RandomHorizontalFlip(p=0.5))
        ops += [transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3)]
        self.tf = transforms.Compose(ops)

    def __len__(self): return self.n

    def __getitem__(self, idx):
        a_path = os.path.join(self.dir_a, self.a_files[idx % self.n_a])
        b_path = os.path.join(self.dir_b, self.b_files[idx % self.n_b])
        a = Image.open(a_path).convert('RGB')
        b = Image.open(b_path).convert('RGB')
        return self.tf(a), self.tf(b)

# train / val loaders
cyc_train      = UnpairedFolderDataset(TRAIN_A_DIR, TRAIN_B_DIR, size=IMG_SIZE, augment=True)
cyc_loader     = DataLoader(cyc_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
cyc_val        = UnpairedFolderDataset(VALA_DIR, VALB_DIR, size=IMG_SIZE, augment=False)
cyc_val_loader = DataLoader(cyc_val, batch_size=1, shuffle=False, num_workers=2)

print(f"Unpaired train — A:{cyc_train.n_a}  B:{cyc_train.n_b}  |  batches:{len(cyc_loader)}")


In [None]:
# Cell 4: Training Loop for CycleGAN (WGAN-GP + cycle + identity + perceptual, AMP, replay buffer)

# --- Replay buffer ---
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, images):
        out = []
        for image in images.detach():
            image = torch.unsqueeze(image, 0)
            if len(self.data) < self.max_size:
                self.data.append(image); out.append(image)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    out.append(self.data[i].clone()); self.data[i] = image
                else:
                    out.append(image)
        return torch.cat(out)

# --- WGAN-GP losses ---
def discriminator_loss_wgan_gp(disc_real_pred, disc_fake_pred, real_img, fake_img, discriminator, lambda_gp, device):
    alpha = torch.rand(real_img.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_img + (1 - alpha) * fake_img).requires_grad_(True)
    disc_interpolates = discriminator(interpolates)
    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates, device=device),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return disc_fake_pred.mean() - disc_real_pred.mean() + gradient_penalty

def generator_loss_wgan_gp(disc_fake_pred):
    return -disc_fake_pred.mean()

# --- models ---
gen_A  = Generator(in_channels=3, out_channels=3, features=64).to(DEVICE)  # A→B
gen_B  = Generator(in_channels=3, out_channels=3, features=64).to(DEVICE)  # B→A
disc_A = Discriminator(in_channels=3).to(DEVICE)
disc_B = Discriminator(in_channels=3).to(DEVICE)

initialize_weights(gen_A); initialize_weights(gen_B)
initialize_weights(disc_A); initialize_weights(disc_B)

# --- losses ---
criterion_Cycle      = nn.L1Loss()
criterion_Identity   = nn.L1Loss()
criterion_Perceptual = PerceptualLoss(DEVICE)

# --- optimizers ---
optimizer_gen    = optim.Adam(list(gen_A.parameters()) + list(gen_B.parameters()), lr=LEARNING_RATE_GEN, betas=(0.5, 0.999))
optimizer_disc_A = optim.Adam(disc_A.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))
optimizer_disc_B = optim.Adam(disc_B.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999))

# --- schedulers (linear decay) ---
def lambda_rule(epoch):
    return 1.0 - max(0, epoch - DECAY_EPOCH_START) / float(NUM_EPOCHS - DECAY_EPOCH_START)

scheduler_gen    = torch.optim.lr_scheduler.LambdaLR(optimizer_gen,    lr_lambda=lambda_rule)
scheduler_disc_A = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_A, lr_lambda=lambda_rule)
scheduler_disc_B = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_B, lr_lambda=lambda_rule)

# --- misc ---
start_epoch       = 0
scaler            = GradScaler(init_scale=2.**10)
fake_A_buffer     = ReplayBuffer()
fake_B_buffer     = ReplayBuffer()

# --- resume (optional) ---
if LOAD_MODEL:
    try:
        checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pth"), map_location=DEVICE)
        gen_A.load_state_dict(checkpoint['gen_A_state_dict'])
        gen_B.load_state_dict(checkpoint['gen_B_state_dict'])
        disc_A.load_state_dict(checkpoint['disc_A_state_dict'])
        disc_B.load_state_dict(checkpoint['disc_B_state_dict'])
        optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        optimizer_disc_A.load_state_dict(checkpoint['optimizer_disc_A_state_dict'])
        optimizer_disc_B.load_state_dict(checkpoint['optimizer_disc_B_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Resuming from epoch {start_epoch + 1}/{NUM_EPOCHS}")
    except FileNotFoundError:
        print("No latest checkpoint found. Starting from scratch.")
    except Exception as e:
        print(f"Error loading models: {e}\nStarting from scratch.")

# --- train ---
print("\nStarting CycleGAN training...")
for epoch in range(start_epoch, NUM_EPOCHS):
    gen_A.train(); gen_B.train(); disc_A.train(); disc_B.train()
    loop = tqdm(train_loader, leave=True, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    avg_gen_loss_total = avg_disc_A_loss = avg_disc_B_loss = 0.0
    avg_perceptual_loss_A = avg_perceptual_loss_B = 0.0
    avg_gen_adv_A = avg_gen_adv_B = 0.0
    avg_cycle_A = avg_cycle_B = 0.0
    avg_identity_A = avg_identity_B = 0.0

    for idx, (real_A, real_B) in enumerate(loop):
        real_A = real_A.to(DEVICE)
        real_B = real_B.to(DEVICE)

        # --- D_B ---
        optimizer_disc_B.zero_grad()
        with autocast(DEVICE):
            fake_B = gen_A(real_A)
            disc_real_B_pred = disc_B(real_B)
            fake_B_from_buffer = fake_B_buffer.push_and_pop(fake_B)
            disc_fake_B_pred = disc_B(fake_B_from_buffer.detach())
            loss_disc_B_adv = disc_fake_B_pred.mean() - disc_real_B_pred.mean()

        scaler.scale(loss_disc_B_adv).backward(retain_graph=True)
        scaler.unscale_(optimizer_disc_B)
        loss_disc_B_gp = discriminator_loss_wgan_gp(
            disc_real_B_pred, disc_fake_B_pred, real_B, fake_B_from_buffer, disc_B, LAMBDA_GP, DEVICE
        )
        scaler.scale(loss_disc_B_gp).backward()
        torch.nn.utils.clip_grad_norm_(disc_B.parameters(), max_norm=5.0)
        scaler.step(optimizer_disc_B); scaler.update()
        loss_disc_B_total = loss_disc_B_adv.item() + loss_disc_B_gp.item()

        # --- D_A ---
        optimizer_disc_A.zero_grad()
        with autocast(DEVICE):
            fake_A = gen_B(real_B)
            disc_real_A_pred = disc_A(real_A)
            fake_A_from_buffer = fake_A_buffer.push_and_pop(fake_A)
            disc_fake_A_pred = disc_A(fake_A_from_buffer.detach())
            loss_disc_A_adv = disc_fake_A_pred.mean() - disc_real_A_pred.mean()

        scaler.scale(loss_disc_A_adv).backward(retain_graph=True)
        scaler.unscale_(optimizer_disc_A)
        loss_disc_A_gp = discriminator_loss_wgan_gp(
            disc_real_A_pred, disc_fake_A_pred, real_A, fake_A_from_buffer, disc_A, LAMBDA_GP, DEVICE
        )
        scaler.scale(loss_disc_A_gp).backward()
        torch.nn.utils.clip_grad_norm_(disc_A.parameters(), max_norm=5.0)
        scaler.step(optimizer_disc_A); scaler.update()
        loss_disc_A_total = loss_disc_A_adv.item() + loss_disc_A_gp.item()

        # --- Generators ---
        optimizer_gen.zero_grad()
        with autocast(DEVICE):
            fake_B = gen_A(real_A); fake_A = gen_B(real_B)
            loss_gen_adv_A_val = generator_loss_wgan_gp(disc_B(fake_B))
            loss_gen_adv_B_val = generator_loss_wgan_gp(disc_A(fake_A))
            cycled_A = gen_B(fake_B); loss_cycle_A_val = criterion_Cycle(cycled_A, real_A)
            cycled_B = gen_A(fake_A); loss_cycle_B_val = criterion_Cycle(cycled_B, real_B)
            identity_B = gen_A(real_B); loss_identity_B_val = criterion_Identity(identity_B, real_B)
            identity_A = gen_B(real_A); loss_identity_A_val = criterion_Identity(identity_A, real_A)
            loss_perceptual_A_val = criterion_Perceptual(fake_B.float(), real_B)
            loss_perceptual_B_val = criterion_Perceptual(fake_A.float(), real_A)
            loss_gen_total = (
                loss_gen_adv_A_val + loss_gen_adv_B_val
                + LAMBDA_CYCLE * loss_cycle_A_val + LAMBDA_CYCLE * loss_cycle_B_val
                + LAMBDA_IDENTITY * loss_identity_A_val + LAMBDA_IDENTITY * loss_identity_B_val
                + LAMBDA_PERCEPTUAL * loss_perceptual_A_val + LAMBDA_PERCEPTUAL * loss_perceptual_B_val
            )

        scaler.scale(loss_gen_total).backward()
        torch.nn.utils.clip_grad_norm_(list(gen_A.parameters()) + list(gen_B.parameters()), max_norm=5.0)
        scaler.step(optimizer_gen); scaler.update()

        # --- logs (running averages) ---
        avg_disc_A_loss        += loss_disc_A_total
        avg_disc_B_loss        += loss_disc_B_total
        avg_gen_loss_total     += loss_gen_total.item()
        avg_gen_adv_A          += loss_gen_adv_A_val.item()
        avg_gen_adv_B          += loss_gen_adv_B_val.item()
        avg_cycle_A            += loss_cycle_A_val.item()
        avg_cycle_B            += loss_cycle_B_val.item()
        avg_identity_A         += loss_identity_A_val.item()
        avg_identity_B         += loss_identity_B_val.item()
        avg_perceptual_loss_A  += loss_perceptual_A_val.item()
        avg_perceptual_loss_B  += loss_perceptual_B_val.item()

        loop.set_postfix(
            D_A_loss=loss_disc_A_total, D_B_loss=loss_disc_B_total,
            G_total_loss=loss_gen_total.item(),
            G_adv_A=loss_gen_adv_A_val.item(), G_adv_B=loss_gen_adv_B_val.item(),
            Cycle_A=loss_cycle_A_val.item(), Cycle_B=loss_cycle_B_val.item(),
            Identity_A=loss_identity_A_val.item(), Identity_B=loss_identity_B_val.item(),
            Percept_A=loss_perceptual_A_val.item(), Percept_B=loss_perceptual_B_val.item(),
        )

    # --- epoch end ---
    avg_disc_A_loss       /= len(train_loader)
    avg_disc_B_loss       /= len(train_loader)
    avg_gen_loss_total    /= len(train_loader)
    avg_gen_adv_A         /= len(train_loader)
    avg_gen_adv_B         /= len(train_loader)
    avg_cycle_A           /= len(train_loader)
    avg_cycle_B           /= len(train_loader)
    avg_identity_A        /= len(train_loader)
    avg_identity_B        /= len(train_loader)
    avg_perceptual_loss_A /= len(train_loader)
    avg_perceptual_loss_B /= len(train_loader)

    scheduler_gen.step(); scheduler_disc_A.step(); scheduler_disc_B.step()

    with open(TRAINING_LOG_PATH, 'a') as f:
        if epoch == start_epoch or (epoch == 0 and os.path.getsize(TRAINING_LOG_PATH) == 0):
            f.write("Epoch,Avg_D_A_Loss,Avg_D_B_Loss,Avg_G_Total_Loss,Avg_G_Adv_A,Avg_G_Adv_B,Avg_Cycle_A,Avg_Cycle_B,Avg_Identity_A,Avg_Identity_B,Avg_Percept_A,Avg_Percept_B\n")
        f.write(f"{epoch+1},{avg_disc_A_loss:.6f},{avg_disc_B_loss:.6f},{avg_gen_loss_total:.6f},{avg_gen_adv_A:.6f},{avg_gen_adv_B:.6f},{avg_cycle_A:.6f},{avg_cycle_B:.6f},{avg_identity_A:.6f},{avg_identity_B:.6f},{avg_perceptual_loss_A:.6f},{avg_perceptual_loss_B:.6f}\n")

    if (epoch + 1) % 5 == 0 or epoch == start_epoch:
        if epoch == start_epoch:
            print("\n" + "="*160)
            print(f"{'Epoch':<8} | {'D_A Loss':<10} | {'D_B Loss':<10} | {'G_Total Loss':<14} | {'G_Adv_A':<10} | {'G_Adv_B':<10} | {'Cycle_A':<10} | {'Cycle_B':<10} | {'Identity_A':<12} | {'Identity_B':<12} | {'Percept_A':<11} | {'Percept_B':<11}")
            print("="*160)
        print(f"{epoch+1:<8} | {avg_disc_A_loss:<10.4f} | {avg_disc_B_loss:<10.4f} | {avg_gen_loss_total:<14.4f} | {avg_gen_adv_A:<10.4f} | {avg_gen_adv_B:<10.4f} | {avg_cycle_A:<10.4f} | {avg_cycle_B:<10.4f} | {avg_identity_A:<12.4f} | {avg_identity_B:<12.4f} | {avg_perceptual_loss_A:<11.4f} | {avg_perceptual_loss_B:<11.4f}")

    # --- save ---
    if SAVE_MODEL and (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'gen_A_state_dict': gen_A.state_dict(),
            'gen_B_state_dict': gen_B.state_dict(),
            'disc_A_state_dict': disc_A.state_dict(),
            'disc_B_state_dict': disc_B.state_dict(),
            'optimizer_gen_state_dict': optimizer_gen.state_dict(),
            'optimizer_disc_A_state_dict': optimizer_disc_A.state_dict(),
            'optimizer_disc_B_state_dict': optimizer_disc_B.state_dict(),
            'scheduler_gen_state_dict': scheduler_gen.state_dict(),
            'scheduler_disc_A_state_dict': scheduler_disc_A.state_dict(),
            'scheduler_disc_B_state_dict': scheduler_disc_B.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pth"))
        print(f"Saved checkpoint at epoch {epoch+1}")

    # one sample per epoch
    gen_A.eval(); gen_B.eval()
    with torch.no_grad():
        for i, (val_A, val_B) in enumerate(val_loader):
            if i >= 1: break
            val_A = val_A.to(DEVICE); val_B = val_B.to(DEVICE)
            with autocast(DEVICE):
                fake_B_val   = gen_A(val_A);  cycled_A_val = gen_B(fake_B_val)
                fake_A_val   = gen_B(val_B);  cycled_B_val = gen_A(fake_A_val)
            val_A_den = (val_A * 0.5 + 0.5).clamp(0, 1)
            val_B_den = (val_B * 0.5 + 0.5).clamp(0, 1)
            save_image(val_A_den,          os.path.join(SAMPLE_IMAGES_DIR, f"real_A_he_sample.png"))
            save_image(val_B_den,          os.path.join(SAMPLE_IMAGES_DIR, f"real_B_ihc_sample.png"))
            save_image((fake_B_val*0.5+0.5).clamp(0,1), os.path.join(SAMPLE_IMAGES_DIR, f"generated_B_ihc_epoch_{epoch+1}.png"))
            save_image((cycled_A_val*0.5+0.5).clamp(0,1), os.path.join(SAMPLE_IMAGES_DIR, f"cycled_A_he_epoch_{epoch+1}.png"))
            save_image((fake_A_val*0.5+0.5).clamp(0,1), os.path.join(SAMPLE_IMAGES_DIR, f"generated_A_he_epoch_{epoch+1}.png"))
            save_image((cycled_B_val*0.5+0.5).clamp(0,1), os.path.join(SAMPLE_IMAGES_DIR, f"cycled_B_ihc_epoch_{epoch+1}.png"))
    gen_A.train(); gen_B.train()

print("\nCycleGAN training complete!")


In [None]:
# ===== Cell H1: metrics libs + imports =====
!pip -q install torch-fidelity==0.3.0 lpips==0.1.4 piq==0.8.0

import os, json, random
from pathlib import Path
from PIL import Image

import torch
import torchvision.transforms as T
from torchvision.utils import save_image
import pandas as pd
import numpy as np

from torch_fidelity import calculate_metrics  # FID/KID
import lpips                                   # LPIPS
import piq                                     # SSIM/PSNR

print("Ready on:", DEVICE)


In [None]:
# ===== Cell H3: Eval config & pairing =====
# Use the same val split as vanilla
VAL_DIR_A = '/content/drive/MyDrive/HER2/TrainValAB/valA'
VAL_DIR_B = '/content/drive/MyDrive/HER2/TrainValAB/valB'

EVAL_TAG = "hybrid_epoch_200"
OUT_ROOT = "/content/drive/MyDrive/HER2/hybrid_eval"
GEN_DIR  = Path(OUT_ROOT) / f"eval_{EVAL_TAG}_A2B"      # generated A->B
LOG_DIR  = Path(OUT_ROOT) / f"metrics_{EVAL_TAG}"       # metrics
GEN_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256  # keep consistent with training

def list_images(folder):
    exts = ('.png','.jpg','.jpeg','.tif','.tiff','.bmp')
    return sorted([p for p in Path(folder).rglob('*') if p.suffix.lower() in exts])

paths_A = list_images(VAL_DIR_A)
paths_B = list_images(VAL_DIR_B)
bname_to_B = {Path(p).stem: p for p in paths_B}
pairs = [(str(pA), bname_to_B[Path(pA).stem]) for pA in paths_A if Path(pA).stem in bname_to_B]

print(f"A images: {len(paths_A)} | B images: {len(paths_B)} | Paired matches: {len(pairs)}")

# Transforms
to_tensor = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize((0.5,)*3, (0.5,)*3),
])
to_pil = T.Compose([T.Lambda(lambda x: (x * 0.5 + 0.5).clamp(0,1)), T.ToPILImage()])

# For LPIPS/SSIM/PSNR (both sides same size)
resize_01 = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
    T.ToTensor(),  # [0,1]
])


In [None]:
# ===== Cell H4: Generate A->B for all valA =====
with torch.no_grad():
    for pA in paths_A:
        out_path = GEN_DIR / (Path(pA).stem + ".png")
        if out_path.exists():
            continue
        imgA = Image.open(pA).convert('RGB')
        x = to_tensor(imgA).unsqueeze(0).to(DEVICE)
        y = gen_A_h(x)[0].cpu()            # [-1,1]
        to_pil(y).save(out_path)

print(f"Saved {len(list(GEN_DIR.glob('*.png')))} generated images to {GEN_DIR}")


In [None]:
# ===== Cell H5: FID & KID =====
fidkid = calculate_metrics(
    input1=str(GEN_DIR),
    input2=str(VAL_DIR_B),
    cuda=torch.cuda.is_available(),
    isc=False, fid=True, kid=True, prc=False, verbose=False
)
with open(LOG_DIR / "fid_kid.json", "w") as f:
    json.dump(fidkid, f, indent=2)
print(json.dumps(fidkid, indent=2))


In [None]:
# ===== Cell H6: LPIPS / SSIM / PSNR on paired matches =====
lpips_fn = lpips.LPIPS(net='vgg').to(DEVICE).eval()

rows = []
with torch.no_grad():
    for pA, pB in pairs:
        gen_path = GEN_DIR / (Path(pA).stem + ".png")
        if not gen_path.exists():
            continue

        # Resize both sides to the same eval size (256)
        G01 = resize_01(Image.open(gen_path).convert('RGB')).to(DEVICE).unsqueeze(0)
        B01 = resize_01(Image.open(pB).convert('RGB')).to(DEVICE).unsqueeze(0)

        # LPIPS expects [-1,1]
        Gm1p1, Bm1p1 = G01*2-1, B01*2-1
        lp = lpips_fn(Gm1p1, Bm1p1).item()

        ssim = piq.ssim(G01, B01, data_range=1.0).item()
        psnr = piq.psnr(G01, B01, data_range=1.0).item()

        rows.append({"basename": Path(pA).stem, "lpips": lp, "ssim": ssim, "psnr": psnr})

df = pd.DataFrame(rows).sort_values("basename")
df.to_csv(LOG_DIR / "paired_metrics.csv", index=False)

summary = {
    "tag": EVAL_TAG,
    "N_pairs": int(len(df)),
    "LPIPS_mean": float(df.lpips.mean()),
    "LPIPS_std": float(df.lpips.std(ddof=0)),
    "SSIM_mean": float(df.ssim.mean()),
    "SSIM_std": float(df.ssim.std(ddof=0)),
    "PSNR_mean": float(df.psnr.mean()),
    "PSNR_std": float(df.psnr.std(ddof=0)),
}
with open(LOG_DIR / "paired_metrics_summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary, indent=2))


In [None]:
# ===== Cell H7: Triplet grid for the thesis =====
import matplotlib.pyplot as plt
random.seed(42)

N = 12  # number of triplets
sample_pairs = pairs[:]
random.shuffle(sample_pairs)
sample_pairs = sample_pairs[:N]

ncols, nrows = 3, N
fig, axes = plt.subplots(nrows, ncols, figsize=(9, 3*N), dpi=150)

for i, (pA, pB) in enumerate(sample_pairs):
    gen = GEN_DIR / (Path(pA).stem + ".png")
    A = Image.open(pA).convert('RGB')
    B = Image.open(pB).convert('RGB')
    G = Image.open(gen).convert('RGB')

    axes[i,0].imshow(A); axes[i,0].set_title("H&E (A)"); axes[i,0].axis('off')
    axes[i,1].imshow(B); axes[i,1].set_title("IHC – Expected (B)"); axes[i,1].axis('off')
    axes[i,2].imshow(G); axes[i,2].set_title("IHC – Generated (A→B)"); axes[i,2].axis('off')

plt.tight_layout()
FIG_PATH = Path(OUT_ROOT) / f"eval_{EVAL_TAG}_A-B-G_triplets.png"
plt.savefig(FIG_PATH, bbox_inches='tight')
print("Saved figure:", FIG_PATH)


#Plain Cycle GAN

In [None]:
# ===== Cell 1: Imports and Global Configuration =====
import os, random, math
from pathlib import Path
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from PIL import Image

# (Optional) Pretty loss tables
try:
    import pandas as pd
except Exception:
    pd = None

# --- Mount Google Drive (Colab) ---
from google.colab import drive
drive.mount('/content/drive')

# --- Configuration ---
@dataclass
class Config:
    # Your dataset folders (pix2pix-style)
    TRAIN_DIR_A: str = '/content/drive/MyDrive/HER2/TrainValAB/trainA'
    TRAIN_DIR_B: str = '/content/drive/MyDrive/HER2/TrainValAB/trainB'
    VAL_DIR_A:   str = '/content/drive/MyDrive/HER2/TrainValAB/valA'
    VAL_DIR_B:   str = '/content/drive/MyDrive/HER2/TrainValAB/valB'

    # Experiment output (new folder for vanilla CycleGAN)
    EXP_ROOT: str = '/content/drive/MyDrive/HER2/cyclegan_vanilla'
    SAVE_EVERY: int = 10      # checkpoint cadence
    SAMPLE_EVERY: int = 10    # visualization cadence

    # Training
    EPOCHS: int = 200
    BATCH_SIZE: int = 1
    IMG_SIZE: int = 256
    SEED: int = 42
    DEVICE: str = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')

    # Optimizer (vanilla CycleGAN uses LSGAN; these are standard)
    LR: float = 2e-4
    BETA1: float = 0.5
    BETA2: float = 0.999
    LR_DECAY_START: int = 100  # start linear decay here

    # Loss weights
    LAMBDA_CYCLE: float = 10.0
    LAMBDA_ID: float = 5.0     # 0.5 * LAMBDA_CYCLE

cfg = Config()

# Create output dirs
exp_dir = Path(cfg.EXP_ROOT)
ckpt_dir = exp_dir / 'checkpoints'
samples_dir = exp_dir / 'samples'
logs_dir = exp_dir / 'logs'
for d in [exp_dir, ckpt_dir, samples_dir, logs_dir]:
    d.mkdir(parents=True, exist_ok=True)

# Reproducibility
random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)

# Transforms
transform = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

print("Config:", asdict(cfg))
print("Outputs ->", str(exp_dir))
print("Device:", cfg.DEVICE)


In [None]:
# ===== Cell 2: Dataset Class and DataLoader =====
class UnpairedAB(Dataset):
    """
    Expects four folders:
      trainA/, trainB/, valA/, valB/
    Unpaired by default. For visualization, tries to match by filename (if names align).
    """
    def __init__(self, dir_A, dir_B, augment=False):
        self.dir_A = Path(dir_A)
        self.dir_B = Path(dir_B)
        exts = ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')
        self.paths_A = sorted([p for p in self.dir_A.rglob('*') if p.suffix.lower() in exts])
        self.paths_B = sorted([p for p in self.dir_B.rglob('*') if p.suffix.lower() in exts])
        if len(self.paths_A) == 0 or len(self.paths_B) == 0:
            raise RuntimeError(f"No images found in:\n{self.dir_A}\n{self.dir_B}")

        # Basic aug (optional)
        t = [transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC)]
        if augment:
            t.append(transforms.RandomHorizontalFlip())
            t.append(transforms.RandomVerticalFlip())
        t += [transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3)]
        self.tf = transforms.Compose(t)

    def __len__(self):
        # Unpaired sampling
        return max(len(self.paths_A), len(self.paths_B))

    def __getitem__(self, idx):
        path_A = self.paths_A[idx % len(self.paths_A)]
        # Try to pick a B with the same basename; else random
        base = path_A.stem
        same = [p for p in self.paths_B if p.stem == base]
        path_B = same[0] if same else self.paths_B[random.randint(0, len(self.paths_B)-1)]

        img_A = Image.open(path_A).convert('RGB')
        img_B = Image.open(path_B).convert('RGB')
        return self.tf(img_A), self.tf(img_B), path_A.name, path_B.name

# DataLoaders
train_ds = UnpairedAB(cfg.TRAIN_DIR_A, cfg.TRAIN_DIR_B, augment=True)
val_ds   = UnpairedAB(cfg.VAL_DIR_A,   cfg.VAL_DIR_B,   augment=False)

train_loader = DataLoader(train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1, shuffle=True, num_workers=2, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

# Helpers for visualization
inv_norm = transforms.Normalize(mean=[-1,-1,-1], std=[2,2,2])
def denorm(x):  # [-1,1] -> [0,1]
    return (x * 0.5 + 0.5).clamp(0,1)


In [None]:
# ===== Cell 3: Model Definitions (Generators & Discriminators) =====
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, padding=0),
            nn.InstanceNorm2d(dim, affine=False),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, padding=0),
            nn.InstanceNorm2d(dim, affine=False),
        )
    def forward(self, x):
        return x + self.block(x)

class ResnetGenerator(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_filters=64, n_blocks=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_c, n_filters, 7, padding=0),
            nn.InstanceNorm2d(n_filters, affine=False),
            nn.ReLU(inplace=True),
        ]
        curr = n_filters
        # Down 2x
        for _ in range(2):
            model += [
                nn.Conv2d(curr, curr*2, 3, stride=2, padding=1),
                nn.InstanceNorm2d(curr*2, affine=False),
                nn.ReLU(inplace=True),
            ]
            curr *= 2
        # Res blocks
        for _ in range(n_blocks):
            model += [ResnetBlock(curr)]
        # Up 2x
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr, curr//2, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(curr//2, affine=False),
                nn.ReLU(inplace=True),
            ]
            curr //= 2
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(curr, out_c, 7, padding=0),
            nn.Tanh(),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=3, n_filters=64):
        super().__init__()
        def block(in_f, out_f, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
            if norm: layers.insert(1, nn.InstanceNorm2d(out_f, affine=False))
            return layers
        self.model = nn.Sequential(
            *block(in_c, n_filters, norm=False),
            *block(n_filters, n_filters*2),
            *block(n_filters*2, n_filters*4),
            nn.Conv2d(n_filters*4, 1, 4, stride=1, padding=1),  # Patch score
        )
    def forward(self, x):
        return self.model(x)

# Instantiate models
G_A2B = ResnetGenerator().to(cfg.DEVICE)
G_B2A = ResnetGenerator().to(cfg.DEVICE)
D_A   = PatchDiscriminator().to(cfg.DEVICE)
D_B   = PatchDiscriminator().to(cfg.DEVICE)

# Optimizers
opt_G   = torch.optim.Adam(list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))
opt_D_A = torch.optim.Adam(D_A.parameters(), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))
opt_D_B = torch.optim.Adam(D_B.parameters(), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))

# Schedulers: linear decay after LR_DECAY_START
def lambda_rule(epoch):
    if epoch < cfg.LR_DECAY_START:
        return 1.0
    return 1.0 - (epoch - cfg.LR_DECAY_START) / float(max(1, cfg.EPOCHS - cfg.LR_DECAY_START))

sch_G   = torch.optim.lr_scheduler.LambdaLR(opt_G,   lr_lambda=lambda_rule)
sch_D_A = torch.optim.lr_scheduler.LambdaLR(opt_D_A, lr_lambda=lambda_rule)
sch_D_B = torch.optim.lr_scheduler.LambdaLR(opt_D_B, lr_lambda=lambda_rule)

# Losses (LSGAN + L1)
adv_criterion   = nn.MSELoss()
recon_criterion = nn.L1Loss()

# Image pools (stabilize D)
class ImagePool:
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.images = []
    def query(self, images):
        if self.pool_size == 0:
            return images
        out = []
        for img in images:
            img = img.detach()
            if len(self.images) < self.pool_size:
                self.images.append(img)
                out.append(img)
            else:
                if random.random() > 0.5:
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = img
                    out.append(tmp)
                else:
                    out.append(img)
        return torch.stack(out, dim=0)

pool_A = ImagePool(50)
pool_B = ImagePool(50)

# Resume if latest exists
start_epoch = 1
latest = ckpt_dir / 'latest.pt'
if latest.exists():
    print("Resuming from", latest)
    s = torch.load(latest, map_location=cfg.DEVICE)
    G_A2B.load_state_dict(s['G_A2B']); G_B2A.load_state_dict(s['G_B2A'])
    D_A.load_state_dict(s['D_A']);     D_B.load_state_dict(s['D_B'])
    opt_G.load_state_dict(s['opt_G']); opt_D_A.load_state_dict(s['opt_D_A']); opt_D_B.load_state_dict(s['opt_D_B'])
    sch_G.load_state_dict(s['sch_G']); sch_D_A.load_state_dict(s['sch_D_A']); sch_D_B.load_state_dict(s['sch_D_B'])
    start_epoch = s['epoch'] + 1
    print("Start epoch:", start_epoch)


In [None]:
# ===== Cell 4 (REPLACED): Training Loop for CycleGAN with tqdm progress bar =====
from tqdm import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def save_checkpoint(epoch):
    obj = {
        'epoch': epoch,
        'G_A2B': G_A2B.state_dict(),
        'G_B2A': G_B2A.state_dict(),
        'D_A': D_A.state_dict(),
        'D_B': D_B.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D_A': opt_D_A.state_dict(),
        'opt_D_B': opt_D_B.state_dict(),
        'sch_G': sch_G.state_dict(),
        'sch_D_A': sch_D_A.state_dict(),
        'sch_D_B': sch_D_B.state_dict(),
        'cfg': asdict(cfg),
    }
    path = ckpt_dir / f'epoch_{epoch:03d}.pt'
    torch.save(obj, path)
    torch.save(obj, ckpt_dir / 'latest.pt')
    print(f'[Checkpoint] Saved {path}')

@torch.no_grad()
def visualize_epoch(epoch):
    G_A2B.eval(); G_B2A.eval()
    try:
        img_A, img_B, name_A, name_B = next(iter(val_loader))
    except StopIteration:
        return
    img_A = img_A.to(cfg.DEVICE); img_B = img_B.to(cfg.DEVICE)
    fake_B = G_A2B(img_A)

    paired = (name_A[0].split('.')[0] == name_B[0].split('.')[0])
    title_B = "Expected (paired B)" if paired else "Reference (real B, unpaired)"

    grid = make_grid(torch.cat([denorm(img_A), denorm(img_B), denorm(fake_B)], dim=0), nrow=img_A.size(0))
    out_path = samples_dir / f'epoch_{epoch:03d}.png'
    save_image(grid, out_path)
    print(f"[Sample] Saved {out_path}")

    plt.figure(figsize=(12,4))
    plt.imshow(grid.permute(1,2,0).cpu().numpy()); plt.axis('off')
    plt.title(f"Epoch {epoch} — Left: Actual (A), Middle: {title_B}, Right: Generated (A→B)")
    plt.show()

class LossBook:
    def __init__(self):
        self.rows = []
    def add(self, epoch, **kw):
        row = {'epoch': epoch}; row.update({k: float(v) for k,v in kw.items()})
        self.rows.append(row)
    def table(self):
        if pd is None: return None
        return pd.DataFrame(self.rows)

lossbook = LossBook()

real_label = 1.0
fake_label = 0.0

for epoch in range(start_epoch, cfg.EPOCHS + 1):
    G_A2B.train(); G_B2A.train(); D_A.train(); D_B.train()

    sums = {k:0.0 for k in [
        'G_total','G_adv_A2B','G_adv_B2A','cycle_A','cycle_B','id_A','id_B','D_A','D_B'
    ]}
    nb = 0

    # tqdm progress bar over train batches
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.EPOCHS}", ncols=100, leave=False)
    for imgs_A, imgs_B, _, _ in pbar:
        imgs_A = imgs_A.to(cfg.DEVICE)
        imgs_B = imgs_B.to(cfg.DEVICE)

        # === Generators ===
        opt_G.zero_grad()

        # Identity: G_A2B(B) ≈ B, G_B2A(A) ≈ A
        id_B = G_A2B(imgs_B)
        loss_id_B = recon_criterion(id_B, imgs_B) * cfg.LAMBDA_ID
        id_A = G_B2A(imgs_A)
        loss_id_A = recon_criterion(id_A, imgs_A) * cfg.LAMBDA_ID

        # GAN: A->B
        fake_B = G_A2B(imgs_A)
        pred_fake_B = D_B(fake_B)
        valid_B = torch.ones_like(pred_fake_B, device=cfg.DEVICE)
        loss_G_A2B = adv_criterion(pred_fake_B, valid_B)

        # GAN: B->A
        fake_A = G_B2A(imgs_B)
        pred_fake_A = D_A(fake_A)
        valid_A = torch.ones_like(pred_fake_A, device=cfg.DEVICE)
        loss_G_B2A = adv_criterion(pred_fake_A, valid_A)

        # Cycle
        rec_A = G_B2A(fake_B)
        rec_B = G_A2B(fake_A)
        loss_cyc_A = recon_criterion(rec_A, imgs_A) * cfg.LAMBDA_CYCLE
        loss_cyc_B = recon_criterion(rec_B, imgs_B) * cfg.LAMBDA_CYCLE

        loss_G = loss_G_A2B + loss_G_B2A + loss_cyc_A + loss_cyc_B + loss_id_A + loss_id_B
        loss_G.backward()
        opt_G.step()

        # === D_A ===
        opt_D_A.zero_grad()
        pred_real_A = D_A(imgs_A)
        valid = torch.ones_like(pred_real_A, device=cfg.DEVICE)
        loss_D_A_real = adv_criterion(pred_real_A, valid)

        fake_A_pool = pool_A.query(fake_A)
        pred_fake_A = D_A(fake_A_pool.detach())
        fake = torch.zeros_like(pred_fake_A, device=cfg.DEVICE)
        loss_D_A_fake = adv_criterion(pred_fake_A, fake)
        loss_DA = 0.5*(loss_D_A_real + loss_D_A_fake)
        loss_DA.backward()
        opt_D_A.step()

        # === D_B ===
        opt_D_B.zero_grad()
        pred_real_B = D_B(imgs_B)
        valid = torch.ones_like(pred_real_B, device=cfg.DEVICE)
        loss_D_B_real = adv_criterion(pred_real_B, valid)

        fake_B_pool = pool_B.query(fake_B)
        pred_fake_B = D_B(fake_B_pool.detach())
        fake = torch.zeros_like(pred_fake_B, device=cfg.DEVICE)
        loss_D_B_fake = adv_criterion(pred_fake_B, fake)
        loss_DB = 0.5*(loss_D_B_real + loss_D_B_fake)
        loss_DB.backward()
        opt_D_B.step()

        # Accumulate + update progress bar postfix
        sums['G_total']    += loss_G.item()
        sums['G_adv_A2B']  += loss_G_A2B.item()
        sums['G_adv_B2A']  += loss_G_B2A.item()
        sums['cycle_A']    += loss_cyc_A.item()
        sums['cycle_B']    += loss_cyc_B.item()
        sums['id_A']       += loss_id_A.item()
        sums['id_B']       += loss_id_B.item()
        sums['D_A']        += loss_DA.item()
        sums['D_B']        += loss_DB.item()
        nb += 1

        pbar.set_postfix({
            "G": f"{sums['G_total']/nb:.3f}",
            "D_A": f"{sums['D_A']/nb:.3f}",
            "D_B": f"{sums['D_B']/nb:.3f}",
        })

    # Step schedulers once per epoch
    sch_G.step(); sch_D_A.step(); sch_D_B.step()

    means = {k: v/max(1,nb) for k,v in sums.items()}
    lossbook.add(epoch, **means)

    # End-of-epoch summary line
    print(f"Epoch {epoch:03d}/{cfg.EPOCHS} | "
          f"G:{means['G_total']:.4f} | D_A:{means['D_A']:.4f} D_B:{means['D_B']:.4f} | "
          f"cycA:{means['cycle_A']:.3f} cycB:{means['cycle_B']:.3f} | idA:{means['id_A']:.3f} idB:{means['id_B']:.3f}")

    # Sample + table every SAMPLE_EVERY
    if epoch % cfg.SAMPLE_EVERY == 0:
        visualize_epoch(epoch)
        if pd is not None:
            df = lossbook.table()
            print(df.to_string(index=False))
        else:
            print("[Info] Install pandas for a formatted loss table.")

    # Checkpoint every SAVE_EVERY (and on final epoch)
    if epoch % cfg.SAVE_EVERY == 0 or epoch == cfg.EPOCHS:
        save_checkpoint(epoch)

# Save final CSV and show final sample
if pd is not None:
    df = lossbook.table()
    csv_path = logs_dir / 'epoch_losses.csv'
    df.to_csv(csv_path, index=False)
    print(f"[Log] Saved loss CSV to {csv_path}")

visualize_epoch(epoch=cfg.EPOCHS)


In [None]:
# ===== Cell 5: Install metrics libs (run once per fresh Colab) =====
!pip -q install torch-fidelity==0.3.0 lpips==0.1.4 piq==0.8.0


In [None]:
from torch_fidelity import calculate_metrics

In [None]:
# ===== Cell 6: Eval config, pairing, helpers =====
import os, glob, random, json
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T
import pandas as pd
import numpy as np

from torch_fidelity import calculate_metrics  # torch-fidelity
import lpips  # perceptual
import piq    # SSIM/PSNR

# --- Choose the split you want to evaluate (val or test) ---
EVAL_DIR_A = cfg.VAL_DIR_A  # '/content/drive/MyDrive/HER2/TrainValAB/valA'
EVAL_DIR_B = cfg.VAL_DIR_B  # '/content/drive/MyDrive/HER2/TrainValAB/valB'

# --- Where to dump generated images & results for this checkpoint ---
EVAL_TAG       = "epoch_150"  # <-- change to the checkpoint you want to evaluate
GEN_OUT_DIR    = samples_dir / f"eval_{EVAL_TAG}_A2B"
METRICS_OUTDIR = logs_dir / f"metrics_{EVAL_TAG}"
GEN_OUT_DIR.mkdir(parents=True, exist_ok=True)
METRICS_OUTDIR.mkdir(parents=True, exist_ok=True)

# Transforms to feed the net, and to convert back to PIL for saving
to_tensor = T.Compose([
    T.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize((0.5,)*3, (0.5,)*3)
])
to_pil = T.Compose([T.Lambda(lambda x: (x * 0.5 + 0.5).clamp(0,1)), T.ToPILImage()])

# Pair files by basename (only for LPIPS/SSIM/PSNR)
def list_images(folder):
    exts = ('.png','.jpg','.jpeg','.tif','.tiff','.bmp')
    return sorted([p for p in Path(folder).rglob('*') if p.suffix.lower() in exts])

paths_A = list_images(EVAL_DIR_A)
paths_B = list_images(EVAL_DIR_B)
bname_to_B = {Path(p).stem: p for p in paths_B}
pairs = []
for pA in paths_A:
    bn = Path(pA).stem
    if bn in bname_to_B:
        pairs.append((str(pA), bname_to_B[bn]))

print(f"Found {len(paths_A)} A images, {len(paths_B)} B images, and {len(pairs)} paired matches by filename.")


In [None]:
# ===== Cell 7: Load checkpoint & export A→B translations =====
# Load the checkpoint you want to evaluate
ckpt = torch.load(ckpt_dir / f"{EVAL_TAG}.pt", map_location=cfg.DEVICE)
G_A2B.load_state_dict(ckpt['G_A2B'])
G_B2A.load_state_dict(ckpt['G_B2A'])
G_A2B.eval(); G_B2A.eval()

# Generate A->B for all A images in EVAL_DIR_A
with torch.no_grad():
    for pA in paths_A:
        imgA = Image.open(pA).convert('RGB')
        x = to_tensor(imgA).unsqueeze(0).to(cfg.DEVICE)
        y_fake = G_A2B(x)[0].cpu()
        pil = to_pil(y_fake)
        out_path = GEN_OUT_DIR / (Path(pA).stem + ".png")
        pil.save(out_path)

print(f"Saved {len(paths_A)} generated images to {GEN_OUT_DIR}")


In [None]:
# ===== Cell 8: Compute FID & KID =====
# torch-fidelity expects directories with images
metrics = calculate_metrics(
    input1=str(GEN_OUT_DIR),
    input2=str(EVAL_DIR_B),
    cuda=torch.cuda.is_available(),
    isc=False, fid=True, kid=True, prc=False, verbose=False
)
with open(METRICS_OUTDIR / "fid_kid.json", "w") as f:
    json.dump(metrics, f, indent=2)
print(json.dumps(metrics, indent=2))


In [None]:
# ===== Cell 9 (REPLACED): LPIPS, SSIM, PSNR on paired matches =====
import torch
import torchvision.transforms as T
from PIL import Image
import pandas as pd
import lpips
import piq

# set eval size = training size used by the nets
EVAL_SIZE = cfg.IMG_SIZE

# transforms
to_unit = T.ToTensor()  # [0,1]
resize_01 = T.Compose([
    T.Resize((EVAL_SIZE, EVAL_SIZE), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
    T.ToTensor(),  # [0,1]
])
def to_m1p1(x01):  # [0,1] -> [-1,1]
    return x01 * 2 - 1

lpips_fn = lpips.LPIPS(net='vgg').to(cfg.DEVICE).eval()

rows = []
with torch.no_grad():
    for pA, pB in pairs:
        gen_path = GEN_OUT_DIR / (Path(pA).stem + ".png")
        if not gen_path.exists():
            continue

        # Load & resize both to the SAME size
        Gimg01 = resize_01(Image.open(gen_path).convert('RGB')).to(cfg.DEVICE).unsqueeze(0)  # [1,3,H,W], 0..1
        Bimg01 = resize_01(Image.open(pB).convert('RGB')).to(cfg.DEVICE).unsqueeze(0)

        # LPIPS expects [-1,1]
        Gm1p1 = to_m1p1(Gimg01)
        Bm1p1 = to_m1p1(Bimg01)
        lp = lpips_fn(Gm1p1, Bm1p1).item()

        # SSIM / PSNR on [0,1]
        ssim_val = piq.ssim(Gimg01, Bimg01, data_range=1.0).item()
        psnr_val = piq.psnr(Gimg01, Bimg01, data_range=1.0).item()

        rows.append({
            "basename": Path(pA).stem,
            "lpips": lp,
            "ssim": ssim_val,
            "psnr": psnr_val
        })

df = pd.DataFrame(rows).sort_values("basename")
df.to_csv(METRICS_OUTDIR / "paired_metrics.csv", index=False)

agg = {
    "N_pairs": int(len(df)),
    "LPIPS_mean": float(df["lpips"].mean()),
    "LPIPS_std": float(df["lpips"].std(ddof=0)),
    "SSIM_mean": float(df["ssim"].mean()),
    "SSIM_std": float(df["ssim"].std(ddof=0)),
    "PSNR_mean": float(df["psnr"].mean()),
    "PSNR_std": float(df["psnr"].std(ddof=0)),
}
with open(METRICS_OUTDIR / "paired_metrics_summary.json", "w") as f:
    json.dump(agg, f, indent=2)

print(agg)


In [None]:
# ===== Cell 10: Create a grid of examples for the thesis =====
import matplotlib.pyplot as plt
import math

N = 12  # how many triplets to show
sample_pairs = pairs[:]
random.shuffle(sample_pairs)
sample_pairs = sample_pairs[:N]

ncols = 3
nrows = N
fig, axes = plt.subplots(nrows, ncols, figsize=(9, 3*N), dpi=150)

for i, (pA, pB) in enumerate(sample_pairs):
    gen = GEN_OUT_DIR / (Path(pA).stem + ".png")
    A = Image.open(pA).convert('RGB')
    B = Image.open(pB).convert('RGB')
    G = Image.open(gen).convert('RGB')

    axes[i,0].imshow(A); axes[i,0].set_title("H&E (A)"); axes[i,0].axis('off')
    axes[i,1].imshow(B); axes[i,1].set_title("IHC – Expected (B)"); axes[i,1].axis('off')
    axes[i,2].imshow(G); axes[i,2].set_title("IHC – Generated (A→B)"); axes[i,2].axis('off')

plt.tight_layout()
fig_path = samples_dir / f"eval_{EVAL_TAG}_A-B-G_triplets.png"
plt.savefig(fig_path, bbox_inches='tight')
print(f"Saved figure: {fig_path}")


# Results

In [None]:
# === Comparison charts: Our model, Pix2Pix, CycleGAN Hybrid, CycleGAN Vanilla (e150), Paper (CycleGAN, ASP) ===
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import math

# ----- Fill in your numbers (MIST–HER2) -----
# If "Our model" should be something else, just edit the first row.
rows = [
    ("Our model (CycleGAN vanilla)", 110,  75.16, 0.1380, 14.55),   # <-- our model (edit if needed)
    ("Pix2pix (U-Net + PatchGAN)",   200, 325.08, 0.0765, 10.33),
    ("CycleGAN (hybrid)",            200,  93.30, 0.1208, 13.71),
    ("CycleGAN (vanilla)",           150,  69.24, 0.1400, 14.52),
    # Paper (Li et al., MICCAI'23) on MIST–HER2
    ("CycleGAN (paper, MIST–HER2)", "paper", 240.30, 0.6386, np.nan),  # PSNR not reported
    ("ASP (paper, MIST–HER2)",      "paper",  51.40, 0.4881, np.nan),  # PSNR not reported
]

df = pd.DataFrame(rows, columns=["Model","Epoch","FID","SSIM","PSNR (dB)"])

# --- Output folder & CSV
outdir = Path("charts_comparison"); outdir.mkdir(exist_ok=True)
(df.assign(Epoch=df["Epoch"].astype(str))
   .to_csv(outdir / "comparison_table.csv", index=False))

# --- Axis labels like "Model (e150)" or "(paper)"
def epoch_tag(e):
    try: return f"e{int(e)}"
    except: return str(e)

labels = df.apply(lambda r: f"{r['Model']} ({epoch_tag(r['Epoch'])})", axis=1)

# --- Colors
palette = {
    "Our model (CycleGAN vanilla) (e110)": "#1f77b4",  # blue
    "Pix2pix (U-Net + PatchGAN) (e200)":   "#ff7f0e",  # orange
    "CycleGAN (hybrid) (e200)":            "#2ca02c",  # green
    "CycleGAN (vanilla) (e150)":           "#d62728",  # red
    "CycleGAN (paper, MIST–HER2) (paper)": "#9467bd",  # purple
    "ASP (paper, MIST–HER2) (paper)":      "#8c564b",  # brown
}
colors = [palette.get(lbl, "#888888") for lbl in labels]

def plot_metric(metric, lower_is_better=False, title_prefix="MIST–HER2"):
    x = np.arange(len(labels))
    vals = df[metric].astype(float).to_numpy()
    finite_vals = vals[np.isfinite(vals)]
    ymax = float(finite_vals.max()) * 1.25 if finite_vals.size else 1.0

    plt.figure(figsize=(11,5))
    plt.bar(x, vals, color=colors)
    plt.xticks(x, labels, rotation=20, ha="right")
    plt.ylabel(metric)
    title_dir = "(lower is better)" if lower_is_better else "(higher is better)"
    plt.title(f"{title_prefix}: {metric} by model {title_dir}")
    plt.ylim(0, ymax)

    # annotate numeric bars
    for xi, v in zip(x, vals):
        if np.isfinite(v):
            plt.text(xi, v, f"{v:.2f}", ha="center", va="bottom", fontsize=9, rotation=90)

    plt.tight_layout()
    fname = f"{title_prefix.replace(' ','_')}_{metric.replace(' ','_').replace('(','').replace(')','')}.png"
    plt.savefig(outdir / fname, dpi=200)
    plt.show()
    print("Saved:", outdir / fname)

# --- Build the three charts
plot_metric("FID", lower_is_better=True)
plot_metric("SSIM", lower_is_better=False)
plot_metric("PSNR (dB)", lower_is_better=False)

print("CSV saved to:", (outdir / "comparison_table.csv").resolve())
print("Charts saved in:", outdir.resolve())
