# Capstone Project: Polyp Segmentation (Kaggle Version)

## Semi-Supervised Learning with Swin-UNet

This notebook is adapted for the Kaggle environment. It implements a Semi-Supervised Learning (SSL) approach using a Swin-UNet architecture.

**Key Features:**
- **Data Source**: Expects `dataset1.zip` or extracted folder `CVC-ClinicDB-612` in `/kaggle/input`.
- **Model**: Swin-UNet (Swin Transformer Encoder + U-Net Decoder).
- **Training**: Iterative Pseudo-Labeling on unlabeled data.
- **Outputs**: Saves models and plots to `/kaggle/working/capstone_output/`.

In [None]:
!pip install -q transformers

In [None]:
# --- Import Libraries ---
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import PIL.Image
import cv2
import glob
import time
import shutil
import zipfile

# --- PyTorch Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset

# --- Torchvision Imports ---
import torchvision.transforms as T
from torchvision.transforms.functional import to_pil_image, to_tensor, resize

# --- Other Imports ---
from sklearn.model_selection import train_test_split
from transformers import SwinModel

# --- Setup Seed for Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# --- 1. Kaggle Environment Setup & Data Handling ---

# Define paths
KAGGLE_INPUT_DIR = '/kaggle/input'
KAGGLE_WORKING_DIR = '/kaggle/working'
TEMP_DATA_DIR = '/kaggle/temp/dataset' # Temporary directory for extraction if needed

# Search for the dataset
dataset_root = None

# Check if dataset is extracted in input
possible_paths = [
    os.path.join(KAGGLE_INPUT_DIR, 'dataset1', 'CVC-ClinicDB-612'),
    os.path.join(KAGGLE_INPUT_DIR, 'dataset1'),
    os.path.join(KAGGLE_INPUT_DIR, 'tanishqgupta3142/dataset1'),
    # Add more likely paths if known
]

found_zip = False
for p in possible_paths:
    if os.path.exists(os.path.join(p, 'images')) and os.path.exists(os.path.join(p, 'Ground Truth')):
        dataset_root = p
        print(f"Found extracted dataset at: {dataset_root}")
        break

# If not found, look for zip and extract
if dataset_root is None:
    print("Dataset not immediately found in expected extraction paths. Searching for zip...")
    # Try to find zip file
    zip_file = None
    for root, dirs, files in os.walk(KAGGLE_INPUT_DIR):
        for file in files:
            if file.endswith('.zip'):
                zip_file = os.path.join(root, file)
                break
        if zip_file:
            break
    
    if zip_file:
        print(f"Found zip file: {zip_file}. Extracting to {TEMP_DATA_DIR}...")
        os.makedirs(TEMP_DATA_DIR, exist_ok=True)
        with zipfile.ZipFile(zip_file, 'r') as zf:
            zf.extractall(TEMP_DATA_DIR)
        
        # Locate 'images' inside temp dir
        for root, dirs, files in os.walk(TEMP_DATA_DIR):
            if 'images' in dirs and 'Ground Truth' in dirs:
                dataset_root = root
                break
        print(f"Dataset downloaded and extracted to: {dataset_root}")
    else:
        # Fallback: Assume user will upload/add dataset named 'dataset1' properly
        print("WARNING: Could not find dataset zip or folders. Assuming default structure relative to current directory or manual fix needed.")
        dataset_root = "/kaggle/input/dataset1/CVC-ClinicDB-612" # Best guess fallback

DATA_ROOT = dataset_root
IMAGE_DIR = os.path.join(DATA_ROOT, "images")
MASK_DIR = os.path.join(DATA_ROOT, "Ground Truth")

# --- 2. Output Configuration ---
# Define output paths for saving models and plots
OUTPUT_ROOT = os.path.join(KAGGLE_WORKING_DIR, 'capstone_output')
os.makedirs(OUTPUT_ROOT, exist_ok=True)

print(f"Data Source: {DATA_ROOT}")
print(f"Output Directory: {OUTPUT_ROOT}")

