<a href="https://colab.research.google.com/github/saisubash1013/MSc-Project/blob/main/H%26E_to_IHC_Project_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pix2Pix

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

DATA_PATH = "/content/drive/MyDrive/HER2/TrainValAB"

In [None]:
# Cell 2a: Check drive mount and find dataset
import os

# Mount Google Drive if not already mounted
from google.colab import drive
drive.mount('/content/drive')

# Check if the path exists
base_path = "/content/drive/MyDrive"
print("Contents of MyDrive:")
print(os.listdir(base_path))

# Look for HER2 folder
if 'HER2' in os.listdir(base_path):
    print("\nContents of HER2 folder:")
    print(os.listdir(f"{base_path}/HER2"))
else:
    print("\nHER2 folder not found. Available folders:")
    for item in os.listdir(base_path):
        if os.path.isdir(f"{base_path}/{item}"):
            print(f"  {item}/")

In [None]:
# Cell 2b: Check dataset structure
DATA_PATH = "/content/drive/MyDrive/HER2/TrainValAB"

print("Contents of TrainValAB:")
print(os.listdir(DATA_PATH))

# Check each folder
for folder in ['trainA', 'trainB', 'valA', 'valB']:
    folder_path = os.path.join(DATA_PATH, folder)
    if os.path.exists(folder_path):
        count = len(os.listdir(folder_path))
        print(f"{folder}: {count} images")
    else:
        print(f"{folder}: NOT FOUND")

In [None]:
# Cell 3: Generator (U-Net)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Encoder
        self.enc1 = self.conv_block(3, 64, normalize=False)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 512)

        # Decoder
        self.dec1 = self.upconv_block(512, 512)
        self.dec2 = self.upconv_block(1024, 256)
        self.dec3 = self.upconv_block(512, 128)
        self.dec4 = self.upconv_block(256, 64)
        self.dec5 = nn.ConvTranspose2d(128, 3, 4, 2, 1)

    def conv_block(self, in_channels, out_channels, normalize=True):
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)

        # Decoder with skip connections
        d1 = self.dec1(e5)
        d1 = torch.cat([d1, e4], dim=1)

        d2 = self.dec2(d1)
        d2 = torch.cat([d2, e3], dim=1)

        d3 = self.dec3(d2)
        d3 = torch.cat([d3, e2], dim=1)

        d4 = self.dec4(d3)
        d4 = torch.cat([d4, e1], dim=1)

        output = torch.tanh(self.dec5(d4))
        return output

generator = Generator().to(device)
print("Generator created")
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")

In [None]:
# Cell 4: Discriminator (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, input_img, target_img):
        x = torch.cat([input_img, target_img], dim=1)
        return self.model(x)

discriminator = Discriminator().to(device)
print("Discriminator created")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

In [None]:
# Cell 5: Training Setup
criterion_GAN = nn.BCELoss()
criterion_L1 = nn.L1Loss()

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training parameters
num_epochs = 10  # Start with 10 epochs
lambda_L1 = 100
start_epoch = 0

# Create checkpoint directory
os.makedirs('/content/drive/MyDrive/checkpointsPix2Pix', exist_ok=True)

print("Training setup complete")
print(f"Will train for {num_epochs} epochs")
print(f"L1 loss weight: {lambda_L1}")

In [None]:
# Cell 6: Resume Training Function
def load_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Resumed from epoch {epoch}")
    return epoch

# Uncomment to resume training from specific epoch
# start_epoch = load_checkpoint('/content/drive/MyDrive/checkpointsPix2Pix/checkpoint_epoch_5.pth')

print("Resume function ready")

In [None]:
# Cell 7: Training Loop with Checkpoints
generator.train()
discriminator.train()

training_losses = {'g_losses': [], 'd_losses': []}

for epoch in range(start_epoch, num_epochs):
    epoch_g_loss = 0
    epoch_d_loss = 0

    for i, (he_images, ihc_images) in enumerate(train_loader):
        he_images = he_images.to(device)
        ihc_images = ihc_images.to(device)

        batch_size = he_images.size(0)
        real_labels = torch.ones(batch_size, 1, 30, 30).to(device)
        fake_labels = torch.zeros(batch_size, 1, 30, 30).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()

        real_output = discriminator(he_images, ihc_images)
        d_real_loss = criterion_GAN(real_output, real_labels)

        fake_ihc = generator(he_images)
        fake_output = discriminator(he_images, fake_ihc.detach())
        d_fake_loss = criterion_GAN(fake_output, fake_labels)

        d_loss = (d_real_loss + d_fake_loss) * 0.5
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()

        fake_output = discriminator(he_images, fake_ihc)
        g_gan_loss = criterion_GAN(fake_output, real_labels)
        g_l1_loss = criterion_L1(fake_ihc, ihc_images)

        g_loss = g_gan_loss + lambda_L1 * g_l1_loss
        g_loss.backward()
        optimizer_G.step()

        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

        if i % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], '
                  f'D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

    # Calculate average losses
    avg_g_loss = epoch_g_loss / len(train_loader)
    avg_d_loss = epoch_d_loss / len(train_loader)

    training_losses['g_losses'].append(avg_g_loss)
    training_losses['d_losses'].append(avg_d_loss)

    # Save checkpoint every epoch
    checkpoint = {
        'epoch': epoch + 1,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'g_loss': avg_g_loss,
        'd_loss': avg_d_loss,
        'training_losses': training_losses
    }
    torch.save(checkpoint, f'/content/drive/MyDrive/checkpointsPix2Pix/checkpoint_epoch_{epoch+1}.pth')

    print(f'Epoch [{epoch+1}/{num_epochs}] - G Loss: {avg_g_loss:.4f}, D Loss: {avg_d_loss:.4f} - SAVED')

print("Training completed!")

In [None]:
# --- STEP 3: Full evaluation (auto-detects Pix2Pix vs Cycle U-Net), paired metrics + FID ---
import os, glob, csv, time
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.models import inception_v3, Inception_V3_Weights
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from scipy import linalg

# -------- use paths from STEP 2 --------
assert 'CHECKPOINT_PATH' in globals() and os.path.exists(CHECKPOINT_PATH), "Checkpoint not found. Re-run Step 2."
assert 'VALA_DIR' in globals() and 'VALB_DIR' in globals() and os.path.exists(VALA_DIR) and os.path.exists(VALB_DIR), "valA/valB not found. Re-run Step 2."

ROOT = os.path.dirname(os.path.dirname(CHECKPOINT_PATH))  # .../HER2
EVAL_DIR = os.path.join(ROOT, "pix2pix_eval", "epoch_200_autodetect")
FAKE_DIR = os.path.join(EVAL_DIR, "fakeIHC")
TRIP_DIR = os.path.join(EVAL_DIR, "triptychs_10")
os.makedirs(FAKE_DIR, exist_ok=True)
os.makedirs(TRIP_DIR, exist_ok=True)

CSV_PATH = os.path.join(EVAL_DIR, "metrics_per_image.csv")
SUMMARY_PATH = os.path.join(EVAL_DIR, "metrics_summary.txt")

N_EVAL = 1000           # evaluate up to this many pairs
SAVE_TRIPTYCHS = True   # save 10 H&E|Real|Fake triptychs
COMPUTE_FID = True      # set-level realism
FID_BATCH = 32

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("Using checkpoint:", CHECKPOINT_PATH)

# ----- Generators -----
class Pix2PixUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = self._c(3,64,False); self.enc2 = self._c(64,128); self.enc3 = self._c(128,256)
        self.enc4 = self._c(256,512); self.enc5 = self._c(512,512)
        self.dec1 = self._u(512,512); self.dec2 = self._u(1024,256)
        self.dec3 = self._u(512,128); self.dec4 = self._u(256,64)
        self.dec5 = nn.ConvTranspose2d(128,3,4,2,1)
    def _c(self,i,o,norm=True):
        layers=[nn.Conv2d(i,o,4,2,1)]
        if norm: layers.append(nn.BatchNorm2d(o))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)
    def _u(self,i,o):
        return nn.Sequential(nn.ConvTranspose2d(i,o,4,2,1), nn.BatchNorm2d(o), nn.ReLU())
    def forward(self,x):
        e1=self.enc1(x); e2=self.enc2(e1); e3=self.enc3(e2); e4=self.enc4(e3); e5=self.enc5(e4)
        d1=self.dec1(e5); d1=torch.cat([d1,e4],1)
        d2=self.dec2(d1); d2=torch.cat([d2,e3],1)
        d3=self.dec3(d2); d3=torch.cat([d3,e2],1)
        d4=self.dec4(d3); d4=torch.cat([d4,e1],1)
        return torch.tanh(self.dec5(d4))

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, down=True, act="relu", use_bn=True):
        super().__init__()
        norm = nn.InstanceNorm2d(out_ch) if use_bn else nn.Identity()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False, padding_mode="reflect") if down
            else nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            norm,
            nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
    def forward(self,x): return self.conv(x)

class CycleUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, feat=64):
        super().__init__()
        self.initial_down = ConvBlock(in_ch, feat, True, "leaky", False)
        self.down1=ConvBlock(feat,feat*2,True,"leaky")
        self.down2=ConvBlock(feat*2,feat*4,True,"leaky")
        self.down3=ConvBlock(feat*4,feat*8,True,"leaky")
        self.down4=ConvBlock(feat*8,feat*8,True,"leaky")
        self.down5=ConvBlock(feat*8,feat*8,True,"leaky")
        self.down6=ConvBlock(feat*8,feat*8,True,"leaky")
        self.bottleneck=nn.Sequential(nn.Conv2d(feat*8,feat*8,4,2,1,padding_mode="reflect"), nn.ReLU())
        self.up1=ConvBlock(feat*8,feat*8,False,"relu",True)
        self.up2=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up3=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up4=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up5=ConvBlock(feat*16,feat*4,False,"relu",True)
        self.up6=ConvBlock(feat*8,feat*2,False,"relu",True)
        self.up7=ConvBlock(feat*4,feat,False,"relu",True)
        self.final_up=nn.Sequential(nn.ConvTranspose2d(feat*2,out_ch,4,2,1), nn.Tanh())
    def forward(self,x):
        d1=self.initial_down(x); d2=self.down1(d1); d3=self.down2(d2)
        d4=self.down3(d3); d5=self.down4(d4); d6=self.down5(d5); d7=self.down6(d6)
        b=self.bottleneck(d7)
        u1=self.up1(b); u2=self.up2(torch.cat([u1,d7],1))
        u3=self.up3(torch.cat([u2,d6],1)); u4=self.up4(torch.cat([u3,d5],1))
        u5=self.up5(torch.cat([u4,d4],1)); u6=self.up6(torch.cat([u5,d3],1))
        u7=self.up7(torch.cat([u6,d2],1))
        return self.final_up(torch.cat([u7,d1],1))

# ----- Load checkpoint & auto-detect arch -----
def load_state_dict_from_checkpoint(path, device):
    ckpt = torch.load(path, map_location=device)
    if isinstance(ckpt, dict):
        for k in ["generator_state_dict","gen_A_state_dict","state_dict","model","netG","G"]:
            if k in ckpt and isinstance(ckpt[k], dict):
                sd = ckpt[k]; break
        else:
            if all(torch.is_tensor(v) for v in ckpt.values()):
                sd = ckpt
            else:
                raise RuntimeError(f"Unrecognized checkpoint keys: {list(ckpt.keys())[:8]}")
    else:
        raise RuntimeError("Checkpoint must be a dict/state_dict.")
    first = next(iter(sd))
    if first.startswith("module."):
        sd = {k.replace("module.","",1): v for k,v in sd.items()}
    return sd

def detect_arch(sd_keys):
    ks = list(sd_keys)
    if any(k.startswith("enc1") or ".enc1." in k for k in ks): return "pix2pix"
    if any(k.startswith("initial_down") or ".initial_down." in k for k in ks): return "cyc_unet"
    if any(".down1." in k for k in ks): return "cyc_unet"
    return "unknown"

sd = load_state_dict_from_checkpoint(CHECKPOINT_PATH, DEVICE)
arch = detect_arch(sd.keys())
print("Detected checkpoint architecture:", arch)

