"""
Toy example demonstrating PyTorch to OpenVINO conversion workflow.

This script:
1. Loads a pre-trained ResNet18 model from PyTorch
2. Converts it to OpenVINO format
3. Exports the converted model to a directory
4. Re-imports the model from that directory
5. Verifies the model works by running inference
"""

import os
import sys
import torch
import torchvision.models as models
import numpy as np
from pathlib import Path
import openvino as ov

# Add the defect_predictor module to path
sys.path.insert(0, '/home/algo-ipc/inspekto/inspekto/defect-predictor/src')

from tiny_std.logger import Logger


def create_dummy_input(batch_size=1, channels=3, height=224, width=224):
    """Create a dummy input tensor for testing."""
    return torch.randn(batch_size, channels, height, width)


def convert_pytorch_to_openvino(pytorch_model, example_input, output_path, logger):
    """
    Convert PyTorch model to OpenVINO format.
    
    :param pytorch_model: PyTorch model to convert
    :param example_input: Example input tensor for tracing
    :param output_path: Path where to save the OpenVINO model (.xml and .bin)
    :param logger: Logger instance
    :return: Path to the saved OpenVINO model (.xml file)
    """
    try:
        logger.info("Converting PyTorch model to OpenVINO format...")
        
        # Set model to evaluation mode
        pytorch_model.eval()
        
        # Convert PyTorch model to OpenVINO
        ov_model = ov.convert_model(pytorch_model, example_input=example_input)
        
        # Save the model
        ov.save_model(ov_model, output_path)
        
        logger.info(f"✓ Model converted and saved to {output_path}")
        return output_path
        
    except Exception as e:
        logger.error(f"✗ Failed to convert model: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return None


def load_and_compile_model(core, model_path, logger):
    """
    Load an OpenVINO model from disk and compile it.
    This is the initial compilation step after conversion.
    
    :param core: OpenVINO Core instance
    :param model_path: Path to the .xml file of the OpenVINO model
    :param logger: Logger instance
    :return: Compiled OpenVINO model or None if failed
    """
    try:
        logger.info(f"Loading and compiling OpenVINO model from {model_path}...")
        
        # Read the model
        model = core.read_model(model_path)
        
        # Compile the model for CPU
        compiled_model = core.compile_model(model, "CPU")
        
        logger.info(f"✓ Model loaded and compiled successfully")
        
        # Get input names safely
        input_names = []
        for i, inp in enumerate(compiled_model.inputs):
            try:
                input_names.append(inp.any_name)
            except:
                input_names.append(f"input_{i}")
        logger.info(f"  Model inputs: {input_names}")
        
        # Get output names safely
        output_names = []
        for i, out in enumerate(compiled_model.outputs):
            try:
                output_names.append(out.any_name)
            except:
                output_names.append(f"output_{i}")
        logger.info(f"  Model outputs: {output_names}")
        
        return compiled_model
        
    except Exception as e:
        logger.error(f"✗ Failed to load and compile model: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return None


def run_inference(compiled_model, input_tensor, logger):
    """
    Run inference on the compiled OpenVINO model.
    
    :param compiled_model: Compiled OpenVINO model
    :param input_tensor: Input tensor (numpy array or torch tensor)
    :param logger: Logger instance
    :return: Output tensor or None if failed
    """
    try:
        logger.info("Running inference...")
        
        # Convert torch tensor to numpy if needed
        if isinstance(input_tensor, torch.Tensor):
            input_array = input_tensor.detach().cpu().numpy()
        else:
            input_array = input_tensor
        
        # Run inference
        output = compiled_model([input_array])[0]
        
        logger.info(f"✓ Inference completed")
        logger.info(f"  Input shape: {input_array.shape}")
        logger.info(f"  Output shape: {output.shape}")
        logger.info(f"  Output sample (first 5 values): {output[0, :5]}")
        
        return output
        
    except Exception as e:
        logger.error(f"✗ Failed to run inference: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return None


def export_compiled_model(core, compiled_model, destination_dir, model_name, logger):
    """
    Export (save) compiled OpenVINO model to a directory.
    Similar to OpenVinoManager.export_compiled_models() pattern.
    
    Note: This uses compiled_model.export_model() to save the compiled state,
    which can then be imported without recompilation.
    
    :param core: OpenVINO Core instance
    :param compiled_model: Compiled OpenVINO model
    :param destination_dir: Destination directory (like profile_version_path/ov_models)
    :param model_name: Name for the model file (without extension)
    :param logger: Logger instance
    :return: Path to the exported .bin file or None if failed
    """
    try:
        logger.info(f"Exporting compiled model to {destination_dir}...")
        
        # Create destination directory (like ov_models folder)
        os.makedirs(destination_dir, exist_ok=True)
        
        # Define destination path for the model (compiled models use .bin extension)
        dest_bin = os.path.join(destination_dir, f"{model_name}_compiled.bin")
        tmp_dest_bin = dest_bin + ".tmp"
        
        # Export the compiled model (this saves the compiled state)
        user_stream = compiled_model.export_model()
        with open(tmp_dest_bin, 'wb') as f:
            f.write(user_stream.getvalue())
            f.flush()
            os.fsync(f.fileno())
        
        # Atomic operation to prevent partial writes
        os.replace(tmp_dest_bin, dest_bin)
        
        logger.info(f"✓ Compiled model exported successfully")
        logger.info(f"  Exported to: {dest_bin}")
        
        return dest_bin
        
    except Exception as e:
        logger.error(f"✗ Failed to export compiled model: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return None


def import_compiled_model(core, model_path, logger):
    """
    Import (load) a compiled OpenVINO model from disk.
    Similar to OpenVinoManager.import_compiled_models() pattern.
    
    Note: This uses core.import_model() to load the pre-compiled model state,
    so no recompilation is needed.
    
    :param core: OpenVINO Core instance
    :param model_path: Path to the .bin file of the compiled model
    :param logger: Logger instance
    :return: Loaded compiled OpenVINO model or None if failed
    """
    try:
        logger.info(f"Importing compiled model from {model_path}...")
        
        # Read the compiled model binary
        with open(model_path, "rb") as f:
            user_stream = f.read()
        
        # Import the compiled model (no recompilation needed)
        compiled_model = core.import_model(user_stream, "CPU")
        
        logger.info(f"✓ Compiled model imported successfully (no recompilation)")
        
        # Get input names safely
        input_names = []
        for i, inp in enumerate(compiled_model.inputs):
            try:
                input_names.append(inp.any_name)
            except:
                input_names.append(f"input_{i}")
        logger.info(f"  Model inputs: {input_names}")
        
        # Get output names safely
        output_names = []
        for i, out in enumerate(compiled_model.outputs):
            try:
                output_names.append(out.any_name)
            except:
                output_names.append(f"output_{i}")
        logger.info(f"  Model outputs: {output_names}")
        
        return compiled_model
        
    except Exception as e:
        logger.error(f"✗ Failed to import compiled model: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return None


def verify_outputs_match(output1, output2, logger, tolerance=1e-5):
    """
    Verify that two outputs are close enough (within tolerance).
    
    :param output1: First output array
    :param output2: Second output array
    :param logger: Logger instance
    :param tolerance: Maximum allowed difference
    :return: True if outputs match, False otherwise
    """
    try:
        max_diff = np.max(np.abs(output1 - output2))
        mean_diff = np.mean(np.abs(output1 - output2))
        
        logger.info(f"Output comparison:")
        logger.info(f"  Max difference: {max_diff}")
        logger.info(f"  Mean difference: {mean_diff}")
        logger.info(f"  Tolerance: {tolerance}")
        
        if max_diff < tolerance:
            logger.info(f"✓ Outputs match within tolerance")
            return True
        else:
            logger.warning(f"⚠ Outputs differ by more than tolerance")
            return False
            
    except Exception as e:
        logger.error(f"✗ Failed to compare outputs: {e}")
        return False


def main():
    """Main function demonstrating the full workflow."""
    
    # Setup logger
    logger = Logger(
        log_title='pytorch_to_openvino_toy_example',
        file_name='pytorch_to_openvino_toy_example',
        file_logging_level=Logger.DEBUG,
        console_logging_level=Logger.INFO
    )
    
    logger.info("="*80)
    logger.info("PyTorch to OpenVINO Conversion - Toy Example")
    logger.info("="*80)
    
    # Configuration
    temp_dir = "/tmp/openvino_toy_example"
    initial_export_dir = os.path.join(temp_dir, "initial_export")
    reimport_export_dir = os.path.join(temp_dir, "reimport_export")
    ov_models_folder = "ov_models"  # Like OpenVinoManager.ov_models_folder
    model_name = "resnet18"  # Without extension
    
    # Create directories
    os.makedirs(temp_dir, exist_ok=True)
    
    # Initialize OpenVINO Core (reused throughout)
    core = ov.Core()
    
    # Step 1: Load PyTorch model
    logger.info("\n" + "="*80)
    logger.info("STEP 1: Loading PyTorch ResNet18 model")
    logger.info("="*80)
    
    try:
        pytorch_model = models.resnet18(pretrained=True)
        pytorch_model.eval()
        logger.info("✓ PyTorch model loaded successfully")
    except Exception as e:
        logger.error(f"✗ Failed to load PyTorch model: {e}")
        return
    
    # Step 2: Create example input
    logger.info("\n" + "="*80)
    logger.info("STEP 2: Creating example input tensor")
    logger.info("="*80)
    
    example_input = create_dummy_input()
    logger.info(f"✓ Created input tensor with shape: {example_input.shape}")
    
    # Step 3: Run inference on PyTorch model (for comparison)
    logger.info("\n" + "="*80)
    logger.info("STEP 3: Running inference on PyTorch model (baseline)")
    logger.info("="*80)
    
    try:
        with torch.no_grad():
            pytorch_output = pytorch_model(example_input).detach().cpu().numpy()
        logger.info(f"✓ PyTorch inference completed")
        logger.info(f"  Output shape: {pytorch_output.shape}")
        logger.info(f"  Output sample (first 5 values): {pytorch_output[0, :5]}")
    except Exception as e:
        logger.error(f"✗ Failed PyTorch inference: {e}")
        return
    
    # Step 4: Convert to OpenVINO
    logger.info("\n" + "="*80)
    logger.info("STEP 4: Converting PyTorch model to OpenVINO format")
    logger.info("="*80)
    
    # Save converted model to initial directory
    initial_model_path = os.path.join(initial_export_dir, f"{model_name}.xml")
    converted_model_path = convert_pytorch_to_openvino(
        pytorch_model, 
        example_input, 
        initial_model_path, 
        logger
    )
    
    if not converted_model_path:
        logger.error("Conversion failed, stopping")
        return
    
    # Step 5: Load and compile the converted model
    logger.info("\n" + "="*80)
    logger.info("STEP 5: Loading and compiling the converted OpenVINO model")
    logger.info("="*80)
    
    compiled_model_1 = load_and_compile_model(core, converted_model_path, logger)
    
    if not compiled_model_1:
        logger.error("Load and compile failed, stopping")
        return
    
    # Step 6: Run inference on compiled model
    logger.info("\n" + "="*80)
    logger.info("STEP 6: Running inference on compiled OpenVINO model")
    logger.info("="*80)
    
    ov_output_1 = run_inference(compiled_model_1, example_input, logger)
    
    if ov_output_1 is None:
        logger.error("Inference failed, stopping")
        return
    
    # Verify outputs match
    logger.info("\nVerifying OpenVINO output matches PyTorch output...")
    verify_outputs_match(pytorch_output, ov_output_1, logger, tolerance=1e-4)
    
    # Step 7: Export compiled model to a different directory (like compile_and_save -> export_compiled_models)
    logger.info("\n" + "="*80)
    logger.info("STEP 7: Exporting compiled model to a different directory")
    logger.info("="*80)
    logger.info("(Similar to OpenVinoManager.export_compiled_models)")
    
    # Create ov_models subdirectory like in OpenVinoManager
    export_ov_models_dir = os.path.join(reimport_export_dir, ov_models_folder)
    exported_model_path = export_compiled_model(
        core,
        compiled_model_1,
        export_ov_models_dir,
        model_name,
        logger
    )
    
    if not exported_model_path:
        logger.error("Export failed, stopping")
        return
    
    # Step 8: Import compiled model from the new directory (like import_compiled_models)
    logger.info("\n" + "="*80)
    logger.info("STEP 8: Importing compiled model from exported directory")
    logger.info("="*80)
    logger.info("(Similar to OpenVinoManager.import_compiled_models)")
    logger.info("Note: Using import_model() - no recompilation needed")
    
    compiled_model_2 = import_compiled_model(core, exported_model_path, logger)
    
    if not compiled_model_2:
        logger.error("Import of compiled model failed, stopping")
        return
    
    # Step 9: Run inference on imported compiled model
    logger.info("\n" + "="*80)
    logger.info("STEP 9: Running inference on imported compiled OpenVINO model")
    logger.info("="*80)
    
    ov_output_2 = run_inference(compiled_model_2, example_input, logger)
    
    if ov_output_2 is None:
        logger.error("Imported model inference failed, stopping")
        return
    
    # Verify outputs match
    logger.info("\nComparing imported compiled model output with original compiled model output...")
    match_1_2 = verify_outputs_match(ov_output_1, ov_output_2, logger, tolerance=1e-6)
    
    logger.info("\nComparing imported compiled model output with PyTorch output...")
    match_2_pytorch = verify_outputs_match(pytorch_output, ov_output_2, logger, tolerance=1e-4)
    
    # Step 10: Export compiled model again to a third directory
    logger.info("\n" + "="*80)
    logger.info("STEP 10: Exporting compiled model again to a third directory")
    logger.info("="*80)
    logger.info("Testing export after import - compiled model should be re-exportable")
    
    third_export_dir = os.path.join(temp_dir, "third_export")
    third_ov_models_dir = os.path.join(third_export_dir, ov_models_folder)
    exported_model_path_3 = export_compiled_model(
        core,
        compiled_model_2,
        third_ov_models_dir,
        model_name,
        logger
    )
    
    if not exported_model_path_3:
        logger.error("Second export failed, stopping")
        return
    
    # Step 11: Import compiled model from the third directory
    logger.info("\n" + "="*80)
    logger.info("STEP 11: Importing compiled model from third export directory")
    logger.info("="*80)
    logger.info("Testing reimport after re-export")
    
    compiled_model_3 = import_compiled_model(core, exported_model_path_3, logger)
    
    if not compiled_model_3:
        logger.error("Second import failed, stopping")
        return
    
    # Step 12: Run inference on second imported compiled model
    logger.info("\n" + "="*80)
    logger.info("STEP 12: Running inference on second imported compiled OpenVINO model")
    logger.info("="*80)
    
    ov_output_3 = run_inference(compiled_model_3, example_input, logger)
    
    if ov_output_3 is None:
        logger.error("Second imported model inference failed, stopping")
        return
    
    # Step 13: Final verification - comparing all outputs
    logger.info("\n" + "="*80)
    logger.info("STEP 13: Final verification - comparing all outputs")
    logger.info("="*80)
    
    logger.info("\nComparing second imported model output with first imported model output...")
    match_2_3 = verify_outputs_match(ov_output_2, ov_output_3, logger, tolerance=1e-6)
    
    logger.info("\nComparing second imported model output with PyTorch output...")
    match_3_pytorch = verify_outputs_match(pytorch_output, ov_output_3, logger, tolerance=1e-4)
    
    # Final summary
    logger.info("\n" + "="*80)
    logger.info("FINAL SUMMARY")
    logger.info("="*80)
    
    all_steps_passed = (
        converted_model_path is not None and
        compiled_model_1 is not None and
        ov_output_1 is not None and
        exported_model_path is not None and
        compiled_model_2 is not None and
        ov_output_2 is not None and
        match_1_2 and
        match_2_pytorch and
        exported_model_path_3 is not None and
        compiled_model_3 is not None and
        ov_output_3 is not None and
        match_2_3 and
        match_3_pytorch
    )
    
    if all_steps_passed:
        logger.info("✓✓✓ ALL STEPS COMPLETED SUCCESSFULLY! ✓✓✓")
        logger.info("\nWorkflow verified (following OpenVinoManager pattern):")
        logger.info("  1. ✓ PyTorch model loaded")
        logger.info("  2. ✓ Converted to OpenVINO IR format")
        logger.info("  3. ✓ Loaded and compiled the converted model")
        logger.info("  4. ✓ Inference produces correct results")
        logger.info("  5. ✓ Exported compiled model binary (like export_compiled_models)")
        logger.info("  6. ✓ Imported compiled model (like import_compiled_models - no recompilation)")
        logger.info("  7. ✓ First imported model produces identical results")
        logger.info("  8. ✓ Re-exported compiled model to third directory")
        logger.info("  9. ✓ Re-imported compiled model from third directory")
        logger.info(" 10. ✓ Second imported model produces identical results")
        logger.info("\nVerified: export → import → export → import chain works perfectly")
        logger.info(f"\nModel files are in:")
        logger.info(f"  Initial conversion: {initial_export_dir}")
        logger.info(f"  First export: {reimport_export_dir}/{ov_models_folder}")
        logger.info(f"  Second export: {third_export_dir}/{ov_models_folder}")
    else:
        logger.warning("⚠ SOME STEPS FAILED - CHECK LOGS ABOVE ⚠")
    
    logger.info("="*80)


if __name__ == "__main__":
    main()
