T5 position encoding:
1. pos relative;
2. pos bucket;
3. embedding mapping;



In [1]:
import torch
import torch.nn as nn
import math

def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
    """
    Adapted from Mesh Tensorflow:
    https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

    Translate relative position to a bucket number for relative attention. The relative position is defined as
    memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
    small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
    positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
    This should allow for more graceful generalization to longer sequences than the model has been trained on

    Args:
        relative_position: an int32 Tensor
        bidirectional: a boolean - whether the attention is bidirectional
        num_buckets: an integer
        max_distance: an integer

    Returns:
        a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
    """
    relative_buckets = 0
    if bidirectional:
        num_buckets //= 2
        relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)
    else:
        relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
    # now relative_position is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / math.log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    return relative_buckets

relative_attention_bias = nn.Embedding(32, 12)
context_position = torch.arange(20, dtype=torch.long)[:, None]
memory_position = torch.arange(20, dtype=torch.long)[None, :]

print(context_position)
print(memory_position)

relative_position = (memory_position - context_position)
print(relative_position)

relative_position_bucket = _relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=True,
            num_buckets=32,
            max_distance=128,
        )
print(relative_position_bucket)


values = relative_attention_bias(relative_position_bucket)

values = values.permute([2, 0, 1]).unsqueeze(0)
print(values.shape)
print(values[0, 0, :, :])


tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18],
        [19]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19]])
tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19],
        [ -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18],
        [ -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,
          12,  13,  14,  15,  16,  17],
        [ -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,
          11,  12,  13,  14,  15,  16],
        [ -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
          10,  11,  12,  13,  14,  15],
        [ -5,  -4,  -3, 

**AliBi**

In [8]:
context_position = torch.arange(10, dtype=torch.long)[:, None]
memory_position = torch.arange(10, dtype=torch.long)[None, :]

# print(context_position)
# print(memory_position)

relative_position = (context_position - memory_position)
# relative_position = -1 * torch.abs(relative_position)
# print(relative_position)

left_right = (context_position - memory_position)
left_right = torch.where(left_right > 0, -0.5, 2)
# print(left_right)

relative_position = relative_position * left_right

print(relative_position)




tensor([[  0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000, -10.0000, -12.0000,
         -14.0000, -16.0000, -18.0000],
        [ -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000, -10.0000,
         -12.0000, -14.0000, -16.0000],
        [ -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000,
         -10.0000, -12.0000, -14.0000],
        [ -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,
          -8.0000, -10.0000, -12.0000],
        [ -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,
          -6.0000,  -8.0000, -10.0000],
        [ -2.5000,  -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,
          -4.0000,  -6.0000,  -8.0000],
        [ -3.0000,  -2.5000,  -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,
          -2.0000,  -4.0000,  -6.0000],
        [ -3.5000,  -3.0000,  -2.5000,  -2.0000,  -1.5000,  -1.0000,  -0.5000,
           0.0000,  -2.0000,  -4.0000],
        [ -4.0000,  -3.5000,  -3.0000,  -2.5000,

In [None]:
def get_slopes(n):
  def get_slopes_power_of_2(n):
      start = (2**(-2**-(math.log2(n)-3)))
      ratio = start
      return [start*ratio**i for i in range(n)]

  if math.log2(n).is_integer():
    print(n)
    r = get_slopes_power_of_2(n)
    print(r)
    return r                  #In the paper, we only train models that have 2^a heads for some a. This function has
  else:                                                 #some good properties that only occur when the input is a power of 2. To maintain that even
      closest_power_of_2 = 2**math.floor(math.log2(n))  #when the number of heads is not a power of 2, we use this workaround.
      return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

In [None]:
n = 6
slopes = get_slopes(n)
slopes = torch.tensor(slopes).reshape([1, n, 1, 1])
print(slopes)


# relative_position = torch.ones([10, 10])

print(relative_position)

headers_relative_position = slopes * relative_position
print(headers_relative_position.shape)

print(headers_relative_position[0][0])

8
[0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625]
tensor([[[[0.2500]],

         [[0.0625]],

         [[0.0156]],

         [[0.0039]],

         [[0.5000]],

         [[0.1250]]]])
tensor([[  0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000, -10.0000, -12.0000,
         -14.0000, -16.0000, -18.0000],
        [ -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000, -10.0000,
         -12.0000, -14.0000, -16.0000],
        [ -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,  -8.0000,
         -10.0000, -12.0000, -14.0000],
        [ -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,  -6.0000,
          -8.0000, -10.0000, -12.0000],
        [ -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,  -4.0000,
          -6.0000,  -8.0000, -10.0000],
        [ -2.5000,  -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,  -2.0000,
          -4.0000,  -6.0000,  -8.0000],
        [ -3.0000,  -2.5000,  -2.0000,  -1.5000,  -1.0000,  -0.5000,   0.0000,
   

In [None]:
 def get_attention_mask(seq_length):
        seq_ids = torch.arange(seq_length)
        causal_mask = seq_ids[None, :].repeat(1, seq_length, 1) <= seq_ids[None, :, None]
        causal_mask = causal_mask.to(torch.float32).unsqueeze(0)
        causal_mask = (1.0 - causal_mask) * torch.finfo(torch.float32).min
        return causal_mask

get_attention_mask(10)



tensor([[[[-0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -3.4028e+38,
           -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
           -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
           -0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
      