This notebook demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is a scaled-up version of the miniGPT tutorial.

We will use TPU and [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) to train a bigger language model. You can use free TPU on Colab or Kaggle.

## Setup

Install JAX and Flax first. Confirm we have TPUs set up.

In [1]:
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -U flax orbax tiktoken wandb datasets

import jax, os
jax.devices()
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


Get the SlimPajama-6B dataset from HF.

In [2]:
if not os.path.exists('/content/drive'):
  from google.colab import drive
  drive.mount('/content/drive')

REPO_ID = "DKYoon/SlimPajama-6B"
# Cache the dataset (24GB) to gDrive; make sure your gDrive is big enough
CACHE_DIR = "/content/drive/MyDrive/LLM-pretraining/SlimPajama-6B"
os.environ["HF_DATASETS_CACHE"] = "/content/drive/MyDrive/LLM-pretraining/.cache"

# check if CACHE_DIR exists
if not os.path.exists(CACHE_DIR):
    !pip install huggingface_hub
    from huggingface_hub import snapshot_download
    snapshot_download(repo_id="DKYoon/SlimPajama-6B", local_dir=CACHE_DIR, repo_type='dataset')

# Get the [TinyStories dataset](https://www.kaggle.com/datasets/thedevastator/tinystories-narrative-classification) from Kaggle.
# !pip install kaggle
# from google.colab import userdata
# import os
# os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
# os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# if not os.path.exists("TinyStories.csv"):
#   !kaggle datasets download -d thedevastator/tinystories-narrative-classification && unzip tinystories-narrative-classification.zip && ln -s train.csv TinyStories.csv

Take care of the imports.

In [3]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax.checkpoint as orbax
from typing import Any
import os
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import tiktoken, time
# from datasets import load_dataset
from functools import partial
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

## Build the model

Define the mesh. We are going to use 4-way data parallel and 2-way tensor parallel. See this [JAX tutorial](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism-and-2-way-model-tensor-parallelism) for more information.

In [4]:
mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
# mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

Defne the model architecture. We are going to the GPT-2 tokenizer via [Tiktoken](https://github.com/openai/tiktoken) and we shard model layers along the `model` axis.

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                          rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=rate)
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                         rngs=rngs)
        self.linear1 = nnx.Linear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        self.linear2 = nnx.Linear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=rate)
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                         rngs=rngs)


    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        _, seq_len, _ = input_shape

        # Create causal mask
        mask = causal_attention_mask(seq_len)

        # Apply MultiHeadAttention with causal mask
        attention_output = self.mha(
            inputs_q=inputs,
            mask=mask,
            decode=False
        )
        attention_output = self.dropout1(attention_output, deterministic=not training)
        out1 = self.layer_norm1(inputs + attention_output)

        # Feed-forward network
        ffn_output = self.linear1(out1)
        ffn_output = nnx.relu(ffn_output)
        ffn_output = self.linear2(ffn_output)
        ffn_output = self.dropout2(ffn_output, deterministic=not training)

        return self.layer_norm2(out1 + ffn_output)


class TokenAndPositionEmbedding(nnx.Module):

    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)
        self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return token_embedding + position_embedding


class MiniGPT(nnx.Module):
    def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
                    maxlen, vocab_size, embed_dim, rngs=rngs
                )
        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, rngs=rngs
        ) for _ in range(num_transformer_blocks)]

        self.output_layer = nnx.Linear(in_features=embed_dim,
                                       out_features=vocab_size,
                                       kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                       bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                       rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        x = self.embedding_layer(inputs)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        outputs = self.output_layer(x)
        return outputs

def create_model(rngs):
    return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs)


Set some hyperparameters. The model is much bigger with a lot more transformer layers and attention heads.

In [6]:
# vocab_size = tokenizer.n_vocab
# num_transformer_blocks = 12
# maxlen = 200
# embed_dim = 512
# num_heads = 8
# feed_forward_dim = 512
# batch_size = 304
# num_epochs = 1
# learning_rate = 5e-4


# TPU v2
vocab_size = tokenizer.n_vocab
num_transformer_blocks = 12
maxlen = 200
embed_dim = 512
num_heads = 8
feed_forward_dim = 512
batch_size = 600
num_epochs = 2
init_learning_rate = 5e-4

