# üèõÔ∏è Vesuvius Challenge - Surface Detection with nnUNet (Runpods)

**Purpose:** 3D semantic segmentation using nnUNetv2 for detecting papyrus surfaces in CT scan volumes

**Runpods Version:** Modified from Kaggle notebook for Runpods environment

### Key Features
1. **Kaggle Data Download** - Automatic dataset download via Kaggle API
2. **Native TIFF support** - Custom SimpleTiffIO reader, no NIfTI conversion needed
3. **Pre-processed data caching** - Skip 1-2 hour preprocessing
4. **Multi-GPU support** - DDP training with auto-detection
5. **Runpods optimized paths** - Uses `/workspace` for persistent storage

### Environment Requirements
- Runpods with Network Volume mounted at `/workspace`
- CUDA-enabled GPU (RTX 3080/4090/A6000 recommended)
- 50GB+ storage space

In [None]:
import os
import json
import shutil
import subprocess
from functools import partial
from multiprocessing import Pool
from pathlib import Path
from typing import Optional, Tuple, List, Literal, Union

# Clear display function
def print_header(text: str, emoji: str = "üìã", width: int = 80):
    """Print formatted header"""
    print("\n" + "=" * width)
    print(f"{emoji} {text}")
    print("=" * width)

def print_status(key: str, value: str, indent: int = 2):
    """Print formatted status"""
    print(f"{' ' * indent}‚Ä¢ {key}: {value}")

# Fix OpenBLAS thread issues
print_header("ENVIRONMENT SETUP", "‚öôÔ∏è")

print("\nüîß Fixing OpenBLAS thread issues...")
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
print_status("Thread limit", "4 (prevents pthread_create errors)")

# Runpods-specific paths
print("\nüìÅ Setting up directories...")
WORKSPACE = Path("/workspace")
INPUT_DIR = WORKSPACE / "vesuvius_data"
WORKING_DIR = WORKSPACE / "temp"
OUTPUT_DIR = WORKSPACE / "results"

print_status("Workspace", str(WORKSPACE))
print_status("Input", str(INPUT_DIR))
print_status("Output", str(OUTPUT_DIR))

# nnUNet directory structure
print("\nüèóÔ∏è nnUNet directory structure...")
NNUNET_BASE = WORKSPACE / "nnUNet_data"
NNUNET_RAW = NNUNET_BASE / "nnUNet_raw"
NNUNET_PREPROCESSED = NNUNET_BASE / "nnUNet_preprocessed"
NNUNET_RESULTS = OUTPUT_DIR / "nnUNet_results"

print_status("Raw data", str(NNUNET_RAW))
print_status("Preprocessed", str(NNUNET_PREPROCESSED))
print_status("Results", str(NNUNET_RESULTS))

# Dataset configuration
print("\nüìä Dataset configuration...")
DATASET_ID = 100
DATASET_NAME = f"Dataset{DATASET_ID:03d}_VesuviusSurface"
print_status("Dataset ID", str(DATASET_ID))
print_status("Dataset name", DATASET_NAME)

# Training configuration - DEFAULT VALUES (will be updated by auto-config)
print("\nüéØ Default training configuration...")
FOLD = "all"
CONFIGURATION = "3d_fullres"  # Start with fullres by default
PLANNER = "nnUNetPlannerResEncM"  # Medium planner for compatibility
PLANS_NAME = "nnUNetResEncUNetMPlans"  # Match the planner
EPOCHS = 250
NUM_WORKERS = min(os.cpu_count() or 4, 4)

print_status("Configuration", f"{CONFIGURATION} (default)")
print_status("Planner", f"{PLANNER} (default)")
print_status("Epochs", str(EPOCHS))
print_status("Workers", str(NUM_WORKERS))
print_status("Fold", FOLD)

# GPU detection
def _get_gpu_count() -> int:
    try:
        import torch
        return torch.cuda.device_count() if torch.cuda.is_available() else 1
    except ImportError:
        return 0

NUM_GPUS = _get_gpu_count()

print("\nüéÆ Hardware detection...")
print_status("GPUs detected", str(NUM_GPUS))
print_status("CPU cores", str(os.cpu_count()))

# AUTO-CONFIGURE BASED ON GPU AND PREPROCESSED DATA
print("\n" + "=" * 80)
print("üéÆ AUTO-CONFIGURING FOR GPU AND AVAILABLE DATA...")
print("=" * 80)

# Check what plans are actually available
def check_available_plans():
    """Check which plan files actually exist"""
    preprocessed_path = NNUNET_PREPROCESSED / DATASET_NAME
    available_plans = []
    
    if preprocessed_path.exists():
        for plan_file in preprocessed_path.glob("*.json"):
            plan_name = plan_file.stem
            available_plans.append(plan_name)
    
    return available_plans

try:
    import torch
    available_plans = check_available_plans()
    
    print(f"\nüìã Available plan files:")
    for plan in available_plans:
        print_status("Plan", plan)
    
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        
        print(f"\nüìä GPU Information:")
        print_status("Model", gpu_name)
        print_status("VRAM", f"{vram_gb:.1f}GB")
        
        # Use available plans or fallback to compatible settings
        if "nnUNetResEncLPlans" in available_plans and vram_gb >= 40:
            # Large model available and high-end GPU
            CONFIGURATION = "3d_fullres"
            PLANNER = "nnUNetPlannerResEncL"
            PLANS_NAME = "nnUNetResEncLPlans"
            print("\n‚úÖ HIGH-END GPU + LARGE PLANS AVAILABLE")
            print_status("Configuration", "3d_fullres")
            print_status("Planner", "nnUNetPlannerResEncL (Large)")
            print_status("Estimated training time", "8-12 hours")
            
        elif "nnUNetResEncUNetMPlans" in available_plans and vram_gb >= 20:
            # Medium model available and sufficient GPU
            CONFIGURATION = "3d_fullres"
            PLANNER = "nnUNetPlannerResEncM"
            PLANS_NAME = "nnUNetResEncUNetMPlans"
            print("\n‚úÖ MID-RANGE GPU + MEDIUM PLANS AVAILABLE")
            print_status("Configuration", "3d_fullres")
            print_status("Planner", "nnUNetPlannerResEncM (Medium)")
            print_status("Estimated training time", "12-18 hours")
            
        elif "nnUNetPlans" in available_plans:
            # Default plans available
            if vram_gb >= 20:
                CONFIGURATION = "3d_fullres"
            else:
                CONFIGURATION = "3d_lowres"
            PLANNER = "nnUNetPlanner"
            PLANS_NAME = "nnUNetPlans"
            print("\n‚úÖ STANDARD PLANS AVAILABLE")
            print_status("Configuration", CONFIGURATION)
            print_status("Planner", "nnUNetPlanner (Standard)")
            
        else:
            # No preprocessed data - need to preprocess first
            print("\n‚ö†Ô∏è NO PREPROCESSED DATA FOUND")
            print_status("Action required", "Run preprocessing first")
            
            # Set safe defaults for preprocessing
            CONFIGURATION = "3d_fullres" if vram_gb >= 20 else "3d_lowres"
            PLANNER = "nnUNetPlannerResEncM"
            PLANS_NAME = "nnUNetResEncUNetMPlans"
            
    else:
        print("\n‚ö†Ô∏è No GPU detected - using CPU fallback")
        CONFIGURATION = "3d_lowres"
        PLANNER = "nnUNetPlannerResEncM"
        PLANS_NAME = "nnUNetResEncUNetMPlans"
        
