In [11]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import os

# Navigate to your project directory in Google Drive
%cd /content/drive/MyDrive/CAF-GAN/

# Unzip the image files (this might take a few minutes)
# The -q makes the output quiet, -n prevents unzipping if already done
!unzip -q -n mimic-cxr-jpg-2.0.0.zip

print("✅ Workspace ready and images unzipped.")

/content/drive/MyDrive/CAF-GAN
✅ Workspace ready and images unzipped.


In [13]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

In [27]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image

# This single dataset file will serve both critic training scripts.

class MIMICCXRClassifierDataset(Dataset):
    """
    Dataset for the Cdiag (classification) task.
    - Loads a JPG image.
    - Converts it to RGB (as required by ResNet).
    - Returns the image and its corresponding Pneumonia label.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        # Construct the path to the JPG image
        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load image and convert to a numpy array in RGB format
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        # Get the label
        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)

        return image, label.unsqueeze(0)


class MIMICXRSegmentationDataset(Dataset):
    """
    Dataset for the Cseg (segmentation) task.
    - Loads a JPG image (as 3-channel RGB). <--- UPDATED
    - Loads its corresponding pre-generated PNG mask.
    - Returns both the image and the mask.
    """
    def __init__(self, df, image_dir, mask_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )
        mask_path = os.path.join(self.mask_dir, f"{dicom_id}.png")

        # --- KEY CHANGE ---
        # Load image and convert to RGB to match the model's expected input channels.
        image = np.array(Image.open(image_path).convert("RGB"), dtype=np.float32)

        # Load mask as grayscale
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # RESIZE THE IMAGE to match the mask size (256×256) BEFORE augmentation
        image = np.array(Image.fromarray(image.astype(np.uint8)).resize((256, 256), Image.BILINEAR))

        # Normalize mask values from [0, 255] to [0.0, 1.0]
        mask[mask == 255.0] = 1.0

        # Apply augmentations (Albumentations will now see matching input sizes from the Resize transform)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Add a channel dimension for the mask for consistency
        return image, mask.unsqueeze(0)

In [28]:
import yaml

# Notice the paths now point to our Colab workspace in Google Drive
config_yaml = """
# Configuration for training the Cseg (Lung Segmenter) model

# --- Data Paths ---
IMAGE_DIR: "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/"
MASK_DIR: "/content/drive/MyDrive/CAF-GAN/data/masks/" # <-- Path to the masks we generated
TRAIN_CSV_PATH: "/content/drive/MyDrive/CAF-GAN/data/splits/train.csv"
VAL_CSV_PATH: "/content/drive/MyDrive/CAF-GAN/data/splits/val.csv"

# --- Output Paths ---
OUTPUT_DIR: "/content/drive/MyDrive/CAF-GAN/outputs/cseg/" # <-- Updated output directory
MODEL_NAME: "best_cseg_colab.pth" # <-- Updated model name

# --- Model & Training Hyperparameters ---
IMG_SIZE: 256
BATCH_SIZE: 16 # Segmentation is more memory intensive, so we use a smaller batch size
EPOCHS: 25     # Segmentation often benefits from more epochs
LEARNING_RATE: 0.0001