if arch == "pix2pix":
    G = Pix2PixUNet().to(DEVICE).eval()
elif arch == "cyc_unet":
    G = CycleUNet().to(DEVICE).eval()
else:
    raise RuntimeError("Could not detect architecture from checkpoint keys.")
G.load_state_dict(sd, strict=False)

# ----- Pair files by basename -----
def index_by_base(folder):
    idx = {}
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp","*.webp"):
        for p in glob.glob(os.path.join(folder, ext)):
            idx[os.path.splitext(os.path.basename(p))[0]] = p
    return idx

A_idx = index_by_base(VALA_DIR)   # H&E
B_idx = index_by_base(VALB_DIR)   # real IHC
common = sorted(set(A_idx) & set(B_idx))
assert common, "No matching basenames between valA and valB."
common = common[:N_EVAL]
print(f"Evaluating {len(common)} paired tiles.")

# ----- Transforms -----
tx_in = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
tx_01 = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),  # [0,1]
])
to01 = lambda t: (t*0.5 + 0.5).clamp(0,1)
to_pil = transforms.ToPILImage()
def tnp01(t): return t.permute(1,2,0).cpu().numpy().astype(np.float32)

# ----- Optional FID setup -----
if COMPUTE_FID:
    weights = Inception_V3_Weights.IMAGENET1K_V1
    inception = inception_v3(weights=weights).to(DEVICE).eval()
    inception.fc = nn.Identity()
    pre_inception = weights.transforms()
    gen_pils, real_pils = [], []

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    if np.iscomplexobj(covmean): covmean = covmean.real
    diff = mu1 - mu2
    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2*np.trace(covmean)

# ----- Evaluate -----
rows = []
trip_bases = common[:10] if SAVE_TRIPTYCHS else []
t0 = time.time()

with torch.no_grad(), open(CSV_PATH, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["basename","ssim","psnr","mae","mse"])

    for base in tqdm(common, desc="Evaluating"):
        he_p, ihc_p = A_idx[base], B_idx[base]
        he_img = Image.open(he_p).convert("RGB")
        ihc_img = Image.open(ihc_p).convert("RGB")

        x = tx_in(he_img).unsqueeze(0).to(DEVICE)  # [-1,1]
        fake = G(x)                                # [-1,1]
        fake01 = to01(fake)[0]                     # [C,H,W] on GPU
        fake01_cpu = fake01.detach().cpu()         # move to CPU for saving/concat

        real01 = tx_01(ihc_img)                    # [C,H,W] CPU

        # Save generated
        save_image(fake01_cpu, os.path.join(FAKE_DIR, base + "_fakeIHC.png"))

        # Metrics on [0,1]
        g_np = tnp01(fake01_cpu)
        r_np = tnp01(real01)
        s = ssim(r_np, g_np, data_range=1.0, channel_axis=2)
        p = psnr(r_np, g_np, data_range=1.0)
        a = float(np.abs(r_np - g_np).mean())
        m = float(((r_np - g_np)**2).mean())
        writer.writerow([base, f"{s:.6f}", f"{p:.6f}", f"{a:.6f}", f"{m:.6f}"])
        rows.append((s,p,a,m))

        # Triptych (CPU tensors only)
        if base in trip_bases:
            he01 = tx_01(he_img)                  # CPU
            trip = torch.cat([he01, real01, fake01_cpu], dim=2)
            save_image(trip, os.path.join(TRIP_DIR, base + "_TRIPTYCH.png"))

        # FID prep (store PILs)
        if COMPUTE_FID:
            gen_pils.append(to_pil(fake01_cpu))
            real_pils.append(ihc_img)

# ----- Summary -----
import statistics as stats
S=[r[0] for r in rows]; P=[r[1] for r in rows]; A=[r[2] for r in rows]; M=[r[3] for r in rows]
summary = [
    f"Detected arch: {arch}",
    f"Checkpoint: {os.path.basename(CHECKPOINT_PATH)}",
    f"N evaluated: {len(rows)}",
    f"SSIM: mean={stats.mean(S):.4f}  median={stats.median(S):.4f}  std={stats.pstdev(S):.4f}",
    f"PSNR: mean={stats.mean(P):.2f} dB  median={stats.median(P):.2f} dB  std={stats.pstdev(P):.2f}",
    f"MAE:  mean={stats.mean(A):.4f}",
    f"MSE:  mean={stats.mean(M):.5f}",
]

# ----- FID -----
if COMPUTE_FID:
    def feats_from_pils(pils, bs=FID_BATCH):
        feats=[]
        for i in tqdm(range(0, len(pils), bs), desc="Inception features"):
            xb = torch.stack([pre_inception(im) for im in pils[i:i+bs]], 0).to(DEVICE)
            with torch.no_grad(): feats.append(inception(xb).cpu().numpy())
        return np.concatenate(feats, 0)

    fake_feats = feats_from_pils(gen_pils, bs=FID_BATCH)
    real_feats = feats_from_pils(real_pils, bs=FID_BATCH)
    mu_f, sigma_f = fake_feats.mean(0), np.cov(fake_feats, rowvar=False)
    mu_r, sigma_r = real_feats.mean(0), np.cov(real_feats, rowvar=False)
    fid = float(frechet_distance(mu_f, sigma_f, mu_r, sigma_r))
    summary.append(f"FID: {fid:.4f}")

summary.append(f"Time: {(time.time()-t0)/60:.1f} min")
print("\n" + "\n".join(summary))

with open(SUMMARY_PATH, "w") as f:
    f.write("\n".join(summary))

print("\nPer-image CSV:", CSV_PATH)
print("Summary text :", SUMMARY_PATH)
print("Fakes saved  :", FAKE_DIR)
if SAVE_TRIPTYCHS:
    print("Triptychs (10):", TRIP_DIR)


In [None]:
# --- Add-on: Compute KID × 1000 for the Pix2Pix/Cycle-UNet eval using saved fakes ---
# Uses images already in FAKE_DIR and real IHCs in VALB_DIR.
import os, glob, math, random
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torchvision.models import inception_v3, Inception_V3_Weights

# Reuse these from STEP 3; adjust if needed
assert 'FAKE_DIR' in globals() and os.path.exists(FAKE_DIR), "FAKE_DIR missing; run STEP 3 first."
assert 'VALB_DIR' in globals() and os.path.exists(VALB_DIR), "VALB_DIR missing."

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("FAKE_DIR:", FAKE_DIR)
print("VALB_DIR:", VALB_DIR)

# Index reals by basename
def index_by_base(folder):
    idx = {}
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp","*.webp"):
        for p in glob.glob(os.path.join(folder, ext)):
            idx[os.path.splitext(os.path.basename(p))[0]] = p
    return idx

real_idx = index_by_base(VALB_DIR)

# Collect matching pairs from FAKE_DIR
fake_paths = sorted(glob.glob(os.path.join(FAKE_DIR, "*_fakeIHC.png")))
pairs = []
for fp in fake_paths:
    base = os.path.basename(fp).replace("_fakeIHC.png", "")
    if base in real_idx:
        pairs.append((fp, real_idx[base]))
print(f"Found {len(pairs)} fake-vs-real pairs for KID.")

# Inception feature extractor
weights = Inception_V3_Weights.IMAGENET1K_V1
preproc = weights.transforms()
inception = inception_v3(weights=weights).to(DEVICE).eval()
inception.fc = nn.Identity()  # 2048-d features

def feats_from_pairs(pairs, which="fake", bs=32):
    imgs = [Image.open(fp if which=="fake" else rp).convert("RGB") for fp, rp in pairs]
    F = []
    for i in tqdm(range(0, len(imgs), bs), desc=f"Inception features ({which})"):
        xb = torch.stack([preproc(im) for im in imgs[i:i+bs]], 0).to(DEVICE)
        with torch.no_grad():
            f = inception(xb).cpu().numpy()
        F.append(f)
    return np.concatenate(F, 0)

fake_feats = feats_from_pairs(pairs, "fake", bs=32)
real_feats = feats_from_pairs(pairs, "real", bs=32)
print("Feature shapes:", fake_feats.shape, real_feats.shape)

# --- KID (unbiased MMD^2 with polynomial kernel (x^T y / d + 1)^3), averaged over splits ---
def polynomial_mmd2_unbiased(X, Y, degree=3, gamma=None, coef0=1.0):
    # X, Y: [n, d]
    n = X.shape[0]
    d = X.shape[1]
    if gamma is None:
        gamma = 1.0 / d
    # shuffle to avoid diagonal bias, then use off-diagonals for unbiased estimate
    idx = np.random.permutation(n)
    X = X[idx]
    Y = Y[idx]
    XX = (X @ X.T) * gamma + coef0
    YY = (Y @ Y.T) * gamma + coef0
    XY = (X @ Y.T) * gamma + coef0
    XX = XX**degree
    YY = YY**degree
    XY = XY**degree
    # Unbiased MMD^2: exclude diagonals for XX and YY
    np.fill_diagonal(XX, 0.0)
    np.fill_diagonal(YY, 0.0)
    mmd2 = XX.sum()/(n*(n-1)) + YY.sum()/(n*(n-1)) - 2.0*XY.mean()
    return float(mmd2)

def compute_kid(feats_fake, feats_real, n_splits=10, max_subset=1000, seed=42):
    rng = np.random.default_rng(seed)
    n = min(len(feats_fake), len(feats_real), max_subset)
    scores = []
    for _ in range(n_splits):
        idx_f = rng.choice(len(feats_fake), n, replace=False)
        idx_r = rng.choice(len(feats_real), n, replace=False)
        scores.append(polynomial_mmd2_unbiased(feats_fake[idx_f], feats_real[idx_r]))
    return float(np.mean(scores)), float(np.std(scores))

kid_mean, kid_std = compute_kid(fake_feats, real_feats, n_splits=10, max_subset=1000, seed=123)
print(f"\nKID × 1000: {kid_mean*1000:.2f} ± {kid_std*1000:.2f}")


In [None]:
# === Generate 10 triptychs from the given epoch-200 checkpoint ===
import os, glob
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CHECKPOINT_PATH = "/content/drive/MyDrive/HER2/pix2pix_checkpoints/generator_epoch_200.pth"
VALA_DIR = "/content/drive/MyDrive/HER2/TrainValAB/valA"  # H&E
VALB_DIR = "/content/drive/MyDrive/HER2/TrainValAB/valB"  # IHC
OUT_DIR  = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"
os.makedirs(OUT_DIR, exist_ok=True)

assert os.path.exists(CHECKPOINT_PATH), f"Checkpoint not found: {CHECKPOINT_PATH}"
assert os.path.exists(VALA_DIR) and os.path.exists(VALB_DIR), "valA/valB paths not found."

# ---- Model defs ----
class Pix2PixUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1=self._c(3,64,False); self.enc2=self._c(64,128); self.enc3=self._c(128,256)
        self.enc4=self._c(256,512); self.enc5=self._c(512,512)
        self.dec1=self._u(512,512); self.dec2=self._u(1024,256)
        self.dec3=self._u(512,128); self.dec4=self._u(256,64)
        self.dec5=nn.ConvTranspose2d(128,3,4,2,1)
    def _c(self,i,o,norm=True):
        layers=[nn.Conv2d(i,o,4,2,1)]
        if norm: layers.append(nn.BatchNorm2d(o))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)
    def _u(self,i,o):
        return nn.Sequential(nn.ConvTranspose2d(i,o,4,2,1), nn.BatchNorm2d(o), nn.ReLU())
    def forward(self,x):
        e1=self.enc1(x); e2=self.enc2(e1); e3=self.enc3(e2); e4=self.enc4(e3); e5=self.enc5(e4)
        d1=self.dec1(e5); d1=torch.cat([d1,e4],1)
        d2=self.dec2(d1); d2=torch.cat([d2,e3],1)
        d3=self.dec3(d2); d3=torch.cat([d3,e2],1)
        d4=self.dec4(d3); d4=torch.cat([d4,e1],1)
        return torch.tanh(self.dec5(d4))

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, down=True, act="relu", use_bn=True):
        super().__init__()
        norm = nn.InstanceNorm2d(out_ch) if use_bn else nn.Identity()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False, padding_mode="reflect") if down
            else nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            norm, nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
    def forward(self,x): return self.conv(x)

class CycleUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, feat=64):
        super().__init__()
        self.initial_down=ConvBlock(in_ch,feat,True,"leaky",False)
        self.down1=ConvBlock(feat,feat*2,True,"leaky"); self.down2=ConvBlock(feat*2,feat*4,True,"leaky")
        self.down3=ConvBlock(feat*4,feat*8,True,"leaky"); self.down4=ConvBlock(feat*8,feat*8,True,"leaky")
        self.down5=ConvBlock(feat*8,feat*8,True,"leaky"); self.down6=ConvBlock(feat*8,feat*8,True,"leaky")
        self.bottleneck=nn.Sequential(nn.Conv2d(feat*8,feat*8,4,2,1,padding_mode="reflect"), nn.ReLU())
        self.up1=ConvBlock(feat*8,feat*8,False,"relu",True)
        self.up2=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up3=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up4=ConvBlock(feat*16,feat*8,False,"relu",True)
        self.up5=ConvBlock(feat*16,feat*4,False,"relu",True)
        self.up6=ConvBlock(feat*8,feat*2,False,"relu",True)
        self.up7=ConvBlock(feat*4,feat,False,"relu",True)
        self.final_up=nn.Sequential(nn.ConvTranspose2d(feat*2,out_ch,4,2,1), nn.Tanh())
    def forward(self,x):
        d1=self.initial_down(x); d2=self.down1(d1); d3=self.down2(d2)
        d4=self.down3(d3); d5=self.down4(d4); d6=self.down5(d5); d7=self.down6(d6)
        b=self.bottleneck(d7)
        u1=self.up1(b); u2=self.up2(torch.cat([u1,d7],1))
        u3=self.up3(torch.cat([u2,d6],1)); u4=self.up4(torch.cat([u3,d5],1))
        u5=self.up5(torch.cat([u4,d4],1)); u6=self.up6(torch.cat([u5,d3],1))
        u7=self.up7(torch.cat([u6,d2],1))
        return self.final_up(torch.cat([u7,d1],1))

# ---- load checkpoint and detect arch ----
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
sd = None
if isinstance(ckpt, dict):
    for k in ["generator_state_dict","gen_A_state_dict","state_dict","model","netG","G"]:
        if k in ckpt and isinstance(ckpt[k], dict):
            sd = ckpt[k]; break
    if sd is None and all(torch.is_tensor(v) for v in ckpt.values()):
        sd = ckpt
else:
    raise RuntimeError("Unexpected checkpoint format.")
first = next(iter(sd))
if first.startswith("module."):
    sd = {k.replace("module.","",1): v for k,v in sd.items()}

def detect_arch(keys):
    ks = list(keys)
    if any(k.startswith("enc1") or ".enc1." in k for k in ks): return "pix2pix"
    if any(k.startswith("initial_down") or ".initial_down." in k for k in ks): return "cyc_unet"
    if any(".down1." in k for k in ks): return "cyc_unet"
    return "unknown"

arch = detect_arch(sd.keys())
print("Detected arch:", arch)

G = Pix2PixUNet().to(DEVICE).eval() if arch=="pix2pix" else CycleUNet().to(DEVICE).eval()
G.load_state_dict(sd, strict=False)

# ---- pair 10 filenames by basename ----
def index_by_base(folder):
    idx = {}
    for ext in ("*.png","*.jpg","*.jpeg","*.tif","*.tiff","*.bmp","*.webp"):
        for p in glob.glob(os.path.join(folder, ext)):
            idx[os.path.splitext(os.path.basename(p))[0]] = p
    return idx

A = index_by_base(VALA_DIR); B = index_by_base(VALB_DIR)
common = sorted(set(A) & set(B))[:10]
assert common, "No matching basenames between valA and valB."

# ---- transforms ----
tx_in = transforms.Compose([
    transforms.Resize(256), transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
])
tx_01 = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
to01 = lambda t: (t*0.5 + 0.5).clamp(0,1)

# ---- generate ----
print("Saving 10 triptychs to:", OUT_DIR)
with torch.no_grad():
    for base in common:
        he_img  = Image.open(A[base]).convert("RGB")
        ihc_img = Image.open(B[base]).convert("RGB")
        x = tx_in(he_img).unsqueeze(0).to(DEVICE)
        fake = G(x)
        fake01 = to01(fake)[0].cpu()
        he01   = tx_01(he_img)
        real01 = tx_01(ihc_img)
        trip = torch.cat([he01, real01, fake01], dim=2)
        save_image(trip,   os.path.join(OUT_DIR, f"{base}_TRIPTYCH_e200.png"))
        save_image(fake01, os.path.join(OUT_DIR, f"{base}_fakeIHC_e200.png"))
print("Done.")


In [None]:
# View the 10 saved triptychs (H&E | Real IHC | Generated IHC)
import os, glob, math
from PIL import Image
import matplotlib.pyplot as plt

OUT_DIR = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"

trip_paths = sorted(glob.glob(os.path.join(OUT_DIR, "*_TRIPTYCH_e200.png")))
print(f"Found {len(trip_paths)} triptychs in:", OUT_DIR)
assert trip_paths, f"No triptych images found in {OUT_DIR}. Check that the previous generation step ran."

cols = 5
rows = math.ceil(len(trip_paths) / cols)
plt.figure(figsize=(cols*4, rows*4))
for i, p in enumerate(trip_paths[:rows*cols]):
    img = Image.open(p).convert("RGB")
    ax = plt.subplot(rows, cols, i+1)
    ax.imshow(img)
    ax.set_title(os.path.basename(p), fontsize=8)
    ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
import random, matplotlib.pyplot as plt
from PIL import Image

OUT_DIR = "/content/drive/MyDrive/HER2/pix2pix_eval/epoch_200_preview_10"
trip_paths = sorted(glob.glob(os.path.join(OUT_DIR, "*_TRIPTYCH_e200.png")))
p = random.choice(trip_paths)  # or set p to a specific file path
plt.figure(figsize=(12,4))
plt.imshow(Image.open(p).convert("RGB"))
plt.title(os.path.basename(p))
plt.axis('off')
plt.show()


# **CycleGAN after pix**

In [None]:
# Cell 1: Imports and Global Configuration

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import os
from tqdm import tqdm
import random
import numpy as np
from google.colab import drive
from torch.utils.data import DataLoader
import torch.autograd as autograd # For gradient penalty
# Updated AMP syntax for both GradScaler and autocast
from torch.amp import GradScaler, autocast # Use torch.amp for both
from torchvision.models import vgg19 # For Perceptual Loss

# --- Mount Google Drive ---
drive.mount('/content/drive')

# --- Configuration ---
# Set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Dataset paths
# Although CycleGAN can work with unpaired data, we'll use your existing
# paired folder structure for convenience. CycleGAN's loss will handle
# the 'unpaired' nature internally.
TRAIN_DIR_A = '/content/drive/MyDrive/HER2/TrainValAB/trainA' # H&E training images
TRAIN_DIR_B = '/content/drive/MyDrive/HER2/TrainValAB/trainB' # IHC training images
VAL_DIR_A = '/content/drive/MyDrive/HER2/TrainValAB/valA'     # H&E validation images
VAL_DIR_B = '/content/drive/MyDrive/HER2/TrainValAB/valB'     # IHC validation images

# Checkpoint and sample image saving paths for CycleGAN
CHECKPOINT_DIR = '/content/drive/MyDrive/HER2/cyclegan_checkpoints'
SAMPLE_IMAGES_DIR = '/content/drive/MyDrive/HER2/cyclegan_sample_images'
TRAINING_LOG_PATH = os.path.join(CHECKPOINT_DIR, 'cyclegan_training_loss_history.csv')

# Ensure directories exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SAMPLE_IMAGES_DIR, exist_ok=True)

# Hyperparameters
LEARNING_RATE_GEN = 0.00005 # Adjusted: Further reduced Generator LR for WGAN-GP stability
LEARNING_RATE_DISC = 0.000005 # Adjusted: Further reduced Discriminator LR for WGAN-GP stability
BATCH_SIZE = 1 # CycleGAN often uses batch size 1
NUM_EPOCHS = 200
LOAD_MODEL = True # Set to False for a fresh start as requested
SAVE_MODEL = True
IMAGE_SIZE = 256 # Images will be resized to this resolution

# Loss weights for CycleGAN
LAMBDA_CYCLE = 10.0 # Weight for cycle consistency loss (crucial for CycleGAN)
LAMBDA_IDENTITY = 0.5 * LAMBDA_CYCLE # Weight for identity mapping loss (prevents color shifts)
LAMBDA_GP = 10.0 # Weight for Gradient Penalty in WGAN-GP
LAMBDA_PERCEPTUAL = 0.01 # Adjusted: Reduced Perceptual Loss weight to prevent instability

# Learning Rate Decay parameters
DECAY_EPOCH_START = 100 # Epoch to start linear decay of learning rates

# Image transformations
# Note: CycleGAN often uses a slightly different transform set than Pix2Pix
# For simplicity, we'll keep the same transforms as your previous setup.
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

print("Global configuration loaded for CycleGAN.")


In [None]:
# Cell 2: Dataset Class and DataLoader

class UnpairedImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_A, root_B, transform=None):
        self.root_A = root_A # Images for Domain A (e.g., H&E)
        self.root_B = root_B # Images for Domain B (e.g., IHC)
        self.transform = transform

        self.images_A = sorted(os.listdir(root_A))
        self.images_B = sorted(os.listdir(root_B))

        # CycleGAN doesn't require paired images, so we just take the minimum length
        # to ensure we don't run out of images for one domain before the other.
        self.length_dataset = max(len(self.images_A), len(self.images_B))


    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        # Cycle through images if one domain has fewer images than the other
        img_A_path = os.path.join(self.root_A, self.images_A[index % len(self.images_A)])
        img_B_path = os.path.join(self.root_B, self.images_B[index % len(self.images_B)])

        img_A = Image.open(img_A_path).convert("RGB")
        img_B = Image.open(img_B_path).convert("RGB")

        if self.transform:
            img_A = self.transform(img_A)
            img_B = self.transform(img_B)

        return img_A, img_B

