## Inference Pipeline Overview

This notebook runs YOLO-based detection of bacterial motors in tomograms with Test-Time Augmentation (TTA).

**Main steps:**
1. **Setup & Configuration** – Install dependencies, set paths, configure device and parameters.
2. **Utility Functions** – Profiling, normalization, and image loading helpers.
3. **TTA Inference** – Run YOLO with original, horizontal flip, and vertical flip; map detections back to original coordinates.
4. **3D NMS & Processing** – Merge detections across slices into final 3D predictions.
5. **Submission Generation** – Process all tomograms, apply TTA + 3D NMS, and save results to `submission.csv`.
6. **Main Execution** – Run the full pipeline and report runtime.

# Part 1 — Setup, Installs, Imports & Global Config

In [None]:
# =========================================
# Part 1 — Setup, Installs, Imports & Global Config
# =========================================

# (Kaggle) Optional installs - keep as comments or enable if needed.
# !tar xfvz /kaggle/input/ultralytics-for-offline-install/archive.tar.gz
# !pip install --no-index --find-links=./packages ultralytics
!pip install -q plotly scikit-learn
!rm -rf ./packages
!pip install -q /kaggle/input/ultralytics-timm/ultralytics-8.3.133-py3-none-any.whl --no-deps

import os, time, threading, random
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from pathlib import Path

import numpy as np
import pandas as pd
import cv2
from PIL import Image
import torch
from ultralytics import YOLO

# ---- Paths & output ----
DATA_DIR = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025'
TRAIN_CSV = os.path.join(DATA_DIR, 'train_labels.csv')
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR  = os.path.join(DATA_DIR, 'test')
OUTPUT_DIR = './'
MODEL_DIR  = './models'

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# ---- Device & seeds ----
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True

# (Some duplication kept from original for safety)
np.random.seed(42)
torch.manual_seed(42)

# ---- Submission & model paths ----
data_path       = DATA_DIR
test_dir        = os.path.join(data_path, "test")
submission_path = "/kaggle/working/submission.csv"
model_path      = "/kaggle/input/yolov10b_trust2/pytorch/default/3/best_10m_add_new_dataset.pt"

# ---- Inference params ----
CONFIDENCE_THRESHOLD   = 0.45
MAX_DETECTIONS_PER_TOMO = 3
NMS_IOU_THRESHOLD      = 0.2
CONCENTRATION          = 1  # Fraction of slices to process (1 = all)

# ---- TTA configuration ----
ENABLE_TTA = True
# Supported modes: 'orig', 'hflip', 'vflip'
TTA_MODES = ['orig', 'hflip', 'vflip']

# ---- Dynamic batch sizing (kept for future use; TTA runs per image) ----
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8
if device.startswith('cuda'):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem  = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name} ({gpu_mem:.2f} GB)")
    free_mem = gpu_mem - torch.cuda.memory_allocated(0) / 1e9
    BATCH_SIZE = max(8, min(32, int(free_mem * 4)))
    print(f"Dynamic batch size = {BATCH_SIZE} (free ~{free_mem:.2f} GB)")
else:
    print("GPU not available, using CPU")
    BATCH_SIZE = 4


# Part 2 — Utilities (profiling, IO helpers, normalization)


In [None]:
# Part 2 — Utilities (profiling, IO helpers, normalization)
# =========================================

class GPUProfiler:
    """Simple context manager to profile CUDA/CPU sections."""
    def __init__(self, name: str):
        self.name = name
        self.start_time = None
    def __enter__(self):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self.start_time = time.time()
        return self
    def __exit__(self, *args):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.time() - self.start_time
        print(f"[PROFILE] {self.name}: {elapsed:.3f}s")

def normalize_slice(slice_data: np.ndarray) -> np.ndarray:
    """Contrast-stretch using 2nd–98th percentiles, returning uint8 [0..255]."""
    p2, p98 = np.percentile(slice_data, [2, 98])
    clipped = np.clip(slice_data, p2, p98)
    norm = 255 * (clipped - p2) / max(1e-6, (p98 - p2))
    return np.uint8(norm)

def preload_image_batch(file_paths):
    """Preload a batch of images into CPU memory (unused with TTA per-image inference)."""
    images = []
    for path in file_paths:
        img = cv2.imread(path)
        if img is None:
            img = np.array(Image.open(path))
        images.append(img)
    return images


# Part 3 — TTA Inference (run YOLO with flips & map boxes back)


In [None]:
def _apply_tta(image_bgr: np.ndarray, mode: str) -> np.ndarray:
    """Return augmented image for a given TTA mode."""
    if mode == 'orig':
        return image_bgr
    if mode == 'hflip':
        return cv2.flip(image_bgr, 1)
    if mode == 'vflip':
        return cv2.flip(image_bgr, 0)
    raise ValueError(f"Unsupported TTA mode: {mode}")

