Tested with free Google Compute Engine Backend. No GPU required.

# Imports

In [None]:
import jax
import flax.linen as nn
import jax.numpy as jnp
from flax.training import train_state
import optax
import numpy as np
import matplotlib.pyplot as pp
import tqdm
import unittest
import time
import functools
import math

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
#@title Helper functions
dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

def get_batch(random_key, data, batch_size, block_size):
  """Generate a batch of data of inputs x and targets y.

  Args:
    random_key (jax.random.PRNGKey): Random number generator key.
    data (array-like): 1d JAX array of integer tokens
    batch_size (int): Batch size.
    block_size (int): The maximum input context length.

  Returns:
    x (array-like): 2d JAX array of shape (batch_size, block_size).
    y (array-like): 2d JAX array of shape (batch_size, block_size).
        x[i, j] == y[i, j-1] where j > 0.
  """
  # generate a small batch of data of inputs x and targets y
  ix = jax.random.randint(random_key, shape=(batch_size, 1), minval=0, maxval=len(data)-block_size)
  x = dynamic_slice_vmap(data, ix, (block_size,))
  y = dynamic_slice_vmap(data, ix+1, (block_size,))
  return x, y

def load_shakespeare_dataset():
  with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
  data = jnp.array(encode(text))
  n = int(0.9*len(data)) # first 90% will be train, rest val
  train_data = data[:n]
  eval_data = data[n:]
  return train_data, eval_data

