In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [4]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [5]:
import os
import math
import warnings
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision.models as models
import torchvision.utils as vutils

import pandas as pd
import numpy as np
from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2

import segmentation_models_pytorch as smp

warnings.filterwarnings("ignore")

# ---------------------------
# CONFIG
# ---------------------------
CONFIG = {
    "IMAGE_DIR": "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/",
    "TRAIN_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/splits/train.csv",
    "CDIAG_PATH": "/content/drive/MyDrive/CAF-GAN/outputs/cdiag_512/best_cdiag_512.pth",
    "CSEG_PATH": "/content/drive/MyDrive/CAF-GAN/outputs/cseg_512/best_cseg_512.pth",
    "OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/caf_gan_final/",
    "IMAGE_OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/caf_gan_final/generated_images/",
    "TARGET_IMG_SIZE": 512,
    "LATENT_DIM": 512,
    "CHANNELS": 3,
    "PROGRESSIVE_EPOCHS": [20, 20, 20, 40, 40, 40, 100],  # 7 entries for 8..512
    "BATCH_SIZES": [8, 8, 8, 8, 8, 8, 4],
    "G_LR": 1e-3,
    "D_LR": 1e-3,
    "B1": 0.0,
    "B2": 0.99,
    "R1_REG_FREQ": 16,
    "LAMBDA_R1": 10,
    "LAMBDA_CLINIC": 0.5,  ### ROBUSTNESS: Tuned weight
    "LAMBDA_FAIR": 0.2,   ### ROBUSTNESS: Tuned weight

    # --- üí° NEW UTILITY LOSS (FROM EVALUATION FEEDBACK) ---\n",
    "LAMBDA_UTILITY": 0.3, # Pushes generator to make diagnosable Pneumonia"

    "LOSS_ANNEAL_START_EPOCH": 15,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "NUM_WORKERS": 2,
    "LOG_FREQ": 100,
    "SAVE_IMG_FREQ": 1,
    "CHECKPOINT_FREQ": 500,

    ### ROBUSTNESS FIX: Parameters for new L_clinical and L_fairness
    "CLINICAL_LOSS_TYPE": "area", # 'area' implements L_area
    "PLAUSIBLE_LUNG_AREA_MEAN": 0.220646, # Mean area of lungs as % of image (NEEDS TO BE PRE-CALCULATED from your Cseg on real data)
    "PLAUSIBLE_LUNG_AREA_STD": 0.066277, # Std dev of lung area (NEEDS TO BE PRE-CALCULATED)
    "FAIRNESS_BATCH_SIZE": 30, # Large, stable batch size for fairness loss
    "NUM_RACE_GROUPS": 5, # From your dataset (White, Black, Asian, Hispanic, Other)

    # --- üí° NEW FAIRNESS PUSH (FROM EVALUATION FEEDBACK) ---",
    "FAIRNESS_PUSH_WEIGHT": 0.5 # Balances std(TPR) vs. (1 - mean(TPR))\n",
}

os.makedirs(CONFIG['IMAGE_OUTPUT_DIR'], exist_ok=True)
os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)


# ---------------------------
# Basic building blocks
# (PixelNorm, WSConv2d, InjectNoise, AdaIN)
# ---------------------------
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None
        nn.init.normal_(self.conv.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    def forward(self, x):
        out = self.conv(x * self.scale)
        if self.bias is not None:
            out = out + self.bias.view(1, self.bias.shape[0], 1, 1)
        return out

class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = nn.Linear(w_dim, channels)
        self.style_bias = nn.Linear(w_dim, channels)
    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

# ---------------------------
# Mapping & Generator blocks
# (MappingNetwork, GenBlock, Generator)
# ---------------------------
class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        layers = [PixelNorm()]
        for i in range(8):
            if i == 0:
                layers.append(nn.Linear(z_dim, w_dim))
            else:
                layers.append(nn.Linear(w_dim, w_dim))
            if i < 7:
                layers.append(nn.ReLU())
        self.mapping = nn.Sequential(*layers)
    def forward(self, x):
        return self.mapping(x)

class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)
    def forward(self, x, w):
        x = self.leaky(self.inject_noise1(self.conv1(x)))
        x = self.adain1(x, w)
        x = self.leaky(self.inject_noise2(self.conv2(x)))
        x = self.adain2(x, w)
        return x

