import os
import sys
import torch
from torch.nn.functional import softmax

SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(SCRIPT_DIR)

from src.nn.modules.linear import Linear


class SelfAttention(torch.nn.Module):
    def __init__(self, embed_size: int, d_k: int, d_v: int, dropout: float=0.0, device: torch.device=None, dtype: torch.dtype=None):
        super().__init__()
        self.W_Q = Linear(in_features=embed_size, out_features=d_k, device=device, dtype=dtype)
        self.W_K = Linear(in_features=embed_size, out_features=d_k, device=device, dtype=dtype)
        self.W_V = Linear(in_features=embed_size, out_features=d_v, device=device, dtype=dtype)
        self.dropout = torch.nn.Dropout(dropout)

    # x: [batch_size, seq_len, embed_size]
    def _get_normalized_scores(self, x: torch.Tensor) -> torch.Tensor:
        # [batch_size, seq_len, d_k]
        Q: torch.Tensor = self.W_Q(x)
        # [batch_size, seq_len, d_k]
        K: torch.Tensor = self.W_K(x)
        # [batch_size, seq_len, seq_len]
        return (Q @ K.transpose(-2, -1)) / (K.shape[-1] ** 0.5)
        
    # x: [batch_size, seq_len, embed_size]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # [batch_size, seq_len, d_v]
        V: torch.Tensor = self.W_V(x)
        # [batch_size, seq_len, seq_len]
        attention_scores: torch.Tensor = self._get_normalized_scores(x)
        # [batch_size, seq_len, seq_len]
        attention_weights: torch.Tensor = self.dropout(softmax(attention_scores, dim=-1))
        # [batch_size, seq_len, d_v]
        return attention_weights @ V


class CausalMaskedSelfAttention(SelfAttention):
    def __init__(self, embed_size: int, d_k: int, d_v: int, dropout: float=0.0, device: torch.device=None, dtype: torch.dtype=None):
        super().__init__(embed_size=embed_size, d_k=d_k, d_v=d_v, dropout=dropout, device=device, dtype=dtype)
    
    # x: [batch_size, seq_len, embed_size]
    # mask: [seq_len, seq_len]
    def forward(self, x: torch.Tensor, mask: torch.Tensor| None = None) -> torch.Tensor:
        # [batch_size, seq_len, d_v]
        V: torch.Tensor = self.W_V(x)
        # [batch_size, seq_len, seq_len]
        attention_scores: torch.Tensor = self._get_normalized_scores(x)
        if mask is None:
            mask = torch.triu(torch.ones(attention_scores.shape[-2], attention_scores.shape[-1], device=attention_scores.device, dtype=attention_scores.dtype),
                                        diagonal=1).bool()
        # [batch_size, seq_len, seq_len]
        attention_scores = attention_scores.masked_fill(mask, -torch.inf)
        # [batch_size, seq_len, seq_len]
        attention_weights: torch.Tensor = self.dropout(softmax(attention_scores, dim=-1))
        # [batch_size, seq_len, d_v]
        return attention_weights @ V


