In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [16]:
class SlidingWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"

        self.scaling = self.head_dim ** -0.5

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, C = x.shape
        qkv_x = self.qkv_proj(x)
        qkv = qkv_x.reshape(B, T, 3, self.num_heads, self.head_dim)
        print("qkv: ", qkv.shape)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).chunk(3)
        print("qkv: ", qkv.permute(2, 0, 3, 1, 4).shape)

        q = q * self.scaling

        # Create padded k and v tensors
        k_padded = F.pad(k, (0, 0, self.window_size // 2, self.window_size // 2))
        v_padded = F.pad(v, (0, 0, self.window_size // 2, self.window_size // 2))
        print(k_padded.shape,)

        # Use unfold to create sliding windows
        k_unfold = k_padded.unfold(3, self.window_size, 1)
        v_unfold = v_padded.unfold(3, self.window_size, 1)

        # Adjusted einsum for batched matrix multiplication
        attn_weights = torch.einsum('bhqd,bhqkl->bhkl', q, k_unfold)
        attn_weights = F.softmax(attn_weights, dim=-1)

        attn = torch.einsum('bhkl,bhqkv->bhqv', attn_weights, v_unfold)

        attn = attn.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn)

# Example usage
embed_dim = 128
num_heads = 8
window_size = 10
model = SlidingWindowAttention(embed_dim, num_heads, window_size)
input_tensor = torch.rand(1, 50, embed_dim) # Batch

In [17]:
output = model(input_tensor)


qkv:  torch.Size([1, 50, 3, 8, 16])
qkv:  torch.Size([3, 1, 8, 50, 16])
torch.Size([1, 1, 8, 60, 16])


RuntimeError: einsum(): the number of subscripts in the equation (4) does not match the number of dimensions (5) for operand 0 and no ellipsis was given

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvSlidingWindowAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_dim // num_heads

        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"

        self.q_conv = nn.Conv1d(embed_dim, embed_dim, window_size, padding=window_size//2)
        self.k_conv = nn.Conv1d(embed_dim, embed_dim, window_size, padding=window_size//2)
        self.v_conv = nn.Conv1d(embed_dim, embed_dim, window_size, padding=window_size//2)

        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, C = x.shape

        # Reshape for convolution
        x = x.permute(0, 2, 1)

        q = self.q_conv(x).permute(0, 2, 1)
        k = self.k_conv(x).permute(0, 2, 1)
        v = self.v_conv(x).permute(0, 2, 1)
        
        # Scaling for dot product attention
        q = q * (self.head_dim ** -0.5)

        # Reshape q, k, v for multi-head attention
        q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Dot product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1))
        attn_weights = F.softmax(attn_weights, dim=-1)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).reshape(B, T, C)

        # Output projection
        output = self.out_proj(attn_output)
        return output



In [22]:
# Example usage
embed_dim = 128
num_heads = 8
window_size = 10
model = ConvSlidingWindowAttention(embed_dim, num_heads, window_size)
input_tensor = torch.rand(1, 50, embed_dim) # Batch

output = model(input_tensor)


RuntimeError: shape '[1, 50, 8, 16]' is invalid for input of size 6528

In [21]:
input_tensor.shape

torch.Size([1, 50, 128])

In [23]:
class MistralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config: MistralConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.rotary_emb = MistralRotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        padding_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2]
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )

            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

NameError: name 'MistralConfig' is not defined