In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

# Navigate to your project directory in Google Drive
%cd /content/drive/MyDrive/CAF-GAN/

# Unzip the image files (this might take a few minutes)
# The -q makes the output quiet, -n prevents unzipping if already done
!unzip -q -n mimic-cxr-jpg-2.0.0.zip

print("✅ Workspace ready and images unzipped.")

/content/drive/MyDrive/CAF-GAN
✅ Workspace ready and images unzipped.


In [3]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image

# This single dataset file will serve both critic training scripts.

class MIMICCXRClassifierDataset(Dataset):
    """
    Dataset for the Cdiag (classification) task.
    - Loads a JPG image.
    - Converts it to RGB (as required by ResNet).
    - Returns the image and its corresponding Pneumonia label.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        # Construct the path to the JPG image
        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load image and convert to a numpy array in RGB format
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        # Get the label
        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)

        return image, label.unsqueeze(0)


class MIMICXRSegmentationDataset(Dataset):
    """
    Dataset for the Cseg (segmentation) task.
    - Loads a JPG image (as 3-channel RGB). <--- UPDATED
    - Loads its corresponding pre-generated PNG mask.
    - Returns both the image and the mask.
    """
    def __init__(self, df, image_dir, mask_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )
        mask_path = os.path.join(self.mask_dir, f"{dicom_id}.png")

        # --- KEY CHANGE ---
        # Load image and convert to RGB to match the model's expected input channels.
        image = np.array(Image.open(image_path).convert("RGB"), dtype=np.float32)

        # Load mask as grayscale
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # RESIZE THE IMAGE to match the mask size (256×256) BEFORE augmentation
        image = np.array(Image.fromarray(image.astype(np.uint8)).resize((256, 256), Image.BILINEAR))

        # Normalize mask values from [0, 255] to [0.0, 1.0]
        mask[mask == 255.0] = 1.0

        # Apply augmentations (Albumentations will now see matching input sizes from the Resize transform)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Add a channel dimension for the mask for consistency
        return image, mask.unsqueeze(0)

In [4]:
import yaml

# Notice the paths now point to our Colab workspace in Google Drive
config_yaml = """
# Configuration for training the Cdiag (Pneumonia Classifier) model

# --- Data Paths ---
IMAGE_DIR: "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/"
TRAIN_CSV_PATH: "/content/drive/MyDrive/CAF-GAN/data/splits/train.csv"
VAL_CSV_PATH: "/content/drive/MyDrive/CAF-GAN/data/splits/val.csv"

# --- Output Paths ---
OUTPUT_DIR: "/content/drive/MyDrive/CAF-GAN/outputs/cdiag/"
MODEL_NAME: "best_cdiag_colab.pth"

# --- Model & Training Hyperparameters ---
IMG_SIZE: 256
BATCH_SIZE: 32  # We can use a larger batch size on a Colab GPU
EPOCHS: 20
LEARNING_RATE: 0.0001

