<a href="https://colab.research.google.com/github/zhenyiqi/rawLLM/blob/main/TransformerJax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This notebook follows strictly the paper ["Attenion is all you need"](https://arxiv.org/pdf/1706.03762.pdf) (or it intends to do so), althought the architecture has countless variants right now. For different architectures, please refer to other notebooks (coming soon).

# Imports

In [1]:
import time
import math
from functools import partial

In [2]:
import numpy as np
# import torch for getting the data
from torch.utils import data
from torchvision.datasets import MNIST

In [3]:
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from jax import grad, jit, vmap
# random is used to generate random matrix and manage random keys.
from jax import random
import jax
from flax.training import train_state, checkpoints

In [4]:
# ! pip install flax --quiet

In [5]:
import flax
from flax import linen as nn

In [6]:
## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

# Define constants

In [7]:
rng_main = random.PRNGKey(0)

In [8]:
# batch_size = 16
# n_targets = 10
# num_epochs = 5

# layer_sizes = [784, 512, 512, 10]
# step_size = 0.01

# Training Prep

In [8]:
rng = random.PRNGKey(42)

## Utils for initalizing parameters

In [9]:
rng

Array([ 0, 42], dtype=uint32)

In [10]:
# from traitlets.traitlets import Tuple

# # A helper function to randomly initialize weights and biases
# # for a dense neural network layer.
# def random_layer_params(input_dim, output_dim, key, scale=1e-2):
#   w_key, b_key = random.split(key)
#   # random.normal(w_key, (n, m)) generates a random matrix of dimension (n, m)
#   return (scale * random.normal(w_key, (output_dim, input_dim)),
#           scale * random.normal(b_key, (output_dim,)))

# # Initialize all layers for a fully-connected neural network with sizes "sizes"
# def init_network_params(sizes, key: ...):
#   keys = random.split(key, len(sizes))
#   return [random_layer_params(m, n, k)
#           for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

## Core training utils

### Builds the model

#### Scaled dot-product attention

The scaled dot-product attention is calculated as  


$$\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

$Q\in\mathbb{R}^{T\times d_k}$, keys $K\in\mathbb{R}^{T\times d_k}$ and values $V\in\mathbb{R}^{T\times d_v}$, where $T$ is the sequence length, and $d_k$ is the hidden dimension. In practice, we can add make all of them tensor of dimension $\mathbb{R}^{B\times T \times d_k}$, where $B$ is batch size.

In [11]:
def scaled_dot_product_attention(q, k, v, mask = None):
  """Q attend to the K-V pair.

  q, k, v are matrices of size (batch_size * ) T * k_q, where, T is the sequence length,
  and k_q is the hidden dimension.
  We may add dimensions in the front for parallelized computation."""
  k_q = q.shape[-1]

  # attention_logits is of dimension (batch_size * ) T * T
  attention_logits = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / math.sqrt(k_q)
  if mask is not None:
    attention_logits = jnp.where(mask == 0, -9e15, attention_logits)
  attention = nn.softmax(attention_logits)

  # values is of dimension (batch_size * ) T * k_q again.
  values = jnp.matmul(attention, v)
  return values

#### test scaled dot-product attention

In [12]:
# assuming we have a tensor of (batch_size=2, sequence_length=2, k_q=3)
q = jnp.array([[[1, 2, 3], [4, 5, 6]], [[2, 2, 2], [3, 3, 3]]])
v = q.copy()
k = q.copy()

In [13]:
q.shape == (2, 2, 3)

True

In [14]:
values = scaled_dot_product_attention(q, k, v)
# final values should have the same dimension
values.shape == (2, 2, 3)

True

In [15]:
# now assuming that we have a tensor of (num_heads=3, batch_size=2, sequence_length=2, k_q = 3)
q = jnp.array([
    [[[1, 2, 3], [4, 5, 6]], [[2, 2, 2], [3, 3, 3]]],
    [[[1, 2, 3], [4, 5, 6]], [[2, 2, 2], [3, 3, 3]]],
    [[[1, 2, 3], [4, 5, 6]], [[2, 2, 2], [3, 3, 3]]],
    [[[1, 2, 3], [4, 5, 6]], [[2, 2, 2], [3, 3, 3]]],
])
k = q.copy()
v = q.copy()
q.shape == (4, 2, 2, 3)

True

In [16]:
values = scaled_dot_product_attention(q, k, v)

In [17]:
values.shape == (4, 2, 2, 3)

True

#### MLP Layer

In [18]:
def relu(x: jnp.ndarray):
  """Relu operation on a matrix"""
  return jnp.maximum(0, x)