In [None]:
# --- 3. Core Experiment Parameters ---
TRAINING_SIZE = 25       # Options: 25, 50, 75, 100 (SSL) or 490 (Fully Supervised)
NUM_EPOCHS = 20
START_EPOCH = 0          # Set to > 0 to resume training from 'latest_model.pth'
ITERATIONS_PER_EPOCH = 20 # Number of training batches per epoch (for SSL)
PSEUDO_LABEL_INTERVAL = 5  # Iterations to wait before pseudo-labeling (for SSL)
CONFIDENCE_THRESHOLD = 0.79 # Confidence score to accept a pseudo-label (for SSL)

BATCH_SIZE = 16
LEARNING_RATE = 1e-4
IMAGE_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Specific Run Paths
RUN_PATH = os.path.join(OUTPUT_ROOT, f'{TRAINING_SIZE}')
PLOT_PATH = os.path.join(RUN_PATH, 'plots')
LATEST_MODEL_PATH = os.path.join(RUN_PATH, 'latest_model.pth')
BEST_MODEL_PATH = os.path.join(RUN_PATH, 'best_model.pth')

os.makedirs(RUN_PATH, exist_ok=True)
os.makedirs(PLOT_PATH, exist_ok=True)

print(f"--- Experiment Configuration ---")
print(f"Training Scenario: {'Semi-Supervised' if TRAINING_SIZE < 490 else 'Fully Supervised'}")
print(f"Training Size: {TRAINING_SIZE}")
print(f"Device: {DEVICE}")
print(f"Run Path: {RUN_PATH}")

In [None]:
# --- 4. Model Architecture (Swin-UNet) ---

class SwinUNet(nn.Module):
    def __init__(self, num_classes=1):
        super(SwinUNet, self).__init__()
        # Encoder: Pre-trained Swin Transformer
        self.swin = SwinModel.from_pretrained(
            "microsoft/swin-base-patch4-window7-224",
            output_hidden_states=True,
        )

        # Freeze some early layers of the encoder
        for name, param in self.swin.named_parameters():
            if "layers.0" in name or "layers.1" in name or "embed" in name:
                 param.requires_grad = False

        # --- Decoder ---
        self.decoder4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0)
        self.conv4 = nn.Sequential(nn.Conv2d(1024, 512, 3, padding=1), nn.ReLU())

        self.decoder3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Sequential(nn.Conv2d(512, 256, 3, padding=1), nn.ReLU())

        self.decoder2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Sequential(nn.Conv2d(256, 128, 3, padding=1), nn.ReLU())

        # Final upsampling layers
        self.final_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 
        self.final_conv = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, x):
        hidden_states = self.swin(x).hidden_states

        # Stage 1: (B, 56*56, 128) -> (B, 128, 56, 56)
        B, L1, C1 = hidden_states[0].shape
        H1 = W1 = int(L1**0.5)
        s1 = hidden_states[0].reshape(B, H1, W1, C1).permute(0, 3, 1, 2)

        # Stage 2: (B, 28*28, 256) -> (B, 256, 28, 28)
        B, L2, C2 = hidden_states[1].shape
        H2 = W2 = int(L2**0.5)
        s2 = hidden_states[1].reshape(B, H2, W2, C2).permute(0, 3, 1, 2)

        # Stage 3: (B, 14*14, 512) -> (B, 512, 14, 14)
        B, L3, C3 = hidden_states[2].shape
        H3 = W3 = int(L3**0.5)
        s3 = hidden_states[2].reshape(B, H3, W3, C3).permute(0, 3, 1, 2)

        # Stage 4: (B, 7*7, 1024) -> (B, 1024, 7, 7)
        B, L4, C4 = hidden_states[3].shape
        H4 = W4 = int(L4**0.5)
        s4 = hidden_states[3].reshape(B, H4, W4, C4).permute(0, 3, 1, 2)

        # --- Decoder Path with Skip Connections ---
        # Upsample d4 to match s3 (7x7 -> 14x14)
        d4 = self.decoder4(s4)
        d4 = torch.cat([d4, s3], dim=1)
        d4 = self.conv4(d4)

        # Upsample d3 to match s2 (14x14 -> 28x28)
        d3 = self.decoder3(d4)
        d3 = torch.cat([d3, s2], dim=1)
        d3 = self.conv3(d3)

        # Upsample d2 to match s1 (28x28 -> 56x56)
        d2 = self.decoder2(d3)
        d2 = torch.cat([d2, s1], dim=1)
        d2 = self.conv2(d2)

        # Final upsampling to 224x224
        out = self.final_upsample(d2)
        out = self.final_conv(out)

        return out

