import os
import sys
from torch import Tensor
from typing import Dict

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from src.tests.adapters import (run_embedding, run_rmsnorm, run_linear, run_rope, run_scaled_dot_product_attention, run_multihead_self_attention)
from src.tests.conftest import NumpySnapshot


def test_linear(numpy_snapshot: NumpySnapshot, ts_state_dict: Dict[str, Tensor], in_embeddings: Tensor, d_model: int, d_ff: int):
    w1_weight = ts_state_dict[0]["layers.0.ffn.w1.weight"]
    output = run_linear(
        d_in=d_model,
        d_out=d_ff,
        weights=w1_weight,
        in_features=in_embeddings,
    )
    numpy_snapshot.assert_match(output)

def test_embedding(numpy_snapshot: NumpySnapshot, ts_state_dict: Dict[str, Tensor], in_indices: Tensor, vocab_size: int, d_model: int):
    embedding_weight = ts_state_dict[0]["token_embeddings.weight"]
    output: Tensor = run_embedding(
        vocab_size=vocab_size,
        d_model=d_model,
        weights=embedding_weight,
        token_ids=in_indices,
    )
    numpy_snapshot.assert_match(output)

def test_rmsnorm(numpy_snapshot: NumpySnapshot, ts_state_dict: Dict[str, Tensor], in_embeddings: Tensor):
    state_dict, _ = ts_state_dict
    reference_weights = state_dict["layers.1.ln1.weight"]
    d_model = reference_weights.shape[0]
    actual_output = run_rmsnorm(d_model=d_model, eps=1e-5, weights=reference_weights, in_features=in_embeddings)
    numpy_snapshot.assert_match(actual_output, atol=1e-6)

def test_rope(numpy_snapshot: NumpySnapshot, in_embeddings: Tensor, d_model: int, theta: float, pos_ids: Tensor):
    output = run_rope(
        d_model=d_model, theta=theta, seq_len=pos_ids.shape[0], in_query_or_key=in_embeddings
    )
    numpy_snapshot.assert_match(output, atol=1e-6)

def test_scaled_dot_product_attention(numpy_snapshot: NumpySnapshot, q: Tensor, k: Tensor, v: Tensor, mask: Tensor):
    actual_output = run_scaled_dot_product_attention(Q=q, K=k, V=v, mask=mask)
    numpy_snapshot.assert_match(
        actual_output,
        atol=1e-6,
    )

def test_multihead_self_attention(numpy_snapshot: NumpySnapshot, in_embeddings: Tensor, d_model: int, n_heads: int, ts_state_dict: Dict[str, Tensor]):
    d, _ = ts_state_dict
    q_proj_weight, k_proj_weight, v_proj_weight, o_proj_weight = [
        d[f"layers.0.attn.{k}_proj.weight"] for k in ["q", "k", "v", "output"]
    ]
    actual_output = run_multihead_self_attention(
        d_model=d_model,
        num_heads=n_heads,
        q_proj_weight=q_proj_weight,
        k_proj_weight=k_proj_weight,
        v_proj_weight=v_proj_weight,
        o_proj_weight=o_proj_weight,
        in_features=in_embeddings
    )
    numpy_snapshot.assert_match(actual_output, atol=1e-6)