class MLPLayer(nn.Module):
  """A Multi-layer perceptor + a layer norm with skipped add."""
  input_dim: int
  hidden_dim: int

  def setup(self):
    self.layer_norm = nn.LayerNorm()
    self.input_layer = nn.Dense(self.hidden_dim, use_bias=False)
    self.output_layer = nn.Dense(self.input_dim, use_bias=False)

  def __call__(self, input):
    x = input
    x = self.input_layer(x)
    x = self.output_layer(x)
    x = self.layer_norm(x + input)
    return x

#### MultiHeadAttention Layer

In [19]:
class MultiHeadAttnLayer(nn.Module):
  """A Multi-Head attention layer.

  This is the most explicit translation of the paper.
  """
  output_dim: int
  num_heads: int  # also notated as h.

  def setup(self):
    if self.output_dim % (self.num_heads) != 0:
      raise ValueError(
          'output_dim for a MultiHeadAttnLayer must be multiples of '
          'num_heads.')

    # Projection layers:
    # output dimension k_q of each connected layer is
    # (total output dimension) / num_heads
    # there are 3 * num_heads of such layers, because for each head,
    # we need to apply such layer to q, k and v matrices. The q, k, v matrices
    # are of dimension (batch_size, sequence_length, embed_dim), therefore the
    # the weight matrix of each layer needs to be (embed_dim, k_v). And there
    # are 3 * num_heads of them.

    # 1) [0, num_heads) are applied to q
    # 2) [num_heads, num_heads * 2) are applied to k
    # 3) [2 * num_heads, 3 * num_heads) are applied to v
    self.qkv_projs = [nn.Dense(self.output_dim // self.num_heads,
                               kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
                               bias_init=nn.initializers.zeros) for _ in range(3 * self.num_heads)]

    self.layer_norm = nn.LayerNorm()


  def __call__(self,
               x, y,
               mask: jnp.ndarray | None = None):
    batch_size, sequence_length, embed_dim = x.shape  # embed_dim is also d_{model}

    # after the code block below, the dimension of Q, K, V becomes
    # [batch_size, T, output_dim / (3 * num_heads), 3]
    # q, k and v are matrices of dimension (num_heads, batch_size, T, k_q).
    # num_heads is the new dimension after stacking.
    q = jnp.stack(tuple(self.qkv_projs[i](x) for i in range(self.num_heads)))
    k = jnp.stack(tuple(self.qkv_projs[i](y) for i in range(self.num_heads, 2 * self.num_heads)))
    v = jnp.stack(tuple(self.qkv_projs[i](y) for i in range(2 * self.num_heads, 3 * self.num_heads)))

    # values will be a tensor of dimension (num_heads, batch_size, T, k_q)
    values = scaled_dot_product_attention(q, k, v, mask=mask)

    # concatenate the values of different heads
    # values will be a tensor of dimension (batch_size, T, embed_dim)
    values = jnp.moveaxis(values, 0, -1)
    values = values.reshape(batch_size, sequence_length, -1)

    # skip-add operation
    values = self.layer_norm(values + x)

    return values

#### test MultiHeadAttention Layer

In [27]:
# The three dimensions correspond to (batch_size, sequence_length, input_dim/output_dim)
x = random.normal(random.PRNGKey(42), (3, 128, 512))
y = random.normal(random.PRNGKey(0), (3, 128, 512))

In [28]:
multi_head_attention = MultiHeadAttnLayer(output_dim=512, num_heads=8)

In [29]:
params = multi_head_attention.init(random.PRNGKey(42), x, x)

In [30]:
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