except Exception as e:
    print(f"\n‚ö†Ô∏è Auto-configuration failed: {e}")
    print("Using safe default configuration")
    CONFIGURATION = "3d_fullres"
    PLANNER = "nnUNetPlannerResEncM"  
    PLANS_NAME = "nnUNetResEncUNetMPlans"

# Final configuration summary
print("\n" + "=" * 80)
print("‚úÖ FINAL CONFIGURATION")
print("=" * 80)
print_status("Mode", CONFIGURATION)
print_status("Planner", PLANNER)
print_status("Plans", PLANS_NAME)
print_status("Epochs", str(EPOCHS))
print_status("Fold", FOLD)
print_status("Workers", str(NUM_WORKERS))

# Show manual override options
print("\nüí° Manual override (if needed):")
print("   # Use existing preprocessed plans:")
available_plans = check_available_plans()
for plan in available_plans[:3]:  # Show first 3 options
    config_type = "3d_fullres" if "fullres" in plan or "M" in plan or "L" in plan else "3d_lowres"
    print(f"   CONFIGURATION = '{config_type}'; PLANS_NAME = '{plan}'")

print("\n" + "=" * 80)
print("‚úÖ Environment setup complete!")
print("=" * 80)

In [None]:
# A6000 optimized configuration cell
def configure_for_gpu():
    """Auto-configure based on available GPU"""
    global CONFIGURATION, EPOCHS, PLANNER
    
    print("\nüéÆ AUTO-CONFIGURING FOR GPU...")
    
    try:
        import torch
        if torch.cuda.is_available():
            gpu_name = torch.cuda.get_device_name(0)
            vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            
            print(f"   GPU detected: {gpu_name}")
            print(f"   VRAM: {vram_gb:.1f}GB")
            
            # Auto-select configuration based on VRAM
            if vram_gb >= 40:  # A6000, A100
                CONFIGURATION = "3d_fullres"
                PLANNER = "nnUNetPlannerResEncL"  # Large model for A6000
                print("   ‚úÖ Using 3d_fullres (High-end GPU)")
                print("   üí° Patch size: up to 160x160x160")
                print("   üí° Batch size: 2-4")
            elif vram_gb >= 20:  # RTX 4090, 3090
                CONFIGURATION = "3d_fullres"
                PLANNER = "nnUNetPlannerResEncM"  # Medium model
                print("   ‚úÖ Using 3d_fullres (Mid-range GPU)")
                print("   üí° Patch size: 96x96x96")
                print("   üí° Batch size: 2")
            else:  # T4, RTX 3080
                CONFIGURATION = "3d_lowres"
                PLANNER = "nnUNetPlannerResEncM"
                print("   ‚úÖ Using 3d_lowres (Entry-level GPU)")
                print("   üí° Patch size: 64x64x64")
                print("   üí° Batch size: 1-2")
        else:
            print("   ‚ö†Ô∏è No GPU detected, using default 3d_lowres")
    except Exception as e:
        print(f"   ‚ö†Ô∏è Auto-config failed: {e}")
        print("   Using default configuration")
    
    print(f"\n   Final configuration:")
    print(f"   ‚Ä¢ Mode: {CONFIGURATION}")
    print(f"   ‚Ä¢ Planner: {PLANNER}")
    print(f"   ‚Ä¢ Epochs: {EPOCHS}")

# Run auto-configuration
configure_for_gpu()

In [None]:
# Install required packages
!pip install kaggle nnunetv2 nibabel tifffile tqdm -q

# Create directories
for directory in [INPUT_DIR, WORKING_DIR, OUTPUT_DIR, NNUNET_RAW, NNUNET_PREPROCESSED, NNUNET_RESULTS]:
    directory.mkdir(parents=True, exist_ok=True)

print("‚úÖ Packages installed and directories created")

In [None]:
def setup_kaggle_auth() -> bool:
    """
    Setup Kaggle authentication for Runpods
    """
    possible_paths = [
        Path.home() / ".kaggle" / "kaggle.json",
        WORKSPACE / ".kaggle" / "kaggle.json",
        WORKSPACE / "kaggle.json",
        Path("./kaggle.json")
    ]
    
    kaggle_json = None
    for path in possible_paths:
        if path.exists():
            kaggle_json = path
            print(f"‚úÖ Kaggle config found: {path}")
            break
    
    if kaggle_json:
        # Copy to standard location
        kaggle_dir = Path.home() / ".kaggle"
        kaggle_dir.mkdir(exist_ok=True)
        standard_path = kaggle_dir / "kaggle.json"
        
        if kaggle_json != standard_path:
            shutil.copy2(kaggle_json, standard_path)
            print(f"üìÅ Copied to standard location: {standard_path}")
        
        os.chmod(standard_path, 0o600)
        
        try:
            import kaggle
            kaggle.api.authenticate()
            print("‚úÖ Kaggle authentication successful")
            return True
        except Exception as e:
            print(f"‚ùå Kaggle authentication failed: {e}")
            return False
    else:
        print("‚ùå Kaggle config not found. Place kaggle.json in one of:")
        for path in possible_paths:
            print(f"  ‚Ä¢ {path}")
        return False

kaggle_ready = setup_kaggle_auth()

