import os
import torch
from transformers import AutoModelForCausalLM
from types import SimpleNamespace

# --- mamba_ssm import ---
# The Mamba2 model from the official mamba_ssm library is used directly.
from mamba_ssm.modules.mamba2 import Mamba2

PADDED_SEQ_LEN = 128

def demonstrate_padding_error(model: Mamba2):
    """
    Shows that an error occurs if `hidden_states` contains padding tokens that
    are not accounted for in `cu_seqlens`.
    """
    print(f"\n{'='*20} 🧪 SCENARIO 1: Ignoring Padding Tokens {'='*20}")
    print(
        "Goal: Show that an error occurs if `hidden_states` contains padding tokens "
        "that are not accounted for in `cu_seqlens`.\n"
    )

    device = model.in_proj.weight.device
    dtype = model.in_proj.weight.dtype

    # Define real requests and add some padding tokens
    real_request_lengths = [10, 5, 22]
    padding_len = PADDED_SEQ_LEN - sum(real_request_lengths)
    total_padded_len = PADDED_SEQ_LEN

    # 1. Create an input tensor that includes the padding tokens.
    # The shape is (1, total_padded_len, d_model) to represent a single large batch item.
    hidden_states_padded = torch.randn(
        1, total_padded_len, model.d_model, device=device, dtype=dtype
    )
    print(f"Input `hidden_states` shape: {hidden_states_padded.shape}")

    # 2. Create metadata that *only* describes the real requests, ignoring the padding.
    cu_seqlens = torch.cumsum(
        torch.tensor([0] + real_request_lengths, device=device), dim=0
    ).to(torch.int32)
    request_indices = torch.arange(len(real_request_lengths), device=device)
    lengths_tensor = torch.tensor(real_request_lengths, device=device)
    seq_idx = torch.repeat_interleave(request_indices, lengths_tensor).unsqueeze(0).to(torch.int32)

    print(f"`cu_seqlens` created for real requests: {cu_seqlens}")
    print(
        f"The last value in `cu_seqlens` is {cu_seqlens[-1]}, but the input tensor "
        f"has {total_padded_len} tokens. This mismatch is expected to cause an error."
    )

    # Allocate inference cache for only the real requests
    inference_params = SimpleNamespace(
        seqlen_offset=0,
        key_value_memory_dict={
            model.layer_idx: model.allocate_inference_cache(
                batch_size=len(real_request_lengths), max_seqlen=1
            )
        },
    )

    print("\nAttempting forward pass with mismatched metadata...")
    try:
        model.forward(
            hidden_states_padded,
            cu_seqlens=cu_seqlens,
            seq_idx=seq_idx,
            inference_params=inference_params,
        )
        print("🟢 SUCCESS: The forward pass succeeded!.")
    except Exception as e:
        print("🔴 FAILED: The forward pass threw an error.")
        print(f"   Error Type: {type(e).__name__}")
        # The error message from the CUDA kernel clearly states the mismatch.
        print(f"   Error Message: {e}")
        print(
            "\nThis confirms that the total length of the input tensor must exactly "
            "match the token count described by `cu_seqlens`."
        )


