## YOLO26 Nano Training for MIAP Person Detection

This notebook performs the following steps:
1.  **Setup**: Installs necessary libraries and connects to Google Drive and Google Cloud for data access.
2.  **Data Preparation**: 
    - Downloads annotations for the MIAP (person subset of Google Open Dataset) from a GCS bucket.
    - Downloads the corresponding images from the public Open Images dataset bucket.
    - Filters out images containing only very small bounding boxes (width or height < 6 pixels).
    - Converts the dataset into the YOLO format required by Ultralytics.
3.  **Training**: Trains a `yolo26n` (nano) model on the prepared dataset.
4.  **Export**: 
    - Exports the trained model to ONNX format, including preprocessing and NMS post-processing.
    - Creates a Float32 version.
    - Creates an INT8 quantized version using static calibration.
5.  **Save & Verify**: Saves the final models to Google Drive and runs a quick verification with ONNX Runtime.

In [None]:
%pip install ultralytics onnx onnxruntime onnxsim pandas gcsfs tqdm -q

### 1. Setup and Authentication

Mount Google Drive to save the final models and authenticate with Google Cloud to access the dataset annotations.

In [None]:
from google.colab import drive, auth
import os

print("Mounting Google Drive...")
drive.mount('/content/drive')

print("Authenticating with Google Cloud...")
auth.authenticate_user()

# Define a directory in your Google Drive to save the models
GDRIVE_SAVE_DIR = '/content/drive/MyDrive/miap_yolov26_models'
os.makedirs(GDRIVE_SAVE_DIR, exist_ok=True)
print(f"Models will be saved to: {GDRIVE_SAVE_DIR}")

In [None]:
import onnx
import numpy as np
from onnx import helper, numpy_helper, TensorProto
import onnxruntime as ort
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantFormat, QuantType, CalibrationMethod

def embed_uint8_preprocess_into_onnx(in_onnx, out_onnx, imgsz, input_scale=1.0/255.0):
    """
    Embeds uint8 normalization and resize into the ONNX graph.
    Input: [1, 3, H, W] uint8
    Output: [1, 3, imgsz, imgsz] float32
    """
    model = onnx.load(in_onnx)
    orig_in_name = model.graph.input[0].name
    
    new_input_name = "images_uint8"
    new_input = helper.make_tensor_value_info(new_input_name, TensorProto.UINT8, [1, 3, "H", "W"])
    
    cast_out = f"{orig_in_name}__cast_f32"
    scaled_out = f"{orig_in_name}__scaled"
    resize_out = orig_in_name

    nodes = [
        helper.make_node("Cast", inputs=[new_input_name], outputs=[cast_out], to=TensorProto.FLOAT),
        helper.make_node("Mul", inputs=[cast_out, "Preprocess_Scale_Val"], outputs=[scaled_out])
    ]
    
    # Preprocessing Initializers
    scale_tensor = numpy_helper.from_array(np.array([input_scale], dtype=np.float32), name="Preprocess_Scale_Val")
    roi = numpy_helper.from_array(np.array([], dtype=np.float32), name="Preprocess_ROI")
    scales = numpy_helper.from_array(np.array([], dtype=np.float32), name="Preprocess_Scales")
    sizes = numpy_helper.from_array(np.array([1, 3, imgsz, imgsz], dtype=np.int64), name="Preprocess_Sizes")
    
    model.graph.initializer.extend([scale_tensor, roi, scales, sizes])
    
    nodes.append(helper.make_node("Resize", 
                                  inputs=[scaled_out, "Preprocess_ROI", "Preprocess_Scales", "Preprocess_Sizes"], 
                                  outputs=[resize_out], mode="linear"))

    # Reconstruct graph
    del model.graph.input[0]
    model.graph.input.insert(0, new_input)
    model.graph.node[:0] = nodes
    
    onnx.checker.check_model(model)
    onnx.save(model, out_onnx)
    print(f"Preprocessed ONNX saved to {out_onnx}")
    return out_onnx

def fix_onnx_outputs(onnx_path):
    """Renames outputs for consistency."""
    model = onnx.load(onnx_path)
    for i, out in enumerate(model.graph.output):
        if "output" in out.name or i == 0:
            out.name = "detections"
    # Update nodes producing these outputs
    for node in model.graph.node:
        for i, name in enumerate(node.output):
            if "output" in name: node.output[i] = "detections"
    onnx.save(model, onnx_path)


### 2. Data Preparation

Download the `vertex_miap_import.csv` file, which contains GCS paths to the images and their corresponding bounding box annotations.

In [None]:
import pandas as pd
import os
import gcsfs
from PIL import Image
from tqdm.notebook import tqdm
import numpy as np
import random