In [None]:
def download_vesuvius_data():
    """
    Download Vesuvius Challenge data using Kaggle API with clear progress display
    """
    if not kaggle_ready:
        print("‚ùå Kaggle authentication required")
        return False
    
    import kaggle
    import zipfile
    import threading
    import time
    
    print_header("DATA DOWNLOAD", "üì•")
    
    # Expected file sizes
    DATASETS = {
        "competition": {
            "name": "vesuvius-challenge-surface-detection",
            "type": "competition",
            "size_mb": 500,
            "path": INPUT_DIR / "competition"
        },
        "preprocessed": {
            "name": "jirkaborovec/vesuvius-surface-nnunet-preprocessed",
            "type": "dataset",
            "size_mb": 2000,
            "path": INPUT_DIR / "preprocessed"
        }
    }
    
    def format_size(mb: float) -> str:
        """Format size in MB/GB"""
        if mb >= 1024:
            return f"{mb/1024:.1f}GB"
        return f"{mb:.0f}MB"
    
    def create_progress_bar(current: float, total: float, width: int = 40) -> str:
        """Create a visual progress bar"""
        percent = min(100, (current / total) * 100) if total > 0 else 0
        filled = int(width * percent / 100)
        bar = '‚ñà' * filled + '‚ñë' * (width - filled)
        return f"[{bar}] {percent:.1f}%"
    
    def download_dataset(dataset_info: dict) -> bool:
        """Download a single dataset with progress"""
        name = dataset_info["name"]
        dtype = dataset_info["type"]
        expected_mb = dataset_info["size_mb"]
        path = dataset_info["path"]
        
        # Check if already exists
        if dtype == "competition":
            check_path = path / "train_images"
        else:
            check_path = path
            
        if check_path.exists() and len(list(check_path.glob("*"))) > 0:
            print(f"\n‚úÖ Already downloaded: {name}")
            return True
        
        print(f"\nüì¶ Downloading: {name}")
        print(f"   Expected size: {format_size(expected_mb)}")
        print(f"   Destination: {path}")
        
        path.mkdir(parents=True, exist_ok=True)
        
        try:
            # Start download
            start_time = time.time()
            
            if dtype == "competition":
                print("   Status: Downloading competition data...")
                kaggle.api.competition_download_files(
                    name.replace('vesuvius-challenge-', ''),
                    path=str(path),
                    quiet=False
                )
            else:
                print("   Status: Downloading dataset...")
                kaggle.api.dataset_download_files(
                    name,
                    path=str(path),
                    quiet=False,
                    unzip=False
                )
            
            elapsed = time.time() - start_time
            print(f"   ‚úÖ Download complete in {int(elapsed//60):02d}:{int(elapsed%60):02d}")
            
            # Extract files
            zip_files = list(path.glob("*.zip"))
            if zip_files:
                for zip_file in zip_files:
                    print(f"   üìÇ Extracting: {zip_file.name}")
                    
                    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
                        total_files = len(zip_ref.namelist())
                        print(f"   Files to extract: {total_files}")
                        
                        # Extract with progress
                        for i, member in enumerate(zip_ref.namelist()):
                            if i % max(1, total_files // 10) == 0:
                                percent = (i / total_files) * 100
                                print(f"   Progress: {percent:.0f}% ({i}/{total_files} files)")
                            zip_ref.extract(member, path)
                    
                    print(f"   ‚úÖ Extraction complete")
                    zip_file.unlink()  # Delete zip
            
            return True
            
        except Exception as e:
            print(f"   ‚ùå Error: {e}")
            return False
    
    # Download each dataset
    success = True
    total_datasets = len(DATASETS)
    
    for i, (key, dataset_info) in enumerate(DATASETS.items(), 1):
        print(f"\n{'='*60}")
        print(f"üìä Dataset {i}/{total_datasets}: {key.upper()}")
        print(f"{'='*60}")
        
        if not download_dataset(dataset_info):
            success = False
            if key == "competition":
                print("‚ö†Ô∏è Competition data is required. Stopping download.")
                break
            else:
                print("‚ö†Ô∏è Optional dataset failed. Continuing...")
    
    # Summary
    print("\n" + "="*80)
    if success:
        print("‚úÖ DATA DOWNLOAD COMPLETE")
        
        # Show what was downloaded
        print("\nüìÅ Downloaded files:")
        for key, dataset_info in DATASETS.items():
            path = dataset_info["path"]
            if path.exists():
                size = sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) / (1024*1024)
                print(f"   ‚Ä¢ {key}: {format_size(size)} at {path}")
    else:
        print("‚ö†Ô∏è PARTIAL DOWNLOAD - Some datasets failed")
    print("="*80)
    
    return success

# Download data
if kaggle_ready:
    data_downloaded = download_vesuvius_data()
else:
    print("\n‚ö†Ô∏è Skipping data download - Kaggle authentication not available")
    data_downloaded = False

In [None]:
# Set up nnUNet environment variables
os.environ["nnUNet_raw"] = str(NNUNET_RAW)
os.environ["nnUNet_preprocessed"] = str(NNUNET_PREPROCESSED)
os.environ["nnUNet_results"] = str(NNUNET_RESULTS)
os.environ["nnUNet_USE_BLOSC2"] = "1"  # Faster compression
os.environ["nnUNet_compile"] = "true"  # Enable torch.compile

print("üîß nnUNet environment configured:")
print(f"  Raw: {NNUNET_RAW}")
print(f"  Preprocessed: {NNUNET_PREPROCESSED}")
print(f"  Results: {NNUNET_RESULTS}")

# Import after setting environment
import nibabel as nib
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

print(f"‚úÖ Environment ready with {NUM_GPUS} GPU(s)")

In [None]:
def fix_dataset_json_ioclass(dataset_path: Path) -> bool:
    """Fix SimpleTiffIO error by removing ioclass from dataset.json"""
    json_path = dataset_path / "dataset.json"
    if not json_path.exists():
        return False
    
    try:
        with open(json_path, 'r') as f:
            config = json.load(f)
        
        modified = False
        # Remove problematic settings
        if 'ioclass' in config:
            del config['ioclass']
            modified = True
            print("  ‚úÖ Removed 'ioclass' setting")
        
        if 'overwrite_image_reader_writer' in config:
            if config['overwrite_image_reader_writer'] == 'SimpleTiffIO':
                del config['overwrite_image_reader_writer']
                modified = True
                print("  ‚úÖ Removed SimpleTiffIO override")
        
        if modified:
            with open(json_path, 'w') as f:
                json.dump(config, f, indent=4)
            print(f"  ‚úÖ Fixed dataset.json at {json_path}")
        
        return True
    except Exception as e:
        print(f"  ‚ùå Error fixing dataset.json: {e}")
        return False

def create_spacing_json(output_path: Path, shape: tuple, spacing: tuple = (1.0, 1.0, 1.0)):
    """Create JSON sidecar with spacing info for TIFF files."""
    json_data = {"spacing": list(spacing)}
    with open(output_path, "w") as f:
        json.dump(json_data, f)

