In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyyaml pandas scikit-learn joblib albumentations segmentation-models-pytorch -q

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [14]:
#!/usr/bin/env python3
"""
MASTER_v4 - Paper-ready TSTR + Fairness + Clinical evaluation script
- Uses pre-generated balanced synthetic data (option A)
- Evaluates LR, RF, SVM for Train-on-Real and Train-on-Synth (TSTR)
- Saves .joblib downstream models to CLASSIFIER_OUTPUT_DIR
- Saves confusion matrices, results CSVs, and a summary CSV
- Robust memory-safe feature extraction + clinical scoring using Cseg
- Includes path to user's uploaded file for reproducibility tooling
"""

import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.utils as vutils
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, confusion_matrix
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# Path to the uploaded file (developer instruction)
USER_UPLOADED_FILE = "/mnt/data/b77eaff1-0448-4a1c-b99a-2cc8209bda8c.pdf"

# -------------------------
# CONFIG - change if needed
# -------------------------
CONFIG = {
    # Checkpoints
    "GENERATOR_CHECKPOINT": "/content/drive/MyDrive/CAF-GAN/outputs/caf_gan_final/caf_gan_generator_final.pth",
    "CDIAG_CHECKPOINT": "/content/drive/MyDrive/CAF-GAN/outputs/cdiag_512/best_cdiag_512.pth",
    "CSEG_CHECKPOINT": "/content/drive/MyDrive/CAF-GAN/outputs/cseg_512/best_cseg_512.pth",

    # Data
    "REAL_TEST_CSV": "/content/drive/MyDrive/CAF-GAN/data/splits/test.csv",
    "REAL_IMAGE_DIR": "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/",

    # Use existing balanced synthetic (choice A)
    "SYNTHETIC_IMAGE_DIR": "/content/drive/MyDrive/CAF-GAN/data/synthetic_images_balanced/",
    "SYNTHETIC_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/synthetic_images_balanced/labels.csv",

    # Outputs
    "OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/evaluation_results_master_v4/",
    "CLASSIFIER_OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/classifiers/",

    # Model params
    "LATENT_DIM": 512,
    "BASE_CHANNELS": 512,
    "CHANNELS": 3,
    "EVAL_IMG_SIZE": 224,   # for ResNet feature extractor
    "SEG_IMG_SIZE": 512,    # for Cseg clinical scoring

    # Dataset / runtime
    "TARGET_PER_CLASS": 2500,
    "BATCH_SIZE": 32,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "NUM_WORKERS": 2,
    "MANUAL_SEED": 42,

    # Fairness / clinical
    "MIN_GROUP_POSITIVES_FOR_TPR": 1,  # show TPR for small groups (paper)
    "PLAUSIBLE_LUNG_AREA_MEAN": 0.220646,
    "PLAUSIBLE_LUNG_AREA_STD": 0.066277,

    # Misc
    "SAVE_DOWNSTREAM_MODELS": True,
}

os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)
os.makedirs(CONFIG['CLASSIFIER_OUTPUT_DIR'], exist_ok=True)
os.makedirs(CONFIG['SYNTHETIC_IMAGE_DIR'], exist_ok=True)

torch.manual_seed(CONFIG['MANUAL_SEED'])
np.random.seed(CONFIG['MANUAL_SEED'])

# -------------------------
# Generator architecture (same as trained model)
# -------------------------
class PixelNorm(nn.Module):
    def __init__(self): super().__init__(); self.epsilon = 1e-8
    def forward(self, x): return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias; self.conv.bias = None
        nn.init.normal_(self.conv.weight)
        if self.bias is not None: nn.init.zeros_(self.bias)
    def forward(self, x):
        out = self.conv(x * self.scale)
        if self.bias is not None: out = out + self.bias.view(1, self.bias.shape[0], 1, 1)
        return out

class InjectNoise(nn.Module):
    def __init__(self, channels):
        super().__init__(); self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = nn.Linear(w_dim, channels)
        self.style_bias = nn.Linear(w_dim, channels)
    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        layers = [PixelNorm()]
        for i in range(8):
            layers.append(nn.Linear(z_dim if i == 0 else w_dim, w_dim))
            if i < 7: layers.append(nn.ReLU())
        self.mapping = nn.Sequential(*layers)
    def forward(self, x): return self.mapping(x)

class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels); self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True); self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels); self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)
    def forward(self, x, w):
        x = self.leaky(self.inject_noise1(self.conv1(x))); x = self.adain1(x, w)
        x = self.leaky(self.inject_noise2(self.conv2(x))); x = self.adain2(x, w)
        return x

