In [1]:
from math import sqrt, log
import torch
import torch.nn as nn
import torch.functional as F
import numpy as np

import sys
ROOT_DIR = "../"
sys.path.append(ROOT_DIR)

#### torch.einsum
[documentation](https://pytorch.org/docs/stable/generated/torch.einsum.html)

In [2]:
BATCH_SIZE = 1
seq_len = 2          # sequence length (here we assume the same for input/output, but it can be different)
heads = 3            # number of heads (multi-head attention)
d_model = 12         # Embedding dimension (?)
d_query_key = 9      # Dimension for the query and key vectors
d_values = 10        # Dimension for the value vector

x = torch.ones(BATCH_SIZE,seq_len,d_model) # d_model 

## Multihead detail

In [3]:
# single-head VS multi-head
x = torch.ones(BATCH_SIZE,seq_len,d_model)
one_projection = nn.Linear(d_model, d_query_key, bias=False)
multi_projection = nn.Linear(d_model, d_query_key * heads, bias=False)

# set the same weights
same_weights = multi_projection.weight[:d_query_key,:]
one_projection.weight = nn.Parameter(same_weights)

if (one_projection.weight-same_weights).all() == 0:
    print("The weights are the same")

The weights are the same


### Projection to Query-Key dimension
The input is projected from its embedding dimension `embed_dim` to the query-key dimension `d_queries_keys`. In case of multi-head attention, the latter is multiplied by the number of heads.

In [4]:
one_x = one_projection(x)
multi_x = multi_projection(x)
print(f"x shape: {x.shape}, single head shape: {one_x.shape}, multi head shape: {multi_x.shape}")

x shape: torch.Size([1, 2, 12]), single head shape: torch.Size([1, 2, 9]), multi head shape: torch.Size([1, 2, 27])


In [5]:
one_x

tensor([[[ 0.0128,  0.2430, -0.7559,  0.4436, -0.3671,  1.3778, -0.5534,
           1.1814,  0.5273],
         [ 0.0128,  0.2430, -0.7559,  0.4436, -0.3671,  1.3778, -0.5534,
           1.1814,  0.5273]]], grad_fn=<UnsafeViewBackward0>)

In [6]:
multi_x

tensor([[[ 0.0128,  0.2430, -0.7559,  0.4436, -0.3671,  1.3778, -0.5534,
           1.1814,  0.5273, -0.2040,  0.7508, -0.1226,  0.3260, -0.9922,
          -0.1010, -0.4803, -0.0498, -1.0036,  0.5103, -0.1023,  0.3187,
          -0.2637,  0.0280,  0.0977,  0.3184,  0.2986, -0.9544],
         [ 0.0128,  0.2430, -0.7559,  0.4436, -0.3671,  1.3778, -0.5534,
           1.1814,  0.5273, -0.2040,  0.7508, -0.1226,  0.3260, -0.9922,
          -0.1010, -0.4803, -0.0498, -1.0036,  0.5103, -0.1023,  0.3187,
          -0.2637,  0.0280,  0.0977,  0.3184,  0.2986, -0.9544]]],
       grad_fn=<UnsafeViewBackward0>)

Check that the first `d_queries_keys` are the same for both tensors. The successive heads are concatenated along the last dimension.

*Q: How do I know that different heads learn different patterns?*

# Implementation
$$Attention = Softmax\Bigg(\frac{Q\cdot K^T}{\sqrt{d}}\Bigg)V$$
### 1. Query and Key

In [7]:
# input shape (BATCH_SIZE,seq_len,d_model) 
# not sure it's correct...where are the features? It's all in the embedding?

x = torch.rand(BATCH_SIZE,seq_len,d_model) # d_model = embed_dim (?)

query_projection = nn.Linear(d_model, d_query_key * heads)
key_projection = nn.Linear(d_model, d_query_key * heads)

B, L, _ = x.shape
query = query_projection(x)
key = key_projection(x)

query.shape

torch.Size([1, 2, 27])

### 2. Separate heads
So far the heads are all concatenated along the last dimension. Using [view](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html), we can separate separate the heads, each one of them with its `d_query_key` components.

In [8]:
query_view = query.view((B, L, heads, -1))
key_view = key.view((B, L, heads, -1))

print(f"Old shape: BATCH_SIZE ({BATCH_SIZE}), seq_len ({seq_len}), d_query_key*heads ({d_query_key*heads}) --> {query.shape}")
print(f"New shape: BATCH_SIZE ({BATCH_SIZE}), seq_len ({seq_len}), heads ({heads}), d_query_key ({d_query_key*heads}) --> {query_view.shape}")

Old shape: BATCH_SIZE (1), seq_len (2), d_query_key*heads (27) --> torch.Size([1, 2, 27])
New shape: BATCH_SIZE (1), seq_len (2), heads (3), d_query_key (27) --> torch.Size([1, 2, 3, 9])


### 3. Scaling
The "scaled" dot-product attention is scaled by $\sqrt{d_{query, key}}$

In [9]:
scale = 1.0 / sqrt(d_query_key)

### 4. Attention scores
This is the center of the dot-product attention, where the similarity between query and keys are calculated. We obtain the **attention matrix**, which gives insights on the degree of correlation between query and key (*cross-attention* input/output, *self-attention* input/input).

*Q: does it works also for different input-output sizes? Yes because the dot product is taken along the key-queries dimension*

In [10]:
scores = torch.einsum("blhe,bshe->bhls", query_view, key_view)
print(f"Score shapes: BATCH_SIZE ({BATCH_SIZE}), heads ({heads}), seq_len ({seq_len}), seq_len ({seq_len}) --> {scores.shape}")

Score shapes: BATCH_SIZE (1), heads (3), seq_len (2), seq_len (2) --> torch.Size([1, 3, 2, 2])


### 5. Scale and apply Softmax
Softmax converts velues into probability distribution, since it normalize each columns into values $\in [0,1]$. It is apply to the last dimension of our tensor, since it correspond to the columns of the attention sub-matrix.

> This step is the reason why previous attempts to reduce the square-complexity of attention fail.

In [11]:
A = torch.softmax(scale * scores, dim=-1)
print(f"This step doesn't change the shape: {A.shape}")

This step doesn't change the shape: torch.Size([1, 3, 2, 2])


### 6. Multiply by the values

The multiplication by values represent the feedback for the embedding, based on what was calculated by the attention. It should be considered as a unique linear map, from the current embedding to the new suggestions from the attention.

`d_model` $\rightarrow$ `d_model`

But, in practice, it's implemented in two steps

`d_model`$\rightarrow$`d_values`$\times$`heads` $\rightarrow$ `d_values`$\times$`heads`$\rightarrow$ `d_model`

In [12]:
# just for comparison 
# query_projection.shape = (d_model, d_query_key * heads)

value_projection = nn.Linear(d_model, d_values *heads)
out_projection = nn.Linear(d_values*heads, d_model)

values = value_projection(x)
values.shape

torch.Size([1, 2, 30])

#### 6.1 Separate heads
As done before for query and key...

In [13]:
# N.B. In case query and keys have different seq_lengths, 
# it has to be set = to keys (columns)

values_view = values.view(BATCH_SIZE, seq_len, heads, -1)
values_view.shape

torch.Size([1, 2, 3, 10])

#### 6.2 Multiply with einsum

In [14]:
V = torch.einsum("bhls,bshd->blhd", A, values_view)
V = V.contiguous()
V.shape

torch.Size([1, 2, 3, 10])

### 7. Concatenate the heads back

In [15]:
V_view = V.view(BATCH_SIZE,seq_len,-1)
V_view.shape

torch.Size([1, 2, 30])

### 8. Multiply by the Output
This last multiplication, maps the output back to the initial shape. 

In [16]:
out = out_projection(V_view)
out.shape

torch.Size([1, 2, 12])

# Kernel Attention
$$A(x_q)=\sum_{x_k} \frac{k(x_q,x_k)}{Z}v(x_k)$$
With the normalization constant $Z = \sum_{x_k} k(x_q,x_k')$


- The kernel $k$ from *Tsai (i) Positional Embeddings*
$$k(x_q,x_k)=k_\text{exp}(f_q+t_q,f_k+t_k)$$ 

where $k_\text{exp}$ is the exponential kernel 

$k_\text{exp}(q,k)=\exp{\bigg(\frac{\langle qW_q,kW_k\rangle}{\sqrt{d_k}}\bigg)}$

- The value function $v(x_k)=(f_k+t_k)W_v$ is pretty much tha same as the dot-product attention


In [17]:
def k_exp(Q,K,d_k):
    return torch.exp(torch.einsum("ble,bke->blk", Q, K)/sqrt(d_k))
    
q_size = 2
k_size = 4
v_size = q_size
BATCH_SIZE = 3

x_q = torch.randn(BATCH_SIZE,q_size, d_model)
x_k = torch.randn(BATCH_SIZE,k_size, d_model)

W_k = torch.randn(BATCH_SIZE,d_model, d_query_key)
W_q = torch.randn(BATCH_SIZE,d_model, d_query_key)
W_v = torch.randn(BATCH_SIZE,d_model, d_model)


K = x_k@W_k
Q = x_q@W_q
V = x_k@W_v  

ker = k_exp(Q,K,d_query_key)

# normalization constant
Z = ker.sum(axis=2).unsqueeze(-1).expand(-1,-1,k_size)

k_cross_att = (ker/Z)
att_out = k_cross_att@V
k_cross_att

tensor([[[1.8513e-04, 4.1025e-01, 4.0401e-07, 5.8956e-01],
         [3.0236e-08, 9.6120e-07, 9.9875e-01, 1.2447e-03]],

        [[4.4123e-10, 2.2233e-04, 2.1314e-11, 9.9978e-01],
         [9.9636e-01, 1.7948e-05, 7.2879e-05, 3.5478e-03]],

        [[3.0797e-01, 9.8793e-09, 1.2706e-04, 6.9190e-01],
         [4.5116e-02, 1.2714e-10, 1.0660e-09, 9.5488e-01]]])

In [18]:
print("Shapes:\n")
print(f"x_keys: {x_k.shape}")
print(f"x_queries: {x_q.shape}")
print(f"Keys: {K.shape}")
print(f"Queries: {Q.shape}")
print(f"kernel: {ker.shape}")
print(f"Cross_attention_coefficients: {k_cross_att.shape}")
print(f"Attention output: {att_out.shape}")

Shapes:

x_keys: torch.Size([3, 4, 12])
x_queries: torch.Size([3, 2, 12])
Keys: torch.Size([3, 4, 9])
Queries: torch.Size([3, 2, 9])
kernel: torch.Size([3, 2, 4])
Cross_attention_coefficients: torch.Size([3, 2, 4])
Attention output: torch.Size([3, 2, 12])


In [19]:
# Normalization
k_cross_att.sum(axis=2)

tensor([[1.0000, 1.0000],
        [1.0000, 1.0000],
        [1.0000, 1.0000]])

In [20]:
k_cross_att.shape

torch.Size([3, 2, 4])

In [21]:
mask = torch.rand(k_size)>0.5
mask = mask.unsqueeze(0).unsqueeze(0)
mask=mask.expand_as(k_cross_att)
mask

tensor([[[False,  True, False, False],
         [False,  True, False, False]],

        [[False,  True, False, False],
         [False,  True, False, False]],

        [[False,  True, False, False],
         [False,  True, False, False]]])

# One to One Head

In [79]:
d_emb1 = 3
d_emb2 = 5

d_emb_list = [d_emb1,d_emb2]

X_emb1 = torch.randn(BATCH_SIZE,seq_len,d_emb1)
X_emb1[0,0,0] = float('nan')
X_emb2 = torch.zeros(BATCH_SIZE,seq_len,d_emb2).fill_(float('nan'))
X_emb1

tensor([[[    nan,  0.1039, -0.3482],
         [-1.4673,  0.8101, -0.7187]],

        [[-0.6928, -0.4300,  0.3263],
         [ 0.4130, -0.9167,  1.4718]],

        [[ 0.4810, -0.0921,  0.2670],
         [ 0.5295, -1.0155,  1.2435]]])

In [81]:

X_emb = torch.cat((X_emb1,X_emb2), dim=-1)
X_emb_split = torch.split(X_emb,d_emb_list,dim=-1)
X_emb_split[0]

tensor([[[    nan,  0.1039, -0.3482],
         [-1.4673,  0.8101, -0.7187]],

        [[-0.6928, -0.4300,  0.3263],
         [ 0.4130, -0.9167,  1.4718]],

        [[ 0.4810, -0.0921,  0.2670],
         [ 0.5295, -1.0155,  1.2435]]])

In [88]:
query_projections = nn.ModuleList([nn.Linear(d_emb, d_query_key)for d_emb in d_emb_list])
query_projections


ModuleList(
  (0): Linear(in_features=3, out_features=9, bias=True)
  (1): Linear(in_features=5, out_features=9, bias=True)
)

In [91]:
queries = [proj(emb) for proj, emb in zip(query_projections, X_emb_split)]
keys = [proj(emb) for proj, emb in zip(query_projections, X_emb_split)]
queries

[tensor([[[    nan,     nan,     nan,     nan,     nan,     nan,     nan,
               nan,     nan],
          [-0.4774, -0.7712,  0.7957, -0.1341, -0.2469, -0.9042, -0.4872,
           -0.7373,  0.3509]],
 
         [[ 0.2119, -0.1436,  0.6397, -0.0292,  0.3975, -0.4922,  0.1334,
           -0.8205, -0.0744],
          [ 0.6324,  0.1488,  0.1885,  0.6374,  0.9332,  0.2457,  0.3951,
           -0.3664, -0.1188]],
 
         [[ 0.7270, -0.1419,  0.1202,  0.6506,  0.1419,  0.1347, -0.2498,
           -0.0755,  0.3711],
          [ 0.8513,  0.2342,  0.1881,  0.5598,  0.8071,  0.2293,  0.3608,
           -0.3808, -0.1034]]], grad_fn=<AddBackward0>),
 tensor([[[nan, nan, nan, nan, nan, nan, nan, nan, nan],
          [nan, nan, nan, nan, nan, nan, nan, nan, nan]],
 
         [[nan, nan, nan, nan, nan, nan, nan, nan, nan],
          [nan, nan, nan, nan, nan, nan, nan, nan, nan]],
 
         [[nan, nan, nan, nan, nan, nan, nan, nan, nan],
          [nan, nan, nan, nan, nan, nan, nan, nan, n

In [109]:
scores = [torch.nan_to_num(torch.matmul(q, k.transpose(-2, -1))) for q,k in zip(queries,keys)]
sum(scores)

tensor([[[0.0000, 0.0000],
         [0.0000, 3.2566]],

        [[1.5726, 0.8269],
         [0.8269, 2.0997]],

        [[1.2305, 1.0181],
         [1.0181, 2.1182]]], grad_fn=<AddBackward0>)

In [93]:
mask = torch.isnan(scores)
scores.masked_fill_(mask,-torch.inf)
scores

TypeError: isnan(): argument 'input' (position 1) must be Tensor, not list

# Attention masks

In [114]:
from typing import Any


class UniformAttentionMask(nn.Module):
    def __init__(self) -> None:
        super(UniformAttentionMask,self).__init__()
    
    def forward(self, attention_scores:torch.Tensor, mask:torch.Tensor):
        """
        Applies masking to the attention scores.
        
        Args:
        - attention_scores: Tensor of shape (batch_size, N_queries, N_keys).
        - mask: Boolean tensor of shape (N_keys), where False means the corresponding key should be masked (zeroed).
        
        Returns:
        - masked_attention_scores: Tensor with masked attention scores.
        """

        assert attention_scores.shape[-1] == len(mask), AssertionError(f"Got mask of length {len(mask)}, expected {attention_scores.shape[-1]}")
        
        # Ensure the mask is a torch tensor
        if not isinstance(mask, torch.Tensor):
            mask = torch.tensor(mask)
        
        # Ensure the mask is on the same device as the attention scores
        if mask.device != attention_scores.device:
            mask = mask.to(attention_scores.device)
        
        # Convert boolean mask to float and expand it to match attention_scores
        mask = mask.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, N_keys)
        mask=mask.expand_as(attention_scores)
        # Apply the mask to zero out the attention scores where mask is False
        
        return attention_scores.masked_fill(mask, -torch.inf)
    
uniform_mask = UniformAttentionMask()

mask = np.random.rand(k_size)>0.5

uniform_mask(k_cross_att,mask)

tensor([[[1.8513e-04,       -inf, 4.0401e-07, 5.8956e-01],
         [3.0236e-08,       -inf, 9.9875e-01, 1.2447e-03]],

        [[4.4123e-10,       -inf, 2.1314e-11, 9.9978e-01],
         [9.9636e-01,       -inf, 7.2879e-05, 3.5478e-03]],

        [[3.0797e-01,       -inf, 1.2706e-04, 6.9190e-01],
         [4.5116e-02,       -inf, 1.0660e-09, 9.5488e-01]]])

array([False,  True, False, False])

In [23]:
a = {"a":0}
b = {"b":1}

a | b


{'a': 0, 'b': 1}

In [24]:
from prochain_transformer.modules.extra_layers import UniformAttentionMask

In [25]:
mask_layer = UniformAttentionMask()
scores = torch.einsum("blhe,bshe->bhls", query_view, key_view)
A = torch.nan_to_num(torch.softmax(scores, dim=-1))
A.shape

torch.Size([1, 3, 2, 2])