In [None]:
# run_nuclei_segmentation.py

import os
import random
import xml.etree.ElementTree as ET
import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision import transforms as T
import warnings

warnings.filterwarnings("ignore")

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

# Set device to MPS (for Apple Silicon) or CUDA or CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")


# --- 1. Configuration ---
class Config:
    """Configuration settings for the project."""
    DATA_PATH = "kaggle-data/"
    TRAIN_DIR = os.path.join(DATA_PATH, "train")
    VAL_DIR = os.path.join(DATA_PATH, "val")
    TEST_DIR = os.path.join(DATA_PATH, "test_final")

    # Model parameters
    NUM_CLASSES = 5  # 4 classes + 1 background
    BATCH_SIZE = 2 # Lower if you run out of memory
    NUM_EPOCHS = 15 # Increase for better results
    LEARNING_RATE = 0.001
    
    # Class mapping
    CLASS_MAP = {
        "Epithelial": 1,
        "Lymphocyte": 2,
        "Macrophage": 3,
        "Neutrophil": 4,
    }
    # Inverse mapping for prediction visualization
    INV_CLASS_MAP = {v: k for k, v in CLASS_MAP.items()}

# Instantiate config
config = Config()

# --- 2. RLE Encoding Helper ---
def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Converts a 2D instance mask (where each object has a unique integer ID)
    into a Run-Length Encoded (RLE) string for the submission format.
    """
    pixels = mask.flatten(order="F").astype(np.int32)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1

    rle = []
    for i in range(0, len(runs) - 1):
        start = runs[i]
        length = runs[i+1] - start
        val = pixels[start]
        if val > 0:
            rle.extend([val, start, length])

    if not rle:
        return "0"  # Return "0" if no instances are found

    return " ".join(map(str, rle))

# --- 3. Custom Dataset ---
class NucleiDataset(Dataset):
    """
    Custom PyTorch Dataset for loading nuclei images and their corresponding
    XML annotations for instance segmentation.
    """
    def __init__(self, image_dir, transforms=None):
        self.image_dir = image_dir
        self.transforms = transforms
        self.image_ids = sorted([f.split('.')[0] for f in os.listdir(image_dir) if f.endswith('.tif')])

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{image_id}.tif")
        xml_path = os.path.join(self.image_dir, f"{image_id}.xml")

        # Open image and convert to RGB
        img = Image.open(img_path).convert("RGB")
        
        # Parse XML annotations
        tree = ET.parse(xml_path)
        root = tree.getroot()

        masks = []
        labels = []
        boxes = []

        for annotation in root.findall(".//Annotation"):
            label_name = annotation.get("Name")
            if label_name in config.CLASS_MAP:
                label_id = config.CLASS_MAP[label_name]
                for region in annotation.findall(".//Region"):
                    vertices = []
                    for vertex in region.findall(".//Vertex"):
                        vertices.append((float(vertex.get("X")), float(vertex.get("Y"))))
                    
                    if not vertices:
                        continue

                    mask = np.zeros((img.height, img.width), dtype=np.uint8)
                    pts = np.array(vertices, dtype=np.int32)
                    cv2.fillPoly(mask, [pts], 1)
                    
                    pos = np.where(mask)
                    if pos[0].size == 0 or pos[1].size == 0:
                        continue

                    xmin = np.min(pos[1])
                    xmax = np.max(pos[1])
                    ymin = np.min(pos[0])
                    ymax = np.max(pos[0])
                    
                    if xmin >= xmax or ymin >= ymax:
                        continue

                    masks.append(mask)
                    labels.append(label_id)
                    boxes.append([xmin, ymin, xmax, ymax])

        # Create target dictionary
        target = {}
        if masks:
            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
            target["masks"] = torch.as_tensor(np.array(masks), dtype=torch.uint8)
            target["image_id"] = torch.tensor([idx])
            target["area"] = (target["boxes"][:, 3] - target["boxes"][:, 1]) * (target["boxes"][:, 2] - target["boxes"][:, 0])
            target["iscrowd"] = torch.zeros((len(boxes),), dtype=torch.int64)
        else: # Handle images with no annotations
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros(0, dtype=torch.int64)
            target["masks"] = torch.zeros((0, img.height, img.width), dtype=torch.uint8)
            target["image_id"] = torch.tensor([idx])
            target["area"] = torch.zeros(0, dtype=torch.float32)
            target["iscrowd"] = torch.zeros((0,), dtype=torch.int64)

        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target

# --- 4. Data Augmentation and Transforms ---
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = T.ToTensor()(image)
        return image, target

def get_transform(train):
    transforms = [ToTensor()]
    return Compose(transforms)
    
def collate_fn(batch):
    """
    Filters out images with no annotations and then prepares the batch.
    """
    # Filter out samples where the target has no boxes.
    batch = list(filter(lambda x: len(x[1]["boxes"]) > 0, batch))
    # If the filtered batch is empty, return None.
    if not batch:
        return None, None
    return tuple(zip(*batch))

# --- 5. Model Definition ---
def get_model(num_classes):
    """
    Loads a pre-trained Mask R-CNN model and modifies its classification heads
    for the specified number of classes.
    """
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

# --- 6. Training Logic ---
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    """Main training loop for one epoch."""
    model.train()
    print_freq = 10
    
    for i, (images, targets) in enumerate(data_loader):
        # Skip iteration if the batch is empty after filtering
        if images is None or not images:
            continue
            
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        if not torch.isfinite(losses):
            print(f"Skipping iteration {i} due to non-finite loss: {losses.item()}")
            continue

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if (i + 1) % print_freq == 0:
            print(f"Epoch [{epoch+1}], Iter [{i+1}/{len(data_loader)}], Loss: {losses.item():.4f}")

# --- 7. Main Execution Block ---
def main():
    dataset_train = NucleiDataset(config.TRAIN_DIR, get_transform(train=True))
    dataset_val = NucleiDataset(config.VAL_DIR, get_transform(train=False))

    data_loader_train = DataLoader(
        dataset_train, batch_size=config.BATCH_SIZE, shuffle=True, 
        num_workers=0, collate_fn=collate_fn
    )
    data_loader_val = DataLoader(
        dataset_val, batch_size=1, shuffle=False, 
        num_workers=0, collate_fn=collate_fn
    )
    
    print("Datasets and DataLoaders created successfully.")

    model = get_model(config.NUM_CLASSES)
    model.to(device)
    print("Model loaded and moved to device.")

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=config.LEARNING_RATE)

    print("Starting training...")
    for epoch in range(config.NUM_EPOCHS):
        train_one_epoch(model, optimizer, data_loader_train, device, epoch)
        print(f"--- Epoch {epoch+1} finished ---")

    print("Training complete.")
    
    MODEL_SAVE_PATH = "nuclei_segmentation_model.pth"
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")

    print("Starting inference on the test set...")
    model.eval()
    
    test_image_ids = sorted([f.split('.')[0] for f in os.listdir(config.TEST_DIR) if f.endswith('.tif')])
    submission_data = []

    for image_id in test_image_ids:
        img_path = os.path.join(config.TEST_DIR, f"{image_id}.tif")
        img = Image.open(img_path).convert("RGB")
        img_tensor = T.ToTensor()(img).to(device)

        with torch.no_grad():
            prediction = model([img_tensor])

        instance_maps = {cls_name: np.zeros((img.height, img.width), dtype=np.int32) 
                         for cls_name in config.CLASS_MAP.keys()}
        
        pred_scores = prediction[0]['scores'].cpu().numpy()
        pred_labels = prediction[0]['labels'].cpu().numpy()
        pred_masks = prediction[0]['masks'].cpu().numpy()

        instance_counters = {cls_name: 1 for cls_name in config.CLASS_MAP.keys()}

        score_threshold = 0.5
        
        for i in range(len(pred_scores)):
            if pred_scores[i] > score_threshold:
                label_id = pred_labels[i]
                if label_id in config.INV_CLASS_MAP:
                    class_name = config.INV_CLASS_MAP[label_id]
                    mask = (pred_masks[i, 0] > 0.5).astype(np.uint8)
                    instance_id = instance_counters[class_name]
                    instance_maps[class_name][(mask == 1) & (instance_maps[class_name] == 0)] = instance_id
                    instance_counters[class_name] += 1

        rle_results = {}
        for class_name, instance_map in instance_maps.items():
            rle_results[class_name] = rle_encode_instance_mask(instance_map)
            
        submission_data.append({
            'image_id': image_id,
            'Epithelial': rle_results['Epithelial'],
            'Lymphocyte': rle_results['Lymphocyte'],
            'Macrophage': rle_results['Macrophage'],
            'Neutrophil': rle_results['Neutrophil'],
        })
        print(f"Processed test image: {image_id}")

    submission_df = pd.DataFrame(submission_data)
    submission_df = submission_df[['image_id', 'Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']]
    submission_df.to_csv("submission.csv", index=False)

    print("\nSubmission file 'submission.csv' created successfully!")


if __name__ == "__main__":
    main()

PyTorch Version: 2.9.0
Torchvision Version: 0.24.0
Using device: mps
Starting training...


KeyboardInterrupt: 