# --- System ---
DEVICE: "cuda"
NUM_WORKERS: 2
"""

CONFIG = yaml.safe_load(config_yaml)
print("Configuration loaded for Cseg training:")
print(CONFIG)

Configuration loaded for Cseg training:
{'IMAGE_DIR': '/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/', 'MASK_DIR': '/content/drive/MyDrive/CAF-GAN/data/masks/', 'TRAIN_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/train.csv', 'VAL_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/val.csv', 'OUTPUT_DIR': '/content/drive/MyDrive/CAF-GAN/outputs/cseg/', 'MODEL_NAME': 'best_cseg_colab.pth', 'IMG_SIZE': 256, 'BATCH_SIZE': 16, 'EPOCHS': 25, 'LEARNING_RATE': 0.0001, 'DEVICE': 'cuda', 'NUM_WORKERS': 2}


In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
import segmentation_models_pytorch as smp

# --- ⚙️ Setup ---
# The CONFIG dictionary is already loaded from the previous cell
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)
print(f"Using device: {DEVICE}")

# --- 🏋️‍♀️ Loss, Metrics, and Training Functions ---
# (DiceBCELoss, dice_score, train_one_epoch, validate functions remain the same as before)
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()
    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = nn.BCELoss()(inputs, targets)
        return BCE + dice_loss

def dice_score(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds) > 0.5
    preds = preds.float().view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    return (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(dataloader.dataset)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    total_dice_score = 0.0
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validating"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item() * images.size(0)
            total_dice_score += dice_score(outputs, masks).item()
    val_loss = running_loss / len(dataloader.dataset)
    val_dice = total_dice_score / len(dataloader)
    return val_loss, val_dice

# --- 🚀 Main Execution ---
def run_training():
    # --- Data Loading & Augmentation ---
    # --- KEY CHANGE ---
    # The transform now resizes both image and mask to the same size FIRST.
    # It also uses ImageNet normalization for the 3-channel input image.
    transform = A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']), # This fixes the size mismatch!
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet stats for RGB
        ToTensorV2(),
    ])

    train_df = pd.read_csv(CONFIG['TRAIN_CSV_PATH'])
    val_df = pd.read_csv(CONFIG['VAL_CSV_PATH'])

    train_dataset = MIMICXRSegmentationDataset(train_df, CONFIG['IMAGE_DIR'], CONFIG['MASK_DIR'], transform=transform)
    val_dataset = MIMICXRSegmentationDataset(val_df, CONFIG['IMAGE_DIR'], CONFIG['MASK_DIR'], transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=CONFIG['NUM_WORKERS'])
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=CONFIG['NUM_WORKERS'])

    # --- Model Setup ---
    # --- KEY CHANGE ---
    # The model MUST be initialized with in_channels=3 to match the data.
    model = smp.Unet(encoder_name="resnet34", encoder_weights="imagenet", in_channels=3, classes=1)
    model.to(DEVICE)

    criterion = DiceBCELoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['LEARNING_RATE'])

    best_val_dice = 0.0

    # --- Training Loop ---
    for epoch in range(CONFIG['EPOCHS']):
        print(f"\n--- Epoch {epoch+1}/{CONFIG['EPOCHS']} ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        val_loss, val_dice = validate(model, val_loader, criterion, DEVICE)

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice Score: {val_dice:.4f}")

        if val_dice > best_val_dice:
            best_val_dice = val_dice
            model_path = os.path.join(CONFIG['OUTPUT_DIR'], CONFIG['MODEL_NAME'])
            torch.save(model.state_dict(), model_path)
            print(f"✨ New best model saved to {model_path} (Dice: {val_dice:.4f})")

    print("\n✅ Training of Cseg complete!")

# Run the training process
run_training()

Using device: cuda

--- Epoch 1/25 ---


Training: 100%|██████████| 88/88 [06:22<00:00,  4.34s/it]
Validating: 100%|██████████| 19/19 [01:27<00:00,  4.58s/it]


Train Loss: 1.4151 | Val Loss: 1.1972 | Val Dice Score: 0.4877
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.4877)

--- Epoch 2/25 ---


Training: 100%|██████████| 88/88 [02:50<00:00,  1.93s/it]
Validating: 100%|██████████| 19/19 [00:37<00:00,  1.97s/it]


Train Loss: 1.1507 | Val Loss: 1.0915 | Val Dice Score: 0.5214
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5214)

--- Epoch 3/25 ---


Training: 100%|██████████| 88/88 [02:57<00:00,  2.01s/it]
Validating: 100%|██████████| 19/19 [00:36<00:00,  1.91s/it]


Train Loss: 1.0399 | Val Loss: 1.0155 | Val Dice Score: 0.5247
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5247)

--- Epoch 4/25 ---


Training: 100%|██████████| 88/88 [02:54<00:00,  1.98s/it]
Validating: 100%|██████████| 19/19 [00:35<00:00,  1.86s/it]


Train Loss: 0.9856 | Val Loss: 0.9722 | Val Dice Score: 0.5318
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5318)

--- Epoch 5/25 ---


Training: 100%|██████████| 88/88 [02:56<00:00,  2.01s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.84s/it]


Train Loss: 0.9490 | Val Loss: 0.9547 | Val Dice Score: 0.5441
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5441)

--- Epoch 6/25 ---


Training: 100%|██████████| 88/88 [02:57<00:00,  2.02s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.83s/it]


Train Loss: 0.9045 | Val Loss: 0.9156 | Val Dice Score: 0.5539
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5539)

--- Epoch 7/25 ---


Training: 100%|██████████| 88/88 [02:54<00:00,  1.99s/it]
Validating: 100%|██████████| 19/19 [00:38<00:00,  2.02s/it]


Train Loss: 0.8708 | Val Loss: 0.8884 | Val Dice Score: 0.5533

--- Epoch 8/25 ---


Training: 100%|██████████| 88/88 [02:51<00:00,  1.95s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.84s/it]


Train Loss: 0.8389 | Val Loss: 0.9051 | Val Dice Score: 0.5518

--- Epoch 9/25 ---


Training: 100%|██████████| 88/88 [02:50<00:00,  1.94s/it]
Validating: 100%|██████████| 19/19 [00:38<00:00,  2.01s/it]


Train Loss: 0.8210 | Val Loss: 0.8682 | Val Dice Score: 0.5572
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5572)

--- Epoch 10/25 ---


Training: 100%|██████████| 88/88 [02:52<00:00,  1.96s/it]
Validating: 100%|██████████| 19/19 [00:35<00:00,  1.89s/it]


Train Loss: 0.7962 | Val Loss: 0.8806 | Val Dice Score: 0.5533

--- Epoch 11/25 ---


Training: 100%|██████████| 88/88 [02:54<00:00,  1.98s/it]
Validating: 100%|██████████| 19/19 [00:35<00:00,  1.88s/it]


Train Loss: 0.7727 | Val Loss: 0.8492 | Val Dice Score: 0.5485

--- Epoch 12/25 ---


Training: 100%|██████████| 88/88 [02:57<00:00,  2.02s/it]
Validating: 100%|██████████| 19/19 [00:35<00:00,  1.89s/it]


Train Loss: 0.7587 | Val Loss: 0.8884 | Val Dice Score: 0.5626
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5626)

--- Epoch 13/25 ---


Training: 100%|██████████| 88/88 [02:58<00:00,  2.02s/it]
Validating: 100%|██████████| 19/19 [00:37<00:00,  1.97s/it]


Train Loss: 0.7447 | Val Loss: 0.8400 | Val Dice Score: 0.5624

--- Epoch 14/25 ---


Training: 100%|██████████| 88/88 [02:55<00:00,  1.99s/it]
Validating: 100%|██████████| 19/19 [00:36<00:00,  1.92s/it]


Train Loss: 0.7357 | Val Loss: 0.8359 | Val Dice Score: 0.5653
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5653)

--- Epoch 15/25 ---


Training: 100%|██████████| 88/88 [02:55<00:00,  1.99s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.82s/it]


Train Loss: 0.7148 | Val Loss: 0.8330 | Val Dice Score: 0.5731
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5731)

--- Epoch 16/25 ---


Training: 100%|██████████| 88/88 [02:57<00:00,  2.01s/it]
Validating: 100%|██████████| 19/19 [00:35<00:00,  1.84s/it]


Train Loss: 0.7067 | Val Loss: 0.8356 | Val Dice Score: 0.5560

--- Epoch 17/25 ---


Training: 100%|██████████| 88/88 [02:50<00:00,  1.94s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.83s/it]


Train Loss: 0.6907 | Val Loss: 0.8467 | Val Dice Score: 0.5715

--- Epoch 18/25 ---


Training: 100%|██████████| 88/88 [02:52<00:00,  1.95s/it]
Validating: 100%|██████████| 19/19 [00:38<00:00,  2.03s/it]


Train Loss: 0.6832 | Val Loss: 0.8392 | Val Dice Score: 0.5487

--- Epoch 19/25 ---


Training: 100%|██████████| 88/88 [02:51<00:00,  1.95s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.82s/it]


Train Loss: 0.6717 | Val Loss: 0.8282 | Val Dice Score: 0.5662

--- Epoch 20/25 ---


Training: 100%|██████████| 88/88 [02:53<00:00,  1.98s/it]
Validating: 100%|██████████| 19/19 [00:38<00:00,  2.03s/it]


Train Loss: 0.6619 | Val Loss: 0.8302 | Val Dice Score: 0.5755
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5755)

--- Epoch 21/25 ---


Training: 100%|██████████| 88/88 [02:53<00:00,  1.97s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.84s/it]


Train Loss: 0.6536 | Val Loss: 0.8249 | Val Dice Score: 0.5765
✨ New best model saved to /content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth (Dice: 0.5765)

--- Epoch 22/25 ---


Training: 100%|██████████| 88/88 [02:55<00:00,  1.99s/it]
Validating: 100%|██████████| 19/19 [00:37<00:00,  1.97s/it]


Train Loss: 0.6430 | Val Loss: 0.8377 | Val Dice Score: 0.5668

--- Epoch 23/25 ---


Training: 100%|██████████| 88/88 [02:50<00:00,  1.94s/it]
Validating: 100%|██████████| 19/19 [00:36<00:00,  1.93s/it]


Train Loss: 0.6359 | Val Loss: 0.8312 | Val Dice Score: 0.5656

--- Epoch 24/25 ---


Training: 100%|██████████| 88/88 [02:47<00:00,  1.90s/it]
Validating: 100%|██████████| 19/19 [00:36<00:00,  1.92s/it]


Train Loss: 0.6295 | Val Loss: 0.8266 | Val Dice Score: 0.5761

--- Epoch 25/25 ---


Training: 100%|██████████| 88/88 [02:50<00:00,  1.94s/it]
Validating: 100%|██████████| 19/19 [00:34<00:00,  1.80s/it]

Train Loss: 0.6226 | Val Loss: 0.8218 | Val Dice Score: 0.5686

✅ Training of Cseg complete!