def init_train_state(
    model,
    params,
    learning_rate=1e-4,
):
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def train_step(state, x, y):
  """Run one step of training.
  Args:
    state (jax.training.TrainState): Jax TrainState containing weights and
      optimizer states.
    x (array-like): 2d JAX int array of shape (batch_size, block_size).
    y (array-like): 2d JAX int array of shape (batch_size, block_size).

  Returns:
    state (jax.training.TrainState): The new train state after applying
      gradient descent on weights and updating optimizer states.
    loss (float): Loss for this training step.
  """
  def _loss(params):
    predictions = state.apply_fn(params, x) # B, T, vocab_size
    loss = optax.softmax_cross_entropy_with_integer_labels(predictions, y)
    return loss.mean()
  loss, grads = jax.value_and_grad(_loss)(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss

@jax.jit
def eval_step(state, x, y):
  predictions = state.apply_fn(state.params, x)
  return optax.softmax_cross_entropy_with_integer_labels(predictions, y).mean()

def run_training_loop(
    num_iterations,
    batch_size,
    block_size,
    learning_rate,
    eval_data,
    train_data,
    model,
):
  """
  Runs the training loop for the specified model.

  Args:
      num_iterations (int): The number of training iterations.
      batch_size (int): The number of samples in each batch.
      block_size (int): The size of each block (sequence length).
      learning_rate (float): The learning rate for the optimizer.
      eval_data (array-like): 1d JAX array of integer tokens, consisting of evaluation data.
      train_data (array-like): 1d JAX array of integer tokens, consisting of training data.
      model (nn.Module, optional): A Jax Model object.

  Returns:
      state: The training state with the best eval metrics.

  Example:
      >>> final_state = run_training_loop(
      >>>     num_iterations=1000,
      >>>     batch_size=16,
      >>>     block_size=32,
      >>>     learning_rate=0.001,
      >>>     eval_data=eval_data,
      >>>     train_data=train_data,
      >>>     model=mini_gpt
      >>> )
  """
  random_key = jax.random.PRNGKey(0)
  x = jnp.ones((batch_size, block_size), dtype=jnp.int16)
  random_key, random_subkey = jax.random.split(random_key)
  params = model.init(random_subkey, x)
  state = init_train_state(
      model, params, learning_rate=learning_rate)
  predictions = state.apply_fn(state.params, x)
  best_state = state
  best_eval_loss = math.inf
  for i in range(num_iterations):
    random_key, random_subkey = jax.random.split(random_key)
    x, y = get_batch(random_subkey, train_data, batch_size=batch_size, block_size=block_size)
    state, loss = train_step(state, x, y)

    if i % 100 == 0:
      random_key, random_subkey = jax.random.split(random_key)
      eval_loss = eval_step(state, *get_batch(random_subkey, eval_data, batch_size=batch_size, block_size=block_size))
      print(f"Step: {i}\t train loss: {loss}\t eval loss: {eval_loss}")
      if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        best_state = state
  return best_state

## Load and tokenize dataset

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print("length of dataset in characters: ", len(text))

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Let's now split up the data into train and validation sets
data = jnp.array(encode(text))
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
eval_data = data[n:]

# Warm up - Check the performance of a simple text decoder model

The SimpleDecoder below will predict the next token given a single token.

In [None]:
class SimpleDecoder(nn.Module):
  vocab_size: int

  def setup(self):
    self.token_embedding = nn.Embed(
        num_embeddings=self.vocab_size,
        features=self.vocab_size)

  def __call__(self, x):
    B, T = x.shape
    return self.token_embedding(x) # B, T, vocab_size

  def generate(self, start_token, max_length=20, end_token=None):
    # Initialize the generated sequence with the start token
    generated_sequence = [start_token]
    current_token = start_token

    for _ in range(max_length - 1):  # We already have the start token
      # Convert the current token to a tensor
      current_token_tensor = jnp.array([[current_token]])

      # Get the token embeddings
      token_logits = self.__call__(current_token_tensor)

      # Get the token with the highest probability
      next_token = jnp.argmax(token_logits, axis=-1)[0]

      # Append the next token to the generated sequence
      generated_sequence.append(int(next_token[0]))

      # If the end token is generated, stop the generation
      if end_token is not None and next_token[0] == end_token:
          break

      # Update the current token
      current_token = int(next_token[0])

    return generated_sequence

decoder = SimpleDecoder(vocab_size=vocab_size)
start_token = 23
dummy = jnp.ones((4, 8), dtype=jnp.int16)
params = decoder.init(jax.random.PRNGKey(0), dummy)

# Generate a sequence
generated_sequence = decoder.apply(params, start_token, method=decoder.generate, max_length=20)
print("Generated sequence:", decode(generated_sequence))


The Generated sequence is gibberish. Let's see if it gets better when we train it.

In [None]:
# You can play around the parameters here to see how that affects loss.
num_iterations = 7000
learning_rate = 1e-3
num_layers = 4
batch_size = 16
block_size = 32
num_heads = 4
hidden_dim = 64

decoder = SimpleDecoder(vocab_size=vocab_size)

simple_decoder_state = run_training_loop(
    num_iterations = num_iterations,
    learning_rate = learning_rate,
    batch_size = batch_size,
    block_size = block_size,
    eval_data = eval_data,
    train_data = train_data,
    model = decoder
)

In [None]:
generated_sequence = decoder.apply(simple_decoder_state.params, start_token, method=decoder.generate, max_length=20)
print("Generated sequence:", decode(generated_sequence))


# Task 1 - Implement MiniGPT.

* You can use off-the-shelf Flax modules like Dense, LayerNorm. You may not use Flax's SelfAttention. Instead, use AttentionTask1 provided below.
* Note that block_size, T, input context window length are different ways to refer to the same thing.

In [None]:
# B == batch_size.
# T == number of tokens in sequence.
# C == hidden_dim == hidden dimension of transformer.
# head_dim == Head dimension for each Attention head. head_dim * num_heads == C.

# You can use this class for solving Task 1. We will revisit this class in Task 2.
class AttentionTask1(nn.Module):
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)
    self.attention_impl = nn.MultiHeadDotProductAttention(
        num_heads=1, qkv_features=self.head_dim, dropout_rate=0.)

  def __call__(self, x):
    # x is of shape B, T, C.
    q = self.query(x)  # B, T, head_dim
    k = self.key(x)  # B, T, head_dim
    v = self.value(x)  # B, T, head_dim
    return self.attention_impl(inputs_q=q, inputs_k=k, inputs_v=v)  # B, T, head_dim

# FeedForward is given to you for free.
class FeedForward(nn.Module):
  hidden_dim: int

  def setup(self):
    self.f1 = nn.Dense(features=4 * self.hidden_dim)
    self.f2 = nn.Dense(features=self.hidden_dim)

  def __call__(self, x):
    return self.f2(nn.relu(self.f1(x)))  # B, T, hidden_dim