# --- System ---
DEVICE: "cuda"
NUM_WORKERS: 2
"""

CONFIG = yaml.safe_load(config_yaml)
print("Configuration loaded:")
print(CONFIG)

Configuration loaded:
{'IMAGE_DIR': '/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/', 'TRAIN_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/train.csv', 'VAL_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/val.csv', 'OUTPUT_DIR': '/content/drive/MyDrive/CAF-GAN/outputs/cdiag/', 'MODEL_NAME': 'best_cdiag_colab.pth', 'IMG_SIZE': 256, 'BATCH_SIZE': 32, 'EPOCHS': 20, 'LEARNING_RATE': 0.0001, 'DEVICE': 'cuda', 'NUM_WORKERS': 2}


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import pandas as pd
from tqdm import tqdm
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
# The dataset classes are already defined in a previous cell

# --- ⚙️ Setup ---
# The CONFIG dictionary is already loaded from the previous cell
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)
print(f"Using device: {DEVICE}")

# --- 🏋️‍♀️ Training & Validation Functions ---
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(dataloader.dataset)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_preds = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = torch.sigmoid(outputs) > 0.5
            correct_preds += (preds == labels).sum().item()

    val_loss = running_loss / len(dataloader.dataset)
    val_acc = correct_preds / len(dataloader.dataset)
    return val_loss, val_acc

# --- 🚀 Main Execution ---
def run_training():
    # --- Data Loading & Augmentation ---
    train_transform = A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.7),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    val_transform = A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    train_df = pd.read_csv(CONFIG['TRAIN_CSV_PATH'])
    val_df = pd.read_csv(CONFIG['VAL_CSV_PATH'])

    train_dataset = MIMICCXRClassifierDataset(train_df, CONFIG['IMAGE_DIR'], transform=train_transform)
    val_dataset = MIMICCXRClassifierDataset(val_df, CONFIG['IMAGE_DIR'], transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=CONFIG['NUM_WORKERS'])
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=CONFIG['NUM_WORKERS'])

    # --- Model Setup ---
    model = models.resnet50(weights='IMAGENET1K_V1')
    model.fc = nn.Linear(model.fc.in_features, 1)
    model.to(DEVICE)

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['LEARNING_RATE'])

    best_val_loss = float('inf')

    # --- Training Loop ---
    for epoch in range(CONFIG['EPOCHS']):
        print(f"\n--- Epoch {epoch+1}/{CONFIG['EPOCHS']} ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            model_path = os.path.join(CONFIG['OUTPUT_DIR'], CONFIG['MODEL_NAME'])
            torch.save(model.state_dict(), model_path)
            print(f"✨ New best model saved to {model_path}")

    print("\n✅ Training of Cdiag complete!")

# Run the training process
run_training()

Using device: cuda
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 192MB/s]



--- Epoch 1/20 ---


Training: 100%|██████████| 44/44 [15:40<00:00, 21.38s/it]
Validating: 100%|██████████| 10/10 [03:30<00:00, 21.08s/it]


Train Loss: 0.6677 | Val Loss: 0.6163 | Val Accuracy: 0.6600
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cdiag/best_cdiag_colab.pth

--- Epoch 2/20 ---


Training: 100%|██████████| 44/44 [01:35<00:00,  2.18s/it]
Validating: 100%|██████████| 10/10 [00:21<00:00,  2.10s/it]


Train Loss: 0.5986 | Val Loss: 0.6603 | Val Accuracy: 0.6800

--- Epoch 3/20 ---


Training: 100%|██████████| 44/44 [01:38<00:00,  2.24s/it]
Validating: 100%|██████████| 10/10 [00:20<00:00,  2.06s/it]


Train Loss: 0.5575 | Val Loss: 0.6395 | Val Accuracy: 0.6233

--- Epoch 4/20 ---


Training: 100%|██████████| 44/44 [01:38<00:00,  2.24s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.5294 | Val Loss: 0.6861 | Val Accuracy: 0.6667

--- Epoch 5/20 ---


Training: 100%|██████████| 44/44 [01:41<00:00,  2.32s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.4720 | Val Loss: 0.7663 | Val Accuracy: 0.6033

--- Epoch 6/20 ---


Training: 100%|██████████| 44/44 [01:41<00:00,  2.31s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]


Train Loss: 0.4276 | Val Loss: 0.7407 | Val Accuracy: 0.6800

--- Epoch 7/20 ---


Training: 100%|██████████| 44/44 [01:41<00:00,  2.31s/it]
Validating: 100%|██████████| 10/10 [00:20<00:00,  2.00s/it]


Train Loss: 0.3818 | Val Loss: 0.8811 | Val Accuracy: 0.6767

--- Epoch 8/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.33s/it]
Validating: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]


Train Loss: 0.3065 | Val Loss: 0.9475 | Val Accuracy: 0.6400

--- Epoch 9/20 ---


Training: 100%|██████████| 44/44 [01:43<00:00,  2.35s/it]
Validating: 100%|██████████| 10/10 [00:20<00:00,  2.03s/it]


Train Loss: 0.2955 | Val Loss: 1.2725 | Val Accuracy: 0.6667

--- Epoch 10/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.33s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.2717 | Val Loss: 0.8627 | Val Accuracy: 0.6800

--- Epoch 11/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.32s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  2.00s/it]


Train Loss: 0.2274 | Val Loss: 0.9855 | Val Accuracy: 0.6967

--- Epoch 12/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.32s/it]
Validating: 100%|██████████| 10/10 [00:20<00:00,  2.01s/it]


Train Loss: 0.2349 | Val Loss: 0.9656 | Val Accuracy: 0.6500

--- Epoch 13/20 ---


Training: 100%|██████████| 44/44 [01:41<00:00,  2.30s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.1880 | Val Loss: 1.0389 | Val Accuracy: 0.6667

--- Epoch 14/20 ---


Training: 100%|██████████| 44/44 [01:40<00:00,  2.27s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]


Train Loss: 0.1483 | Val Loss: 1.1956 | Val Accuracy: 0.6533

--- Epoch 15/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.33s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.97s/it]


Train Loss: 0.1838 | Val Loss: 1.3567 | Val Accuracy: 0.5900

--- Epoch 16/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.33s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]


Train Loss: 0.1665 | Val Loss: 1.2590 | Val Accuracy: 0.6600

--- Epoch 17/20 ---


Training: 100%|██████████| 44/44 [01:40<00:00,  2.29s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.96s/it]


Train Loss: 0.1205 | Val Loss: 1.0907 | Val Accuracy: 0.6533

--- Epoch 18/20 ---


Training: 100%|██████████| 44/44 [01:40<00:00,  2.29s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.1025 | Val Loss: 1.4623 | Val Accuracy: 0.6367

--- Epoch 19/20 ---


Training: 100%|██████████| 44/44 [01:42<00:00,  2.33s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.99s/it]


Train Loss: 0.1152 | Val Loss: 1.0179 | Val Accuracy: 0.6700

--- Epoch 20/20 ---


Training: 100%|██████████| 44/44 [01:41<00:00,  2.31s/it]
Validating: 100%|██████████| 10/10 [00:19<00:00,  1.98s/it]

Train Loss: 0.1337 | Val Loss: 1.2141 | Val Accuracy: 0.6567

✅ Training of Cdiag complete!