class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, base_channels, img_channels=3):
        super().__init__()
        self.starting_const = nn.Parameter(torch.randn(1, base_channels, 4, 4))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_conv = WSConv2d(base_channels, base_channels, kernel_size=3, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.factors = [512, 512, 512, 256, 128, 64, 32, 16]
        assert base_channels == self.factors[0]
        self.prog_blocks = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.to_rgbs.append(WSConv2d(self.factors[0], img_channels, kernel_size=1, padding=0))
        for i in range(1, len(self.factors)):
            in_c = self.factors[i-1]
            out_c = self.factors[i]
            self.prog_blocks.append(GenBlock(in_c, out_c, w_dim))
            self.to_rgbs.append(WSConv2d(out_c, img_channels, kernel_size=1, padding=0))

    def forward(self, z, alpha, steps):
        w = self.map(z)
        batch = z.shape[0]
        x = self.starting_const.repeat(batch, 1, 1, 1)
        x = self.initial_conv(x)
        x = self.leaky(x)
        if steps == 0:
            out = self.to_rgbs[0](x)
            return torch.tanh(out)
        prev = None
        for step in range(1, steps + 1):
            prev = x
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            x = self.prog_blocks[step - 1](x, w)
        final_out = self.to_rgbs[steps](x)
        if alpha < 1.0 and prev is not None:
            prev_rgb = self.to_rgbs[steps - 1](prev)
            prev_rgb_upsampled = F.interpolate(prev_rgb, scale_factor=2, mode='bilinear', align_corners=False)
            out = alpha * final_out + (1.0 - alpha) * prev_rgb_upsampled
        else:
            out = final_out
        return torch.tanh(out)

# ---------------------------
# Discriminator
# (DiscBlock, Discriminator)
# ---------------------------
class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, in_channels)
        self.conv2 = WSConv2d(in_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x

class Discriminator(nn.Module):
    def __init__(self, base_channels, img_channels=3):
        super().__init__()
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.factors = [512, 512, 512, 256, 128, 64, 32, 16]
        assert base_channels == self.factors[0]
        self.from_rgbs = nn.ModuleList([WSConv2d(img_channels, ch, kernel_size=1, padding=0) for ch in self.factors])
        self.prog_blocks = nn.ModuleList()
        for i in range(1, len(self.factors)):
            in_c = self.factors[i]
            out_c = self.factors[i-1]
            self.prog_blocks.append(DiscBlock(in_c, out_c))
        self.final_block = nn.Sequential(
            WSConv2d(self.factors[0] + 1, self.factors[0]), self.leaky,
            WSConv2d(self.factors[0], self.factors[0], kernel_size=4, padding=0, stride=1), self.leaky,
            WSConv2d(self.factors[0], 1, kernel_size=1, padding=0, stride=1)
        )
    def forward(self, x, alpha, steps):
        out = self.leaky(self.from_rgbs[steps](x))
        for s in range(steps, 0, -1):
            out = self.prog_blocks[s - 1](out)
            out = F.avg_pool2d(out, kernel_size=2)
            if alpha < 1.0 and s == steps:
                downscaled_image = F.avg_pool2d(x, kernel_size=2)
                low_res_features = self.leaky(self.from_rgbs[s - 1](downscaled_image))
                out = alpha * out + (1.0 - alpha) * low_res_features
        std = torch.std(out, dim=0, unbiased=False).mean().view(1, 1, 1, 1)
        std_map = std.repeat(out.shape[0], 1, out.shape[2], out.shape[3])
        out = torch.cat([out, std_map], dim=1)
        out = self.final_block(out)
        return out.view(out.shape[0], -1)

# --------------------
# Dataset & utilities
# --------------------
class MIMIC_GAN_Dataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.df['race_group'] = self.df['race_group'].astype('category')
        self.one_hot_races = pd.get_dummies(self.df['race_group']).reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']
        image_path = os.path.join(self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg')
        try:
            image = np.array(Image.open(image_path).convert("RGB"))
        except Exception:
            return None, None, None
        if self.transform:
            image = self.transform(image=image)['image']
        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)
        race = torch.tensor(self.one_hot_races.iloc[idx].values, dtype=torch.float32)
        return image, label, race

def custom_collate(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    if not batch:
        return torch.Tensor(), torch.Tensor(), torch.Tensor()
    return torch.utils.data.dataloader.default_collate(batch)

def load_critics(cdiag_path, cseg_path, device):
    print("üß† Loading frozen critic models...")
    Cdiag = models.resnet50()
    Cdiag.fc = nn.Linear(Cdiag.fc.in_features, 1)
    Cdiag.load_state_dict(torch.load(cdiag_path, map_location=device))
    Cdiag.eval().to(device)
    for param in Cdiag.parameters():
        param.requires_grad = False
    print("   - Cdiag (Diagnostic Critic) loaded.")

    Cseg = smp.Unet("resnet34", in_channels=3, classes=1)
    Cseg.load_state_dict(torch.load(cseg_path, map_location=device))
    Cseg.eval().to(device)
    for param in Cseg.parameters():
        param.requires_grad = False
    print("   - Cseg (Segmentation Critic) loaded.")
    return Cdiag, Cseg

def save_checkpoint(gen, disc, opt_gen, opt_disc, epoch, step, alpha, path):
    state = {
        'gen': gen.state_dict(), 'disc': disc.state_dict(),
        'opt_gen': opt_gen.state_dict(), 'opt_disc': opt_disc.state_dict(),
        'epoch': epoch, 'step': step, 'alpha': alpha,
    }
    torch.save(state, path)

def load_checkpoint(gen, disc, opt_gen, opt_disc, path, device):
    if os.path.exists(path):
        print(f"‚úÖ Resuming training from checkpoint: {path}")
        state = torch.load(path, map_location=device)
        gen.load_state_dict(state['gen'])
        disc.load_state_dict(state['disc'])
        opt_gen.load_state_dict(state['opt_gen'])
        opt_disc.load_state_dict(state['opt_disc'])
        return state['epoch'], state['step'], state['alpha']
    print("üì≠ No checkpoint found, starting from scratch.")
    return 0, 0, 1.0


# ----------------------------------------------------
### ROBUSTNESS FIX: New Clinical Loss ###
# ----------------------------------------------------
def calculate_clinical_loss(fake_masks, config):
    """
    Calculates the L_clinical loss based.
    This version implements L_area.
    """
    if config["CLINICAL_LOSS_TYPE"] == "area":
        # fake_masks are logits (B, 1, H, W)
        fake_masks_prob = torch.sigmoid(fake_masks)

        # Calculate the area of the mask as a percentage of total pixels
        # (B, 1, H, W) -> (B,)
        total_pixels = fake_masks_prob.shape[2] * fake_masks_prob.shape[3]
        mask_area_percent = fake_masks_prob.sum(dim=[1, 2, 3]) / total_pixels

        # Get target plausible area stats
        mean_area = config["PLAUSIBLE_LUNG_AREA_MEAN"]
        std_area = config["PLAUSIBLE_LUNG_AREA_STD"]

        # Calculate L1 distance from the mean, normalized by std
        # This penalizes masks that are too large or too small
        area_loss = F.l1_loss(mask_area_percent, torch.tensor(mean_area, device=mask_area_percent.device)) / std_area
        return area_loss
    else:
        # Placeholder for L_contiguity or L_shape
        return torch.tensor(0.0, device=fake_masks.device)

# --------------------------------------------------------------
# --- üí° NEW UTILITY LOSS FUNCTION (FROM EVALUATION FEEDBACK) ---
# --------------------------------------------------------------
def calculate_diag_utility_loss(fake_images_512, cdiag):
    """
    Calculates the utility loss.
    This pushes the generator to create images that Cdiag sees as 'Pneumonia'.
    """
    # Get predictions (logits) from Cdiag for the *in-batch* fakes
    diag_preds_logits = cdiag(fake_images_512).squeeze()

    # Create target labels of all '1.0' (Pneumonia)
    target_labels = torch.ones_like(diag_preds_logits)

    # Calculate the loss: how far are the predictions from 1.0?
    loss = F.binary_cross_entropy_with_logits(diag_preds_logits, target_labels)
    return loss


# ----------------------------------------------------------------
# --- üí° UPGRADED FAIRNESS LOSS (FROM EVALUATION FEEDBACK) ---
# ----------------------------------------------------------------
def calculate_fairness_loss(gen, cdiag, alpha, steps, config):
    """
    Calculates a robust L_fairness.
    Now combines two objectives:
    1. loss_std: Make all TPRs EQUAL (std -> 0)
    2. loss_push: Make all TPRs HIGH (mean -> 1.0)
    """
    gen.eval() # Use eval mode for this calculation

    B = config["FAIRNESS_BATCH_SIZE"]
    N_GROUPS = config["NUM_RACE_GROUPS"]
    DEVICE = config["DEVICE"]

    if B % N_GROUPS != 0:
        raise ValueError("FAIRNESS_BATCH_SIZE must be divisible by NUM_RACE_GROUPS")

    with torch.no_grad():
        z_fair = torch.randn(B, config["LATENT_DIM"], device=DEVICE)
        fake_images_fair = gen(z_fair, alpha, steps)
        fake_images_512 = F.interpolate(fake_images_fair, size=(512, 512), mode='bilinear', align_corners=False)

    diag_preds_logits = cdiag(fake_images_512).squeeze()
    diag_preds_probs = torch.sigmoid(diag_preds_logits)

    samples_per_group = B // N_GROUPS
    tpr_per_group = []
    for j in range(N_GROUPS):
        start_idx = j * samples_per_group
        end_idx = (j + 1) * samples_per_group
        tpr = diag_preds_probs[start_idx:end_idx].mean()
        tpr_per_group.append(tpr)

    if len(tpr_per_group) > 1:
        tpr_stack = torch.stack(tpr_per_group)

        # 1. Fairness (Equality) Loss: Make them equal
        loss_std = torch.std(tpr_stack)

        # 2. Utility (Push) Loss: Make them HIGH
        # We want the mean TPR to be 1.0. Loss is (1.0 - mean_tpr)
        loss_push = 1.0 - torch.mean(tpr_stack)

        # 3. Combine them using the push weight
        loss_g_fair = loss_std + config["FAIRNESS_PUSH_WEIGHT"] * loss_push

    else:
        loss_g_fair = torch.tensor(0.0, device=DEVICE)

    gen.train() # Set generator back to train mode
    return loss_g_fair


# ---------------------------
# Training loop (MODIFIED)
# ---------------------------
def run_training():
    DEVICE = CONFIG['DEVICE']
    print("üöÄ Initializing Robust CAF-GAN Training...")

    # --- Models & Optimizers ---
    gen = Generator(CONFIG['LATENT_DIM'], CONFIG['LATENT_DIM'], base_channels=512, img_channels=CONFIG['CHANNELS']).to(DEVICE)
    disc = Discriminator(base_channels=512, img_channels=CONFIG['CHANNELS']).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=CONFIG['G_LR'], betas=(CONFIG['B1'], CONFIG['B2']))
    opt_disc = optim.Adam(disc.parameters(), lr=CONFIG['D_LR'], betas=(CONFIG['B1'], CONFIG['B2']))

    # --- Load Frozen Critics (Cdiag and Cseg) ---
    Cdiag, Cseg = load_critics(CONFIG['CDIAG_PATH'], CONFIG['CSEG_PATH'], DEVICE)

    # --- Checkpoint Loading ---
    checkpoint_path = os.path.join(CONFIG['OUTPUT_DIR'], "caf_gan_checkpoint.pth")
    start_epoch, start_step, alpha = load_checkpoint(gen, disc, opt_gen, opt_disc, checkpoint_path, DEVICE)

    train_df = pd.read_csv(CONFIG['TRAIN_CSV_PATH'])
    fixed_noise = torch.randn(8, CONFIG['LATENT_DIM'], device=DEVICE)

    gen.train()
    disc.train()

    num_progressive_stages = len(CONFIG['PROGRESSIVE_EPOCHS'])

    # --- Main Training Loop ---
    for res_step in range(start_step, num_progressive_stages):
        steps = res_step + 1
        img_size = 4 * (2 ** steps)
        if img_size > CONFIG['TARGET_IMG_SIZE']: break

        loader_batch_size = CONFIG['BATCH_SIZES'][res_step]
        num_epochs_for_res = CONFIG['PROGRESSIVE_EPOCHS'][res_step]

        epoch_offset = start_epoch if res_step == start_step else 0
        if epoch_offset > 0:
            print(f"Resuming at Epoch {epoch_offset+1}/{num_epochs_for_res} for this stage.")
        start_epoch = 0

        print("\n" + "="*50)
        print(f"üìà Starting Training for Resolution: {img_size}x{img_size}")
        print(f"   Epochs for this stage: {num_epochs_for_res} | Batch Size: {loader_batch_size}")
        print("="*50)

        transform = A.Compose([
            A.Resize(width=img_size, height=img_size, interpolation=Image.LANCZOS),
            A.HorizontalFlip(p=0.5),
            A.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]),
            ToTensorV2(),
        ])
        dataset = MIMIC_GAN_Dataset(train_df, CONFIG['IMAGE_DIR'], transform=transform)
        loader = DataLoader(dataset, batch_size=loader_batch_size, shuffle=True, num_workers=CONFIG['NUM_WORKERS'], pin_memory=True, collate_fn=custom_collate)

        for epoch in range(epoch_offset, num_epochs_for_res):
            current_epoch_total = sum(CONFIG['PROGRESSIVE_EPOCHS'][:res_step]) + epoch + 1
            print(f"\n--- Resolution {img_size}x{img_size} | Epoch {epoch+1}/{num_epochs_for_res} (Total: {current_epoch_total}) ---")

            # --- Alpha and Lambda Annealing ---
            if num_epochs_for_res > 1:
                alpha = min(1.0, (epoch + 1) / (num_epochs_for_res * 0.5))
            else:
                alpha = 1.0

            clinical_fair_lambda_scale = 1.0
            if epoch < CONFIG['LOSS_ANNEAL_START_EPOCH']:
                clinical_fair_lambda_scale = 0.0
            elif epoch < num_epochs_for_res:
                clinical_fair_lambda_scale = max(0.0, (epoch - CONFIG['LOSS_ANNEAL_START_EPOCH']) / max(1, (num_epochs_for_res - CONFIG['LOSS_ANNEAL_START_EPOCH'])))

            loop = tqdm(loader, leave=True)
            for batch_idx, (real, labels, races) in enumerate(loop):
                if real.nelement() == 0:
                    continue
                real = real.to(DEVICE)

                # ========== Train Discriminator ==========
                noise = torch.randn(real.shape[0], CONFIG['LATENT_DIM'], device=DEVICE)
                fake = gen(noise, alpha, steps)

                disc_real = disc(real, alpha, steps)
                disc_fake = disc(fake.detach(), alpha, steps)

                # R1 Gradient Penalty
                gp = 0.0
                if batch_idx % CONFIG['R1_REG_FREQ'] == 0:
                    real.requires_grad = True
                    disc_real_for_gp = disc(real, alpha, steps)
                    grad = torch.autograd.grad(outputs=disc_real_for_gp.sum(), inputs=real, create_graph=True)[0]
                    grad_penalty = (grad.view(grad.shape[0], -1).norm(2, dim=1) ** 2).mean()
                    gp = CONFIG['LAMBDA_R1'] * grad_penalty
                    real.requires_grad = False

                loss_disc = (torch.mean(F.softplus(disc_fake)) + torch.mean(F.softplus(-disc_real))) + gp

                opt_disc.zero_grad()
                loss_disc.backward()
                opt_disc.step()

                # ========== Train Generator ==========
                gen_fake = gen(noise, alpha, steps)
                disc_fake_for_g = disc(gen_fake, alpha, steps)
                loss_g_adv = torch.mean(F.softplus(-disc_fake_for_g))

                loss_g_clinic = torch.tensor(0.0, device=DEVICE)
                loss_g_fair = torch.tensor(0.0, device=DEVICE)
                # --- üí° NEW UTILITY LOSS VARIABLE ---
                loss_g_utility = torch.tensor(0.0, device=DEVICE)

                # --- Apply Robust Clinical, Fair, and Utility Losses ---
                if clinical_fair_lambda_scale > 0.0 and img_size == CONFIG['TARGET_IMG_SIZE']:

                    fake_512 = F.interpolate(gen_fake, size=(512, 512), mode='bilinear', align_corners=False)

                    # --- 1. L_clinical (from Cseg) ---
                    fake_masks = Cseg(fake_512)
                    loss_g_clinic = calculate_clinical_loss(fake_masks, CONFIG)

                    # --- 2. üí° NEW L_utility (from Cdiag) ---
                    loss_g_utility = calculate_diag_utility_loss(fake_512, Cdiag)

                    # --- 3. üí° UPGRADED L_fairness (from Cdiag) ---
                    loss_g_fair = calculate_fairness_loss(gen, Cdiag, alpha, steps, CONFIG)


                # --- üí° UPDATED Total Generator Loss ---
                loss_gen = (
                    loss_g_adv +
                    clinical_fair_lambda_scale * CONFIG['LAMBDA_CLINIC'] * loss_g_clinic +
                    clinical_fair_lambda_scale * CONFIG['LAMBDA_FAIR'] * loss_g_fair +
                    clinical_fair_lambda_scale * CONFIG['LAMBDA_UTILITY'] * loss_g_utility
                )

                opt_gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()

                # --- üí° UPDATED TQDM Postfix ---
                loop.set_postfix(
                    D_loss=f"{loss_disc.item():.4f}",
                    G_loss=f"{loss_gen.item():.4f}",
                    G_adv=f"{loss_g_adv.item():.4f}",
                    G_clinic=f"{loss_g_clinic.item():.4f}",
                    G_fair=f"{loss_g_fair.item():.4f}",
                    G_util=f"{loss_g_utility.item():.4f}",
                    Alpha=f"{alpha:.3f}",
                    Œª_Scale=f"{clinical_fair_lambda_scale:.3f}"
                )

                if batch_idx % CONFIG['CHECKPOINT_FREQ'] == 0:
                    save_checkpoint(gen, disc, opt_gen, opt_disc, epoch, res_step, alpha, checkpoint_path)

            # --- End of Epoch Actions ---
            if (epoch + 1) % CONFIG['SAVE_IMG_FREQ'] == 0:
                gen.eval()
                with torch.no_grad():
                    img_grid = gen(fixed_noise, alpha, steps).detach().cpu()
                vutils.save_image(img_grid, os.path.join(CONFIG['IMAGE_OUTPUT_DIR'], f"res_{img_size}_epoch_{current_epoch_total}.png"), normalize=True, nrow=4)
                print(f"üñºÔ∏è Saved generated images for epoch {current_epoch_total}")
                gen.train()

            save_checkpoint(gen, disc, opt_gen, opt_disc, epoch, res_step, alpha, checkpoint_path)
            print(f"üíæ End of epoch {current_epoch_total}. Checkpoint saved.")

    print("\nüéâüéâüéâ CAF-GAN Training Complete! üéâüéâüéâ")
    torch.save(gen.state_dict(), os.path.join(CONFIG['OUTPUT_DIR'], "caf_gan_generator_final.pth"))