# Create datasets and dataloaders
train_dataset = UnpairedImageDataset(TRAIN_DIR_A, TRAIN_DIR_B, transform=transform)
val_dataset = UnpairedImageDataset(VAL_DIR_A, VAL_DIR_B, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Dataset and DataLoaders created for CycleGAN (handling unpaired data).")


In [None]:
# Cell 3: Model Definitions (Generator and Discriminator for CycleGAN)

# --- Generator Model (U-Net Architecture - Copied from your successful Pix2Pix inference) ---
# This is the Generator architecture that successfully loaded your previous checkpoint.
# We'll use this as the base for both G_A (H&E to IHC) and G_B (IHC to H&E).
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_bn=True):
        super().__init__()
        # Use InstanceNorm2d for CycleGAN as it often performs better for style transfer
        # If use_bn is True, use InstanceNorm2d, otherwise Identity
        norm_layer = nn.InstanceNorm2d(out_channels) if use_bn else nn.Identity()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            norm_layer, # Changed from BatchNorm2d to InstanceNorm2d
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        # Encoder (Downsampling path) - 6 ConvBlocks
        self.initial_down = ConvBlock(in_channels, features, down=True, act="leaky", use_bn=False) # 3 -> 64 (128x128)
        self.down1 = ConvBlock(features, features * 2, down=True, act="leaky") # 64 -> 128 (64x64)
        self.down2 = ConvBlock(features * 2, features * 4, down=True, act="leaky") # 128 -> 256 (32x32)
        self.down3 = ConvBlock(features * 4, features * 8, down=True, act="leaky") # 256 -> 512 (16x16)
        self.down4 = ConvBlock(features * 8, features * 8, down=True, act="leaky") # 512 -> 512 (8x8)
        self.down5 = ConvBlock(features * 8, features * 8, down=True, act="leaky") # 512 -> 512 (4x4)
        self.down6 = ConvBlock(features * 8, features * 8, down=True, act="leaky") # 512 -> 512 (2x2)

        # Bottleneck (from 2x2 to 1x1) - This is the 7th downsampling step
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1, padding_mode="reflect"), # 512 -> 512 (1x1)
            nn.ReLU()
        )

        # Decoder (Upsampling path) - 7 ConvBlocks
        self.up1 = ConvBlock(features * 8, features * 8, down=False, act="relu", use_bn=True) # 512 -> 512 (skip d6, 2x2)
        self.up2 = ConvBlock(features * 8 * 2, features * 8, down=False, act="relu", use_bn=True) # 1024 -> 512 (skip d5, 4x4)
        self.up3 = ConvBlock(features * 8 * 2, features * 8, down=False, act="relu", use_bn=True) # 1024 -> 512 (skip d4, 8x8)
        self.up4 = ConvBlock(features * 8 * 2, features * 8, down=False, act="relu", use_bn=True) # 1024 -> 512 (skip d3, 16x16)
        self.up5 = ConvBlock(features * 8 * 2, features * 4, down=False, act="relu", use_bn=True) # 1024 -> 256 (skip d2, 32x32)
        self.up6 = ConvBlock(features * 4 * 2, features * 2, down=False, act="relu", use_bn=True) # 512 -> 128 (skip d1, 64x64)
        self.up7 = ConvBlock(features * 2 * 2, features, down=False, act="relu", use_bn=True) # 256 -> 64 (skip initial_down, 128x128)

        # Final output layer (no BatchNorm, Tanh activation)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, 4, 2, 1), # 128 -> 3 (256x256)
            nn.Tanh(), # Output normalized to [-1, 1]
        )

    def forward(self, x):
        # Encoder outputs (with skip connections)
        d1 = self.initial_down(x) # 64, 128x128
        d2 = self.down1(d1) # 128, 64x64
        d3 = self.down2(d2) # 256, 32x32
        d4 = self.down3(d3) # 512, 16x16
        d5 = self.down4(d4) # 512, 8x8
        d6 = self.down5(d5) # 512, 4x4
        d7 = self.down6(d6) # 512, 2x2

        bottleneck = self.bottleneck(d7) # 512, 1x1

        # Decoder with skip connections
        up1 = self.up1(bottleneck) # 512, 2x2
        up2 = self.up2(torch.cat([up1, d7], 1)) # 512+512=1024 -> 512, 4x4
        up3 = self.up3(torch.cat([up2, d6], 1)) # 512+512=1024 -> 512, 8x8
        up4 = self.up4(torch.cat([up3, d5], 1)) # 512+512=1024 -> 512, 16x16
        up5 = self.up5(torch.cat([up4, d4], 1)) # 512+512=1024 -> 256, 32x32
        up6 = self.up6(torch.cat([up5, d3], 1)) # 256+256=512 -> 128, 64x64
        up7 = self.up7(torch.cat([up6, d2], 1)) # 128+128=256 -> 64, 128x128

        # Final output layer
        return self.final_up(torch.cat([up7, d1], 1)) # 64+64=128 -> 3, 256x256

# --- Discriminator Model (PatchGAN with Spectral Normalization) ---
# For CycleGAN, we need two discriminators, one for each domain.
# Each Discriminator takes a single image (not concatenated input/output).
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        # Apply spectral_norm to the initial convolutional layer
        self.initial = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, features[0], 4, 2, 1, padding_mode="reflect")),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            # Apply spectral_norm to Conv2d layers within ConvBlock
            layers.append(
                nn.Sequential(
                    nn.utils.spectral_norm(nn.Conv2d(in_channels, feature, 4, 2, 1, bias=False, padding_mode="reflect")),
                    nn.InstanceNorm2d(feature), # Keep InstanceNorm
                    nn.LeakyReLU(0.2)
                )
            )
            in_channels = feature

        # Apply spectral_norm to the final convolutional layer
        layers.append(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, 1, 4, 1, 1, padding_mode="reflect")) # Output a single logit for each patch
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return self.model(x)

# --- Initialize Weights (Optional but Good Practice for GANs) ---
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            # Check if weight exists and is not None before accessing .data
            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            # Check if bias exists and is not None before accessing .data
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
            # Check if weight (gamma) exists and is not None
            if hasattr(m, 'weight') and m.weight is not None:
                nn.init.normal_(m.weight.data, 1.0, 0.02) # Gamma
            # Check if bias (beta) exists and is not None
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0) # Beta

# --- Perceptual Loss (VGG-based) ---
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        # Load pre-trained VGG19 features, use only up to relu5_1 for feature extraction
        vgg = vgg19(pretrained=True).features[:36].eval().to(device) # Using features up to relu5_1 (index 36)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss() # Use L1 loss on feature maps

    def forward(self, input_image, target_image):
        # Normalize input images to VGG's expected range [0, 1]
        # Our GAN output is [-1, 1], so convert to [0, 1]
        input_image_norm = (input_image * 0.5 + 0.5).clamp(0, 1)
        target_image_norm = (target_image * 0.5 + 0.5).clamp(0, 1)

        # Extract features
        features_input = self.vgg(input_image_norm)
        features_target = self.vgg(target_image_norm)

        # Calculate L1 loss between feature maps
        loss = self.criterion(features_input, features_target)
        return loss

print("Models defined for CycleGAN.")


In [None]:
# Cell 4: Training Loop for CycleGAN

# --- Fake Image Replay Buffer ---
# This class stores a history of generated images to prevent discriminator overfitting
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole"
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, images):
        to_return = []
        # Iterate directly over the tensor and detach to prevent graph issues/memory leaks
        for image in images.detach():
            image = torch.unsqueeze(image, 0) # Add batch dimension back for single image
            if len(self.data) < self.max_size:
                self.data.append(image)
                to_return.append(image)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = image
                else:
                    to_return.append(image)
        return torch.cat(to_return)

# --- WGAN-GP Loss Functions ---
def discriminator_loss_wgan_gp(disc_real_pred, disc_fake_pred, real_img, fake_img, discriminator, lambda_gp, device):
    # WGAN-GP: Maximize D(real) - D(fake) - lambda * GP
    # Gradient Penalty
    alpha = torch.rand(real_img.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_img + ((1 - alpha) * fake_img)).requires_grad_(True)

    disc_interpolates = discriminator(interpolates)

    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates, device=device),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp

    # Discriminator loss: D(fake) - D(real) + gradient_penalty
    loss = disc_fake_pred.mean() - disc_real_pred.mean() + gradient_penalty
    return loss

def generator_loss_wgan_gp(disc_fake_pred):
    # WGAN-GP: Minimize -D(fake)
    return -disc_fake_pred.mean()


# Instantiate models
# G_A: H&E (A) -> IHC (B)
# G_B: IHC (B) -> H&E (A)
gen_A = Generator(in_channels=3, out_channels=3, features=64).to(DEVICE)
gen_B = Generator(in_channels=3, out_channels=3, features=64).to(DEVICE)

# D_A: Discriminates real H&E from fake H&E
# D_B: Discriminates real IHC from fake IHC
disc_A = Discriminator(in_channels=3).to(DEVICE)
disc_B = Discriminator(in_channels=3).to(DEVICE)

# Initialize weights
initialize_weights(gen_A)
initialize_weights(gen_B)
initialize_weights(disc_A)
initialize_weights(disc_B)

# Loss functions
# criterion_GAN is now handled by custom WGAN-GP functions
criterion_Cycle = nn.L1Loss() # L1 loss for cycle consistency
criterion_Identity = nn.L1Loss() # L1 loss for identity mapping
criterion_Perceptual = PerceptualLoss(DEVICE) # New: Perceptual Loss

# Optimizers
optimizer_gen = optim.Adam(
    list(gen_A.parameters()) + list(gen_B.parameters()),
    lr=LEARNING_RATE_GEN,
    betas=(0.5, 0.999),
)
optimizer_disc_A = optim.Adam(
    disc_A.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999)
)
optimizer_disc_B = optim.Adam(
    disc_B.parameters(), lr=LEARNING_RATE_DISC, betas=(0.5, 0.999)
)

# Learning Rate Schedulers (Linear Decay)
def lambda_rule(epoch):
    # Linear decay from DECAY_EPOCH_START to NUM_EPOCHS
    return 1.0 - max(0, epoch - DECAY_EPOCH_START) / float(NUM_EPOCHS - DECAY_EPOCH_START)

scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda=lambda_rule)
scheduler_disc_A = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_A, lr_lambda=lambda_rule)
scheduler_disc_B = torch.optim.lr_scheduler.LambdaLR(optimizer_disc_B, lr_lambda=lambda_rule)


# Initialize start_epoch
start_epoch = 0

# Initialize GradScaler for AMP
scaler = GradScaler(init_scale=2.**10) # Increased initial_scale for more stability with very low LRs

# Initialize Replay Buffers
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Load checkpoint if resuming training
if LOAD_MODEL:
    try:
        # Load the consolidated checkpoint
        checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pth"), map_location=DEVICE)

        gen_A.load_state_dict(checkpoint['gen_A_state_dict'])
        gen_B.load_state_dict(checkpoint['gen_B_state_dict'])
        disc_A.load_state_dict(checkpoint['disc_A_state_dict'])
        disc_B.load_state_dict(checkpoint['disc_B_state_dict'])

        optimizer_gen.load_state_dict(checkpoint['optimizer_gen_state_dict'])
        optimizer_disc_A.load_state_dict(checkpoint['optimizer_disc_A_state_dict'])
        optimizer_disc_B.load_state_dict(checkpoint['optimizer_disc_B_state_dict'])

        start_epoch = checkpoint['epoch'] # Get the epoch to resume from
        print(f"Loaded models and optimizers. Resuming training from epoch {start_epoch + 1}/{NUM_EPOCHS}")

    except FileNotFoundError:
        print("No latest checkpoint found. Starting training from scratch.")
    except Exception as e:
        print(f"Error loading models: {e}")
        print("Starting training from scratch.") # Fallback to scratch if loading fails

