<a href="https://colab.research.google.com/github/zhenyiqi/rawLLM/blob/main/TransformerJax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import time
import math
from functools import partial

In [None]:
import numpy as np

from torch.utils import data
from torchvision.datasets import MNIST

In [None]:
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from jax import grad, jit, vmap
from jax import random
import jax
from flax.training import train_state, checkpoints

In [None]:
! pip install flax --quiet

[?25l[K     |█▋                              | 10 kB 20.2 MB/s eta 0:00:01[K     |███▎                            | 20 kB 6.9 MB/s eta 0:00:01[K     |█████                           | 30 kB 9.5 MB/s eta 0:00:01[K     |██████▋                         | 40 kB 4.2 MB/s eta 0:00:01[K     |████████▎                       | 51 kB 4.3 MB/s eta 0:00:01[K     |██████████                      | 61 kB 5.1 MB/s eta 0:00:01[K     |███████████▋                    | 71 kB 5.6 MB/s eta 0:00:01[K     |█████████████▎                  | 81 kB 5.3 MB/s eta 0:00:01[K     |███████████████                 | 92 kB 5.9 MB/s eta 0:00:01[K     |████████████████▋               | 102 kB 4.8 MB/s eta 0:00:01[K     |██████████████████▎             | 112 kB 4.8 MB/s eta 0:00:01[K     |████████████████████            | 122 kB 4.8 MB/s eta 0:00:01[K     |█████████████████████▋          | 133 kB 4.8 MB/s eta 0:00:01[K     |███████████████████████▎        | 143 kB 4.8 MB/s eta 0:00:01[K    

In [None]:
import flax
from flax import linen as nn

In [None]:
## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

# Define constants

In [None]:
rng_main = random.PRNGKey(0)

In [None]:
batch_size = 16
n_targets = 10
num_epochs = 5

layer_sizes = [784, 512, 512, 10]
step_size = 0.01

# Training Prep

In [None]:
rng = random.PRNGKey(1)

## Utils for initalizing parameters

In [None]:
from traitlets.traitlets import Tuple

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(input_dim, output_dim, key, scale=1e-2):
  w_key, b_key = random.split(key)
  # random.normal(w_key, (n, m)) generates a random matrix of dimension (n, m)
  return (scale * random.normal(w_key, (output_dim, input_dim)),
          scale * random.normal(b_key, (output_dim,)))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key: ...):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k)
          for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

## Core training utils

### Builds the model

#### Scaled-dot production attention

In [None]:
def scaled_dot_product_attention(q, k, v, mask = None):
  """Q attend to the K-V pair.

  q, k, v are matrices of size (batch_size * ) T * k_q, where, T is the sequence length,
  and k_q is the hidden dimension.
  We may add dimensions in the front for parallelized computation."""
  # print('query is: ', q)
  # print('key is: ', k)
  # print('value is: ', v)
  k_q = q.shape[-1]

  # attention_logits is of dimension (batch_size * ) T * T
  attention_logits = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / math.sqrt(k_q)
  if mask is not None:
    attention_logits = jnp.where(mask == 0, -9e15, attention_logits)
  attention = nn.softmax(attention_logits)

  # values is of dimension (batch_size * ) T * k_q again.
  values = jnp.matmul(attention, v)
  return values

#### MLP Layer

In [None]:
def relu(x):
  return jnp.maximum(0, x)

class MLPLayer(nn.Module):
  hidden_dim: int
  num_hidden: int

  def setup(self):
    self.layer_norm = nn.LayerNorm()
    self.dense_layers = [nn.Dense(
        self.hidden_dim) for _ in range(self.num_hidden)]

  def __call__(self, input):
    x = input
    for layer in self.dense_layers:
      x = layer(x)
    x = self.layer_norm(x)
    return x

#### MultiHeadAttention Layer

In [None]:
class MultiHeadAttnLayer(nn.Module):
  output_dim: int  # output_dim == embedding_dim
  num_heads: int = 1

  def setup(self):
    if self.output_dim % (self.num_heads) != 0:
      raise ValueError(
          'output_dim for a MultiHeadAttnLayer must be multiples of '
          'num_heads.')

    # Projection layers:
    # 1) [0, num_heads) are applied to q
    # 2) [num_heads, num_heads * 2) are applied to k
    # 3) [2 *num_heads, 3 * num_heads) are applied to v
    self.qkv_projs = [nn.Dense(self.output_dim // self.num_heads,
                               kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
                               bias_init=nn.initializers.zeros) for _ in range(3 * self.num_heads)]

    self.layer_norm = nn.LayerNorm()


  def __call__(self, q, k, v, mask = None):
    batch_size, sequence_length, embed_dim = q.shape

    # after the code block below, the dimension of Q, K, V becomes
    # [batch_size, sequence_length, output_dim / (3 * num_heads), 3]
    q = jnp.stack(tuple(self.qkv_projs[i](q) for i in range(self.num_heads)))
    k = jnp.stack(tuple(self.qkv_projs[i](k) for i in range(self.num_heads, 2 * self.num_heads)))
    v = jnp.stack(tuple(self.qkv_projs[i](k) for i in range(2 * self.num_heads, 3 * self.num_heads)))

    q = q.reshape(batch_size, sequence_length, -1)
    k = k.reshape(batch_size, sequence_length, -1)
    v = v.reshape(batch_size, sequence_length, -1)

    values = scaled_dot_product_attention(q, k, v, mask=mask)
    # skip-add operation
    values = values + q
    values = self.layer_norm(values)

    return values

#### Encoder Block

In [None]:

class EncoderBlock(nn.Module):
  input_dim: int

  def setup(self):
    self.multi_head_attn_layer = MultiHeadAttnLayer(output_dim=self.input_dim)
    self.mlp_layer = MLPLayer(hidden_dim=self.input_dim, num_hidden=5)

  def __call__(self, q, k, v):
    values = self.multi_head_attn_layer(q, k, v)
    output = self.mlp_layer(values)
    return output

#### Encoder

In [None]:
class Encoder(nn.Module):
  num_encoder_block: int
  input_dim: int

  def setup(self):
    self.encoder_block = EncoderBlock(input_dim=self.input_dim)

  def __call__(self, x):
    for i in range(self.num_encoder_block):
      x = self.encoder_block(x, x, x)
    return x

#### Decoder Block

In [None]:

class DecoderBlock(nn.Module):
  input_dim: int

  def setup(self):
    self.multi_head_attn_layer = MultiHeadAttnLayer(output_dim=self.input_dim)
    self.cross_attn_layer = MultiHeadAttnLayer(output_dim=self.input_dim)
    self.mlp_layer = MLPLayer(hidden_dim=self.input_dim, num_hidden=5)

  def __call__(self, y, x):
    """y is either the initial input of decoder or output of last DecoderBlock.
    x is the output from the Encoder."""
    q = self.multi_head_attn_layer(y, y, y)
    k, v = x, x
    y = self.cross_attn_layer(q, k, v)
    y = self.mlp_layer(y)
    return y

#### Decoder

In [None]:
class Decoder(nn.Module):
  num_decoder_block: int
  input_dim: int
  output_dim: int

  def setup(self):
    self.decoder_block = DecoderBlock(input_dim=self.input_dim)
    self.projection = nn.Dense(features=self.output_dim)

  def __call__(self, x, y, is_train: bool = False):
    for i in range(self.num_decoder_block):
      if not is_train:
        y = self.decoder_block(y, x)

    return y

#### Transformer

In [None]:
class Transformer(nn.Module):
  num_encoder_block: int
  num_decoder_block: int
  # input_dim is not sequence length, it's the embedding dimension
  input_dim: int
  output_dim: int
  def setup(self):
    self.encoder = Encoder(num_encoder_block=self.num_encoder_block,
                           input_dim=self.input_dim)
    self.decoder = Decoder(num_decoder_block=self.num_decoder_block,
                           input_dim=self.input_dim,
                           output_dim=self.output_dim)

  def __call__(self, x, y):
    x = self.encoder(x)
    output = self.decoder(x, y)
    return output

### Initialize Transformer

In [None]:
transformer = Transformer(num_encoder_block=2,
                          num_decoder_block=3,
                          input_dim=10,
                          output_dim=10)

#### Create a dummy example as the input to initialize the linen module.


In [None]:
rng_params, rng_x, rng_y = random.split(rng_main, 3)

In [None]:
example_x = random.normal(rng_x, (64, 20, 10))
example_y = random.normal(rng_y, (64, 20, 10))

#### Initialize the linen module

In [None]:
params = transformer.init(random.PRNGKey(0), example_x, example_y)['params']

#### Try applying the model on the dummy input and see the output format

In [None]:
out = transformer.apply({'params': params}, example_x, example_y)
print('Out', out.shape)

Out (64, 20, 10)


In [None]:
out

DeviceArray([[[-0.81314814,  1.1046953 ,  0.4776487 , ...,  1.2827947 ,
               -0.8177259 ,  0.68436754],
              [-0.9457575 ,  1.0223197 ,  0.44467473, ...,  1.2053069 ,
               -0.87547475,  0.7679628 ],
              [ 0.8291353 , -1.0935264 , -0.3079091 , ..., -1.2200329 ,
                0.7818424 , -0.8695053 ],
              ...,
              [-0.94948304,  1.0248268 ,  0.45207748, ...,  1.207757  ,
               -0.868755  ,  0.7732253 ],
              [ 0.72673786, -1.1195886 , -0.19017659, ..., -1.2199332 ,
                0.7187557 , -0.9032288 ],
              [-0.9352995 ,  1.0279328 ,  0.44403443, ...,  1.2102603 ,
               -0.87217087,  0.7617321 ]],

             [[-0.63446015,  1.1582543 ,  0.1800934 , ...,  1.2679591 ,
               -0.7362134 ,  0.84023964],
              [ 0.76549387, -1.1043112 , -0.28653726, ..., -1.2490824 ,
                0.7787767 , -0.8253995 ],
              [-0.5244601 ,  1.173768  ,  0.05919825, ...,  1.26000

In [None]:
transformer.init(random.PRNGKey(0), example_x, example_y).keys()

frozen_dict_keys(['params'])

### Training functions

In [None]:
from optax._src.base import Updates

def predict_logits(params, input_sequences):
  # TODO: The second input parameter should not be input_sequences
  return transformer.apply({'params': params}, input_sequences, input_sequences)


def loss_and_accuracy(params, input_sequences, targets):
  """Computes ce loss."""
  logits = predict_logits(params, input_sequences)
  vocab_size = logits.shape[-1]
  target_onehot = jax.nn.one_hot(n_targets, num_classes=vocab_size)
  loss = optax.softmax_cross_entropy(logits, target_onehot).mean()
  accuracy = (logits.argmax(axis=-1) == targets).astype(jnp.float32).mean()
  return loss, accuracy

@jax.jit
def train_step(params, opt_state, batch):
  _, targets = batch
  input_sequences = batch_to_input(batch)
  # Calculate loss value and its gradients by the value_and_grad function
  loss_fn = lambda params: loss_and_accuracy(params, input_sequences, targets)
  ret, grads = jax.value_and_grad(
      loss_fn,
      has_aux=True)(params)
  loss, acc = ret[0], ret[1]
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss, acc

def batch_to_input(batch):
  inp_data, _ = batch
  # There are 10 digits (0, 1, ..., 9) and therefore when we make the "embedding",
  # num_classes is 10.
  inp_data = jax.nn.one_hot(inp_data, num_classes=10)
  return inp_data

def train_epoch(train_loader, epoch_idx: int, opt_state, params):
  accs, losses = [], []
  for batch in train_loader:
    params, opt_state, loss, accuracy = train_step(params, opt_state, batch)
    losses.append(loss)
    accs.append(accuracy)
  avg_loss = np.stack(jax.device_get(losses)).mean()
  avg_acc = np.stack(jax.device_get(accs)).mean()

def train_model(train_loader, val_loader, opt_state, params, num_epochs: int = 2):
  # Train model for defined number of epochs
  # best_acc = 0.0
  for epoch_idx in range(1, num_epochs+1):
    train_epoch(train_loader, epoch_idx=epoch_idx, opt_state=opt_state,
                params=params)
    # if epoch_idx % 5 == 0:
    #   eval_acc = eval_model(val_loader)
    #   logger.add_scalar('val/accuracy', eval_acc, global_step=epoch_idx)
    #   if eval_acc >= best_acc:
    #     best_acc = eval_acc
    #     save_model(step=epoch_idx)
    #   self.logger.flush()

NameError: ignored

# Preparing the dataset

## Dataset 1 Reversed Sequence

In [None]:
# Make a map-style PyTorch dataset
# (see more documentation at https://pytorch.org/docs/stable/data.html)
class ReverseDataset(data.Dataset):

    def __init__(self,
                 num_categories: int,
                 seq_len: int,
                 size: int,
                 np_rng: ...):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size
        self.np_rng = np_rng

        self.data = self.np_rng.integers(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        input_data = self.data[idx]
        labels = np.flip(input_data, axis=0)
        return input_data, labels

In [None]:
# Combine batch elements (all numpy) by stacking
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

dataset = partial(ReverseDataset, 20, 10)
rev_train_loader = data.DataLoader(dataset(50000, np_rng=np.random.default_rng(42)),
                                   batch_size=64,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=numpy_collate)
rev_val_loader   = data.DataLoader(dataset(1000, np_rng=np.random.default_rng(43)),
                                   batch_size=64,
                                   collate_fn=numpy_collate)
rev_test_loader  = data.DataLoader(dataset(10000, np_rng=np.random.default_rng(44)),
                                   batch_size=64,
                                   collate_fn=numpy_collate)

In [None]:
inp_data, labels = rev_train_loader.dataset[0]
print("Input data:", inp_data)
print("Labels:    ", labels)

Input data: [ 1 15 13  8  8 17  1 13  4  1]
Labels:     [ 1  4 13  1 17  8  8 13 15  1]


# Training

## Training loop

In [None]:
optimizer = optax.adam(learning_rate=1e-2)

In [None]:
opt_state = optimizer.init(params)

In [None]:
train_model(rev_train_loader, rev_val_loader, params=params, opt_state=opt_state)