# CONFIGURATION
GCS_BUCKET = 'colin-miap-madness'
CSV_FILENAME = 'vertex_miap_import.csv'
GCS_CSV_PATH = f'gs://{GCS_BUCKET}/{CSV_FILENAME}'
IMAGE_SIZE = 320
MIN_BOX_PIXEL_SIZE = 6
VAL_SPLIT = 0.07  # 7% for validation
n_samples = 10000

# Initialize GCS FileSystem (uses default project/auth from Colab environment)
fs = gcsfs.GCSFileSystem()

DATASET_ROOT = '/content/datasets/miap_single_class'
for split in ['train', 'val']:
    os.makedirs(os.path.join(DATASET_ROOT, 'images', split), exist_ok=True)
    os.makedirs(os.path.join(DATASET_ROOT, 'labels', split), exist_ok=True)

print('Reading annotations CSV from GCS...')
col_names = ['ml_use', 'gcs_path', 'label', 'x_min', 'y_min', 'c1', 'c2', 'x_max', 'y_max', 'c3', 'c4']
df = pd.read_csv(GCS_CSV_PATH, header=None, names=col_names)
print(f'Found {len(df)} annotations.')

grouped = list(df.groupby('gcs_path'))
random.seed(42)
random.shuffle(grouped)

images_processed = {'train': 0, 'val': 0}
images_dropped = 0

print(f'Processing and filtering images (min_box={MIN_BOX_PIXEL_SIZE}px)...')
for gcs_path, group in tqdm(grouped):
    split = 'val' if (images_processed['train'] + images_processed['val']) % int(1/VAL_SPLIT) == 0 else 'train'
    
    image_id = gcs_path.split('/')[-1].replace('.jpg', '')
    # http_url = f'https://storage.googleapis.com/open-images-dataset/train/{image_id}.jpg'
    local_image_path = os.path.join(DATASET_ROOT, 'images', split, f'{image_id}.jpg')
    local_label_path = os.path.join(DATASET_ROOT, 'labels', split, f'{image_id}.txt')

    if not os.path.exists(local_image_path):
        try:
            # Pull directly from GCS bucket to stay on internal network
            # urrllib.request.urlretrieve(http_url, local_image_path)
            fs.get(gcs_path, local_image_path)
        except Exception:
            continue

    try:
        with Image.open(local_image_path) as img: img_w, img_h = img.size
    except Exception: 
        if os.path.exists(local_image_path): os.remove(local_image_path)
        continue

    yolo_labels = []
    # YOLO requires scale factor for small box filtering to match target imgsz
    scale = min(IMAGE_SIZE / img_w, IMAGE_SIZE / img_h)
    
    for _, row in group.iterrows():
        x_min, y_min, x_max, y_max = row['x_min'], row['y_min'], row['x_max'], row['y_max']
        
        # Filter small boxes
        bw_px = (x_max - x_min) * img_w * scale
        bh_px = (y_max - y_min) * img_h * scale
        if bw_px < MIN_BOX_PIXEL_SIZE or bh_px < MIN_BOX_PIXEL_SIZE:
            continue
            
        cx, cy = (x_min + x_max) / 2.0, (y_min + y_max) / 2.0
        w, h = x_max - x_min, y_max - y_min
        yolo_labels.append(f'0 {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}')
    
    if yolo_labels:
        with open(local_label_path, 'w') as f: f.write('\n'.join(yolo_labels))
        images_processed[split] += 1
    else:
        if os.path.exists(local_image_path): os.remove(local_image_path)
        images_dropped += 1
    if images_processed["train"] == n_samples:
        break

print('\n--- Data Preparation Summary ---')
print(f"Training images: {images_processed['train']}")
print(f"Validation images: {images_processed['val']}")
print(f'Images dropped (no valid boxes): {images_dropped}')


#### Create Dataset YAML File

In [None]:
import yaml
dataset_yaml_path = os.path.join(DATASET_ROOT, 'data.yaml')
yaml_content = {
    'path': os.path.abspath(DATASET_ROOT),
    'train': 'images/train',
    'val': 'images/val',
    'names': {0: 'person'}
}
with open(dataset_yaml_path, 'w') as f: yaml.dump(yaml_content, f)
print(f'Dataset YAML created at: {dataset_yaml_path}')


### 3. Model Training

In [None]:
from ultralytics import YOLO

MODEL_VARIANT = 'yolo26n.pt' 

model = YOLO(MODEL_VARIANT)
results = model.train(
    data=dataset_yaml_path, 
    imgsz=IMAGE_SIZE, 
    epochs=50, 
    batch=32, 
    name='miap_person_detector',
    project='runs'
)
print('\nTraining complete!')


