# GPT2 instruction tuning

This notebook demonstrates how to finetune a pretrained GPT2(124M) model to follow user instructions. We are going to re-use a lot of code from the previous GPT2 pretraining notebook; as a matter of fact, the only major change needed to data preparation.

## Determine platform

In [1]:
import os
if os.path.exists('/content/'):
  platform = "Colab"
elif os.path.exists('/kaggle/'):
  platform = "Kaggle"
else:
  # Assume using Cloud TPU otherwise
  platform = "GCP"

## Setup

Install JAX and Flax first.

In [2]:
!pip install -q jax-ai-stack[grain]
if platform == "Colab": # temp workaround on Colab (https://github.com/jax-ml/jax-ai-stack/issues/149)
  !pip install -Uq "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info datasets


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jax-ai-stack 2025.4.9 requires jax==0.5.3, but you have jax 0.6.2 which is incompatible.[0m[31m
[0m

Confirm we have TPUs set up.

In [3]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [4]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import tiktoken, time, wandb

## Build the model

Define the device mesh.


In [5]:
### Alternative data and model parallel
# mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We are going to use the GPT-2 tokenizer via OpenAI's [Tiktoken](https://github.com/openai/tiktoken) library.

In [6]:
tokenizer = tiktoken.get_encoding("gpt2")

Set some hyperparameters.

In [7]:
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2" # "GPT2-medium"
if GPT2_variant == "GPT2-medium":
  num_transformer_blocks = 24
  seqlen = 1024
  embed_dim = 1024
  num_heads = 16
  feed_forward_dim = 4 * embed_dim
  batch_size = 16  # Can only run on TPU v3+
else: ## Assume GPT2 otherwise
  num_transformer_blocks = 12
  seqlen = 1024
  embed_dim = 768
  num_heads = 12
  feed_forward_dim = 4 * embed_dim
  if platform == "Colab":
      batch_size = 24 # TPU v2
  else:
      batch_size = 72 # TPU v3

dropout_rate = 0.1

max_steps = 600000*12//batch_size
# Kaggle TPU limit per session is 9 hours, which is ~95K steps for GPT2
if platform == "Kaggle":
  max_steps = 90000
init_learning_rate = 5e-4
weight_decay = 1e-1
top_k = 10
sampling_temp = 2
dtype = jnp.bfloat16
param_dtype = jnp.float32

Now define the model architecture, which is the same as in our previous pretraining notebook.

In [8]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, dropout_rate: float, rngs: nnx.Rngs):
        self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        self.mha = nnx.MultiHeadAttention(num_heads=num_heads,
                                          in_features=embed_dim,
                                          kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                          bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                          dtype=dtype,
                                          param_dtype=param_dtype,
                                          rngs=rngs)
        self.dropout1 = nnx.Dropout(rate=dropout_rate)  # Added dropout layer after MHA
        self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,
                                         num_features=embed_dim,
                                         scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),
                                         bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),
                                         dtype=dtype,
                                         param_dtype=param_dtype,
                                         rngs=rngs)
        self.linear1 = nnx.Linear(in_features=embed_dim,
                                  out_features=ff_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.linear2 = nnx.Linear(in_features=ff_dim,
                                  out_features=embed_dim,
                                  kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),
                                  bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                  dtype=dtype,
                                  param_dtype=param_dtype,
                                  rngs=rngs)
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(self, inputs, training: bool = False):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            inputs_q=self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            decode=False,
        )
        x = inputs + self.dropout1(attention_output, deterministic=not training)

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):

    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs):
        self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)
        self.pos_emb = nnx.Embed(num_embeddings=seqlen, features=embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs)

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding+position_embedding


class GPT2(nnx.Module):
    def __init__(self, seqlen: int, vocab_size: int, embed_dim: int, num_heads: int, rate: float, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
                    seqlen, vocab_size, embed_dim, rngs=rngs
                )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [TransformerBlock(
            embed_dim, num_heads, feed_forward_dim, dropout_rate, rngs=rngs
        ) for _ in range(num_transformer_blocks)]

        self.layer_norm = nnx.LayerNorm(epsilon=1e-6,
                                    num_features=embed_dim,
                                    scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),
                                    bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),
                                    dtype=dtype,
                                    param_dtype=param_dtype,
                                    rngs=rngs)

    def __call__(self, inputs, training: bool = False):
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, training=training)
        x = self.layer_norm(x)
        # Weights tying
        outputs = token_embedding.attend(x)
        return outputs

    @nnx.jit
    def sample_from(self, logits):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits/sampling_temp)
        return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)

    @nnx.jit
    def generate_step(self, padded_tokens, sample_index):
        logits = self(padded_tokens)
        next_token = self.sample_from(logits[0][sample_index])
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        generated = []
        print(tokenizer.decode(start_tokens), flush=True, end='')
        for i in range(max_tokens):
            sample_index = len(start_tokens) + len(generated) - 1
            # TODO: use attention masking for better efficiency
            padded_tokens = jnp.array((start_tokens + generated + [0] * (seqlen - len(start_tokens) - len(generated))))[None, :]
            next_token = int(self.generate_step(padded_tokens, sample_index))
            if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
              break
            generated.append(next_token)
            # decode and print next_token
            print(tokenizer.decode([next_token]), flush=True, end='')
        return tokenizer.decode(start_tokens + generated)