if __name__ == "__main__":
    run_training()



üöÄ Initializing Robust CAF-GAN Training...
üß† Loading frozen critic models...
   - Cdiag (Diagnostic Critic) loaded.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

   - Cseg (Segmentation Critic) loaded.
‚úÖ Resuming training from checkpoint: /content/drive/MyDrive/CAF-GAN/outputs/caf_gan_final/caf_gan_checkpoint.pth
Resuming at Epoch 96/100 for this stage.

üìà Starting Training for Resolution: 512x512
   Epochs for this stage: 100 | Batch Size: 4

--- Resolution 512x512 | Epoch 96/100 (Total: 276) ---


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 350/350 [09:24<00:00,  1.61s/it, Alpha=1.000, D_loss=0.0385, G_adv=1.5693, G_clinic=0.3712, G_fair=0.0170, G_loss=1.7518, G_util=0.0166, Œª_Scale=0.941]


üñºÔ∏è Saved generated images for epoch 276
üíæ End of epoch 276. Checkpoint saved.

--- Resolution 512x512 | Epoch 97/100 (Total: 277) ---


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 350/350 [09:03<00:00,  1.55s/it, Alpha=1.000, D_loss=0.2446, G_adv=1.9582, G_clinic=0.6295, G_fair=0.0939, G_loss=2.3463, G_util=0.2459, Œª_Scale=0.953]


