In [1]:
import os
import sys
import shutil
import numpy as np
import torch
import pydicom
import pandas as pd
import polars as pl
from pathlib import Path
from collections import defaultdict
import cv2
import kaggle_evaluation.rsna_inference_server

# Add model source to path
sys.path.append('/kaggle/input/rnsa-aneurysm-detection')
sys.path.append('/kaggle/input/rnsa-aneurysm-detection/src')

# Import model
from model import Model

  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model


The evaluation API requires that you set up a server which will respond to inference requests. We have already defined the server; you just need write the predict function. When we evaluate your submission on the hidden test set the client defined in `rsna_gateway` will run in a different container with direct access to the hidden test set and hand off the data series by series.

Your code will always have access to the published copies of the files.

In [2]:
# Constants
ID_COL = 'SeriesInstanceUID'
LABEL_COLS = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
    'Aneurysm Present',
]

# ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# List of checkpoint paths
CHECKPOINT_PATHS = [
    "/kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp6_ep44.pt",
    "/kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp5_ep6.pt",
    "/kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp8_fold1_ep16.pt",
    "/kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp9_fold2_ep10.pt",
    "/kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp9_fold3_ep25.pt",
]

# Initialize models globally - loads once when module is imported
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Loading models...")
models = []
for checkpoint_path in CHECKPOINT_PATHS:
    model = Model(
        pre=None,
        num_classes=14,
        ps=0.1,
        mask_head=False
    ).to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    models.append(model)
    print(f"Model loaded successfully from {checkpoint_path}")

print(f"Total models loaded: {len(models)}")

def get_slice_position(ds):
    """Get slice position for sorting"""
    if hasattr(ds, 'ImagePositionPatient') and ds.ImagePositionPatient:
        return float(ds.ImagePositionPatient[2])
    if hasattr(ds, 'SliceLocation'):
        return float(ds.SliceLocation)
    if hasattr(ds, 'InstanceNumber'):
        return float(ds.InstanceNumber)
    return 0.0