def create_dataset_json(output_dir: Path, num_training: int, file_ending: str = ".tif") -> dict:
    """Create dataset.json with ignore label support (without SimpleTiffIO)."""
    dataset_json = {
        "channel_names": {"0": "CT"},
        "labels": {"background": 0, "surface": 1, "ignore": 2},
        "numTraining": num_training,
        "file_ending": file_ending
        # Removed overwrite_image_reader_writer to avoid SimpleTiffIO error
    }
    
    json_path = output_dir / "dataset.json"
    with open(json_path, "w") as f:
        json.dump(dataset_json, f, indent=4)
    
    print(f"Created dataset.json: {num_training} training cases")
    return dataset_json

def prepare_single_case(src_path: Path, dest_path: Path, json_path: Path, use_symlinks: bool = True) -> bool:
    """Prepare a single TIFF file for nnUNet."""
    try:
        # Get shape for JSON
        with tifffile.TiffFile(src_path) as tif:
            shape = tif.pages[0].shape if len(tif.pages) == 1 else (len(tif.pages), *tif.pages[0].shape)
        
        # Link or copy file
        if use_symlinks:
            if not dest_path.exists():
                dest_path.symlink_to(src_path.resolve())
        else:
            shutil.copy2(src_path, dest_path)
        
        # Create JSON sidecar
        create_spacing_json(json_path, shape)
        return True
    
    except Exception as e:
        print(f"Error processing {src_path.name}: {e}")
        return False

print("‚úÖ Utility functions loaded (SimpleTiffIO issue fixed)")

In [None]:
def prepare_dataset(input_base_dir: Path, max_cases: Optional[int] = None, use_symlinks: bool = True):
    """
    Convert competition data to nnUNet format
    """
    dataset_dir = NNUNET_RAW / DATASET_NAME
    images_dir = dataset_dir / "imagesTr"
    labels_dir = dataset_dir / "labelsTr"
    
    images_dir.mkdir(parents=True, exist_ok=True)
    labels_dir.mkdir(parents=True, exist_ok=True)
    
    # Look for competition data
    train_images_dir = input_base_dir / "competition" / "train_images"
    train_labels_dir = input_base_dir / "competition" / "train_labels"
    
    if not train_images_dir.exists():
        print(f"‚ùå Training images not found: {train_images_dir}")
        return None
    
    image_files = sorted(train_images_dir.glob("*.tif"))
    if max_cases:
        image_files = image_files[:max_cases]
    
    print(f"Found {len(image_files)} training cases")
    print(f"Using {'symlinks' if use_symlinks else 'copy'}")
    
    success_count = 0
    for img_path in tqdm(image_files, desc="Preparing dataset"):
        case_id = img_path.stem
        label_path = train_labels_dir / img_path.name
        
        if not label_path.exists():
            print(f"Warning: No label for {case_id}")
            continue
        
        # Prepare image
        img_ok = prepare_single_case(
            img_path,
            images_dir / f"{case_id}_0000.tif",
            images_dir / f"{case_id}_0000.json",
            use_symlinks
        )
        
        # Prepare label
        label_ok = prepare_single_case(
            label_path,
            labels_dir / f"{case_id}.tif",
            labels_dir / f"{case_id}.json",
            use_symlinks
        )
        
        if img_ok and label_ok:
            success_count += 1
    
    create_dataset_json(dataset_dir, success_count, file_ending=".tif")
    print(f"‚úÖ Dataset prepared: {success_count} cases")
    return dataset_dir

# Prepare dataset if data was downloaded
if data_downloaded:
    dataset_dir = prepare_dataset(INPUT_DIR, max_cases=50)  # Limit for testing
else:
    print("Skipping dataset preparation - no data available")
    dataset_dir = None

