In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
def save_checkpoint(model, optimizer, epoch, best_val_dice, path):
    checkpoint = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "best_val_dice": best_val_dice
    }
    torch.save(checkpoint, path)
    print(f"üíæ Checkpoint saved at {path} (Epoch {epoch+1}, Dice {best_val_dice:.4f})")

def load_checkpoint(model, optimizer, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    start_epoch = checkpoint["epoch"] + 1  # resume from next epoch
    best_val_dice = checkpoint["best_val_dice"]
    print(f"‚úÖ Resumed from checkpoint at epoch {start_epoch} with best Dice {best_val_dice:.4f}")
    return model, optimizer, start_epoch, best_val_dice

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
import segmentation_models_pytorch as smp

# --- 1. Configuration ---
CONFIG = {
    "IMAGE_DIR": "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/",
    "MASK_DIR": "/content/drive/MyDrive/CAF-GAN/data/masks_512x512/",
    "TRAIN_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/splits/train.csv",
    "VAL_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/splits/val.csv",
    "OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/cseg_512/",
    "MODEL_NAME": "best_cseg_512.pth",
    "IMG_SIZE": 512,
    "BATCH_SIZE": 8,
    "EPOCHS": 25,
    "LEARNING_RATE": 0.0001,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "NUM_WORKERS": 2
}
print("‚úÖ Configuration loaded for 512x512 Cseg training.")
print(CONFIG)

# --- 2. PyTorch Dataset (MODIFIED) ---
class MIMICXRSegmentationDataset(Dataset):
    # We pass img_size to the constructor
    def __init__(self, df, image_dir, mask_dir, img_size, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.img_size = img_size # <-- Store image size
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id, study_id, dicom_id = str(row['subject_id']), str(row['study_id']), row['dicom_id']

        image_path = os.path.join(self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg')
        mask_path = os.path.join(self.mask_dir, f"{dicom_id}.png")

        # Load high-resolution image
        image = Image.open(image_path).convert("RGB")

        # --- KEY CHANGE: Resize the image to match the mask size BEFORE augmentation ---
        # Use high-quality LANCZOS resampling
        image = image.resize((self.img_size, self.img_size), Image.LANCZOS)
        image = np.array(image)
        # ---------------------------------------------------------------------------------

        # Load the 512x512 mask
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # Normalize mask to 0.0-1.0 range
        mask[mask == 255.0] = 1.0

        # Now, both image and mask are 512x512, so the transform will work
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.unsqueeze(0)

# --- 3. Loss, Metrics & Training Functions (Unchanged) ---
class DiceBCELoss(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs_flat, targets_flat = inputs.view(-1), targets.view(-1)
        intersection = (inputs_flat * targets_flat).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs_flat.sum() + targets_flat.sum() + smooth)
        bce_loss = nn.BCEWithLogitsLoss()(inputs, targets)
        return bce_loss + dice_loss

def dice_score(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds) > 0.5
    preds_flat, targets_flat = preds.float().view(-1), targets.view(-1)
    intersection = (preds_flat * targets_flat).sum()
    return (2. * intersection + smooth) / (preds_flat.sum() + targets_flat.sum() + smooth)

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc="Training"):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(dataloader.dataset)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss, total_dice = 0.0, 0.0
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Validating"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item() * images.size(0)
            total_dice += dice_score(outputs, masks).item()
    return running_loss / len(dataloader.dataset), total_dice / len(dataloader)

# --- 4. Main Training Execution ---
def run_training():
    DEVICE = CONFIG['DEVICE']
    os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)

    # The Albumentations transform no longer needs a Resize at the start,
    # but it is harmless to keep it as a safeguard.
    transform = A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=15, p=0.7),
        A.RandomBrightnessContrast(p=0.3),
        A.GaussNoise(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    train_df = pd.read_csv(CONFIG['TRAIN_CSV_PATH'])
    val_df = pd.read_csv(CONFIG['VAL_CSV_PATH'])

    # --- KEY CHANGE: Pass IMG_SIZE to the Dataset constructor ---
    train_dataset = MIMICXRSegmentationDataset(train_df, CONFIG['IMAGE_DIR'], CONFIG['MASK_DIR'], CONFIG['IMG_SIZE'], transform)
    val_dataset = MIMICXRSegmentationDataset(val_df, CONFIG['IMAGE_DIR'], CONFIG['MASK_DIR'], CONFIG['IMG_SIZE'], transform)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=CONFIG['NUM_WORKERS'])
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=CONFIG['NUM_WORKERS'])

    model = smp.Unet("resnet34", encoder_weights="imagenet", in_channels=3, classes=1).to(DEVICE)
    criterion = DiceBCELoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['LEARNING_RATE'])

    # üîπ Check if checkpoint exists
    checkpoint_path = os.path.join(CONFIG['OUTPUT_DIR'], "checkpoint.pth")
    if os.path.exists(checkpoint_path):
        model, optimizer, start_epoch, best_val_dice = load_checkpoint(model, optimizer, checkpoint_path, DEVICE)
    else:
        start_epoch, best_val_dice = 0, 0.0
        print("üöÄ Starting training from scratch")

    print("\nüèãÔ∏è‚Äç‚ôÄÔ∏è Training Cseg model with 512x512 images...")
    for epoch in range(start_epoch, CONFIG['EPOCHS']):
        print(f"\n--- Epoch {epoch+1}/{CONFIG['EPOCHS']} ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        val_loss, val_dice = validate(model, val_loader, criterion, DEVICE)
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice Score: {val_dice:.4f}")

        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            model_path = os.path.join(CONFIG['OUTPUT_DIR'], CONFIG['MODEL_NAME'])
            torch.save(model.state_dict(), model_path)
            print(f"‚ú® Best model saved to {model_path} (Dice: {val_dice:.4f})")

        # üîπ Always save checkpoint (so you can resume tomorrow)
        save_checkpoint(model, optimizer, epoch, best_val_dice, checkpoint_path)

    print("\n‚úÖ Training complete!")

run_training()

‚úÖ Configuration loaded for 512x512 Cseg training.
{'IMAGE_DIR': '/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/', 'MASK_DIR': '/content/drive/MyDrive/CAF-GAN/data/masks_512x512/', 'TRAIN_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/train.csv', 'VAL_CSV_PATH': '/content/drive/MyDrive/CAF-GAN/data/splits/val.csv', 'OUTPUT_DIR': '/content/drive/MyDrive/CAF-GAN/outputs/cseg_512/', 'MODEL_NAME': 'best_cseg_512.pth', 'IMG_SIZE': 512, 'BATCH_SIZE': 8, 'EPOCHS': 25, 'LEARNING_RATE': 0.0001, 'DEVICE': 'cuda', 'NUM_WORKERS': 2}


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

‚úÖ Resumed from checkpoint at epoch 22 with best Dice 0.9525

üèãÔ∏è‚Äç‚ôÄÔ∏è Training Cseg model with 512x512 images...

--- Epoch 23/25 ---


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 175/175 [26:31<00:00,  9.09s/it]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [05:36<00:00,  8.84s/it]


Train Loss: 0.6631 | Val Loss: 0.6670 | Val Dice Score: 0.9508
üíæ Checkpoint saved at /content/drive/MyDrive/CAF-GAN/outputs/cseg_512/checkpoint.pth (Epoch 23, Dice 0.9525)

--- Epoch 24/25 ---


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 175/175 [03:42<00:00,  1.27s/it]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [00:42<00:00,  1.11s/it]


Train Loss: 0.6612 | Val Loss: 0.6681 | Val Dice Score: 0.9502
üíæ Checkpoint saved at /content/drive/MyDrive/CAF-GAN/outputs/cseg_512/checkpoint.pth (Epoch 24, Dice 0.9525)

--- Epoch 25/25 ---


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 175/175 [03:41<00:00,  1.26s/it]
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38/38 [00:42<00:00,  1.12s/it]


Train Loss: 0.6618 | Val Loss: 0.6674 | Val Dice Score: 0.9501
üíæ Checkpoint saved at /content/drive/MyDrive/CAF-GAN/outputs/cseg_512/checkpoint.pth (Epoch 25, Dice 0.9525)

‚úÖ Training complete!
