"""
Minimal reproducible example for tf2onnx bug with BatchNorm + custom layer in channels_first.

Bug: tf2onnx fails to convert models with multiple BatchNormalization layers followed by
a custom layer that performs element-wise multiplication in channels_first data format.

Error: "Check failed: size >= 0 (0 vs. -2)" during ONNX conversion
Exit code: 134 (SIGABRT)

This script demonstrates:
1. SUCCESSFUL conversion with channels_last (WORKS ✓)
2. FAILED conversion with channels_first (FAILS ✗)

Note: The script will crash with SIGABRT on the second conversion.
"""

import keras
import tensorflow as tf
import tf2onnx

# Disable GPU for consistent behavior
tf.config.set_visible_devices([], 'GPU')

print("="*70, flush=True)
print("tf2onnx Bug Report - Minimal Reproducible Example", flush=True)
print("="*70, flush=True)
print(f"TensorFlow version: {tf.__version__}", flush=True)
print(f"Keras version: {keras.__version__}", flush=True)
print(f"tf2onnx version: {tf2onnx.__version__}", flush=True)
print("="*70, flush=True)
print(flush=True)


# Define a simple custom layer that scales its input (similar to LayerScale)
@keras.utils.register_keras_serializable(package='BugReport')
class SimpleScaleLayer(keras.layers.Layer):
    """A simple layer that multiplies input by a learnable scale factor per channel."""

    def __init__(self, init_value=1e-5, **kwargs):
        super().__init__(**kwargs)
        self._init_value = init_value

    def build(self, input_shape):
        # Get number of channels based on data format
        data_format = keras.backend.image_data_format()
        if data_format == 'channels_last':
            num_channels = input_shape[-1]
        else:  # channels_first
            num_channels = input_shape[1]

        # Store data format for use in call()
        self._data_format = data_format

        # Create a learnable scale parameter (1D vector of channel scales)
        self._scale = self.add_weight(
            name='scale',
            shape=(num_channels,),
            initializer=keras.initializers.Constant(self._init_value),
            trainable=True
        )
        super().build(input_shape)

    def call(self, x):
        scale = tf.cast(self._scale, x.dtype)

        # Reshape scale for broadcasting
        if self._data_format == 'channels_last':
            # For [B, H, W, C], reshape scale [C] to [1, 1, 1, C]
            scale = tf.reshape(scale, (1, 1, 1, -1))
        else:  # channels_first
            # For [B, C, H, W], reshape scale [C] to [1, C, 1, 1]
            scale = tf.reshape(scale, (1, -1, 1, 1))

        return x * scale

    def get_config(self):
        config = {'init_value': self._init_value}
        base_config = super().get_config()
        return {**base_config, **config}


# Build a test model with the problematic pattern
class BugTestModel(keras.layers.Layer):
    """Model demonstrating the bug: multiple BatchNorm layers + custom scale layer."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        data_format = keras.backend.image_data_format()
        bn_axis = 1 if data_format == 'channels_first' else -1

        # First Conv + BatchNorm
        self.conv1 = keras.layers.Conv2D(64, 1, data_format=data_format, use_bias=False)
        self.bn1 = keras.layers.BatchNormalization(axis=bn_axis)

        # Second Conv + BatchNorm
        self.conv2 = keras.layers.Conv2D(80, 1, data_format=data_format, use_bias=False)
        self.bn2 = keras.layers.BatchNormalization(axis=bn_axis)

        # Third Conv + BatchNorm
        self.conv3 = keras.layers.Conv2D(80, 1, data_format=data_format, use_bias=False)
        self.bn3 = keras.layers.BatchNormalization(axis=bn_axis)

        # Custom scale layer - THIS IS WHERE THE BUG MANIFESTS
        self.scale = SimpleScaleLayer(init_value=1e-5)

        super().build(input_shape)

    def call(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.scale(x)  # BUG: ONNX conversion fails with channels_first
        return x


def build_and_convert_model(data_format):
    """Build model and attempt ONNX conversion with specified data format."""
    print(f"\n{'='*70}", flush=True)
    print(f"TEST: {data_format.upper()}", flush=True)
    print(f"{'='*70}", flush=True)

    # Set the data format
    keras.backend.set_image_data_format(data_format)
    print(f"Data format: {keras.backend.image_data_format()}", flush=True)

    # Build the model with appropriate input shape
    if data_format == 'channels_last':
        image_shape = (128, 128, 48)  # (H, W, C) for channels_last
    else:  # channels_first
        image_shape = (48, 128, 128)  # (C, H, W) for channels_first

    print(f"Input shape: {image_shape}", flush=True)
    print("\nBuilding model...", flush=True)

    inputs = keras.layers.Input(shape=image_shape, name='input')
    test_layer = BugTestModel()
    outputs = test_layer(inputs)
    model = keras.models.Model(inputs=inputs, outputs=outputs)

    print(f"✓ Model built successfully", flush=True)
    print(f"  Output shape: {model.output.shape}", flush=True)

    # Try to convert to ONNX
    print("\nAttempting ONNX conversion...", flush=True)
    output_filename = f"bug_test_{data_format}.onnx"

    input_signature = tf.TensorSpec(
        shape=(None,) + image_shape,
        dtype=tf.float32,
        name='input'
    )

    try:
        tf2onnx.convert.from_keras(
            model,
            input_signature=[input_signature],
            output_path=output_filename,
        )
        print(f"✓ ONNX conversion SUCCESSFUL!", flush=True)
        print(f"  Saved to: {output_filename}", flush=True)
        return True

    except Exception as e:
        print(f"✗ ONNX conversion FAILED", flush=True)
        print(f"\nError: {e}", flush=True)
        import traceback
        traceback.print_exc()
        return False


# ============================================================================
# PART 1: Test with channels_last (EXPECTED TO SUCCEED)
# ============================================================================
success_channels_last = build_and_convert_model('channels_last')

# ============================================================================
# PART 2: Test with channels_first (EXPECTED TO FAIL)
# ============================================================================
success_channels_first = build_and_convert_model('channels_first')

# ============================================================================
# SUMMARY
# ============================================================================
print(f"\n{'='*70}", flush=True)
print("SUMMARY", flush=True)
print(f"{'='*70}", flush=True)
print(f"channels_last:  {'✓ PASS' if success_channels_last else '✗ FAIL'}", flush=True)
print(f"channels_first: {'✓ PASS' if success_channels_first else '✗ FAIL'}", flush=True)
print(f"{'='*70}", flush=True)

if success_channels_last and not success_channels_first:
    print("\nBUG CONFIRMED!", flush=True)
    print("The same model architecture converts successfully with channels_last", flush=True)
    print("but fails with channels_first, indicating a tf2onnx bug.", flush=True)
elif success_channels_last and success_channels_first:
    print("\nBUG NOT REPRODUCED", flush=True)
    print("Both conversions succeeded. The bug may have been fixed.", flush=True)
else:
    print("\nUNEXPECTED RESULT", flush=True)
    print("Please check the environment and dependencies.", flush=True)

print(f"{'='*70}", flush=True)