{'params': {'layer_norm': {'bias': (512,), 'scale': (512,)},
  'qkv_projs_0': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_1': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_10': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_11': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_12': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_13': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_14': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_15': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_16': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_17': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_18': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_19': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_2': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_20': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_21': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_22': {'bias': (64,), 'kernel': (512, 64)},
  'qkv_projs_23': {'bias': (64,), 'kernel': (5

In [31]:
out = multi_head_attention.apply(params, x, x)

In [32]:
out.shape == x.shape

True

#### Encoder Block

In [46]:
class EncoderBlock(nn.Module):
  """Encoder block that contains multi-head self-attention layer and MLP."""
  input_dim: int
  num_heads: int
  def setup(self):
    self.multi_head_attn_layer = MultiHeadAttnLayer(output_dim=self.input_dim,
                                                    num_heads=self.num_heads)

    self.mlp_layer = MLPLayer(input_dim=self.input_dim, hidden_dim=4 * self.input_dim)

  def __call__(self, x):
    values = self.multi_head_attn_layer(x, x)
    output = self.mlp_layer(values)
    return output

#### Encoder

In [47]:
class Encoder(nn.Module):
  """Encoder that repeats the Encoder layer N times."""
  num_layers: int # N
  input_dim: int
  num_heads: int

  def setup(self):
    # TODO: add the embedding layer
    self.encoder_blocks = [
        EncoderBlock(input_dim=self.input_dim, num_heads=self.num_heads) for _ in range(self.num_layers)]

  def __call__(self, x):
    # TODO: add the embedding layer
    for encoder_block in self.encoder_blocks:
      x = encoder_block(x)
    return x

#### Decoder Block

In [22]:
class DecoderBlock(nn.Module):
  """Decoder block that contains Multi-head attention layer, cross-attention layer and MLP."""
  input_dim: int
  num_heads: int

  def setup(self):
    self.multi_head_attn_layer = MultiHeadAttnLayer(
        input_dim=self.input_dim, num_heads=self.num_heads)
    self.cross_attn_layer = MultiHeadAttnLayer(input_dim=self.input_dim,
                                               num_heads=self.num_heads)
    self.mlp_layer = MLPLayer(input_dim=self.input_dim, hidden_dim=self.input_dim * 4)


  def __call__(self, y, x):
    """y is either the initial input of decoder or output of last DecoderBlock.
    x is the output from the Encoder."""
    y = self.multi_head_attn_layer(y, y)
    y = self.cross_attn_layer(y, x)
    y = self.mlp_layer(y)
    return y

#### Decoder

In [23]:
class Decoder(nn.Module):
  num_layers: int
  input_dim: int
  num_heads: int

  def setup(self):
    self.decoder_block = DecoderBlock(input_dim=self.input_dim,
                                      num_heads=self.num_heads)
    self.projection = nn.Dense(features=self.input_dim)
    self.softmax = nn.softmax(features=self.input_dim)

  def __call__(self, x, y, is_train: bool = False):
    for i in range(self.num_layers):
      y = self.decoder_block(y, x)

    return y

#### Transformer

In [None]:
class Transformer(nn.Module):
  num_encoder_block: int
  num_decoder_block: int
  # input_dim is not sequence length, it's the embedding dimension
  input_dim: int
  num_heads: int
  def setup(self):
    self.encoder = Encoder(num_layers=self.num_encoder_block,
                           input_dim=self.input_dim,
                           num_heads=self.num_heads)
    self.decoder = Decoder(num_layers=self.num_decoder_block,
                           input_dim=self.input_dim,
                           num_heads=self.num_heads)

  def __call__(self, x, y):
    x = self.encoder(x)
    output = self.decoder(x, y)
    return output

### Initialize Transformer

In [None]:
transformer = Transformer(num_encoder_block=2,
                          num_decoder_block=3,
                          input_dim=10,
                          num_heads=2)

#### Create a dummy example as the input to initialize the linen module.


In [None]:
rng_params, rng_x, rng_y = random.split(rng_main, 3)

In [None]:
example_x = random.normal(rng_x, (64, 20, 10))
example_y = random.normal(rng_y, (64, 20, 10))

#### Initialize the linen module

In [None]:
params = transformer.init(random.PRNGKey(0), example_x, example_y)['params']

#### Try applying the model on the dummy input and see the output format

In [None]:
out = transformer.apply({'params': params}, example_x, example_y)
print('Out', out.shape)

Out (64, 20, 10)


In [None]:
out

DeviceArray([[[-0.81314814,  1.1046953 ,  0.4776487 , ...,  1.2827947 ,
               -0.8177259 ,  0.68436754],
              [-0.9457575 ,  1.0223197 ,  0.44467473, ...,  1.2053069 ,
               -0.87547475,  0.7679628 ],
              [ 0.8291353 , -1.0935264 , -0.3079091 , ..., -1.2200329 ,
                0.7818424 , -0.8695053 ],
              ...,
              [-0.94948304,  1.0248268 ,  0.45207748, ...,  1.207757  ,
               -0.868755  ,  0.7732253 ],
              [ 0.72673786, -1.1195886 , -0.19017659, ..., -1.2199332 ,
                0.7187557 , -0.9032288 ],
              [-0.9352995 ,  1.0279328 ,  0.44403443, ...,  1.2102603 ,
               -0.87217087,  0.7617321 ]],

             [[-0.63446015,  1.1582543 ,  0.1800934 , ...,  1.2679591 ,
               -0.7362134 ,  0.84023964],
              [ 0.76549387, -1.1043112 , -0.28653726, ..., -1.2490824 ,
                0.7787767 , -0.8253995 ],
              [-0.5244601 ,  1.173768  ,  0.05919825, ...,  1.26000

In [None]:
transformer.init(random.PRNGKey(0), example_x, example_y).keys()

frozen_dict_keys(['params'])

### Training functions

In [None]:
from optax._src.base import Updates

def predict_logits(params, input_sequences):
  # TODO: The second input parameter should not be input_sequences
  return transformer.apply({'params': params}, input_sequences, input_sequences)


def loss_and_accuracy(params, input_sequences, targets):
  """Computes ce loss."""
  logits = predict_logits(params, input_sequences)
  vocab_size = logits.shape[-1]
  target_onehot = jax.nn.one_hot(n_targets, num_classes=vocab_size)
  loss = optax.softmax_cross_entropy(logits, target_onehot).mean()
  accuracy = (logits.argmax(axis=-1) == targets).astype(jnp.float32).mean()
  return loss, accuracy

@jax.jit
def train_step(params, opt_state, batch):
  _, targets = batch
  input_sequences = batch_to_input(batch)
  # Calculate loss value and its gradients by the value_and_grad function
  loss_fn = lambda params: loss_and_accuracy(params, input_sequences, targets)
  ret, grads = jax.value_and_grad(
      loss_fn,
      has_aux=True)(params)
  loss, acc = ret[0], ret[1]
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss, acc

def batch_to_input(batch):
  inp_data, _ = batch
  # There are 10 digits (0, 1, ..., 9) and therefore when we make the "embedding",
  # num_classes is 10.
  inp_data = jax.nn.one_hot(inp_data, num_classes=10)
  return inp_data

def train_epoch(train_loader, epoch_idx: int, opt_state, params):
  accs, losses = [], []
  for batch in train_loader:
    params, opt_state, loss, accuracy = train_step(params, opt_state, batch)
    losses.append(loss)
    accs.append(accuracy)
  avg_loss = np.stack(jax.device_get(losses)).mean()
  avg_acc = np.stack(jax.device_get(accs)).mean()

def train_model(train_loader, val_loader, opt_state, params, num_epochs: int = 2):
  # Train model for defined number of epochs
  # best_acc = 0.0
  for epoch_idx in range(1, num_epochs+1):
    train_epoch(train_loader, epoch_idx=epoch_idx, opt_state=opt_state,
                params=params)
    # if epoch_idx % 5 == 0:
    #   eval_acc = eval_model(val_loader)
    #   logger.add_scalar('val/accuracy', eval_acc, global_step=epoch_idx)
    #   if eval_acc >= best_acc:
    #     best_acc = eval_acc
    #     save_model(step=epoch_idx)
    #   self.logger.flush()

NameError: ignored

# Preparing the dataset

## Dataset 1 Reversed Sequence

In [None]:
# Make a map-style PyTorch dataset
# (see more documentation at https://pytorch.org/docs/stable/data.html)
class ReverseDataset(data.Dataset):

    def __init__(self,
                 num_categories: int,
                 seq_len: int,
                 size: int,
                 np_rng: ...):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size
        self.np_rng = np_rng

        self.data = self.np_rng.integers(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        input_data = self.data[idx]
        labels = np.flip(input_data, axis=0)
        return input_data, labels

In [None]:
# Combine batch elements (all numpy) by stacking
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

dataset = partial(ReverseDataset, 20, 10)
rev_train_loader = data.DataLoader(dataset(50000, np_rng=np.random.default_rng(42)),
                                   batch_size=64,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=numpy_collate)
rev_val_loader   = data.DataLoader(dataset(1000, np_rng=np.random.default_rng(43)),
                                   batch_size=64,
                                   collate_fn=numpy_collate)
rev_test_loader  = data.DataLoader(dataset(10000, np_rng=np.random.default_rng(44)),
                                   batch_size=64,
                                   collate_fn=numpy_collate)

In [None]:
inp_data, labels = rev_train_loader.dataset[0]
print("Input data:", inp_data)
print("Labels:    ", labels)

Input data: [ 1 15 13  8  8 17  1 13  4  1]
Labels:     [ 1  4 13  1 17  8  8 13 15  1]


# Training

## Training loop

In [None]:
optimizer = optax.adam(learning_rate=1e-2)

In [None]:
opt_state = optimizer.init(params)

In [None]:
train_model(rev_train_loader, rev_val_loader, params=params, opt_state=opt_state)