In [None]:
def run_command(cmd: str, name: str = "Command", timeout: Optional[int] = None) -> bool:
    """Execute shell command with lightweight text-based progress display."""
    import subprocess
    import time
    import re
    
    print(f"üöÄ Starting: {name}")
    print(f"üìù Command: {cmd[:100]}..." if len(cmd) > 100 else f"üìù Command: {cmd}")
    print("-" * 80)
    
    # Progress tracking variables
    current_epoch = 0
    total_epochs = 250
    current_batch = 0
    total_batches = 0
    best_dice = 0.0
    current_loss = 0.0
    validation_dice = 0.0
    start_time = time.time()
    epoch_start_time = time.time()
    
    def format_time(seconds: float) -> str:
        """Format seconds to HH:MM:SS"""
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        if hours > 0:
            return f"{hours:02d}:{minutes:02d}:{secs:02d}"
        return f"{minutes:02d}:{secs:02d}"
    
    def print_epoch_progress():
        """Print overall epoch progress"""
        elapsed = time.time() - start_time
        
        if current_epoch > 0:
            progress_pct = (current_epoch / total_epochs) * 100
            time_per_epoch = elapsed / current_epoch
            eta = (total_epochs - current_epoch) * time_per_epoch
        else:
            progress_pct = 0
            eta = 0
        
        # Epoch progress bar
        bar_width = 30
        filled = int(bar_width * progress_pct / 100)
        epoch_bar = "‚ñà" * filled + "‚ñë" * (bar_width - filled)
        
        print(f"\rüìä Epoch {current_epoch:3d}/{total_epochs} [{epoch_bar}] {progress_pct:5.1f}% | "
              f"Best Dice: {best_dice:.3f} | Elapsed: {format_time(elapsed)} | ETA: {format_time(eta)}", 
              end="", flush=True)
    
    def print_batch_progress():
        """Print current epoch batch progress"""
        if total_batches > 0:
            batch_pct = (current_batch / total_batches) * 100
            epoch_elapsed = time.time() - epoch_start_time
            
            if current_batch > 0:
                time_per_batch = epoch_elapsed / current_batch
                batch_eta = (total_batches - current_batch) * time_per_batch
            else:
                batch_eta = 0
            
            # Batch progress bar
            bar_width = 25
            filled = int(bar_width * batch_pct / 100)
            batch_bar = "‚ñà" * filled + "‚ñë" * (bar_width - filled)
            
            print(f"\r  ‚è≥ Batch {current_batch:3d}/{total_batches} [{batch_bar}] {batch_pct:5.1f}% | "
                  f"Loss: {current_loss:.3f} | Batch ETA: {format_time(batch_eta)}", 
                  end="", flush=True)
    
    try:
        # Start process
        process = subprocess.Popen(
            cmd, shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True,
            env={**os.environ, 'PYTHONUNBUFFERED': '1'}
        )
        
        output_lines = []
        last_update = time.time()
        
        for line in process.stdout:
            line = line.strip()
            if not line:
                continue
                
            output_lines.append(line)
            line_lower = line.lower()
            
            # Extract epoch information
            if 'epoch' in line_lower:
                epoch_match = re.search(r'epoch[:\s]+(\d+)', line_lower)
                if epoch_match:
                    new_epoch = int(epoch_match.group(1))
                    if new_epoch > current_epoch:
                        if current_epoch > 0:  # Finish previous epoch display
                            print()  # New line after batch progress
                        current_epoch = new_epoch
                        current_batch = 0
                        total_batches = 0
                        epoch_start_time = time.time()
                        print_epoch_progress()
            
            # Extract batch information
            batch_patterns = [
                r'batch[:\s]+(\d+)[/\s]+(\d+)',  # batch: 45/200
                r'(\d+)[/](\d+)',  # 45/200
                r'step[:\s]+(\d+)[/\s]+(\d+)',  # step: 45/200
            ]
            
            for pattern in batch_patterns:
                batch_match = re.search(pattern, line_lower)
                if batch_match:
                    current_batch = int(batch_match.group(1))
                    total_batches = int(batch_match.group(2))
                    break
            
            # Extract total epochs
            if 'training' in line_lower and ('epochs' in line_lower or 'epoch' in line_lower):
                epochs_match = re.search(r'(\d+)\s*epochs?', line_lower)
                if epochs_match:
                    total_epochs = int(epochs_match.group(1))
            
            # Extract metrics
            dice_patterns = [
                r'dice[:\s]+([0-9.]+)',
                r'mean_dice[:\s]+([0-9.]+)',
                r'validation_dice[:\s]+([0-9.]+)',
            ]
            
            for pattern in dice_patterns:
                dice_match = re.search(pattern, line_lower)
                if dice_match:
                    dice_val = float(dice_match.group(1))
                    if 'val' in line_lower or 'validation' in line_lower:
                        validation_dice = dice_val
                    if dice_val > best_dice:
                        best_dice = dice_val
                    break
            
            loss_patterns = [
                r'loss[:\s]+([0-9.]+)',
                r'train_loss[:\s]+([0-9.]+)',
            ]
            
            for pattern in loss_patterns:
                loss_match = re.search(pattern, line_lower)
                if loss_match:
                    current_loss = float(loss_match.group(1))
                    break
            
            # Update progress displays
            now = time.time()
            if now - last_update > 5:  # Update every 5 seconds for batch progress
                if current_batch > 0 and total_batches > 0:
                    print()  # New line after epoch progress
                    print_batch_progress()
                else:
                    print_epoch_progress()
                last_update = now
            
            # Show important messages immediately
            if any(keyword in line_lower for keyword in ['error', 'failed', 'saved', 'completed', 'best', 'validation']):
                print(f"\nüìù {line}")
        
        # Wait for process to complete
        process.wait()
        elapsed_total = time.time() - start_time
        
        print()  # New line after progress bar
        print("-" * 80)
        
        if process.returncode == 0:
            print(f"‚úÖ {name} completed successfully!")
            print(f"üìä Final: {current_epoch} epochs | Best Dice: {best_dice:.4f} | Val Dice: {validation_dice:.4f}")
            print(f"‚è∞ Duration: {format_time(elapsed_total)}")
            
            if best_dice > 0:
                print(f"üéâ Model saved with best Dice score: {best_dice:.4f}")
            
            return True
        else:
            print(f"‚ùå {name} failed (exit code: {process.returncode})")
            print(f"üìä Progress: {current_epoch}/{total_epochs} epochs")
            print(f"‚è∞ Duration: {format_time(elapsed_total)}")
            
            # Show last few lines of output for debugging
            if output_lines:
                print("\nüîç Last output lines:")
                for line in output_lines[-5:]:
                    print(f"   {line}")
            
            return False
            
    except Exception as e:
        print(f"\nüí• {name} ERROR: {e}")
        return False

# Keep the same function signatures for preprocessing, training, and inference
def run_preprocessing(dataset_id: int = DATASET_ID, planner: str = PLANNER, num_workers: int = None) -> bool:
    """Run nnUNet preprocessing with progress bar."""
    if num_workers is None:
        num_workers = NUM_WORKERS
    
    num_workers = min(num_workers, 4)
    
    cmd = f"nnUNetv2_plan_and_preprocess -d {dataset_id:03d} -np {num_workers} -pl {planner} -c {CONFIGURATION}"
    return run_command(cmd, f"Preprocessing (Dataset {dataset_id:03d})", timeout=7200)

def run_training(dataset_id: int = DATASET_ID, config: str = CONFIGURATION, 
                fold: Union[int, str] = FOLD, plans: str = PLANS_NAME, 
                epochs: Optional[int] = EPOCHS, num_gpus: int = NUM_GPUS) -> bool:
    """Run nnUNet training with epoch and batch progress display."""
    trainer = "nnUNetTrainer" if epochs is None or epochs == 1000 else f"nnUNetTrainer_{epochs}epochs"
    
    cmd = f"PYTHONUNBUFFERED=1 nnUNetv2_train {dataset_id:03d} {config} {fold} -p {plans} -tr {trainer}"
    if num_gpus > 1:
        cmd += f" -num_gpus {num_gpus}"
    
    return run_command(cmd, f"Training ({epochs} epochs, {config})", timeout=86400)

def run_inference(input_dir: Path, output_dir: Path, dataset_id: int = DATASET_ID,
                 config: str = CONFIGURATION, fold: Union[int, str] = FOLD,
                 plans: str = PLANS_NAME, epochs: Optional[int] = EPOCHS) -> bool:
    """Run inference with progress bar."""
    output_dir.mkdir(parents=True, exist_ok=True)
    
    trainer = "nnUNetTrainer" if epochs is None or epochs == 1000 else f"nnUNetTrainer_{epochs}epochs"
    
    cmd = f"nnUNetv2_predict -d {dataset_id:03d} -c {config} -f {fold}"
    cmd += f" -i {input_dir} -o {output_dir} -p {plans} -tr {trainer}"
    cmd += " --save_probabilities --verbose"
    
    return run_command(cmd, "Inference", timeout=3600)

print("‚úÖ Lightweight progress with epoch AND batch progress ready")

