# Adding in Position Embeddings

This code is part of a Transformer-based model (like T5) and deals with position embeddings. Position embeddings are used to inject information about the position of tokens (words) in a sequence so that the model can take the order of tokens into account. The goal of this code is to add relative position embeddings to a Transformer’s attention mechanism, which helps the model understand how far apart tokens are from each other.

From Memorizing Transformers paper:

    "Position bias. For dense attention within the local context, we use the T5 relative position bias (Raffel
    et al., 2020). As noted by Dai et al. (2019), adding a global position encoding to each token does not
    work well when processing long documents. We don’t use a position bias for the retrieved memories.
    Experiments on the PG19 dataset (Sun et al., 2021) have shown that relative position does not appear
    to matter at long range, and the T5 relative bias puts all long-range tokens in the same bucket anyway."

From T5 paper:

    "Since self-attention is order-independent (i.e. it is an operation on sets), it is common
    to provide an explicit position signal to the Transformer. While the original Transformer
    used a sinusoidal position signal or learned position embeddings, it has recently become
    more common to use relative position embeddings (Shaw et al., 2018; Huang et al., 2018a).
    Instead of using a fixed embedding for each position, relative position embeddings produce
    a different learned embedding according to the offset between the “key” and “query” being
    compared in the self-attention mechanism. We use a simplified form of position embeddings
    where each “embedding” is simply a scalar that is added to the corresponding logit used
    for computing the attention weights. For efficiency, we also share the position embedding
    parameters across all layers in our model, though within a given layer each attention head
    uses a different learned position embedding. Typically, a fixed number of embeddings are
    learned, each corresponding to a range of possible key-query offsets. In this work, we use 32
    embeddings for all of our models with ranges that increase in size logarithmically up to an
    offset of 128 beyond which we assign all relative positions to the same embedding. Note
    that a given layer is insensitive to relative position beyond 128 tokens, but subsequent layers
    can build a sensitivity to larger offsets by combining local information from previous layers."

In [None]:
# RELATIVE POSITION MATRIX
# tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
#         [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12],
#         [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
#         [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
#         [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
#         [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8],
#         [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7],
#         [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
#         [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
#         [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
#         [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
#         [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
#         [-12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1],
#         [-13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0]])

### Idea

- Positional embeddings are added to the QK embeddings during attention
- Relative position embeddings identify, for each input example, how far away all the other tokens are from a specific token of interest
- Instead of giving each token a relative position index of n that is n positions away from our token of interest, T5 relative position "buckets" some tokens into the same index
- First we create this set of indices. then the indices are matched to an embedding layer of weight values. These values are then added to the QK embeddings during attention. The positional embeddings are trained with the network.

### Recipe

- Construct a relative position matrix
- For offsets larger than what we want, start to spread offset values logarithmically into a finite amount of buckets. (Past a certian max value (128) we'll just map everything to one value)
- Initialize embedding weights that we will assign offset values to
- Now the relative position matrix is mapped to these weights
- This matrix gets added to our attention when we perform self-attention. Our self-attention now incorporates as a piece of information the relative positions between tokens



### Imports and Setup

- numpy: Used for numerical operations (not directly used in this snippet but often used for array manipulations).
- torch: The deep learning framework, used here to create tensors and perform operations.
- torch.nn: Contains layers, such as nn.Embedding, which is used for the position embeddings.
- einsum: A special function for complex tensor operations, not used in this block directly.
- torch.nn.functional: Contains functions like activation functions, loss functions, etc., though not used here.
- math: The math library for functions like log.

In [1]:
import numpy as np
import torch
from torch import nn, einsum
import torch.nn.functional as F
import math

### Set Parameters for Position Embedding

- num_buckets: Defines how many "buckets" or categories you will have for the relative positions between tokens. For example, positions from 1 to 5 might be grouped into one bucket, and 6 to 10 into another.
- max_distance: The longest distance between two tokens in the context the model will consider.
- sequence_length and max_context_length: Define how long your input and key sequences are.

In [2]:
num_buckets = 6 # the total number of index buckets we'll use
max_distance = 20 # maximum sequence length

sequence_length = 14 # query length / input sequence length
max_context_length = 14 # key length: can be equal to sequence_length or greater if recurrence/memory is concatenated

### Creating the Relative Position Matrix

- q_pos: Creates a sequence of numbers from 0 to sequence_length - 1, representing the positions of tokens in the query.
- k_pos: Similar to q_pos, but for the key sequence.
- rel_pos: This computes the relative positions between the query and key tokens by subtracting the query position from the - key position. This matrix represents the distance between tokens.


In [3]:
q_pos = torch.arange(sequence_length, dtype=torch.long)
print(q_pos)
print(q_pos.shape)

q_pos = q_pos.reshape(q_pos.shape[0], 1)
print(q_pos)
print(q_pos.shape)

k_pos = torch.arange(max_context_length, dtype=torch.long)
print(k_pos)
print(k_pos.shape)

rel_pos = k_pos - q_pos
rel_pos

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])
torch.Size([14])
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13]])
torch.Size([14, 1])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])
torch.Size([14])


tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13],
        [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12],
        [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
        [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10],
        [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [ -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8],
        [ -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7],
        [ -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6],
        [ -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5],
        [ -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3,   4],
        [-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,   3],
        [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2],
        [-12, -11, -10,  -9,  -8,  -7,  

### Handling Negative and Large Values in Positions

- n = -rel_pos: Makes all the relative positions negative because the model is concerned with how far tokens are to the right (positive) or to the left (negative).
- torch.max(n, torch.zeros_like(n)): Ensures that any negative values (which can happen if a token is far to the left) are set to zero, because relative position can't be negative.

In [4]:
n = -rel_pos
n

tensor([[  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13],
        [  1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12],
        [  2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11],
        [  3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10],
        [  4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9],
        [  5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8],
        [  6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6,  -7],
        [  7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5,  -6],
        [  8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4,  -5],
        [  9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3,  -4],
        [ 10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2,  -3],
        [ 11,  10,   9,   8,   7,   6,   5,   4,   3,   2,   1,   0,  -1,  -2],
        [ 12,  11,  10,   9,   8,   7,  

In [5]:
n = torch.max(n, torch.zeros_like(n))
n

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0,  0],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0,  0],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0,  0],
        [10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0,  0],
        [11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0,  0],
        [12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0,  0],
        [13, 12, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2,  1,  0]])

### Handling Small and Large Offsets

- max_exact: Half of the buckets will deal with exact increments, like +1, +2, etc.
- is_small: A boolean mask identifying which relative positions are small enough to be assigned exactly.

In [6]:
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
max_exact

3

In [8]:
is_small = n < max_exact
is_small

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True,
          True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,

The other half of the buckets are for logarithmically bigger bins in positions up to max_distance.

So, we map a positional embeddings up to a number k exactly (offset by 1, offset by 2, offset by 3...) but at a certain point we have a longer sequence than positional embedding "buckets" (like bins), so we map them logarithmically to spread out over our fixed number of buckets: e.g. [1,2,3,4,5,5,6,6,7,7,7,8,8,8,8,]

### Handling Large Offsets Logarithmically

For positions that are large, the model uses logarithmic scaling. This means that instead of directly mapping each position to a unique bucket, larger distances will be grouped together, and each group corresponds to a bucket.

The formula essentially takes the log of the position, scales it, and

The positions that are far apart (large n values) are handled differently. Instead of directly assigning them a unique bucket, we use a logarithmic scale to group them into a smaller number of buckets.
- torch.log(n.float() / max_exact): Takes the logarithm of the relative position divided by the maximum exact position.
- math.log(max_distance / max_exact): Normalizes the logarithmic scale to fit within the number of buckets.
- torch.min(val_if_large, ...): Ensures that the values don't exceed the maximum bucket value.

In [9]:
# Not sure why this function looks the way it does! But it maps indices logarithmically over a fixed number of buckets.
val_if_large = max_exact + \
  (
    torch.log(n.float() / max_exact)  # log of matrix divided by scalar
    / math.log(max_distance / max_exact) * (num_buckets - max_exact) # scalar
    ).long() # convert float to int

val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact))

In [10]:
val_if_large

