In [4]:
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm import tqdm

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T

# --- Configuration ---
# Set main data directory
DATA_DIR = Path("kaggle-data")
TRAIN_IMG_DIR = DATA_DIR / "train"
VAL_IMG_DIR = DATA_DIR / "val"
TEST_IMG_DIR = DATA_DIR / "test_final"

# Model parameters
NUM_CLASSES = 5 # 4 classes + 1 background
BATCH_SIZE = 2 # Keep this low for CPU training, 1 or 2 is good.
NUM_EPOCHS = 30 # A more realistic number for starting to see results.
LEARNING_RATE = 0.005

# Class mapping
CLASS_MAP = {
    "Epithelial": 1,
    "Lymphocyte": 2,
    "Macrophage": 3,
    "Neutrophil": 4,
}
# Create an inverse map for prediction
INV_CLASS_MAP = {v: k for k, v in CLASS_MAP.items()}

# --- KEY FIX (Device) ---
# Force CPU to avoid MPS bugs on M-series chips.
DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")


# --- RLE Encoding Function (from problem description) ---
def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Convert an instance segmentation mask (H,W) -> RLE triple string.
    0 = background, >0 = instance IDs.
    """
    pixels = mask.flatten(order="F").astype(np.int32)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    rle = []
    for i in range(len(runs) - 1):
        start = runs[i]
        end = runs[i + 1]
        length = end - start
        val = pixels[start]
        if val > 0:
            rle.extend([val, start, length])
    if not rle:
        return "0"
    return " ".join(map(str, rle))


# --- Custom Dataset Class (Verified Logic) ---
class NucleiDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.image_dir = Path(image_dir)
        self.transforms = transforms
        
        all_image_files = sorted([f for f in self.image_dir.glob("*.tif")])
        all_xml_files = sorted([f for f in self.image_dir.glob("*.xml")])

        self.image_files = []
        self.xml_files = []
        
        print(f"Filtering dataset in {image_dir}...")
        for img_path, xml_path in tqdm(zip(all_image_files, all_xml_files), total=len(all_image_files)):
            try:
                tree = ET.parse(xml_path)
                root = tree.getroot()
                
                # --- KEY FIX (Correct Hierarchy Parsing) ---
                has_valid_annotation = False
                
                # Loop over top-level Annotations
                for annotation in root.findall("Annotation"):
                    # Find the class name for this ANNOTATION
                    attribute = annotation.find("Attributes/Attribute")
                    if attribute is None:
                        continue
                    
                    label_name = attribute.get("Name")
                    if label_name in CLASS_MAP:
                        # This annotation is for a class we care about.
                        # Does it contain any regions (instances)?
                        if annotation.find("Regions/Region") is not None:
                             has_valid_annotation = True
                             break # Found a valid class with at least one region, this file is good.
                
                if has_valid_annotation:
                    self.image_files.append(img_path)
                    self.xml_files.append(xml_path)
            except ET.ParseError:
                print(f"Warning: Skipping corrupted XML file: {xml_path}")
        
        print(f"Found {len(self.image_files)} images with valid annotations out of {len(all_image_files)} total.")


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            img_path = self.image_files[idx]
            xml_path = self.xml_files[idx]
            
            # Load image
            img = Image.open(img_path).convert("RGB")
            width, height = img.size

            # Parse XML annotations
            tree = ET.parse(xml_path)
            root = tree.getroot()

            masks, labels, boxes = [], [], []
            
            # --- KEY FIX (Correct Hierarchy Parsing) ---
            # Loop over top-level Annotations
            for annotation in root.findall("Annotation"):
                # Find the class name for this ANNOTATION
                attribute = annotation.find("Attributes/Attribute")
                if attribute is None:
                    continue
                    
                label_name = attribute.get("Name")
                if label_name not in CLASS_MAP:
                    continue
                
                # This is a class we care about, e.g., "Epithelial"
                label_id = CLASS_MAP[label_name]

                # Now, find all regions *within this annotation's Regions tag*
                for region in annotation.findall("Regions/Region"):
                    vertices = []
                    # Try the path from slide1.xml: Region -> Vertices -> Vertex
                    vertex_nodes = region.findall("Vertices/Vertex")
                    
                    if not vertex_nodes:
                        # If that fails, try the other common path: Region -> Vertex
                        vertex_nodes = region.findall("Vertex")

                    for vertex in vertex_nodes:
                        x = float(vertex.get("X"))
                        y = float(vertex.get("Y"))
                        vertices.append((x, y))

                    if not vertices:
                        # This region has a class but no vertices, skip it.
                        continue

                    # Create a binary mask for this single instance
                    instance_mask = Image.new("L", (width, height), 0)
                    ImageDraw.Draw(instance_mask).polygon(vertices, outline=1, fill=1)
                    mask_np = np.array(instance_mask)
                    
                    # Get bounding box from mask
                    pos = np.where(mask_np)
                    if pos[0].size == 0 or pos[1].size == 0:
                        continue # Skip empty masks
                    xmin, xmax = np.min(pos[1]), np.max(pos[1])
                    ymin, ymax = np.min(pos[0]), np.max(pos[0])
                    
                    # Check for valid box area
                    if xmax > xmin and ymax > ymin:
                        masks.append(mask_np)
                        labels.append(label_id)
                        boxes.append([xmin, ymin, xmax, ymax])

            if not boxes: # This should now only happen if a "valid" file has 0-area polygons
                target = {
                    "boxes": torch.zeros((0, 4), dtype=torch.float32),
                    "labels": torch.zeros(0, dtype=torch.int64),
                    "masks": torch.zeros((0, height, width), dtype=torch.uint8)
                }
            else:
                # Convert to tensors
                boxes = torch.as_tensor(boxes, dtype=torch.float32)
                labels = torch.as_tensor(labels, dtype=torch.int64)
                masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)

                target = {"boxes": boxes, "labels": labels, "masks": masks}
            
            # Apply transforms
            img_tensor = T.ToImage()(img) # Convert PIL to tensor
            img_tensor = T.ToDtype(torch.float32, scale=True)(img_tensor) # Normalize to [0,1]
            
            if self.transforms:
                # Apply augmentations if any (e.g., RandomHorizontalFlip)
                # Note: v2 transforms update target in-place
                img_tensor, target = self.transforms(img_tensor, target)

            return img_tensor, target
        
        except Exception as e:
            print(f"Error loading item at index {idx}, path: {self.image_files[idx]}. Error: {e}")
            return None, None # Return None to be filtered by collate_fn

# --- Data Augmentation ---
def get_transforms(is_train):
    transforms = []
    # ToImage() and ToDtype() are now handled in __getitem__
    if is_train:
        # Adds random horizontal flipping for augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
        # --- ADD MORE AUGMENTATIONS HERE ---
        # transforms.append(T.RandomVerticalFlip(0.5))
        # transforms.append(T.RandomRotation(degrees=30))
    
    # --- FIX: Return None if transforms list is empty ---
    if not transforms:
        return None
    
    return T.Compose(transforms)

# --- Model Definition ---
def get_model(num_classes):
    # Load a pre-trained instance segmentation model
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # Replace the box predictor
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    
    return model

# --- Utility for DataLoader ---
def collate_fn(batch):
    # Filter out None entries (from errors in __getitem__)
    batch = [(img, tgt) for img, tgt in batch if img is not None and tgt is not None]
    if not batch:
        return None, None # Return None to be skipped in training loop
    return tuple(zip(*batch))

# --- Training Loop ---
def train_one_epoch(model, optimizer, data_loader, device):
    model.train()
    loop = tqdm(data_loader, leave=True)
    total_loss = 0
    batches_processed = 0
    
    for images, targets in loop:
        # Check for empty batch from collate_fn
        if images is None or targets is None:
            print("Skipping empty or problematic batch.")
            continue
            
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Skip batches that became empty after filtering (e.t., all images had 0 valid instances)
        if len(images) == 0:
            print("Skipping batch with no valid instances after processing.")
            continue
            
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            # Handle potential NaN losses
            if torch.isnan(losses):
                print("Warning: NaN loss detected. Skipping batch.")
                continue

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            total_loss += losses.item()
            batches_processed += 1
            loop.set_postfix(loss=losses.item())
        
        except Exception as e:
            print(f"Error during training step: {e}. Skipping batch.")
            
    return total_loss / batches_processed if batches_processed > 0 else 0.0

# --- NEW: Evaluation Function ---
def evaluate(model, data_loader, device):
    model.train() # <-- FIX: Set to train() to get loss dict, but no_grad() will stop updates
    total_loss = 0
    batches_processed = 0
    loop = tqdm(data_loader, leave=True, desc="Evaluating")
    
    with torch.no_grad(): # Don't calculate gradients
        for images, targets in loop:
            # Check for empty batch from collate_fn
            if images is None or targets is None:
                print("Skipping empty or problematic batch during evaluation.")
                continue
            
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            if len(images) == 0:
                print("Skipping batch with no valid instances during evaluation.")
                continue

            try:
                # During evaluation, the model still needs targets to calculate loss
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                
                if not torch.isnan(losses):
                    total_loss += losses.item()
                    batches_processed += 1
                
                loop.set_postfix(val_loss=losses.item())
            except Exception as e:
                print(f"Error during evaluation step: {e}. Skipping batch.")

    return total_loss / batches_processed if batches_processed > 0 else 0.0

# --- Main Execution Block ---
if __name__ == "__main__":
    
    # 1. Setup Datasets and DataLoaders
    print("Setting up datasets...")
    dataset_train = NucleiDataset(TRAIN_IMG_DIR, transforms=get_transforms(is_train=True))
    dataset_val = NucleiDataset(VAL_IMG_DIR, transforms=get_transforms(is_train=False))

    # --- KEY FIX (DataLoader) ---
    # Set num_workers=0 to avoid multiprocessing hangs on macOS
    train_loader = DataLoader(
        dataset_train, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=0, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        dataset_val, batch_size=1, shuffle=False, 
        num_workers=0, collate_fn=collate_fn
    )

    # 2. Initialize Model, Optimizer
    print("Initializing model...")
    model = get_model(NUM_CLASSES)
    model.to(DEVICE) # Move model to CPU

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=0.9, weight_decay=0.0005)
    
    # # 3. Training
    # print(f"Starting training for {NUM_EPOCHS} epochs on CPU. This will take a while...")
    
    # # --- NEW: Variables to track the best model ---
    # best_val_loss = float('inf')
    best_model_path = "nuclei_maskrcnn_model_cpu_BEST.pth"

    # for epoch in range(NUM_EPOCHS):
    #     # --- Train for one epoch ---
    #     avg_train_loss = train_one_epoch(model, optimizer, train_loader, DEVICE)
        
    #     # --- Evaluate on the validation set ---
    #     avg_val_loss = evaluate(model, val_loader, DEVICE)
        
    #     print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    
    #     # --- NEW: Checkpoint logic ---
    #     # Save the model *only if* the validation loss improved
    #     if avg_val_loss < best_val_loss and avg_val_loss > 0:
    #         best_val_loss = avg_val_loss
    #         torch.save(model.state_dict(), best_model_path)
    #         print(f"New best model saved with validation loss: {avg_val_loss:.4f}")

    # # 4. Prediction/Inference
    # print("--- Training finished ---")
    print(f"Loading best model from {best_model_path} for prediction...")
    
    # --- NEW: Load the best model's weights ---
    # Ensure model is created before loading weights
    model = get_model(NUM_CLASSES) 
    model.load_state_dict(torch.load(best_model_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval() # Set to evaluation mode
    
    test_files = sorted(list(TEST_IMG_DIR.glob("*.tif")))
    
    # --- NEW: Parameter sweep loop ---
    PREDICTION_DIR = Path("combinations-maskrcnn-final")
    PREDICTION_DIR.mkdir(exist_ok=True) # Create the folder if it doesn't exist

    # --- DEFINE YOUR PARAMETERS TO TEST HERE ---
    confidence_threshold_list = [0.1, 0.25, 0.4, 0.5]
    mask_threshold_list = [0.4, 0.5]
    
    print(f"Starting prediction sweep... saving results to {PREDICTION_DIR}")

    for confidence_threshold in confidence_threshold_list:
        for mask_threshold in mask_threshold_list:
            
            print(f"--- Running prediction for conf_thresh={confidence_threshold}, mask_thresh={mask_threshold} ---")
            results = []
            
            for img_path in tqdm(test_files):
                img = Image.open(img_path).convert("RGB")
                # Apply the simple tensor conversion and normalization
                img_tensor = T.ToImage()(img)
                img_tensor = T.ToDtype(torch.float32, scale=True)(img_tensor)
                
                with torch.no_grad():
                    prediction = model([img_tensor.to(DEVICE)])[0]
                    
                # Initialize one instance mask per class
                instance_masks = {class_name: np.zeros((img.height, img.width), dtype=np.int32) for class_name in CLASS_MAP.keys()}
                instance_counters = {class_name: 1 for class_name in CLASS_MAP.keys()}

                # Filter predictions by score
                scores = prediction['scores'].cpu().numpy()
                high_conf_indices = np.where(scores > confidence_threshold)[0]

                for i in high_conf_indices:
                    label_id = prediction['labels'][i].item()
                    if label_id not in INV_CLASS_MAP:
                        continue
                    
                    class_name = INV_CLASS_MAP[label_id]
                    
                    # Get the binary mask by thresholding the soft mask
                    mask = prediction['masks'][i, 0].cpu().numpy()
                    binary_mask = (mask > mask_threshold).astype(np.uint8)
                    
                    # Add instance to the correct class mask with a unique ID
                    instance_id = instance_counters[class_name]
                    # Ensure no overlap: only assign ID where mask is 1 AND current pixel is 0
                    instance_masks[class_name][(binary_mask == 1) & (instance_masks[class_name] == 0)] = instance_id
                    instance_counters[class_name] += 1
                    
                # RLE encode each class mask
                rle_results = {
                    "image_id": img_path.stem,
                    "Epithelial": rle_encode_instance_mask(instance_masks["Epithelial"]),
                    "Lymphocyte": rle_encode_instance_mask(instance_masks["Lymphocyte"]),
                    "Macrophage": rle_encode_instance_mask(instance_masks["Macrophage"]),
                    "Neutrophil": rle_encode_instance_mask(instance_masks["Neutrophil"])
                }
                results.append(rle_results)

            # 5. Create Submission CSV
            csv_filename = PREDICTION_DIR / f"submission_conf{confidence_threshold}_mask{mask_threshold}.csv"
            submission_df = pd.DataFrame(results, columns=["image_id", "Epithelial", "Lymphocyte", "Neutrophil", "Macrophage"])
            submission_df.to_csv(csv_filename, index=False)
            
            print(f"Saved {csv_filename}")
    
    print("--- Prediction sweep finished! ---")



Using device: cpu
Setting up datasets...
Filtering dataset in kaggle-data/train...


100%|██████████| 209/209 [00:02<00:00, 73.99it/s]


Found 209 images with valid annotations out of 209 total.
Filtering dataset in kaggle-data/val...


100%|██████████| 45/45 [00:00<00:00, 80.53it/s]


Found 45 images with valid annotations out of 45 total.
Initializing model...
Loading best model from nuclei_maskrcnn_model_cpu_BEST.pth for prediction...
Starting prediction sweep... saving results to combinations-maskrcnn-final
--- Running prediction for conf_thresh=0.1, mask_thresh=0.4 ---


100%|██████████| 40/40 [01:25<00:00,  2.14s/it]


Saved combinations-maskrcnn-final/submission_conf0.1_mask0.4.csv
--- Running prediction for conf_thresh=0.1, mask_thresh=0.5 ---


100%|██████████| 40/40 [01:12<00:00,  1.80s/it]


Saved combinations-maskrcnn-final/submission_conf0.1_mask0.5.csv
--- Running prediction for conf_thresh=0.25, mask_thresh=0.4 ---


100%|██████████| 40/40 [01:25<00:00,  2.13s/it]


Saved combinations-maskrcnn-final/submission_conf0.25_mask0.4.csv
--- Running prediction for conf_thresh=0.25, mask_thresh=0.5 ---


100%|██████████| 40/40 [01:25<00:00,  2.14s/it]


Saved combinations-maskrcnn-final/submission_conf0.25_mask0.5.csv
--- Running prediction for conf_thresh=0.4, mask_thresh=0.4 ---


100%|██████████| 40/40 [01:27<00:00,  2.19s/it]


Saved combinations-maskrcnn-final/submission_conf0.4_mask0.4.csv
--- Running prediction for conf_thresh=0.4, mask_thresh=0.5 ---


100%|██████████| 40/40 [01:27<00:00,  2.19s/it]


Saved combinations-maskrcnn-final/submission_conf0.4_mask0.5.csv
--- Running prediction for conf_thresh=0.5, mask_thresh=0.4 ---


100%|██████████| 40/40 [01:25<00:00,  2.14s/it]


Saved combinations-maskrcnn-final/submission_conf0.5_mask0.4.csv
--- Running prediction for conf_thresh=0.5, mask_thresh=0.5 ---


100%|██████████| 40/40 [01:24<00:00,  2.12s/it]

Saved combinations-maskrcnn-final/submission_conf0.5_mask0.5.csv
--- Prediction sweep finished! ---



