In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

# Navigate to your project directory in Google Drive
%cd /content/drive/MyDrive/CAF-GAN/

# Unzip the image files (this might take a few minutes)
# The -q makes the output quiet, -n prevents unzipping if already done
!unzip -q -n mimic-cxr-jpg-2.0.0.zip

print("✅ Workspace ready and images unzipped.")

/content/drive/MyDrive/CAF-GAN
✅ Workspace ready and images unzipped.


In [2]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    """
    CAF-GAN Generator Network (DCGAN architecture).
    Takes a latent vector z and outputs a 256x256 grayscale image.
    Output is normalized between -1 and 1 using Tanh.
    """
    def __init__(self, latent_dim, channels=1):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            # State: 1024 x 4 x 4
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State: 512 x 8 x 8
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State: 256 x 16 x 16
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State: 128 x 32 x 32
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State: 64 x 64 x 64
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # State: 32 x 128 x 128
            nn.ConvTranspose2d(32, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 256 x 256
        )

    def forward(self, input):
        return self.main(input)

In [4]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image

# This single dataset file will serve both critic training scripts.

class MIMICCXRClassifierDataset(Dataset):
    """
    Dataset for the Cdiag (classification) task.
    - Loads a JPG image.
    - Converts it to RGB (as required by ResNet).
    - Returns the image and its corresponding Pneumonia label.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        # Construct the path to the JPG image
        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load image and convert to a numpy array in RGB format
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        # Get the label
        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)

        return image, label.unsqueeze(0)


class MIMICXRSegmentationDataset(Dataset):
    """
    Dataset for the Cseg (segmentation) task.
    - Loads a JPG image (as 3-channel RGB). <--- UPDATED
    - Loads its corresponding pre-generated PNG mask.
    - Returns both the image and the mask.
    """
    def __init__(self, df, image_dir, mask_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )
        mask_path = os.path.join(self.mask_dir, f"{dicom_id}.png")

        # --- KEY CHANGE ---
        # Load image and convert to RGB to match the model's expected input channels.
        image = np.array(Image.open(image_path).convert("RGB"), dtype=np.float32)

        # Load mask as grayscale
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # RESIZE THE IMAGE to match the mask size (256×256) BEFORE augmentation
        image = np.array(Image.fromarray(image.astype(np.uint8)).resize((256, 256), Image.BILINEAR))

        # Normalize mask values from [0, 255] to [0.0, 1.0]
        mask[mask == 255.0] = 1.0

        # Apply augmentations (Albumentations will now see matching input sizes from the Resize transform)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Add a channel dimension for the mask for consistency
        return image, mask.unsqueeze(0)

class MIMICCXR_GANDataset(Dataset):
    """
    Dataset for the main GAN training.
    - Loads a JPG image (as grayscale).
    - Returns the image, its Pneumonia label, and the one-hot encoded race group.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

        # Pre-process sensitive attributes
        self.df['race_group'] = self.df['race_group'].astype('category')
        self.race_categories = self.df['race_group'].cat.categories
        self.one_hot_races = pd.get_dummies(self.df['race_group'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load the image and convert to "RGB"
        image = Image.open(image_path).convert("RGB") # <-- CHANGE "L" to "RGB"
        image = np.array(image, dtype=np.float32)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)
        race = torch.tensor(self.one_hot_races.iloc[idx].values, dtype=torch.float32)

        return image, label, race

In [7]:
config = {
  # Configuration for the final evaluation (Phase D)

  # --- Paths ---
  # Path to the trained generator from Phase C
  "GENERATOR_CHECKPOINT": "/content/drive/MyDrive/CAF-GAN/outputs/gan/netG_epoch_200.pth", #<-- IMPORTANT: Update this to your best generator checkpoint
  "REAL_DATA_CSV_TEST": "/content/drive/MyDrive/CAF-GAN/data/splits/test.csv",
  "IMAGE_DIR_REAL": "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/",

  # --- Synthetic Data Generation ---
  "SYNTHETIC_DATA_DIR": "/content/drive/MyDrive/CAF-GAN/data/synthetic_images/",
  "SYNTHETIC_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/synthetic_images/labels.csv",
  "NUM_SYNTHETIC_IMAGES": 2000, # Generate a dataset of the same size as our original subset
  "GENERATION_BATCH_SIZE": 32,

  # --- Downstream Classifier Training ---
  "CLASSIFIER_OUTPUT_DIR": "outputs/downstream_classifier/",
  "CLASSIFIER_MODEL_NAME": "best_synth_trained_classifier.pth",
  "IMG_SIZE": 256,
  "BATCH_SIZE": 32,
  "EPOCHS": 15,
  "LEARNING_RATE": 0.0001,

  # --- System ---
  "DEVICE": "cuda",
  "NUM_WORKERS": 2,
  "LATENT_DIM": 128,
  "CHANNELS": 3
}


In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torchvision.utils as vutils
import pandas as pd
from tqdm import tqdm
import os
import yaml
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, confusion_matrix

# We need to import our project's models and datasets
# from src.models.generator import Generator
# from src.data.dataset import MIMICCXRClassifierDataset

def generate_synthetic_data(config, device):
    """Loads a trained generator and creates a new synthetic dataset."""
    print("---  bước 1: Generating Synthetic Dataset ---")
    os.makedirs(config['SYNTHETIC_DATA_DIR'], exist_ok=True)

    # Load Generator
    netG = Generator(config['LATENT_DIM'], config['CHANNELS']).to(device)
    netG.load_state_dict(torch.load(config['GENERATOR_CHECKPOINT'], map_location=device))
    netG.eval()

    # Create balanced labels for the synthetic data
    num_positive = config['NUM_SYNTHETIC_IMAGES'] // 2
    labels = np.array([1] * num_positive + [0] * (config['NUM_SYNTHETIC_IMAGES'] - num_positive))
    np.random.shuffle(labels)

    image_ids = []
    generated_labels = []

    with torch.no_grad():
        for i in tqdm(range(0, config['NUM_SYNTHETIC_IMAGES'], config['GENERATION_BATCH_SIZE']), desc="Generating Images"):
            batch_size = min(config['GENERATION_BATCH_SIZE'], config['NUM_SYNTHETIC_IMAGES'] - i)
            noise = torch.randn(batch_size, config['LATENT_DIM'], 1, 1, device=device)
            fake_imgs = netG(noise)

            for j in range(batch_size):
                img_idx = i + j
                image_id = f"synth_{img_idx:05d}.jpg"
                vutils.save_image(fake_imgs[j], os.path.join(config['SYNTHETIC_DATA_DIR'], image_id), normalize=True)
                image_ids.append(image_id)
                generated_labels.append(labels[img_idx])

    # Save the labels to a CSV file
    synthetic_df = pd.DataFrame({'image_id': image_ids, 'Pneumonia': generated_labels})
    synthetic_df.to_csv(config['SYNTHETIC_CSV_PATH'], index=False)
    print(f"✅ Generated {config['NUM_SYNTHETIC_IMAGES']} synthetic images and saved labels to {config['SYNTHETIC_CSV_PATH']}")
    return synthetic_df

def train_downstream_classifier(config, device):
    """Trains a new classifier from scratch on the synthetic dataset."""
    print("\n--- bước 2: Training Downstream Classifier on Synthetic Data ---")
    os.makedirs(config['CLASSIFIER_OUTPUT_DIR'], exist_ok=True)

    transform = A.Compose([
        A.Resize(config['IMG_SIZE'], config['IMG_SIZE']),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    # We need a small, modified Dataset class for the synthetic data
    class SyntheticDataset(Dataset):
        def __init__(self, df, image_dir, transform):
            self.df = df
            self.image_dir = image_dir
            self.transform = transform
        def __len__(self):
            return len(self.df)
        def __getitem__(self, idx):
            row = self.df.iloc[idx]
            image_path = os.path.join(self.image_dir, row['image_id'])
            image = np.array(Image.open(image_path).convert("RGB"))
            label = torch.tensor(row['Pneumonia'], dtype=torch.float32)
            if self.transform:
                image = self.transform(image=image)['image']
            return image, label.unsqueeze(0)

    synth_df = pd.read_csv(config['SYNTHETIC_CSV_PATH'])
    train_dataset = SyntheticDataset(synth_df, config['SYNTHETIC_DATA_DIR'], transform)
    train_loader = DataLoader(train_dataset, batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=config['NUM_WORKERS'])

    model = models.resnet50(weights='IMAGENET1K_V1')
    model.fc = nn.Linear(model.fc.in_features, 1)
    model.to(device)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['LEARNING_RATE'])

    for epoch in range(config['EPOCHS']):
        model.train()
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['EPOCHS']}"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    torch.save(model.state_dict(), os.path.join(config['CLASSIFIER_OUTPUT_DIR'], config['CLASSIFIER_MODEL_NAME']))
    print("✅ Downstream classifier training complete.")
    return model

def evaluate_on_real_data(classifier, config, device):
    """Evaluates the synthetically-trained classifier on the real test set."""
    print("\n--- bước 3: Evaluating on Real Test Data ---")

    transform = A.Compose([
        A.Resize(config['IMG_SIZE'], config['IMG_SIZE']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    test_df = pd.read_csv(config['REAL_DATA_CSV_TEST'])
    # The original dataset needs a different class that knows the MIMIC path structure
    class RealTestDataset(MIMICCXRClassifierDataset):
        def __getitem__(self, idx):
            image, label = super().__getitem__(idx)
            race = self.df.iloc[idx]['race_group']
            return image, label, race

    test_dataset = RealTestDataset(test_df, config['IMAGE_DIR_REAL'], transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=config['NUM_WORKERS'])

    classifier.eval()
    all_preds, all_labels, all_races = [], [], []

    with torch.no_grad():
        for images, labels, races in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = classifier(images)
            preds = torch.sigmoid(outputs).cpu().numpy().flatten()

            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy().flatten())
            all_races.extend(races)

    # --- Calculate Metrics ---
    df_results = pd.DataFrame({'label': all_labels, 'pred_prob': all_preds, 'race': all_races})
    df_results['prediction'] = (df_results['pred_prob'] > 0.5).astype(int)

    # Utility Metrics
    auc = roc_auc_score(df_results['label'], df_results['pred_prob'])
    accuracy = accuracy_score(df_results['label'], df_results['prediction'])
    f1 = f1_score(df_results['label'], df_results['prediction'])

    # Fairness Metrics (Equal Opportunity Difference)
    tpr_per_group = {}
    for group in df_results['race'].unique():
        group_df = df_results[df_results['race'] == group]
        tn, fp, fn, tp = confusion_matrix(group_df['label'], group_df['prediction'], labels=[0,1]).ravel()
        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
        tpr_per_group[group] = tpr

    eod = max(tpr_per_group.values()) - min(tpr_per_group.values())

    # --- Print Report ---
    print("\n--- 📊 Evaluation Report ---")
    print(f"Trained on {config['NUM_SYNTHETIC_IMAGES']} synthetic images. Evaluated on {len(df_results)} real test images.")
    print("\n## 🎯 Overall Performance (Utility)")
    print(f"**AUC:** {auc:.4f}")
    print(f"**Accuracy:** {accuracy:.4f}")
    print(f"**F1-Score:** {f1:.4f}")

    print("\n## ⚖️ Fairness Performance")
    print("True Positive Rate (TPR) by Group:")
    for group, tpr in tpr_per_group.items():
        print(f"  - {group}: {tpr:.4f}")
    print(f"**Equal Opportunity Difference (Max TPR - Min TPR): {eod:.4f}**")
    print("\n--- Evaluation Complete ---")


def main(config):
    """Main function to run the entire evaluation pipeline."""
    # with open('configs/evaluate.yaml', 'r') as f:
    #     config = yaml.safe_load(f)

    device = config['DEVICE'] if torch.cuda.is_available() else 'cpu'

    # STEP 1
    generate_synthetic_data(config, device)

    # STEP 2
    trained_classifier = train_downstream_classifier(config, device)

    # STEP 3
    evaluate_on_real_data(trained_classifier, config, device)

if __name__ == '__main__':
    main(config)

---  bước 1: Generating Synthetic Dataset ---


Generating Images: 100%|██████████| 63/63 [00:21<00:00,  2.87it/s]


✅ Generated 2000 synthetic images and saved labels to /content/drive/MyDrive/CAF-GAN/data/synthetic_images/labels.csv

--- bước 2: Training Downstream Classifier on Synthetic Data ---


Epoch 1/15: 100%|██████████| 63/63 [00:24<00:00,  2.55it/s]
Epoch 2/15: 100%|██████████| 63/63 [00:24<00:00,  2.62it/s]
Epoch 3/15: 100%|██████████| 63/63 [00:23<00:00,  2.69it/s]
Epoch 4/15: 100%|██████████| 63/63 [00:23<00:00,  2.66it/s]
Epoch 5/15: 100%|██████████| 63/63 [00:23<00:00,  2.65it/s]
Epoch 6/15: 100%|██████████| 63/63 [00:23<00:00,  2.65it/s]
Epoch 7/15: 100%|██████████| 63/63 [00:23<00:00,  2.67it/s]
Epoch 8/15: 100%|██████████| 63/63 [00:23<00:00,  2.66it/s]
Epoch 9/15: 100%|██████████| 63/63 [00:23<00:00,  2.66it/s]
Epoch 10/15: 100%|██████████| 63/63 [00:23<00:00,  2.64it/s]
Epoch 11/15: 100%|██████████| 63/63 [00:23<00:00,  2.67it/s]
Epoch 12/15: 100%|██████████| 63/63 [00:23<00:00,  2.67it/s]
Epoch 13/15: 100%|██████████| 63/63 [00:23<00:00,  2.67it/s]
Epoch 14/15: 100%|██████████| 63/63 [00:23<00:00,  2.64it/s]
Epoch 15/15: 100%|██████████| 63/63 [00:23<00:00,  2.67it/s]


✅ Downstream classifier training complete.

--- bước 3: Evaluating on Real Test Data ---


Evaluating: 100%|██████████| 10/10 [03:46<00:00, 22.62s/it]



--- 📊 Evaluation Report ---
Trained on 2000 synthetic images. Evaluated on 301 real test images.

## 🎯 Overall Performance (Utility)
**AUC:** 0.4849
**Accuracy:** 0.5548
**F1-Score:** 0.3232

## ⚖️ Fairness Performance
True Positive Rate (TPR) by Group:
  - WHITE: 0.2911
  - OTHER: 0.1250
  - BLACK: 0.2778
  - HISPANIC/LATINO: 0.3333
  - ASIAN: 0.2500
**Equal Opportunity Difference (Max TPR - Min TPR): 0.2083**

--- Evaluation Complete ---
