In [1]:
!pip install cellpose stardist scikit-image scikit-learn opencv-python torch torchvision pandas tqdm scipy joblib


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import os
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import warnings
warnings.filterwarnings('ignore')

# Import pretrained models
from cellpose import models as cellpose_models
from cellpose.models import CellposeModel
import stardist
from stardist.models import StarDist2D
from stardist.data import test_image_nuclei
import xml.etree.ElementTree as ET
from scipy import ndimage
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
import joblib

# Setup device - MPS for MacBook M3, fallback to CPU
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS (Metal Performance Shaders)")
else:
    device = torch.device('cpu')
    print("Using CPU")

# ============================================================================
# 1. RLE ENCODING/DECODING
# ============================================================================

def rle_decode_instance_mask(rle: str, shape: tuple) -> np.ndarray:
    """Decode RLE string back to instance mask"""
    if not rle or str(rle).strip() in ("", "0", "nan"):
        return np.zeros(shape, dtype=np.uint16)
    s = list(map(int, rle.split()))
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint16)
    for i in range(0, len(s), 3):
        val, start, length = s[i], s[i+1], s[i+2]
        mask[start-1:start-1+length] = val
    return mask.reshape(shape, order="F")

def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """Encode instance mask to RLE string"""
    pixels = mask.flatten(order="F").astype(np.int32)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    
    rle = []
    for i in range(0, len(runs)-1):
        start = runs[i]
        end = runs[i+1] if i+1 < len(runs) else len(pixels)-1
        length = end - start
        val = pixels[start]
        if val > 0:
            rle.extend([val, start, length])
    
    return " ".join(map(str, rle)) if rle else "0"

# ============================================================================
# 2. XML ANNOTATION PARSING
# ============================================================================

def parse_xml_annotations(xml_path):
    """Parse raw XML annotations to get instance masks and labels"""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    annotations = {}
    for annotation in root.findall('Annotation'):
        cell_type = annotation.get('Name')
        if cell_type not in ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil']:
            continue
        
        regions = []
        for region in annotation.findall('Region'):
            vertices = []
            for vertex in region.findall('Vertex'):
                x = int(vertex.get('X'))
                y = int(vertex.get('Y'))
                vertices.append([x, y])
            if vertices:
                regions.append(np.array(vertices, dtype=np.int32))
        
        annotations[cell_type] = regions
    
    return annotations

def create_instance_masks_from_xml(xml_path, image_shape):
    """Create instance segmentation masks from XML polygons"""
    annotations = parse_xml_annotations(xml_path)
    
    masks = {}
    for cell_type in ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil']:
        instance_mask = np.zeros(image_shape[:2], dtype=np.uint16)
        
        if cell_type in annotations:
            for instance_id, polygon in enumerate(annotations[cell_type], 1):
                cv2.drawContours(instance_mask, [polygon], 0, instance_id, -1)
        
        masks[cell_type] = instance_mask
    
    return masks

# ============================================================================
# 3. FEATURE EXTRACTION FOR CLASSIFICATION
# ============================================================================

def extract_nucleus_features(image, instance_mask, nucleus_id):
    """Extract features for a single nucleus"""
    nucleus_region = (instance_mask == nucleus_id)
    
    # Get the bounding box
    coords = np.where(nucleus_region)
    if len(coords[0]) == 0:
        return None
    
    y_min, y_max = coords[0].min(), coords[0].max()
    x_min, x_max = coords[1].min(), coords[1].max()
    
    # Extract intensity statistics
    nucleus_pixels = image[nucleus_region]
    
    features = {
        'mean_intensity': nucleus_pixels.mean(),
        'std_intensity': nucleus_pixels.std(),
        'min_intensity': nucleus_pixels.min(),
        'max_intensity': nucleus_pixels.max(),
        'area': nucleus_region.sum(),
        'perimeter': cv2.contourArea(np.where(nucleus_region)),
        'solidity': nucleus_region.sum() / ((y_max - y_min + 1) * (x_max - x_min + 1)),
    }
    
    return features