# Training loop
print("\nStarting CycleGAN training...")
# Loop from start_epoch to NUM_EPOCHS
for epoch in range(start_epoch, NUM_EPOCHS):
    gen_A.train()
    gen_B.train()
    disc_A.train()
    disc_B.train()

    # Adjust tqdm description to show actual epoch number
    loop = tqdm(train_loader, leave=True, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    avg_gen_loss_total = 0.0
    avg_disc_A_loss = 0.0
    avg_disc_B_loss = 0.0
    avg_perceptual_loss_A = 0.0 # New: Perceptual loss for A->B
    avg_perceptual_loss_B = 0.0 # New: Perceptual loss for B->A
    avg_gen_adv_A = 0.0 # New: Average G_adv_A loss for logging
    avg_gen_adv_B = 0.0 # New: Average G_adv_B loss for logging
    avg_cycle_A = 0.0   # New: Average Cycle_A loss for logging
    avg_cycle_B = 0.0   # New: Average Cycle_B loss for logging
    avg_identity_A = 0.0 # New: Average Identity_A loss for logging
    avg_identity_B = 0.0 # New: Average Identity_B loss for logging


    for idx, (real_A, real_B) in enumerate(loop):
        real_A = real_A.to(DEVICE) # H&E
        real_B = real_B.to(DEVICE) # IHC

        # --- Train Discriminator D_B (for IHC) ---
        optimizer_disc_B.zero_grad()
        with autocast(DEVICE): # Use updated autocast syntax
            # Generate fake IHC from real H&E
            fake_B = gen_A(real_A)

            # Discriminate real IHC
            disc_real_B_pred = disc_B(real_B)

            # Discriminate fake IHC from replay buffer
            fake_B_from_buffer = fake_B_buffer.push_and_pop(fake_B)
            disc_fake_B_pred = disc_B(fake_B_from_buffer.detach())

            # WGAN-GP loss components (without GP term yet)
            loss_disc_B_adv = disc_fake_B_pred.mean() - disc_real_B_pred.mean()

        # Backward pass for adv loss, then unscale for GP calculation in FP32
        scaler.scale(loss_disc_B_adv).backward(retain_graph=True) # retain_graph for GP
        scaler.unscale_(optimizer_disc_B) # Unscale gradients for this optimizer

        # Calculate D_B Gradient Penalty in FP32
        # The GP calculation itself should be outside autocast if it was causing issues,
        # but the current setup with unscale_() should handle it.
        loss_disc_B_gp = discriminator_loss_wgan_gp(
            disc_real_B_pred, disc_fake_B_pred, real_B, fake_B_from_buffer, disc_B, LAMBDA_GP, DEVICE
        ) # Removed .item() here, let it be a tensor for proper backward if needed

        # Backward pass for GP
        scaler.scale(loss_disc_B_gp).backward() # Backward the GP term

        # Total D_B loss for logging (sum of adv and gp)
        loss_disc_B_total = loss_disc_B_adv.item() + loss_disc_B_gp.item() # Use .item() for logging

        # Clip gradients before optimizer step
        torch.nn.utils.clip_grad_norm_(disc_B.parameters(), max_norm=5.0)

        scaler.step(optimizer_disc_B)
        scaler.update()

        # --- Train Discriminator D_A (for H&E) ---
        optimizer_disc_A.zero_grad()
        with autocast(DEVICE): # Use updated autocast syntax
            # Generate fake H&E from real IHC
            fake_A = gen_B(real_B)

            # Discriminate real H&E
            disc_real_A_pred = disc_A(real_A)

            # Discriminate fake H&E from replay buffer
            fake_A_from_buffer = fake_A_buffer.push_and_pop(fake_A)
            disc_fake_A_pred = disc_A(fake_A_from_buffer.detach())

            # WGAN-GP loss components (without GP term yet)
            loss_disc_A_adv = disc_fake_A_pred.mean() - disc_real_A_pred.mean()

        # Backward pass for adv loss, then unscale for GP calculation in FP32
        scaler.scale(loss_disc_A_adv).backward(retain_graph=True) # retain_graph for GP
        scaler.unscale_(optimizer_disc_A) # Unscale gradients for this optimizer

        # Calculate D_A Gradient Penalty in FP32
        loss_disc_A_gp = discriminator_loss_wgan_gp(
            disc_real_A_pred, disc_fake_A_pred, real_A, fake_A_from_buffer, disc_A, LAMBDA_GP, DEVICE
        ) # Removed .item() here

        # Backward pass for GP
        scaler.scale(loss_disc_A_gp).backward() # Backward the GP term

        # Total D_A loss for logging
        loss_disc_A_total = loss_disc_A_adv.item() + loss_disc_A_gp.item()

        # Clip gradients before optimizer step
        torch.nn.utils.clip_grad_norm_(disc_A.parameters(), max_norm=5.0)

        scaler.step(optimizer_disc_A)
        scaler.update()

        # --- Train Generators G_A and G_B ---
        optimizer_gen.zero_grad()

        # Combine all generator losses into a single autocast block and single backward pass
        with autocast(DEVICE): # Use updated autocast syntax
            # Generate fake images
            fake_B = gen_A(real_A)
            fake_A = gen_B(real_B)

            # Adversarial Loss (Generators try to fool Discriminators)
            loss_gen_adv_A_val = generator_loss_wgan_gp(disc_B(fake_B)) # G_A tries to fool D_B
            loss_gen_adv_B_val = generator_loss_wgan_gp(disc_A(fake_A)) # G_B tries to fool D_A

            # Cycle Consistency Loss (A -> B -> A and B -> A -> B)
            cycled_A = gen_B(fake_B) # real_A -> fake_B -> cycled_A
            loss_cycle_A_val = criterion_Cycle(cycled_A, real_A)

            cycled_B = gen_A(fake_A) # real_B -> fake_A -> cycled_B
            loss_cycle_B_val = criterion_Cycle(cycled_B, real_B)

            # Identity Loss (Optional, but helps preserve color/content)
            identity_B = gen_A(real_B)
            loss_identity_B_val = criterion_Identity(identity_B, real_B)

            identity_A = gen_B(real_A)
            loss_identity_A_val = criterion_Identity(identity_A, real_A)

            # Perceptual Loss (ensure inputs are float for VGG, even within autocast)
            # VGG expects FP32, so explicitly cast if fake_B/fake_A are FP16 from autocast
            loss_perceptual_A_val = criterion_Perceptual(fake_B.float(), real_B)
            loss_perceptual_B_val = criterion_Perceptual(fake_A.float(), real_A)

            # Total Generator Loss
            loss_gen_total = (
                loss_gen_adv_A_val
                + loss_gen_adv_B_val
                + LAMBDA_CYCLE * loss_cycle_A_val
                + LAMBDA_CYCLE * loss_cycle_B_val
                + LAMBDA_IDENTITY * loss_identity_A_val
                + LAMBDA_IDENTITY * loss_identity_B_val
                + LAMBDA_PERCEPTUAL * loss_perceptual_A_val
                + LAMBDA_PERCEPTUAL * loss_perceptual_B_val
            )

        # Single backward pass for all generator losses
        scaler.scale(loss_gen_total).backward()

        # Clip gradients before optimizer step
        torch.nn.utils.clip_grad_norm_(list(gen_A.parameters()) + list(gen_B.parameters()), max_norm=5.0)

        scaler.step(optimizer_gen)
        scaler.update()

        # --- Update Averages and Progress Bar ---
        avg_disc_A_loss += loss_disc_A_total # Use total D loss for logging
        avg_disc_B_loss += loss_disc_B_total # Use total D loss for logging
        avg_gen_loss_total += loss_gen_total.item() # Use .item() for logging
        avg_gen_adv_A += loss_gen_adv_A_val.item() # Accumulate for average
        avg_gen_adv_B += loss_gen_adv_B_val.item() # Accumulate for average
        avg_cycle_A += loss_cycle_A_val.item()     # Accumulate for average
        avg_cycle_B += loss_cycle_B_val.item()     # Accumulate for average
        avg_identity_A += loss_identity_A_val.item() # Accumulate for average
        avg_identity_B += loss_identity_B_val.item() # Accumulate for average
        avg_perceptual_loss_A += loss_perceptual_A_val.item()
        avg_perceptual_loss_B += loss_perceptual_B_val.item()


        loop.set_postfix(
            D_A_loss=loss_disc_A_total, # Log total D loss
            D_B_loss=loss_disc_B_total, # Log total D loss
            G_total_loss=loss_gen_total.item(),
            G_adv_A=loss_gen_adv_A_val.item(), # Log last batch value for postfix
            G_adv_B=loss_gen_adv_B_val.item(), # Log last batch value for postfix
            Cycle_A=loss_cycle_A_val.item(),
            Cycle_B=loss_cycle_B_val.item(),
            Identity_A=loss_identity_A_val.item(),
            Identity_B=loss_identity_B_val.item(),
            Percept_A=loss_perceptual_A_val.item(),
            Percept_B=loss_perceptual_B_val.item(),
        )

    # --- End of Epoch ---
    avg_disc_A_loss /= len(train_loader)
    avg_disc_B_loss /= len(train_loader)
    avg_gen_loss_total /= len(train_loader)
    avg_gen_adv_A /= len(train_loader) # Calculate epoch average
    avg_gen_adv_B /= len(train_loader) # Calculate epoch average
    avg_cycle_A /= len(train_loader)   # Calculate epoch average
    avg_cycle_B /= len(train_loader)   # Calculate epoch average
    avg_identity_A /= len(train_loader) # Calculate epoch average
    avg_identity_B /= len(train_loader) # Calculate epoch average
    avg_perceptual_loss_A /= len(train_loader)
    avg_perceptual_loss_B /= len(train_loader)

    # Step learning rate schedulers
    scheduler_gen.step()
    scheduler_disc_A.step()
    scheduler_disc_B.step()

    # Append loss history to CSV
    with open(TRAINING_LOG_PATH, 'a') as f:
        # Only write header if starting a new log or if file is empty
        if epoch == start_epoch or (epoch == 0 and os.path.getsize(TRAINING_LOG_PATH) == 0):
            f.write("Epoch,Avg_D_A_Loss,Avg_D_B_Loss,Avg_G_Total_Loss,Avg_G_Adv_A,Avg_G_Adv_B,Avg_Cycle_A,Avg_Cycle_B,Avg_Identity_A,Avg_Identity_B,Avg_Percept_A,Avg_Percept_B\n") # Updated header
        f.write(f"{epoch+1},{avg_disc_A_loss:.6f},{avg_disc_B_loss:.6f},{avg_gen_loss_total:.6f},{avg_gen_adv_A:.6f},{avg_gen_adv_B:.6f},{avg_cycle_A:.6f},{avg_cycle_B:.6f},{avg_identity_A:.6f},{avg_identity_B:.6f},{avg_perceptual_loss_A:.6f},{avg_perceptual_loss_B:.6f}\n") # Updated data with averages

    # Print formatted table every 5 epochs
    if (epoch + 1) % 5 == 0 or epoch == start_epoch: # Print header and first epoch of the run
        if epoch == start_epoch: # Print header only on the first epoch of this run
            print("\n" + "="*160) # Extended width for new columns
            print(f"{'Epoch':<8} | {'D_A Loss':<10} | {'D_B Loss':<10} | {'G_Total Loss':<14} | {'G_Adv_A':<10} | {'G_Adv_B':<10} | {'Cycle_A':<10} | {'Cycle_B':<10} | {'Identity_A':<12} | {'Identity_B':<12} | {'Percept_A':<11} | {'Percept_B':<11}") # Updated header
            print("="*160) # Extended width
        print(f"{epoch+1:<8} | {avg_disc_A_loss:<10.4f} | {avg_disc_B_loss:<10.4f} | {avg_gen_loss_total:<14.4f} | {avg_gen_adv_A:<10.4f} | {avg_gen_adv_B:<10.4f} | {avg_cycle_A:<10.4f} | {avg_cycle_B:<10.4f} | {avg_identity_A:<12.4f} | {avg_identity_B:<12.4f} | {avg_perceptual_loss_A:<11.4f} | {avg_perceptual_loss_B:<11.4f}") # Updated data with averages


    # Save checkpoint
    if SAVE_MODEL and (epoch + 1) % 10 == 0:
        # Consolidate all states into one dictionary
        checkpoint = {
            'epoch': epoch + 1,
            'gen_A_state_dict': gen_A.state_dict(),
            'gen_B_state_dict': gen_B.state_dict(),
            'disc_A_state_dict': disc_A.state_dict(),
            'disc_B_state_dict': disc_B.state_dict(),
            'optimizer_gen_state_dict': optimizer_gen.state_dict(),
            'optimizer_disc_A_state_dict': optimizer_disc_A.state_dict(),
            'optimizer_disc_B_state_dict': optimizer_disc_B.state_dict(),
            'scheduler_gen_state_dict': scheduler_gen.state_dict(), # Save scheduler state
            'scheduler_disc_A_state_dict': scheduler_disc_A.state_dict(), # Save scheduler state
            'scheduler_disc_B_state_dict': scheduler_disc_B.state_dict(), # Save scheduler state
            'scaler_state_dict': scaler.state_dict(), # Save scaler state
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pth"))
        print(f"Models, optimizers, schedulers, and scaler saved at epoch {epoch+1}")

    # Save sample images (using validation data)
    gen_A.eval() # G_A: H&E -> IHC
    gen_B.eval() # G_B: IHC -> H&E
    with torch.no_grad():
        for i, (val_A, val_B) in enumerate(val_loader):
            if i >= 1: break # Save only one sample per epoch for brevity
            val_A = val_A.to(DEVICE)
            val_B = val_B.to(DEVICE)

            # Use autocast for inference as well for consistency, though not strictly needed for speed
            with autocast(DEVICE): # Use updated autocast syntax
                # Generate fake IHC from H&E
                fake_B_val = gen_A(val_A)
                # Cycle back to H&E
                cycled_A_val = gen_B(fake_B_val)

                # Generate fake H&E from IHC
                fake_A_val = gen_B(val_B)
                # Cycle back to IHC
                cycled_B_val = gen_A(fake_A_val)

            # Denormalize for saving
            val_A_denorm = (val_A * 0.5 + 0.5).clamp(0, 1)
            val_B_denorm = (val_B * 0.5 + 0.5).clamp(0, 1)
            fake_B_val_denorm = (fake_B_val * 0.5 + 0.5).clamp(0, 1)
            cycled_A_val_denorm = (cycled_A_val * 0.5 + 0.5).clamp(0, 1)
            fake_A_val_denorm = (fake_A_val * 0.5 + 0.5).clamp(0, 1)
            cycled_B_val_denorm = (cycled_B_val * 0.5 + 0.5).clamp(0, 1)

            # Save generated and cycled images
            save_image(val_A_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"real_A_he_sample.png"))
            save_image(val_B_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"real_B_ihc_sample.png"))
            save_image(fake_B_val_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"generated_B_ihc_epoch_{epoch+1}.png"))
            save_image(cycled_A_val_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"cycled_A_he_epoch_{epoch+1}.png"))
            save_image(fake_A_val_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"generated_A_he_epoch_{epoch+1}.png"))
            save_image(cycled_B_val_denorm, os.path.join(SAMPLE_IMAGES_DIR, f"cycled_B_ihc_epoch_{epoch+1}.png"))

    gen_A.train()
    gen_B.train() # Set back to train mode