def create_model(rngs):
    return GPT2(seqlen, vocab_size, embed_dim, num_heads, dropout_rate, feed_forward_dim, num_transformer_blocks, rngs=rngs)

Use Weights and Biases to track training progress.

In [9]:
if platform == "Colab":
  from google.colab import userdata
  os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
  os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
  os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
elif platform == "Kaggle":
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  os.environ['WANDB_API_KEY'] = user_secrets.get_secret('WANDB_API_KEY')
else:
  print("Please set the WANDB_API_KEY env variable manually") #input()

wandb.login()

import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-it',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      'dataset': 'OpenWebText',
      'platform': platform,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': max_steps,
      'batch_size': batch_size,
      'weight_decay': weight_decay
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Instruct tune

We are going to use the [Alpaca dataset](https://huggingface.co/datasets/tatsu-lab/alpaca) from Stanford.

In [10]:
import grain.python as pygrain
import pandas as pd
from datasets import load_dataset

@dataclass
class TextDataset:
    data_df: list
    seqlen: int

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        encoding = tokenizer.encode(
            self.data_df.iloc[idx], allowed_special={"<|endoftext|>"}
        )[:self.seqlen        ]
        return encoding + [50256] * (self.seqlen - len(encoding))


def load_and_preprocess_data(alpaca_data, batch_size, seqlen):
    alpaca_data_df = pd.DataFrame(alpaca_data)
    dataset = TextDataset(alpaca_data_df["text"], seqlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=1,
    )
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    return dl


alpaca_data = load_dataset("tatsu-lab/alpaca", split="train")
text_dl = load_and_preprocess_data(alpaca_data, batch_size, seqlen)

Define the loss and training step function.

In [11]:
@nnx.jit
def loss_fn(model, batch):
    logits = model(batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch[1]
    ).mean()
    return loss, logits


@nnx.jit
def train_step(
    model: nnx.Module, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, batch)
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(grads)

Load the checkpoint from Kaggle. This checkpoint was saved from our previous pretraining.

In [12]:
import kagglehub
import orbax.checkpoint as orbax

model = nnx.eval_shape(lambda: create_model(rngs=nnx.Rngs(0)))
state = nnx.state(model)
checkpointer = orbax.PyTreeCheckpointer()
checkpoint_path = kagglehub.model_download("windmaple/gpt2/jax/124m")
state = checkpointer.restore(checkpoint_path, item=state)
nnx.update(model, state)



Define optimizer and metrics.

In [13]:
schedule = optax.cosine_decay_schedule(
    init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
    optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)

metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
)

metrics_history = {
    "train_loss": [],
}

Do a test run on our pretrained model to see how it reponds to instruction. Note how we use a template to format the prompt, which needs to be consistent with the training data.

In [14]:
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"

start_prompt = template.format(
    instruction="What is the future for human?",
    input="",
    output="",
)
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"***Initial generated text:")
generated_text = model.generate_text(seqlen//5, start_tokens)

***Initial generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:

## Input:FIGURING_HANDLER

### Response:


As you can see, the pretrained model generates a bunch of garbage; clearly it does not know how to follow the instruction to generate an appropriate answer, which is not surprising given that we have not trained it to do so.

Now let's do the instruction tuning.

In [15]:
prep_target_batch = jax.vmap(
    lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0])))
)

step = 0
start_time = time.time()
for batch in text_dl:
    if len(batch) % len(jax.devices()) != 0:
        continue  # skip the remaining elements
    input_batch = jnp.array(batch).T
    target_batch = prep_target_batch(input_batch)
    train_step(
        model,
        optimizer,
        metrics,
        jax.device_put(
            (input_batch, target_batch), NamedSharding(mesh, P("batch", None))
        ),
    )

    if (step + 1) % 20 == 0:
        for metric, value in metrics.compute().items():
            metrics_history[f"train_{metric}"].append(value)
        metrics.reset()

        elapsed_time = time.time() - start_time
        print(
            f"\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds"
        )
        wandb.log(data={'train_loss': metrics_history['train_loss'][-1]}, step=step)
        start_time = time.time()
        print(f"\n***Intermediate generated text:")
        generated_text = model.generate_text(seqlen // 5, start_tokens)
    step += 1

# Final text generation
print(f"\n***Final generated text:")
generated_text = model.generate_text(seqlen // 5, start_tokens)


Step 20, Loss: 1.5087890625, Elapsed Time: 46.40 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
###:
###:
###.
Step 40, Loss: 0.39106446504592896, Elapsed Time: 24.25 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
A future is the future, the future will be a new future.
Step 60, Loss: 0.28657227754592896, Elapsed Time: 4.74 seconds

***Intermediate generated text:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
Human will have the ability to learn from the future and the future will be the same for human lif

As you can see, at the end of the finetuning, the model is able to generate a somewhat sensible answer, based on the simple user question. This means our instruction tuning actually works.

However, GPT2 is fundamentally a very small (and weak) model, especially the 124M variant we are using here. If you ask the model a more challenging question, it will easily break down.

But hopefully you have learned the key concept behind instruction tuning through this simple notebook.