In [1]:
# --- 1. Setup and Configuration ---

import os
import gc
import numpy as np
import pandas as pd
import cv2
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2

class CFG:
    # General
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Data
    TEST_FRAGMENTS = ['a', 'b', 'c']
    VAL_FRAGMENT_ID = 2 # For calibration
    
    # Slices - Load a wide range to accommodate both models
    Z_START = 16 
    Z_END = 48
    
    # Tiling
    TILE_SIZE = 256
    STRIDE = TILE_SIZE // 2

    # Models
    # The robust model uses b4, the new model uses b0. We need to handle this.
    ENCODER_NAME_ROBUST = 'efficientnet-b4'
    ENCODER_NAME_TIF_V2 = 'efficientnet-b0'
    MODEL_ROBUST_PATH = 'best_robust_model.pth'
    MODEL_TIF_V2_PATH = 'best_tif_only_model_v4_b0.pth' # From 07_training
    
    # Model-specific input channels
    IN_CHANNELS_ROBUST = 32 # Trained on Z-range 16-48
    IN_CHANNELS_TIF_V2 = 25 # Trained on Z-range 20-45
    
    # Ensemble
    ENSEMBLE_WEIGHTS = [0.4, 0.6] # [robust, tif_v2]
    USE_TTA = True
    
    # Post-processing (to be calibrated)
    THRESHOLD = 0.4
    MIN_AREA = 100

print(f"Device: {CFG.DEVICE}")

Device: cuda


In [2]:
# --- 2. Model and Data Loading ---

def load_model(model_path, in_channels, encoder_name):
    """Loads a segmentation model from a state dict."""
    model = smp.Unet(
        encoder_name=encoder_name,
        encoder_weights=None,  # Weights are loaded from file
        in_channels=in_channels,
        classes=1,
        activation=None,
    )
    model.load_state_dict(torch.load(model_path, map_location=CFG.DEVICE))
    model.to(CFG.DEVICE)
    model.eval()
    return model

def get_img_stack(fragment_id, is_test=True):
    """Reads a stack of TIF images for a given fragment."""
    images = []
    
    data_folder = 'test' if is_test else 'train'
    fragment_path = f"{data_folder}/{fragment_id}/surface_volume"
    
    for i in range(CFG.Z_START, CFG.Z_END):
        image_path = os.path.join(fragment_path, f'{i:02}.tif')
        image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        if image is None:
            raise FileNotFoundError(f"Image not found: {image_path}")
            
        # Normalize slice
        p_low, p_high = np.percentile(image, [1, 99])
        image = np.clip((image.astype(np.float32) - p_low) / (p_high - p_low + 1e-6), 0, 1)
        images.append(image)
        
    return np.stack(images, axis=-1)

def rle_encode(mask):
    """Encodes a binary mask into Run-Length Encoding format."""
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# Define TTA transforms
def get_tta_transforms():
    """Returns a list of augmentation functions for TTA."""
    return [
        A.Compose([ToTensorV2(transpose_mask=True)]),
        A.Compose([A.HorizontalFlip(p=1.0), ToTensorV2(transpose_mask=True)]),
        A.Compose([A.VerticalFlip(p=1.0), ToTensorV2(transpose_mask=True)]),
        A.Compose([A.RandomRotate90(p=1.0), ToTensorV2(transpose_mask=True)]),
    ]

# Inverse TTA transforms
def get_inverse_tta_transforms():
    """Returns a list of inverse functions for TTA predictions."""
    return [
        lambda x: x, # Original
        lambda x: torch.flip(x, dims=[-1]), # HorizontalFlip
        lambda x: torch.flip(x, dims=[-2]), # VerticalFlip
        lambda x: torch.rot90(x, k=-1, dims=[-2, -1]), # RandomRotate90
    ]

print("Helper functions and TTA defined.")

Helper functions and TTA defined.


In [3]:
# --- 3. Main Inference Loop ---