class MultiHeadAttention(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.heads = [AttentionTask1(self.head_dim) for _ in range(self.num_heads)]
    self.dense = nn.Dense(self.num_heads*self.head_dim)

  def __call__(self, x):
    # TODO: Implement multi-head attention.
    return self.dense(x)  # B, T, hidden_dim

class DecoderBlock(nn.Module):
  hidden_dim: int
  num_heads: int

  def setup(self):
    # head_dim * num_heads == hidden_dim should always hold true.
    head_dim = self.hidden_dim // self.num_heads
    # TODO: Fill out the rest of setup function.

  def __call__(self, x):
    # TODO: Implement this function.
    return x  # B, T, hidden_dim

class MiniGPT(nn.Module):
  vocab_size: int
  hidden_dim: int
  block_size: int
  num_layers: int
  num_heads: int

  def setup(self):
    self.token_embedding = nn.Embed(
        num_embeddings=self.vocab_size,
        features=self.hidden_dim)
    self.position_encoding = nn.Embed(
        num_embeddings=self.block_size,
        features=self.hidden_dim
    )
    self.final_dense = nn.Dense(features=self.vocab_size)
    # TODO: Fill out the rest of this function.

  def __call__(self, x):
    B, T = x.shape
    x = self.token_embedding(x)  # B, T, hidden_dim

    # TODO: Fill in missing functionalities here.

    return self.final_dense(x)

  def generate(self, random_key, params, x, max_new_tokens=50):
    for _ in range(max_new_tokens):
      logits = self.apply(params, x[:, -self.block_size:])
      random_key, random_subkey = jax.random.split(random_key)
      new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=None)
      x = jnp.concatenate([x, new_token[:, None]], axis=1)
    return x

In [None]:
# You can play around the parameters here to see how that affects loss.
num_iterations = 4000
learning_rate = 1e-3
num_layers = 4
batch_size = 16
block_size = 32
num_heads = 4
hidden_dim = 128

mini_gpt = MiniGPT(
    vocab_size=vocab_size,
    hidden_dim=hidden_dim,
    block_size=block_size,
    num_layers=num_layers,
    num_heads=num_heads
)

mini_gpt_state = run_training_loop(
    num_iterations=num_iterations,
    learning_rate=learning_rate,
    batch_size=batch_size,
    block_size=block_size,
    eval_data=eval_data,
    train_data=train_data,
    model=mini_gpt
)

In [None]:
# Uncomment below to print predictions:
# x = jnp.zeros((1, 1), dtype=jnp.int32)
# random_key = jax.random.PRNGKey(0)
# tokens = mini_gpt.generate(random_key, params=mini_gpt_state.params, x=x)
# print(decode(tokens[0].tolist()))

In [None]:
# Pass this test before moving on to Task 2.
class TestTask1(unittest.TestCase):

  def test_minigpt(self):
    # Do not change these parameters.
    num_iterations = 4000
    learning_rate = 1e-3
    num_layers = 4
    batch_size = 16
    block_size = 32
    num_heads = 4
    hidden_dim = 128
    random_key = jax.random.PRNGKey(42)

    mini_gpt = MiniGPT(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        block_size=block_size,
        num_layers=num_layers,
        num_heads=num_heads
    )

    train_data, eval_data = load_shakespeare_dataset()
    mini_gpt_state = run_training_loop(
        num_iterations = num_iterations,
        learning_rate = learning_rate,
        batch_size = batch_size,
        block_size = block_size,
        eval_data = eval_data,
        train_data = train_data,
        model = mini_gpt
    )
    eval_losses = []
    for _ in tqdm.tqdm(range(100)):
      random_key, random_subkey = jax.random.split(random_key)
      x, y = get_batch(
          random_subkey, eval_data, batch_size=batch_size, block_size=block_size)
      batch_eval_loss = eval_step(mini_gpt_state, x, y)
      eval_losses.append(batch_eval_loss)
    print(f"Average eval loss: {np.mean(eval_losses)}")
    self.assertTrue(np.mean(eval_losses) < 1.9)

# Uncomment the test below.
# TestTask1().test_minigpt()

# Task 2 - implement the Self-Attention Jax Module

Your task is to implement Attention without using Flax's built-in nn.MultiHeadDotProductAttention module. Fill in the TODO section below.

Things to keep in mind:

* We are implementing a decoder-only transformer. This means that each token can only attend to previous tokens, but not future tokens.