tensor([[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf],
        [4.3399, 4.0961, 3.8078, 3.4549, 3.0000, 2.3588, 1.2627,   -inf,   -inf,
  

In [11]:
val_if_large = val_if_large.long()
val_if_large

tensor([[-9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   1, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808],
        [                   2,                    1, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854775808,
         -9223372036854775808, -9223372036854775808, -9223372036854

### Creating Position Bucket Indices

This line decides whether to use the exact bucket (n) or the logarithmically scaled bucket (val_if_large), based on whether the position is small or large.

In [12]:
position_bucket_indices = torch.where(is_small, n, val_if_large)
position_bucket_indices

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0],
        [4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0, 0],
        [5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0, 0],
        [5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0, 0],
        [5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 2, 1, 0]])

### Initializing the Position Embedding Layer

This line initializes a learnable embedding layer for the position biases. The layer will learn embeddings for each bucket (relative position), which will be used in the attention calculation. The number of buckets is num_buckets, and the number of attention heads is heads.

In [13]:
heads = 4
relative_position_bias = nn.Embedding(num_buckets, heads)
relative_position_bias

Embedding(6, 4)

In [14]:
relative_position_bias.weight

Parameter containing:
tensor([[-0.3185,  0.7457, -0.3274, -0.2323],
        [-0.7818, -0.3635, -1.0704,  0.7617],
        [-0.0420,  0.5257,  0.8047,  0.7205],
        [-1.0042, -0.6981, -1.2029,  1.3093],
        [-0.5956, -0.1140,  0.5519,  0.1065],
        [ 0.9892, -0.0077,  0.5691, -0.5393]], requires_grad=True)

Here, the relative_position_bias embedding layer is used to map the position_bucket_indices (the indices for each relative position) into embeddings. This converts the indices into vector representations that the model can use.

In [15]:
relative_position_values = relative_position_bias(position_bucket_indices)
relative_position_values.shape

torch.Size([14, 14, 4])

- transpose(0, 2): This swaps the sequence and context dimensions so that the shape aligns for attention computation.
- unsqueeze(0): Adds a batch dimension (since models typically expect a batch of data, even if there’s only one example).

In [16]:
# Need to reshape from (sequence, context, heads) -> (batch, heads, sequence, context)
relative_position_values = relative_position_values.transpose(0,2).unsqueeze(0)
relative_position_values.shape

torch.Size([1, 4, 14, 14])

In [18]:
relative_position_values

tensor([[[[-0.3185, -0.7818, -0.0420, -1.0042, -1.0042, -1.0042, -0.5956,
           -0.5956, -0.5956, -0.5956, -0.5956,  0.9892,  0.9892,  0.9892],
          [-0.3185, -0.3185, -0.7818, -0.0420, -1.0042, -1.0042, -1.0042,
           -0.5956, -0.5956, -0.5956, -0.5956, -0.5956,  0.9892,  0.9892],
          [-0.3185, -0.3185, -0.3185, -0.7818, -0.0420, -1.0042, -1.0042,
           -1.0042, -0.5956, -0.5956, -0.5956, -0.5956, -0.5956,  0.9892],
          [-0.3185, -0.3185, -0.3185, -0.3185, -0.7818, -0.0420, -1.0042,
           -1.0042, -1.0042, -0.5956, -0.5956, -0.5956, -0.5956, -0.5956],
          [-0.3185, -0.3185, -0.3185, -0.3185, -0.3185, -0.7818, -0.0420,
           -1.0042, -1.0042, -1.0042, -0.5956, -0.5956, -0.5956, -0.5956],
          [-0.3185, -0.3185, -0.3185, -0.3185, -0.3185, -0.3185, -0.7818,
           -0.0420, -1.0042, -1.0042, -1.0042, -0.5956, -0.5956, -0.5956],
          [-0.3185, -0.3185, -0.3185, -0.3185, -0.3185, -0.3185, -0.3185,
           -0.7818, -0.0420, -1.


Creating the RelativePosition Class
- This is the definition of a new class called RelativePosition, which inherits from nn.Module. This class is responsible for creating and managing relative position embeddings within the model.
- The constructor initializes several parameters, including num_buckets, rp_max_distance, and heads.

Computing Position Buckets in the forward Method
- sequence_pos: Creates a tensor representing the positions of tokens in the query sequence.
- context_pos: Similarly, represents the positions in the context (key) sequence.
- rel_pos: The relative positions between the query and context sequences.

Computing the Final Relative Position Embedding
- rp_values: Retrieves the position embeddings based on the position_bucket_indices.
- rp_values.transpose(0, 2): Adjusts the dimensions to match the attention mechanism's requirements.
- rp_values.unsqueeze(0): Adds a batch dimension.
- return rp_values * self.scale: The relative position values are scaled (by rp_scale) and returned.

The **RelativePosition** embedding system is a way of encoding positional relationships between tokens (e.g., words) in a sequence for use in the attention mechanism of a transformer model. This is especially useful in cases where the model needs to handle varying sequence lengths and efficiently represent the distance between tokens without explicitly storing all pairwise positional distances.

### Key Ideas Behind Relative Position Embeddings

1. **Relative Positions**:
   - Instead of assigning each position in a sequence an absolute value (e.g., 1, 2, 3, ...), relative position embeddings focus on the distance between tokens.
   - For example, if you're at position `i`, the token at position `j` has a relative position of `j - i`.

2. **Position Bucketing**:
   - Large relative distances can make models inefficient. Instead of representing every possible relative distance, we **group relative distances into buckets**.
   - Example: Distances `1-5` might go into bucket `0`, distances `6-10` into bucket `1`, and so on. For large distances, logarithmic scaling is used to group distances more coarsely (e.g., `20-40`, `40-80`).

3. **Relative Position Embeddings**:
   - Each bucket is assigned an embedding vector.
   - These embeddings are added or applied as bias terms to the attention mechanism, enriching the model's understanding of how tokens relate to each other based on their distance.

---

### **Code Explanation**

#### Key Components:

1. **Initialization**:
   ```python
   self.relative_attention_embedding = nn.Embedding(num_buckets, heads)
   ```
   - An embedding matrix of shape `(num_buckets, heads)` is created. Each bucket gets a unique embedding for each attention head.

2. **Position Bucketing**:
   ```python
   def relative_position_bucket(self, relative_position_matrix):
       n = -relative_position_matrix
       n = torch.max(n, torch.zeros_like(n))
       ...
   ```
   - Relative positions (`rel_pos`) are mapped into buckets. Small distances (`< max_exact`) are preserved exactly, while larger distances are grouped logarithmically.

   Example:
   - `rel_pos = [-2, -1, 0, 1, 2, 10, 50]`
   - Buckets might map these to `[2, 1, 0, 1, 2, 5, 8]`.

3. **Forward Pass**:
   ```python
   sequence_pos = torch.arange(sequence_length, dtype=torch.long)
   rel_pos = context_rel_pos - sequence_rel_pos
   ```
   - A relative position matrix (`rel_pos`) is calculated for all token pairs in the sequence.

   ```python
   position_bucket_indices = self.relative_position_bucket(rel_pos)
   rp_values = self.relative_attention_embedding(position_bucket_indices)
   ```
   - The relative positions are converted to bucket indices and replaced with their corresponding embeddings.

4. **Output**:
   ```python
   rp_values = rp_values.transpose(0,2).unsqueeze(0)
   return rp_values * self.scale
   ```
   - The embeddings are reshaped for use in the attention mechanism, with a scaling factor (`self.scale`) applied to adjust their magnitude.

---

### **Example**

#### Input:
- Sequence length = 4
- Context length = 6
- Relative positions:
  ```
  [[0, 1, 2, 3, 4, 5],
   [-1, 0, 1, 2, 3, 4],
   [-2, -1, 0, 1, 2, 3],
   ...
  ]
  ```

#### Bucketing:
- Exact for small distances:
  - `[-2, -1, 0, 1, 2]` → `[2, 1, 0, 1, 2]`
- Logarithmic for large distances:
  - `3 → 3`, `4 → 4`, `5 → 4` (log scale simplifies larger values).

#### Attention Impact:
The embeddings guide the model to:
- Focus on nearby tokens more precisely for short contexts (e.g., neighboring words).
- Handle distant tokens efficiently for long sequences (e.g., paragraph-level understanding).

---

### Why Use Relative Position Embeddings?

1. **Efficiency for Long Sequences**:
   - By using buckets and logarithmic scaling, memory requirements are reduced for large sequences.

2. **Better Performance on Variable-Length Sequences**:
   - Unlike absolute embeddings, relative embeddings generalize well when the sequence length varies during training and inference.

3. **Context Awareness**:
   - Tokens are enriched with relative positional context, improving model performance in tasks like language modeling and translation.

In [17]:
class RelativePosition(nn.Module):
  def __init__(
      self,
      rp_scale,
      num_buckets = 32,
      rp_max_distance = 128,
      heads = 8
  ):
      super().__init__()
      self.scale = rp_scale
      self.num_buckets = num_buckets
      self.rp_max_distance = rp_max_distance
      self.relative_attention_embedding = nn.Embedding(num_buckets, heads)

  def relative_position_bucket(self, relative_position_matrix):
      n = -relative_position_matrix
      n = torch.max(n, torch.zeros_like(n))

      max_exact = self.num_buckets // 2

      is_small = n < max_exact
      val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(self.rp_max_distance / max_exact) * (self.num_buckets - max_exact)).long()
      val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, self.num_buckets - 1))

      return torch.where(is_small, n, val_if_large)

  def forward(self, sequence_length, max_context_length):

      sequence_pos = torch.arange(sequence_length, dtype=torch.long)
      context_pos = torch.arange(max_context_length, dtype=torch.long)
      sequence_pos = sequence_pos.reshape(sequence_pos.shape[0], 1)
      rel_pos = context_rel_pos - sequence_rel_pos

      position_bucket_indices = self.relative_position_bucket(rel_pos)

      rp_values = self.relative_attention_embedding(position_bucket_indices)
      # Rearrange (sequence, context, heads) -> (1, heads, sequence, context)
      rp_values = rp_values.transpose(0,2)
      rp_values = rp_values.unsqueeze(0)
      return rp_values * self.scale

This part of the code is responsible for creating relative position embeddings that are added to the query and key vectors in the attention mechanism.

Position buckets are used to group relative positions into a fixed number of categories (buckets).

Logarithmic scaling is applied to large distances to ensure the model can handle long-range dependencies efficiently.

The RelativePosition class calculates and returns the relative position embeddings, which are then used in the attention mechanism of the model.