print("\nCycleGAN training complete!")


In [None]:
# ===== Cell H1: metrics libs + imports =====
!pip -q install torch-fidelity==0.3.0 lpips==0.1.4 piq==0.8.0

import os, json, random
from pathlib import Path
from PIL import Image

import torch
import torchvision.transforms as T
from torchvision.utils import save_image
import pandas as pd
import numpy as np

from torch_fidelity import calculate_metrics  # FID/KID
import lpips                                   # LPIPS
import piq                                     # SSIM/PSNR

print("Ready on:", DEVICE)


In [None]:
# ===== Cell H3: Eval config & pairing =====
# Use the same val split as vanilla
VAL_DIR_A = '/content/drive/MyDrive/HER2/TrainValAB/valA'
VAL_DIR_B = '/content/drive/MyDrive/HER2/TrainValAB/valB'

EVAL_TAG = "hybrid_epoch_200"
OUT_ROOT = "/content/drive/MyDrive/HER2/hybrid_eval"
GEN_DIR  = Path(OUT_ROOT) / f"eval_{EVAL_TAG}_A2B"      # generated A->B
LOG_DIR  = Path(OUT_ROOT) / f"metrics_{EVAL_TAG}"       # metrics
GEN_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 256  # keep consistent with training

def list_images(folder):
    exts = ('.png','.jpg','.jpeg','.tif','.tiff','.bmp')
    return sorted([p for p in Path(folder).rglob('*') if p.suffix.lower() in exts])

paths_A = list_images(VAL_DIR_A)
paths_B = list_images(VAL_DIR_B)
bname_to_B = {Path(p).stem: p for p in paths_B}
pairs = [(str(pA), bname_to_B[Path(pA).stem]) for pA in paths_A if Path(pA).stem in bname_to_B]

print(f"A images: {len(paths_A)} | B images: {len(paths_B)} | Paired matches: {len(pairs)}")

# Transforms
to_tensor = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize((0.5,)*3, (0.5,)*3),
])
to_pil = T.Compose([T.Lambda(lambda x: (x * 0.5 + 0.5).clamp(0,1)), T.ToPILImage()])

# For LPIPS/SSIM/PSNR (both sides same size)
resize_01 = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
    T.ToTensor(),  # [0,1]
])


In [None]:
# ===== Cell H4: Generate A->B for all valA =====
with torch.no_grad():
    for pA in paths_A:
        out_path = GEN_DIR / (Path(pA).stem + ".png")
        if out_path.exists():
            continue
        imgA = Image.open(pA).convert('RGB')
        x = to_tensor(imgA).unsqueeze(0).to(DEVICE)
        y = gen_A_h(x)[0].cpu()            # [-1,1]
        to_pil(y).save(out_path)

print(f"Saved {len(list(GEN_DIR.glob('*.png')))} generated images to {GEN_DIR}")


In [None]:
# ===== Cell H5: FID & KID =====
fidkid = calculate_metrics(
    input1=str(GEN_DIR),
    input2=str(VAL_DIR_B),
    cuda=torch.cuda.is_available(),
    isc=False, fid=True, kid=True, prc=False, verbose=False
)
with open(LOG_DIR / "fid_kid.json", "w") as f:
    json.dump(fidkid, f, indent=2)
print(json.dumps(fidkid, indent=2))


In [None]:
# ===== Cell H6: LPIPS / SSIM / PSNR on paired matches =====
lpips_fn = lpips.LPIPS(net='vgg').to(DEVICE).eval()

rows = []
with torch.no_grad():
    for pA, pB in pairs:
        gen_path = GEN_DIR / (Path(pA).stem + ".png")
        if not gen_path.exists():
            continue

        # Resize both sides to the same eval size (256)
        G01 = resize_01(Image.open(gen_path).convert('RGB')).to(DEVICE).unsqueeze(0)
        B01 = resize_01(Image.open(pB).convert('RGB')).to(DEVICE).unsqueeze(0)

        # LPIPS expects [-1,1]
        Gm1p1, Bm1p1 = G01*2-1, B01*2-1
        lp = lpips_fn(Gm1p1, Bm1p1).item()

        ssim = piq.ssim(G01, B01, data_range=1.0).item()
        psnr = piq.psnr(G01, B01, data_range=1.0).item()

        rows.append({"basename": Path(pA).stem, "lpips": lp, "ssim": ssim, "psnr": psnr})

df = pd.DataFrame(rows).sort_values("basename")
df.to_csv(LOG_DIR / "paired_metrics.csv", index=False)

summary = {
    "tag": EVAL_TAG,
    "N_pairs": int(len(df)),
    "LPIPS_mean": float(df.lpips.mean()),
    "LPIPS_std": float(df.lpips.std(ddof=0)),
    "SSIM_mean": float(df.ssim.mean()),
    "SSIM_std": float(df.ssim.std(ddof=0)),
    "PSNR_mean": float(df.psnr.mean()),
    "PSNR_std": float(df.psnr.std(ddof=0)),
}
with open(LOG_DIR / "paired_metrics_summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary, indent=2))


In [None]:
# ===== Cell H7: Triplet grid for the thesis =====
import matplotlib.pyplot as plt
random.seed(42)

N = 12  # number of triplets
sample_pairs = pairs[:]
random.shuffle(sample_pairs)
sample_pairs = sample_pairs[:N]

ncols, nrows = 3, N
fig, axes = plt.subplots(nrows, ncols, figsize=(9, 3*N), dpi=150)

for i, (pA, pB) in enumerate(sample_pairs):
    gen = GEN_DIR / (Path(pA).stem + ".png")
    A = Image.open(pA).convert('RGB')
    B = Image.open(pB).convert('RGB')
    G = Image.open(gen).convert('RGB')

    axes[i,0].imshow(A); axes[i,0].set_title("H&E (A)"); axes[i,0].axis('off')
    axes[i,1].imshow(B); axes[i,1].set_title("IHC – Expected (B)"); axes[i,1].axis('off')
    axes[i,2].imshow(G); axes[i,2].set_title("IHC – Generated (A→B)"); axes[i,2].axis('off')

plt.tight_layout()
FIG_PATH = Path(OUT_ROOT) / f"eval_{EVAL_TAG}_A-B-G_triplets.png"
plt.savefig(FIG_PATH, bbox_inches='tight')
print("Saved figure:", FIG_PATH)


#Plain Cycle GAN

In [None]:
# ===== Cell 1: Imports and Global Configuration =====
import os, random, math
from pathlib import Path
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from PIL import Image

# (Optional) Pretty loss tables
try:
    import pandas as pd
except Exception:
    pd = None

# --- Mount Google Drive (Colab) ---
from google.colab import drive
drive.mount('/content/drive')

# --- Configuration ---
@dataclass
class Config:
    # Your dataset folders (pix2pix-style)
    TRAIN_DIR_A: str = '/content/drive/MyDrive/HER2/TrainValAB/trainA'
    TRAIN_DIR_B: str = '/content/drive/MyDrive/HER2/TrainValAB/trainB'
    VAL_DIR_A:   str = '/content/drive/MyDrive/HER2/TrainValAB/valA'
    VAL_DIR_B:   str = '/content/drive/MyDrive/HER2/TrainValAB/valB'

    # Experiment output (new folder for vanilla CycleGAN)
    EXP_ROOT: str = '/content/drive/MyDrive/HER2/cyclegan_vanilla'
    SAVE_EVERY: int = 10      # checkpoint cadence
    SAMPLE_EVERY: int = 10    # visualization cadence

    # Training
    EPOCHS: int = 200
    BATCH_SIZE: int = 1
    IMG_SIZE: int = 256
    SEED: int = 42
    DEVICE: str = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')

    # Optimizer (vanilla CycleGAN uses LSGAN; these are standard)
    LR: float = 2e-4
    BETA1: float = 0.5
    BETA2: float = 0.999
    LR_DECAY_START: int = 100  # start linear decay here

    # Loss weights
    LAMBDA_CYCLE: float = 10.0
    LAMBDA_ID: float = 5.0     # 0.5 * LAMBDA_CYCLE

cfg = Config()

# Create output dirs
exp_dir = Path(cfg.EXP_ROOT)
ckpt_dir = exp_dir / 'checkpoints'
samples_dir = exp_dir / 'samples'
logs_dir = exp_dir / 'logs'
for d in [exp_dir, ckpt_dir, samples_dir, logs_dir]:
    d.mkdir(parents=True, exist_ok=True)

# Reproducibility
random.seed(cfg.SEED)
torch.manual_seed(cfg.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.SEED)

# Transforms
transform = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

print("Config:", asdict(cfg))
print("Outputs ->", str(exp_dir))
print("Device:", cfg.DEVICE)


In [None]:
# ===== Cell 2: Dataset Class and DataLoader =====
class UnpairedAB(Dataset):
    """
    Expects four folders:
      trainA/, trainB/, valA/, valB/
    Unpaired by default. For visualization, tries to match by filename (if names align).
    """
    def __init__(self, dir_A, dir_B, augment=False):
        self.dir_A = Path(dir_A)
        self.dir_B = Path(dir_B)
        exts = ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')
        self.paths_A = sorted([p for p in self.dir_A.rglob('*') if p.suffix.lower() in exts])
        self.paths_B = sorted([p for p in self.dir_B.rglob('*') if p.suffix.lower() in exts])
        if len(self.paths_A) == 0 or len(self.paths_B) == 0:
            raise RuntimeError(f"No images found in:\n{self.dir_A}\n{self.dir_B}")

        # Basic aug (optional)
        t = [transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC)]
        if augment:
            t.append(transforms.RandomHorizontalFlip())
            t.append(transforms.RandomVerticalFlip())
        t += [transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3)]
        self.tf = transforms.Compose(t)

    def __len__(self):
        # Unpaired sampling
        return max(len(self.paths_A), len(self.paths_B))

    def __getitem__(self, idx):
        path_A = self.paths_A[idx % len(self.paths_A)]
        # Try to pick a B with the same basename; else random
        base = path_A.stem
        same = [p for p in self.paths_B if p.stem == base]
        path_B = same[0] if same else self.paths_B[random.randint(0, len(self.paths_B)-1)]

        img_A = Image.open(path_A).convert('RGB')
        img_B = Image.open(path_B).convert('RGB')
        return self.tf(img_A), self.tf(img_B), path_A.name, path_B.name

# DataLoaders
train_ds = UnpairedAB(cfg.TRAIN_DIR_A, cfg.TRAIN_DIR_B, augment=True)
val_ds   = UnpairedAB(cfg.VAL_DIR_A,   cfg.VAL_DIR_B,   augment=False)

train_loader = DataLoader(train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=1, shuffle=True, num_workers=2, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

# Helpers for visualization
inv_norm = transforms.Normalize(mean=[-1,-1,-1], std=[2,2,2])
def denorm(x):  # [-1,1] -> [0,1]
    return (x * 0.5 + 0.5).clamp(0,1)


In [None]:
# ===== Cell 3: Model Definitions (Generators & Discriminators) =====
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, padding=0),
            nn.InstanceNorm2d(dim, affine=False),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, padding=0),
            nn.InstanceNorm2d(dim, affine=False),
        )
    def forward(self, x):
        return x + self.block(x)