class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, embed_size: int, d_k: int, d_v: int, num_heads: int, dropout: float=0.0, device: torch.device=None, dtype: torch.dtype=None):
        super().__init__()
        self.num_heads: int = num_heads
        self.W_Q = Linear(in_features=embed_size, out_features=num_heads*d_k, device=device, dtype=dtype)
        self.W_K = Linear(in_features=embed_size, out_features=num_heads*d_k, device=device, dtype=dtype)
        self.W_V = Linear(in_features=embed_size, out_features=num_heads*d_v, device=device, dtype=dtype)
        self.W_O = Linear(in_features=num_heads*d_v, out_features=embed_size, device=device, dtype=dtype)
        self.dropout = torch.nn.Dropout(dropout)
        self.d_k: int = d_k
        self.d_v: int = d_v
    
    # x: [batch_size, seq_len, embed_size]
    def _get_normalized_scores(self, x: torch.Tensor) -> torch.Tensor:
        # [batch_size, seq_len, num_heads*d_k]
        Q: torch.Tensor = self.W_Q(x)
        # [batch_size, seq_len, num_heads*d_k]
        K: torch.Tensor = self.W_K(x)
        # [batch_size, seq_len, num_heads, d_k]
        Q = Q.view(Q.shape[0], Q.shape[1], self.num_heads, self.d_k)
        # [batch_size, seq_len, num_heads, d_k]
        K = K.view(K.shape[0], K.shape[1], self.num_heads, self.d_k)
        # [batch_size, num_heads, seq_len, d_k]
        Q = Q.transpose(1, 2)
        # [batch_size, num_heads, seq_len, d_k]
        K = K.transpose(1, 2)
        # [batch_size, num_heads, seq_len, seq_len]
        return (Q @ K.transpose(-2, -1)) / (K.shape[-1] ** 0.5)

    # x: [batch_size, seq_len, embed_size]
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # [batch_size, seq_len, num_heads*d_v]
        V: torch.Tensor = self.W_V(x)
        # [batch_size, seq_len, num_heads, d_v]
        V = V.view(V.shape[0], V.shape[1], self.num_heads, self.d_v)
        # [batch_size, num_heads, seq_len, d_v]
        V = V.transpose(1, 2)
        # [batch_size, num_heads, seq_len, seq_len]
        attention_scores: torch.Tensor = self._get_normalized_scores(x)
        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights: torch.Tensor = self.dropout(softmax(attention_scores, dim=-1))
        # [batch_size, num_heads, seq_len, d_v]
        attention_weights = attention_weights @ V
        # [batch_size, seq_len, num_heads, d_v]
        attention_weights = attention_weights.transpose(1, 2)
        # [batch_size, seq_len, num_heads*d_v]
        attention_weights = attention_weights.flatten(-2, -1)
        # [batch_size, seq_len, embed_size]
        attention_weights = self.W_O(attention_weights)
        return attention_weights


class MultiHeadCausalMaskedSelfAttention(MultiHeadSelfAttention):
    def __init__(self, embed_size: int, d_k: int, d_v: int, num_heads: int, dropout: float=0.0, device: torch.device=None, dtype: torch.dtype=None):
        super().__init__(embed_size=embed_size, d_k=d_k, d_v=d_v, num_heads=num_heads, dropout=dropout, device=device, dtype=dtype)

    # x: [batch_size, seq_len, embed_size]
    # mask: [seq_len, seq_len]
    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        # [batch_size, seq_len, num_heads*d_v]
        V: torch.Tensor = self.W_V(x)
        # [batch_size, seq_len, num_heads, d_v]
        V = V.view(V.shape[0], V.shape[1], self.num_heads, self.d_v)
        # [batch_size, num_heads, seq_len, d_v]
        V = V.transpose(1, 2)
        # [batch_size, num_heads, seq_len, seq_len]
        attention_scores: torch.Tensor = self._get_normalized_scores(x)
        if mask is None:
            mask = torch.triu(torch.ones(attention_scores.shape[-2], attention_scores.shape[-1], device=attention_scores.device, dtype=attention_scores.dtype),
                                        diagonal=1).bool()
        # [batch_size, num_heads, seq_len, seq_len]
        attention_scores = attention_scores.masked_fill(mask, -torch.inf)
        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights: torch.Tensor = self.dropout(softmax(attention_scores, dim=-1))
        # [batch_size, num_heads, seq_len, d_v]
        attention_weights = attention_weights @ V
        # [batch_size, seq_len, num_heads, d_v]
        attention_weights = attention_weights.transpose(1, 2)
        # [batch_size, seq_len, num_heads*d_v]
        attention_weights = attention_weights.flatten(-2, -1)
        # [batch_size, seq_len, embed_size]
        attention_weights = self.W_O(attention_weights)
        return attention_weights

        
        
        

        
        