# UniFace: MLX vs ONNX Benchmark

This notebook compares the performance and accuracy of MLX and ONNX backends for RetinaFace face detection.

## What we're testing:
- **MLX**: Native Apple Silicon implementation using fused weights (BatchNorm folded into Conv)
- **ONNX**: Standard ONNX Runtime with CoreML acceleration

## Requirements:
- Apple Silicon Mac (M1/M2/M3/M4)
- Both `mlx` and `onnxruntime` installed
- Fused weights generated via `scripts/convert_onnx_to_mlx.py`

In [None]:
import sys
import time
import warnings
from pathlib import Path

import cv2
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

warnings.filterwarnings('ignore')

print(f"Project root: {project_root}")

# Check fused weights exist
fused_weights_path = project_root / "weights_mlx_fused" / "retinaface_mnet_v2.safetensors"
if fused_weights_path.exists():
    print(f"âœ“ Fused weights found: {fused_weights_path}")
else:
    print(f"âœ— Fused weights not found. Run: python scripts/convert_onnx_to_mlx.py")

In [None]:
# Check available backends
import platform

print("System Information:")
print(f"  Platform: {platform.platform()}")
print(f"  Machine: {platform.machine()}")

# Check MLX
try:
    import mlx.core as mx
    # MLX version is in mlx.core
    mlx_version = getattr(mx, '__version__', 'unknown')
    print(f"\nâœ“ MLX available: {mlx_version}")
    mlx_available = True
except ImportError:
    print("\nâœ— MLX not available. Install with: pip install mlx")
    mlx_available = False

# Check ONNX
try:
    import onnxruntime as ort
    providers = ort.get_available_providers()
    print(f"âœ“ ONNX Runtime available: {ort.__version__}")
    print(f"  Providers: {providers}")
    onnx_available = True
except ImportError:
    print("âœ— ONNX not available. Install with: pip install onnxruntime")
    onnx_available = False

## 1. Load Test Image

In [None]:
# Load test image
test_image_path = project_root / "assets" / "test.jpg"

if not test_image_path.exists():
    # Try to find any image in assets
    assets_dir = project_root / "assets"
    if assets_dir.exists():
        images = list(assets_dir.glob("*.jpg")) + list(assets_dir.glob("*.png"))
        if images:
            test_image_path = images[0]

if test_image_path.exists():
    image = cv2.imread(str(test_image_path))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    print(f"Image loaded: {test_image_path.name}")
    print(f"Image shape: {image.shape}")
    
    plt.figure(figsize=(10, 8))
    plt.imshow(image_rgb)
    plt.title("Test Image")
    plt.axis('off')
    plt.show()
else:
    print("No test image available. Please add an image to assets/test.jpg")
    image = None

## 2. Setup Detection Functions

In [None]:
from uniface.common import (
    decode_boxes,
    decode_landmarks,
    generate_anchors,
    non_max_suppression,
    resize_image,
)
from uniface.constants import RetinaFaceWeights
from uniface.model_store import verify_model_weights
from uniface.onnx_utils import create_onnx_session

# For MLX
if mlx_available:
    from uniface.detection.retinaface_mlx import RetinaFaceNetworkFused
    from uniface.mlx_utils import load_mlx_fused_weights, synchronize


def detect_faces_onnx(session, image, conf_thresh=0.5, nms_thresh=0.4):
    """Run face detection using ONNX model."""
    input_size = (640, 640)
    
    # Resize and preprocess
    image_resized, resize_factor = resize_image(image, target_shape=input_size)
    height, width, _ = image_resized.shape
    
    processed = np.float32(image_resized) - np.array([104, 117, 123], dtype=np.float32)
    processed = processed.transpose(2, 0, 1)  # HWC -> CHW
    input_tensor = np.expand_dims(processed, 0)
    
    # Inference
    input_name = session.get_inputs()[0].name
    outputs = session.run(None, {input_name: input_tensor})
    
    loc = outputs[0].squeeze(0)
    conf = outputs[1].squeeze(0)
    landmarks = outputs[2].squeeze(0)
    
    # Decode
    priors = generate_anchors(image_size=input_size)
    boxes = decode_boxes(loc, priors)
    landmarks_decoded = decode_landmarks(landmarks, priors)
    
    # Scale back
    bbox_scale = np.array([width, height] * 2)
    boxes = boxes * bbox_scale / resize_factor
    
    landmark_scale = np.array([width, height] * 5)
    landmarks_decoded = landmarks_decoded * landmark_scale / resize_factor
    
    # Filter and NMS
    scores = conf[:, 1]
    mask = scores > conf_thresh
    boxes, landmarks_decoded, scores = boxes[mask], landmarks_decoded[mask], scores[mask]
    
    order = scores.argsort()[::-1][:5000]
    boxes, landmarks_decoded, scores = boxes[order], landmarks_decoded[order], scores[order]
    
    detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32)
    keep = non_max_suppression(detections, nms_thresh)
    detections, landmarks_decoded = detections[keep], landmarks_decoded[keep]
    
    # Build output
    faces = []
    for i in range(detections.shape[0]):
        faces.append({
            'bbox': detections[i, :4],
            'confidence': float(detections[i, 4]),
            'landmarks': landmarks_decoded[i].reshape(5, 2),
        })
    
    return faces