class ResnetGenerator(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_filters=64, n_blocks=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_c, n_filters, 7, padding=0),
            nn.InstanceNorm2d(n_filters, affine=False),
            nn.ReLU(inplace=True),
        ]
        curr = n_filters
        # Down 2x
        for _ in range(2):
            model += [
                nn.Conv2d(curr, curr*2, 3, stride=2, padding=1),
                nn.InstanceNorm2d(curr*2, affine=False),
                nn.ReLU(inplace=True),
            ]
            curr *= 2
        # Res blocks
        for _ in range(n_blocks):
            model += [ResnetBlock(curr)]
        # Up 2x
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr, curr//2, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(curr//2, affine=False),
                nn.ReLU(inplace=True),
            ]
            curr //= 2
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(curr, out_c, 7, padding=0),
            nn.Tanh(),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=3, n_filters=64):
        super().__init__()
        def block(in_f, out_f, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
            if norm: layers.insert(1, nn.InstanceNorm2d(out_f, affine=False))
            return layers
        self.model = nn.Sequential(
            *block(in_c, n_filters, norm=False),
            *block(n_filters, n_filters*2),
            *block(n_filters*2, n_filters*4),
            nn.Conv2d(n_filters*4, 1, 4, stride=1, padding=1),  # Patch score
        )
    def forward(self, x):
        return self.model(x)

# Instantiate models
G_A2B = ResnetGenerator().to(cfg.DEVICE)
G_B2A = ResnetGenerator().to(cfg.DEVICE)
D_A   = PatchDiscriminator().to(cfg.DEVICE)
D_B   = PatchDiscriminator().to(cfg.DEVICE)

# Optimizers
opt_G   = torch.optim.Adam(list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))
opt_D_A = torch.optim.Adam(D_A.parameters(), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))
opt_D_B = torch.optim.Adam(D_B.parameters(), lr=cfg.LR, betas=(cfg.BETA1, cfg.BETA2))

# Schedulers: linear decay after LR_DECAY_START
def lambda_rule(epoch):
    if epoch < cfg.LR_DECAY_START:
        return 1.0
    return 1.0 - (epoch - cfg.LR_DECAY_START) / float(max(1, cfg.EPOCHS - cfg.LR_DECAY_START))

sch_G   = torch.optim.lr_scheduler.LambdaLR(opt_G,   lr_lambda=lambda_rule)
sch_D_A = torch.optim.lr_scheduler.LambdaLR(opt_D_A, lr_lambda=lambda_rule)
sch_D_B = torch.optim.lr_scheduler.LambdaLR(opt_D_B, lr_lambda=lambda_rule)

# Losses (LSGAN + L1)
adv_criterion   = nn.MSELoss()
recon_criterion = nn.L1Loss()

# Image pools (stabilize D)
class ImagePool:
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.images = []
    def query(self, images):
        if self.pool_size == 0:
            return images
        out = []
        for img in images:
            img = img.detach()
            if len(self.images) < self.pool_size:
                self.images.append(img)
                out.append(img)
            else:
                if random.random() > 0.5:
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = img
                    out.append(tmp)
                else:
                    out.append(img)
        return torch.stack(out, dim=0)

pool_A = ImagePool(50)
pool_B = ImagePool(50)

# Resume if latest exists
start_epoch = 1
latest = ckpt_dir / 'latest.pt'
if latest.exists():
    print("Resuming from", latest)
    s = torch.load(latest, map_location=cfg.DEVICE)
    G_A2B.load_state_dict(s['G_A2B']); G_B2A.load_state_dict(s['G_B2A'])
    D_A.load_state_dict(s['D_A']);     D_B.load_state_dict(s['D_B'])
    opt_G.load_state_dict(s['opt_G']); opt_D_A.load_state_dict(s['opt_D_A']); opt_D_B.load_state_dict(s['opt_D_B'])
    sch_G.load_state_dict(s['sch_G']); sch_D_A.load_state_dict(s['sch_D_A']); sch_D_B.load_state_dict(s['sch_D_B'])
    start_epoch = s['epoch'] + 1
    print("Start epoch:", start_epoch)


In [None]:
# ===== Cell 4 (REPLACED): Training Loop for CycleGAN with tqdm progress bar =====
from tqdm import tqdm
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def save_checkpoint(epoch):
    obj = {
        'epoch': epoch,
        'G_A2B': G_A2B.state_dict(),
        'G_B2A': G_B2A.state_dict(),
        'D_A': D_A.state_dict(),
        'D_B': D_B.state_dict(),
        'opt_G': opt_G.state_dict(),
        'opt_D_A': opt_D_A.state_dict(),
        'opt_D_B': opt_D_B.state_dict(),
        'sch_G': sch_G.state_dict(),
        'sch_D_A': sch_D_A.state_dict(),
        'sch_D_B': sch_D_B.state_dict(),
        'cfg': asdict(cfg),
    }
    path = ckpt_dir / f'epoch_{epoch:03d}.pt'
    torch.save(obj, path)
    torch.save(obj, ckpt_dir / 'latest.pt')
    print(f'[Checkpoint] Saved {path}')

@torch.no_grad()
def visualize_epoch(epoch):
    G_A2B.eval(); G_B2A.eval()
    try:
        img_A, img_B, name_A, name_B = next(iter(val_loader))
    except StopIteration:
        return
    img_A = img_A.to(cfg.DEVICE); img_B = img_B.to(cfg.DEVICE)
    fake_B = G_A2B(img_A)

    paired = (name_A[0].split('.')[0] == name_B[0].split('.')[0])
    title_B = "Expected (paired B)" if paired else "Reference (real B, unpaired)"

    grid = make_grid(torch.cat([denorm(img_A), denorm(img_B), denorm(fake_B)], dim=0), nrow=img_A.size(0))
    out_path = samples_dir / f'epoch_{epoch:03d}.png'
    save_image(grid, out_path)
    print(f"[Sample] Saved {out_path}")

    plt.figure(figsize=(12,4))
    plt.imshow(grid.permute(1,2,0).cpu().numpy()); plt.axis('off')
    plt.title(f"Epoch {epoch} — Left: Actual (A), Middle: {title_B}, Right: Generated (A→B)")
    plt.show()

class LossBook:
    def __init__(self):
        self.rows = []
    def add(self, epoch, **kw):
        row = {'epoch': epoch}; row.update({k: float(v) for k,v in kw.items()})
        self.rows.append(row)
    def table(self):
        if pd is None: return None
        return pd.DataFrame(self.rows)

lossbook = LossBook()

real_label = 1.0
fake_label = 0.0

for epoch in range(start_epoch, cfg.EPOCHS + 1):
    G_A2B.train(); G_B2A.train(); D_A.train(); D_B.train()

    sums = {k:0.0 for k in [
        'G_total','G_adv_A2B','G_adv_B2A','cycle_A','cycle_B','id_A','id_B','D_A','D_B'
    ]}
    nb = 0

    # tqdm progress bar over train batches
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.EPOCHS}", ncols=100, leave=False)
    for imgs_A, imgs_B, _, _ in pbar:
        imgs_A = imgs_A.to(cfg.DEVICE)
        imgs_B = imgs_B.to(cfg.DEVICE)

        # === Generators ===
        opt_G.zero_grad()

        # Identity: G_A2B(B) ≈ B, G_B2A(A) ≈ A
        id_B = G_A2B(imgs_B)
        loss_id_B = recon_criterion(id_B, imgs_B) * cfg.LAMBDA_ID
        id_A = G_B2A(imgs_A)
        loss_id_A = recon_criterion(id_A, imgs_A) * cfg.LAMBDA_ID

        # GAN: A->B
        fake_B = G_A2B(imgs_A)
        pred_fake_B = D_B(fake_B)
        valid_B = torch.ones_like(pred_fake_B, device=cfg.DEVICE)
        loss_G_A2B = adv_criterion(pred_fake_B, valid_B)

        # GAN: B->A
        fake_A = G_B2A(imgs_B)
        pred_fake_A = D_A(fake_A)
        valid_A = torch.ones_like(pred_fake_A, device=cfg.DEVICE)
        loss_G_B2A = adv_criterion(pred_fake_A, valid_A)

        # Cycle
        rec_A = G_B2A(fake_B)
        rec_B = G_A2B(fake_A)
        loss_cyc_A = recon_criterion(rec_A, imgs_A) * cfg.LAMBDA_CYCLE
        loss_cyc_B = recon_criterion(rec_B, imgs_B) * cfg.LAMBDA_CYCLE

        loss_G = loss_G_A2B + loss_G_B2A + loss_cyc_A + loss_cyc_B + loss_id_A + loss_id_B
        loss_G.backward()
        opt_G.step()

        # === D_A ===
        opt_D_A.zero_grad()
        pred_real_A = D_A(imgs_A)
        valid = torch.ones_like(pred_real_A, device=cfg.DEVICE)
        loss_D_A_real = adv_criterion(pred_real_A, valid)

        fake_A_pool = pool_A.query(fake_A)
        pred_fake_A = D_A(fake_A_pool.detach())
        fake = torch.zeros_like(pred_fake_A, device=cfg.DEVICE)
        loss_D_A_fake = adv_criterion(pred_fake_A, fake)
        loss_DA = 0.5*(loss_D_A_real + loss_D_A_fake)
        loss_DA.backward()
        opt_D_A.step()

        # === D_B ===
        opt_D_B.zero_grad()
        pred_real_B = D_B(imgs_B)
        valid = torch.ones_like(pred_real_B, device=cfg.DEVICE)
        loss_D_B_real = adv_criterion(pred_real_B, valid)

        fake_B_pool = pool_B.query(fake_B)
        pred_fake_B = D_B(fake_B_pool.detach())
        fake = torch.zeros_like(pred_fake_B, device=cfg.DEVICE)
        loss_D_B_fake = adv_criterion(pred_fake_B, fake)
        loss_DB = 0.5*(loss_D_B_real + loss_D_B_fake)
        loss_DB.backward()
        opt_D_B.step()

        # Accumulate + update progress bar postfix
        sums['G_total']    += loss_G.item()
        sums['G_adv_A2B']  += loss_G_A2B.item()
        sums['G_adv_B2A']  += loss_G_B2A.item()
        sums['cycle_A']    += loss_cyc_A.item()
        sums['cycle_B']    += loss_cyc_B.item()
        sums['id_A']       += loss_id_A.item()
        sums['id_B']       += loss_id_B.item()
        sums['D_A']        += loss_DA.item()
        sums['D_B']        += loss_DB.item()
        nb += 1

        pbar.set_postfix({
            "G": f"{sums['G_total']/nb:.3f}",
            "D_A": f"{sums['D_A']/nb:.3f}",
            "D_B": f"{sums['D_B']/nb:.3f}",
        })

    # Step schedulers once per epoch
    sch_G.step(); sch_D_A.step(); sch_D_B.step()

    means = {k: v/max(1,nb) for k,v in sums.items()}
    lossbook.add(epoch, **means)

    # End-of-epoch summary line
    print(f"Epoch {epoch:03d}/{cfg.EPOCHS} | "
          f"G:{means['G_total']:.4f} | D_A:{means['D_A']:.4f} D_B:{means['D_B']:.4f} | "
          f"cycA:{means['cycle_A']:.3f} cycB:{means['cycle_B']:.3f} | idA:{means['id_A']:.3f} idB:{means['id_B']:.3f}")

    # Sample + table every SAMPLE_EVERY
    if epoch % cfg.SAMPLE_EVERY == 0:
        visualize_epoch(epoch)
        if pd is not None:
            df = lossbook.table()
            print(df.to_string(index=False))
        else:
            print("[Info] Install pandas for a formatted loss table.")

    # Checkpoint every SAVE_EVERY (and on final epoch)
    if epoch % cfg.SAVE_EVERY == 0 or epoch == cfg.EPOCHS:
        save_checkpoint(epoch)

# Save final CSV and show final sample
if pd is not None:
    df = lossbook.table()
    csv_path = logs_dir / 'epoch_losses.csv'
    df.to_csv(csv_path, index=False)
    print(f"[Log] Saved loss CSV to {csv_path}")

visualize_epoch(epoch=cfg.EPOCHS)


In [None]:
# ===== Cell 5: Install metrics libs (run once per fresh Colab) =====
!pip -q install torch-fidelity==0.3.0 lpips==0.1.4 piq==0.8.0