def extract_training_features(train_dir, train_csv, xml_dir):
    """Extract features from training set for classifier"""
    df = pd.read_csv(train_csv)
    
    all_features = []
    all_labels = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Extracting features"):
        img_id = row['image_id']
        img_path = Path(train_dir) / f"{img_id}.tif"
        xml_path = Path(xml_dir) / f"{img_id}.xml"
        
        if not img_path.exists() or not xml_path.exists():
            continue
        
        image = cv2.imread(str(img_path))
        if image is None:
            continue
        
        # Convert to grayscale for intensity features
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        masks = create_instance_masks_from_xml(xml_path, image.shape)
        
        for cell_type, mask in masks.items():
            for nucleus_id in np.unique(mask):
                if nucleus_id == 0:
                    continue
                
                feat = extract_nucleus_features(gray, mask, nucleus_id)
                if feat:
                    all_features.append(feat)
                    label = ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil'].index(cell_type)
                    all_labels.append(label)
    
    features_df = pd.DataFrame(all_features)
    return features_df.values, np.array(all_labels)

# ============================================================================
# 4. SEGMENTATION WITH CELLPOSE
# ============================================================================

class CellposeSegmenter:
    def __init__(self, device='cpu'):
        self.device = device
        # Use pretrained nuclei model
        self.model = CellposeModel(gpu=(device=='mps' or device=='cuda'), 
                             model_type='nuclei',
                             device=device)
    
    def segment(self, image):
        """Segment nuclei in image"""
        # Cellpose expects RGB images
        if len(image.shape) == 3 and image.shape[2] == 3:
            img = image
        else:
            img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Run segmentation
        masks, _, _, _ = self.model.eval(img, diameter=30, channels=[0, 0])
        
        return masks

# ============================================================================
# 5. SEGMENTATION WITH STARDIST
# ============================================================================

class StarDistSegmenter:
    def __init__(self):
        # Use pretrained 2D nuclei model
        self.model = StarDist2D.from_pretrained('2D_nuclei_obj')
    
    def segment(self, image):
        """Segment nuclei in image"""
        if len(image.shape) == 3 and image.shape[2] == 3:
            img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            img = image
        
        # Normalize to 0-1
        img = (img / 255.0).astype(np.float32)
        
        labels, _ = self.model.predict_instances(img)
        
        return labels

# ============================================================================
# 6. ENSEMBLE SEGMENTATION
# ============================================================================

def ensemble_segment(image, use_cellpose=True, use_stardist=True):
    """Ensemble both models for robust segmentation"""
    results = []
    
    if use_cellpose:
        try:
            segmenter = CellposeSegmenter(device=str(device))
            cellpose_mask = segmenter.segment(image)
            results.append(cellpose_mask)
        except Exception as e:
            print(f"Cellpose failed: {e}")
    
    if use_stardist:
        try:
            segmenter = StarDistSegmenter()
            stardist_mask = segmenter.segment(image)
            results.append(stardist_mask)
        except Exception as e:
            print(f"StarDist failed: {e}")
    
    if not results:
        raise Exception("Both segmentation models failed")
    
    if len(results) == 1:
        return results[0]
    
    # Ensemble: take union of detections and merge close instances
    ensemble_mask = np.zeros_like(results[0])
    current_id = 1
    
    for result in results:
        for nucleus_id in np.unique(result):
            if nucleus_id == 0:
                continue
            nucleus = (result == nucleus_id)
            ensemble_mask[nucleus] = current_id
            current_id += 1
    
    return ensemble_mask

# ============================================================================
# 7. CELL TYPE CLASSIFICATION
# ============================================================================

class NucleusClassifier:
    def __init__(self):
        self.model = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
        self.scaler = StandardScaler()
        self.is_fitted = False
    
    def train(self, features, labels):
        """Train classifier on extracted features"""
        features_scaled = self.scaler.fit_transform(features)
        self.model.fit(features_scaled, labels)
        self.is_fitted = True
    
    def predict(self, features):
        """Predict cell types"""
        if not self.is_fitted:
            raise Exception("Classifier not trained")
        features_scaled = self.scaler.transform(features)
        return self.model.predict(features_scaled)
    
    def save(self, path):
        joblib.dump({'model': self.model, 'scaler': self.scaler}, path)
    
    def load(self, path):
        data = joblib.load(path)
        self.model = data['model']
        self.scaler = data['scaler']
        self.is_fitted = True

# ============================================================================
# 8. INFERENCE PIPELINE
# ============================================================================

