In [6]:
import os
import glob
import pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm.notebook import tqdm
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split

# PyTorch and TorchVision for Deep Learning
# Note: For this single-block example, we will use torchvision's Mask R-CNN as a high-level model placeholder.
# In a full solution, a customized architecture (like Hover-Net or a multi-task U-Net) would be used.
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# --- Configuration and File Paths ---
# IMPORTANT: Adjust this to your actual Kaggle data directory
DATA_DIR = 'kaggle-data'
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR = os.path.join(DATA_DIR, 'test_final')
TRAIN_CSV_PATH = os.path.join(DATA_DIR, 'train_ground_truth.csv')
SUBMISSION_CSV_PATH = 'submission.csv'

# Define the four target classes
CLASS_NAMES = ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil']
NUM_CLASSES = len(CLASS_NAMES) + 1 # +1 for background
CLASS_TO_ID = {name: i + 1 for i, name in enumerate(CLASS_NAMES)}
ID_TO_CLASS = {i + 1: name for i, name in enumerate(CLASS_NAMES)}

# Default image size for training/inference. H&E images are large, so they must be tiled.
# We will use a smaller size for this conceptual example.
TILE_SIZE = (512, 512) 

# --- RLE Encoding/Decoding Functions (Provided in Description) ---
def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Convert an instance segmentation mask (H,W) -> RLE triple string.
    0 = background, >0 = instance IDs.
    """
    # Ensure mask is a numpy array of integers
    if mask.dtype != np.int32:
        pixels = mask.flatten(order="F").astype(np.int32)
    else:
        pixels = mask.flatten(order="F")

    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1

    rle = []
    for i in range(0, len(runs) - 1):
        start = runs[i]
        end = runs[i+1] if i+1 < len(runs) else len(pixels)-1
        length = end - start
        val = pixels[start]
        if val > 0:
            # val: Instance ID, start: 1-based start index, length: run length
            rle.extend([val, start, length])

    if not rle:
        return "0" 

    return " ".join(map(str, rle))

def rle_decode_instance_mask(rle: str, shape: tuple[int, int]) -> np.ndarray:
    """
    Convert RLE triple string back into an instance mask of shape (H, W).
    """
    if not rle or str(rle).strip() in ("", "0", "nan"):
        return np.zeros(shape, dtype=np.uint16)
    try:
        s = list(map(int, rle.split()))
    except ValueError:
        return np.zeros(shape, dtype=np.uint16)
        
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint16)
    for i in range(0, len(s), 3):
        val, start, length = s[i], s[i+1], s[i+2]
        # start is 1-based, mask is 0-based
        mask[start-1:start-1+length] = val 
    return mask.reshape(shape, order="F")

# --- XML Annotation Parser to Generate Instance Masks ---
def get_instance_masks_from_xml(xml_path, image_shape, class_to_id):
    """
    Parses the raw XML annotation file to generate a dictionary of 
    instance masks, one 2D array per cell class.
    
    Args:
        xml_path (str): Path to the .xml file.
        image_shape (tuple): (H, W) of the corresponding image.
        class_to_id (dict): Mapping from class name to ID.

    Returns:
        dict: {class_name: 2D numpy array (H, W) of instance masks}
    """
    H, W = image_shape
    instance_masks = {name: np.zeros((H, W), dtype=np.uint16) for name in class_to_id.keys()}
    
    if not os.path.exists(xml_path):
        return instance_masks

    tree = ET.parse(xml_path)
    root = tree.getroot()

    for annotation in root.findall('.//Annotation'):
        class_name_tag = annotation.find('.//Attribute[@Name]')
        if class_name_tag is None:
            continue
            
        class_name = class_name_tag.get('Name')
        if class_name not in class_to_id:
            continue
            
        class_id = class_to_id[class_name]
        instance_counter = 1

        for region in annotation.findall('.//Region'):
            vertices = region.findall('.//Vertex')
            if not vertices:
                continue
                
            # Extract polygon vertices
            polygon = []
            for vertex in vertices:
                # Vertices are given as floating-point in the XML, convert to int
                x = int(float(vertex.get('X')))
                y = int(float(vertex.get('Y')))
                polygon.append((x, y))
            
            # Draw the polygon mask
            if polygon:
                # cv2.fillPoly expects an array of polygons, each polygon being an array of points
                poly_mask = np.zeros((H, W), dtype=np.uint8)
                cv2.fillPoly(poly_mask, [np.array(polygon)], color=instance_counter)
                
                # Add this instance to the class-specific mask array
                # Ensure no overlap with existing instances in this class mask (shouldn't happen with correct XML)
                # We use bitwise OR to combine new instance with existing ones
                current_class_mask = instance_masks[class_name]
                current_class_mask[poly_mask > 0] = instance_counter
                
                instance_counter += 1
                
    return instance_masks

# --- PyTorch Dataset for Training ---
class NucleiDataset(Dataset):
    def __init__(self, image_ids, base_dir, xml_dir, transforms=None):
        self.image_ids = image_ids
        self.base_dir = base_dir
        self.xml_dir = xml_dir
        self.transforms = transforms

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.base_dir, f'{image_id}.tif')
        xml_path = os.path.join(self.xml_dir, f'{image_id}.xml')

        # Load image (H, W, C)
        image = np.array(Image.open(image_path).convert("RGB"))
        H, W, C = image.shape
        
        # NOTE on large images: Real H&E images are often WSI (Whole Slide Images) 
        # and would need to be tiled/patched here to fit into GPU memory. 
        # For simplicity, we assume the provided .tif files are small enough, but 
        # this is a critical simplification for a real project.

        # Get all instance masks for this image
        instance_masks_per_class = get_instance_masks_from_xml(xml_path, (H, W), CLASS_TO_ID)

        # Combine all class masks into a single set of targets for an instance segmentation model
        boxes = []
        labels = []
        masks = []

        # Instance ID must be unique across all classes in the model target format.
        # We process them class by class.
        for class_name, class_id in CLASS_TO_ID.items():
            class_mask = instance_masks_per_class[class_name]
            
            # Find unique instance IDs (excluding background 0)
            instance_ids = np.unique(class_mask)
            instance_ids = instance_ids[instance_ids != 0]

            for inst_id in instance_ids:
                mask = (class_mask == inst_id).astype(np.uint8)
                
                # Get bounding box from mask
                pos = np.where(mask)
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])

                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_id)
                    masks.append(mask)

        # Convert to Tensors
        image = torch.as_tensor(image, dtype=torch.float32).permute(2, 0, 1) # C, H, W
        
        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
        target["masks"] = torch.as_tensor(np.array(masks), dtype=torch.uint8)
        target["image_id"] = torch.tensor([idx])

        # Apply transforms (e.g., color normalization, augmentation)
        if self.transforms:
             # Using torchvision's transform setup requires a function that takes (image, target)
             # and returns (image, target). A full implementation would use Albumentations 
             # for more robust medical image augmentation.
             pass # Simplified for this example

        return image, target

# --- Model Definition (Mask R-CNN Placeholder) ---
def get_instance_segmentation_model(num_classes):
    # Load a pre-trained Mask R-CNN model (e.g., ResNet50-FPN backbone)
    # The weights are pre-trained on COCO dataset, which is a good starting point.
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=torchvision.models.detection.MaskRCNN_ResNet50_FPN_Weights.COCO_V1)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # Replace the pre-trained box predictor with one that knows our number of classes
    model.roi_heads.box_predictor = FastRCNN(in_features, num_classes)

    # Get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    # Replace the pre-trained mask predictor with one that knows our number of classes
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       model.roi_heads.mask_predictor.conv5_mask.out_channels,
                                                       num_classes)
    return model

# --- Utility: Collate Function for DataLoader ---
def collate_fn(batch):
    return tuple(zip(*batch))

# --- Main Execution Block ---
if __name__ == '__main__':
    # 1. Setup Data Paths and IDs
    train_df = pd.read_csv(TRAIN_CSV_PATH)
    all_image_ids = train_df['image_id'].tolist()
    
    # Simple split for conceptual training/validation
    train_ids, val_ids = train_test_split(all_image_ids, test_size=0.1, random_state=42)
    
    # 2. Setup Dataset and DataLoader
    # NOTE: The full dataset is very large. This uses a conceptual subset/setup.
    train_xml_dir = os.path.join(TRAIN_DIR, 'annotations') # Assuming XMLs are in 'train/annotations'
    train_dataset = NucleiDataset(train_ids, TRAIN_DIR, train_xml_dir)
    
    # 3. Model Initialization (Conceptual)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = get_instance_segmentation_model(NUM_CLASSES)
    model.to(device)
    
    # 4. Conceptual Training Loop (Skipped for brevity, but this is where you handle wPQ optimization)
    # In a real scenario, you would set model.train(), define an optimizer, 
    # use a custom loss or sampler to handle class imbalance, and train for epochs.
    # Training code:
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
    # train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
    # ... train loop ...
    
    # 5. Inference on Test Set
    test_image_paths = sorted(glob.glob(os.path.join(TEST_DIR, '*.tif')))
    test_image_ids = [os.path.basename(p).replace('.tif', '') for p in test_image_paths]

    submission_rows = []
    
    # Set model to evaluation mode
    model.eval()

    print(f"Starting inference on {len(test_image_ids)} test images...")
    for image_id, image_path in tqdm(zip(test_image_ids, test_image_paths), total=len(test_image_ids)):
        try:
            image = np.array(Image.open(image_path).convert("RGB"))
        except FileNotFoundError:
            # Handle case where file might be missing or path is wrong
            print(f"Image file not found for {image_id}. Skipping.")
            continue
            
        H, W, C = image.shape
        # Preprocessing: Convert to PyTorch tensor (C, H, W) and normalize (conceptual)
        input_tensor = torch.as_tensor(image, dtype=torch.float32).permute(2, 0, 1).to(device)
        input_tensor = input_tensor / 255.0 # Simple normalization

        # Run inference
        with torch.no_grad():
            prediction = model([input_tensor])

        if not prediction or not prediction[0]:
            # If no predictions, all RLEs are "0"
            rle_results = {name: "0" for name in CLASS_NAMES}
        else:
            pred = prediction[0]
            # Use a threshold for mask/detection confidence (e.g., 0.5)
            # and a mask threshold (e.g., 0.5) to get binary masks
            keep = pred['scores'] > 0.5
            
            masks = pred['masks'][keep].squeeze(1) # [N, H, W]
            labels = pred['labels'][keep]
            
            # The output masks are floats (logits), convert to binary
            masks = (masks > 0.5).cpu().numpy().astype(np.uint8)

            rle_results = {name: "0" for name in CLASS_NAMES}
            
            # Aggregate instance masks by class
            for class_idx in range(1, NUM_CLASSES):
                class_name = ID_TO_CLASS[class_idx]
                
                # Filter masks and labels for the current class
                class_indices = (labels == class_idx).nonzero(as_tuple=True)[0]
                
                if len(class_indices) > 0:
                    class_masks = masks[class_indices]
                    
                    # Create the final class-specific instance map
                    # The value of each pixel must be the *unique instance ID* (1, 2, 3...)
                    class_instance_map = np.zeros((H, W), dtype=np.uint16)
                    
                    for i, mask in enumerate(class_masks):
                        # i + 1 is the unique instance ID for this class
                        instance_id = i + 1
                        class_instance_map[mask > 0] = instance_id 

                    # Convert the instance map to RLE string
                    rle_string = rle_encode_instance_mask(class_instance_map)
                    rle_results[class_name] = rle_string

        # Collect results for the submission DataFrame
        row = {'image_id': image_id}
        row.update(rle_results)
        submission_rows.append(row)
        
    # 6. Generate Submission CSV
    submission_df = pd.DataFrame(submission_rows)
    
    # Crucial Step: Sort the submission by image_id lexicographically
    submission_df = submission_df.sort_values(by='image_id').reset_index(drop=True)
    
    # Re-order columns to match required submission format: image_id,Epithelial,Lymphocyte,Neutrophil,Macrophage
    # Note the specific order of Neutrophil and Macrophage.
    final_cols = ['image_id', 'Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
    submission_df = submission_df[final_cols]
    
    # Replace any NaN RLEs with the required string "0"
    submission_df = submission_df.fillna("0")
    
    # Save the final submission file
    submission_df.to_csv(SUBMISSION_CSV_PATH, index=False)
    print(f"\nSubmission file created: {SUBMISSION_CSV_PATH}")
    print("\nFirst few rows of submission:")
    print(submission_df.head())

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /Users/swooshie/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:02<00:00, 64.1MB/s] 


NameError: name 'FastRCNN' is not defined