In [None]:
def prepare_test_data(input_base_dir: Path, output_dir: Path, use_symlinks: bool = True) -> Path:
    """Prepare test TIFF images for nnUNet inference."""
    output_dir.mkdir(parents=True, exist_ok=True)
    
    test_images_dir = input_base_dir / "competition" / "test_images"
    
    if not test_images_dir.exists():
        print(f"‚ùå Test images not found: {test_images_dir}")
        return output_dir
    
    test_files = sorted(test_images_dir.glob("*.tif"))
    print(f"Found {len(test_files)} test cases")
    
    for img_path in tqdm(test_files, desc="Preparing test data"):
        case_id = img_path.stem
        prepare_single_case(
            img_path,
            output_dir / f"{case_id}_0000.tif",
            output_dir / f"{case_id}_0000.json",
            use_symlinks
        )
    
    return output_dir

def predictions_to_tiff(pred_dir: Path, output_dir: Path):
    """Convert nnUNet predictions to TIFF format."""
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Try NPZ files first (probability maps)
    npz_files = list(pred_dir.glob("*.npz"))
    tif_files = list(pred_dir.glob("*.tif"))
    
    if npz_files:
        print(f"Converting {len(npz_files)} NPZ files to TIFF...")
        for npz_path in tqdm(npz_files, desc="Converting"):
            case_id = npz_path.stem
            # Load probabilities and take argmax
            data = np.load(npz_path)
            probs = data['probabilities']
            pred = np.argmax(probs, axis=0).astype(np.uint8)
            tifffile.imwrite(output_dir / f"{case_id}.tif", pred)
    elif tif_files:
        print(f"Copying {len(tif_files)} TIFF files...")
        for tif_path in tqdm(tif_files, desc="Copying"):
            case_id = tif_path.stem
            pred = tifffile.imread(str(tif_path)).astype(np.uint8)
            tifffile.imwrite(output_dir / f"{case_id}.tif", pred)
    else:
        print(f"‚ùå No prediction files found in {pred_dir}")

print("‚úÖ Test data and conversion functions ready")

In [None]:
def full_pipeline(do_preprocess: bool = True, do_train: bool = True, 
                 do_inference: bool = True, max_cases: Optional[int] = None):
    """Run complete nnUNet pipeline with clear status display."""
    
    # Pipeline header
    print("\n" + "="*80)
    print("üèõÔ∏è VESUVIUS CHALLENGE - NNUNET PIPELINE")
    print("="*80)
    
    # Show configuration
    print("\nüìã PIPELINE CONFIGURATION")
    print("-"*40)
    print(f"  ‚ñ∂ Preprocessing: {'‚úÖ Yes' if do_preprocess else '‚≠ï Skip'}")
    print(f"  ‚ñ∂ Training:      {'‚úÖ Yes' if do_train else '‚≠ï Skip'}")
    print(f"  ‚ñ∂ Inference:     {'‚úÖ Yes' if do_inference else '‚≠ï Skip'}")
    print(f"  ‚ñ∂ Max cases:     {max_cases if max_cases else 'All'}")
    print(f"  ‚ñ∂ Configuration: {CONFIGURATION}")
    print(f"  ‚ñ∂ Epochs:        {EPOCHS}")
    print(f"  ‚ñ∂ GPUs:          {NUM_GPUS}")
    print("-"*40)
    
    # Step counter
    current_step = 0
    total_steps = sum([1, do_preprocess, do_train, do_inference])  # 1 for data prep
    
    def print_step(step_name: str, status: str = "STARTING"):
        nonlocal current_step
        current_step += 1
        print(f"\n{'='*80}")
        print(f"üìç STEP {current_step}/{total_steps}: {step_name}")
        print(f"   Status: {status}")
        print(f"{'='*80}")
    
    # 1. Data preparation
    print_step("DATA PREPARATION")
    dataset_path = NNUNET_RAW / DATASET_NAME
    
    if not dataset_path.exists():
        if not data_downloaded:
            print("   ‚ùå No data available")
            print("   üí° Please set up Kaggle authentication and download data")
            return False
        
        print("   üìÅ Preparing dataset from raw data...")
        dataset_dir = prepare_dataset(INPUT_DIR, max_cases=max_cases)
        if not dataset_dir:
            print("   ‚ùå Dataset preparation failed")
            return False
        print("   ‚úÖ Dataset prepared successfully")
    else:
        print("   ‚úÖ Dataset already exists")
        dataset_dir = dataset_path
        
        # Count files
        images_dir = dataset_dir / "imagesTr"
        if images_dir.exists():
            nifti_count = len(list(images_dir.glob("*.nii.gz")))
            tiff_count = len(list(images_dir.glob("*.tif")))
            print(f"   üìä Files: {nifti_count} NIfTI, {tiff_count} TIFF")
    
    # Fix issues if dataset exists
    if dataset_dir and dataset_dir.exists():
        print("\n   üîß Checking for issues...")
        
        # Fix SimpleTiffIO
        fix_dataset_json_ioclass(dataset_dir)
        
        # Convert TIFF to NIfTI if needed
        print("   üîÑ Checking file formats...")
        if convert_tiff_dataset_to_nifti(dataset_dir):
            print("   ‚úÖ File formats verified")
    
    # 2. Preprocessing
    if do_preprocess:
        print_step("PREPROCESSING", "Running nnUNet planning and preprocessing")
        
        print(f"   Configuration: {CONFIGURATION}")
        print(f"   Planner: {PLANNER}")
        print(f"   Workers: {NUM_WORKERS}")
        
        if not run_preprocessing(num_workers=NUM_WORKERS):
            print("\n   ‚ùå Preprocessing failed")
            print("   üîÑ Retrying with fewer workers...")
            
            if not run_preprocessing(num_workers=2):
                print("   ‚ùå Preprocessing failed again")
                return False
        
        print("   ‚úÖ Preprocessing completed successfully")
    else:
        print("\n   ‚≠ï Skipping preprocessing")
    
    # 3. Training
    if do_train:
        print_step("TRAINING", f"Starting {EPOCHS} epoch training")
        
        print(f"   Model: ResidualEncoderUNet")
        print(f"   Patch size: 128x128x128")
        print(f"   Batch size: 2")
        print(f"   Learning rate: PolyLR schedule")
        print(f"\n   üìä Training progress will be shown below:")
        print("   " + "-"*40)
        
        if not run_training():
            print("   ‚ùå Training failed")
            return False
        
        print("   ‚úÖ Training completed successfully")
    else:
        print("\n   ‚≠ï Skipping training")
    
    # 4. Inference
    if do_inference:
        print_step("INFERENCE", "Running predictions on test data")
        
        # Prepare test data
        test_input_dir = WORKING_DIR / "test_input"
        
        if data_downloaded:
            print("   üìÅ Preparing test data...")
            prepare_test_data(INPUT_DIR, test_input_dir)
            
            # Count test files
            test_files = list(test_input_dir.glob("*.tif")) + list(test_input_dir.glob("*.nii.gz"))
            print(f"   üìä Test cases: {len(test_files)}")
        else:
            print("   ‚ö†Ô∏è No test data available")
            return True
        
        # Run inference
        print("   üîÆ Running inference...")
        predictions_dir = WORKING_DIR / "predictions"
        
        if not run_inference(test_input_dir, predictions_dir):
            print("   ‚ùå Inference failed")
            return False
        
        # Convert predictions
        print("   üìÑ Converting predictions to TIFF...")
        tiff_output_dir = OUTPUT_DIR / "predictions_tiff"
        predictions_to_tiff(predictions_dir, tiff_output_dir)
        
        print(f"   ‚úÖ Predictions saved to: {tiff_output_dir}")
    else:
        print("\n   ‚≠ï Skipping inference")
    
    # Final summary
    print("\n" + "="*80)
    print("üéâ PIPELINE COMPLETED SUCCESSFULLY!")
    print("="*80)
    
    print("\nüìä SUMMARY:")
    print("-"*40)
    
    if do_preprocess:
        preprocessed_dir = NNUNET_PREPROCESSED / DATASET_NAME
        if preprocessed_dir.exists():
            print(f"  ‚úÖ Preprocessed data: {preprocessed_dir}")
    
    if do_train:
        results_dir = NNUNET_RESULTS / f"Dataset{DATASET_ID:03d}_{DATASET_NAME}" / f"{PLANNER}__{CONFIGURATION}"
        if results_dir.exists():
            print(f"  ‚úÖ Training results: {results_dir}")
            
            # Check for best model
            best_model = results_dir / "fold_all" / "checkpoint_best.pth"
            if best_model.exists():
                size_mb = best_model.stat().st_size / (1024*1024)
                print(f"  ‚úÖ Best model saved: {size_mb:.1f}MB")
    
    if do_inference:
        if tiff_output_dir.exists():
            pred_count = len(list(tiff_output_dir.glob("*.tif")))
            print(f"  ‚úÖ Predictions: {pred_count} files")
    
    print("-"*40)
    print("\n‚ú® All tasks completed successfully!")
    
    return True