def _map_back_xyxy(xyxy: np.ndarray, mode: str, w: int, h: int) -> np.ndarray:
    """
    Map [x1, y1, x2, y2] from TTA'ed image back to original orientation.
    """
    x1, y1, x2, y2 = xyxy
    if mode == 'orig':
        return np.array([x1, y1, x2, y2], dtype=np.float32)
    if mode == 'hflip':
        # x' = w-1-x ; swap after flipping box ends
        nx1, nx2 = (w - 1 - x2), (w - 1 - x1)
        return np.array([nx1, y1, nx2, y2], dtype=np.float32)
    if mode == 'vflip':
        ny1, ny2 = (h - 1 - y2), (h - 1 - y1)
        return np.array([x1, ny1, x2, ny2], dtype=np.float32)
    raise ValueError(f"Unsupported TTA mode: {mode}")

def infer_image_with_tta(model, img_path: str, conf_thr: float):
    """
    Run YOLO on a single image with TTA and return a list of detections
    mapped back to the original orientation. Each detection is a dict:
    {'x': center_x, 'y': center_y, 'confidence': float}
    """
    # Load BGR image (3-ch). YOLO expects color images.
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    if img is None:
        # Fall back to PIL if needed
        img = cv2.cvtColor(np.array(Image.open(img_path)), cv2.COLOR_RGB2BGR)
    h, w = img.shape[:2]

    all_boxes = []
    all_confs = []

    modes = TTA_MODES if ENABLE_TTA else ['orig']
    for mode in modes:
        aug = _apply_tta(img, mode)

        # Run model on numpy image (Ultralytics supports np arrays)
        results = model([aug], verbose=False)
        res = results[0]
        if res.boxes is None or len(res.boxes) == 0:
            continue

        # Map boxes back to original orientation
        for bi in range(len(res.boxes)):
            conf = float(res.boxes.conf[bi])
            if conf < conf_thr:
                continue
            xyxy = res.boxes.xyxy[bi].cpu().numpy()
            xyxy_back = _map_back_xyxy(xyxy, mode, w, h)
            all_boxes.append(xyxy_back)
            all_confs.append(conf)

    detections = []
    for xyxy, conf in zip(all_boxes, all_confs):
        x1, y1, x2, y2 = xyxy
        cx = (x1 + x2) / 2.0
        cy = (y1 + y2) / 2.0
        detections.append({'x': float(cx), 'y': float(cy), 'confidence': float(conf)})

    return detections

# Part 4 — 3D NMS & Per-Tomogram Processing


In [None]:
def perform_3d_nms(detections, iou_threshold: float):
    """
    Lightweight 3D NMS using a distance threshold based on an approximate box size.
    Detections are dicts with keys: 'z','y','x','confidence'.
    """
    if not detections:
        return []

    detections = sorted(detections, key=lambda d: d['confidence'], reverse=True)
    final_dets = []

    def distance_3d(d1, d2):
        return np.sqrt((d1['z'] - d2['z'])**2 + (d1['y'] - d2['y'])**2 + (d1['x'] - d2['x'])**2)

    # Approximate lateral box size (pixels); threshold is scaled by IOU-like factor.
    box_size = 24
    dist_thr = box_size * iou_threshold

    while detections:
        best = detections.pop(0)
        final_dets.append(best)
        detections = [d for d in detections if distance_3d(d, best) > dist_thr]

    return final_dets