def should_rescale_ct(ds, pixel_array):
    """Determine if CT should be rescaled"""
    if ds.get('Modality', '') != 'CT':
        return False
    if not (hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept')):
        return False
    min_pixel = pixel_array.min()
    if min_pixel >= -100 or min_pixel == -2000:
        return True
    return False

def resize_slice_to_384(slice_array):
    """Resize slice to 384x384 using cubic interpolation"""
    if slice_array.shape == (384, 384):
        return slice_array
    
    try:
        resized = cv2.resize(slice_array.astype(np.float32), (384, 384), 
                           interpolation=cv2.INTER_CUBIC)
        return resized
    except:
        from scipy.ndimage import zoom
        zoom_factors = (384 / slice_array.shape[0], 384 / slice_array.shape[1])
        return zoom(slice_array.astype(np.float32), zoom_factors, order=3)

def normalize_slab(slab):
    """Apply ImageNet normalization to slab"""
    for c in range(3):
        slab[c] = (slab[c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
    return slab

def process_dicom_series(series_path):
    """Process DICOM series into 384x384 volume"""
    dcm_files = []
    for root, _, files in os.walk(series_path):
        for file in files:
            if file.endswith('.dcm'):
                dcm_files.append(os.path.join(root, file))
    
    if len(dcm_files) == 0:
        raise ValueError("No DICOM files found")
    
    slice_data = []
    for dcm_file in dcm_files:
        try:
            ds = pydicom.dcmread(dcm_file, force=True)
            pixel_array = ds.pixel_array
            
            if len(pixel_array.shape) == 3:
                for slice_idx in range(pixel_array.shape[0]):
                    slice_position = get_slice_position(ds) + slice_idx
                    slice_data.append({
                        'dataset': ds,
                        'pixel_array': pixel_array[slice_idx],
                        'position': slice_position
                    })
            else:
                slice_data.append({
                    'dataset': ds,
                    'pixel_array': pixel_array,
                    'position': get_slice_position(ds)
                })
        except Exception as e:
            print(f"Error reading {dcm_file}: {e}")
            continue
    
    if len(slice_data) == 0:
        raise ValueError("No valid slices found")
    
    slice_data.sort(key=lambda x: x['position'])
    
    processed_slices = []
    first_ds = slice_data[0]['dataset']
    modality = first_ds.get('Modality', 'Unknown')
    
    for slice_info in slice_data:
        ds = slice_info['dataset']
        pixel_array = slice_info['pixel_array']
        
        if modality == 'CT':
            if should_rescale_ct(ds, pixel_array):
                slope = float(ds.RescaleSlope)
                intercept = float(ds.RescaleIntercept)
                hu_array = pixel_array * slope + intercept
            else:
                hu_array = pixel_array.astype(np.float32)
        else:
            slope = float(getattr(ds, 'RescaleSlope', 1.0))
            intercept = float(getattr(ds, 'RescaleIntercept', 0.0))
            hu_array = pixel_array * slope + intercept
        
        resized_slice = resize_slice_to_384(hu_array)
        processed_slices.append(resized_slice)
    
    volume = np.stack(processed_slices, axis=0)
    
    global_min = volume.min()
    global_max = volume.max()
    
    if global_max - global_min == 0:
        volume_uint8 = np.zeros_like(volume, dtype=np.uint8)
    else:
        normalized = (volume - global_min) / (global_max - global_min)
        volume_uint8 = (normalized * 255).astype(np.uint8)
    
    return volume_uint8

def create_slabs(volume, slab_size=3):
    """Create slabs of 3 slices from volume with minimal overlap"""
    num_slices = volume.shape[0]
    slabs = []
    
    if num_slices <= slab_size:
        if num_slices == 1:
            slab = np.stack([volume[0], volume[0], volume[0]], axis=0)
        elif num_slices == 2:
            slab = np.stack([volume[0], volume[1], volume[1]], axis=0)
        else:
            slab = volume
        slabs.append(slab)
    else:
        num_complete_slabs = num_slices // slab_size
        for i in range(num_complete_slabs):
            start_idx = i * slab_size
            slab = volume[start_idx:start_idx + slab_size]
            slabs.append(slab)
        
        remainder = num_slices % slab_size
        if remainder != 0:
            last_slab = volume[-slab_size:]
            slabs.append(last_slab)
    
    return slabs

def predict(series_path: str) -> pl.DataFrame:
    """Make predictions for a DICOM series using ensemble of models with batch processing"""
    try:
        # Use global models directly (already loaded at module level)
        series_id = os.path.basename(series_path)
        
        volume_uint8 = process_dicom_series(series_path)
        print(f"Processed volume shape: {volume_uint8.shape}")
        
        slabs = create_slabs(volume_uint8, slab_size=3)
        print(f"Created {len(slabs)} slabs")
        
        # Store predictions from all models
        model_predictions = []
        
        # Run inference with each model
        for model_idx, model in enumerate(models):
            all_predictions = []
            
            with torch.no_grad():
                # Process slabs in batches of 32
                batch_size = 32
                for i in range(0, len(slabs), batch_size):
                    batch_slabs = slabs[i:i + batch_size]
                    
                    # Prepare batch tensor
                    batch_tensors = []
                    for slab in batch_slabs:
                        slab_float = slab.astype(np.float32) / 255.0
                        slab_normalized = normalize_slab(slab_float.copy())
                        slab_tensor = torch.from_numpy(slab_normalized).float()
                        batch_tensors.append(slab_tensor)
                    
                    # Stack into batch
                    batch_tensor = torch.stack(batch_tensors, dim=0).to(device)
                    
                    # Get predictions for batch
                    cls_output, _ = model(batch_tensor)
                    predictions = torch.sigmoid(cls_output)
                    all_predictions.append(predictions.cpu())
            
            if len(all_predictions) > 0:
                all_predictions = torch.cat(all_predictions, dim=0)
                final_predictions = torch.max(all_predictions, dim=0)[0].numpy()
            else:
                final_predictions = np.zeros(14)
            
            model_predictions.append(final_predictions)
            print(f"Model {model_idx + 1} predictions for {series_id}: {final_predictions}")
        
        # Calculate mean prediction across all models
        final_predictions = np.mean(model_predictions, axis=0)
        print(f"Mean ensemble predictions for {series_id}: {final_predictions}")
        
        result_data = [[series_id] + final_predictions.tolist()]
        predictions_df = pl.DataFrame(
            data=result_data,
            schema=[ID_COL, *LABEL_COLS],
            orient='row',
        )

    except Exception as e:
        print(f"Error processing {series_path}: {e}")
        predictions_df = pl.DataFrame(
            data=[[series_id] + [0.5] * len(LABEL_COLS)],
            schema=[ID_COL, *LABEL_COLS],
            orient='row',
        )
    
    shutil.rmtree('/kaggle/shared', ignore_errors=True)
    
    return predictions_df.drop(ID_COL)

Loading models...
Model loaded successfully from /kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp6_ep44.pt
Model loaded successfully from /kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp5_ep6.pt
Model loaded successfully from /kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp8_fold1_ep16.pt
Model loaded successfully from /kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp9_fold2_ep10.pt
Model loaded successfully from /kaggle/input/rnsa-aneurysm-detection/ckpt/ckpt_exp9_fold3_ep25.pt
Total models loaded: 5


When your notebook is run on the hidden test set, `inference_server.serve` must be called within 15 minutes of the notebook starting or the gateway will throw an error. If you need more than 15 minutes to load your model you can do so during the very first `predict` call.

In [3]:
# Initialize inference server
inference_server = kaggle_evaluation.rsna_inference_server.RSNAInferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway()
    # Display results for local testing
    if os.path.exists('/kaggle/working/submission.parquet'):
        display(pl.read_parquet('/kaggle/working/submission.parquet'))

Processed volume shape: (176, 384, 384)
Created 59 slabs
Model 1 predictions for 1.2.826.0.1.3680043.8.498.10028406715369553772267826812576760572: [1.08479544e-04 7.98903685e-03 3.95630836e-04 1.94816440e-02
 9.43214595e-01 2.86105368e-03 1.07355630e-02 3.43987660e-04
 9.03945824e-04 2.18437955e-04 1.03904679e-03 4.22077865e-05
 6.82897747e-01 9.63146687e-01]
Model 2 predictions for 1.2.826.0.1.3680043.8.498.10028406715369553772267826812576760572: [0.00876737 0.01225688 0.27784508 0.07556773 0.2073177  0.24037255
 0.04116516 0.0195063  0.05652323 0.23192322 0.07079834 0.06500311
 0.06092188 0.7181832 ]
Model 3 predictions for 1.2.826.0.1.3680043.8.498.10028406715369553772267826812576760572: [6.8103801e-04 1.0048762e-03 4.1038734e-03 1.9241272e-02 6.5669268e-03
 8.9743900e-01 3.1658872e-03 1.8624212e-03 1.9396362e-03 1.6632354e-02
 3.4094849e-03 7.4304151e-03 2.6800539e-03 7.5367957e-01]
Model 4 predictions for 1.2.826.0.1.3680043.8.498.10028406715369553772267826812576760572: [2.1193529

SeriesInstanceUID,Left Infraclinoid Internal Carotid Artery,Right Infraclinoid Internal Carotid Artery,Left Supraclinoid Internal Carotid Artery,Right Supraclinoid Internal Carotid Artery,Left Middle Cerebral Artery,Right Middle Cerebral Artery,Anterior Communicating Artery,Left Anterior Cerebral Artery,Right Anterior Cerebral Artery,Left Posterior Communicating Artery,Right Posterior Communicating Artery,Basilar Tip,Other Posterior Circulation,Aneurysm Present
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""1.2.826.0.1.3680043.8.498.1002…",0.006189,0.004943,0.089363,0.043797,0.234382,0.430363,0.031402,0.010978,0.078192,0.052234,0.064905,0.014573,0.199948,0.864224
"""1.2.826.0.1.3680043.8.498.1007…",0.010304,0.008139,0.00182,0.007932,0.015152,0.011231,0.227326,0.002435,0.00308,0.000868,0.020028,0.068763,0.069783,0.388374
"""1.2.826.0.1.3680043.8.498.1005…",0.113353,0.268212,0.176721,0.181126,0.120893,0.207127,0.131469,0.007427,0.024737,0.015305,0.182111,0.167448,0.035449,0.628091