### 4. Export to ONNX

In [None]:
import shutil
from ultralytics import YOLO

best_weights_path = os.path.join(results.save_dir, 'weights/best.pt')
model = YOLO(best_weights_path)

print('\n1. Exporting Raw FP32 ONNX model...')
# imgsz is passed as a list/tuple to ensure output match
fp32_raw_path = model.export(format='onnx', imgsz=IMAGE_SIZE, opset=17, simplify=True)

print('\n2. Embedding Preprocessing (uint8 -> float32 -> resize) into model...')
fp32_pre_path = fp32_raw_path.replace('.onnx', '_pre_u8.onnx')
embed_uint8_preprocess_into_onnx(fp32_raw_path, fp32_pre_path, IMAGE_SIZE)

print('\n3. Finalizing FP32 model labels...')
fix_onnx_outputs(fp32_pre_path)
print(f'FP32 ONNX (with preprocessing) saved to: {fp32_pre_path}')


### 5. Advanced Static Quantization (Improved)

YOLO models can be sensitive to static quantization. We use ONNX Runtimeâ€™s advanced quantization tools directly, using **Entropy (KL Divergence)** calibration and a larger calibration set.

In [None]:
import glob

class MIAPCalibrationDataReader(CalibrationDataReader):
    def __init__(self, image_dir, imgsz, max_images=1000):
        self.image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))
        random.shuffle(self.image_paths)
        self.image_paths = self.image_paths[:max_images]
        self.imgsz = imgsz
        self.index = 0
        
        # Get input name from model
        session = ort.InferenceSession(fp32_pre_path, providers=['CPUExecutionProvider'])
        self.input_name = session.get_inputs()[0].name

    def get_next(self):
        if self.index >= len(self.image_paths): return None
        
        # Load and resize to exact imgsz in uint8 to match the new uint8 input
        img = Image.open(self.image_paths[self.index]).convert('RGB')
        img = img.resize((self.imgsz, self.imgsz), Image.BILINEAR)
        input_data = np.array(img).transpose(2, 0, 1)[None, ...].astype(np.uint8)
        
        self.index += 1
        return {self.input_name: input_data}

print('Starting Static Quantization...')
# Calibrate on a representative subset of the training data (e.g. 1000-2000 images is usually plenty)
# but we can go higher if desired. 2000 is a good balance.
dr = MIAPCalibrationDataReader(os.path.join(DATASET_ROOT, 'images', 'train'), IMAGE_SIZE, max_images=2000)

int8_path = fp32_pre_path.replace('.onnx', '_int8.onnx')

quantize_static(
    model_input=fp32_pre_path, 
    model_output=int8_path, 
    calibration_data_reader=dr, 
    quant_format=QuantFormat.QDQ,
    activation_type=QuantType.QUInt8, 
    weight_type=QuantType.QInt8, 
    per_channel=True, 
    reduce_range=False, # Often better for accuracy on non-Intel hardware
    calibrate_method=CalibrationMethod.Entropy # KL Divergence - better for fine gradients
)

print(f'\nINT8 ONNX model saved to: {int8_path}')


### 6. Save & Inspect ONNX Models

In [None]:
def inspect_onnx_model(model_path):
    print(f'\n--- Inspecting: {os.path.basename(model_path)} ---')
    sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    input_nodes = sess.get_inputs()
    print(f'Inputs: {[(n.name, n.shape, n.type) for n in input_nodes]}')
    print(f'Outputs: {[(n.name, n.shape, n.type) for n in sess.get_outputs()]}')
    
    # Test with dummy uint8 input if needed
    if "Uint8" in input_nodes[0].type:
        dummy_input = np.random.randint(0, 255, size=(1, 3, IMAGE_SIZE, IMAGE_SIZE), dtype=np.uint8)
    else:
        dummy_input = np.random.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).astype(np.float32)
        
    outputs = sess.run(None, {input_nodes[0].name: dummy_input})
    print(f'Output shapes: {[o.shape for o in outputs]}')

# Copy results to Drive
final_fp32 = os.path.join(GDRIVE_SAVE_DIR, os.path.basename(fp32_pre_path))
final_int8 = os.path.join(GDRIVE_SAVE_DIR, os.path.basename(int8_path))

shutil.copy2(fp32_pre_path, final_fp32)
shutil.copy2(int8_path, final_int8)

print(f'\nModels saved to Drive: {GDRIVE_SAVE_DIR}')
inspect_onnx_model(final_fp32)
inspect_onnx_model(final_int8)
