In [2]:
!pip install torch torchvision albumentations segmentation-models-pytorch opencv-python-headless xmltodict

Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl.metadata (43 kB)
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-macosx_13_0_arm64.whl.metadata (19 kB)
Collecting xmltodict
  Downloading xmltodict-1.0.2-py3-none-any.whl.metadata (15 kB)
Collecting albucore==0.0.24 (from albumentations)
  Downloading albucore-0.0.24-py3-none-any.whl.metadata (5.3 kB)
Collecting stringzilla>=3.10.4 (from albucore==0.0.24->albumentations)
  Downloading stringzilla-4.2.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (110 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.24->albumentations)
  Downloading simsimd-6.5.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (70 kB)
Collecting numpy (from torchvision)
  Downloading numpy-2.2.6-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Downloading albumentations-2.0.8-py3-none

In [19]:
!pip install hf_xet

Collecting hf_xet
  Downloading hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl.metadata (4.9 kB)
Downloading hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl (2.7 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: hf_xet
Successfully installed hf_xet-1.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
# Multi-Organ Nuclei Segmentation & Classification
# CS GY 6643 - Project 2

# Install required packages (run once)
# !pip install torch torchvision albumentations segmentation-models-pytorch opencv-python-headless xmltodict

import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from pathlib import Path
from tqdm import tqdm
import warnings
import xmltodict
from PIL import Image
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm
  from scipy import special


Using device: cpu


In [6]:
# ============================================
# 1. DATA LOADING AND RLE UTILITIES
# ============================================

def rle_decode_instance_mask(rle: str, shape: tuple) -> np.ndarray:
    """
    Convert RLE triple string back into an instance mask of shape (H, W).
    """
    if not rle or str(rle).strip() in ("", "0", "nan"):
        return np.zeros(shape, dtype=np.uint16)
    s = list(map(int, rle.split()))
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint16)
    for i in range(0, len(s), 3):
        val, start, length = s[i], s[i+1], s[i+2]
        mask[start-1:start-1+length] = val
    return mask.reshape(shape, order="F")