In [None]:
class AttentionTask2(nn.Module):
  head_dim: int

  def setup(self):
    # Don't change the setup function.
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)

  def __call__(self, x):
    B, T, C = x.shape

    # TODO: This function contains an incorrect attention implmentation. Change
    # its definition below:
    B, T, C = x.shape
    return self.query(x) + self.key(x) + self.value(x) # B, T, head_dim

In [None]:
class TestAttention(unittest.TestCase):

  EXPECTED_ATTENTION_ARRAY = np.array([[
      [-2.3660736,  -1.0994253,   0.54647386,  1.663486,    1.0262686,
      0.50164324, -0.40740347, -0.86529493,  1.6112939,  -0.46789974,
      1.3150474,   0.9799258,  -0.5418715,  -1.2731858,  -0.7926506,
      -0.8737542],
      [-1.2626604,   1.3287369,   0.96550566,  0.4553011,   0.6900299,
      -0.6283262,  -0.44400188,  0.18089633, -0.6977915,  -0.49270085,
      0.1377207,   0.19912332, -0.02095406, -1.0335875,  -0.13449836,
      -0.9766264],
      [-2.1633344,  -0.7197231,   0.59619266,  1.4519494,   0.9575919,
      0.33423916, -0.39966965, -0.69272554,  1.2452503,  -0.46386948,
      1.1163911,   0.84217906, -0.45448378, -1.2067673,  -0.6738551,
      -0.87238634],
      [ 0.21197765,  0.2127537,  -0.27920845, -0.4683921,  -0.22381224,
      0.49012795,  0.44253582,  0.2606917,   0.03008281,  0.06132472,
      -0.28707987, -0.4550119,   0.16932811,  0.7396863,   0.54958737,
      0.23469326]
  ]])

  def test_attention(self):
    attention = AttentionTask2(head_dim=16)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    params = attention.init(subkey, jnp.ones((1, 4, 8)))
    x = jax.random.normal(key=key, shape=(1, 4, 8), dtype=jnp.float32)
    y = attention.apply(params, x)
    self.assertTrue(np.allclose(y, self.EXPECTED_ATTENTION_ARRAY))

# Uncomment the test below.
TestAttention().test_attention()

# Task 3 Speed up MultiheadAttention with Einsum.

Please finish task 2 first before doing this task.

In [None]:
class MultiHeadAttentionTask3(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.dense = nn.Dense(features=self.num_heads * self.head_dim)

  def __call__(self, x):
    B, T, C = x.shape
    # TODO: Implement this using Einsum.

    return self.query(x) + self.key(x) + self.value(x) # B, T, head_dim


In [None]:
class TestMultiHeadEinsum(unittest.TestCase):
  EXPECTED_ATTENTION_ARRAY = np.array([[
      [-2.3660736,  -1.0994253,   0.54647386,  1.663486,    1.0262686,
      0.50164324, -0.40740347, -0.86529493,  1.6112939,  -0.46789974,
      1.3150474,   0.9799258,  -0.5418715,  -1.2731858,  -0.7926506,
      -0.8737542],
      [-1.2626604,   1.3287369,   0.96550566,  0.4553011,   0.6900299,
      -0.6283262,  -0.44400188,  0.18089633, -0.6977915,  -0.49270085,
      0.1377207,   0.19912332, -0.02095406, -1.0335875,  -0.13449836,
      -0.9766264],
      [-2.1633344,  -0.7197231,   0.59619266,  1.4519494,   0.9575919,
      0.33423916, -0.39966965, -0.69272554,  1.2452503,  -0.46386948,
      1.1163911,   0.84217906, -0.45448378, -1.2067673,  -0.6738551,
      -0.87238634],
      [ 0.21197765,  0.2127537,  -0.27920845, -0.4683921,  -0.22381224,
      0.49012795,  0.44253582,  0.2606917,   0.03008281,  0.06132472,
      -0.28707987, -0.4550119,   0.16932811,  0.7396863,   0.54958737,
      0.23469326]
  ]])

  def test_multihead_einsum(self):
    attention_einsum = MultiHeadAttentionTask3(num_heads=2, head_dim=8)
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    params = attention_einsum.init(subkey, jnp.ones((1, 4, 16)))
    x = jax.random.normal(key=key, shape=(1, 4, 16), dtype=jnp.float32)
    y = attention_einsum.apply(params, x)
    self.assertTrue(np.allclose(y, self.EXPECTED_ATTENTION_ARRAY))

# TestMultiHeadEinsum().test_multihead_einsum()

In [None]:
# TODO: Rerun the training loop using MultiHeadAttentionTask3. Can you still achieve the same eval results?

# Appendix

* Run the Import section at the beginning of this colab before running the Solution.
* The Solutions below need to be executed sequentially.

In [None]:
#@title Solution for task 1.
class ReferenceAttention(nn.Module):
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)
    self.attention_impl = nn.MultiHeadDotProductAttention(
        num_heads=1, qkv_features=self.head_dim, dropout_rate=0.)

  def __call__(self, x):
    B, T, C = x.shape
    q = self.query(x)  # B, T, head_dim
    k = self.key(x)  # B, T, head_dim
    v = self.value(x)  # B, T, head_dim
    mask = jnp.tril(jnp.ones((B, 1, T, T)))
    return self.attention_impl(inputs_q=q, inputs_k=k, inputs_v=v, mask=mask)  # B, T, head_dim

