In [7]:
import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
from glob import glob
from tifffile import imread
import logging
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from scipy import ndimage as ndi  # <-- IMPORTING SCIPY FOR HV MAPS

# --- HoVer-Net Library Imports ---
from cellseg_models_pytorch import models
from cellseg_models_pytorch.models import hovernet
from cellseg_models_pytorch import losses
# *** ALL BROKEN UTILS IMPORTS ARE REMOVED ***

# --- Configuration ---
CONFIG = {
    "base_dir": "kaggle-data",
    "train_dir": "train",
    "val_dir": "val",
    "model_dir": os.path.join("kaggle-data", "hovernet_model_mps"),
    "model_name": "hovernet_best.pth",
    "class_names": ["Epithelial", "Lymphocyte", "Macrophage", "Neutrophil"],
    # Training parameters
    "n_epochs": 100,
    "batch_size": 1, # Start with 1. Increase to 2 or 4 if your M3 Pro can handle it.
    "lr": 1e-4,
    "num_classes": 5 # 4 classes + 1 background
}

# -----------------------------------------------------------------------------
# NEW HV-MAP FUNCTION (REPLACES 'prep_data')
# -----------------------------------------------------------------------------

def compute_hv_maps(inst_map: np.ndarray) -> np.ndarray:
    """
    Computes the Horizontal/Vertical (HV) maps for HoVer-Net.
    This is the manual implementation of the library's missing 'prep_data'.
    
    Args:
        inst_map (np.ndarray): The instance segmentation map (H, W).
        
    Returns:
        np.ndarray: The HV map (H, W, 2)
    """
    H, W = inst_map.shape
    
    # 1. Create X and Y coordinate grids
    x_coords = np.arange(W, dtype=np.float32)
    y_coords = np.arange(H, dtype=np.float32)
    xx, yy = np.meshgrid(x_coords, y_coords) # (H, W)
    
    # 2. Initialize output maps
    hv_map = np.zeros((H, W, 2), dtype=np.float32)
    
    # 3. Find centers of mass for each nucleus
    # This gives a list of (y, x) tuples
    try:
        centers_of_mass = ndi.center_of_mass(
            inst_map, inst_map, range(1, np.max(inst_map) + 1)
        )
    except Exception:
        # Handle rare case of empty mask
        return hv_map

    # Iterate over each nucleus found
    for inst_id in range(1, np.max(inst_map) + 1):
        if inst_id - 1 < len(centers_of_mass):
            center_y, center_x = centers_of_mass[inst_id - 1]
            
            # 4. Get all pixels for this nucleus
            mask = (inst_map == inst_id)
            
            # 5. Compute vector from each pixel to the center
            # Vector = Center - Pixel_Coord
            hv_map_x = center_x - xx[mask]
            hv_map_y = center_y - yy[mask]
            
            # 6. Normalize the vectors (a key part of HoVer-Net)
            mag_x = np.abs(hv_map_x).max()
            mag_y = np.abs(hv_map_y).max()
            
            if mag_x > 0:
                hv_map_x /= mag_x
            if mag_y > 0:
                hv_map_y /= mag_y
            
            # 7. Assign to the final map
            hv_map[mask, 0] = hv_map_y
            hv_map[mask, 1] = hv_map_x
                
    return hv_map

# -----------------------------------------------------------------------------
# DATA LOADING FUNCTIONS (Unchanged)
# -----------------------------------------------------------------------------

def parse_xml_for_hovernet(xml_path: str) -> (list, list):
    """
    Parses an XML file to get polygons and their associated class index.
    """
    polygons = []
    class_ids = []
    class_map = {name: i + 1 for i, name in enumerate(CONFIG["class_names"])} # 1-indexed classes
    
    if not os.path.exists(xml_path):
        return polygons, class_ids
        
    tree = ET.parse(xml_path)
    root = tree.getroot()
    
    for annotation_node in root.findall('.//Annotation'):
        name_attrib = None
        if 'Name' in annotation_node.attrib and annotation_node.attrib['Name'] in class_map:
            name_attrib = annotation_node.attrib['Name']
        else:
            attrib_node = annotation_node.find('.//Attribute')
            if attrib_node is not None and 'Name' in attrib_node.attrib:
                if attrib_node.attrib['Name'] in class_map:
                    name_attrib = attrib_node.attrib['Name']

        if name_attrib:
            class_id = class_map[name_attrib]
            for region_node in annotation_node.findall('.//Region'):
                vertices = []
                for vertex_node in region_node.findall('.//Vertex'):
                    x = float(vertex_node.get('X'))
                    y = float(vertex_node.get('Y'))
                    vertices.append([x, y])
                if vertices:
                    polygons.append(np.array(vertices, dtype=np.int32))
                    class_ids.append(class_id)
                    
    return polygons, class_ids

