In [1]:
import numpy as np
import pandas as pd
from typing import Tuple, Dict, Any
import matplotlib.pyplot as plt
import os
from PIL import Image

# --- RLE Encoding and Decoding Functions (Provided in the Prompt) ---

def rle_decode_instance_mask(rle: str, shape: Tuple[int, int]) -> np.ndarray:
    """
    Convert RLE triple string back into an instance mask of shape (H, W).
    """
    if not rle or str(rle).strip() in ("", "0", "nan"):
        return np.zeros(shape, dtype=np.uint16)
    
    # Handle potential float/NaN from pandas read
    if isinstance(rle, float) and np.isnan(rle):
        return np.zeros(shape, dtype=np.uint16)
    
    s = list(map(int, str(rle).split()))
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint16)
    
    # s is a list of [val, start, length, val, start, length, ...]
    for i in range(0, len(s), 3):
        val, start, length = s[i], s[i+1], s[i+2]
        # RLE uses 1-based indexing
        mask[start-1:start-1+length] = val
        
    # Reshape back to (H, W) using 'F' order (column-major/Fortran order)
    return mask.reshape(shape, order="F")

def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Convert an instance segmentation mask (H,W) -> RLE triple string.
    0 = background, >0 = instance IDs.
    """
    # Flatten using Fortran order (column-major)
    pixels = mask.flatten(order="F").astype(np.int32)
    # Pad with 0s for run-length detection
    pixels = np.concatenate([[0], pixels, [0]])
    # Find the indices where the value changes
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1

    rle = []
    # Iterate through runs to extract the [value, start, length] triples
    for i in range(0, len(runs)-1):
        start = runs[i]
        end = runs[i+1]
        length = end - start
        val = pixels[start]
        # Only encode instances (val > 0)
        if val > 0:
            rle.extend([val, start, length])

    if not rle:
        return "0" # Return "0" if no instances are found

    return " ".join(map(str, rle))

In [2]:
# --- Configuration ---
DATA_DIR = "kaggle-data"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test_final")
ANNOTATION_FILE = os.path.join(DATA_DIR, "train_ground_truth.csv")
TARGET_CLASSES = ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil']

# --- Main Data Loading Function ---

def load_and_decode_masks(annotation_file: str, train_dir: str) -> pd.DataFrame:
    """
    Loads the ground truth CSV and adds image dimensions to each row.
    """
    # Read the annotation file
    df = pd.read_csv(annotation_file)
    
    # 1. Determine image shapes
    print("Determining image dimensions...")
    shapes = {}
    for image_id in df['image_id']:
        image_path = os.path.join(train_dir, f"{image_id}.tif")
        try:
            # Use PIL to quickly read the TIF metadata (size)
            with Image.open(image_path) as img:
                # Shape is (Height, Width)
                shapes[image_id] = (img.height, img.width)
        except FileNotFoundError:
            print(f"Warning: Image file not found for {image_id}")
            shapes[image_id] = (0, 0) # Placeholder for missing file

    # 2. Add H and W columns to the dataframe
    df['Height'] = df['image_id'].map(lambda x: shapes.get(x, (0, 0))[0])
    df['Width'] = df['image_id'].map(lambda x: shapes.get(x, (0, 0))[1])
    
    print(f"Loaded annotations for {len(df)} images.")
    return df

# --- Execution ---
# Load the dataframe with image sizes
train_df = load_and_decode_masks(ANNOTATION_FILE, TRAIN_DIR)
print("\nFirst 3 rows of the processed DataFrame:")
print(train_df[['image_id', 'Height', 'Width', 'Epithelial']].head(3))

# --- Example of Decoding a Single Image's Masks ---

EXAMPLE_ID = train_df['image_id'].iloc[0]
example_row = train_df[train_df['image_id'] == EXAMPLE_ID].iloc[0]
H, W = example_row['Height'], example_row['Width']

print(f"\nExample Image: {EXAMPLE_ID} (Shape: {H}x{W})")

# Dictionary to hold the instance mask for each class
decoded_masks: Dict[str, np.ndarray] = {}

for class_name in TARGET_CLASSES:
    rle_string = example_row[class_name]
    # Decode the RLE string into a 2D instance mask
    mask = rle_decode_instance_mask(rle_string, (H, W))
    decoded_masks[class_name] = mask
    
    num_instances = np.unique(mask[mask > 0]).size
    print(f"  - {class_name}: {num_instances} instances found.")

# --- Visualization (Optional but highly recommended) ---
def visualize_masks(image_id: str, image_dir: str, masks: Dict[str, np.ndarray]):
    """Visualizes the original image and the decoded instance masks."""
    
    image_path = os.path.join(image_dir, f"{image_id}.tif")
    try:
        image = np.array(Image.open(image_path))
    except FileNotFoundError:
        print(f"Could not load image {image_id} for visualization.")
        return

    fig, axes = plt.subplots(1, 5, figsize=(20, 5))
    
    # 1. Original Image
    axes[0].imshow(image)
    axes[0].set_title(f'Original H&E ({image_id})')
    axes[0].axis('off')

    # Assign a unique color to each class for the overlay
    class_colors = {
        'Epithelial': 'red',
        'Lymphocyte': 'yellow',
        'Macrophage': 'cyan',
        'Neutrophil': 'lime'
    }
    
    # 2. Combined Overlay (all classes)
    overlay = np.zeros(masks['Epithelial'].shape + (3,), dtype=np.uint8)
    for i, class_name in enumerate(TARGET_CLASSES):
        mask = masks[class_name]
        # Create a single color layer for this class
        color_rgb = np.array(plt.colormaps.get_cmap('hsv')(i / len(TARGET_CLASSES))[:3]) * 255
        # Find all instance pixels
        class_pixels = mask > 0
        # Add color to the overlay
        overlay[class_pixels] = color_rgb.astype(np.uint8)

    # Blend the image and the mask
    # For visualization, we use a simple transparency overlay
    blended = image.copy()
    alpha = 0.5
    blended[overlay > 0] = (blended[overlay > 0] * (1 - alpha) + overlay[overlay > 0] * alpha).astype(np.uint8)
    
    axes[1].imshow(blended)
    axes[1].set_title('Combined Instance Overlay')
    axes[1].axis('off')

    # 3-6. Individual Class Masks
    for i, class_name in enumerate(TARGET_CLASSES):
        ax = axes[i + 1]
        mask = masks[class_name]
        # Show mask boundaries or color-coded instances
        ax.imshow(mask > 0, cmap='gray') 
        ax.set_title(f'{class_name} ({np.unique(mask[mask>0]).size} Inst.)', color=class_colors.get(class_name, 'white'))
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Run the visualization for the example image
# visualize_masks(EXAMPLE_ID, TRAIN_DIR, decoded_masks) 
# Uncomment the line above to run the visualization in a Jupyter environment

Determining image dimensions...
Loaded annotations for 209 images.

First 3 rows of the processed DataFrame:
  image_id  Height  Width                                         Epithelial
0   slide1     297    204                                                  0
1   slide2     783    915  191 3596 1 191 4378 14 191 5162 15 191 5947 16...
2   slide3     231    275                                                  0

Example Image: slide1 (Shape: 297x204)
  - Epithelial: 0 instances found.
  - Lymphocyte: 0 instances found.
  - Macrophage: 5 instances found.
  - Neutrophil: 0 instances found.


In [3]:
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import watershed

def generate_target_maps(instance_masks: Dict[str, np.ndarray], shape: Tuple[int, int]) -> np.ndarray:
    """
    Generates a multi-channel target map for a U-Net model.
    Channels: [Semantic_Nuclei_Mask, Dist_Map_Centers, Class_Map]
    """
    H, W = shape
    
    # 1. Combined Instance Mask (for distance calculation)
    # Combine all nuclei into one mask, maintaining unique instance IDs
    combined_instance_mask = np.zeros((H, W), dtype=np.uint16)
    
    # 2. Semantic Nuclei Mask (Channel 0: All nuclei, value 1)
    semantic_nuclei_mask = np.zeros((H, W), dtype=np.uint8)
    
    # 3. Class Map (Channel 2: 0=Bkg, 1=Epithelial, 2=Lymphocyte, 3=Macrophage, 4=Neutrophil)
    # The classification targets are crucial for the classification loss.
    class_map = np.zeros((H, W), dtype=np.uint8)
    
    class_to_id = {'Epithelial': 1, 'Lymphocyte': 2, 'Macrophage': 3, 'Neutrophil': 4}
    
    # Track global instance ID across all classes
    current_max_id = 0
    
    for class_name, class_id in class_to_id.items():
        mask = instance_masks.get(class_name, np.zeros((H, W), dtype=np.uint16))
        
        # Get unique instances for this class
        unique_ids = np.unique(mask[mask > 0])
        
        for inst_id in unique_ids:
            instance_region = (mask == inst_id)
            
            # Update Semantic Mask
            semantic_nuclei_mask[instance_region] = 1
            
            # Update Class Map
            class_map[instance_region] = class_id
            
            # Update Combined Instance Mask with unique global IDs
            # This is needed for a single distance map calculation
            # We shift the local instance ID to a global unique ID
            new_global_id = current_max_id + 1
            combined_instance_mask[instance_region] = new_global_id
            current_max_id += 1
            
    # 4. Distance Map to Centers (Channel 1)
    # The distance transform is a good way to create targets for separating instances.
    dist_map = np.zeros((H, W), dtype=np.float32)
    
    # Calculate a distance map for each instance and take the maximum distance for that pixel
    unique_global_ids = np.unique(combined_instance_mask[combined_instance_mask > 0])
    for inst_id in unique_global_ids:
        instance_region = (combined_instance_mask == inst_id)
        # Calculate distance to the boundary of this instance
        dist_to_boundary = distance_transform_edt(instance_region)
        # Update the final distance map: dist_map will hold the distance to the center of the instance it belongs to
        dist_map[instance_region] = dist_to_boundary[instance_region]
        
    # Normalize the distance map for better U-Net training stability
    max_dist = dist_map.max()
    if max_dist > 0:
        dist_map /= max_dist
        
    # Stack the channels: (H, W, 3) -> [Semantic Mask, Normalized Distance Map, Class Map]
    # Note: Class Map is often best handled by applying classification loss only on nuclei pixels
    target = np.stack([
        semantic_nuclei_mask, 
        dist_map, 
        class_map
    ], axis=-1)
    
    return target

# --- End of Target Generation Utilities ---

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

class NucleiDataset(Dataset):
    def __init__(self, df: pd.DataFrame, data_dir: str, transforms: A.Compose = None):
        self.df = df
        self.data_dir = data_dir
        self.transforms = transforms
        self.target_classes = TARGET_CLASSES # Inherited from Section 2

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']
        H, W = row['Height'], row['Width']
        
        # 1. Load Image
        image_path = os.path.join(self.data_dir, f"{image_id}.tif")
        image = np.array(Image.open(image_path).convert("RGB")) # Convert to RGB if needed

        # 2. Decode RLE Masks for all classes
        instance_masks: Dict[str, np.ndarray] = {}
        for class_name in self.target_classes:
            rle_string = row[class_name]
            instance_masks[class_name] = rle_decode_instance_mask(rle_string, (H, W))

        # 3. Generate Target Maps (Semantic, Distance, Class)
        target_maps = generate_target_maps(instance_masks, (H, W))
        
        # Split target maps for augmentation
        semantic_mask = target_maps[:, :, 0]
        dist_map = target_maps[:, :, 1]
        class_map = target_maps[:, :, 2]
        
        # Stack targets for augmentation: (H, W, 3)
        # Note: Class map should be handled carefully if using color/stain augmentations.
        # Here we stack [Image, Semantic Mask, Distance Map, Class Map]
        augmented = None
        if self.transforms:
            # Augmentation applies simultaneously to image and masks
            augmented = self.transforms(image=image, masks=[semantic_mask, dist_map, class_map])
            image = augmented['image']
            
            # The masks list is returned in the order they were passed
            semantic_mask = augmented['masks'][0]
            dist_map = augmented['masks'][1]
            class_map = augmented['masks'][2]

        # Convert to PyTorch tensors and appropriate shapes/types
        image = image.float() / 255.0 # Normalize image
        
        # Targets are 2D, squeeze the HxW to (1, H, W) or keep as (H, W) for direct loss calculation
        # Semantic/Dist maps for segmentation head
        seg_target = torch.stack([torch.tensor(semantic_mask).float(), torch.tensor(dist_map).float()], dim=0) # (2, H, W)
        
        # Class map for classification head (long type for CrossEntropy)
        class_target = torch.tensor(class_map).long() # (H, W)
        
        return {
            'image': image,             # (3, H, W)
            'seg_target': seg_target,   # (2, H, W) -> [Semantic, Distance]
            'class_target': class_target, # (H, W) -> [0, 1, 2, 3, 4]
            'image_id': image_id
        }

# --- Augmentation Pipeline ---
train_transforms = A.Compose([
    # Geometric Augmentations (Essential for Multi-Organ Robustness)
    A.Resize(256, 256), # Resize to a fixed size for mini-batching
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=10, p=0.5),
    
    # Photometric Augmentations (Essential for H&E Stain Variation)
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, p=0.5),
    A.GaussNoise(p=0.2),
    
    # ToTensorV2 converts numpy arrays to PyTorch tensors and moves the channel dim
    ToTensorV2(transpose_mask=False), # False to keep HxW for mask, we stack later
])

# --- Create Dataset and DataLoader ---
# Split data (simplified 90/10 split)
train_split_df = train_df.sample(frac=0.9, random_state=42)
val_split_df = train_df.drop(train_split_df.index)

train_dataset = NucleiDataset(train_split_df, TRAIN_DIR, transforms=train_transforms)
val_dataset = NucleiDataset(val_split_df, TRAIN_DIR, transforms=A.Compose([
    A.Resize(256, 256), # Only resize for validation, no heavy augmentations
    ToTensorV2(transpose_mask=False),
]))

BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) 
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"\nDataLoader created. Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")


DataLoader created. Train size: 188, Validation size: 21


In [19]:
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class SimpleUNet(nn.Module):
    def __init__(self, in_channels, out_seg_channels, out_class_channels):
        super().__init__()
        
        # Encoder
        self.inc = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(2)
        self.conv1 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(256, 512)
        
        # Decoder
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv_up2 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv_up3 = DoubleConv(256, 128)
        self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv_up4 = DoubleConv(128, 64)
        
        # Output Heads (from the final upsampled feature map)
        self.seg_head = nn.Conv2d(64, out_seg_channels, kernel_size=1)
        self.class_head = nn.Conv2d(64, out_class_channels, kernel_size=1)

    def forward(self, x):
        # Encoder passes
        x1 = self.inc(x)
        x2 = self.conv1(self.down1(x1))
        x3 = self.conv2(self.down2(x2))
        x4 = self.conv3(self.down3(x3))
        
        # Decoder passes (with skip connections)
        x = self.up2(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.conv_up2(x)
        
        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.conv_up3(x)
        
        x = self.up4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.conv_up4(x)
        
        # Output Heads
        seg_output = self.seg_head(x)
        class_output = self.class_head(x)
        
        return seg_output, class_output

In [7]:
!pip3 install torch torchvision torchaudio

Collecting torchaudio
  Downloading torchaudio-2.9.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.9 kB)
INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.
  Downloading torchaudio-2.8.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (7.2 kB)
  Downloading torchaudio-2.7.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.6 kB)
  Downloading torchaudio-2.7.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.6 kB)
  Downloading torchaudio-2.6.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.6 kB)
Downloading torchaudio-2.6.0-cp312-cp312-macosx_11_0_arm64.whl (1.8 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchaudio
Successfully installed torchaudio-2.6.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m

In [20]:
import torch

# ----------------- CORRECTED DEVICE SETUP -----------------
# 1. Check for Apple Silicon (MPS) support
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    DEVICE = torch.device("mps")
    print("Using Apple Silicon (MPS) backend for GPU acceleration.")
else:
    DEVICE = torch.device("cpu")
    print("MPS backend not available. Falling back to CPU.")

# ----------------- CORRECTED CLASS_WEIGHTS INITIALIZATION -----------------

# 5 classes: Background (0), Epithelial (1), Lymphocyte (2), Macrophage (3), Neutrophil (4)
# Class Order: [Bkg, Epithelial, Lymphocyte, Macrophage, Neutrophil]

# Using estimated, manually set weights reflecting the problem structure:
CLASS_WEIGHTS = torch.tensor([0.1, 1.0, 1.0, 10.0, 10.0]).to(DEVICE) # <-- Use .to(DEVICE)
print(f"Class weights initialized and moved to {DEVICE}.")

# ----------------- CRITERION FUNCTION (No change needed inside) -----------------

def criterion(seg_output, class_output, seg_target, class_target):
    # ... (function body remains the same)
    # The function uses F.cross_entropy which handles the weights on the device correctly.
    # ...
    # Placeholder for the rest of the function (assuming it's defined in your notebook)
    pass
# Define the actual criterion function here if you are re-running this cell:
def criterion(seg_output, class_output, seg_target, class_target):
    import torch.nn.functional as F
    
    # 1. Segmentation Loss (MSE for Semantic + Distance Map)
    seg_loss = F.mse_loss(seg_output, seg_target)

    # 2. Classification Loss (Weighted Cross Entropy)
    class_output_flat = class_output.permute(0, 2, 3, 1).reshape(-1, class_output.size(1))
    class_target_flat = class_target.reshape(-1)
    
    class_loss = F.cross_entropy(
        class_output_flat, 
        class_target_flat, 
        weight=CLASS_WEIGHTS, # CLASS_WEIGHTS is already on DEVICE
        ignore_index=0
    )
    
    # Combined Loss
    total_loss = seg_loss + 0.5 * class_loss
    
    return total_loss, seg_loss, class_loss

Using Apple Silicon (MPS) backend for GPU acceleration.
Class weights initialized and moved to mps.


In [21]:
from scipy.ndimage import label
from skimage.measure import regionprops

def post_process_watershed(seg_output_np, class_output_np, min_size: int = 10) -> Dict[str, np.ndarray]:
    """
    Applies Watershed segmentation to separate instances and uses the class prediction.
    
    Args:
        seg_output_np: The (2, H, W) segmentation output (Semantic, Distance)
        class_output_np: The (5, H, W) class prediction output (Logits)
    
    Returns:
        A dictionary of {Class_Name: Instance_Mask}
    """
    
    # 1. Semantic Mask and Distance Map from U-Net output
    semantic_pred = (seg_output_np[0] > 0.5).astype(np.uint8) # Nuclei mask from semantic head
    dist_map_pred = seg_output_np[1] # Distance map from distance head

    # 2. Find Markers (Seeds) for Watershed
    # Markers are often the local maxima of the distance map, constrained by the semantic mask
    local_maxima = (dist_map_pred == distance_transform_edt(dist_map_pred, sampling=[1, 1])) * semantic_pred
    markers, num_markers = label(local_maxima)
    
    # If no markers found, return empty masks
    if num_markers == 0:
        return {c: np.zeros(semantic_pred.shape, dtype=np.uint16) for c in TARGET_CLASSES}
    
    # 3. Apply Watershed
    # The distance map is inverted to act as a "landscape" where seeds flow to basins
    # We use the negative distance map for the segmentation function
    watershed_labels = watershed(-dist_map_pred, markers, mask=semantic_pred)

    # 4. Classification
    final_masks: Dict[str, np.ndarray] = {c: np.zeros(semantic_pred.shape, dtype=np.uint16) for c in TARGET_CLASSES}
    class_id_to_name = {1: 'Epithelial', 2: 'Lymphocyte', 3: 'Macrophage', 4: 'Neutrophil'}

    # Get the class prediction for each pixel (Bkg: 0, Epithelial: 1, ...)
    # The classification head output (C, H, W) are logits. We take the argmax.
    class_pred_map = np.argmax(class_output_np, axis=0)

    # Iterate through all detected instances
    for region in regionprops(watershed_labels):
        inst_id = region.label
        coords = region.coords

        # Ensure the instance is not too small (filter noise)
        if len(coords) < min_size:
            continue
            
        # Determine the class for the entire instance
        # Use the majority vote within the instance mask based on the class prediction map
        instance_classes = class_pred_map[coords[:, 0], coords[:, 1]]
        
        # Filter out background (class 0) from the votes
        instance_classes_nuclei = instance_classes[instance_classes > 0]
        if len(instance_classes_nuclei) == 0:
            continue
            
        # Majority vote for the instance's final class
        final_class_id = np.argmax(np.bincount(instance_classes_nuclei))
        
        if final_class_id > 0:
            class_name = class_id_to_name[final_class_id]
            
            # Assign the instance to the correct class mask (using the original watershed ID)
            # IMPORTANT: The instance ID must be unique *per class mask*
            # We track the max ID for the current class mask
            current_max_id = final_masks[class_name].max()
            new_inst_id = current_max_id + 1
            
            # Apply the new ID to the final class mask
            for r, c in coords:
                final_masks[class_name][r, c] = new_inst_id
                
    return final_masks

In [23]:
import torch.optim as optim

# 1. Check for Apple Silicon (MPS) support
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    DEVICE = torch.device("mps")
    print("Using Apple Silicon (MPS) backend for GPU acceleration.")
else:
    DEVICE = torch.device("cpu")
    print("MPS backend not available. Falling back to CPU.")

# 2. Initialize and move model to the correct device
model = SimpleUNet(
    in_channels=3, 
    out_seg_channels=2, # Semantic, Distance
    out_class_channels=5 # Bkg, Epithelial, Lymphocyte, Macrophage, Neutrophil
).to(DEVICE) # <-- Uses the dynamically determined DEVICE ('mps' or 'cpu')

optimizer = optim.Adam(model.parameters(), lr=1e-4)

NUM_EPOCHS = 10 
# In a real competition, this would be much higher, with early stopping.

print(f"\nStarting training on {DEVICE} for {NUM_EPOCHS} epochs...")

# --- Simplified Training Loop ---
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0
    
    for batch_idx, data in enumerate(train_loader):
        images = data['image'].to(DEVICE)
        seg_targets = data['seg_target'].to(DEVICE)
        class_targets = data['class_target'].to(DEVICE)

        optimizer.zero_grad()
        
        # Forward pass
        seg_output, class_output = model(images)
        
        # Calculate loss
        loss, seg_loss, class_loss = criterion(seg_output, class_output, seg_targets, class_targets)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f} (Seg: {seg_loss.item():.4f}, Class: {class_loss.item():.4f})")
    
    # (Optional: Add validation step with wPQ calculation here)

print("Training finished.")

Using Apple Silicon (MPS) backend for GPU acceleration.

Starting training on mps for 10 epochs...


  seg_target = torch.stack([torch.tensor(semantic_mask).float(), torch.tensor(dist_map).float()], dim=0) # (2, H, W)
  class_target = torch.tensor(class_map).long() # (H, W)


Epoch 1/10, Loss: 0.9179 (Seg: 0.1048, Class: 1.4659)
Epoch 2/10, Loss: 0.7261 (Seg: 0.1189, Class: 1.4899)
Epoch 3/10, Loss: 0.6871 (Seg: 0.0597, Class: 1.2909)
Epoch 4/10, Loss: 0.6332 (Seg: 0.0829, Class: 1.3233)
Epoch 5/10, Loss: 0.5967 (Seg: 0.0860, Class: 0.9329)
Epoch 6/10, Loss: 0.5325 (Seg: 0.0600, Class: 1.6381)
Epoch 7/10, Loss: 0.5448 (Seg: 0.0437, Class: 0.9104)
Epoch 8/10, Loss: 0.5195 (Seg: 0.0385, Class: 0.6598)
Epoch 9/10, Loss: 0.4842 (Seg: 0.0544, Class: 0.6249)
Epoch 10/10, Loss: 0.4862 (Seg: 0.0679, Class: 0.4931)
Training finished.


In [None]:
# Load Test Data
TEST_DF = pd.DataFrame({'image_id': [os.path.splitext(f)[0] for f in os.listdir(TEST_DIR) if f.endswith('.tif')]})

# Determine test image shapes
TEST_DF = load_and_decode_masks(os.path.join(DATA_DIR, "train_ground_truth.csv"), TEST_DIR)
# Note: Re-use the load_and_decode_masks function as it includes the shape discovery logic. 
# We need to manually filter the DF to just test images if the directory listing is used.

# For simplicity, assume all images in TEST_DIR were added to TEST_DF with correct shapes.
TEST_DF = TEST_DF[TEST_DF['image_id'].isin([os.path.splitext(f)[0] for f in os.listdir(TEST_DIR) if f.endswith('.tif')])]

print(f"\nLoaded {len(TEST_DF)} test images.")
print(TEST_DF.head())

class TestNucleiDataset(Dataset):
    def __init__(self, df: pd.DataFrame, data_dir: str):
        self.df = df
        self.data_dir = data_dir
        self.transforms = A.Compose([A.Resize(256, 256), ToTensorV2()])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']
        image_path = os.path.join(self.data_dir, f"{image_id}.tif")
        # print(f"Loading test image: {image_path}")
        image = np.array(Image.open(image_path).convert("RGB"))
        
        # The image needs to be stored *before* the transform for correct RLE encoding later
        original_shape = image.shape[:2]
        
        augmented = self.transforms(image=image)
        image_tensor = augmented['image'].float() / 255.0
        
        return {
            'image': image_tensor,
            'image_id': image_id,
            'original_shape': original_shape
        }

test_dataset = TestNucleiDataset(TEST_DF, TEST_DIR)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# --- Submission Generation ---
submission_data = []

model.eval()
with torch.no_grad():
    for data in test_loader:
        image = data['image'].to(DEVICE)
        image_id = data['image_id'][0]
        # Original shape (H, W) for RLE encoding
        orig_H, orig_W = data['original_shape'][0].item(), data['original_shape'][1].item() 

        # Forward pass
        seg_output, class_output = model(image)

        # Convert outputs to numpy and resize back to original shape (critical!)
        seg_output_np = seg_output.cpu().squeeze(0).numpy()
        class_output_np = class_output.cpu().squeeze(0).numpy()
        
        # Resize output back to original image size
        # Use bilinear interpolation for continuous maps (seg) and nearest for class logits
        from skimage.transform import resize
        
        resized_seg_output = np.stack([
            resize(seg_output_np[i], (orig_H, orig_W), order=1, anti_aliasing=False) 
            for i in range(seg_output_np.shape[0])
        ], axis=0)

        resized_class_output = np.stack([
            resize(class_output_np[i], (orig_H, orig_W), order=0, anti_aliasing=False) 
            for i in range(class_output_np.shape[0])
        ], axis=0)
        
        # Post-process (Watershed) to get instance masks
        final_instance_masks = post_process_watershed(resized_seg_output, resized_class_output)
        
        # RLE Encode the final masks
        rle_row = {'image_id': image_id}
        for class_name in TARGET_CLASSES:
            mask = final_instance_masks[class_name]
            rle_row[class_name] = rle_encode_instance_mask(mask)
            
        submission_data.append(rle_row)

# Create final submission CSV
submission_df = pd.DataFrame(submission_data)
# Ensure columns are in the correct order: image_id,Epithelial,Lymphocyte,Neutrophil,Macrophage
SUBMISSION_COLUMNS = ['image_id', 'Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
submission_df = submission_df[SUBMISSION_COLUMNS]

# Save the submission file
SUBMISSION_FILE_PATH = 'submission.csv'
submission_df.to_csv(SUBMISSION_FILE_PATH, index=False)
print(f"\nSubmission file created at: {SUBMISSION_FILE_PATH}")
print(submission_df.head())

Determining image dimensions...
Loaded annotations for 209 images.

Loaded 40 test images.
  image_id                                         Epithelial  \
0   slide1                                                  0   
1   slide2  191 3596 1 191 4378 14 191 5162 15 191 5947 16...   
2   slide3                                                  0   
3   slide4  17 606 1 17 1114 9 61 1425 21 17 1624 13 61 19...   
4   slide5  1 106286 7 1 106708 9 1 107131 11 1 107553 13 ...   

                                          Lymphocyte  \
0                                                  0   
1  6 15974 1 1 16131 2 6 16755 8 1 16911 8 6 1753...   
2                                                  0   
3                                                  0   
4  70 1748 2 69 2128 3 70 2168 8 69 2550 7 70 259...   

                                          Neutrophil  \
0                                                  0   
1                                                  0   
2  2 4323 8 2