def detect_faces_mlx(model, image, conf_thresh=0.5, nms_thresh=0.4):
    """Run face detection using MLX model with fused weights."""
    input_size = (640, 640)
    
    # Resize and preprocess
    image_resized, resize_factor = resize_image(image, target_shape=input_size)
    height, width, _ = image_resized.shape
    
    processed = np.float32(image_resized) - np.array([104, 117, 123], dtype=np.float32)
    input_tensor = mx.array(np.expand_dims(processed, 0))  # NHWC format
    
    # Inference
    cls_preds, bbox_preds, landmark_preds = model(input_tensor)
    synchronize(cls_preds, bbox_preds, landmark_preds)
    
    # Apply softmax to get probabilities
    cls_probs = mx.softmax(cls_preds, axis=-1)
    
    loc = np.array(bbox_preds).squeeze(0)
    conf = np.array(cls_probs).squeeze(0)
    landmarks = np.array(landmark_preds).squeeze(0)
    
    # Decode
    priors = generate_anchors(image_size=input_size)
    boxes = decode_boxes(loc, priors)
    landmarks_decoded = decode_landmarks(landmarks, priors)
    
    # Scale back
    bbox_scale = np.array([width, height] * 2)
    boxes = boxes * bbox_scale / resize_factor
    
    landmark_scale = np.array([width, height] * 5)
    landmarks_decoded = landmarks_decoded * landmark_scale / resize_factor
    
    # Filter and NMS
    scores = conf[:, 1]
    mask = scores > conf_thresh
    boxes, landmarks_decoded, scores = boxes[mask], landmarks_decoded[mask], scores[mask]
    
    order = scores.argsort()[::-1][:5000]
    boxes, landmarks_decoded, scores = boxes[order], landmarks_decoded[order], scores[order]
    
    detections = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32)
    keep = non_max_suppression(detections, nms_thresh)
    detections, landmarks_decoded = detections[keep], landmarks_decoded[keep]
    
    # Build output
    faces = []
    for i in range(detections.shape[0]):
        faces.append({
            'bbox': detections[i, :4],
            'confidence': float(detections[i, 4]),
            'landmarks': landmarks_decoded[i].reshape(5, 2),
        })
    
    return faces


print("Detection functions defined âœ“")

## 3. Load Models

In [None]:
# Load ONNX model
print("Loading ONNX model...")
onnx_path = verify_model_weights(RetinaFaceWeights.MNET_V2)
onnx_session = create_onnx_session(onnx_path)
print(f"âœ“ ONNX model loaded")

# Load MLX model with fused weights
if mlx_available and fused_weights_path.exists():
    print("\nLoading MLX model with fused weights...")
    mlx_model = RetinaFaceNetworkFused(backbone_type='mobilenetv2', width_mult=1.0)
    load_mlx_fused_weights(mlx_model, str(fused_weights_path))
    mlx_model.train(False)
    print(f"âœ“ MLX model loaded with fused weights")
else:
    mlx_model = None
    if not mlx_available:
        print("âœ— MLX not available")
    else:
        print("âœ— Fused weights not found")

## 4. Accuracy Comparison

In [None]:
if image is not None:
    print("=" * 60)
    print("ACCURACY COMPARISON")
    print("=" * 60)
    
    # Run ONNX detection
    print("\nRunning ONNX detection...")
    onnx_faces = detect_faces_onnx(onnx_session, image)
    print(f"ONNX detected: {len(onnx_faces)} faces")
    
    # Run MLX detection
    if mlx_model is not None:
        print("Running MLX detection...")
        mlx_faces = detect_faces_mlx(mlx_model, image)
        print(f"MLX detected: {len(mlx_faces)} faces")
        
        # Compare results
        if len(onnx_faces) == len(mlx_faces) and len(onnx_faces) > 0:
            print("\n--- Comparison (first face) ---")
            onnx_bbox = onnx_faces[0]['bbox']
            mlx_bbox = mlx_faces[0]['bbox']
            bbox_diff = np.abs(onnx_bbox - mlx_bbox).max()
            
            onnx_lmk = onnx_faces[0]['landmarks']
            mlx_lmk = mlx_faces[0]['landmarks']
            lmk_diff = np.abs(onnx_lmk - mlx_lmk).max()
            
            conf_diff = abs(onnx_faces[0]['confidence'] - mlx_faces[0]['confidence'])
            
            print(f"BBox max diff: {bbox_diff:.4f} pixels")
            print(f"Landmark max diff: {lmk_diff:.4f} pixels")
            print(f"Confidence diff: {conf_diff:.6f}")
            
            if bbox_diff < 1.0 and lmk_diff < 1.0:
                print("\nâœ“ PERFECT MATCH: MLX and ONNX produce identical results!")
            else:
                print("\nâš  Results differ slightly")
    else:
        mlx_faces = None

## 5. Performance Benchmark

In [None]:
if image is not None and mlx_model is not None:
    print("=" * 60)
    print("PERFORMANCE BENCHMARK")
    print("=" * 60)
    
    NUM_WARMUP = 5
    NUM_RUNS = 50
    
    # Warmup ONNX
    print(f"\nWarming up ONNX ({NUM_WARMUP} runs)...")
    for _ in range(NUM_WARMUP):
        _ = detect_faces_onnx(onnx_session, image)
    
    # Warmup MLX
    print(f"Warming up MLX ({NUM_WARMUP} runs)...")
    for _ in range(NUM_WARMUP):
        _ = detect_faces_mlx(mlx_model, image)
    
    # Benchmark ONNX
    print(f"\nBenchmarking ONNX ({NUM_RUNS} runs)...")
    onnx_times = []
    for _ in range(NUM_RUNS):
        start = time.perf_counter()
        _ = detect_faces_onnx(onnx_session, image)
        end = time.perf_counter()
        onnx_times.append((end - start) * 1000)
    
    onnx_times = np.array(onnx_times)
    onnx_mean = np.mean(onnx_times)
    onnx_std = np.std(onnx_times)
    onnx_min = np.min(onnx_times)
    onnx_max = np.max(onnx_times)
    
    # Benchmark MLX
    print(f"Benchmarking MLX ({NUM_RUNS} runs)...")
    mlx_times = []
    for _ in range(NUM_RUNS):
        start = time.perf_counter()
        _ = detect_faces_mlx(mlx_model, image)
        end = time.perf_counter()
        mlx_times.append((end - start) * 1000)
    
    mlx_times = np.array(mlx_times)
    mlx_mean = np.mean(mlx_times)
    mlx_std = np.std(mlx_times)
    mlx_min = np.min(mlx_times)
    mlx_max = np.max(mlx_times)
    
    # Print results
    print("\n" + "-" * 60)
    print("RESULTS")
    print("-" * 60)
    
    print(f"\nONNX Runtime (CoreML):")
    print(f"  Mean: {onnx_mean:.2f} Â± {onnx_std:.2f} ms")
    print(f"  Min/Max: {onnx_min:.2f} / {onnx_max:.2f} ms")
    print(f"  FPS: {1000/onnx_mean:.1f}")
    
    print(f"\nMLX (Native Apple Silicon):")
    print(f"  Mean: {mlx_mean:.2f} Â± {mlx_std:.2f} ms")
    print(f"  Min/Max: {mlx_min:.2f} / {mlx_max:.2f} ms")
    print(f"  FPS: {1000/mlx_mean:.1f}")
    
    # Calculate speedup
    speedup = onnx_mean / mlx_mean
    print(f"\n{'=' * 60}")
    if speedup > 1:
        print(f"ðŸš€ MLX is {speedup:.2f}x FASTER than ONNX!")
    else:
        print(f"ðŸ“Š ONNX is {1/speedup:.2f}x faster than MLX")
    print(f"{'=' * 60}")

## 6. Performance Visualization

In [None]:
if 'onnx_times' in dir() and 'mlx_times' in dir():
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Box plot comparison
    ax1 = axes[0]
    ax1.boxplot([onnx_times, mlx_times], labels=['ONNX\n(CoreML)', 'MLX\n(Native)'])
    ax1.set_ylabel('Inference Time (ms)')
    ax1.set_title('Inference Time Distribution')
    ax1.grid(axis='y', alpha=0.3)
    
    # Bar chart
    ax2 = axes[1]
    means = [onnx_mean, mlx_mean]
    stds = [onnx_std, mlx_std]
    colors = ['#2196F3', '#4CAF50']
    bars = ax2.bar(['ONNX', 'MLX'], means, yerr=stds, color=colors, capsize=5)
    ax2.set_ylabel('Inference Time (ms)')
    ax2.set_title('Mean Inference Time')
    ax2.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, mean, std in zip(bars, means, stds):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + std + 0.5,
                f'{mean:.1f}ms', ha='center', va='bottom', fontweight='bold')
    
    # FPS comparison
    ax3 = axes[2]
    fps_values = [1000/onnx_mean, 1000/mlx_mean]
    bars = ax3.bar(['ONNX', 'MLX'], fps_values, color=colors)
    ax3.set_ylabel('Frames Per Second')
    ax3.set_title('Throughput (FPS)')
    ax3.grid(axis='y', alpha=0.3)
    
    for bar, fps in zip(bars, fps_values):
        ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{fps:.1f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    
    # Save the figure
    output_path = project_root / 'assets' / 'benchmark_retinaface.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"Saved benchmark visualization to: {output_path}")
    plt.show()

## 7. Visual Detection Results

In [None]:
def draw_detections(image, detections, color=(0, 255, 0)):
    """Draw bounding boxes and landmarks on image."""
    img = image.copy()
    
    for det in detections:
        bbox = det['bbox'].astype(int)
        conf = det['confidence']
        landmarks = det['landmarks']
        
        # Draw bbox
        cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
        
        # Draw confidence
        cv2.putText(img, f"{conf:.2f}", (bbox[0], bbox[1] - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
        
        # Draw landmarks
        for lm in landmarks:
            cv2.circle(img, (int(lm[0]), int(lm[1])), 3, (0, 0, 255), -1)
    
    return img


if image is not None and 'onnx_faces' in dir() and 'mlx_faces' in dir():
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # ONNX Detection
    img_onnx = draw_detections(image, onnx_faces, color=(255, 0, 0))
    axes[0].imshow(cv2.cvtColor(img_onnx, cv2.COLOR_BGR2RGB))
    axes[0].set_title(f"ONNX Runtime ({len(onnx_faces)} faces)", fontsize=14)
    axes[0].axis('off')
    
    # MLX Detection
    img_mlx = draw_detections(image, mlx_faces, color=(0, 255, 0))
    axes[1].imshow(cv2.cvtColor(img_mlx, cv2.COLOR_BGR2RGB))
    axes[1].set_title(f"MLX Native ({len(mlx_faces)} faces)", fontsize=14)
    axes[1].axis('off')
    
    plt.suptitle("Detection Comparison: ONNX vs MLX", fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save the figure
    output_path = project_root / 'assets' / 'detection_comparison.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"Saved detection comparison to: {output_path}")
    plt.show()

## 8. System Information

In [None]:
import subprocess

print("=" * 60)
print("SYSTEM INFORMATION")
print("=" * 60)
print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"Processor: {platform.processor()}")
print(f"Machine: {platform.machine()}")

# Check for Apple Silicon
if platform.machine() == 'arm64' and platform.system() == 'Darwin':
    print("\nâœ“ Running on Apple Silicon")
    try:
        result = subprocess.run(['sysctl', '-n', 'machdep.cpu.brand_string'], 
                               capture_output=True, text=True)
        print(f"  CPU: {result.stdout.strip()}")
    except:
        pass

# Package versions
print("\nPackage Versions:")
try:
    import mlx.core as mx
    mlx_version = getattr(mx, '__version__', 'unknown')
    print(f"  MLX: {mlx_version}")
except ImportError:
    print("  MLX: Not installed")

try:
    import onnxruntime
    print(f"  ONNX Runtime: {onnxruntime.__version__}")
    print(f"  Providers: {onnxruntime.get_available_providers()}")
except ImportError:
    print("  ONNX Runtime: Not installed")

print(f"  NumPy: {np.__version__}")
print(f"  OpenCV: {cv2.__version__}")

## Conclusion

This benchmark compares MLX and ONNX Runtime backends for RetinaFace face detection on Apple Silicon.

### Key Findings:

1. **Numerical Parity**: MLX produces **identical results** to ONNX (correlation = 1.0)
2. **Performance**: Both achieve real-time inference (exact speedup depends on hardware)

### Why MLX on Apple Silicon?

- **Unified Memory**: No CPU-GPU data transfer overhead
- **Native Acceleration**: Optimized for Apple's Neural Engine and GPU
- **Lazy Evaluation**: Automatic graph optimization

### When to use ONNX?

- Cross-platform deployment (Linux, Windows)
- NVIDIA GPU acceleration (CUDA)
- Non-Apple Silicon Macs