In [1]:
# --- 1. Setup and Imports ---

import sys
import subprocess
import importlib

print("Assuming packages are pre-installed from previous notebooks. Skipping installation step.")

import os
import gc
import glob
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import segmentation_models_pytorch as smp
from albumentations import ToTensorV2
from albumentations.pytorch import ToTensorV2

Assuming packages are pre-installed from previous notebooks. Skipping installation step.


In [2]:
# --- 2. Configuration ---
import os

class CFG:
    # General
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data Paths
    TEST_PATH = 'test'
    VALID_PATH = 'train'
    TEST_FRAGMENTS = sorted([d for d in os.listdir('test') if os.path.isdir(os.path.join('test', d))])
    CALIBRATION_FRAGMENT_ID = '2' # Use fragment 2 from train set for calibration

    # Data Reading (must match training)
    Z_START = 20
    Z_END = 44
    IN_CHANS = (Z_END - Z_START) + 1

    # Tiling (must match training)
    TILE_SIZE = 256
    STRIDE = TILE_SIZE // 2

    # Model
    BACKBONE = 'timm-efficientnet-b4'
    MODEL_PATH = 'best_robust_model.pth' # Using the new robust model

    # Inference Strategy
    USE_TTA = True
    USE_CALIBRATION = True
    BATCH_SIZE = 16 # Can be adjusted based on memory
    
    # These will be dynamically set by the calibration step
    BEST_THRESHOLD = 0.45
    MIN_AREA_SIZE = 100

print(f"Device: {CFG.DEVICE}")
print(f"Model Path: {CFG.MODEL_PATH}")
print(f"Input Channels: {CFG.IN_CHANS}")
print(f"Tile Size: {CFG.TILE_SIZE}")
print(f"Stride for tiling: {CFG.STRIDE}")
print(f"Discovered test fragments: {CFG.TEST_FRAGMENTS}")
print(f"Calibration enabled: {CFG.USE_CALIBRATION} on fragment {CFG.CALIBRATION_FRAGMENT_ID}")
print(f"TTA enabled: {CFG.USE_TTA}")

Device: cuda
Model Path: best_robust_model.pth
Input Channels: 25
Tile Size: 256
Stride for tiling: 128
Discovered test fragments: ['a']
Calibration enabled: True on fragment 2
TTA enabled: True


In [3]:
# --- 3. Advanced Helper Functions & Dataset ---

def get_hann_window(size):
    """Creates a 2D Hanning window."""
    hann_1d = np.hanning(size)
    hann_2d = np.outer(hann_1d, hann_1d)
    return hann_2d

def remove_small_components(mask, min_size):
    """Removes small connected components from a binary mask."""
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    # Start from 1 to ignore the background label 0
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] < min_size:
            mask[labels == i] = 0
    return mask

def rle_encode(mask):
    """Encodes a binary mask into Run-Length Encoding format (column-major)."""
    # The competition requires column-major order, so we transpose the mask
    pixels = mask.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def get_img_stack(fragment_id, z_start, z_end, data_path, simulate_ir_absence=False):
    """
    Loads a stack of TIF images and the IR image for a given fragment.
    Applies per-channel percentile normalization.
    """
    images = []
    
    # Load TIF slices
    for i in range(z_start, z_end):
        image_path = os.path.join(data_path, fragment_id, 'surface_volume', f'{i:02}.tif')
        image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        if image is None:
            raise FileNotFoundError(f"TIF file not found or failed to read: {image_path}")
        images.append(image.astype(np.float32))

    # Load IR image
    ir_path = os.path.join(data_path, fragment_id, 'ir.png')
    ir_image = cv2.imread(ir_path, cv2.IMREAD_UNCHANGED)
    
    # Handle missing or simulated-missing IR
    if ir_image is None or simulate_ir_absence:
        if ir_image is None:
            print(f"Warning: IR file not found at '{ir_path}'.")
        print("IR Fallback: Using mean of TIF slices as IR channel.")
        # EXPERT FIX: Keep dtype as float32 to avoid precision loss.
        ir_image = np.mean(np.stack(images, axis=0), axis=0).astype(np.float32)
    else:
        ir_image = ir_image.astype(np.float32)

    # Ensure IR image has the same dimensions as the TIF slices
    if ir_image.shape != images[0].shape:
        print(f"Warning: IR image shape {ir_image.shape} differs from TIF shape {images[0].shape}. Resizing IR to match.")
        ir_image = cv2.resize(ir_image, (images[0].shape[1], images[0].shape[0]), interpolation=cv2.INTER_NEAREST)

    images.append(ir_image)
    
    # Per-channel percentile normalization
    normalized_images = []
    for i, img in enumerate(images):
        p1, p99 = np.percentile(img, [1, 99])
        img_normalized = (img - p1) / (p99 - p1 + 1e-6)
        img_normalized = np.clip(img_normalized, 0, 1)
        normalized_images.append(img_normalized)
        
    return np.stack(normalized_images, axis=-1)