class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, base_channels, img_channels=3):
        super().__init__()
        self.starting_const = nn.Parameter(torch.randn(1, base_channels, 4, 4))
        self.map = MappingNetwork(z_dim, w_dim)
        self.initial_conv = WSConv2d(base_channels, base_channels, kernel_size=3, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.factors = [512, 512, 512, 256, 128, 64, 32, 16]
        self.prog_blocks = nn.ModuleList(); self.to_rgbs = nn.ModuleList()
        self.to_rgbs.append(WSConv2d(self.factors[0], img_channels, kernel_size=1, padding=0))
        for i in range(1, len(self.factors)):
            in_c = self.factors[i-1]; out_c = self.factors[i]
            self.prog_blocks.append(GenBlock(in_c, out_c, w_dim))
            self.to_rgbs.append(WSConv2d(out_c, img_channels, kernel_size=1, padding=0))

    def forward(self, z, alpha, steps):
        w = self.map(z); batch = z.shape[0]
        x = self.starting_const.repeat(batch, 1, 1, 1); x = self.initial_conv(x); x = self.leaky(x)
        if steps == 0: return torch.tanh(self.to_rgbs[0](x))
        prev = None
        for step in range(1, steps + 1):
            prev = x; x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            x = self.prog_blocks[step - 1](x, w)
        final_out = self.to_rgbs[steps](x)
        if alpha < 1.0 and prev is not None:
            prev_rgb = self.to_rgbs[steps - 1](prev)
            prev_rgb_upsampled = F.interpolate(prev_rgb, scale_factor=2, mode='bilinear', align_corners=False)
            out = alpha * final_out + (1.0 - alpha) * prev_rgb_upsampled
        else: out = final_out
        return torch.tanh(out)

# -------------------------
# Dataset loader for real data (keeps race)
# -------------------------
class RobustRealDatasetWithRace(Dataset):
    def __init__(self, df, image_dir, transform=None, eval_size=CONFIG['EVAL_IMG_SIZE']):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.eval_size = eval_size
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id']); study_id = str(row['study_id']); dicom_id = row['dicom_id']
        image_path = os.path.join(self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg')
        if not os.path.exists(image_path):
            return torch.zeros((3, self.eval_size, self.eval_size), dtype=torch.float32), torch.tensor(-1.0), "Unknown"
        try:
            img = Image.open(image_path).convert("RGB")
            if self.transform:
                arr = np.array(img)
                out = self.transform(image=arr)['image']
            else:
                out = transforms.ToTensor()(img)
            label = torch.tensor(row['Pneumonia'], dtype=torch.float32)
            race = row.get('race_group', 'Unknown')
            return out, label, race
        except Exception:
            return torch.zeros((3, self.eval_size, self.eval_size), dtype=torch.float32), torch.tensor(-1.0), "Unknown"

# -------------------------
# Preprocessor for ResNet/Cseg
# -------------------------
class Preprocessor:
    def __init__(self, device):
        self.device = device
        self.mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1)
        self.std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1)
    def gan_to_01(self, x): return (x + 1.0) * 0.5
    def prepare_for_resnet(self, x):
        if x.dim() == 3: x = x.unsqueeze(0)
        x01 = self.gan_to_01(x)
        x_resized = F.interpolate(x01, size=(CONFIG['EVAL_IMG_SIZE'], CONFIG['EVAL_IMG_SIZE']), mode='bilinear', align_corners=False)
        x_norm = (x_resized - self.mean) / self.std
        return x_norm
    def prepare_for_cseg(self, x):
        # Ensure in [0,1] and size SEG_IMG_SIZE, NO ImageNet normalization (we pass [0,1])
        if x.dim() == 3: x = x.unsqueeze(0)
        x01 = self.gan_to_01(x) if x.max() <= 1.0 or x.min() < -0.5 else x
        x_resized = F.interpolate(x01, size=(CONFIG['SEG_IMG_SIZE'], CONFIG['SEG_IMG_SIZE']), mode='bilinear', align_corners=False)
        return x_resized

