from __future__ import annotations
import os
import sys
from torch import Tensor, nn, empty, softmax
from unittest.mock import MagicMock

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.nn.modules.embedding import Embedding
from src.nn.modules.normalization import RMSNorm
from src.nn.modules.linear import Linear
from src.nn.modules.embedding import RotaryPositionalEncoding
from src.nn.modules.attention import CausalMaskedSelfAttention, MultiHeadSelfAttention


def run_linear(
    d_in: int,
    d_out: int,
    weights: Tensor, # [d_out, d_in]
    in_features: Tensor, # [..., d_in]
) -> Tensor:
    """
    Given the weights of a Linear layer, compute the transformation of a batched input.

    Args:
        in_dim (int): The size of the input dimension
        out_dim (int): The size of the output dimension
        weights (Tensor): The linear weights to use
        in_features (Tensor): The output tensor to apply the function to

    Returns:
        Tensor: The transformed output of your linear module.
    """
    linear: Linear = Linear(d_in, d_out)
    linear.weight = nn.Parameter(weights)
    return linear(in_features)


def run_embedding(
    vocab_size: int,
    d_model: int,
    weights: Tensor, # [vocab_size, d_model]
    token_ids: Tensor
) -> Tensor:
    """
    Given the weights of an Embedding layer, get the embeddings for a batch of token ids.

    Args:
        vocab_size (int): The number of embeddings in the vocabulary
        d_model (int): The size of the embedding dimension
        weights (Tensor): The embedding vectors to fetch from
        token_ids (Tensor): The set of token ids to fetch from the Embedding layer

    Returns:
        Tensor: Batch of embeddings returned by your Embedding layer.
    """
    embedding: Embedding = Embedding(vocab_size, d_model)
    embedding.weight = nn.Parameter(weights)
    return embedding(token_ids)


def run_rmsnorm(
    d_model: int,
    eps: float,
    weights: Tensor, # [d_model]
    in_features: Tensor, # [..., d_model]
) -> Tensor:
    """Given the weights of a RMSNorm affine transform,
    return the output of running RMSNorm on the input features.

    Args:
        d_model (int): The dimensionality of the RMSNorm input.
        eps: (float): A value added to the denominator for numerical stability.
        weights (Tensor): RMSNorm weights.
        in_features (Tensor): Input features to run RMSNorm on. Can have arbitrary leading
            dimensions.

    Returns:
        Tensor: Tensor of with the same shape as `in_features` with the output of running
        RMSNorm of the `in_features`.
    """
    rms_norm: RMSNorm = RMSNorm(d_model, eps)
    rms_norm.weight = nn.Parameter(weights)
    return rms_norm(in_features)


def run_rope(
    d_model: int,
    theta: float,
    seq_len: int,
    in_query_or_key: Tensor # [..., seq_len, d_model]
) -> Tensor:
    """
    Run RoPE for a given input tensor.

    Args:
        d_model (int): Embedding dimension size for the query or key tensor.
        theta (float): RoPE parameter.
        seq_len (int): Sequence length of the input tensor.
        in_query_or_key (Tensor): Input tensor to run RoPE on.
    Returns:
        Tensor: Tensor with RoPEd input.
    """
    rope: RotaryPositionalEncoding = RotaryPositionalEncoding(seq_len, d_model, theta)
    return rope(in_query_or_key)


def run_scaled_dot_product_attention(
    Q: Tensor,
    K: Tensor,
    V: Tensor,
    mask: Tensor
) -> Tensor:
    """
    Given the query, key, and value tensors, compute the scaled dot product attention.

    Args:
        Q (Tensor): Query tensor.
        K (Tensor): Key tensor.
        V (Tensor): Value tensor.
        mask (Tensor): Mask tensor.

    Returns:
        Tensor: Scaled dot product attention output.
    """
    
    attention = CausalMaskedSelfAttention(embed_size=Q.shape[-1], d_k=Q.shape[-1], d_v=Q.shape[-1])
    attention.W_Q.forward = MagicMock(return_value=Q)
    attention.W_K.forward = MagicMock(return_value=K)
    attention.W_V.forward = MagicMock(return_value=V)

    x: Tensor = empty(Q.shape)
    return attention(x, ~mask)


def run_multihead_self_attention(
    d_model: int,
    num_heads: int,
    q_proj_weight: Tensor,
    k_proj_weight: Tensor,
    v_proj_weight: Tensor,
    o_proj_weight: Tensor,
    in_features: Tensor
) -> Tensor:
    """
    Given the query, key, and value tensors, compute the scaled dot product attention.

    Args:
        d_model (int): Embedding dimension size for the query or key tensor.
        num_heads (int): Number of heads.
        q_proj_weight (Tensor): Query projection weights [d_k, d_in].
        k_proj_weight (Tensor): Key projection weights [d_k, d_in].
        v_proj_weight (Tensor): Value projection weights [d_v, d_in].
        o_proj_weight (Tensor): Output projection weights [d_out, d_in].
        in_features (Tensor): Input features to run MultiHeadSelfAttention on.

    Returns:
        Tensor: Scaled dot product attention output.
    """

    attention = MultiHeadSelfAttention(embed_size=d_model, d_k=d_model//num_heads, d_v=d_model//num_heads, num_heads=num_heads)
    attention.W_Q.weight = nn.Parameter(q_proj_weight)
    attention.W_K.weight = nn.Parameter(k_proj_weight)
    attention.W_V.weight = nn.Parameter(v_proj_weight)
    attention.W_O.weight = nn.Parameter(o_proj_weight)

    x: Tensor = in_features
    return attention(x)
    