def demonstrate_padding_influence(model: Mamba2):
    """
    Shows that adding a padding sequence (as a dummy request) affects the output
    of the real, non-padding tokens.
    """
    print(f"\n{'='*20} 🧪 SCENARIO 2: Padding Tokens Influence Output {'='*20}")
    print(
        "Goal: Compare the model's output for a set of requests with and without "
        "an additional padding sequence to show that padding influences the result.\n"
    )

    device = model.in_proj.weight.device
    dtype = model.in_proj.weight.dtype

    # Define real requests and a padding sequence
    real_request_lengths = [10, 5, 22]
    real_tokens_count = sum(real_request_lengths)
    padding_len = PADDED_SEQ_LEN - real_tokens_count

    # --- Run 1: No Padding ---
    print("\n--- 🔬 Run 1: Forward pass without any padding ---")

    # 1a. Create input tensor containing only the real requests.
    hidden_states_no_padding = torch.randn(
        1, real_tokens_count, model.d_model, device=device, dtype=dtype
    )
    print(f"Input shape (no padding): {hidden_states_no_padding.shape}")

    # 1b. Construct metadata for only the real requests.
    cu_seqlens_no_padding = torch.cumsum(
        torch.tensor([0] + real_request_lengths, device=device), dim=0
    ).to(torch.int32)
    request_indices_no_padding = torch.arange(len(real_request_lengths), device=device)
    lengths_tensor_no_padding = torch.tensor(real_request_lengths, device=device)
    seq_idx_no_padding = torch.repeat_interleave(
        request_indices_no_padding, lengths_tensor_no_padding
    ).unsqueeze(0).to(torch.int32)

    print(f"`cu_seqlens` (no padding): {cu_seqlens_no_padding}")

    # 1c. Run the forward pass.
    inference_params_1 = SimpleNamespace(
        seqlen_offset=0,
        key_value_memory_dict={
            model.layer_idx: model.allocate_inference_cache(
                batch_size=len(real_request_lengths), max_seqlen=1
            )
        },
    )
    out_no_padding = model.forward(
        hidden_states_no_padding,
        cu_seqlens=cu_seqlens_no_padding,
        seq_idx=seq_idx_no_padding,
        inference_params=inference_params_1,
    ).squeeze(0)
    print("Forward pass without padding complete.")

    # --- Run 2: With Padding ---
    print("\n--- 🔭 Run 2: Forward pass with a padding sequence ---")

    # 2a. Create an input tensor with the *same* real request data, plus padding.
    padding_tensors = torch.randn(
        1, padding_len, model.d_model, device=device, dtype=dtype
    )
    hidden_states_with_padding = torch.cat(
        [hidden_states_no_padding, padding_tensors], dim=1
    )
    print(f"Input shape (with padding): {hidden_states_with_padding.shape}")

    # 2b. Construct metadata that accounts for padding as a dummy request.
    all_request_lengths = real_request_lengths + [padding_len]
    cu_seqlens_with_padding = torch.cumsum(
        torch.tensor([0] + all_request_lengths, device=device), dim=0
    ).to(torch.int32)
    request_indices_with_padding = torch.arange(len(all_request_lengths), device=device)
    lengths_tensor_with_padding = torch.tensor(all_request_lengths, device=device)
    seq_idx_with_padding = torch.repeat_interleave(
        request_indices_with_padding, lengths_tensor_with_padding
    ).unsqueeze(0).to(torch.int32)

    print(f"`cu_seqlens` (with padding): {cu_seqlens_with_padding}")

    # 2c. Run the forward pass.
    inference_params_2 = SimpleNamespace(
        seqlen_offset=0,
        key_value_memory_dict={
            model.layer_idx: model.allocate_inference_cache(
                batch_size=len(all_request_lengths), max_seqlen=1
            )
        },
    )
    out_batch_with_padding = model.forward(
        hidden_states_with_padding,
        cu_seqlens=cu_seqlens_with_padding,
        seq_idx=seq_idx_with_padding,
        inference_params=inference_params_2,
    )
    # Extract the output for only the real (non-padding) tokens
    out_with_padding_real_tokens = out_batch_with_padding.squeeze(0)[:real_tokens_count]
    print("Forward pass with padding complete.")

    # --- Comparison ---
    print("\n--- ✅ Comparison Results ---")
    # Use a tight tolerance to check for any difference
    outputs_match = torch.allclose(
        out_no_padding, out_with_padding_real_tokens, atol=1e-5, rtol=1e-5
    )

    if not outputs_match:
        print("🔴 ERROR: The outputs for the real tokens are DIFFERENT.")
        diff = torch.max(
            torch.abs(out_no_padding - out_with_padding_real_tokens)
        )
        print(f"   -> Max absolute difference: {diff.item():.6f}")
        print(
            "\nThis demonstrates that the presence of a padding sequence (even when treated as a separate request) "
            "influences the computation for the real tokens."
        )
    else:
        print("🟢 SUCCESS: The outputs for the real tokens are identical.")


@torch.inference_mode()
def main():
    """Main function to instantiate the Mamba2 model and run padding tests."""
    torch.manual_seed(42)

    # --- Model Configuration ---
    D_MODEL = 5120
    D_STATE = 128
    D_CONV = 4
    EXPAND = 2
    HEADDIM = 80
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    DTYPE = torch.bfloat16

    if DEVICE == "cpu":
        print("🔴 This script requires a CUDA-enabled GPU to run the Mamba2 kernels.")
        return

    print("🔧 Initializing Mamba2 model...")
    model = Mamba2(
        d_model=D_MODEL,
        d_state=D_STATE,
        d_conv=D_CONV,
        expand=EXPAND,
        headdim=HEADDIM,
        ngroups=8,
        layer_idx=0,
        device=DEVICE,
        dtype=DTYPE,
    ).eval()
    print(f"✅ Model initialized on {DEVICE} with {DTYPE} dtype.")

    print("🔧 Pulling weights from NVIDIA-Nemotron-Nano-12B-v2-Base...")
    # We load weights from a real model to ensure realistic behavior.
    try:
        hf_model = AutoModelForCausalLM.from_pretrained(
            "nvidia/NVIDIA-Nemotron-Nano-12B-v2-Base", trust_remote_code=True
        )
        state_dict = {
            name: param
            for (name, param) in hf_model.backbone.layers[0].mixer.named_parameters()
        }
        model.load_state_dict(state_dict)
        print("✅ Weights loaded successfully.")
    except Exception as e:
        print(f"🔴 Could not load weights from Hugging Face: {e}")
        print("Proceeding with randomly initialized weights.")


    # --- RUN DEMONSTRATIONS ---
    demonstrate_padding_error(model)
    demonstrate_padding_influence(model)


if __name__ == "__main__":
    main()