def predict_fragment(fragment_id, model_robust, model_tif_v2, is_test=True):
    """Runs tiled inference on a single fragment and returns the predicted mask."""
    print(f"\nProcessing fragment {fragment_id}...")
    
    # Load data
    images = get_img_stack(fragment_id, is_test=is_test)
    height, width, _ = images.shape
    
    # Create prediction and normalization arrays
    pred_mask = np.zeros((height, width), dtype=np.float32)
    norm_mask = np.zeros((height, width), dtype=np.float32)
    
    # Get TTA transforms
    tta_transforms = get_tta_transforms() if CFG.USE_TTA else [get_tta_transforms()[0]]
    inverse_tta_transforms = get_inverse_tta_transforms() if CFG.USE_TTA else [get_inverse_tta_transforms()[0]]

    # Tiled inference
    for y in tqdm(range(0, height - CFG.TILE_SIZE + 1, CFG.STRIDE), desc=f"Inferring on {fragment_id}"):
        for x in range(0, width - CFG.TILE_SIZE + 1, CFG.STRIDE):
            tile = images[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE, :]
            
            ensembled_preds = torch.zeros((1, 1, CFG.TILE_SIZE, CFG.TILE_SIZE), device=CFG.DEVICE)
            
            with torch.no_grad():
                # TTA loop
                for transform, inv_transform in zip(tta_transforms, inverse_tta_transforms):
                    # --- Prepare inputs for each model ---
                    # Robust model input (Z: 16-48 -> 32 channels)
                    tile_robust = transform(image=tile)['image'].unsqueeze(0).to(CFG.DEVICE)
                    
                    # TIF v2 model input (Z: 20-45 -> 25 channels)
                    # The full stack is 16-48. We need 20-45. This corresponds to indices 4 to 4+25=29.
                    tile_tif_v2_np = tile[:, :, 4:29]
                    tile_tif_v2 = transform(image=tile_tif_v2_np)['image'].unsqueeze(0).to(CFG.DEVICE)

                    # --- Get predictions ---
                    pred_robust = model_robust(tile_robust)
                    pred_tif_v2 = model_tif_v2(tile_tif_v2)
                    
                    # Inverse TTA
                    pred_robust = inv_transform(pred_robust)
                    pred_tif_v2 = inv_transform(pred_tif_v2)
                    
                    # --- Ensemble (averaging sigmoid outputs) ---
                    ensembled_pred = (CFG.ENSEMBLE_WEIGHTS[0] * torch.sigmoid(pred_robust) + 
                                      CFG.ENSEMBLE_WEIGHTS[1] * torch.sigmoid(pred_tif_v2))
                    
                    ensembled_preds += ensembled_pred

            # Average TTA predictions
            avg_preds = ensembled_preds / len(tta_transforms)
            
            # Add to full mask
            pred_mask[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE] += avg_preds.squeeze().cpu().numpy()
            norm_mask[y:y+CFG.TILE_SIZE, x:x+CFG.TILE_SIZE] += 1

    # Normalize overlapping predictions
    pred_mask /= (norm_mask + 1e-6)
    
    return pred_mask

def main():
    # Load models
    print("Loading models...")
    model_robust = load_model(CFG.MODEL_ROBUST_PATH, CFG.IN_CHANNELS_ROBUST, CFG.ENCODER_NAME_ROBUST)
    model_tif_v2 = load_model(CFG.MODEL_TIF_V2_PATH, CFG.IN_CHANNELS_TIF_V2, CFG.ENCODER_NAME_TIF_V2)
    print("Models loaded.")

    results = []
    for fragment_id in CFG.TEST_FRAGMENTS:
        # Predict
        pred_mask = predict_fragment(fragment_id, model_robust, model_tif_v2, is_test=True)
        
        # Post-process
        binary_mask = (pred_mask > CFG.THRESHOLD).astype(np.uint8)
        
        # Remove small components
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
        for i in range(1, num_labels):
            if stats[i, cv2.CC_STAT_AREA] < CFG.MIN_AREA:
                binary_mask[labels == i] = 0
        
        # Apply original ROI mask from test set
        roi_mask = cv2.imread(f"test/{fragment_id}/mask.png", cv2.IMREAD_GRAYSCALE) > 0
        final_mask = binary_mask * roi_mask
        
        # RLE encode
        rle = rle_encode(final_mask)
        results.append({'Id': fragment_id, 'Predicted': rle})

    # Create submission file
    submission_df = pd.DataFrame(results)
    submission_df.to_csv('submission.csv', index=False)
    print("\nSubmission file created: submission.csv")