def create_hovernet_maps(polygons: list, class_ids: list, height: int, width: int) -> (np.ndarray, np.ndarray):
    """
    Creates the two target maps for HoVer-Net:
    1. Instance Map: Each nucleus has a unique ID (1, 2, 3...)
    2. Type Map: Each nucleus pixel is colored by its class ID (1, 2, 3, or 4)
    """
    inst_map = np.zeros((height, width), dtype=np.uint16)
    type_map = np.zeros((height, width), dtype=np.uint8)
    
    for i, (polygon, class_id) in enumerate(zip(polygons, class_ids)):
        instance_id = i + 1
        cv2.fillPoly(inst_map, [polygon], instance_id)
        cv2.fillPoly(type_map, [polygon], class_id)
        
    return inst_map, type_map

class NucleiDataset(Dataset):
    """
    Custom PyTorch Dataset.
    This loads one image and its masks at a time, processes them,
    and returns them as Tensors. This is memory-safe.
    """
    def __init__(self, data_dir):
        self.image_files = sorted(glob(os.path.join(data_dir, '*.tif')))
        print(f"Dataset for {os.path.basename(data_dir)} found {len(self.image_files)} images.")
        
    def __len__(self):
        return len(self.image_files)
        
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        xml_path = os.path.join(os.path.dirname(img_path), f"{base_name}.xml")
        
        try:
            # 1. Load Image
            image = imread(img_path)
            if image.ndim == 3 and image.shape[-1] == 4:
                image = image[..., :3] # Remove alpha
            height, width = image.shape[:2]

            # 2. Load Annotations
            polygons, class_ids = parse_xml_for_hovernet(xml_path)
            if not polygons:
                # If no polygons, return None to be skipped by collate_fn
                return None

            # 3. Create Intermediate Target Maps
            inst_map, type_map = create_hovernet_maps(polygons, class_ids, height, width)

            # *** THIS IS THE KEY CHANGE ***
            # 4. Use our *own* function to compute the HV map
            hv_map = compute_hv_maps(inst_map)
            
            # 5. Create the Nuclei Pixel (NP) map (binary segmentation)
            np_map = (inst_map > 0).astype(np.uint8)

            # 6. Normalize and Convert to Tensors
            image = (image - image.min()) / (image.max() - image.min() + 1e-6) # Simple 0-1 norm
            image = torch.from_numpy(image).permute(2, 0, 1).float() # HWC -> CHW
            
            # Create the correct target dictionary for mss_loss
            targets = {
                "np_map": torch.from_numpy(np_map).long(),   # (H, W)
                "hv_map": torch.from_numpy(hv_map).float(), # (H, W, 2)
                "type_map": torch.from_numpy(type_map).long() # (H, W)
            }
            
            return image, targets
        
        except Exception as e:
            print(f"Error loading {img_path}: {e}. Skipping.")
            return None

def custom_collate_fn(batch):
    """
    A custom collate function to filter out None values
    (from images with no annotations).
    """
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None, None
    return torch.utils.data.dataloader.default_collate(batch)

# -----------------------------------------------------------------------------
# MANUAL TRAINING FUNCTIONS (Unchanged)
# -----------------------------------------------------------------------------

