In [1]:
import jax
import jax.numpy as jnp
from jax import random, jit
from jax.scipy.special import logsumexp
from tqdm import tqdm

from attention import softmax, scaled_dot_product_attention

import matplotlib.pyplot as plt



# JIT compile the attention function for faster execution
fast_attention = jit(scaled_dot_product_attention)

# Generate random data
key_len = 512
feature_dim = 256
rng = random.PRNGKey(0)
Q = random.normal(rng, (key_len, feature_dim))
K = random.normal(rng, (key_len, feature_dim))
V = random.normal(rng, (key_len, feature_dim))

# Timing loop
import time

num_iterations = 1000
# Warm up 
_ = fast_attention(Q, K, V)
start_time = time.time()

for _ in tqdm(range(num_iterations)):
    _ = fast_attention(Q, K, V)

end_time = time.time()
print(f"Average time per attention computation: {(end_time - start_time) / num_iterations:.6f} seconds")


100%|██████████| 1000/1000 [00:01<00:00, 693.55it/s]

Average time per attention computation: 0.001457 seconds





In [None]:
# Compute the attention and get the scores
_, attention_weights, scores = scaled_dot_product_attention(Q, K, V, return_scores=True)

# Plot the attention scores
plt.figure(figsize=(10, 10))
plt.imshow(scores, cmap="Blues", aspect='auto')
plt.colorbar(label='Attention Scores', orientation='vertical')
plt.title('Attention Scores Visualization')
plt.xlabel('Keys')
plt.ylabel('Queries')
plt.tight_layout()
plt.show()

In [2]:
from attention import sliding_window_attention

# JIT compile the sliding window attention function
fast_sliding_window_attention = jit(sliding_window_attention)

# Warm up
_ = fast_sliding_window_attention(Q, K, V, window_size=64)

# Timing the sliding window attention
start_time = time.time()
for _ in tqdm(range(num_iterations)):
    _ = fast_sliding_window_attention(Q, K, V, window_size=64)
end_time = time.time()
print(f"Average time per sliding window attention computation: {(end_time - start_time) / num_iterations:.6f} seconds")


100%|██████████| 1000/1000 [00:01<00:00, 767.14it/s]

Average time per sliding window attention computation: 0.001305 seconds





In [None]:
# Compute the attention and get the scores
_, scores = sliding_window_attention(Q, K, V, window_size=64)

# Plot the attention scores
plt.figure(figsize=(10, 10))
plt.imshow(scores, cmap="Blues", aspect='auto')
plt.colorbar(label='Attention Scores', orientation='vertical')
plt.title('Attention Scores Visualization')
plt.xlabel('Keys')
plt.ylabel('Queries')
plt.tight_layout()
plt.show()

In [11]:
def split_heads(x, num_heads):
    # x.shape = (seq_len, d_model)
    # After reshaping: (seq_len, num_heads, depth)
    return jnp.reshape(x, (x.shape[0], num_heads, -1))

def multi_head_attention(Q, K, V, num_heads, attention_fn, window_size=None):
    d_model = Q.shape[-1]
    # print(f"d_model type: {type(d_model)}, value: {d_model}")
    # print(f"num_heads type: {type(num_heads)}, value: {num_heads}")
    depth = d_model // num_heads
    
    # Split into multiple heads
    Q_heads = split_heads(Q, num_heads)
    K_heads = split_heads(K, num_heads)
    V_heads = split_heads(V, num_heads)
    
    # Apply attention to each head
    if attention_fn == sliding_window_attention:
        assert window_size is not None, "window_size must be provided for sliding window attention"
        outputs = [attention_fn(Q_heads[:, h], K_heads[:, h], V_heads[:, h], window_size) for h in range(num_heads)]
    else:
        outputs = [attention_fn(Q_heads[:, h], K_heads[:, h], V_heads[:, h]) for h in range(num_heads)]
    
    # Concatenate and project
    concatenated = jnp.concatenate(outputs, axis=-1)
    # Typically, you'd have an additional linear layer here, but for simplicity, we'll skip it
    return concatenated

# Update JIT compiled functions to support multi-head
fast_multi_head_attention = jit(multi_head_attention, static_argnums=(4, 3))

# Time multi-head vanilla attention
num_heads = 8
start_time = time.time()
for _ in tqdm(range(num_iterations)):
    _ = fast_multi_head_attention(Q=Q, K=K, V=V, num_heads=num_heads, attention_fn=scaled_dot_product_attention)
end_time = time.time()
print(f"Average time per multi-head vanilla attention: {(end_time - start_time) / num_iterations:.6f} seconds")



100%|██████████| 1000/1000 [00:07<00:00, 126.07it/s]

Average time per multi-head vanilla attention: 0.007934 seconds





In [15]:
def group_heads(x, num_groups):
    # x.shape = (seq_len, d_model)
    # After reshaping: (seq_len, num_groups, depth * heads_per_group)
    return jnp.reshape(x, (x.shape[0], num_groups, -1))

def grouped_query_attention(Q, K, V, num_heads, num_groups, attention_fn, window_size=None):
    assert num_heads % num_groups == 0, "num_heads should be divisible by num_groups"
    heads_per_group = num_heads // num_groups

    d_model = Q.shape[-1]
    depth = d_model // num_heads

    # Split keys and values into multiple heads, but queries into groups
    Q_groups = group_heads(Q, num_groups)
    K_heads = split_heads(K, num_heads)
    V_heads = split_heads(V, num_heads)

    outputs = []
    for g in range(num_groups):
        # For each group of queries, compute attention with all key and value heads
        Q_group = Q_groups[:, g]
        output_heads = [attention_fn(Q_group, K_heads[:, h], V_heads[:, h], window_size) if attention_fn == sliding_window_attention else attention_fn(Q_group, K_heads[:, h], V_heads[:, h]) for h in range(num_heads)]
        # Concatenate output heads for this query group
        outputs.append(jnp.concatenate(output_heads, axis=-1))
    
    # Concatenate outputs for all query groups
    concatenated = jnp.concatenate(outputs, axis=-1)
    return concatenated

# JIT compile for speed
fast_grouped_query_attention = jit(grouped_query_attention, static_argnums=(5,))

# Time grouped query multi-head attention
num_groups = 4
# Make sure num_heads is divisible by num_groups before calling the JIT-compiled function
if num_heads % num_groups != 0:
    raise ValueError("num_heads should be divisible by num_groups")
start_time = time.time()
for _ in range(num_iterations):
    _ = fast_grouped_query_attention(Q=Q, K=K, V=V, num_heads=num_heads, num_groups=num_groups, attention_fn=scaled_dot_product_attention)
end_time = time.time()
print(f"Average time per grouped query vanilla attention: {(end_time - start_time) / num_iterations:.6f} seconds")



TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function grouped_query_attention at /var/folders/yd/npt3q5rj1mvdmlw309mn5d240000gp/T/ipykernel_74694/2317686948.py:6 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = eq b c
    from line /var/folders/yd/npt3q5rj1mvdmlw309mn5d240000gp/T/ipykernel_74694/2317686948.py:7:11 (grouped_query_attention)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError