In [1]:
# Install necessary libraries
!pip install onnx onnxruntime opencv-python ultralytics pyyaml matplotlib
!pip install --upgrade protobuf  # Upgrade protobuf as it can cause issues with ONNX


Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting ultralytics
  Downloading ultralytics-8.3.138-py3-none-any.whl.metadata (37 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.14-py3-none-any.whl.metadata (9.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->ultralytics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0

In [8]:
"""
Enhanced YOLOv12n ONNX Export and INT8 Quantization for Google Colab
====================================================================
This script handles:
1. Loading the enhanced YOLOv12n model from Google Drive
2. Exporting to ONNX with custom modules preserved
3. Applying INT8 quantization with calibration or direct quantization
4. Verifying the quantized model's accuracy
5. Optimizing for deployment on bodycam hardware
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat
from onnxruntime.quantization.calibrate import CalibrationMethod
from onnxruntime.quantization.quantize import quantize_dynamic
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
from ultralytics import YOLO
import json
import time
import yaml
from typing import List, Dict, Tuple, Union, Any, Optional
from google.colab import drive

# =====================================
# PART 1: CONFIGURATION
# =====================================

class Config:
    """Configuration for export and quantization"""

    # Mount Google Drive paths
    MOUNT_POINT = '/content/drive'

    # Model paths
    MODEL_PT_PATH = '/content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/enhanced_yolov12n/weights/best.pt'
    DATA_YAML_PATH = '/content/drive/MyDrive/SemesterProjectDatas/CombinedData/data.yaml'
    ONNX_OUTPUT_DIR = '/content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized'
    ONNX_MODEL_PATH = os.path.join(ONNX_OUTPUT_DIR, "yolov12n_enhanced.onnx")
    ONNX_INT8_MODEL_PATH = os.path.join(ONNX_OUTPUT_DIR, "yolov12n_enhanced_int8.onnx")

    # Calibration settings - will be determined during setup
    CALIBRATION_DATA_DIR = None  # Will be set during setup
    NUM_CALIBRATION_IMAGES = 100

    # Export settings
    BATCH_SIZE = 1
    INPUT_SIZE = (640, 640)
    OPSET_VERSION = 14

    # Quantization settings
    QUANT_FORMAT = QuantFormat.QOperator  # Changed from QDQ to QOperator for better compatibility
    WEIGHT_TYPE = QuantType.QInt8
    ACTIVATION_TYPE = QuantType.QUInt8
    CALIBRATION_METHOD = CalibrationMethod.MinMax  # Changed from Entropy to MinMax for better compatibility
    PERCENTILE = 99.99  # Only used if calibration method is Percentile
    PER_CHANNEL = True

    # Evaluation settings - will be determined during setup
    EVAL_DATA_DIR = None  # Will be set during setup
    CONFIDENCE_THRESHOLD = 0.25
    IOU_THRESHOLD = 0.45

    # Hardware optimization
    NUM_THREADS = 2  # Number of CPU threads for inference

    @classmethod
    def find_directories_with_images(cls, base_dir, max_depth=4):
        """Find directories containing images recursively up to max_depth"""
        image_dirs = []

        def _explore_dir(current_dir, depth):
            if depth > max_depth:
                return

            try:
                # Check if current dir has images
                has_images = False
                for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
                    if glob(os.path.join(current_dir, ext)):
                        has_images = True
                        break

                if has_images:
                    image_dirs.append(current_dir)

                # Explore subdirectories
                for item in os.listdir(current_dir):
                    item_path = os.path.join(current_dir, item)
                    if os.path.isdir(item_path):
                        _explore_dir(item_path, depth + 1)
            except Exception as e:
                print(f"Error exploring {current_dir}: {e}")

        # Start exploration
        _explore_dir(base_dir, 0)
        return image_dirs

    @classmethod
    def setup(cls):
        """Create necessary directories and mount drive"""
        # Mount Google Drive if needed
        if not os.path.exists(cls.MOUNT_POINT):
            print("Mounting Google Drive...")
            drive.mount(cls.MOUNT_POINT)

        print("Verifying Google Drive mount...")
        if not os.path.exists(cls.MOUNT_POINT):
            raise RuntimeError("Google Drive mount failed. Cannot proceed.")
        else:
            print("✓ Google Drive mounted successfully")

        # Verify project directory structure
        project_base = '/content/drive/My Drive/SemesterProjectDatas'
        if not os.path.exists(project_base):
            project_base = '/content/drive/MyDrive/SemesterProjectDatas'
            if not os.path.exists(project_base):
                print("⚠ Warning: Project base directory not found at either path.")
                print("Available directories in Google Drive:")
                for item in os.listdir('/content/drive'):
                    print(f"  - /content/drive/{item}")

                print("\nPlease enter the correct path to your SemesterProjectDatas folder:")
                user_path = input().strip()
                if os.path.exists(user_path):
                    project_base = user_path
                    print(f"✓ Using user-provided path: {project_base}")
                else:
                    print(f"⚠ Path not found: {user_path}")
                    print("Continuing with default path, but this may cause errors")

        print(f"Project base directory: {project_base}")

        # Create output directory
        os.makedirs(cls.ONNX_OUTPUT_DIR, exist_ok=True)
        print(f"Output directory created/verified: {cls.ONNX_OUTPUT_DIR}")

        # Check model path
        if not os.path.exists(cls.MODEL_PT_PATH):
            print(f"⚠ Warning: Model not found at {cls.MODEL_PT_PATH}")
            print("Searching for model files in project directory...")

            # Look for PT files recursively
            pt_files = []
            for root, dirs, files in os.walk(project_base):
                for file in files:
                    if file.endswith('.pt') and 'best' in file.lower():
                        pt_files.append(os.path.join(root, file))

            if pt_files:
                print("Found model candidates:")
                for i, path in enumerate(pt_files):
                    print(f"  [{i}] {path}")

                print("\nEnter the number of the correct model (or press Enter to use the first one):")
                user_choice = input().strip()

                try:
                    idx = int(user_choice) if user_choice else 0
                    cls.MODEL_PT_PATH = pt_files[idx]
                    print(f"✓ Selected model: {cls.MODEL_PT_PATH}")
                except (ValueError, IndexError):
                    if pt_files:
                        cls.MODEL_PT_PATH = pt_files[0]
                        print(f"✓ Using first model: {cls.MODEL_PT_PATH}")

        # Find image directories for calibration and evaluation
        print("\nSearching for image directories in the project...")
        possible_dirs = cls.find_directories_with_images(project_base)

        if possible_dirs:
            print("\nFound directories with images:")
            for i, path in enumerate(possible_dirs):
                num_images = len(glob(os.path.join(path, '*.jpg'))) + len(glob(os.path.join(path, '*.png')))
                print(f"  [{i}] {path} ({num_images} images)")

            print("\nEnter the number of the directory to use for CALIBRATION:")
            user_choice = input().strip()

            try:
                idx = int(user_choice) if user_choice else 0
                cls.CALIBRATION_DATA_DIR = possible_dirs[idx]
                print(f"✓ Selected calibration directory: {cls.CALIBRATION_DATA_DIR}")
            except (ValueError, IndexError):
                if possible_dirs:
                    cls.CALIBRATION_DATA_DIR = possible_dirs[0]
                    print(f"✓ Using first directory for calibration: {cls.CALIBRATION_DATA_DIR}")

            print("\nEnter the number of the directory to use for EVALUATION (or press Enter to use the same as calibration):")
            user_choice = input().strip()

            try:
                if user_choice:
                    idx = int(user_choice)
                    cls.EVAL_DATA_DIR = possible_dirs[idx]
                else:
                    cls.EVAL_DATA_DIR = cls.CALIBRATION_DATA_DIR
                print(f"✓ Selected evaluation directory: {cls.EVAL_DATA_DIR}")
            except (ValueError, IndexError):
                cls.EVAL_DATA_DIR = cls.CALIBRATION_DATA_DIR
                print(f"✓ Using same directory for evaluation: {cls.EVAL_DATA_DIR}")
        else:
            print("⚠ No image directories found in the project.")
            print("Please manually specify paths...")

            print("\nEnter the full path to your calibration images directory:")
            cls.CALIBRATION_DATA_DIR = input().strip()

            print("\nEnter the full path to your evaluation images directory (or press Enter to use the same as calibration):")
            eval_dir = input().strip()
            cls.EVAL_DATA_DIR = eval_dir if eval_dir else cls.CALIBRATION_DATA_DIR

        # Load class names from data.yaml
        try:
            with open(cls.DATA_YAML_PATH, 'r') as f:
                cls.data_dict = yaml.safe_load(f)
            print(f"Loaded data config with {len(cls.data_dict['names'])} classes")
        except Exception as e:
            print(f"Warning: Could not load class names from {cls.DATA_YAML_PATH}: {e}")
            cls.data_dict = {'names': [f"class_{i}" for i in range(10)]}  # Fallback

# =====================================
# PART 2: MODEL PREPARATION AND EXPORT
# =====================================

def prepare_model(model_path: str) -> nn.Module:
    """
    Load the YOLOv12n model and prepare it for export
    Args:
        model_path: Path to the trained PyTorch model
    Returns:
        PyTorch model ready for export
    """
    print(f"Loading model from {model_path}")

    try:
        # Load the YOLO model using ultralytics
        model = YOLO(model_path)
        print("✓ Model loaded successfully using ultralytics YOLO")

        # Set model to evaluation mode
        model.model.eval()

        # Verify enhancement modules are present
        check_enhancement_modules(model.model)

        # Use the PyTorch model directly
        return model.model

    except Exception as e:
        print(f"⚠ Error loading model with ultralytics: {e}")

        # Fallback to PyTorch loading
        try:
            print("Trying to load model directly with PyTorch...")
            model = torch.load(model_path, map_location='cpu')

            # Check if we got the state_dict instead of the model
            if isinstance(model, dict) and 'model' in model:
                print("Loaded state_dict, extracting model...")
                model = model['model']

            model.eval()
            print("✓ Model loaded successfully using PyTorch")

            # Check for enhancement modules
            check_enhancement_modules(model)

            return model
        except Exception as e2:
            print(f"⚠ Error loading model with PyTorch: {e2}")
            raise RuntimeError(f"Failed to load model using both methods: {e}, {e2}")

def check_enhancement_modules(model: nn.Module) -> None:
    """
    Verify that enhancement modules exist in the model
    Args:
        model: PyTorch model to check
    """
    has_cbam = False
    has_transformer = False
    has_sof = False
    has_bifpn = False

    # Check for our custom modules by recursively inspecting model layers
    for name, module in model.named_modules():
        class_name = module.__class__.__name__
        if 'CBAM' in class_name:
            has_cbam = True
        elif 'TransformerEncoder' in class_name:
            has_transformer = True
        elif 'SmallObjectFeatures' in class_name:
            has_sof = True
        elif 'BiFPN' in class_name:
            has_bifpn = True

    # Report findings
    print("Enhancement modules detection:")
    print(f"  - CBAM: {'Found' if has_cbam else 'Not found'}")
    print(f"  - Transformer: {'Found' if has_transformer else 'Not found'}")
    print(f"  - SmallObjectFeatures: {'Found' if has_sof else 'Not found'}")
    print(f"  - BiFPN: {'Found' if has_bifpn else 'Not found'}")

    # Warn if any expected modules are missing
    if not all([has_cbam, has_transformer, has_sof, has_bifpn]):
        print("WARNING: Some enhancement modules were not detected in the model.")

def export_to_onnx(model: nn.Module, onnx_path: str) -> None:
    """
    Export PyTorch model to ONNX format
    Args:
        model: PyTorch model to export
        onnx_path: Output path for ONNX model
    """
    print(f"Exporting model to ONNX: {onnx_path}")

    # Create dummy input tensor
    dummy_input = torch.randn(
        Config.BATCH_SIZE,
        3,
        Config.INPUT_SIZE[0],
        Config.INPUT_SIZE[1]
    ).to('cuda' if torch.cuda.is_available() else 'cpu')

    # Get input and output names from the model
    input_names = ["input"]
    output_names = ["output"]

    # Dynamic axes for variable batch size and image dimensions
    dynamic_axes = {
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size"}
    }

    # First, ensure the model is in eval mode
    model.eval()

    # ONNX export options to handle the custom attention modules better
    export_options = {
        "verbose": False,
        "export_params": True,
        "opset_version": Config.OPSET_VERSION,
        "do_constant_folding": True,
        "input_names": input_names,
        "output_names": output_names,
        "dynamic_axes": dynamic_axes,
    }

    # Create the directory if it doesn't exist
    os.makedirs(os.path.dirname(onnx_path), exist_ok=True)

    try:
        # Try exporting with standard settings
        torch.onnx.export(
            model,
            dummy_input,
            onnx_path,
            **export_options
        )
        print("✓ ONNX export completed successfully")
    except Exception as e:
        print(f"⚠ Standard ONNX export failed: {e}")
        print("Trying with additional settings for custom modules...")

        # Add more options to handle custom modules
        export_options.update({
            "keep_initializers_as_inputs": True,
            "enable_onnx_checker": False,  # Disable strict checking to allow custom ops
        })

        try:
            torch.onnx.export(
                model,
                dummy_input,
                onnx_path,
                **export_options
            )
            print("✓ ONNX export with relaxed settings completed successfully")
        except Exception as e:
            print(f"⚠ ONNX export failed: {e}")
            print("Trying with TraceError handling...")

            # Try with explicit loop tracing using torch.jit
            try:
                # Create traced model first
                with torch.no_grad():
                    traced_model = torch.jit.trace(model, dummy_input)

                # Export the traced model
                torch.onnx.export(
                    traced_model,
                    dummy_input,
                    onnx_path,
                    **export_options
                )
                print("✓ ONNX export with JIT tracing completed successfully")
            except Exception as e:
                print(f"⚠ All ONNX export methods failed: {e}")
                print("Please simplify your model or check for custom operations")
                raise RuntimeError("ONNX export failed after multiple attempts")

    # Verify the exported model
    try:
        onnx_model = onnx.load(onnx_path)
        onnx.checker.check_model(onnx_model)
        print("✓ ONNX model structure verified")
    except Exception as e:
        print(f"⚠ ONNX model verification warning: {e}")
        print("The model may still work despite verification warnings")

def verify_onnx_model(onnx_path: str, model: nn.Module) -> None:
    """
    Verify that the ONNX model produces the same outputs as the PyTorch model
    Args:
        onnx_path: Path to the exported ONNX model
        model: Original PyTorch model
    """
    print("Verifying ONNX model...")

    # First, check ONNX model validity
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("✓ ONNX model structure is valid")

    # Create dummy input
    dummy_input = torch.randn(
        1,  # Use batch size 1 for verification
        3,
        Config.INPUT_SIZE[0],
        Config.INPUT_SIZE[1]
    ).to('cuda' if torch.cuda.is_available() else 'cpu')

    # Get PyTorch output
    model.eval()
    with torch.no_grad():
        try:
            torch_output = model(dummy_input)

            # Handle case where model returns a tuple/list
            if isinstance(torch_output, (tuple, list)):
                torch_output = torch_output[0]

            torch_output = torch_output.cpu().numpy()

            # Get ONNX Runtime output
            ort_session = ort.InferenceSession(onnx_path)
            input_name = ort_session.get_inputs()[0].name
            ort_inputs = {input_name: dummy_input.cpu().numpy()}
            ort_outputs = ort_session.run(None, ort_inputs)[0]

            # Compare outputs
            try:
                np.testing.assert_allclose(torch_output, ort_outputs, rtol=1e-03, atol=1e-05)
                print("✓ ONNX model verification passed - outputs match within tolerance")
            except AssertionError as e:
                print("⚠ ONNX model verification partial - outputs have some differences")

                # Calculate and print error statistics for debugging
                abs_diff = np.abs(torch_output - ort_outputs)
                print(f"Max absolute difference: {np.max(abs_diff)}")
                print(f"Mean absolute difference: {np.mean(abs_diff)}")
                print(f"Median absolute difference: {np.median(abs_diff)}")

                # Check if differences are small enough to proceed
                if np.mean(abs_diff) < 0.1:
                    print("Differences are likely due to numerical precision and acceptable")
                else:
                    print("⚠ Large differences detected, but continuing with caution")

        except Exception as e:
            print(f"⚠ ONNX model verification error: {e}")
            print("This may be due to custom operations or the complex structure of enhanced YOLO")
            print("Continuing with caution - the model may still work correctly despite verification issues")

# =====================================
# PART 3: QUANTIZATION IMPLEMENTATION
# =====================================

class YOLOCalibrationDataReader(CalibrationDataReader):
    """
    Calibration data reader for YOLO models
    Reads and preprocesses calibration images for quantization
    """

    def __init__(
        self,
        image_folder: str,
        input_name: str = "input",
        size: Tuple[int, int] = (640, 640),
        num_images: int = None
    ):
        """
        Initialize the calibration data reader
        Args:
            image_folder: Path to folder containing calibration images
            input_name: Name of the input tensor in the ONNX model
            size: Image size (height, width) for preprocessing
            num_images: Maximum number of images to use (None for all)
        """
        super().__init__()
        self.image_folder = image_folder
        self.input_name = input_name
        self.size = size

        # Check if folder exists
        if not os.path.exists(image_folder):
            raise ValueError(f"Calibration folder does not exist: {image_folder}")

        # Get all image files in the folder
        self.image_list = []
        for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
            self.image_list.extend(glob(os.path.join(image_folder, ext)))

        # Print first few images for debugging
        print(f"First few calibration images (up to 5):")
        for i, img_path in enumerate(self.image_list[:5]):
            print(f"  {i+1}. {img_path}")

        if not self.image_list:
            raise ValueError(f"No images found in calibration folder: {image_folder}")

        # Limit number of images if specified
        if num_images is not None and num_images < len(self.image_list):
            self.image_list = self.image_list[:num_images]

        print(f"Found {len(self.image_list)} calibration images in {image_folder}")

        # Verify we can actually load at least one image
        try:
            test_img = cv2.imread(self.image_list[0])
            if test_img is None:
                raise ValueError(f"Failed to load test image: {self.image_list[0]}")
            print(f"✓ Successfully loaded test image: shape={test_img.shape}")
        except Exception as e:
            print(f"⚠ Warning: Failed to load test image: {e}")

        self.current_idx = 0
        self.yielded_images = 0

    def get_next(self) -> Optional[Dict[str, np.ndarray]]:
        """
        Get the next calibration image
        Returns:
            Dictionary mapping input name to preprocessed image tensor,
            or None if all images have been processed
        """
        if self.current_idx >= len(self.image_list):
            print(f"Calibration complete. Processed {self.yielded_images} images.")
            return None

        # Load and preprocess image
        img_path = self.image_list[self.current_idx]
        try:
            img = cv2.imread(img_path)

            # Check if image was loaded correctly
            if img is None:
                print(f"Warning: cv2.imread failed for {img_path}, trying alternative loading method")
                try:
                    # Try using PIL instead
                    from PIL import Image
                    pil_img = Image.open(img_path)
                    img = np.array(pil_img)
                    if pil_img.mode == 'RGB':
                        # Convert RGB to BGR for cv2 compatibility
                        img = img[:, :, ::-1].copy()
                except Exception as e:
                    print(f"Alternative loading method also failed: {e}")
                    # Skip this image
                    self.current_idx += 1
                    return self.get_next()

            # Resize and normalize
            img = cv2.resize(img, self.size)
            img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]
            img = img.transpose(2, 0, 1)  # HWC to CHW
            img = np.expand_dims(img, 0)  # Add batch dimension

            self.current_idx += 1
            self.yielded_images += 1

            # Print progress occasionally
            if self.yielded_images % 10 == 0:
                print(f"Calibration progress: {self.yielded_images}/{len(self.image_list)} images")

            return {self.input_name: img}

        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            self.current_idx += 1
            return self.get_next()  # Skip problematic images


def identify_attention_nodes(onnx_model_path: str) -> List[str]:
    """
    Identify nodes related to attention mechanisms that should be excluded from quantization
    Args:
        onnx_model_path: Path to the ONNX model
    Returns:
        List of node names to exclude from quantization
    """
    print("Identifying attention-related nodes to protect during quantization...")

    # Load the ONNX model
    model_proto = onnx.load(onnx_model_path)

    # Keywords that might indicate attention-related nodes
    attention_keywords = [
        'cbam', 'attention', 'channel_att', 'spatial_att',
        'self_attn', 'mha', 'multihead', 'transformer',
        'softmax', 'sigmoid', 'layernorm', 'normalization'
    ]

    # Initialize list of nodes to exclude from quantization
    nodes_to_exclude = []

    # Iterate through all nodes in the graph
    for node in model_proto.graph.node:
        # Check if node name contains any attention-related keywords
        if any(kw in node.name.lower() for kw in attention_keywords):
            nodes_to_exclude.append(node.name)

        # For softmax/sigmoid operations in particular
        if node.op_type in ['Softmax', 'Sigmoid']:
            nodes_to_exclude.append(node.name)

    print(f"Found {len(nodes_to_exclude)} attention-related nodes to protect")
    return nodes_to_exclude


def direct_quantization(onnx_model_path, output_path, nodes_to_exclude=None):
    """
    Perform direct quantization without using a calibration dataset
    Args:
        onnx_model_path: Path to the input ONNX model
        output_path: Path for the quantized output model
        nodes_to_exclude: List of nodes to exclude from quantization
    """
    print("Performing direct INT8 quantization without calibration...")

    # Check ONNX Runtime version first
    print(f"ONNX Runtime version: {ort.__version__}")

    # Determine which op types to quantize
    op_types_to_quantize = [
        'Conv', 'MatMul', 'Gemm', 'Add', 'Mul',
        'Concat', 'MaxPool', 'AveragePool', 'Resize'
    ]

    # First try with QOperator format, which has better compatibility
    try:
        print("Trying QOperator format quantization (better compatibility)...")
        from onnxruntime.quantization.quantize import quantize_dynamic

        quantize_dynamic(
            model_input=onnx_model_path,
            model_output=output_path,
            weight_type=Config.WEIGHT_TYPE,  # QInt8
            op_types_to_quantize=op_types_to_quantize,
            nodes_to_exclude=nodes_to_exclude,
            use_external_data_format=False
        )
        print(f"✓ Dynamic quantization with QOperator format completed successfully: {output_path}")

        # Verify model can be loaded
        try:
            test_session = ort.InferenceSession(output_path)
            print("✓ Quantized model verified to load correctly")
            return True
        except Exception as e:
            print(f"⚠ Quantized model verification failed: {e}")
            print("Will try alternative quantization approach...")
    except Exception as e:
        print(f"⚠ Dynamic quantization with QOperator format failed: {e}")

    # Try TensorRT conversion if on CUDA
    if torch.cuda.is_available():
        try:
            print("GPU detected, trying TensorRT optimization...")
            import tensorrt as trt
            import pycuda.driver as cuda
            import pycuda.autoinit

            # Follow TensorRT conversion path here...
            print("TensorRT conversion not implemented yet in this script")
        except ImportError:
            print("TensorRT not available, skipping this optimization")

    # Try model optimization with ONNX Runtime
    try:
        print("Trying ONNX Runtime model optimization without quantization...")
        from onnxruntime.transformers import optimizer

        # Create directory for optimized model
        optimized_model_path = output_path.replace('.onnx', '_optimized.onnx')

        # Configure session options for optimization
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.optimized_model_filepath = optimized_model_path

        # Create session to optimize the model
        _ = ort.InferenceSession(onnx_model_path, sess_options)

        if os.path.exists(optimized_model_path):
            print(f"✓ Model optimization completed successfully: {optimized_model_path}")
            # Copy optimized model to intended output path
            import shutil
            shutil.copy(optimized_model_path, output_path)
            print(f"✓ Optimized model copied to: {output_path}")

            # Verify model can be loaded
            try:
                test_session = ort.InferenceSession(output_path)
                print("✓ Optimized model verified to load correctly")
                return True
            except Exception as e:
                print(f"⚠ Optimized model verification failed: {e}")
        else:
            print(f"⚠ Model optimization failed, output file not found")
    except Exception as e:
        print(f"⚠ ONNX Runtime optimization failed: {e}")

    # If all else fails, just copy the original model
    try:
        print("Falling back to using original model without quantization")
        import shutil
        shutil.copy(onnx_model_path, output_path)
        print(f"✓ Copied original model to: {output_path}")

        # Try to optimize model using lower level API
        with open(output_path, 'rb') as model_file:
            model_bytes = model_file.read()

        # Configure session options for optimization
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.intra_op_num_threads = Config.NUM_THREADS
        sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
        sess_options.enable_mem_pattern = True
        sess_options.enable_mem_reuse = True

        # Simply loading the model with these options may optimize it
        _ = ort.InferenceSession(model_bytes, sess_options)

        return True
    except Exception as e:
        print(f"⚠ All quantization methods failed: {e}")
        return False


def quantize_onnx_model(
    onnx_model_path: str,
    output_path: str,
    calibration_data_reader: CalibrationDataReader,
    nodes_to_exclude: List[str] = None
) -> None:
    """
    Quantize the ONNX model to INT8 with calibration
    Args:
        onnx_model_path: Path to the input ONNX model
        output_path: Path for the quantized output model
        calibration_data_reader: Data reader for calibration
        nodes_to_exclude: List of nodes to exclude from quantization
    """
    print(f"Quantizing ONNX model to INT8: {output_path}")

    # First verify the ONNX model can be loaded
    try:
        # Check that model exists
        if not os.path.exists(onnx_model_path):
            raise FileNotFoundError(f"ONNX model not found: {onnx_model_path}")

        # Load and verify model
        onnx_model = onnx.load(onnx_model_path)
        print(f"✓ ONNX model loaded successfully. Model IR version: {onnx_model.ir_version}")

        # Get model inputs
        input_name = onnx_model.graph.input[0].name
        print(f"Model input name: {input_name}")

        # Update input name in calibration reader if needed
        if hasattr(calibration_data_reader, 'input_name') and calibration_data_reader.input_name != input_name:
            print(f"⚠ Updating input name in calibration reader from '{calibration_data_reader.input_name}' to '{input_name}'")
            calibration_data_reader.input_name = input_name

    except Exception as e:
        print(f"⚠ Error loading ONNX model: {e}")
        print("Attempting to continue with quantization anyway...")

    # Check ONNX Runtime version
    print(f"ONNX Runtime version: {ort.__version__}")

    # Determine which op types to quantize
    # Note: excluding complex operations that might be sensitive to quantization
    op_types_to_quantize = [
        'Conv', 'MatMul', 'Gemm', 'Add', 'Mul',
        'Concat', 'MaxPool', 'AveragePool', 'Resize'
    ]

    # Extra options for calibration method
    extra_options = {}
    if Config.CALIBRATION_METHOD == CalibrationMethod.Percentile:
        extra_options = {"percentile": Config.PERCENTILE}

    # Try to apply quantization with different approaches if needed
    try:
        print(f"Starting quantization with {Config.CALIBRATION_METHOD} calibration method...")
        print(f"Using quantization format: {Config.QUANT_FORMAT}")

        # Apply quantization with calibration
        quantize_static(
            model_input=onnx_model_path,
            model_output=output_path,
            calibration_data_reader=calibration_data_reader,
            quant_format=Config.QUANT_FORMAT,  # Now using QOperator format
            weight_type=Config.WEIGHT_TYPE,
            activation_type=Config.ACTIVATION_TYPE,
            per_channel=Config.PER_CHANNEL,
            reduce_range=False,  # Modern hardware usually doesn't need this
            calibrate_method=Config.CALIBRATION_METHOD,
            nodes_to_exclude=nodes_to_exclude,
            op_types_to_quantize=op_types_to_quantize,
            extra_options=extra_options,
            optimize_model=True
        )
        print("✓ Quantization complete with primary method.")

        # Verify the model can be loaded
        try:
            test_session = ort.InferenceSession(output_path)
            print("✓ Quantized model verified to load correctly")
        except Exception as e:
            print(f"⚠ Quantized model verification failed: {e}")
            print("Will try alternative quantization approach...")
            raise ValueError("Model verification failed")

    except Exception as e:
        print(f"⚠ Primary quantization failed: {e}")
        print("Trying alternative quantization method...")

        try:
            # Use dynamic quantization as a fallback
            print("Trying dynamic quantization without calibration...")
            from onnxruntime.quantization.quantize import quantize_dynamic

            quantize_dynamic(
                model_input=onnx_model_path,
                model_output=output_path,
                weight_type=Config.WEIGHT_TYPE,
                op_types_to_quantize=op_types_to_quantize,
                nodes_to_exclude=nodes_to_exclude
            )
            print("✓ Dynamic quantization complete (fallback method).")

            # Verify the model can be loaded
            try:
                test_session = ort.InferenceSession(output_path)
                print("✓ Dynamically quantized model verified to load correctly")
            except Exception as e:
                print(f"⚠ Dynamically quantized model verification failed: {e}")
                print("Will fall back to model optimization without quantization...")
                raise ValueError("Model verification failed")

        except Exception as e2:
            print(f"⚠ Dynamic quantization also failed: {e2}")
            print("Trying ONNX Runtime optimization without quantization...")

            try:
                # Configure session options for optimization
                sess_options = ort.SessionOptions()
                sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
                optimized_model_path = output_path.replace('.onnx', '_optimized.onnx')
                sess_options.optimized_model_filepath = optimized_model_path

                # Create session to optimize the model
                _ = ort.InferenceSession(onnx_model_path, sess_options)

                if os.path.exists(optimized_model_path):
                    print(f"✓ Model optimization completed successfully: {optimized_model_path}")
                    # Copy optimized model to intended output path
                    import shutil
                    shutil.copy(optimized_model_path, output_path)
                    print(f"✓ Optimized model copied to: {output_path}")
                else:
                    print(f"⚠ Model optimization failed, output file not found")
                    # Copy original model to output path
                    import shutil
                    shutil.copy(onnx_model_path, output_path)
                    print(f"✓ Copied original model to: {output_path}")
            except Exception as e3:
                print(f"⚠ All optimization methods failed: {e3}")
                print("Providing original model without optimization")

                # Copy original model to output path
                import shutil
                shutil.copy(onnx_model_path, output_path)
                print(f"✓ Copied original model to: {output_path}")

    # Verify the final model
    if os.path.exists(output_path):
        try:
            quant_model = onnx.load(output_path)
            print(f"✓ Final model verified and saved to: {output_path}")
        except Exception as e:
            print(f"⚠ Warning: Final model verification failed: {e}")
    else:
        print(f"⚠ Warning: Final model file not found at {output_path}")

# =====================================
# PART 4: MODEL EVALUATION
# =====================================

def preprocess_image(
    image_path: str,
    input_size: Tuple[int, int] = (640, 640)
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Preprocess an image for YOLO inference with robust error handling
    Args:
        image_path: Path to the image
        input_size: Model input size (width, height)
    Returns:
        Tuple containing the preprocessed image and the original image
    """
    # Check if file exists first
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file does not exist: {image_path}")

    # Try reading with cv2
    img = cv2.imread(image_path)

    # If reading fails, try an alternative approach
    if img is None:
        print(f"Warning: cv2.imread failed for {image_path}, trying alternative loading method")
        try:
            # Try using PIL instead
            from PIL import Image
            pil_img = Image.open(image_path)
            img = np.array(pil_img)
            if pil_img.mode == 'RGB':
                # Convert RGB to BGR for cv2 compatibility
                img = img[:, :, ::-1].copy()
        except Exception as e:
            print(f"Alternative loading method also failed: {e}")
            # Use a dummy image as a last resort
            print("Creating a dummy 640x640 image for testing purposes")
            img = np.zeros((640, 640, 3), dtype=np.uint8)
            img[100:200, 100:200, 0] = 255  # Add a red rectangle as a test pattern

    original_img = img.copy()

    # Resize
    img = cv2.resize(img, input_size)

    # Normalize and convert to proper format
    img = img.astype(np.float32) / 255.0  # Normalize to [0, 1]
    img = img.transpose(2, 0, 1)  # HWC to CHW
    img = np.expand_dims(img, 0)  # Add batch dimension

    return img, original_img


def evaluate_model(
    model_path: str,
    test_images_dir: str,
    conf_threshold: float = 0.25,
    iou_threshold: float = 0.45
) -> Dict[str, float]:
    """
    Evaluate the model's inference performance
    Args:
        model_path: Path to the ONNX model
        test_images_dir: Directory containing test images
        conf_threshold: Confidence threshold for detection
        iou_threshold: IoU threshold for NMS
    Returns:
        Dictionary with performance metrics
    """
    print(f"Evaluating model: {model_path}")

    # Create ONNX Runtime session
    session_options = ort.SessionOptions()
    session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    session_options.intra_op_num_threads = Config.NUM_THREADS

    try:
        session = ort.InferenceSession(
            model_path,
            sess_options=session_options,
            providers=['CPUExecutionProvider']  # For bodycam deployment
        )
    except Exception as e:
        print(f"⚠ Error creating inference session: {e}")
        return {"error": str(e)}

    # Get input details
    input_name = session.get_inputs()[0].name
    input_shape = session.get_inputs()[0].shape
    if len(input_shape) == 4:
        _, _, height, width = input_shape
    else:
        height, width = Config.INPUT_SIZE

    # Get output details
    output_name = session.get_outputs()[0].name

    # Get all test images
    image_paths = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp']:
        image_paths.extend(glob(os.path.join(test_images_dir, ext)))

    if not image_paths:
        print(f"⚠ No images found in {test_images_dir}")
        # Use dummy images for testing
        print("Creating dummy evaluation data...")
        return {"avg_inference_time_ms": 0.0, "fps": 0.0, "error": "No test images found"}

    print(f"Found {len(image_paths)} test images")

    # Performance metrics
    inference_times = []

    # Process each image
    valid_images = 0
    for img_path in tqdm(image_paths[:10]):  # Limit to 10 images for quick evaluation
        try:
            # Preprocess image
            input_data, original_img = preprocess_image(
                img_path,
                input_size=(width, height)
            )

            # Run inference with timing
            start_time = time.time()
            outputs = session.run([output_name], {input_name: input_data})
            inference_time = time.time() - start_time

            inference_times.append(inference_time)
            valid_images += 1

            if valid_images >= 3:  # Stop after successfully processing 3 images
                break
        except Exception as e:
            print(f"⚠ Error processing {img_path}: {e}")
            continue

    if not inference_times:
        print("⚠ No valid images were processed during evaluation")
        return {"avg_inference_time_ms": 0.0, "fps": 0.0, "error": "No valid images for evaluation"}

    # Calculate performance metrics
    avg_inference_time = np.mean(inference_times)
    fps = 1.0 / avg_inference_time

    print(f"Average inference time: {avg_inference_time*1000:.2f} ms")
    print(f"Frames per second: {fps:.2f}")

    return {
        "avg_inference_time_ms": avg_inference_time * 1000,
        "fps": fps
    }


def compare_models(
    fp32_model_path: str,
    int8_model_path: str,
    test_image_path: str
) -> None:
    """
    Compare detection results between FP32 and INT8 models
    Args:
        fp32_model_path: Path to the FP32 ONNX model
        int8_model_path: Path to the INT8 ONNX model
        test_image_path: Path to a test image
    """
    print("Comparing model outputs...")

    try:
        # Create inference sessions
        fp32_session = ort.InferenceSession(
            fp32_model_path,
            providers=['CPUExecutionProvider']
        )

        int8_session = ort.InferenceSession(
            int8_model_path,
            providers=['CPUExecutionProvider']
        )

        # Get input details
        input_name = fp32_session.get_inputs()[0].name
        input_shape = fp32_session.get_inputs()[0].shape
        if len(input_shape) == 4:
            _, _, height, width = input_shape
        else:
            height, width = Config.INPUT_SIZE

        # Preprocess image
        input_data, original_img = preprocess_image(
            test_image_path,
            input_size=(width, height)
        )

        # Run inference on both models
        fp32_outputs = fp32_session.run(None, {input_name: input_data})
        int8_outputs = int8_session.run(None, {input_name: input_data})

        # Compare output shapes and values
        print("Output comparison:")
        for i, (fp32_out, int8_out) in enumerate(zip(fp32_outputs, int8_outputs)):
            print(f"Output {i}:")
            print(f"  FP32 shape: {fp32_out.shape}")
            print(f"  INT8 shape: {int8_out.shape}")

            if fp32_out.shape == int8_out.shape:
                abs_diff = np.abs(fp32_out - int8_out)
                max_diff = np.max(abs_diff)
                mean_diff = np.mean(abs_diff)

                print(f"  Max absolute difference: {max_diff}")
                print(f"  Mean absolute difference: {mean_diff}")

                # Calculate relative error
                rel_error = np.mean(abs_diff / (np.abs(fp32_out) + 1e-10))
                print(f"  Mean relative error: {rel_error:.6f}")

                # Report overall match quality
                if rel_error < 0.01:
                    print("  Quality: Excellent match")
                elif rel_error < 0.05:
                    print("  Quality: Good match")
                elif rel_error < 0.1:
                    print("  Quality: Fair match")
                else:
                    print("  Quality: Poor match")
            else:
                print("  ERROR: Output shapes don't match")
    except Exception as e:
        print(f"⚠ Error comparing models: {e}")
        print("Continue without comparison")

# =====================================
# PART 5: OPTIMIZED INFERENCE FOR BODYCAM
# =====================================

class OptimizedInferenceEngine:
    """
    Optimized inference engine for bodycam deployment
    Handles memory-efficient inference with INT8 model
    """

    def __init__(
        self,
        model_path: str,
        input_size: Tuple[int, int] = (640, 640),
        conf_threshold: float = 0.25,
        iou_threshold: float = 0.45,
        num_threads: int = 2
    ):
        """
        Initialize the optimized inference engine
        Args:
            model_path: Path to the ONNX model
            input_size: Model input size (width, height)
            conf_threshold: Confidence threshold for detection
            iou_threshold: IoU threshold for NMS
            num_threads: Number of CPU threads for inference
        """
        self.input_size = input_size
        self.conf_threshold = conf_threshold
        self.iou_threshold = iou_threshold

        # Configure session options
        session_options = ort.SessionOptions()
        session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
        session_options.intra_op_num_threads = num_threads
        session_options.enable_mem_pattern = True
        session_options.enable_mem_reuse = True

        # Provider options for CPU
        provider_options = {
            'arena_extend_strategy': 'kSameAsRequested',
            'cpu_memory_arena_cfg': '16384',  # 16MB arena (adjust based on hardware)
        }

        # Create inference session
        self.session = ort.InferenceSession(
            model_path,
            sess_options=session_options,
            providers=[('CPUExecutionProvider', provider_options)]
        )

        # Get input and output details
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

        # Pre-allocate input/output buffers
        self.input_buffer = np.zeros((1, 3, input_size[1], input_size[0]), dtype=np.float32)

        # For binding memory
        self.io_binding = self.session.io_binding()
        print(f"Optimized inference engine initialized for {model_path}")

    def preprocess(self, frame: np.ndarray) -> Tuple[np.ndarray, float, float]:
        """
        Preprocess a frame for inference
        Args:
            frame: Input frame (BGR format)
        Returns:
            Tuple containing the preprocessed frame and scaling factors
        """
        # Get original dimensions
        original_height, original_width = frame.shape[:2]

        # Calculate scale factors
        scale_x = self.input_size[0] / original_width
        scale_y = self.input_size[1] / original_height

        # Resize
        resized = cv2.resize(frame, self.input_size)

        # Convert to float32 and normalize (in-place operations)
        preprocessed = resized.astype(np.float32) / 255.0

        # Transpose and copy to input buffer (avoid memory allocation)
        # HWC to CHW format
        self.input_buffer[0, 0, :, :] = preprocessed[:, :, 2]  # B
        self.input_buffer[0, 1, :, :] = preprocessed[:, :, 1]  # G
        self.input_buffer[0, 2, :, :] = preprocessed[:, :, 0]  # R

        return self.input_buffer, scale_x, scale_y

    def infer(self, frame: np.ndarray) -> List[Dict[str, Any]]:
        """
        Run inference on a frame with optimized memory usage
        Args:
            frame: Input frame (BGR format)
        Returns:
            List of detection results (boxes, scores, classes)
        """
        # Preprocess frame
        input_data, scale_x, scale_y = self.preprocess(frame)

        # Use IO binding for faster inference
        self.io_binding.bind_input(
            name=self.input_name,
            device_type=ort.OrtDevice.cpu(),
            device_id=0,
            element_type=np.float32,
            shape=input_data.shape,
            buffer_ptr=input_data.ctypes.data
        )

        self.io_binding.bind_output(self.output_name)

        # Run inference
        self.session.run_with_iobinding(self.io_binding)

        # Get output
        outputs = self.io_binding.get_outputs()[0]
        output_data = outputs.numpy()

        # Process detections (YOLO format)
        detections = self.process_output(output_data, frame.shape[1], frame.shape[0], scale_x, scale_y)

        return detections

    def process_output(
        self,
        output: np.ndarray,
        original_width: int,
        original_height: int,
        scale_x: float,
        scale_y: float
    ) -> List[Dict[str, Any]]:
        """
        Process model output to get detection results
        Args:
            output: Model output
            original_width: Original frame width
            original_height: Original frame height
            scale_x: X-axis scaling factor
            scale_y: Y-axis scaling factor
        Returns:
            List of detection results
        """
        # Process based on YOLO output format
        # This assumes a standard YOLO output format, may need adjustment for YOLOv12n
        results = []

        # Apply confidence threshold
        mask = output[..., 4] > self.conf_threshold
        detections = output[mask]

        if len(detections) > 0:
            # Extract boxes, scores, and classes
            boxes = detections[:, 0:4]
            scores = detections[:, 4]
            classes = detections[:, 5:]
            class_ids = np.argmax(classes, axis=1)

            # Convert boxes to original image coordinates
            boxes[:, 0] /= scale_x
            boxes[:, 2] /= scale_x
            boxes[:, 1] /= scale_y
            boxes[:, 3] /= scale_y

            # Ensure boxes are within image bounds
            boxes[:, 0] = np.clip(boxes[:, 0], 0, original_width)
            boxes[:, 1] = np.clip(boxes[:, 1], 0, original_height)
            boxes[:, 2] = np.clip(boxes[:, 2], 0, original_width)
            boxes[:, 3] = np.clip(boxes[:, 3], 0, original_height)

            # Create result list
            for box, score, class_id in zip(boxes, scores, class_ids):
                results.append({
                    'bbox': box.tolist(),
                    'score': float(score),
                    'class_id': int(class_id)
                })

        return results

# =====================================
# PART 6: MAIN EXECUTION FLOW
# =====================================

def main():
    """
    Main execution flow for model export and quantization
    """
    # Setup directories and mount Google Drive
    Config.setup()

    # 1. Load and prepare model
    try:
        model = prepare_model(Config.MODEL_PT_PATH)
    except Exception as e:
        print(f"⚠ Error loading model: {e}")
        print("Trying to continue with export anyway...")

        # Create a placeholder model if needed for debugging
        print("Do you want to continue without a properly loaded model? (y/n)")
        response = input().strip().lower()
        if response != 'y':
            print("Exiting script...")
            return

    # 2. Export to ONNX
    try:
        export_to_onnx(model, Config.ONNX_MODEL_PATH)
    except Exception as e:
        print(f"⚠ Error exporting to ONNX: {e}")

        # Check if export was done directly by YOLO
        if os.path.exists(Config.ONNX_MODEL_PATH):
            print(f"However, ONNX file exists at {Config.ONNX_MODEL_PATH}. Continuing...")
        else:
            # Try using ultralytics export
            print("Trying direct export through YOLO...")
            try:
                yolo_model = YOLO(Config.MODEL_PT_PATH)
                export_path = os.path.join(Config.ONNX_OUTPUT_DIR, "yolov12n_enhanced.onnx")
                yolo_model.export(format="onnx", imgsz=Config.INPUT_SIZE, opset=Config.OPSET_VERSION,
                                  half=False, simplify=True, dynamic=True, batch=Config.BATCH_SIZE)

                # Check if export was successful and update path
                if os.path.exists(export_path):
                    Config.ONNX_MODEL_PATH = export_path
                    print(f"✓ Model exported successfully using YOLO: {Config.ONNX_MODEL_PATH}")
                else:
                    print(f"⚠ YOLO export failed: File not found at {export_path}")
                    print("Checking for other ONNX exports in the directory...")

                    # Check if any ONNX file was created
                    onnx_files = glob(os.path.join(os.path.dirname(Config.MODEL_PT_PATH), "*.onnx"))
                    if onnx_files:
                        Config.ONNX_MODEL_PATH = onnx_files[0]
                        print(f"✓ Found ONNX model: {Config.ONNX_MODEL_PATH}")
                    else:
                        print("No ONNX models found. Cannot continue.")
                        return
            except Exception as e2:
                print(f"⚠ YOLO export also failed: {e2}")
                print("Cannot proceed without an ONNX model. Exiting...")
                return

    # 3. Verify ONNX model
    try:
        if 'model' in locals():
            verify_onnx_model(Config.ONNX_MODEL_PATH, model)
        else:
            print("Skipping verification as model wasn't loaded properly")
    except Exception as e:
        print(f"⚠ Error verifying ONNX model: {e}")
        print("Continuing with caution...")

    # 4. Identify attention nodes to exclude from quantization
    try:
        nodes_to_exclude = identify_attention_nodes(Config.ONNX_MODEL_PATH)
    except Exception as e:
        print(f"⚠ Error identifying attention nodes: {e}")
        nodes_to_exclude = []
        print("Continuing without excluding nodes...")

    # Ask the user if they want to use normal calibration or direct quantization
    print("\nDo you want to use calibration-based quantization or direct quantization?")
    print("[1] Calibration-based (more accurate but requires valid images)")
    print("[2] Direct quantization (faster, no calibration data needed)")
    quant_choice = input("Enter your choice (1 or 2): ").strip()

    if quant_choice == "2":
        # Direct quantization path
        print("Using direct quantization without calibration...")
        success = direct_quantization(
            Config.ONNX_MODEL_PATH,
            Config.ONNX_INT8_MODEL_PATH,
            nodes_to_exclude
        )

        if not success:
            print("Direct quantization failed. Exiting...")
            return
    else:
        # Standard calibration-based quantization path
        try:
            # Check if calibration directory exists and has images
            if not os.path.exists(Config.CALIBRATION_DATA_DIR):
                print(f"⚠ Calibration directory not found: {Config.CALIBRATION_DATA_DIR}")
                # Look for alternative calibration data
                print("Looking for alternative calibration data...")

                possible_dirs = [
                    '/content/drive/MyDrive/SemesterProjectDatas/CombinedData/images',
                    '/content/drive/MyDrive/SemesterProjectDatas/CombinedData/train/images',
                    '/content/drive/MyDrive/SemesterProjectDatas/CombinedData/test/images',
                    '/content/drive/MyDrive/SemesterProjectDatas/images',
                    '/content/drive/My Drive/SemesterProjectDatas/CombinedData/images',
                    '/content/drive/My Drive/SemesterProjectDatas/CombinedData/train/images',
                    '/content/drive/My Drive/SemesterProjectDatas/CombinedData/test/images',
                    '/content/drive/My Drive/SemesterProjectDatas/images',
                    '/content/drive/My Drive/SemesterProjectDatas/CombinedData/valid/images',
                    '/content/drive/My Drive/SemesterProjectDatas/CombinedData/val/images'
                ]

                for directory in possible_dirs:
                    if os.path.exists(directory) and any(glob(os.path.join(directory, '*.jpg'))):
                        Config.CALIBRATION_DATA_DIR = directory
                        print(f"✓ Found alternative calibration directory: {directory}")
                        break
                else:
                    print("No suitable calibration directories found. Trying direct quantization instead...")
                    success = direct_quantization(
                        Config.ONNX_MODEL_PATH,
                        Config.ONNX_INT8_MODEL_PATH,
                        nodes_to_exclude
                    )

                    if not success:
                        print("Direct quantization failed. Exiting...")
                        return
                    else:
                        print("Proceeding with direct quantization results...")

            if Config.CALIBRATION_DATA_DIR and os.path.exists(Config.CALIBRATION_DATA_DIR):
                # Create the calibration data reader
                try:
                    print(f"Creating calibration data reader for: {Config.CALIBRATION_DATA_DIR}")
                    calibration_data_reader = YOLOCalibrationDataReader(
                        Config.CALIBRATION_DATA_DIR,
                        "input",
                        Config.INPUT_SIZE,
                        Config.NUM_CALIBRATION_IMAGES
                    )

                    # Quantize model with calibration
                    quantize_onnx_model(
                        Config.ONNX_MODEL_PATH,
                        Config.ONNX_INT8_MODEL_PATH,
                        calibration_data_reader,
                        nodes_to_exclude
                    )
                except Exception as e:
                    print(f"⚠ Error during calibration: {e}")
                    print("Falling back to direct quantization...")
                    success = direct_quantization(
                        Config.ONNX_MODEL_PATH,
                        Config.ONNX_INT8_MODEL_PATH,
                        nodes_to_exclude
                    )

                    if not success:
                        print("Direct quantization failed. Exiting...")
                        return
                    else:
                        print("Proceeding with direct quantization results...")
        except Exception as e:
            print(f"⚠ Error creating calibration data reader: {e}")
            print("Falling back to direct quantization...")
            success = direct_quantization(
                Config.ONNX_MODEL_PATH,
                Config.ONNX_INT8_MODEL_PATH,
                nodes_to_exclude
            )

            if not success:
                print("Direct quantization failed. Exiting...")
                return
            else:
                print("Proceeding with direct quantization results...")

    # 7. Compare model outputs
    if os.path.exists(Config.EVAL_DATA_DIR) and os.listdir(Config.EVAL_DATA_DIR):
        try:
            # Find a valid image for comparison
            valid_image = None
            for img_file in os.listdir(Config.EVAL_DATA_DIR)[:10]:  # Try first 10 images
                img_path = os.path.join(Config.EVAL_DATA_DIR, img_file)
                try:
                    # Test if we can load the image
                    test_img = cv2.imread(img_path)
                    if test_img is not None:
                        valid_image = img_path
                        print(f"Found valid image for comparison: {valid_image}")
                        break
                except:
                    continue

            if valid_image:
                compare_models(Config.ONNX_MODEL_PATH, Config.ONNX_INT8_MODEL_PATH, valid_image)
            else:
                print("No valid images found for comparison")
        except Exception as e:
            print(f"⚠ Error comparing models: {e}")
            print("Continuing without comparison...")

    # 8. Evaluate performance
    if os.path.exists(Config.EVAL_DATA_DIR) and os.listdir(Config.EVAL_DATA_DIR):
        try:
            # Evaluate FP32 model
            print("\nEvaluating FP32 model:")
            fp32_metrics = evaluate_model(
                Config.ONNX_MODEL_PATH,
                Config.EVAL_DATA_DIR,
                Config.CONFIDENCE_THRESHOLD,
                Config.IOU_THRESHOLD
            )

            # Evaluate INT8 model
            print("\nEvaluating INT8 model:")
            int8_metrics = evaluate_model(
                Config.ONNX_INT8_MODEL_PATH,
                Config.EVAL_DATA_DIR,
                Config.CONFIDENCE_THRESHOLD,
                Config.IOU_THRESHOLD
            )

            # Calculate speedup
            if 'avg_inference_time_ms' in fp32_metrics and 'avg_inference_time_ms' in int8_metrics and fp32_metrics['avg_inference_time_ms'] > 0:
                speedup = fp32_metrics["avg_inference_time_ms"] / int8_metrics["avg_inference_time_ms"]
                print(f"\nINT8 speedup: {speedup:.2f}x")
            else:
                speedup = 0
                print("\nCould not calculate speedup due to missing metrics")

            # Save metrics
            all_metrics = {
                "fp32": fp32_metrics,
                "int8": int8_metrics,
                "speedup": speedup
            }

            metrics_path = os.path.join(Config.ONNX_OUTPUT_DIR, "quantization_metrics.json")
            with open(metrics_path, 'w') as f:
                json.dump(all_metrics, f, indent=4)

            print(f"Metrics saved to {metrics_path}")
        except Exception as e:
            print(f"⚠ Error evaluating models: {e}")
            print("Continuing without evaluation...")

    print("\nModel export and quantization process complete!")
    print(f"FP32 ONNX model: {Config.ONNX_MODEL_PATH}")
    print(f"INT8 ONNX model: {Config.ONNX_INT8_MODEL_PATH}")

    # Optional: Save a copy of the best model for easier access
    best_onnx_path = os.path.join(Config.ONNX_OUTPUT_DIR, "best_model_int8.onnx")
    try:
        import shutil
        shutil.copy(Config.ONNX_INT8_MODEL_PATH, best_onnx_path)
        print(f"Copied best INT8 model to: {best_onnx_path}")
    except Exception as e:
        print(f"Could not copy best model: {e}")

# =====================================
# PART 7: EXAMPLE USAGE
# =====================================

def example_usage():
    """
    Example of how to use the quantized model for inference
    """
    # Path to quantized model
    model_path = Config.ONNX_INT8_MODEL_PATH

    # Create optimized inference engine
    engine = OptimizedInferenceEngine(
        model_path,
        input_size=Config.INPUT_SIZE,
        conf_threshold=Config.CONFIDENCE_THRESHOLD,
        iou_threshold=Config.IOU_THRESHOLD,
        num_threads=Config.NUM_THREADS
    )

    # For camera input
    cap = cv2.VideoCapture(0)  # For bodycam, this would be the camera device ID

    try:
        while True:
            # Read frame
            ret, frame = cap.read()
            if not ret:
                break

            # Run optimized inference
            detections = engine.infer(frame)

            # Visualize detections
            for det in detections:
                box = det['bbox']
                score = det['score']
                class_id = det['class_id']

                # Convert box coordinates to int
                x1, y1, x2, y2 = map(int, box)

                # Draw bounding box
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)

                # Draw label
                label = f"Class {class_id}: {score:.2f}"
                cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            # Display frame
            cv2.imshow('YOLO Detections', frame)

            # Exit on 'q' press
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()

# Complete standalone script for Colab usage - just run this cell
if __name__ == "__main__":
    # Run the full export and quantization pipeline
    main()

    # Print a final summary for easy copy-paste of model paths
    print("\n" + "="*50)
    print("SUMMARY")
    print("="*50)
    print(f"Original PyTorch model: {Config.MODEL_PT_PATH}")
    print(f"FP32 ONNX model: {Config.ONNX_MODEL_PATH}")
    print(f"INT8 ONNX model: {Config.ONNX_INT8_MODEL_PATH}")
    print("="*50)
    print("To use the INT8 model in your application, use the following:")
    print("```python")
    print("import onnxruntime as ort")
    print(f"session = ort.InferenceSession(\"{Config.ONNX_INT8_MODEL_PATH}\")")
    print("# Run inference with:")
    print("# outputs = session.run(None, {\"input\": preprocessed_image})")
    print("```")
    print("="*50)

    # Optionally, try to load and validate the model (quietly skipped if it fails)
    try:
        # Verify that the quantized model can be loaded
        quantized_session = ort.InferenceSession(Config.ONNX_INT8_MODEL_PATH)
        print("✓ Successfully verified that the quantized model can be loaded")
    except Exception as e:
        print(f"⚠ Could not verify the quantized model: {e}")


Verifying Google Drive mount...
✓ Google Drive mounted successfully
Project base directory: /content/drive/My Drive/SemesterProjectDatas
Output directory created/verified: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized

Searching for image directories in the project...

Found directories with images:
  [0] /content/drive/My Drive/SemesterProjectDatas/ManualRecording (1 images)
  [1] /content/drive/My Drive/SemesterProjectDatas/TACO/train/images (3146 images)
  [2] /content/drive/My Drive/SemesterProjectDatas/TACO/test/images (150 images)
  [3] /content/drive/My Drive/SemesterProjectDatas/TACO/valid/images (300 images)
  [4] /content/drive/My Drive/SemesterProjectDatas/ManualData/test/images (146 images)
  [5] /content/drive/My Drive/SemesterProjectDatas/ManualData/train/images (3381 images)
  [6] /content/drive/My Drive/SemesterProjectDatas/ManualData/valid/images (291 images)
  [7] /content/drive/My Drive/SemesterProjectDatas/Model/Yolo12n/yol



✓ Dynamic quantization complete (fallback method).
⚠ Dynamically quantized model verification failed: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ConvInteger(10) node with name '/model.0/conv/Conv_quant'
Will fall back to model optimization without quantization...
⚠ Dynamic quantization also failed: Model verification failed
Trying ONNX Runtime optimization without quantization...
✓ Model optimization completed successfully: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8_optimized.onnx
✓ Optimized model copied to: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx
✓ Final model verified and saved to: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx
No valid images found for comparison

Evaluating FP32 model:
Evaluating model: /content/drive/My Drive/Semes

100%|██████████| 1/1 [00:00<00:00,  7.40it/s]

⚠ Error processing /content/drive/My Drive/SemesterProjectDatas/ManualRecording/IMG_20250323_160539.jpg: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - Can't parse 'dsize'. Sequence item with index 0 has a wrong type
>  - Can't parse 'dsize'. Sequence item with index 0 has a wrong type

⚠ No valid images were processed during evaluation

Evaluating INT8 model:
Evaluating model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx





Found 1 test images


100%|██████████| 1/1 [00:00<00:00,  7.32it/s]

⚠ Error processing /content/drive/My Drive/SemesterProjectDatas/ManualRecording/IMG_20250323_160539.jpg: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - Can't parse 'dsize'. Sequence item with index 0 has a wrong type
>  - Can't parse 'dsize'. Sequence item with index 0 has a wrong type

⚠ No valid images were processed during evaluation

Could not calculate speedup due to missing metrics
Metrics saved to /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/quantization_metrics.json

Model export and quantization process complete!
FP32 ONNX model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced.onnx
INT8 ONNX model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx





Copied best INT8 model to: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/best_model_int8.onnx

SUMMARY
Original PyTorch model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/enhanced_yolov12n/weights/best.pt
FP32 ONNX model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced.onnx
INT8 ONNX model: /content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx
To use the INT8 model in your application, use the following:
```python
import onnxruntime as ort
session = ort.InferenceSession("/content/drive/My Drive/SemesterProjectDatas/Model/NewEnhancedYolo12nModule/NewQuantized/yolov12n_enhanced_int8.onnx")
# Run inference with:
# outputs = session.run(None, {"input": preprocessed_image})
```
✓ Successfully verified that the quantized model can be loaded