def predict_image_with_classification(image_path, segmenter, classifier, device):
    """Segment and classify nuclei in an image"""
    image = cv2.imread(str(image_path))
    if image is None:
        raise ValueError(f"Cannot read image: {image_path}")
    
    # Segment
    instance_mask = ensemble_segment(image)
    
    # Convert to grayscale for feature extraction
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # Classify each nucleus
    classified_masks = {
        'Epithelial': np.zeros_like(instance_mask),
        'Lymphocyte': np.zeros_like(instance_mask),
        'Macrophage': np.zeros_like(instance_mask),
        'Neutrophil': np.zeros_like(instance_mask)
    }
    
    cell_types = ['Epithelial', 'Lymphocyte', 'Macrophage', 'Neutrophil']
    current_ids = {ct: 1 for ct in cell_types}
    
    for nucleus_id in np.unique(instance_mask):
        if nucleus_id == 0:
            continue
        
        # Extract features
        feat = extract_nucleus_features(gray, instance_mask, nucleus_id)
        if feat is None:
            continue
        
        feat_array = np.array(list(feat.values())).reshape(1, -1)
        predicted_label = classifier.predict(feat_array)[0]
        cell_type = cell_types[predicted_label]
        
        # Assign to classified mask
        nucleus_region = (instance_mask == nucleus_id)
        classified_masks[cell_type][nucleus_region] = current_ids[cell_type]
        current_ids[cell_type] += 1
    
    return classified_masks

# ============================================================================
# 9. MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    data_dir = Path("kaggle-data")
    train_dir = data_dir / "train"
    val_dir = data_dir / "val"
    test_dir = data_dir / "test_final"
    
    train_csv = data_dir / "train_ground_truth.csv"
    val_csv = data_dir / "val_truth.csv"
    
    print("=" * 80)
    print("CELL SEGMENTATION & CLASSIFICATION PIPELINE")
    print("=" * 80)
    
    # Step 1: Train classifier
    print("\n[1/4] Extracting features for classifier training...")
    try:
        X_train, y_train = extract_training_features(train_dir, train_csv, train_dir)
        
        classifier = NucleusClassifier()
        print(f"Training classifier on {len(X_train)} nuclei...")
        classifier.train(X_train, y_train)
        classifier.save("nucleus_classifier.pkl")
        print("✓ Classifier trained and saved")
    except Exception as e:
        print(f"⚠ Classifier training failed: {e}")
        classifier = None
    
    # Step 2: Validate on validation set
    print("\n[2/4] Validating on validation set...")
    try:
        val_images = sorted(val_dir.glob("*.tif"))[:3]  # Test on few images
        for img_path in val_images:
            print(f"Processing {img_path.stem}...")
            masks = predict_image_with_classification(img_path, None, classifier, device)
            print(f"  Found: {sum(np.max(m) for m in masks.values())} nuclei")
    except Exception as e:
        print(f"⚠ Validation failed: {e}")
    
    # Step 3: Generate predictions on test set
    print("\n[3/4] Generating predictions on test set...")
    test_images = sorted(test_dir.glob("*.tif"))
    
    results = []
    for img_path in tqdm(test_images, desc="Predicting"):
        img_id = img_path.stem
        
        try:
            if classifier and classifier.is_fitted:
                masks = predict_image_with_classification(img_path, None, classifier, device)
            else:
                # Fallback to segmentation only
                image = cv2.imread(str(img_path))
                instance_mask = ensemble_segment(image)
                masks = {
                    'Epithelial': instance_mask,
                    'Lymphocyte': np.zeros_like(instance_mask),
                    'Macrophage': np.zeros_like(instance_mask),
                    'Neutrophil': np.zeros_like(instance_mask)
                }
            
            row = {'image_id': img_id}
            for cell_type in ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']:
                row[cell_type] = rle_encode_instance_mask(masks[cell_type])
            
            results.append(row)
        
        except Exception as e:
            print(f"⚠ Prediction failed for {img_id}: {e}")
            row = {'image_id': img_id}
            for cell_type in ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']:
                row[cell_type] = "0"
            results.append(row)
    
    # Step 4: Save submission
    print("\n[4/4] Saving submission...")
    submission_df = pd.DataFrame(results)
    submission_df.to_csv("submission.csv", index=False)
    print(f"✓ Submission saved: {len(results)} images")
    print("=" * 80)