class FeedForward(nn.Module):
  hidden_dim: int

  def setup(self):
    self.f1 = nn.Dense(features=4 * self.hidden_dim)
    self.f2 = nn.Dense(features=self.hidden_dim)

  def __call__(self, x):
    return self.f2(nn.relu(self.f1(x)))  # B, T, hidden_dim

class MultiHeadAttentionSolution(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.heads = [ReferenceAttention(self.head_dim) for _ in range(self.num_heads)]
    self.dense = nn.Dense(self.num_heads*self.head_dim)

  def __call__(self, x):
    x = jnp.concatenate([h(x) for h in self.heads], axis=-1)
    return self.dense(x)  # B, T, hidden_dim

class DecoderBlockSolution(nn.Module):
  hidden_dim: int
  num_heads: int

  def setup(self):
    head_dim = self.hidden_dim // self.num_heads
    self.mha = MultiHeadAttentionSolution(
        num_heads=self.num_heads,
        head_dim=head_dim)
    self.ff = FeedForward(self.hidden_dim)
    self.ff_norm = nn.LayerNorm()
    self.mha_norm = nn.LayerNorm()

  def __call__(self, x):
    x = x + self.mha(self.mha_norm(x))
    return x + self.ff(self.ff_norm(x))


class MiniGPTSolution(nn.Module):
  vocab_size: int
  hidden_dim: int
  block_size: int
  num_layers: int
  num_heads: int

  def setup(self):
    self.token_embedding = nn.Embed(
        num_embeddings=self.vocab_size,
        features=self.hidden_dim)
    self.position_encoding = nn.Embed(
        num_embeddings=self.block_size,
        features=self.hidden_dim
    )
    self.decoder_blocks = [
        DecoderBlockSolution(self.hidden_dim, self.num_heads) for _ in range(self.num_layers)
    ]
    self.final_norm = nn.LayerNorm()
    self.final_dense = nn.Dense(features=self.vocab_size)

  def __call__(self, x):
    B, T = x.shape
    x = self.token_embedding(x)  # B, T, hidden_dim
    pos = self.position_encoding(jnp.arange(T))  # T, hidden_dim
    x += pos
    for block in self.decoder_blocks:
      x = block(x)
    return self.final_dense(self.final_norm(x))

  def generate(self, random_key, params, x, max_new_tokens=50):
    for _ in range(max_new_tokens):
      logits = self.apply(params, x[:, -self.block_size:])
      random_key, random_subkey = jax.random.split(random_key)
      new_token = jax.random.categorical(random_subkey, logits[:, -1, :], axis=-1, shape=None)
      x = jnp.concatenate([x, new_token[:, None]], axis=1)
    return x

# This is a duplicate of TestTask1.
class TestTask1Solution(unittest.TestCase):

  def test_minigpt(self):
    num_iterations = 4000
    learning_rate = 1e-3
    num_layers = 4
    batch_size = 16
    block_size = 32
    num_heads = 4
    hidden_dim = 128
    random_key = jax.random.PRNGKey(42)

    model = MiniGPTSolution(
        vocab_size=vocab_size,
        hidden_dim=hidden_dim,
        block_size=block_size,
        num_layers=num_layers,
        num_heads=num_heads
    )
    train_data, eval_data = load_shakespeare_dataset()
    mini_gpt_state = run_training_loop(
        num_iterations = num_iterations,
        learning_rate = learning_rate,
        batch_size = batch_size,
        block_size = block_size,
        eval_data = eval_data,
        train_data = train_data,
        model = model
    )
    eval_losses = []
    for _ in tqdm.tqdm(range(100)):
      random_key, random_subkey = jax.random.split(random_key)
      x, y = get_batch(
          random_subkey, eval_data, batch_size=batch_size, block_size=block_size)
      batch_eval_loss = eval_step(mini_gpt_state, x, y)
      eval_losses.append(batch_eval_loss)
    print(f"Average eval loss: {np.mean(eval_losses)}")
    self.assertTrue(np.mean(eval_losses) < 1.9)

# Uncomment to execute test.
TestTask1Solution().test_minigpt()

In [None]:
#@title Solution for task 2.

class AttentionTask2Solution(nn.Module):
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.head_dim, use_bias=False)

  def __call__(self, x):
    B, T, C = x.shape
    q = self.query(x) # B, T, head_dim
    k = self.key(x) # B, T, head_dim
    wei = q @ jax.numpy.transpose(k, axes=(0, 2, 1)) # B, T, T
    mask = jnp.tril(jnp.ones((T, T)))
    wei = jnp.where(mask, wei, -jnp.inf)
    wei = nn.softmax(wei / jnp.sqrt(self.head_dim), axis=-1) # B, T, T
    return wei @ self.value(x) # B, T, C