In [None]:
# --- 5. Loss & Metrics ---

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        probs_flat = probs.view(-1)
        targets_flat = targets.view(-1)
        intersection = (probs_flat * targets_flat).sum()
        dice = (2. * intersection + self.smooth) / (probs_flat.sum() + targets_flat.sum() + self.smooth)
        return 1 - dice

def combined_loss(logits, targets):
    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss()
    return bce(logits, targets) + dice(logits, targets)

def calculate_metrics(logits, targets, threshold=0.5, epsilon=1e-6):
    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()

    preds_flat = preds.view(-1)
    targets_flat = targets.view(-1)

    intersection = (preds_flat * targets_flat).sum()

    dice_score = (2. * intersection) / (preds_flat.sum() + targets_flat.sum() + epsilon)
    
    union = preds_flat.sum() + targets_flat.sum() - intersection
    iou_score = (intersection) / (union + epsilon)

    return iou_score.item(), dice_score.item()

def get_confidence_score(probs, pred_mask, epsilon=1e-6):
    confidence = (probs * pred_mask).sum() / (pred_mask.sum() + epsilon)
    return confidence.item()

In [None]:
# --- 6. Dataset Class ---

class PolypDataset(Dataset):
    def __init__(self, image_filenames, image_dir, mask_dir, transform=None, pseudo_masks=None):
        self.image_filenames = image_filenames
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.pseudo_masks = pseudo_masks if pseudo_masks is not None else {}

        self.image_transform = T.Compose([
            T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.mask_transform = T.Compose([
            T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        filename = self.image_filenames[idx]
        img_path = os.path.join(self.image_dir, filename)
        image = PIL.Image.open(img_path).convert("RGB")

        if filename in self.pseudo_masks:
            mask_np = self.pseudo_masks[filename]
            mask = PIL.Image.fromarray((mask_np * 255).astype(np.uint8))
        else:
            mask_path = os.path.join(self.mask_dir, filename)
            mask = PIL.Image.open(mask_path).convert("L")

        image = self.image_transform(image)
        mask = self.mask_transform(mask)
        mask = (mask > 0.5).float()

        return image, mask, filename

In [None]:
# --- 7. Data Splitting & Learners ---

all_filenames = sorted([f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
print(f"Found {len(all_filenames)} total images")

# Split main pool and fixed validation pool
train_pool_files, val_pool_files = train_test_split(
    all_filenames,
    test_size=0.15,
    random_state=SEED
)

print(f"Training Pool: {len(train_pool_files)}")
print(f"Validation Pool: {len(val_pool_files)}")

val_dataset = PolypDataset(
    image_filenames=val_pool_files,
    image_dir=IMAGE_DIR,
    mask_dir=MASK_DIR
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

# Define Pools
if TRAINING_SIZE == 490: # Fully supervised
    labeled_pool_files = train_pool_files
    unlabeled_pool_files = []
    train_loader = DataLoader(
        PolypDataset(labeled_pool_files, IMAGE_DIR, MASK_DIR),
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2
    )
else: # Semi-Supervised
    labeled_pool_files = train_pool_files[:TRAINING_SIZE]
    unlabeled_pool_files = train_pool_files[TRAINING_SIZE:]
    # Train loader will be re-created in the loop
    print(f"Initial Labeled: {len(labeled_pool_files)}, Unlabeled: {len(unlabeled_pool_files)}")

pseudo_masks = {}

In [None]:
# --- 8. Visualization Utilities ---

def save_loss_metric_plots(train_losses, val_losses, train_metrics, val_metrics, current_epoch, save_path):
    """Saves plots for loss and metrics vs. epoch."""
    epochs = range(1, current_epoch + 1)

    # --- Loss Plot ---
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_losses, 'b-o', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-o', label='Validation Loss')
    plt.title(f'Training vs. Validation Loss (Epoch {current_epoch})')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f'loss_curve_epoch_{current_epoch}.png'))
    plt.close()

    # --- Metric (mIoU) Plot ---
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_metrics, 'b-o', label='Training mIoU')
    plt.plot(epochs, val_metrics, 'r-o', label='Validation mIoU')
    plt.title(f'Training vs. Validation mIoU (Epoch {current_epoch})')
    plt.xlabel('Epochs')
    plt.ylabel('mIoU')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f'miou_curve_epoch_{current_epoch}.png'))
    plt.close()

def save_data_growth_plot(pool_sizes, current_epoch, save_path):
    """Saves a bar chart of the labeled pool size growth."""
    if not pool_sizes:
        return

    epochs = range(1, current_epoch + 1)
    plt.figure(figsize=(10, 5))
    plt.bar(epochs, pool_sizes, color='green')
    plt.title(f'Labeled Pool Size Growth (Epoch {current_epoch})')
    plt.xlabel('Epochs')
    plt.ylabel('Number of Labeled Samples')
    plt.xticks(epochs)
    plt.grid(axis='y', linestyle='--')
    plt.savefig(os.path.join(save_path, f'data_growth_epoch_{current_epoch}.png'))
    plt.close()

def save_new_labels_plot(new_labels, current_epoch, save_path):
    """Saves a bar chart of the new pseudo-labels added per epoch."""
    if not new_labels:
        return

    epochs = range(1, current_epoch + 1)
    plt.figure(figsize=(10, 5))
    plt.bar(epochs, new_labels, color='dodgerblue')
    plt.title(f'New Pseudo-Labels Added Per Epoch (Epoch {current_epoch})')
    plt.xlabel('Epochs')
    plt.ylabel('Number of New Samples Added')
    plt.xticks(epochs)
    plt.grid(axis='y', linestyle='--')
    plt.savefig(os.path.join(save_path, f'new_labels_per_epoch_{current_epoch}.png'))
    plt.close()

def save_sample_predictions(model, loader, device, epoch, save_path, num_samples=3):
    """Saves 3-panel sample predictions (Image, Ground Truth, Prediction)."""
    model.eval()
    try:
        images, masks, _ = next(iter(loader))
        images = images.to(device)
        
        with torch.no_grad():
            logits = model(images)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
        
        images = images.cpu()
        masks = masks.cpu()
        preds = preds.cpu()
        
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        
        for i in range(min(num_samples, len(images))):
            img = (images[i] * std) + mean
            img = to_pil_image(img.clamp(0, 1))
            gt_mask = to_pil_image(masks[i].squeeze(0))
            pred_mask = to_pil_image(preds[i].squeeze(0))
            
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
            ax1.imshow(img)
            ax1.set_title("Image")
            ax1.axis('off')
            ax2.imshow(gt_mask, cmap='gray')
            ax2.set_title("Ground Truth")
            ax2.axis('off')
            ax3.imshow(pred_mask, cmap='gray')
            ax3.set_title("Prediction")
            ax3.axis('off')
            
            plt.suptitle(f'Epoch {epoch} Sample {i+1}')
            plt.savefig(os.path.join(save_path, f'epoch_{epoch}_sample_{i+1}.png'))
            plt.close()
    except Exception as e:
        print(f"Error plotting samples: {e}")
    finally:
        model.train()

## Training Loop

In [None]:
# --- 9. Training Loop ---

model = SwinUNet(num_classes=1).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = combined_loss

train_losses = []
val_losses = []
train_mious = []
val_mious = []
labeled_pool_sizes = []
new_labels_per_epoch = []

best_miou = 0.0

for epoch in range(START_EPOCH, NUM_EPOCHS + 1):
    print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")
    
    # --- 1. Prepare Training Data ---
    current_labeled_files = labeled_pool_files + list(pseudo_masks.keys())
    current_dataset = PolypDataset(
        current_labeled_files,
        IMAGE_DIR,
        MASK_DIR,
        pseudo_masks=pseudo_masks
    )
    
    train_loader = DataLoader(
        current_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2
    )
    
    # --- 2. Train ---
    model.train()
    epoch_loss = 0
    epoch_iou = 0
    iterations = 0
    
    pbar = tqdm(total=ITERATIONS_PER_EPOCH, desc="Training")
    train_iter = iter(train_loader)
    
    while iterations < ITERATIONS_PER_EPOCH:
        try:
            images, masks, _ = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            images, masks, _ = next(train_iter)
            
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()
        
        # Calculate Training Metric for this batch
        iou_batch, _ = calculate_metrics(logits, masks)
        epoch_loss += loss.item()
        epoch_iou += iou_batch
        
        iterations += 1
        pbar.update(1)
        
    pbar.close()
    avg_train_loss = epoch_loss / ITERATIONS_PER_EPOCH
    avg_train_iou = epoch_iou / ITERATIONS_PER_EPOCH
    train_losses.append(avg_train_loss)
    train_mious.append(avg_train_iou)
    labeled_pool_sizes.append(len(current_labeled_files))
    
    # --- 3. Validate ---
    model.eval()
    val_loss = 0
    val_miou = 0
    with torch.no_grad():
        for images, masks, _ in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            logits = model(images)
            val_loss += criterion(logits, masks).item()
            iou, _ = calculate_metrics(logits, masks)
            val_miou += iou
            
    avg_val_loss = val_loss / len(val_loader)
    avg_val_miou = val_miou / len(val_loader)
    val_losses.append(avg_val_loss)
    val_mious.append(avg_val_miou)
    
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val mIoU: {avg_val_miou:.4f}")
    
    # --- 4. Pseudo-Labeling (Semi-Supervised Only) ---
    new_pseudos_this_epoch = 0
    if TRAINING_SIZE < 490 and (epoch % 1 == 0):
        print("Generating pseudo-labels...")
        model.eval()
        
        # Create loader for unlabeled pool
        if unlabeled_pool_files:
            unlabeled_dataset = PolypDataset(unlabeled_pool_files, IMAGE_DIR, MASK_DIR)
            unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=False)
            
            width_candidates = []
            
            with torch.no_grad():
                for images, _, filenames in unlabeled_loader:
                    images = images.to(DEVICE)
                    logits = model(images)
                    probs = torch.sigmoid(logits)
                    preds = (probs > 0.5).float()
                    
                    # Evaluate confidence
                    for i in range(len(images)):
                        fname = filenames[i]
                        prob_map = probs[i].squeeze(0)
                        pred_mask = preds[i].squeeze(0)
                        
                        conf = get_confidence_score(prob_map, pred_mask)
                        
                        if conf > CONFIDENCE_THRESHOLD:
                            pseudo_masks[fname] = pred_mask.cpu().numpy()
                            width_candidates.append(fname)
                            new_pseudos_this_epoch += 1
            
            # Update pools
            for fname in width_candidates:
                if fname in unlabeled_pool_files:
                    unlabeled_pool_files.remove(fname)
            
            print(f"Added {new_pseudos_this_epoch} new pseudo-labels. Total Pseudo: {len(pseudo_masks)}")
    
    new_labels_per_epoch.append(new_pseudos_this_epoch)
            
    # --- 5. Save Checkpoints & Plots ---
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_mious': train_mious,
        'val_mious': val_mious,
    }, LATEST_MODEL_PATH)
    
    # Save Best Model
    if avg_val_miou > best_miou:
        best_miou = avg_val_miou
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        print(f"New Best Model Saved! (mIoU: {best_miou:.4f})")
    
    # Call Visualization Functions
    save_sample_predictions(model, val_loader, DEVICE, epoch, PLOT_PATH)
    save_loss_metric_plots(train_losses, val_losses, train_mious, val_mious, epoch, PLOT_PATH)
    if TRAINING_SIZE < 490:
        save_data_growth_plot(labeled_pool_sizes, epoch, PLOT_PATH)
        save_new_labels_plot(new_labels_per_epoch, epoch, PLOT_PATH)
        
    print(f"Mode and Plots saved to {RUN_PATH}")