In [5]:
import collections
import math
import string

import numpy as np
import tensorflow.compat.v2 as tf
# import tensorflow as tf

from keras import constraints
from keras import initializers
from keras import regularizers
# from keras.engine.base_layer import Layer
# from keras.layers import activation
# from keras.layers import core
# from keras.layers import regularization
# from keras.utils import tf_utils

#  einsum projection

In [8]:
# Simulating input query tensor
Q = tf.random.uniform(shape=[2, 8, 512])  # Shape: [batch_size, seq_length, embedding_dim]

# Adjusting projection weights to account for num_heads and key_dim
# New shape: [embedding_dim, num_heads, key_dim]
W = tf.random.uniform(shape=[512, 8, 64])

# Using einsum for the projection, adjusting the equation accordingly
# The corrected equation: 'abc,cde->abde'
# Where:
# 'a' represents the batch size
# 'b' represents the sequence length
# 'c' represents the embedding dimension (to be contracted)
# 'd' represents the number of attention heads
# 'e' represents the key dimension
Q_proj = tf.einsum('abc,cde->abde', Q, W)

print(Q_proj.shape)  # Expected output shape: [2, 8, 8, 64]

(2, 8, 8, 64)


In [17]:
def _build_proj_equation(free_dims, bound_dims, output_dims):
    # Assign letters to dimensions for the einsum equation
    # Starting letters for free dimensions
    letters = 'abcdefghijklmnopqrstuvwxyz'
    free_letters = letters[:free_dims]
    
    # Letters for bound (contracted) dimensions
    bound_letter = letters[free_dims:free_dims + bound_dims]
    
    # Letters for output dimensions
    output_letters = letters[free_dims + bound_dims:free_dims + bound_dims + output_dims]
    
    # Construct the input part of the equation (before "->")
    input_str = free_letters + bound_letter
    
    # Construct the projection weights part of the equation
    weights_str = bound_letter + output_letters
    
    # Construct the output part of the equation (after "->")
    output_str = free_letters + output_letters
    
    # Combine into full einsum equation string
    equation = f"{input_str},{weights_str}->{output_str}"
    return equation


- free_dims = 2 (for batch_size and seq_length) 
- bound_dims = 1 (for embedding_dim, the dimension to be reduced/transformed)
- output_dims = 2 (for num_heads and key_dim)

In [18]:
equation = _build_proj_equation(free_dims=2, bound_dims=1, output_dims=2)
print("Einsum equation:", equation)
# Expected output: "abc,cde->abde"


Einsum equation: abc,cde->abde


### code from tf mha

In [19]:
_CHR_IDX = string.ascii_lowercase
print(_CHR_IDX)

def _build_proj_equation(free_dims, bound_dims, output_dims):
    """Builds an einsum equation for projections inside multi-head attention."""
    input_str = ""
    kernel_str = ""
    output_str = ""
    bias_axes = ""
    letter_offset = 0
    for i in range(free_dims):
        char = _CHR_IDX[i + letter_offset]
        input_str += char
        output_str += char

    letter_offset += free_dims
    for i in range(bound_dims):
        char = _CHR_IDX[i + letter_offset]
        input_str += char
        kernel_str += char

    letter_offset += bound_dims
    for i in range(output_dims):
        char = _CHR_IDX[i + letter_offset]
        kernel_str += char
        output_str += char
        bias_axes += char
    equation = f"{input_str},{kernel_str}->{output_str}"

    return equation, bias_axes, len(output_str)
free_dims = 2
bound_dims = 1
output_dims = 2 
_build_proj_equation(free_dims, bound_dims, output_dims)

abcdefghijklmnopqrstuvwxyz


('abc,cde->abde', 'de', 4)

#  scaled dot product attention 

In [22]:


# Assuming Q and K are query and key tensors with shapes [2, 8, 10, 64]
Q = tf.random.uniform(shape=[2, 8, 10, 64])  # Shape: [batch_size, num_heads, seq_length_q, depth]
K = tf.random.uniform(shape=[2, 8, 10, 64])  # Shape: [batch_size, num_heads, seq_length_k, depth]

# Correctly calculating the dot product between Q and K^T using einsum
attention_scores = tf.einsum('bhqd,bhkd->bhqk', Q, K)

# Scaling by 1/sqrt(depth) for normalization
depth = 64
scaled_attention_scores = attention_scores / tf.math.sqrt(tf.cast(depth, tf.float32))

print("Scaled attention scores shape:", scaled_attention_scores.shape)
# Expected shape: [2, 8, 10, 10] matching [batch_size, num_heads, seq_length_q, seq_length_k]


Scaled attention scores shape: (2, 8, 10, 10)