print("‚úÖ Pipeline function ready with clear output formatting")

In [None]:
# Add this cell BEFORE full-pipeline cell
def check_saved_models():
    """Check what model checkpoints are available"""
    print_header("CHECKPOINT MANAGER", "üíæ")
    
    # Model directory
    model_dir = NNUNET_RESULTS / f"Dataset{DATASET_ID:03d}_{DATASET_NAME}" / f"{PLANNER}__{CONFIGURATION}" / f"fold_{FOLD}"
    
    if not model_dir.exists():
        print("üìÅ Model directory not found")
        print(f"   Expected: {model_dir}")
        print("   üîÑ Models will be saved here after training")
        return None
    
    print(f"üìÅ Model directory: {model_dir}")
    
    # Check for checkpoint files
    checkpoints = {
        "best": model_dir / "checkpoint_best.pth",
        "final": model_dir / "checkpoint_final.pth", 
        "latest": model_dir / "checkpoint_latest.pth"
    }
    
    available_models = []
    total_size = 0
    
    print("\nüíæ Available checkpoints:")
    print("-" * 50)
    
    for checkpoint_type, checkpoint_path in checkpoints.items():
        if checkpoint_path.exists():
            size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
            mtime = checkpoint_path.stat().st_mtime
            import time
            time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(mtime))
            
            print(f"  ‚úÖ {checkpoint_type.upper()}: {checkpoint_path.name}")
            print(f"     Size: {size_mb:.1f}MB | Modified: {time_str}")
            
            available_models.append(checkpoint_type)
            total_size += size_mb
        else:
            print(f"  ‚ùå {checkpoint_type.upper()}: Not found")
    
    # Check for additional files
    print("\nüìä Additional files:")
    print("-" * 50)
    
    additional_files = [
        ("progress.png", "Learning curves"),
        ("training_log.txt", "Training log"),
        ("plans.json", "Model configuration")
    ]
    
    for filename, description in additional_files:
        file_path = model_dir / filename
        if file_path.exists():
            size_kb = file_path.stat().st_size / 1024
            print(f"  ‚úÖ {filename}: {description} ({size_kb:.1f}KB)")
        else:
            # Check for pattern matches (logs often have timestamps)
            if filename == "training_log.txt":
                log_files = list(model_dir.glob("training_log_*.txt"))
                if log_files:
                    for log_file in log_files:
                        size_kb = log_file.stat().st_size / 1024
                        print(f"  ‚úÖ {log_file.name}: Training log ({size_kb:.1f}KB)")
                else:
                    print(f"  ‚ùå {filename}: Not found")
            else:
                print(f"  ‚ùå {filename}: Not found")
    
    if available_models:
        print(f"\nüìà Summary:")
        print(f"  ‚Ä¢ Available models: {', '.join(available_models)}")
        print(f"  ‚Ä¢ Total size: {total_size:.1f}MB")
        print(f"  ‚Ä¢ Ready for inference: {'‚úÖ Yes' if 'best' in available_models else '‚ùå No (need training)'}")
        
        # Show inference command
        if 'best' in available_models:
            print(f"\nüöÄ Ready to use for inference!")
            print(f"   The 'best' model will be automatically used for predictions.")
    
    return model_dir if available_models else None

def cleanup_old_checkpoints(keep_best: bool = True, keep_latest: bool = True):
    """Clean up old checkpoint files to save disk space"""
    model_dir = NNUNET_RESULTS / f"Dataset{DATASET_ID:03d}_{DATASET_NAME}" / f"{PLANNER}__{CONFIGURATION}" / f"fold_{FOLD}"
    
    if not model_dir.exists():
        print("üìÅ No model directory found")
        return
    
    print_header("CHECKPOINT CLEANUP", "üßπ")
    
    total_freed = 0
    files_removed = 0
    
    # Files to potentially clean up
    cleanup_candidates = []
    
    if not keep_latest:
        cleanup_candidates.append("checkpoint_latest.pth")
    
    # Add other checkpoint files if they exist
    for checkpoint_file in model_dir.glob("checkpoint_epoch_*.pth"):
        cleanup_candidates.append(checkpoint_file.name)
    
    print(f"üîç Scanning: {model_dir}")
    print(f"üõ°Ô∏è Protected files: ", end="")
    protected = []
    if keep_best:
        protected.append("checkpoint_best.pth")
    if keep_latest:
        protected.append("checkpoint_latest.pth")
    protected.append("checkpoint_final.pth")
    print(", ".join(protected))
    
    print(f"\nüßπ Cleanup candidates:")
    
    for filename in cleanup_candidates:
        file_path = model_dir / filename
        if file_path.exists():
            size_mb = file_path.stat().st_size / (1024 * 1024)
            print(f"  üóëÔ∏è {filename}: {size_mb:.1f}MB")
            try:
                file_path.unlink()
                total_freed += size_mb
                files_removed += 1
                print(f"     ‚úÖ Removed")
            except Exception as e:
                print(f"     ‚ùå Error: {e}")
        else:
            print(f"  ‚ö™ {filename}: Not found")
    
    print(f"\nüìä Cleanup summary:")
    print(f"  ‚Ä¢ Files removed: {files_removed}")
    print(f"  ‚Ä¢ Space freed: {total_freed:.1f}MB")