# This is a duplicate of TestAttention. The only difference is that the test
# runs AttentionTask2Solution instead of Attention.
class TestAttention(unittest.TestCase):

  EXPECTED_ATTENTION_ARRAY = np.array([
    [[-0.3368626, 0.1565489, 0.96250117, 0.7116083, 0.48668504,
      0.3070267, -0.49149823, 0.7827484, 0.4131582, 0.7505922,
      0.90185213, -0.34802976, 1.2631372, 0.8314824, 0.45534268,
      0.11072167],
     [0.355573, 0.36409345, 0.19864899, 0.58222437, -0.01833684,
      0.8821246, 0.26334122, 0.10999514, 0.69409794, 0.3437622,
      -0.71399987, 0.6530971, 0.00235165, -0.5397035, 0.55874693,
      -0.4885986],
     [0.6003635, 0.34785143, -0.25671193, 0.3002994, -0.31720588,
      1.2125036, 0.6570689, -0.22460055, 0.9200514, -0.01703957,
      -1.5395278, 1.1767541, -0.7460983, -1.3350787, 0.61231965,
      -1.0458561],
     [-0.7845163, -0.5571454, 0.39112994, -0.63247937, -0.2971205,
      0.19273886, -0.25068092, 0.5804176, 0.3952121, 0.24023446,
      1.1744585, -1.0228857, 1.0987606, 0.90741533, 0.19215004,
      -0.98253024]]
    ]
  )

  def test_attention(self):
    attention = AttentionTask2Solution(head_dim=16)
    params = attention.init(jax.random.key(0), jnp.ones((1, 4, 8)))
    x = jax.random.normal(key=jax.random.key(0), shape=(1, 4, 8), dtype=jnp.float32)
    y = attention.apply(params, x)
    self.assertTrue(np.allclose(y, self.EXPECTED_ATTENTION_ARRAY))

TestAttention().test_attention()