In [26]:
def _build_attention_equation(rank_q, rank_k):
    # Correctly identifying each part's role in the einsum equation
    letters = 'abcdefghijklmnopqrstuvwxyz'
    
    # Assume the batch and heads dimensions are the same for Q and K and are the first two dimensions
    base_letters = letters[:2]  # This covers batch (b) and heads (h) dimensions
    
    # The next two letters represent sequence length of Q (q) and sequence length of K (k)
    seq_letter_q = letters[2]  # Third dimension for Q
    seq_letter_k = letters[3]  # Assume next letter for K's sequence length
    
    # The last shared letter represents the depth dimension (d)
    depth_letter = letters[4]  # Shared depth dimension
    
    # Constructing the einsum equation for the dot product attention mechanism
    equation = f"{base_letters}{seq_letter_q}{depth_letter}," \
               f"{base_letters}{seq_letter_k}{depth_letter}->" \
               f"{base_letters}{seq_letter_q}{seq_letter_k}"
    
    return equation


In [27]:
rank_q = 4  # Rank for query tensor Q
rank_k = 4  # Rank for key tensor K

einsum_equation = _build_attention_equation(rank_q, rank_k)
print("Einsum equation for attention scores:", einsum_equation)


Einsum equation for attention scores: abce,abde->abcd


###  code from tf mha

In [None]:
def _build_attention_equation(rank, attn_axes):
    """Builds einsum equations for the attention computation.

    Query, key, value inputs after projection are expected to have the shape as:
    `(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.
    `bs` and `<non-attention dims>` are treated as `<batch dims>`.
    For sequence data, <non-attention dims> this might be empty. However, in more complex data like images or 3D data, this could represent spatial dimensions (height, width) or other dimensions not directly involved in attention.

    The attention operations can be generalized:
    (1) Query-key dot product:
    `(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
    <key attention dims>, num_heads, channels) -> (<batch dims>,
    num_heads, <query attention dims>, <key attention dims>)`
    (2) Combination:
    `(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
    (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch
    dims>, <query attention dims>, num_heads, channels)`

    Args:
        rank: Rank of query, key, value tensors.
        attn_axes: List/tuple of axes, `[-1, rank)`,
            that attention will be applied to.

    Returns:
        Einsum equations.
    """
    target_notation = _CHR_IDX[:rank]
    # `batch_dims` includes the head dim.
    batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
    letter_offset = rank
    source_notation = ""
    for i in range(rank):
        if i in batch_dims or i == rank - 1:
            source_notation += target_notation[i]
        else:
            source_notation += _CHR_IDX[letter_offset]
            letter_offset += 1

    product_notation = "".join(
        [target_notation[i] for i in batch_dims]
        + [target_notation[i] for i in attn_axes]
        + [source_notation[i] for i in attn_axes]
    )
    dot_product_equation = "%s,%s->%s" % (
        source_notation,
        target_notation,
        product_notation,
    )
    attn_scores_rank = len(product_notation)
    combine_equation = "%s,%s->%s" % (
        product_notation,
        source_notation,
        target_notation,
    )
    return dot_product_equation, combine_equation, attn_scores_rank

In linear algebra, the rank of a matrix is defined as the maximum number of linearly independent column vectors in the matrix or the maximum number of linearly independent row vectors in the matrix. Essentially, it measures the dimension of the vector space spanned by its columns or rows. The rank of a matrix can indeed be understood as the smaller of the two dimensions (columns or rows) for a full-rank matrix but more accurately, it's about linear independence:

Full Rank: A matrix is considered full rank if its rank equals the smaller of its number of rows or columns.
Rank Deficiency: A matrix is rank-deficient if it does not have full rank, meaning some of its rows or columns can be expressed as a linear combination of others.
Rank in Tensor Operations
In the context of tensor operations and deep learning frameworks like TensorFlow or PyTorch, the term rank is used differently. Here, the rank of a tensor simply refers to the number of dimensions (also called axes) that the tensor has. For example:

A scalar (a single number) has a rank of 0.
A vector (a 1D array of numbers) has a rank of 1.
A matrix (a 2D array of numbers) has a rank of 2.
A 3D array of numbers has a rank of 3, and so on.
This usage aligns with the notion of an n-dimensional array, where "n" is the rank.

Clarification for Tensor Shape Arguments
When referring to the rank of query, key, value tensors in the context of building attention mechanisms or neural network layers, we're talking about how many dimensions these tensors have. Each dimension (or axis) of these tensors has a certain size (or length), which represents the extent of the tensor along that dimension. For example, a tensor shape [2, 8, 10, 64] has a rank of 4, with each number representing the size of each dimension:

2 in the batch dimension,
8 in the heads dimension,
10 in the sequence length dimension,
64 in the channels (or features) dimension.
This distinction is important to keep in mind when transitioning between mathematical discussions of linear algebra and practical implementations of neural networks and tensor operations.


