In [25]:
from typing import Callable, Optional

import torch as t
from torch import nn
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests

### Raw attention pattern

In [16]:
def raw_attention_pattern(
    token_activations: t.Tensor,
    num_heads: int,
    project_query: Callable[[t.Tensor], t.Tensor],
    project_key: Callable[[t.Tensor], t.Tensor],
) -> t.Tensor:
    """
    token_activations: Tensor[batch_size, seq_length, hidden_size (768)]
    project_query: function( (Tensor[..., 768]) -> Tensor[..., 768] )
    project_key:   function( (Tensor[..., 768]) -> Tensor[..., 768] )
    return: Tensor[batch_size, head_num, key_token: seq_length, query_token: seq_length]
    """

    queries = rearrange(
        project_query(token_activations), "b s (head d) -> b head s d", head=num_heads
    )
    keys = rearrange(
        project_key(token_activations), "b s (head d) -> b head s d", head=num_heads
    )

    head_size = t.tensor(keys.shape[-1])
    return einsum("bhid, bhjd -> bhij", keys, queries) / t.sqrt(head_size)


bert_tests.test_attention_pattern_fn(raw_attention_pattern)


attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: -0.00134 STD: 0.1129 VALS [-0.01475 0.08565 -0.0173 0.08945 0.1001 -0.2143 -0.05152 -0.08566 0.03025 -0.003722...]


In [17]:
def bert_attention(
    token_activations: t.Tensor,
    num_heads: int,
    attention_pattern: t.Tensor,
    project_value: Callable[[t.Tensor], t.Tensor],
    project_output: Callable[[t.Tensor], t.Tensor],
) -> t.Tensor:
    """
    token_activations: Tensor[batch_size, seq_length, hidden_size (768)],
    num_heads: int,
    attention_pattern: Tensor[batch_size,num_heads, seq_length, seq_length],
    project_value: function( (Tensor[..., 768]) -> Tensor[..., 768] ),
    project_output: function( (Tensor[..., 768]) -> Tensor[..., 768] )
    return: Tensor[batch_size, seq_length, hidden_size]
    """

    attention_prob = t.softmax(attention_pattern, dim=-2)  # dim: b head s s
    values = rearrange(
        project_value(token_activations), "b s (head d) -> b head s d", head=num_heads
    )

    output_by_head = einsum("bhis, bhid -> bhsd", attention_prob, values)
    concatenated = rearrange(output_by_head, "b h s d -> b s (h d)")

    return project_output(concatenated)


bert_tests.test_attention_fn(bert_attention)


attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001894 STD: 0.1235 VALS [0.1045 -0.07578 0.009482 -0.2152 -0.0599 0.08476 -0.2925 -0.02358 -0.1737 0.05641...]


In [26]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        hidden_size: int,
        attention_dim: int = 64,
        per_head_output_dim: int = 64,
        output_dim: Optional[int] = None,
    ):
        super().__init__()

        if output_dim is None:
            output_dim = hidden_size

        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.attention_dim = attention_dim
        self.per_head_output_dim = per_head_output_dim
        self.output_dim: int = output_dim

        self.Q = nn.Linear(
            in_features=hidden_size, out_features=num_heads * attention_dim
        )
        self.K = nn.Linear(
            in_features=hidden_size, out_features=num_heads * attention_dim
        )
        self.V = nn.Linear(
            in_features=hidden_size, out_features=num_heads * per_head_output_dim
        )
        self.O = nn.Linear(
            in_features=num_heads * per_head_output_dim, out_features=output_dim
        )

    def forward(self, input: t.Tensor) -> t.Tensor:
        """
        input: Tensor[batch_size, seq_length, hidden_size]
        """

        attention_pattern = raw_attention_pattern(
            input,
            self.num_heads,
            project_key=self.K,
            project_query=self.Q,
        )

        return bert_attention(
            token_activations=input,
            num_heads=self.num_heads,
            attention_pattern=attention_pattern,
            project_value=self.V,
            project_output=self.O,
        )


bert_tests.test_bert_attention(MultiHeadedSelfAttention)


bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [28]:
mhsa = MultiHeadedSelfAttention(
    num_heads=17,
    hidden_size=768,
)
mhsa(t.ones((10, 117, 768))).shape


torch.Size([10, 117, 768])