def train_one_epoch(model, loader, optimizer, criterion, device):
    """
    Performs one full training epoch.
    """
    model.train() # Set model to training mode
    total_loss = 0.0
    
    # Use tqdm for a nice progress bar
    for images, targets in tqdm(loader, desc="Training"):
        # Handle empty batches from collate_fn
        if images is None:
            continue
            
        # Move data to the MPS device
        images = images.to(device)
        targets = {k: v.to(device) for k, v in targets.items()}
        
        # 1. Clear gradients
        optimizer.zero_grad()
        
        # 2. Forward pass
        outputs = model(images)
        
        # 3. Calculate loss
        loss = criterion(outputs, targets)
        
        # 4. Backward pass
        loss.backward()
        
        # 5. Update weights
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    """
    Performs one full validation epoch.
    """
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    
    with torch.no_grad(): # Disable gradient calculations
        for images, targets in tqdm(loader, desc="Validation"):
            # Handle empty batches
            if images is None:
                continue
                
            # Move data to MPS device
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}
            
            # 1. Forward pass
            outputs = model(images)
            
            # 2. Calculate loss
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
            
    return total_loss / len(loader)

# -----------------------------------------------------------------------------
# MAIN FUNCTION (Unchanged)
# -----------------------------------------------------------------------------

def main():
    """
    Main training function for HoVer-Net using MPS.
    """
    print("üöÄ Starting Phase 1: Training the HoVer-Net Model (with MPS)")
    os.makedirs(CONFIG["model_dir"], exist_ok=True)

    # 1. Set up Device (MPS for Apple Silicon)
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("‚úÖ Using Apple MPS for GPU acceleration.")
    else:
        device = torch.device("cpu")
        print("‚ö†Ô∏è MPS not found. Using CPU.")
        
    # 2. Create Datasets and DataLoaders
    print("\n--- Step 1: Loading Datasets ---")
    train_ds = NucleiDataset(data_dir=os.path.join(CONFIG["base_dir"], CONFIG["train_dir"]))
    val_ds = NucleiDataset(data_dir=os.path.join(CONFIG["base_dir"], CONFIG["val_dir"]))

    # We use num_workers=0 to avoid multiprocessing issues on Mac
    train_loader = DataLoader(
        train_ds,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=custom_collate_fn,
        num_workers=0 
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=CONFIG["batch_size"],
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=0
    )
    print("DataLoaders created.")

    # 3. Initialize Model, Loss, and Optimizer
    print("\n--- Step 2: Initializing Model ---")
    
    # This single line gets us the complex HoVer-Net model
    model = hovernet.HoverNet(
        num_classes=CONFIG["num_classes"],
        pretrained_encoder="imagenet" # Use a pre-trained backbone
    ).to(device)
    
    # This single line gets us the complex 3-part loss function
    criterion = losses.mss_loss(num_classes=CONFIG["num_classes"])
    
    optimizer = AdamW(model.parameters(), lr=CONFIG["lr"])
    
    print("Model, Loss, and Optimizer are ready.")

    # 4. Set up the MANUAL Training Loop
    print("\n--- Step 3: Starting Model Training ---")
    
    best_val_loss = float('inf')
    save_path = os.path.join(CONFIG["model_dir"], CONFIG["model_name"])
    
    try:
        for epoch in range(CONFIG["n_epochs"]):
            print(f"\n--- Epoch {epoch+1}/{CONFIG['n_epochs']} ---")
            
            # Run one training epoch
            train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
            
            # Run one validation epoch
            val_loss = evaluate(model, val_loader, criterion, device)
            
            print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # 5. Save the best model
            if val_loss < best_val_loss:
                print(f"New best model! Val loss improved from {best_val_loss:.4f} to {val_loss:.4f}.")
                print(f"Saving to {save_path}")
                best_val_loss = val_loss
                torch.save(model.state_dict(), save_path)
            
    except KeyboardInterrupt:
        print("\n--- Training interrupted by user. ---")
    except Exception as e:
        print(f"\n--- An error occurred during training: {e} ---")
        logging.exception("Training error")

    print(f"\n‚úÖ Phase 1 Complete. Best model saved to: {save_path}")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()

üöÄ Starting Phase 1: Training the HoVer-Net Model (with MPS)
‚úÖ Using Apple MPS for GPU acceleration.

--- Step 1: Loading Datasets ---
Dataset for train found 209 images.
Dataset for val found 45 images.
DataLoaders created.

--- Step 2: Initializing Model ---


TypeError: HoverNet.__init__() got an unexpected keyword argument 'num_classes'