# Note: This notebook will be run once the training in 07 is complete and the model is saved.
# To run, uncomment the line below.
# main()

In [4]:
# --- 4. Calibration on Validation Set ---

def fbeta_score(y_true, y_pred, beta=0.5):
    """Calculates the F-beta score."""
    tp = (y_true * y_pred).sum()
    fp = ((1 - y_true) * y_pred).sum()
    fn = (y_true * (1 - y_pred)).sum()

    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    
    fbeta = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall + 1e-6)
    return fbeta

def calibrate_params(model_robust, model_tif_v2):
    """Finds the best threshold and min_area on the validation fragment."""
    print("\n--- Starting Calibration on Fragment 2 ---")
    
    # Predict on validation fragment (is_test=False)
    pred_mask_val = predict_fragment(CFG.VAL_FRAGMENT_ID, model_robust, model_tif_v2, is_test=False)
    
    # Load ground truth mask
    gt_mask = cv2.imread(f"train/{CFG.VAL_FRAGMENT_ID}/inklabels.png", cv2.IMREAD_GRAYSCALE) / 255
    gt_mask = gt_mask.astype(np.uint8)
    
    # Define search space
    thresholds = np.arange(0.2, 0.7, 0.05)
    min_areas = [25, 50, 75, 100, 125, 150]
    
    best_score = 0
    best_threshold = 0
    best_min_area = 0

    for threshold in tqdm(thresholds, desc="Calibrating Thresholds"):
        for min_area in min_areas:
            binary_mask = (pred_mask_val > threshold).astype(np.uint8)
            
            # Remove small components
            num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
            processed_mask = binary_mask.copy()
            for i in range(1, num_labels):
                if stats[i, cv2.CC_STAT_AREA] < min_area:
                    processed_mask[labels == i] = 0
            
            score = fbeta_score(gt_mask, processed_mask)
            
            if score > best_score:
                best_score = score
                best_threshold = threshold
                best_min_area = min_area
                print(f"New best score: {best_score:.4f} at T={best_threshold:.2f}, A={best_min_area}")

    print(f"\n--- Calibration Complete ---")
    print(f"Best F0.5 Score: {best_score:.4f}")
    print(f"Optimal Threshold: {best_threshold:.2f}")
    print(f"Optimal Min Area: {best_min_area}")
    
    # Update CFG with optimal values
    CFG.THRESHOLD = best_threshold
    CFG.MIN_AREA = best_min_area

# To run calibration, load models and then call this function.
# Note: This requires 'best_tif_only_model_v3_b0.pth' to exist.
# model_r = load_model(CFG.MODEL_ROBUST_PATH, CFG.IN_CHANNELS_ROBUST, CFG.ENCODER_NAME_ROBUST)
# model_t = load_model(CFG.MODEL_TIF_V2_PATH, CFG.IN_CHANNELS_TIF_V2, CFG.ENCODER_NAME_TIF_V2)
# calibrate_params(model_r, model_t)

In [None]:
# --- 5. Run Calibration ---
print("Loading models for calibration...")
model_r = load_model(CFG.MODEL_ROBUST_PATH, CFG.IN_CHANNELS_ROBUST, CFG.ENCODER_NAME_ROBUST)
model_t = load_model(CFG.MODEL_TIF_V2_PATH, CFG.IN_CHANNELS_TIF_V2, CFG.ENCODER_NAME_TIF_V2)
print("Models loaded.")

calibrate_params(model_r, model_t)

# Clean up memory
del model_r, model_t
gc.collect()
torch.cuda.empty_cache()

print("\nCFG updated with calibrated parameters:")
print(f"CFG.THRESHOLD = {CFG.THRESHOLD}")
print(f"CFG.MIN_AREA = {CFG.MIN_AREA}")

In [None]:
# --- 6. Run Final Inference & Generate Submission ---
print("\nStarting final inference on test fragments...")
main()
print("\nInference complete. submission.csv is ready.")