In [1]:
import torch
from torch import nn
from torchtune.modules.attention import CausalSelfAttention

  from .autonotebook import tqdm as notebook_tqdm


In [28]:
from torchtune.modules import RotaryPositionalEmbeddings
from torch import nn

# Assuming head_dim = 72 (as per your input tensor shape)

pos_embeddings = RotaryPositionalEmbeddings(dim=head_dim)


In [27]:
embed_dim = 72
head_dim = 72 // 12
num_heads = 12
q_proj = nn.Linear(embed_dim, embed_dim)
k_proj = nn.Linear(embed_dim, embed_dim)
v_proj = nn.Linear(embed_dim, embed_dim)
output_proj = nn.Linear(embed_dim, embed_dim)


In [29]:
attention_layer = CausalSelfAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            num_kv_heads=num_heads,  # Use the same number of heads for key/value
            head_dim=head_dim,  # Correctly use head_dim here
            q_proj=q_proj,
            k_proj=k_proj,
            v_proj=v_proj,
            output_proj=output_proj,
            pos_embeddings=pos_embeddings,
            attn_dropout=0.1
        )

In [30]:
x = torch.rand([4,24,56,72])

In [33]:
x_shape = x.shape
output = attention_layer(x.view(x.shape[0] * x.shape[1], -1, embed_dim))
output = output.view(x_shape)
print(output.shape)

torch.Size([4, 24, 56, 72])


In [19]:
import torch
from torch import nn, Tensor
from torchtune.modules import RotaryPositionalEmbeddings
from torchtune.modules.kv_cache import KVCache
from typing import Optional

class CausalSelfAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
        q_proj: nn.Module,
        k_proj: nn.Module,
        v_proj: nn.Module,
        output_proj: nn.Module,
        pos_embeddings: nn.Module,  # Pass the RotaryPositionalEmbeddings here
        kv_cache: Optional[KVCache] = None,
        max_seq_len: int = 4096,
        attn_dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.embed_dim = embed_dim
        self.attn_dropout = attn_dropout
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len

        # Set layers
        self.kv_cache = kv_cache
        self.q_proj = q_proj
        self.k_proj = k_proj
        self.v_proj = v_proj
        self.output_proj = output_proj
        self.pos_embeddings = pos_embeddings  # Store the positional embeddings

    def forward(
        self,
        x: Tensor,
        *,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
    ) -> Tensor:
        bsz, seq_len, _ = x.shape

        # Project input to queries, keys, and values
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape queries, keys, and values to include heads
        q = q.view(bsz, seq_len, self.num_heads, self.head_dim)  # [4, 24, 56, 72]
        k = k.view(bsz, seq_len, self.num_kv_heads, self.head_dim)  # [4, 24, 28, 72]
        v = v.view(bsz, seq_len, self.num_kv_heads, self.head_dim)  # [4, 24, 28, 72]

        # Apply positional embeddings to queries and keys
        q = self.pos_embeddings(q, input_pos=input_pos)
        k = self.pos_embeddings(k, input_pos=input_pos)

        # Transpose for scaled dot-product attention
        q = q.transpose(1, 2)  # [4, 56, 24, 72]
        k = k.transpose(1, 2)  # [4, 28, 24, 72]
        v = v.transpose(1, 2)  # [4, 28, 24, 72]

        # Update key-value cache
        if self.kv_cache is not None:
            k, v = self.kv_cache.update(input_pos, k, v)

        if mask is not None:
            mask = mask[:, None, :, :]

        # Perform scaled dot-product attention
        output = nn.functional.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=mask,
            dropout_p=self.attn_dropout,
            is_causal=self.kv_cache is None and mask is None,
        )

        # Reshape the output to be the same shape as the input
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        return self.output_proj(output)

# Example input tensor of shape [batch_size, seq_length, embed_dim]
input_tensor = torch.randn(4, 24, 4032)  # [4, 24, 56 * 72]

# Define the projection layers and the output projection layer
q_proj = nn.Linear(4032, 4032)
k_proj = nn.Linear(4032, 4032)
v_proj = nn.Linear(4032, 4032)
output_proj = nn.Linear(4032, 4032)

# Initialize the RotaryPositionalEmbeddings and CausalSelfAttention
pos_embeddings = RotaryPositionalEmbeddings(dim=72)

causal_self_attention = CausalSelfAttention(
    embed_dim=4032,
    num_heads=56,
    num_kv_heads=28,  # GQA configuration with 28 key/value heads
    head_dim=72,
    q_proj=q_proj,
    k_proj=k_proj,
    v_proj=v_proj,
    output_proj=output_proj,
    pos_embeddings=pos_embeddings,
)

# Apply the CausalSelfAttention to the input tensor
output = causal_self_attention(input_tensor)
print(output.shape)  # Expected output shape: [4, 24, 4032]


RuntimeError: shape '[4, 24, 28, 72]' is invalid for input of size 387072

In [18]:
# Example input tensor of shape [batch_size, seq_length, embed_dim]
input_tensor = torch.randn(4, 24, 56 * 72)  # [4, 24, 4032]

# Define the projection layers and the output projection layer
q_proj = nn.Linear(56 * 72, 56 * 72)
k_proj = nn.Linear(56 * 72, 56 * 72)
v_proj = nn.Linear(56 * 72, 56 * 72)
output_proj = nn.Linear(56 * 72, 56 * 72)

# Initialize the CausalSelfAttention module
causal_self_attention = CausalSelfAttention(
    embed_dim=56 * 72,
    num_heads=56,
    num_kv_heads=28,  # For example, if you want GQA
    head_dim=72,
    q_proj=q_proj,
    k_proj=k_proj,
    v_proj=v_proj,
    output_proj=output_proj,
    pos_embeddings=pos_embeddings,
)

# Apply the CausalSelfAttention to the input tensor
output = causal_self_attention(input_tensor)


RuntimeError: shape '[4, 24, 28, 72]' is invalid for input of size 387072