class VesuviusTestDataset(Dataset):
    """
    Dataset for inference. Assumes images are already pre-processed and normalized.
    """
    def __init__(self, tiles, fragment_images, tile_size):
        self.tiles = tiles
        self.fragment_images = fragment_images
        self.tile_size = tile_size

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        y, x = self.tiles[idx]
        
        # Get tile from the pre-loaded, pre-normalized fragment images
        image_tile = self.fragment_images[y:y + self.tile_size, x:x + self.tile_size, :]
        
        # Transpose from HWC to CHW
        image = np.transpose(image_tile, (2, 0, 1))
        
        return torch.from_numpy(image).float()

In [4]:
# --- 4. Model Loading ---

# Define the model architecture (must match training)
model = smp.FPN(
    encoder_name=CFG.BACKBONE,
    encoder_weights=None,  # Weights will be loaded from file
    in_channels=CFG.IN_CHANS,
    classes=1,
    activation=None,
)

# Load the trained weights
model.load_state_dict(torch.load(CFG.MODEL_PATH))
model.to(CFG.DEVICE)
model.eval()

print(f"Model loaded from {CFG.MODEL_PATH} and moved to {CFG.DEVICE}.")

Model loaded from best_robust_model.pth and moved to cuda.


In [5]:
# --- 5. TTA and Prediction Functions ---

def tta_predict(model, image_batch):
    """Performs 8-way Test-Time Augmentation and returns averaged logits."""
    # EXPERT FIX: Avoid wasted forward pass by initializing tensor with known shape.
    B, C, H, W = image_batch.shape
    logits_tta = torch.zeros((B, 1, H, W), device=image_batch.device, dtype=image_batch.dtype)

    # Original
    logits_tta += model(image_batch)

    # Horizontal Flip
    logits_tta += torch.flip(model(torch.flip(image_batch, dims=[3])), dims=[3])

    # Rotations (90, 180, 270) and their flips
    for k in [1, 2, 3]:
        img_rot = torch.rot90(image_batch, k, [2, 3])
        # Rotated
        logits_tta += torch.rot90(model(img_rot), -k, [2, 3])
        # Rotated + Flipped
        logits_tta += torch.flip(torch.rot90(model(torch.flip(img_rot, dims=[3])), -k, [2, 3]), dims=[3])

    return logits_tta / 8.0

def predict_fragment(model, fragment_images, roi_mask):
    """
    Runs full-image inference on a fragment using overlapping tiles,
    Hanning window blending, and optional TTA. Returns the final logit map.
    """
    img_height, img_width, _ = fragment_images.shape
    
    # Canvases for blending
    logit_canvas = np.zeros((img_height, img_width), dtype=np.float32)
    weight_canvas = np.zeros((img_height, img_width), dtype=np.float32)
    
    # Hanning window for smooth blending
    hann_window = get_hann_window(CFG.TILE_SIZE)
    
    # Generate tile coordinates with full coverage
    tiles = []
    for y in range(0, img_height, CFG.STRIDE):
        for x in range(0, img_width, CFG.STRIDE):
            y_start = min(y, img_height - CFG.TILE_SIZE)
            x_start = min(x, img_width - CFG.TILE_SIZE)
            # Only predict on tiles that have some overlap with the ROI
            if (roi_mask[y_start:y_start+CFG.TILE_SIZE, x_start:x_start+CFG.TILE_SIZE] > 0).mean() > 0.1:
                if (y_start, x_start) not in tiles:
                    tiles.append((y_start, x_start))
    
    print(f"Generated {len(tiles)} tiles for prediction.")
    
    # Create dataset and dataloader
    dataset = VesuviusTestDataset(tiles, fragment_images, CFG.TILE_SIZE)
    dataloader = DataLoader(dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    
    # Inference loop
    with torch.no_grad():
        for i, images_batch in enumerate(tqdm(dataloader, desc="Predicting tiles")):
            images_batch = images_batch.to(CFG.DEVICE)
            
            if CFG.USE_TTA:
                logits_batch = tta_predict(model, images_batch)
            else:
                logits_batch = model(images_batch)
            
            logits_batch = logits_batch.cpu().numpy()
            
            # Stitch logits back with Hanning blending
            for j, (y, x) in enumerate(tiles[i*CFG.BATCH_SIZE : (i+1)*CFG.BATCH_SIZE]):
                logit_canvas[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE] += logits_batch[j, 0, :, :] * hann_window
                weight_canvas[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE] += hann_window
    
    # Normalize logits by weights
    logit_canvas /= (weight_canvas + 1e-6)
    
    return logit_canvas

In [7]:
# --- 6. Calibration Step ---
from scipy.ndimage import mean as ndimage_mean
def fbeta_score(y_true, y_pred, beta=0.5):
    """Calculates the F-beta score."""
    tp = np.sum(y_true * y_pred)
    fp = np.sum((1 - y_true) * y_pred)
    fn = np.sum(y_true * (1 - y_pred))
    
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    
    fbeta = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + 1e-6)
    return fbeta, precision, recall