In [None]:
#@title Solution for task 3
# This is only used for the next task.
class MultiHeadAttentionReferenceTask3(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.heads = [AttentionTask2Solution(self.head_dim) for _ in range(self.num_heads)]
    self.dense = nn.Dense(self.num_heads*self.head_dim)

  def __call__(self, x):
    x = jnp.concatenate([h(x) for h in self.heads], axis=-1)
    return self.dense(x)  # B, T, hidden_dim

class MultiHeadAttentionTask3Solution(nn.Module):
  num_heads: int
  head_dim: int

  def setup(self):
    self.query = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.key = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.value = nn.Dense(features=self.num_heads * self.head_dim, use_bias=False)
    self.dense = nn.Dense(features=self.num_heads * self.head_dim)

  def __call__(self, x):
    B, T, C = x.shape
    q = self.query(x)  # (B, T, num_heads*head_dim)
    k = self.key(x)  # (B, T, num_heads*head_dim)
    v = self.value(x)  # (B, T, num_heads*head_dim)
    q = q.reshape(B, T, self.num_heads, self.head_dim).transpose((0, 2, 1, 3))
    k = k.reshape(B, T, self.num_heads, self.head_dim).transpose((0, 2, 1, 3))
    v = v.reshape(B, T, self.num_heads, self.head_dim).transpose((0, 2, 1, 3))
    wei = jnp.einsum('bnth,bnsh->bnts', q, k) / jnp.sqrt(self.head_dim)  # (B, num_heads, T, T)
    mask = jnp.tril(jnp.ones((T, T)))
    wei = jnp.where(mask, wei, -jnp.inf)
    wei = nn.softmax(wei, axis=-1)  # (B, num_heads, T, T)
    out = jnp.einsum('bnts,bnsh->bnth', wei, v)  # (B, num_heads, T, head_dim)
    out = out.transpose((0, 2, 1, 3)).reshape(B, T, self.num_heads * self.head_dim)
    return self.dense(out)  # B, T, C

def measure_attention_time(
    batch_size, seq_length, num_heads, head_dim, model, num_iterations=20):
  total_time = 0.0
  rng = jax.random.PRNGKey(0)
  x = jax.random.normal(rng, (batch_size, seq_length, num_heads * head_dim))
  params = model.init(rng, x)
  jitted_apply = jax.jit(model.apply)
  # The first run is for compiling Jax. We'll ignore it.
  _ = jitted_apply(params, x)

  for _ in tqdm.tqdm(range(num_iterations)):
    start_time = time.time()
    jitted_apply(params, x)
    end_time = time.time()
    total_time += end_time - start_time
  average_attention_time = total_time / num_iterations
  print(f"Average inference time: {average_attention_time} seconds")
  return average_attention_time


In [None]:
class TestMultiHeadEinsum(unittest.TestCase):
  def test_multihead_einsum(self):
    head_dim = 4
    num_heads = 2
    batch_size = 1
    seq_length = 2
    hidden_dim = head_dim * num_heads
    random_key = jax.random.PRNGKey(0)

    new_model = MultiHeadAttentionTask3Solution(
        head_dim=head_dim,
        num_heads=num_heads
    )

    new_params = new_model.init(random_key, jnp.ones((batch_size, seq_length, hidden_dim), dtype=jnp.int16))

    baseline_model = MultiHeadAttentionReferenceTask3(
        head_dim=head_dim,
        num_heads=num_heads
    )
    baseline_params = baseline_model.init(random_key, jnp.ones((batch_size, seq_length, hidden_dim), dtype=jnp.int16))

    baseline_params['params']['heads_0']['query']['kernel'] = new_params['params']['query']['kernel'][:, :4].copy()
    baseline_params['params']['heads_0']['key']['kernel'] = new_params['params']['key']['kernel'][:, :4].copy()
    baseline_params['params']['heads_0']['value']['kernel'] = new_params['params']['value']['kernel'][:, :4].copy()
    baseline_params['params']['heads_1']['query']['kernel'] = new_params['params']['query']['kernel'][:, 4:].copy()
    baseline_params['params']['heads_1']['key']['kernel'] = new_params['params']['key']['kernel'][:, 4:].copy()
    baseline_params['params']['heads_1']['value']['kernel'] = new_params['params']['value']['kernel'][:, 4:].copy()
    baseline_params['params']['dense']['kernel'] = new_params['params']['dense']['kernel'].copy()
    baseline_params['params']['dense']['bias'] = new_params['params']['dense']['bias'].copy()

    baseline_res = baseline_model.apply(baseline_params, jnp.ones((batch_size, seq_length, hidden_dim), dtype=jnp.int16))
    new_res = new_model.apply(new_params, jnp.ones((batch_size, seq_length, hidden_dim), dtype=jnp.int16))
    self.assertTrue(np.allclose(new_res, baseline_res))

TestMultiHeadEinsum().test_multihead_einsum()