In [None]:
import os
import numpy as np
import cv2
from glob import glob
from tifffile import imread
import logging
import torch
import pandas as pd
from tqdm import tqdm

# --- HoVer-Net Library Imports ---
from cellseg_models_pytorch import models
from cellseg_models_pytorch.post_process import hovernet_post_process

# --- Configuration ---
CONFIG = {
    "base_dir": "kaggle-data",
    "test_dir": "test_final",
    "model_dir": os.path.join("kaggle-data", "hovernet_model_mps"),
    "model_name": "hovernet_best.pth",
    "class_names": ["Epithelial", "Lymphocyte", "Macrophage", "Neutrophil"],
    "num_classes": 5, # 4 classes + 1 background
    "submission_file": "submission_hovernet.csv"
}

def rle_encode(img):
    """
    Encodes a binary mask into Run-Length Encoding (RLE) string.
    """
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def main():
    """
    Main inference function for HoVer-Net using MPS.
    """
    print("üöÄ Starting Inference with trained HoVer-Net Model (with MPS)")

    # 1. Set up Device (MPS for Apple Silicon)
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("‚úÖ Using Apple MPS for GPU acceleration.")
    else:
        device = torch.device("cpu")
        print("‚ö†Ô∏è MPS not found. Using CPU.")

    # 2. Initialize and Load Model
    print("\n--- Step 1: Loading Trained Model ---")
    model_path = os.path.join(CONFIG["model_dir"], CONFIG["model_name"])
    if not os.path.exists(model_path):
        print(f"FATAL ERROR: Model file not found at {model_path}")
        print("Please run train_hovernet_mps.py first.")
        return

    model = models.hovernet(num_classes=CONFIG["num_classes"]).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval() # Set model to evaluation mode
    print(f"Model loaded successfully from {model_path}")

    # 3. Find Test Images
    test_files = sorted(glob(os.path.join(CONFIG["base_dir"], CONFIG["test_dir"], '*.tif')))
    print(f"Found {len(test_files)} test images.")
    
    # 4. Run Inference Loop
    print("\n--- Step 2: Running Inference on Test Images ---")
    results = []
    class_map = {i + 1: name for i, name in enumerate(CONFIG["class_names"])} # 1-indexed

    for img_path in tqdm(test_files, desc="Generating predictions"):
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        
        try:
            # 1. Load and prepare image
            image = imread(img_path)
            if image.ndim == 3 and image.shape[-1] == 4:
                image = image[..., :3]
            
            # Normalize
            image_norm = (image - image.min()) / (image.max() - image.min() + 1e-6)
            
            # Convert to Tensor, add batch dim, send to MPS
            image_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).float()
            image_tensor = image_tensor.unsqueeze(0).to(device) # BCHW
            
            # 2. Run Model
            with torch.no_grad():
                # The model returns a dictionary of outputs
                output = model(image_tensor)
                
            # Move outputs to CPU and remove batch dim
            output = {k: v.cpu().numpy().squeeze() for k, v in output.items()}

            # 3. Post-processing
            # This is the magic function that converts model output into usable masks
            inst_map, type_map = hovernet_post_process(
                output,
                nms_thresh=0.4,
                type_thresh=0.5,
                inst_thresh=0.3
            )

            # 4. Generate RLEs for submission
            rle_masks = {"image_id": base_name}
            for class_id, class_name in class_map.items():
                # Create a binary mask for *this class only*
                binary_mask = (type_map == class_id).astype(np.uint8)
                
                # Encode the mask
                rle_masks[class_name] = rle_encode(binary_mask)
            
            results.append(rle_masks)
            
        except Exception as e:
            print(f"Error processing {img_path}: {e}. Appending empty RLEs.")
            results.append({
                "image_id": base_name,
                **{name: "" for name in CONFIG["class_names"]}
            })

    # 5. Save Submission File
    print("\n--- Step 3: Saving Submission File ---")
    df = pd.DataFrame(
        results,
        columns=["image_id"] + CONFIG["class_names"]
    )
    df.to_csv(CONFIG["submission_file"], index=False)
    print(f"‚úÖ Inference complete. Submission file saved to: {CONFIG['submission_file']}")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()