In [7]:
from google.colab import userdata
os.environ["WANDB_API_KEY"] = userdata.get('WANDB_API_KEY')
!wandb login

import wandb
wandb.init(
    # set the wandb project where this run will be logged
    project="LLM-pretraining",

    # track hyperparameters and run metadata
    config={
    "architecture": "miniGPT",
    "dataset": "TinyStories",
    "init_learning_rate": init_learning_rate,
    "num_transformer_blocks": num_transformer_blocks,
    "maxlen": maxlen,
    "embed_dim": embed_dim,
    "num_heads": num_heads,
    "feed_forward_dim": feed_forward_dim,
    "batch_size": batch_size,
    "num_epochs": num_epochs
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Prepare data

Data loading and preprocessing with [Grain](https://github.com/google/grain).

In [8]:
# TODO: make a streaming dataset
class TextDataset(Dataset):
    def __init__(self, file_path, maxlen):
        if filename.lower().endswith('.csv'):
            self.data = pd.read_csv(file_path).dropna()
        elif filename.lower().endswith('.parquet'):
            self.data = pd.read_parquet(file_path).dropna()
        else
            self.data = pd.read_parquet(file_path).dropna()
        self.maxlen = maxlen

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['text']
        encoding = tokenizer.encode(text, allowed_special={'<|endoftext|>'})[:maxlen]
        padding = [0] * (self.maxlen - len(encoding))
        # return np.array(encoding + padding, dtype=np.int32)

        input_sequence = np.array(encoding + padding, dtype=np.int32)
        target_sequence = np.roll(input_sequence, -1)
        target_sequence[-1] = 0  # Set the last token to 0 (padding)

        return input_sequence, target_sequence

def load_and_preprocess_data(file_path, batch_size, maxlen):
    dataset = TextDataset(file_path, maxlen)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=10, prefetch_factor=2)


# dataset = load_and_preprocess_data('TinyStories.csv', batch_size, maxlen)
dataloader = load_and_preprocess_data(os.path.join(CACHE_DIR, 'data/'), batch_size, maxlen)
# dataloader = load_and_preprocess_data(os.path.join(CACHE_DIR, 'data/train-00000-of-00048-ab2b35705f029d94.parquet'), batch_size, maxlen)

## Train the model

Define a helper function for generating text given a model and prompt.

In [9]:
def generate_text(model: MiniGPT, max_tokens: int, start_tokens: [int], top_k=10):
    def sample_from(logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    def generate_step(start_tokens):
        pad_len = maxlen - len(start_tokens)
        sample_index = len(start_tokens) - 1
        if pad_len < 0:
            x = jnp.array(start_tokens[:maxlen])
            sample_index = maxlen - 1
        elif pad_len > 0:
            x = jnp.array(start_tokens + [0] * pad_len)
        else:
            x = jnp.array(start_tokens)

        x = x[None, :]
        logits = model(x)
        next_token = sample_from(logits[0][sample_index])
        return next_token

    generated = []
    for _ in range(max_tokens):
        next_token = generate_step(start_tokens + generated)
        generated.append(int(next_token))
    return tokenizer.decode(start_tokens + generated)

Define loss function and training step function.

In [10]:
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits

@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Start training.

In [None]:
model = create_model(rngs=nnx.Rngs(0))
# optimizer = nnx.Optimizer(model, optax.adam(1e-3))

schedule = optax.cosine_decay_schedule(
  init_value=init_learning_rate,
  decay_steps=num_epochs*len(dataloader),
)
optax_chain = optax.chain(
  optax.adamw(learning_rate=schedule),
)
optimizer = nnx.Optimizer(model, optax_chain)

metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)
rng = jax.random.PRNGKey(0)

start_prompt = "The story begins with a little boy who has a dream"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
generated_text = generate_text(
    model, maxlen, start_tokens,
)
print(f"Initial generated text:\n{generated_text}")


metrics_history = {
  'train_loss': [],
}

for epoch in range(num_epochs):
    start_time = time.time()

    step = 0
    for input_batch, target_batch in dataloader:
        train_step(model, optimizer, metrics, jax.device_put((np.array(input_batch), np.array(target_batch)), NamedSharding(mesh, P('batch', None))))

        if (step + 1) % 100 == 0:  # Check if the batch index is a multiple of 100
            for metric, value in metrics.compute().items():  # compute metrics
                metrics_history[f'train_{metric}'].append(value)  # record metrics
                wandb.log({metric: value})  # log metrics to wandb
            metrics.reset()

            elapsed_time = time.time() - start_time  # Calculate elapsed time
            print(f"Epoch {epoch + 1}, Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds")
            start_time = time.time()

            generated_text = generate_text(
                model, maxlen, start_tokens,
            )
            print(f"Generated text after batch {step + 1}:\n{generated_text}\n")
            !tpu-info
        step += 1


# Final text generation
generated_text = generate_text(
    model, maxlen, start_tokens,
)
print(f"Final generated text:\n{generated_text}")


Initial generated text:
The story begins in the midcentury RAW Bugatomic Transfer Glekar Voltoutegently principle prohibition652clear NEO intermedi barr Volt dischargedPrivgently Jeanne comeONE Rac五 virgingentlygently Jeanne Incident capt indicationTile Riding revisitoux suddenlyprints conventionssoft roughly EuropeansInvestcentral whims GleDay NG NGOs () () roughly roughly Idaho windoux proposes suddenlyocalStretch Firefox五 assistants Jeanne Discordplate Board491 ()ishing Firefoxancer Lift NEO him certify readable diagnostic五brance Jeannecentralium Positionainment652 shield naming REG NGOs terrifyingfoot accommodation Lana triggeringTel PP Volt brainstorm removed incent barrgentlyHR riots wildfire Pv substclear barr Volt Hussain Less substance Firefox precarious HTC Kers Gle fallout Australian ()gently testify Gle Flat incons criteriaMathTel=""五 underwentjustice Yards flash Meow Discord appalled EgOne amuletlustdrivingENCElargest Historic consultation ashamed Dungeon habitgently Commo

  self.pid = os.fork()


Epoch 1, Step 100, Loss: 7.411874771118164, Elapsed Time: 57.59 seconds
Generated text after batch 100:
The story begins in the midcentury..!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

[3mTPU Chips                                      [0m
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mChip       [0m[1m [0m┃[1m [0m[1mType       [0m[1m [0m┃[1m [0m[1mDevices[0m[1m [0m┃[1m [0m[1mPID  [0m[1m [0m┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ /dev/accel0 │ TPU v2 chip │ 2       │ 55695 │
│ /dev/accel1 │ TPU v2 chip │ 2       │ 55695 │
│ /dev/accel2 │ TPU v2 chip │ 2       │ 55695 │
│ /dev/accel3 │ TPU v2 chip │ 2       │ 55695 │
└─────────────┴─────────────┴─────────┴───────┘
Connected to libtpu at grpc://localhost:8431...
[3mTPU Utilization                              [0m
┏━━━━━━━━┳━━━━━━━━━━━━

  self.pid = os.fork()
  self.pid = os.fork()


Epoch 2, Step 100, Loss: 3.584099292755127, Elapsed Time: 34.47 seconds
Generated text after batch 100:
The story begins in the midcentury of the American Revolution, where it is, the people of the American Revolution are born and lived. In the midcentury, the Americans are living with a common belief in the world, the Americans who have the right to live their lives, the people who are living with their lives and their lives. The story begins in the early 1700s, where the Americans who live with the Americans are living with their lives. The American Revolution, which began with the American Revolution, was created in the midcentury.
In the early 1700s, the American Revolution began to grow. In the midcentury, the American Revolution was a small and medium-sized American. In the early 1700s, the Americans had the American Revolution, which began to spread over the years, were created to become a large and powerful force in the United States.
The American Revolution, the American Revol

As you can see, the model goes from generating completely random words at the beginning to generating tiny stories at the end of the training. So essentially we have pretrained a small LLM to write tiny stories for us.

# Saving
Colab TPU v2 has a problem when saving the model weights. Kaggle TPU v3 works.

In [None]:
# Don't do this on Colab.

# import orbax.checkpoint as orbax

# state = nnx.state(model)

# checkpointer = orbax.PyTreeCheckpointer()
# checkpointer.save('/kaggle/working/state/', state)

# Disconnect the Colab runtime

In [None]:
from google.colab import runtime
runtime.unassign()

wandb.finish()