def process_tomogram(tomo_id: str, model, index=0, total=1):
    """
    Process a single tomogram (folder of slices). Uses per-image TTA.
    Returns best detection mapped to submission format.
    """
    print(f"Processing tomogram {tomo_id} ({index}/{total})")
    tomo_dir = os.path.join(test_dir, tomo_id)
    slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])

    # Subsample by CONCENTRATION (for quick submissions)
    select_idx = np.linspace(0, len(slice_files) - 1, max(1, int(len(slice_files) * CONCENTRATION)))
    select_idx = np.round(select_idx).astype(int)
    slice_files = [slice_files[i] for i in select_idx]
    print(f"Using {len(slice_files)} / {len(os.listdir(tomo_dir))} slices (CONCENTRATION={CONCENTRATION})")

    all_detections = []

    # Streams kept for parity; TTA runs per image (loop below).
    streams = [torch.cuda.Stream() for _ in range(min(4, BATCH_SIZE))] if device.startswith('cuda') else [None]

    # Iterate slices; per-slice TTA inference
    for k, slice_file in enumerate(slice_files):
        stream = streams[k % len(streams)]
        with torch.cuda.stream(stream) if (stream and device.startswith('cuda')) else nullcontext():
            img_path = os.path.join(tomo_dir, slice_file)
            slice_num = int(slice_file.split('_')[1].split('.')[0])

            with GPUProfiler(f"Inference (TTA) slice {k+1}/{len(slice_files)}"):
                dets = infer_image_with_tta(model, img_path, CONFIDENCE_THRESHOLD)

            # Collect detections with z index
            for d in dets:
                all_detections.append({
                    'z': round(slice_num),
                    'y': round(d['y']),
                    'x': round(d['x']),
                    'confidence': float(d['confidence'])
                })

    if device.startswith('cuda'):
        torch.cuda.synchronize()

    # 3D NMS across the whole volume
    final_dets = perform_3d_nms(all_detections, NMS_IOU_THRESHOLD)
    final_dets.sort(key=lambda x: x['confidence'], reverse=True)

    if not final_dets:
        return {'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1}

    best = final_dets[0]
    return {
        'tomo_id': tomo_id,
        'Motor axis 0': int(best['z']),
        'Motor axis 1': int(best['y']),
        'Motor axis 2': int(best['x']),
    }


# Part 5 — Debug helpers & Submission generation


In [None]:
def debug_image_loading(tomo_id: str):
    """Quick check: can we read images with PIL/CV2 and run the model on one frame."""
    tomo_dir = os.path.join(test_dir, tomo_id)
    slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
    if not slice_files:
        print(f"No image files found in {tomo_dir}")
        return

    print(f"Found {len(slice_files)} image files in {tomo_dir}")
    sample = slice_files[len(slice_files) // 2]
    img_path = os.path.join(tomo_dir, sample)

    try:
        img_pil = Image.open(img_path)
        print(f"PIL shape={np.array(img_pil).shape}, dtype={np.array(img_pil).dtype}")
        img_cv2 = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        print(f"CV2 gray shape={img_cv2.shape}, dtype={img_cv2.dtype}")
        img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        print(f"CV2 RGB shape={img_rgb.shape}, dtype={img_rgb.dtype}")
        print("Image loading OK.")
    except Exception as e:
        print(f"Image loading error: {e}")

    try:
        test_model = YOLO(model_path)
        _ = test_model([cv2.imread(img_path)], verbose=False)
        print("YOLO processed the test image OK.")
    except Exception as e:
        print(f"YOLO test error: {e}")

def generate_submission():
    """
    Orchestrate end-to-end inference over all tomograms and save submission.csv.
    """
    test_tomos = sorted([d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))])
    total = len(test_tomos)
    print(f"Found {total} tomograms in: {test_dir}")

    if total > 0:
        debug_image_loading(test_tomos[0])

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"Loading YOLO model from: {model_path}")
    model = YOLO(model_path)
    model.to(device)

    # Optional speed/precision tweaks
    if device.startswith('cuda'):
        try:
            model.fuse()
        except Exception:
            pass
        if torch.cuda.get_device_capability(0)[0] >= 7:
            try:
                model.model.half()
                print("Using half precision (FP16) for inference")
            except Exception:
                print("FP16 not applied (model backend may not support).")

    results = []
    motors_found = 0

    # Simple single-thread orchestration (safe with per-image TTA)
    for i, tomo_id in enumerate(test_tomos, 1):
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            out = process_tomogram(tomo_id, model, i, total)
            results.append(out)

            # In this pipeline, -1 indicates "no detection"
            has_motor = (out['Motor axis 0'] >= 0)
            if has_motor:
                motors_found += 1
                print(f"Motor in {tomo_id}: z={out['Motor axis 0']}, y={out['Motor axis 1']}, x={out['Motor axis 2']}")
            else:
                print(f"No motor detected in {tomo_id}")

            rate = 100.0 * motors_found / len(results)
            print(f"Current detection rate: {motors_found}/{len(results)} ({rate:.1f}%)")

        except Exception as e:
            print(f"Error processing {tomo_id}: {e}")
            results.append({'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1})

    # Save submission
    submission_df = pd.DataFrame(results)[['tomo_id', 'Motor axis 0', 'Motor axis 1', 'Motor axis 2']]
    submission_df.to_csv(submission_path, index=False)

    print("\nSubmission complete!")
    print(f"Motors detected: {motors_found}/{total} ({(100.0*motors_found/max(1,total)):.1f}%)")
    print(f"Saved to: {submission_path}")
    print("\nPreview:")
    print(submission_df.head())
    return submission_df


# Part 6 — Main


In [None]:
if __name__ == "__main__":
    start = time.time()
    submission = generate_submission()
    elapsed = time.time() - start
    print(f"\nTotal execution time: {elapsed:.2f} s ({elapsed/60:.2f} min)")