# -------------------------
# Plotting helpers
# -------------------------
def save_confusion_matrix(y_true, y_pred, labels, outpath, title="Confusion Matrix"):
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    plt.figure(figsize=(4,4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.title(title)
    plt.tight_layout(); plt.savefig(outpath); plt.close()

# -------------------------
# MASTER V4 evaluator
# -------------------------
class MasterEvaluatorV4:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['DEVICE'])
        self.preproc = Preprocessor(self.device)

        # Load generator
        self.generator = Generator(config['LATENT_DIM'], config['LATENT_DIM'], config['BASE_CHANNELS'], config['CHANNELS']).to(self.device)
        try:
            ckpt = torch.load(config['GENERATOR_CHECKPOINT'], map_location=self.device)
            if isinstance(ckpt, dict):
                if 'gen' in ckpt: self.generator.load_state_dict(ckpt['gen'])
                elif 'state_dict' in ckpt: self.generator.load_state_dict(ckpt['state_dict'])
                else: self.generator.load_state_dict(ckpt)
            else:
                self.generator.load_state_dict(ckpt)
            self.generator.eval()
            print("‚úÖ Generator loaded.")
        except Exception as e:
            raise RuntimeError(f"Failed to load generator: {e}")

        # Load Cdiag (ResNet50)
        self.cdiag = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.cdiag.fc = nn.Linear(self.cdiag.fc.in_features, 1)
        try:
            self.cdiag.load_state_dict(torch.load(config['CDIAG_CHECKPOINT'], map_location=self.device))
            print("‚úÖ Cdiag loaded.")
        except Exception:
            print("‚ö†Ô∏è Cdiag checkpoint missing/invalid; using ImageNet init.")
        self.cdiag.to(self.device).eval()

        # feature extractor (remove final fc)
        self.feature_extractor = nn.Sequential(*list(self.cdiag.children())[:-1]).to(self.device).eval()

    # Load existing balanced synthetic features from CSV (memory-safe)
    def load_synthetic_features_from_disk(self, df):
        df = df.reset_index(drop=True)
        feats = []
        labels = []
        # We'll process in small batches to avoid OOM
        batch_paths = []
        batch_labels = []
        for _, row in df.iterrows():
            batch_paths.append(os.path.join(self.config['SYNTHETIC_IMAGE_DIR'], str(row['image_id'])))
            batch_labels.append(int(row['Pneumonia']))
        # process in mini-batches
        bs = 64
        with torch.no_grad():
            for i in tqdm(range(0, len(batch_paths), bs), desc="Loading synthetic features"):
                paths_chunk = batch_paths[i:i+bs]
                labs_chunk = batch_labels[i:i+bs]
                imgs = []
                for p in paths_chunk:
                    try:
                        img = Image.open(p).convert("RGB")
                        t = transforms.ToTensor()(img)  # [0,1]
                        imgs.append(t)
                    except Exception:
                        imgs.append(torch.zeros(3, self.config['EVAL_IMG_SIZE'], self.config['EVAL_IMG_SIZE']))
                x = torch.stack(imgs, dim=0).to(self.device)
                # normalize for feature extractor
                x = F.interpolate(x, size=(self.config['EVAL_IMG_SIZE'], self.config['EVAL_IMG_SIZE']), mode='bilinear', align_corners=False)
                x = (x - self.preproc.mean) / self.preproc.std
                f = self.feature_extractor(x).view(x.size(0), -1).detach().cpu()
                feats.append(f)
                labels.extend(labs_chunk)
                # free
                del x, f; torch.cuda.empty_cache()
        feats = torch.cat(feats, dim=0).numpy()
        labels = np.array(labels, dtype=int)
        return feats, labels

    # If CSV exists & balanced, load features; otherwise raise error (we assume user already has balanced dataset)
    def get_synthetic_features(self):
        if not os.path.exists(self.config['SYNTHETIC_CSV_PATH']):
            raise RuntimeError("Synthetic CSV not found at configured path: " + self.config['SYNTHETIC_CSV_PATH'])
        df = pd.read_csv(self.config['SYNTHETIC_CSV_PATH'])
        cnt0 = len(df[df['Pneumonia']==0]); cnt1 = len(df[df['Pneumonia']==1])
        if cnt0 < self.config['TARGET_PER_CLASS'] or cnt1 < self.config['TARGET_PER_CLASS']:
            raise RuntimeError(f"Synthetic CSV found but not balanced: ({cnt0},{cnt1})")
        print(f"‚úÖ Existing balanced synthetic dataset found ({cnt0},{cnt1}). Loading features.")
        return self.load_synthetic_features_from_disk(df)

    # Real feature extraction (keeps race)
    def extract_real_features_with_race(self):
        df = pd.read_csv(self.config['REAL_TEST_CSV']).reset_index(drop=True)
        transform = A.Compose([A.Resize(self.config['EVAL_IMG_SIZE'], self.config['EVAL_IMG_SIZE']), A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])
        dataset = RobustRealDatasetWithRace(df, self.config['REAL_IMAGE_DIR'], transform=transform, eval_size=self.config['EVAL_IMG_SIZE'])
        loader = DataLoader(dataset, batch_size=self.config['BATCH_SIZE'], shuffle=False, num_workers=self.config['NUM_WORKERS'])
        feats = []; labs = []; races = []
        with torch.no_grad():
            for imgs, labels, batch_races in tqdm(loader, desc="Extracting real features"):
                mask = (labels != -1)
                if not mask.any(): continue
                imgs_kept = imgs[mask].to(self.device)
                labs_kept = labels[mask].cpu().numpy()
                races_kept = [r for i,r in enumerate(batch_races) if mask[i]]
                f = self.feature_extractor(imgs_kept).view(imgs_kept.shape[0], -1).detach().cpu()
                feats.append(f); labs.append(torch.tensor(labs_kept)); races.extend(races_kept)
                del imgs_kept, f; torch.cuda.empty_cache()
        if len(feats) == 0:
            raise RuntimeError("No real features extracted.")
        X = torch.cat(feats, dim=0).numpy()
        y = torch.cat(labs).numpy()
        races = np.array(races, dtype=object)
        return X, y, races

    # Clinical scoring using Cseg: pass GAN output in [0,1] at 512x512, no ImageNet norm.
    def evaluate_clinical_plausibility(self, n_samples=1024):
        print("\nü©∫ Calculating Clinical Score (L_area) using Cseg...")
        try:
            cseg = smp.Unet('resnet34', in_channels=3, classes=1).to(self.device)
            cseg.load_state_dict(torch.load(self.config['CSEG_CHECKPOINT'], map_location=self.device))
            cseg.eval()
            print("‚úÖ Cseg loaded.")
        except Exception as e:
            print("‚ùå Cseg failed to load:", e); return None
        scores = []
        batch = min(self.config['BATCH_SIZE'], 32)
        steps = math.ceil(n_samples / batch)
        with torch.no_grad():
            for _ in range(steps):
                z = torch.randn(batch, self.config['LATENT_DIM'], device=self.device)
                fake = self.generator(z, alpha=1.0, steps=7)  # [-1,1]
                fake_01 = self.preproc.gan_to_01(fake)         # [0,1]
                fake_resized = F.interpolate(fake_01, size=(self.config['SEG_IMG_SIZE'], self.config['SEG_IMG_SIZE']), mode='bilinear', align_corners=False)
                masks = torch.sigmoid(cseg(fake_resized))
                area_pct = masks.sum(dim=[2,3]) / (masks.shape[2] * masks.shape[3])
                score = torch.abs(area_pct - self.config['PLAUSIBLE_LUNG_AREA_MEAN']) / self.config['PLAUSIBLE_LUNG_AREA_STD']
                scores.append(score.detach().cpu())
                del fake, fake_01, fake_resized, masks; torch.cuda.empty_cache()
        avg_score = torch.cat(scores).mean().item()
        print(f"   üëâ Clinical L_area Score: {avg_score:.4f}")
        return avg_score

    # Fairness metric (TPR per race, EOD)
    def compute_fairness_metrics(self, y_true, y_pred, races, min_pos=None):
        if min_pos is None: min_pos = self.config['MIN_GROUP_POSITIVES_FOR_TPR']
        df = pd.DataFrame({'label': y_true, 'pred': y_pred, 'race': races})
        pos_df = df[df['label']==1]
        tprs = {}
        for r in np.unique(pos_df['race']):
            group = pos_df[pos_df['race']==r]
            if len(group) < min_pos:
                # include small groups if min_pos == 1 (paper setting)
                pass
            tp = group['pred'].sum()
            fn = len(group) - tp
            tpr = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            tprs[r] = tpr
        if len(tprs) >= 2:
            eod = float(max(tprs.values()) - min(tprs.values()))
        else:
            eod = 0.0
        return eod, tprs

    # Main run
    def run(self):
        # 1. Load synthetic (assumes you already have balanced CSV + images)
        X_synth, y_synth = self.get_synthetic_features()

        # 2. Extract real features (and races)
        X_real, y_real, races = self.extract_real_features_with_race()

        # 3. Split real into train/test (we test on test only)
        idxs = np.arange(len(X_real))
        train_idx, test_idx = train_test_split(idxs, test_size=0.3, random_state=42, stratify=y_real)
        X_real_train, y_real_train = X_real[train_idx], y_real[train_idx]
        X_real_test, y_real_test = X_real[test_idx], y_real[test_idx]
        races_test = races[test_idx]

        # 4. Classifiers
        classifiers = {
            'LogisticRegression': LogisticRegression(max_iter=1000, random_state=42),
            'RandomForest': RandomForestClassifier(n_estimators=100, random_state=42),
            'SVM': SVC(probability=True, random_state=42)
        }

        results_rows = []
        labels = [0,1]

        for name, clf in classifiers.items():
            # A: Train on Real -> Test on Real
            clf_real = clf
            clf_real.fit(X_real_train, y_real_train)
            if hasattr(clf_real, "predict_proba"):
                probs_real = clf_real.predict_proba(X_real_test)[:,1]
            else:
                probs_real = clf_real.decision_function(X_real_test)
            preds_real = clf_real.predict(X_real_test)
            auc_real = roc_auc_score(y_real_test, probs_real) if len(np.unique(y_real_test))>1 else float('nan')
            acc_real = accuracy_score(y_real_test, preds_real)
            f1_real = f1_score(y_real_test, preds_real, zero_division=0)
            results_rows.append({'Method':'Train on Real','Classifier':name,'AUC':auc_real,'Accuracy':acc_real,'F1':f1_real})

            # Save model
            if self.config['SAVE_DOWNSTREAM_MODELS']:
                outpath = os.path.join(self.config['CLASSIFIER_OUTPUT_DIR'], f"{name}_trained_on_real.joblib")
                joblib.dump(clf_real, outpath)

            # Save confusion matrix for TRTR
            cm_path_real = os.path.join(self.config['OUTPUT_DIR'], f"confmat_{name}_real.png")
            save_confusion_matrix(y_real_test, preds_real, labels, cm_path_real, title=f"{name} (Train Real)")

            # B: Train on Synthetic -> Test on Real (TSTR)
            # Reinit classifier to avoid state carryover
            if name == 'LogisticRegression': clf_s = LogisticRegression(max_iter=1000, random_state=42)
            elif name == 'RandomForest': clf_s = RandomForestClassifier(n_estimators=100, random_state=42)
            else: clf_s = SVC(probability=True, random_state=42)

            clf_s.fit(X_synth, y_synth)
            if hasattr(clf_s, "predict_proba"):
                probs_synth = clf_s.predict_proba(X_real_test)[:,1]
            else:
                probs_synth = clf_s.decision_function(X_real_test)
            preds_synth = clf_s.predict(X_real_test)
            auc_synth = roc_auc_score(y_real_test, probs_synth) if len(np.unique(y_real_test))>1 else float('nan')
            acc_synth = accuracy_score(y_real_test, preds_synth)
            f1_synth = f1_score(y_real_test, preds_synth, zero_division=0)
            results_rows.append({'Method':'Train on Synthetic (TSTR)','Classifier':name,'AUC':auc_synth,'Accuracy':acc_synth,'F1':f1_synth})

            # Save model trained on synthetic
            if self.config['SAVE_DOWNSTREAM_MODELS']:
                outpath = os.path.join(self.config['CLASSIFIER_OUTPUT_DIR'], f"{name}_trained_on_synth.joblib")
                joblib.dump(clf_s, outpath)

            # Save confusion matrix for TSTR
            cm_path_synth = os.path.join(self.config['OUTPUT_DIR'], f"confmat_{name}_synth.png")
            save_confusion_matrix(y_real_test, preds_synth, labels, cm_path_synth, title=f"{name} (Train Synth)")

            print(f"   Completed evaluation for {name}.")

        # Save results table
        df_res = pd.DataFrame(results_rows)
        df_res.to_csv(os.path.join(self.config['OUTPUT_DIR'], 'tstr_trtr_comparison_v4.csv'), index=False)
        print("\nüìä Results table saved to:", os.path.join(self.config['OUTPUT_DIR'], 'tstr_trtr_comparison_v4.csv'))

        # Fairness & Clinical summary (use RandomForest trained on synth as primary)
        rf_synth_path = os.path.join(self.config['CLASSIFIER_OUTPUT_DIR'], 'RandomForest_trained_on_synth.joblib')
        if os.path.exists(rf_synth_path):
            rf_synth = joblib.load(rf_synth_path)
            preds_rf = rf_synth.predict(X_real_test)
        else:
            # if not saved, compute from classifier above
            preds_rf = classifiers['RandomForest'].predict(X_real_test) if 'RandomForest' in classifiers else np.zeros_like(y_real_test)

        eod, tprs = self.compute_fairness_metrics(y_real_test, preds_rf, races_test)
        clin_score = self.evaluate_clinical_plausibility(n_samples=1024)

        summary = {
            'eod': eod,
            'tprs': tprs,
            'clinical_L_area': clin_score,
            'synth_count': len(y_synth)
        }
        pd.DataFrame([summary]).to_csv(os.path.join(self.config['OUTPUT_DIR'], 'master_v4_summary.csv'), index=False)
        print("\n‚úÖ MASTER_v4 complete. Outputs saved to:", self.config['OUTPUT_DIR'])

        return {'results_table': df_res, 'fairness': summary, 'clinical_score': clin_score}

# -------------------------
# Run main
# -------------------------
if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
    except Exception:
        pass

    evaluator = MasterEvaluatorV4(CONFIG)
    outputs = evaluator.run()
    print(outputs)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Generator loaded.
‚úÖ Cdiag loaded.
‚úÖ Existing balanced synthetic dataset found (2500,2500). Loading features.



Loading synthetic features:   0%|          | 0/79 [00:00<?, ?it/s][A
Loading synthetic features:   1%|‚ñè         | 1/79 [00:00<01:06,  1.17it/s][A
Loading synthetic features:   3%|‚ñé         | 2/79 [00:01<01:04,  1.20it/s][A
Loading synthetic features:   4%|‚ñç         | 3/79 [00:02<01:03,  1.20it/s][A
Loading synthetic features:   5%|‚ñå         | 4/79 [00:03<01:03,  1.18it/s][A
Loading synthetic features:   6%|‚ñã         | 5/79 [00:04<00:58,  1.26it/s][A
Loading synthetic features:   8%|‚ñä         | 6/79 [00:04<00:57,  1.27it/s][A
Loading synthetic features:   9%|‚ñâ         | 7/79 [00:05<00:54,  1.31it/s][A
Loading synthetic features:  10%|‚ñà         | 8/79 [00:06<00:53,  1.33it/s][A
Loading synthetic features:  11%|‚ñà‚ñè        | 9/79 [00:07<00:52,  1.34it/s][A
Loading synthetic features:  13%|‚ñà‚ñé        | 10/79 [00:07<00:50,  1.35it/s][A
Loading synthetic features:  14%|‚ñà‚ñç        | 11/79 [00:08<00:49,  1.36it/s][A
Loading synthetic features:  15%|‚ñà‚ñå  

   Completed evaluation for LogisticRegression.
   Completed evaluation for RandomForest.
   Completed evaluation for SVM.

üìä Results table saved to: /content/drive/MyDrive/CAF-GAN/outputs/evaluation_results_master_v4/tstr_trtr_comparison_v4.csv

ü©∫ Calculating Clinical Score (L_area) using Cseg...
‚úÖ Cseg loaded.
   üëâ Clinical L_area Score: 1.8620

‚úÖ MASTER_v4 complete. Outputs saved to: /content/drive/MyDrive/CAF-GAN/outputs/evaluation_results_master_v4/
{'results_table':                       Method          Classifier       AUC  Accuracy        F1
0              Train on Real  LogisticRegression  0.563776  0.659341  0.474576
1  Train on Synthetic (TSTR)  LogisticRegression  0.647959  0.648352  0.238095
2              Train on Real        RandomForest  0.560969  0.593407  0.350877
3  Train on Synthetic (TSTR)        RandomForest  0.651276  0.626374  0.190476
4              Train on Real                 SVM  0.566582  0.670330  0.375000
5  Train on Synthetic (TSTR)        

In [17]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold
from scipy.spatial.distance import jensenshannon
import segmentation_models_pytorch as smp
import warnings

warnings.filterwarnings("ignore")

# ==========================================
#  THE ULTIMATE ROBUST EVALUATOR CLASS
# ==========================================

class UltimateCAFGANEvaluator:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['DEVICE'])
        os.makedirs(config['OUTPUT_DIR'], exist_ok=True)

        # Load CSEG for Clinical Score (Lung Area)
        print("ü©∫ Loading Clinical Segmentation Model...")
        try:
            self.cseg = smp.Unet('resnet34', in_channels=3, classes=1).to(self.device)
            self.cseg.load_state_dict(torch.load(config['CSEG_CHECKPOINT'], map_location=self.device))
            self.cseg.eval()
            self.has_cseg = True
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load CSEG: {e}. Clinical scores will be skipped.")
            self.has_cseg = False

    # -----------------------------------------------------------
    # 1. UTILITY: Train on Synthetic, Test on Real (TSTR)
    # -----------------------------------------------------------
    def evaluate_utility_tstr(self, X_synth, y_synth, X_real, y_real):
        print("\nü§ñ Running Utility Evaluation (TSTR)...")
        # Safety: Reduce CV folds if data is tiny
        n_splits = 5
        if len(y_real) < 10:
            print("‚ö†Ô∏è Very small test set. Doing simple train/test.")
            clf = LogisticRegression(max_iter=1000)
            clf.fit(X_synth, y_synth)
            probs = clf.predict_proba(X_real)[:, 1] if len(np.unique(y_real)) > 1 else np.zeros(len(y_real))
            return roc_auc_score(y_real, probs) if len(np.unique(y_real)) > 1 else 0.5

        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        aucs = []
        clf = LogisticRegression(max_iter=1000, random_state=42)

        for _, test_idx in skf.split(X_real, y_real):
            X_test_fold = X_real[test_idx]
            y_test_fold = y_real[test_idx]

            # Train on ALL synthetic data
            clf.fit(X_synth, y_synth)

            # Predict on Real Fold
            if len(np.unique(y_test_fold)) > 1:
                probs = clf.predict_proba(X_test_fold)[:, 1]
                aucs.append(roc_auc_score(y_test_fold, probs))

        mean_auc = np.mean(aucs)
        print(f"   ‚úÖ TSTR Mean AUC: {mean_auc:.4f}")
        return mean_auc

    # -----------------------------------------------------------
    # 2. PRIVACY/QUALITY: Distinguishability Check
    # -----------------------------------------------------------
    def evaluate_distinguishability(self, X_real, X_synth):
        print("\nüïµÔ∏è Running Distinguishability Check...")
        # SAFE SAMPLING: Don't crash if we have few images
        n_samples = min(len(X_real), len(X_synth), 2000)

        real_idx = np.random.choice(len(X_real), n_samples, replace=False)
        synth_idx = np.random.choice(len(X_synth), n_samples, replace=False)

        X_combined = np.vstack([X_real[real_idx], X_synth[synth_idx]])
        y_combined = np.array([0] * n_samples + [1] * n_samples) # 0=Real, 1=Fake

        clf = RandomForestClassifier(n_estimators=50, max_depth=10, random_state=42)
        clf.fit(X_combined, y_combined)

        preds = clf.predict(X_combined)
        acc = accuracy_score(y_combined, preds)

        print(f"   ‚úÖ Discriminator Accuracy: {acc:.4f} (0.50 is perfect realism)")
        return acc

    # -----------------------------------------------------------
    # 3. FAIRNESS: EOD and Demographic Parity
    # -----------------------------------------------------------
    def evaluate_fairness(self, X_synth, y_synth, X_real, y_real, races):
        print("\n‚öñÔ∏è Running Fairness Evaluation...")
        clf = LogisticRegression(max_iter=1000, random_state=42)
        clf.fit(X_synth, y_synth)
        preds = clf.predict(X_real)

        # Demographic Parity
        pos_rates = {}
        for r in np.unique(races):
            mask = (races == r)
            if np.sum(mask) > 0:
                pos_rates[r] = np.mean(preds[mask])

        dp_diff = (max(pos_rates.values()) - min(pos_rates.values())) if pos_rates else 0.0

        # Equalized Odds
        tprs = {}
        for r in np.unique(races):
            mask = (races == r)
            if np.sum(mask) > 0:
                y_true_g = y_real[mask]
                y_pred_g = preds[mask]
                if np.sum(y_true_g) > 0:
                    tpr = np.sum((y_true_g == 1) & (y_pred_g == 1)) / np.sum(y_true_g == 1)
                    tprs[r] = tpr

        eod = (max(tprs.values()) - min(tprs.values())) if tprs else 0.0

        print(f"   ‚úÖ Fairness Gap (Demog Parity): {dp_diff:.4f}")
        print(f"   ‚úÖ Fairness Gap (Equalized Odds): {eod:.4f}")
        return dp_diff, eod

    # -----------------------------------------------------------
    # 4. CLINICAL: Lung Segmentation Score
    # -----------------------------------------------------------
    def evaluate_clinical(self, generator, n_samples=100):
        if not self.has_cseg: return 0.0
        print("\nü©ª Running Clinical Plausibility Check...")

        scores = []
        batch_size = 16
        steps = n_samples // batch_size

        generator.eval()
        with torch.no_grad():
            for _ in range(steps):
                z = torch.randn(batch_size, self.config['LATENT_DIM'], device=self.device)
                fake = generator(z, alpha=1.0, steps=7)
                fake = (fake + 1) * 0.5 # [-1,1] -> [0,1]

                fake_resized = F.interpolate(fake, size=(512, 512), mode='bilinear')
                masks = torch.sigmoid(self.cseg(fake_resized))
                area = masks.sum(dim=[2,3]) / (512*512)

                # Deviation from standard lung area ~0.22
                score = torch.abs(area - 0.22)
                scores.append(score)

        avg_dev = torch.cat(scores).mean().item()
        print(f"   ‚úÖ Clinical Deviation: {avg_dev:.4f} (Lower is better)")
        return avg_dev

    # -----------------------------------------------------------
    # MASTER RUN
    # -----------------------------------------------------------
    def run_full_suite(self, X_synth, y_synth, X_real, y_real, races, generator):
        tstr_score = self.evaluate_utility_tstr(X_synth, y_synth, X_real, y_real)
        dist_score = self.evaluate_distinguishability(X_real, X_synth)
        dp_diff, eod_score = self.evaluate_fairness(X_synth, y_synth, X_real, y_real, races)
        clin_score = self.evaluate_clinical(generator)

        print("\n" + "="*40)
        print("üèÜ CAF-GAN FINAL SCORECARD")
        print("="*40)
        print(f"1. Utility (TSTR AUC):      {tstr_score:.4f}  (Target: >0.70)")
        print(f"2. Realism (Discrim Acc):   {dist_score:.4f}  (Target: ~0.50 best)")
        print(f"3. Fairness (Demog Diff):   {dp_diff:.4f}  (Target: <0.10)")
        print(f"4. Fairness (EO Diff):      {eod_score:.4f}  (Target: <0.10)")
        print(f"5. Clinical (Lung Dev):     {clin_score:.4f}  (Target: <0.10)")
        print("="*40)

        res = pd.DataFrame([{
            'TSTR_AUC': tstr_score,
            'Realism_Acc': dist_score,
            'Fairness_DP': dp_diff,
            'Fairness_EOD': eod_score,
            'Clinical_Dev': clin_score
        }])
        res.to_csv(os.path.join(self.config['OUTPUT_DIR'], "FINAL_ROBUST_SCORES.csv"), index=False)
        print(f"Saved to {self.config['OUTPUT_DIR']}/FINAL_ROBUST_SCORES.csv")

# ==========================================
# DATA LOADING & EXECUTION (Safe Mode)
# ==========================================

if __name__ == "__main__":
    # 1. Ensure Evaluator exists
    if 'evaluator' not in locals():
        print("üîÑ Re-initializing MasterEvaluatorV4 to load data...")
        evaluator = MasterEvaluatorV4(CONFIG)

    # 2. Explicitly fetch the data into memory now
    print("üì• Loading Data features...")
    # We assume MasterEvaluatorV4 has these methods (it does in your script)
    X_synth, y_synth = evaluator.get_synthetic_features()
    X_real, y_real, races = evaluator.extract_real_features_with_race()

    # 3. Run Ultimate Suite
    ultimate_eval = UltimateCAFGANEvaluator(CONFIG)

    ultimate_eval.run_full_suite(
        X_synth,
        y_synth,
        X_real,
        y_real,
        races,
        evaluator.generator
    )

üì• Loading Data features...
‚úÖ Existing balanced synthetic dataset found (2500,2500). Loading features.



Loading synthetic features:   0%|          | 0/79 [00:00<?, ?it/s][A
Loading synthetic features:   1%|‚ñè         | 1/79 [00:02<03:51,  2.96s/it][A
Loading synthetic features:   3%|‚ñé         | 2/79 [00:03<02:05,  1.63s/it][A
Loading synthetic features:   4%|‚ñç         | 3/79 [00:04<01:32,  1.21s/it][A
Loading synthetic features:   5%|‚ñå         | 4/79 [00:05<01:15,  1.01s/it][A
Loading synthetic features:   6%|‚ñã         | 5/79 [00:05<01:06,  1.11it/s][A
Loading synthetic features:   8%|‚ñä         | 6/79 [00:06<01:01,  1.18it/s][A
Loading synthetic features:   9%|‚ñâ         | 7/79 [00:07<00:57,  1.25it/s][A
Loading synthetic features:  10%|‚ñà         | 8/79 [00:07<00:55,  1.28it/s][A
Loading synthetic features:  11%|‚ñà‚ñè        | 9/79 [00:08<00:53,  1.32it/s][A
Loading synthetic features:  13%|‚ñà‚ñé        | 10/79 [00:09<00:51,  1.34it/s][A
Loading synthetic features:  14%|‚ñà‚ñç        | 11/79 [00:10<00:50,  1.35it/s][A
Loading synthetic features:  15%|‚ñà‚ñå  

ü©∫ Loading Clinical Segmentation Model...

ü§ñ Running Utility Evaluation (TSTR)...
   ‚úÖ TSTR Mean AUC: 0.6308

üïµÔ∏è Running Distinguishability Check...
   ‚úÖ Discriminator Accuracy: 1.0000 (0.50 is perfect realism)

‚öñÔ∏è Running Fairness Evaluation...
   ‚úÖ Fairness Gap (Demog Parity): 0.2857
   ‚úÖ Fairness Gap (Equalized Odds): 0.5000

ü©ª Running Clinical Plausibility Check...
   ‚úÖ Clinical Deviation: 0.1314 (Lower is better)

üèÜ CAF-GAN FINAL SCORECARD
1. Utility (TSTR AUC):      0.6308  (Target: >0.70)
2. Realism (Discrim Acc):   1.0000  (Target: ~0.50 best)
3. Fairness (Demog Diff):   0.2857  (Target: <0.10)
4. Fairness (EO Diff):      0.5000  (Target: <0.10)
5. Clinical (Lung Dev):     0.1314  (Target: <0.10)
Saved to /content/drive/MyDrive/CAF-GAN/outputs/evaluation_results_master_v4//FINAL_ROBUST_SCORES.csv