print("‚úÖ Checkpoint management functions ready")

In [None]:
# Add this cell BEFORE full-pipeline cell
def convert_tiff_dataset_to_nifti(dataset_path: Path) -> bool:
    """Convert TIFF dataset to NIfTI format to fix 3D image reading issues"""
    
    print("üîÑ Converting TIFF files to NIfTI format...")
    
    # Check if already converted
    images_dir = dataset_path / "imagesTr"
    if images_dir.exists():
        tiff_count = len(list(images_dir.glob("*.tif")))
        nifti_count = len(list(images_dir.glob("*.nii.gz")))
        
        if tiff_count == 0 and nifti_count > 0:
            print(f"  ‚úÖ Already converted: {nifti_count} NIfTI files found")
            return True
    
    try:
        import tifffile
    except ImportError:
        print("  Installing tifffile...")
        import subprocess
        subprocess.run([sys.executable, "-m", "pip", "install", "tifffile", "-q"])
        import tifffile
    
    converted_count = 0
    
    # Convert images
    if images_dir.exists():
        tiff_files = list(images_dir.glob("*.tif"))
        if tiff_files:
            print(f"  Converting {len(tiff_files)} image files...")
            
            for tiff_file in tqdm(tiff_files, desc="Images"):
                nifti_file = tiff_file.with_suffix('').with_suffix('.nii.gz')
                
                if not nifti_file.exists():
                    try:
                        # Load TIFF
                        img_data = tifffile.imread(str(tiff_file))
                        
                        # Ensure 3D
                        if len(img_data.shape) == 2:
                            img_data = img_data[np.newaxis, :, :]
                        
                        # Save as NIfTI
                        affine = np.eye(4)
                        nifti_img = nib.Nifti1Image(img_data.astype(np.float32), affine)
                        nib.save(nifti_img, str(nifti_file))
                        
                        # Remove TIFF
                        tiff_file.unlink()
                        converted_count += 1
                        
                    except Exception as e:
                        print(f"    ‚ùå Error: {tiff_file.name} - {e}")
    
    # Convert labels
    labels_dir = dataset_path / "labelsTr"
    if labels_dir.exists():
        tiff_files = list(labels_dir.glob("*.tif"))
        if tiff_files:
            print(f"  Converting {len(tiff_files)} label files...")
            
            for tiff_file in tqdm(tiff_files, desc="Labels"):
                nifti_file = tiff_file.with_suffix('').with_suffix('.nii.gz')
                
                if not nifti_file.exists():
                    try:
                        # Load TIFF
                        label_data = tifffile.imread(str(tiff_file))
                        
                        # Ensure 3D
                        if len(label_data.shape) == 2:
                            label_data = label_data[np.newaxis, :, :]
                        
                        # Save as NIfTI
                        affine = np.eye(4)
                        nifti_img = nib.Nifti1Image(label_data.astype(np.uint8), affine)
                        nib.save(nifti_img, str(nifti_file))
                        
                        # Remove TIFF
                        tiff_file.unlink()
                        converted_count += 1
                        
                    except Exception as e:
                        print(f"    ‚ùå Error: {tiff_file.name} - {e}")
    
    # Update dataset.json
    json_path = dataset_path / "dataset.json"
    if json_path.exists():
        with open(json_path, 'r') as f:
            config = json.load(f)
        
        config['file_ending'] = '.nii.gz'
        
        # Remove SimpleTiffIO references
        if 'overwrite_image_reader_writer' in config:
            del config['overwrite_image_reader_writer']
        if 'ioclass' in config:
            del config['ioclass']
        
        with open(json_path, 'w') as f:
            json.dump(config, f, indent=4)
        
        print("  ‚úÖ Updated dataset.json for NIfTI format")
    
    if converted_count > 0:
        print(f"‚úÖ Converted {converted_count} files to NIfTI format")
    
    return True

print("‚úÖ TIFF to NIfTI conversion function ready")

In [None]:
# Run the complete pipeline
# Adjust parameters based on your needs:

# For quick testing (few cases, no training)
# success = full_pipeline(do_train=False, do_inference=False, max_cases=5)

# For full training run
success = full_pipeline(max_cases=50)  # Limit cases for testing

if success:
    print("\nüéä All done! Check /workspace/results/ for outputs")
else:
    print("\n‚ùå Pipeline failed. Check error messages above.")

In [None]:
def create_submission_zip():
    """Create submission ZIP from predictions."""
    import zipfile
    
    predictions_dir = OUTPUT_DIR / "predictions_tiff"
    submission_path = OUTPUT_DIR / "submission.zip"
    
    if not predictions_dir.exists():
        print(f"‚ùå No predictions found at {predictions_dir}")
        return None
    
    tiff_files = list(predictions_dir.glob("*.tif"))
    if not tiff_files:
        print("‚ùå No TIFF files found")
        return None
    
    print(f"üì¶ Creating submission with {len(tiff_files)} files...")
    
    with zipfile.ZipFile(submission_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for tiff_path in tqdm(tiff_files, desc="Zipping"):
            zipf.write(tiff_path, tiff_path.name)
    
    size_mb = submission_path.stat().st_size / (1024 * 1024)
    print(f"‚úÖ Submission created: {submission_path} ({size_mb:.1f} MB)")
    
    return submission_path

# Create submission if we have predictions
if success and (OUTPUT_DIR / "predictions_tiff").exists():
    submission_zip = create_submission_zip()
    if submission_zip:
        print(f"\nüèÜ Ready for submission: {submission_zip}")
else:
    print("\n‚ö†Ô∏è No predictions available for submission")