üñºÔ∏è Saved generated images for epoch 277
üíæ End of epoch 277. Checkpoint saved.

--- Resolution 512x512 | Epoch 98/100 (Total: 278) ---


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 350/350 [09:03<00:00,  1.55s/it, Alpha=1.000, D_loss=0.2477, G_adv=4.4539, G_clinic=0.3505, G_fair=0.0439, G_loss=4.6482, G_util=0.0579, Œª_Scale=0.965]


üñºÔ∏è Saved generated images for epoch 278
üíæ End of epoch 278. Checkpoint saved.

--- Resolution 512x512 | Epoch 99/100 (Total: 279) ---


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 350/350 [09:04<00:00,  1.56s/it, Alpha=1.000, D_loss=0.1615, G_adv=5.8137, G_clinic=0.4206, G_fair=0.0370, G_loss=6.0320, G_util=0.0198, Œª_Scale=0.976]


üñºÔ∏è Saved generated images for epoch 279
üíæ End of epoch 279. Checkpoint saved.

--- Resolution 512x512 | Epoch 100/100 (Total: 280) ---


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 350/350 [09:06<00:00,  1.56s/it, Alpha=1.000, D_loss=0.5671, G_adv=4.5973, G_clinic=0.4247, G_fair=0.0915, G_loss=4.8767, G_util=0.1738, Œª_Scale=0.988]


üñºÔ∏è Saved generated images for epoch 280
üíæ End of epoch 280. Checkpoint saved.

üéâüéâüéâ CAF-GAN Training Complete! üéâüéâüéâ