In [None]:
from torch_fidelity import calculate_metrics

In [None]:
# ===== Cell 6: Eval config, pairing, helpers =====
import os, glob, random, json
from pathlib import Path
from PIL import Image
import torch
import torchvision.transforms as T
import pandas as pd
import numpy as np

from torch_fidelity import calculate_metrics  # torch-fidelity
import lpips  # perceptual
import piq    # SSIM/PSNR

# --- Choose the split you want to evaluate (val or test) ---
EVAL_DIR_A = cfg.VAL_DIR_A  # '/content/drive/MyDrive/HER2/TrainValAB/valA'
EVAL_DIR_B = cfg.VAL_DIR_B  # '/content/drive/MyDrive/HER2/TrainValAB/valB'

# --- Where to dump generated images & results for this checkpoint ---
EVAL_TAG       = "epoch_150"  # <-- change to the checkpoint you want to evaluate
GEN_OUT_DIR    = samples_dir / f"eval_{EVAL_TAG}_A2B"
METRICS_OUTDIR = logs_dir / f"metrics_{EVAL_TAG}"
GEN_OUT_DIR.mkdir(parents=True, exist_ok=True)
METRICS_OUTDIR.mkdir(parents=True, exist_ok=True)

# Transforms to feed the net, and to convert back to PIL for saving
to_tensor = T.Compose([
    T.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize((0.5,)*3, (0.5,)*3)
])
to_pil = T.Compose([T.Lambda(lambda x: (x * 0.5 + 0.5).clamp(0,1)), T.ToPILImage()])

# Pair files by basename (only for LPIPS/SSIM/PSNR)
def list_images(folder):
    exts = ('.png','.jpg','.jpeg','.tif','.tiff','.bmp')
    return sorted([p for p in Path(folder).rglob('*') if p.suffix.lower() in exts])

paths_A = list_images(EVAL_DIR_A)
paths_B = list_images(EVAL_DIR_B)
bname_to_B = {Path(p).stem: p for p in paths_B}
pairs = []
for pA in paths_A:
    bn = Path(pA).stem
    if bn in bname_to_B:
        pairs.append((str(pA), bname_to_B[bn]))

print(f"Found {len(paths_A)} A images, {len(paths_B)} B images, and {len(pairs)} paired matches by filename.")


In [None]:
# ===== Cell 7: Load checkpoint & export A→B translations =====
# Load the checkpoint you want to evaluate
ckpt = torch.load(ckpt_dir / f"{EVAL_TAG}.pt", map_location=cfg.DEVICE)
G_A2B.load_state_dict(ckpt['G_A2B'])
G_B2A.load_state_dict(ckpt['G_B2A'])
G_A2B.eval(); G_B2A.eval()

# Generate A->B for all A images in EVAL_DIR_A
with torch.no_grad():
    for pA in paths_A:
        imgA = Image.open(pA).convert('RGB')
        x = to_tensor(imgA).unsqueeze(0).to(cfg.DEVICE)
        y_fake = G_A2B(x)[0].cpu()
        pil = to_pil(y_fake)
        out_path = GEN_OUT_DIR / (Path(pA).stem + ".png")
        pil.save(out_path)

print(f"Saved {len(paths_A)} generated images to {GEN_OUT_DIR}")


In [None]:
# ===== Cell 8: Compute FID & KID =====
# torch-fidelity expects directories with images
metrics = calculate_metrics(
    input1=str(GEN_OUT_DIR),
    input2=str(EVAL_DIR_B),
    cuda=torch.cuda.is_available(),
    isc=False, fid=True, kid=True, prc=False, verbose=False
)
with open(METRICS_OUTDIR / "fid_kid.json", "w") as f:
    json.dump(metrics, f, indent=2)
print(json.dumps(metrics, indent=2))


In [None]:
# ===== Cell 9 (REPLACED): LPIPS, SSIM, PSNR on paired matches =====
import torch
import torchvision.transforms as T
from PIL import Image
import pandas as pd
import lpips
import piq

# set eval size = training size used by the nets
EVAL_SIZE = cfg.IMG_SIZE

# transforms
to_unit = T.ToTensor()  # [0,1]
resize_01 = T.Compose([
    T.Resize((EVAL_SIZE, EVAL_SIZE), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
    T.ToTensor(),  # [0,1]
])
def to_m1p1(x01):  # [0,1] -> [-1,1]
    return x01 * 2 - 1

lpips_fn = lpips.LPIPS(net='vgg').to(cfg.DEVICE).eval()

rows = []
with torch.no_grad():
    for pA, pB in pairs:
        gen_path = GEN_OUT_DIR / (Path(pA).stem + ".png")
        if not gen_path.exists():
            continue

        # Load & resize both to the SAME size
        Gimg01 = resize_01(Image.open(gen_path).convert('RGB')).to(cfg.DEVICE).unsqueeze(0)  # [1,3,H,W], 0..1
        Bimg01 = resize_01(Image.open(pB).convert('RGB')).to(cfg.DEVICE).unsqueeze(0)

        # LPIPS expects [-1,1]
        Gm1p1 = to_m1p1(Gimg01)
        Bm1p1 = to_m1p1(Bimg01)
        lp = lpips_fn(Gm1p1, Bm1p1).item()

        # SSIM / PSNR on [0,1]
        ssim_val = piq.ssim(Gimg01, Bimg01, data_range=1.0).item()
        psnr_val = piq.psnr(Gimg01, Bimg01, data_range=1.0).item()

        rows.append({
            "basename": Path(pA).stem,
            "lpips": lp,
            "ssim": ssim_val,
            "psnr": psnr_val
        })

df = pd.DataFrame(rows).sort_values("basename")
df.to_csv(METRICS_OUTDIR / "paired_metrics.csv", index=False)

agg = {
    "N_pairs": int(len(df)),
    "LPIPS_mean": float(df["lpips"].mean()),
    "LPIPS_std": float(df["lpips"].std(ddof=0)),
    "SSIM_mean": float(df["ssim"].mean()),
    "SSIM_std": float(df["ssim"].std(ddof=0)),
    "PSNR_mean": float(df["psnr"].mean()),
    "PSNR_std": float(df["psnr"].std(ddof=0)),
}
with open(METRICS_OUTDIR / "paired_metrics_summary.json", "w") as f:
    json.dump(agg, f, indent=2)

print(agg)


In [None]:
# ===== Cell 10: Create a grid of examples for the thesis =====
import matplotlib.pyplot as plt
import math

N = 12  # how many triplets to show
sample_pairs = pairs[:]
random.shuffle(sample_pairs)
sample_pairs = sample_pairs[:N]

ncols = 3
nrows = N
fig, axes = plt.subplots(nrows, ncols, figsize=(9, 3*N), dpi=150)

for i, (pA, pB) in enumerate(sample_pairs):
    gen = GEN_OUT_DIR / (Path(pA).stem + ".png")
    A = Image.open(pA).convert('RGB')
    B = Image.open(pB).convert('RGB')
    G = Image.open(gen).convert('RGB')

    axes[i,0].imshow(A); axes[i,0].set_title("H&E (A)"); axes[i,0].axis('off')
    axes[i,1].imshow(B); axes[i,1].set_title("IHC – Expected (B)"); axes[i,1].axis('off')
    axes[i,2].imshow(G); axes[i,2].set_title("IHC – Generated (A→B)"); axes[i,2].axis('off')

plt.tight_layout()
fig_path = samples_dir / f"eval_{EVAL_TAG}_A-B-G_triplets.png"
plt.savefig(fig_path, bbox_inches='tight')
print(f"Saved figure: {fig_path}")


# Results

In [None]:
MODELS = {
    # 1) Pix2Pix (paired)
    "pix2pix": {
        "type": "pix2pix",
        "ckpt": "/content/drive/MyDrive/HER2/pix2pix_checkpoints/generator_epoch_200.pth",
        "out":  "/content/drive/MyDrive/HER2/eval_all_models/pix2pix"
    },

    # 2) CycleGAN-Base (ResNet-9) — saved under EXP_ROOT/checkpoints/
    #    If epoch_200.pt is missing, use latest.pt instead.
    "cyclegan_base": {
        "type": "cyclegan_resnet9",
        "ckpt": "/content/drive/MyDrive/HER2/cyclegan_vanilla/checkpoints/epoch_200.pt",  # fallback: .../latest.pt
        "out":  "/content/drive/MyDrive/HER2/eval_all_models/cyclegan_base"
    },

    # 3) CycleGAN-UNet (your “after pix” section)
    "cyclegan_unet": {
        "type": "cyclegan_unet",
        "ckpt": "/content/drive/MyDrive/HER2/cyclegan_checkpoints/checkpoint_epoch_200.pth",
        "out":  "/content/drive/MyDrive/HER2/eval_all_models/cyclegan_unet"
    },
}


In [None]:
def build_generator(model_type: str):
    if model_type == "pix2pix":
        # From your notebook: class Pix2PixUNet(nn.Module)
        G = Pix2PixUNet().to(DEVICE)

    elif model_type == "cyclegan_resnet9":
        # From your notebook: class ResnetGenerator(nn.Module)
        # Signature: ResnetGenerator(input_nc, output_nc, ngf=64, n_blocks=9)
        G = ResnetGenerator(3, 3, n_blocks=9).to(DEVICE)

    elif model_type == "cyclegan_unet":
        # From your notebook: class CycleUNet(nn.Module)
        G = CycleUNet().to(DEVICE)

    else:
        raise ValueError(f"Unknown model_type: {model_type}")
    G.eval()
    return G


In [None]:
from pathlib import Path
VAL_DIR_A = Path("/content/drive/MyDrive/HER2/TrainValAB/valA")  # H&E
VAL_DIR_B = Path("/content/drive/MyDrive/HER2/TrainValAB/valB")  # IHC


In [None]:
# Auto-locate valA / valB in Google Drive and set VAL_DIR_A / VAL_DIR_B

from pathlib import Path
import os

# Ensure Drive is mounted:
# from google.colab import drive; drive.mount('/content/drive')

ROOTS = [
    Path("/content/drive/MyDrive/HER2/TrainValAB"),
    Path("/content/drive/MyDrive/HER2"),
    Path("/content/drive/MyDrive"),
]

def is_img_dir(p: Path):
    exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp",".webp"}
    return p.is_dir() and any((p/f).suffix.lower() in exts for f in os.listdir(p) if (p/f).is_file())

found_A, found_B = None, None

def ci_eq(a,b): return a.lower()==b.lower()

# search by walking the roots
for root in ROOTS:
    if not root.exists(): continue
    for dirpath, dirnames, filenames in os.walk(root):
        # prefer shallow matches first
        for d in dirnames:
            p = Path(dirpath)/d
            if ci_eq(d, "valA") and is_img_dir(p) and found_A is None:
                found_A = p
            if ci_eq(d, "valB") and is_img_dir(p) and found_B is None:
                found_B = p
        if found_A and found_B:
            break
    if found_A and found_B:
        break

# fallback: use trainA/trainB if valA/valB absent
if found_A is None or found_B is None:
    for root in ROOTS:
        if not root.exists(): continue
        candA = next((Path(dirpath)/d for dirpath, dirnames, _ in os.walk(root)
                      for d in dirnames if ci_eq(d,"trainA")), None)
        candB = next((Path(dirpath)/d for dirpath, dirnames, _ in os.walk(root)
                      for d in dirnames if ci_eq(d,"trainB")), None)
        if candA and candB and is_img_dir(candA) and is_img_dir(candB):
            print("[WARN] valA/valB not found; falling back to trainA/trainB.")
            found_A, found_B = candA, candB
            break

assert found_A and found_B, "Couldn't find valA/valB (or trainA/trainB) anywhere under MyDrive."

VAL_DIR_A, VAL_DIR_B = found_A, found_B
print("VAL_DIR_A:", VAL_DIR_A)
print("VAL_DIR_B:", VAL_DIR_B)
print("A count:", len([f for f in os.listdir(VAL_DIR_A) if (VAL_DIR_A/f).is_file()]))
print("B count:", len([f for f in os.listdir(VAL_DIR_B) if (VAL_DIR_B/f).is_file()]))