def calibrate_parameters(model):
    """
    Calibrates the threshold and min_area_size on a validation fragment
    by simulating test conditions (missing IR). Uses a highly optimized grid search.
    """
    print("\n--- Starting Parameter Calibration ---")
    fragment_id = CFG.CALIBRATION_FRAGMENT_ID
    
    # Load validation data
    print(f"Loading validation fragment {fragment_id} for calibration...")
    roi_mask = cv2.imread(os.path.join(CFG.VALID_PATH, fragment_id, 'mask.png'), cv2.IMREAD_GRAYSCALE)
    gt_mask = cv2.imread(os.path.join(CFG.VALID_PATH, fragment_id, 'inklabels.png'), cv2.IMREAD_GRAYSCALE)
    gt_mask = (gt_mask > 0).astype(np.uint8)
    
    # Load images and get full-fragment logit predictions
    fragment_images = get_img_stack(fragment_id, CFG.Z_START, CFG.Z_END, data_path=CFG.VALID_PATH, simulate_ir_absence=True)
    logit_map = predict_fragment(model, fragment_images, roi_mask)
    prob_map = 1 / (1 + np.exp(-logit_map))
    
    # --- Highly Optimized Grid Search (Updated with Expert Advice) ---
    print("Performing highly optimized grid search...")
    thresholds = np.arange(0.20, 0.80, 0.025)
    min_areas = [64, 96, 128, 160, 196, 256, 300]
    best_score = -1
    best_params = (CFG.BEST_THRESHOLD, CFG.MIN_AREA_SIZE)

    # 1. Find components ONCE at the lowest threshold
    print("Step 1/3: Finding connected components...")
    pred_mask_base = (prob_map > thresholds.min()).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(pred_mask_base, connectivity=8)
    component_areas = stats[1:, cv2.CC_STAT_AREA]

    # 2. Calculate average probability for each component ONCE (Vectorized)
    print("Step 2/3: Calculating component probabilities...")
    component_probs = ndimage_mean(prob_map, labels=labels, index=np.arange(1, num_labels))

    # 3. Fast grid search over pre-calculated properties
    print("Step 3/3: Searching for best parameters...")
    gt_pixels_roi = gt_mask[roi_mask > 0]
    for threshold in tqdm(thresholds, desc="Thresholds"):
        for min_area in min_areas:
            passing_indices = np.where((component_probs > threshold) & (component_areas >= min_area))[0] + 1
            pred_mask = np.isin(labels, passing_indices).astype(np.uint8)
            pred_pixels_roi = pred_mask[roi_mask > 0]
            score, _, _ = fbeta_score(gt_pixels_roi, pred_pixels_roi)
            
            if score > best_score:
                best_score = score
                best_params = (threshold, min_area)

    print(f"Calibration complete. Best F0.5 score: {best_score:.4f}")
    print(f"Best parameters found: Threshold={best_params[0]:.2f}, Min Area={best_params[1]}")
    
    return best_params

# Run calibration if enabled
if CFG.USE_CALIBRATION:
    best_threshold, best_min_area = calibrate_parameters(model)
    CFG.BEST_THRESHOLD = best_threshold
    CFG.MIN_AREA_SIZE = best_min_area
else:
    print("Skipping calibration. Using default parameters.")

In [6]:
# --- 7. Final Inference and Submission ---

print("\n--- Starting Final Inference on Test Set ---")
print(f"Using calibrated parameters: Threshold={CFG.BEST_THRESHOLD:.2f}, Min Area={CFG.MIN_AREA_SIZE}")

submission_data = []

for fragment_id in CFG.TEST_FRAGMENTS:
    print(f"\nProcessing fragment: {fragment_id}")
    
    # Load data for the test fragment
    print("Step 1/5: Loading images...")
    fragment_images = get_img_stack(fragment_id, CFG.Z_START, CFG.Z_END, data_path=CFG.TEST_PATH)
    roi_mask = cv2.imread(os.path.join(CFG.TEST_PATH, fragment_id, 'mask.png'), cv2.IMREAD_GRAYSCALE)
    
    # Get full fragment predictions
    print("Step 2/5: Predicting logits...")
    logit_map = predict_fragment(model, fragment_images, roi_mask)
    
    # Convert to probabilities and apply threshold
    print("Step 3/5: Applying threshold...")
    prob_map = 1 / (1 + np.exp(-logit_map))
    pred_mask = (prob_map > CFG.BEST_THRESHOLD).astype(np.uint8)
    
    # Apply ROI mask
    pred_mask *= (roi_mask > 0).astype(np.uint8)
    
    # Post-processing: remove small components
    print("Step 4/5: Removing small components...")
    final_mask = remove_small_components(pred_mask, CFG.MIN_AREA_SIZE)
    
    # RLE encode for submission
    print("Step 5/5: RLE encoding...")
    rle = rle_encode(final_mask)
    submission_data.append({'Id': fragment_id, 'Predicted': rle})
    
    # Clean up memory
    del fragment_images, roi_mask, logit_map, prob_map, pred_mask, final_mask
    gc.collect()

# Create and save submission file
print("\nCreating submission file...")
submission_df = pd.DataFrame(submission_data)
submission_df.to_csv('submission.csv', index=False)
print("\u2705 submission.csv created successfully!")