def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Convert an instance segmentation mask (H,W) -> RLE triple string.
    0 = background, >0 = instance IDs.
    """
    pixels = mask.flatten(order="F").astype(np.int32)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1

    rle = []
    for i in range(0, len(runs)-1):
        start = runs[i]
        end = runs[i+1] if i+1 < len(runs) else len(pixels)-1
        length = end - start
        val = pixels[start]
        if val > 0:
            rle.extend([val, start, length])

    if not rle:
        return "0"
    return " ".join(map(str, rle))

def load_xml_annotations(xml_path):
    """
    Load instance-level annotations from XML file.
    Returns dict: {class_name: list of polygon coordinates}
    """
    with open(xml_path, 'r') as f:
        data = xmltodict.parse(f.read())
    
    annotations = {}
    regions_data = data.get('Annotations', {}).get('Annotation', [])
    
    if not isinstance(regions_data, list):
        regions_data = [regions_data]
    
    for annotation in regions_data:
        class_name = annotation.get('@Name', 'Unknown')
        regions = annotation.get('Regions', {}).get('Region', [])
        
        if not isinstance(regions, list):
            regions = [regions]
        
        polygons = []
        for region in regions:
            vertices = region.get('Vertices', {}).get('Vertex', [])
            if not isinstance(vertices, list):
                vertices = [vertices]
            
            points = [(float(v['@X']), float(v['@Y'])) for v in vertices]
            if len(points) >= 3:
                polygons.append(np.array(points, dtype=np.int32))
        
        if polygons:
            annotations[class_name] = polygons
    
    return annotations

def create_instance_masks_from_xml(xml_path, shape):
    """
    Create instance masks for all 4 classes from XML annotations.
    Returns: dict with keys ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
    Each value is a 2D array where each unique instance has a unique ID.
    """
    annotations = load_xml_annotations(xml_path)
    
    masks = {
        'Epithelial': np.zeros(shape, dtype=np.uint16),
        'Lymphocyte': np.zeros(shape, dtype=np.uint16),
        'Neutrophil': np.zeros(shape, dtype=np.uint16),
        'Macrophage': np.zeros(shape, dtype=np.uint16)
    }
    
    for class_name in masks.keys():
        if class_name in annotations:
            instance_id = 1
            for polygon in annotations[class_name]:
                cv2.fillPoly(masks[class_name], [polygon], instance_id)
                instance_id += 1
    
    return masks

# Load training data
train_df = pd.read_csv('kaggle-data/train_ground_truth.csv')
print(f"Training samples: {len(train_df)}")
print(f"Columns: {train_df.columns.tolist()}")
print(train_df.head())

Training samples: 209
Columns: ['image_id', 'Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
  image_id                                         Epithelial  \
0   slide1                                                  0   
1   slide2  191 3596 1 191 4378 14 191 5162 15 191 5947 16...   
2   slide3                                                  0   
3   slide4  17 606 1 17 1114 9 61 1425 21 17 1624 13 61 19...   
4   slide5  1 106286 7 1 106708 9 1 107131 11 1 107553 13 ...   

                                          Lymphocyte  \
0                                                  0   
1  6 15974 1 1 16131 2 6 16755 8 1 16911 8 6 1753...   
2                                                  0   
3                                                  0   
4  70 1748 2 69 2128 3 70 2168 8 69 2550 7 70 259...   

                                          Neutrophil  \
0                                                  0   
1                                                  0   
2  2

In [24]:
# ============================================
# 2. DATASET CLASS
# ============================================

class NucleiDataset(Dataset):
    def __init__(self, df, img_dir, mode='train', transform=None, use_xml=False):
        self.df = df
        self.img_dir = Path(img_dir)
        self.mode = mode
        self.transform = transform
        self.use_xml = use_xml
        self.classes = ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']
        
        # Load image
        img_path = self.img_dir / f"{image_id}.tif"
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]
        
        if self.mode == 'test':
            if self.transform:
                transformed = self.transform(image=image)
                image = transformed['image']
            return {'image': image, 'image_id': image_id, 'shape': (h, w)}
        
        # Load masks
        if self.use_xml:
            xml_path = self.img_dir / f"{image_id}.xml"
            if xml_path.exists():
                instance_masks = create_instance_masks_from_xml(str(xml_path), (h, w))
            else:
                instance_masks = {cls: np.zeros((h, w), dtype=np.uint16) for cls in self.classes}
        else:
            instance_masks = {}
            for cls in self.classes:
                rle = row[cls]
                instance_masks[cls] = rle_decode_instance_mask(rle, (h, w))
        
        # Convert to semantic segmentation (class per pixel)
        semantic_mask = np.zeros((h, w), dtype=np.uint8)
        for class_idx, cls in enumerate(self.classes, start=1):
            semantic_mask[instance_masks[cls] > 0] = class_idx
        
        if self.transform:
            transformed = self.transform(image=image, mask=semantic_mask)
            image = transformed['image']
            semantic_mask = transformed['mask']
        
        return {
            'image': image,
            'mask': semantic_mask.long(),
            'image_id': image_id,
            'instance_masks': instance_masks
        }

# Define augmentations
def get_train_transforms(img_size=512):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),  # First resize to ensure consistent size
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 50.0)),
            A.GaussianBlur(),
            A.MotionBlur(),
        ], p=0.3),
        A.OneOf([
            A.OpticalDistortion(distort_limit=0.1),
            A.GridDistortion(num_steps=5, distort_limit=0.3),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50),
        ], p=0.3),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transforms(img_size=512):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

In [25]:
# ============================================
# 3. MODEL ARCHITECTURE
# ============================================

class NucleiSegmentationModel(nn.Module):
    """
    U-Net with ResNet encoder for semantic segmentation.
    Outputs: 5 classes (background + 4 cell types)
    """
    def __init__(self, encoder_name='resnet50', num_classes=5, pretrained=True):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights='imagenet' if pretrained else None,
            in_channels=3,
            classes=num_classes,
        )
    
    def forward(self, x):
        return self.model(x)

# Loss function that handles class imbalance
class WeightedCombinedLoss(nn.Module):
    def __init__(self, weights=None):
        super().__init__()
        self.weights = weights
        self.ce_loss = nn.CrossEntropyLoss(weight=weights)
    
    def forward(self, pred, target):
        ce = self.ce_loss(pred, target)
        
        # Dice loss for better segmentation
        pred_softmax = F.softmax(pred, dim=1)
        dice = 0
        num_classes = pred.shape[1]
        
        for c in range(num_classes):
            pred_c = pred_softmax[:, c]
            target_c = (target == c).float()
            intersection = (pred_c * target_c).sum()
            dice += (2. * intersection + 1e-5) / (pred_c.sum() + target_c.sum() + 1e-5)
        
        dice = 1 - dice / num_classes
        return ce + dice

# Initialize model
def create_model(num_classes=5, device='cuda'):
    model = NucleiSegmentationModel(
        encoder_name='resnet50',
        num_classes=num_classes,
        pretrained=True
    )
    model = model.to(device)
    
    # Class weights to handle imbalance (emphasize rare classes)
    class_weights = torch.tensor([1.0, 1.0, 1.0, 5.0, 5.0]).to(device)
    criterion = WeightedCombinedLoss(weights=class_weights)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
    
    return model, criterion, optimizer, scheduler

In [26]:
# ============================================
# 4. TRAINING LOOP
# ============================================

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(loader)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Validating'):
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            total_loss += loss.item()
    
    return total_loss / len(loader)

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
                num_epochs=30, device='cuda'):
    best_val_loss = float('inf')
    history = {'train_loss': [], 'val_loss': []}
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = validate_epoch(model, val_loader, criterion, device)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            print("✓ Model saved!")
        
        scheduler.step()
    
    return history

# Create data loaders
IMG_SIZE = 512
BATCH_SIZE = 8

train_dataset = NucleiDataset(
    train_df, 
    'kaggle-data/train', 
    mode='train',
    transform=get_train_transforms(IMG_SIZE),
    use_xml=True  # Set to True if you want to use XML files
)

# Split for validation (80-20 split)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size]
)

# Set num_workers=0 for Jupyter notebooks to avoid multiprocessing issues
# If running as a .py script, you can increase num_workers for faster loading
train_loader = DataLoader(
    train_subset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=0,  # Changed from 4 to 0 for Jupyter compatibility
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_subset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=0,  # Changed from 4 to 0 for Jupyter compatibility
    pin_memory=True if torch.cuda.is_available() else False
)

# Create model and train
model, criterion, optimizer, scheduler = create_model(num_classes=5, device=device)

# Train the model
history = train_model(
    model, train_loader, val_loader, 
    criterion, optimizer, scheduler,
    num_epochs=30, 
    device=device
)


Epoch 1/30


Training:   0%|          | 0/21 [00:00<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [597, 657] at entry 0 and [398, 265] at entry 1