# GPT2 DPO finetuning

This notebook demonstrates how to finetune a instruction-tuned GPT2(124M) model with [Direct Preference Optimization](https://arxiv.org/pdf/2305.18290). Note that this notebook only works on Kaggle TPU v3 or Clout TPU v3+ (Colab TPU v2 simply does not have enough HBM).

## Determine platform

In [1]:
import os
if os.path.exists('/kaggle/'):
  platform = "Kaggle"
else:
  # Assume using Cloud TPU otherwise
  platform = "GCP"

## Setup

Install JAX and Flax first.

In [2]:
!pip install -q jax-ai-stack[grain]
if platform == "Colab": # temp workaround on Colab (https://github.com/jax-ml/jax-ai-stack/issues/149)
  !pip install -Uq "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -Uq tiktoken matplotlib kaggle wandb tpu-info datasets

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Confirm we have TPUs set up.

In [3]:
import jax
jax.devices()

E0000 00:00:1752653446.244167    3695 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:230


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

Take care of the imports.

In [4]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax, orbax
from collections import Counter
from dataclasses import dataclass
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import numpy as np
import tiktoken, time, wandb
from huggingface_hub import snapshot_download
from safetensors import safe_open
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


## Build the model

Define the device mesh.


In [5]:
mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

We are going to use the GPT-2 tokenizer via OpenAI's [Tiktoken](https://github.com/openai/tiktoken) library.

In [6]:
tokenizer = tiktoken.get_encoding("gpt2")

Set some hyperparameters.

In [7]:
vocab_size = tokenizer.n_vocab
GPT2_variant = "GPT2"

num_transformer_blocks = 12
seqlen = 1024
embed_dim = 768
num_heads = 12
feed_forward_dim = 4 * embed_dim
batch_size = 64
dropout_rate = 0.1

init_learning_rate = 1e-5 #5e-4
weight_decay = 1e-1
top_k = 10
sampling_temp = 2
dtype = jnp.bfloat16
param_dtype = jnp.float32
beta = 0.1
max_steps = 200

Now define the model architecture, which is the same as in our previous instruction tuning notebook.

In [8]:
def causal_attention_mask(seq_len):
    return jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))

class CustomMHA(nnx.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate, layer_idx, rngs):
        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.embed_dim = embed_dim

        kernel_init = nnx.with_partitioning(
            nnx.initializers.xavier_uniform(), (P(None, "model"),)
        )

        self.query = nnx.Linear(
            embed_dim, embed_dim, rngs=rngs, use_bias=False, kernel_init=kernel_init
        )
        self.key = nnx.Linear(
            embed_dim, embed_dim, rngs=rngs, use_bias=False, kernel_init=kernel_init
        )
        self.value = nnx.Linear(
            embed_dim, embed_dim, rngs=rngs, use_bias=False, kernel_init=kernel_init
        )
        self.out = nnx.Linear(
            embed_dim, embed_dim, rngs=rngs, use_bias=False, kernel_init=kernel_init
        )

        self.q_bias = nnx.Param(
            jnp.zeros((embed_dim,), dtype=param_dtype), sharding=P("model")
        )
        self.k_bias = nnx.Param(
            jnp.zeros((embed_dim,), dtype=param_dtype), sharding=P("model")
        )
        self.v_bias = nnx.Param(
            jnp.zeros((embed_dim,), dtype=param_dtype), sharding=P("model")
        )

        self.out_bias = nnx.Param(
            jnp.zeros((embed_dim,), dtype=param_dtype), sharding=P("model")
        )

        self.dropout = nnx.Dropout(dropout_rate)

    def __call__(
        self, x, mask, padding_mask=None, training: bool = False, rngs: nnx.Rngs = None
    ):
        batch_size, seq_len, _ = x.shape

        q = self.query(x) + self.q_bias
        k = self.key(x) + self.k_bias
        v = self.value(x) + self.v_bias

        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(
            (0, 2, 1, 3)
        )

        attn_weights = jnp.matmul(q, k.transpose((0, 1, 3, 2))) / jnp.sqrt(
            self.head_dim
        )

        combined_mask = mask
        if padding_mask is not None:
            combined_mask = jnp.logical_and(mask, padding_mask)

        if combined_mask is not None:
            attn_weights = jnp.where(combined_mask, attn_weights, -jnp.inf)

        attn_weights = nnx.softmax(attn_weights, axis=-1)
        attn_weights = self.dropout(attn_weights, deterministic=not training, rngs=rngs)

        attn_output = jnp.matmul(attn_weights, v)
        attn_output = attn_output.transpose((0, 2, 1, 3)).reshape(
            (batch_size, seq_len, self.embed_dim)
        )

        output = self.out(attn_output) + self.out_bias
        return output


class TransformerBlock(nnx.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        ff_dim: int,
        dropout_rate: float,
        rngs: nnx.Rngs,
        layer_idx: int,
    ):
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.mha = CustomMHA(embed_dim, num_heads, dropout_rate, layer_idx, rngs)
        self.dropout1 = nnx.Dropout(rate=dropout_rate)
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.linear1 = nnx.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.linear2 = nnx.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, "model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.dropout2 = nnx.Dropout(rate=dropout_rate)

    def __call__(
        self, inputs, padding_mask=None, training: bool = False, rngs: nnx.Rngs = None
    ):
        input_shape = inputs.shape
        bs, seq_len, emb_sz = input_shape

        attention_output = self.mha(
            self.layer_norm1(inputs),
            mask=causal_attention_mask(seq_len),
            padding_mask=padding_mask,
            training=training,
            rngs=rngs,
        )
        x = inputs + self.dropout1(
            attention_output, deterministic=not training, rngs=rngs
        )

        # MLP
        mlp_output = self.linear1(self.layer_norm2(x))
        mlp_output = nnx.gelu(mlp_output)
        mlp_output = self.linear2(mlp_output)
        mlp_output = self.dropout2(mlp_output, deterministic=not training, rngs=rngs)

        return x + mlp_output


class TokenAndPositionEmbedding(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        rngs: nnx.Rngs,
    ):
        self.token_emb = nnx.Embed(
            num_embeddings=vocab_size,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )
        self.pos_emb = nnx.Embed(
            num_embeddings=seqlen,
            features=embed_dim,
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )

    def __call__(self, x):
        positions = jnp.arange(0, x.shape[1])[None, :]
        position_embedding = self.pos_emb(positions)
        token_embedding = self.token_emb(x)
        return self.token_emb, token_embedding + position_embedding


class GPT2(nnx.Module):
    def __init__(
        self,
        seqlen: int,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        rate: float,
        feed_forward_dim: int,
        num_transformer_blocks: int,
        rngs: nnx.Rngs,
    ):
        self.embedding_layer = TokenAndPositionEmbedding(
            seqlen, vocab_size, embed_dim, rngs=rngs
        )
        self.dropout = nnx.Dropout(rate=rate)

        self.transformer_blocks = [
            TransformerBlock(
                embed_dim,
                num_heads,
                feed_forward_dim,
                dropout_rate,
                rngs=rngs,
                layer_idx=i,
            )
            for i in range(num_transformer_blocks)
        ]

        self.layer_norm = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(
                nnx.initializers.ones_init(), NamedSharding(mesh, P("model"))
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), NamedSharding(mesh, P("model"))
            ),
            dtype=dtype,
            param_dtype=param_dtype,
            rngs=rngs,
        )

    def __call__(
        self, inputs, padding_mask=None, training: bool = False, rngs: nnx.Rngs = None
    ):
        token_embedding, x = self.embedding_layer(inputs)
        x = self.dropout(x, deterministic=not training, rngs=rngs)
        for transformer_block in self.transformer_blocks:
            x = transformer_block(
                x, padding_mask=padding_mask, training=training, rngs=rngs
            )
        x = self.layer_norm(x)
        outputs = token_embedding.attend(x)
        return outputs

    @staticmethod
    @nnx.jit
    def sample_from(logits, key):
        logits, indices = jax.lax.top_k(logits, k=top_k)
        logits = nnx.softmax(logits / sampling_temp)
        return jax.random.choice(key, indices, p=logits)

    @staticmethod
    @nnx.jit
    def generate_step_static(params, static_def, padded_tokens, length, key):
        padding_mask = jnp.arange(seqlen) < length
        padding_mask = padding_mask.reshape(1, 1, 1, seqlen)

        model = nnx.merge(params, static_def)
        logits = model(padded_tokens, padding_mask=padding_mask, training=False)
        last_token_logits = logits[:, length - 1, :]

        key, subkey = jax.random.split(key)
        next_token = GPT2.sample_from(jnp.squeeze(last_token_logits), subkey)
        return next_token

    def generate_text(self, max_tokens, start_tokens):
        key = jax.random.PRNGKey(int(time.time()))

        params, static_def = nnx.split(self)

        tokens = jnp.array(start_tokens, dtype=jnp.int32)[None, :]
        end_token = tokenizer.encode(
            "<|endoftext|>", allowed_special={"<|endoftext|>"}
        )[0]

        current_len = tokens.shape[1]
        padded_tokens = jnp.pad(tokens, ((0, 0), (0, seqlen - current_len)), "constant")

        print(tokenizer.decode(tokens[0]), end="", flush=True)

        for i in range(max_tokens):
            key, subkey = jax.random.split(key)

            next_token = self.generate_step_static(
                params, static_def, padded_tokens, current_len, subkey
            )

            if next_token.item() == end_token:
                break

            print(tokenizer.decode([next_token.item()]), end="", flush=True)

            padded_tokens = padded_tokens.at[:, current_len].set(next_token.item())
            current_len += 1

        final_tokens = padded_tokens[0, :current_len]
        return tokenizer.decode(final_tokens.tolist())


def create_model(rngs):
    return GPT2(
        seqlen,
        vocab_size,
        embed_dim,
        num_heads,
        dropout_rate,
        feed_forward_dim,
        num_transformer_blocks,
        rngs=rngs,
    )

Use Weights and Biases to track training progress.

In [9]:
if platform == "Colab":
  from google.colab import userdata
  os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
  os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
  os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
elif platform == "Kaggle":
  from kaggle_secrets import UserSecretsClient
  user_secrets = UserSecretsClient()
  os.environ['WANDB_API_KEY'] = user_secrets.get_secret('WANDB_API_KEY')
else:
  print("Please set the WANDB_API_KEY env variable manually") #input()

wandb.login()

import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project='GPT2-DPO',

    # track hyperparameters and run metadata
    config={
      'architecture': GPT2_variant,
      'dataset': 'OpenWebText',
      'platform': platform,
      'dtype': dtype,
      'param_dtype': param_dtype,
      'init_learning_rate': init_learning_rate,
      'num_transformer_blocks': num_transformer_blocks,
      'seqlen': seqlen,
      'embed_dim': embed_dim,
      'num_heads': num_heads,
      'feed_forward_dim': feed_forward_dim,
      'max_steps': 'unknown',
      'batch_size': batch_size,
      'weight_decay': weight_decay,
      'beta': beta
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mwindmaple[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## DPO

DPO training requires a model to be trained and a separate reference model to compute loss. We are going to initialize them from our previous instruct tuned 124M GPT2 model. On Kaggle, you need to manually add the [model](https://www.kaggle.com/models/windmaple/gpt2/jax/124m-it) as input.

In [10]:
import orbax.checkpoint as orbax
from orbax.checkpoint import PyTreeCheckpointer
import numpy as np
from datasets import load_dataset

# Create the model and load the pretrained weights
model = create_model(rngs=nnx.Rngs(0))
state = nnx.state(model)
checkpointer = PyTreeCheckpointer()

checkpoint_path = '/kaggle/input/gpt2/jax/124m-it/1'
state = checkpointer.restore(checkpoint_path, item=state)
nnx.update(model, state)

# Create a reference model with the same pretrained weights
ref_model = create_model(rngs=nnx.Rngs(1))
ref_state = nnx.state(ref_model)
ref_state = checkpointer.restore(checkpoint_path, item=ref_state)
nnx.update(ref_model, ref_state)



We are going to use a [pairwise preference dataset from Argilla](https://huggingface.co/datasets/argilla/ultrafeedback-binarized-preferences-cleaned). And create a preprocessing helper function and a dataloader.

In [11]:
from datasets import load_dataset
ds = load_dataset(
    "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
)

max_steps = ds.num_rows // batch_size

# Define the template for the chat messages
template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}"

# Define a helper function to preprocess the dataset
def preprocess_function(examples):
    def format_and_tokenize(messages):
        prompts = []
        for msg_pair in messages:
            instruction = msg_pair[0]["content"]
            output = msg_pair[1]["content"] if len(msg_pair) > 1 else ""
            prompts.append(
                template.format(instruction=instruction, input="", output=output)
            )

        tokenized_prompts = [tokenizer.encode(p) for p in prompts]

        input_ids = []
        attention_masks = []

        for tokens in tokenized_prompts:
            if len(tokens) > seqlen:
                tokens = tokens[:seqlen]

            padding_len = seqlen - len(tokens)
            input_ids.append(tokens + [50256] * padding_len)
            attention_masks.append([1] * len(tokens) + [0] * padding_len)

        return np.array(input_ids), np.array(attention_masks)

    chosen_input_ids, chosen_attention_mask = format_and_tokenize(examples["chosen"])
    rejected_input_ids, rejected_attention_mask = format_and_tokenize(
        examples["rejected"]
    )

    return {
        "chosen_input_ids": chosen_input_ids,
        "chosen_attention_mask": chosen_attention_mask,
        "rejected_input_ids": rejected_input_ids,
        "rejected_attention_mask": rejected_attention_mask,
    }

# Create a data loader.
def data_loader(dataset, batch_size):
    while True:
        for i in range(0, len(dataset["chosen"]), batch_size):
            batch = {
                "chosen": dataset["chosen"][i : i + batch_size],
                "rejected": dataset["rejected"][i : i + batch_size],
            }
            processed_batch = preprocess_function(batch)
            yield processed_batch

Now we can define functions to calculate DPO loss.

In [12]:
# Define the DPO loss function
def dpo_loss(
    policy_chosen_logps,
    policy_rejected_logps,
    ref_chosen_logps,
    ref_rejected_logps,
    beta,
):
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = ref_chosen_logps - ref_rejected_logps
    return -jax.nn.log_sigmoid(beta * (pi_logratios - ref_logratios))


# Define a function to get the log probabilities of the sequences
def get_log_probs(logits, labels, attention_mask):
    batch_size, seq_len = labels.shape
    assert logits.shape[:2] == (batch_size, seq_len), f"Shape mismatch: {logits.shape} vs {labels.shape}"

    # Get the log probabilities from the logits.
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    # Get the log probabilities of the labels.
    log_probs_labels = jnp.squeeze(
        jnp.take_along_axis(log_probs, labels[:, :, None], axis=-1), -1
    )
    # Set the log probabilities of the padding tokens to 0.
    return (log_probs_labels * attention_mask).sum(axis=-1)


def calculate_loss(model, ref_model, batch, rngs):
    # Get the logits from the policy model.
    policy_chosen_logits = model(batch["chosen_input_ids"], training=True, rngs=rngs)
    policy_rejected_logits = model(
        batch["rejected_input_ids"], training=True, rngs=rngs
    )

    # Get the log probabilities from the policy model.
    policy_chosen_logps = get_log_probs(
        policy_chosen_logits,
        batch["chosen_input_ids"],
        batch["chosen_attention_mask"],
    )
    policy_rejected_logps = get_log_probs(
        policy_rejected_logits,
        batch["rejected_input_ids"],
        batch["rejected_attention_mask"],
    )

    # Get the logits from the reference model.
    ref_chosen_logits = jax.lax.stop_gradient(ref_model(batch["chosen_input_ids"], training=False))
    ref_rejected_logits = jax.lax.stop_gradient(ref_model(batch["rejected_input_ids"], training=False))

    # Get the log probabilities from the reference model.
    ref_chosen_logps = get_log_probs(
        ref_chosen_logits,
        batch["chosen_input_ids"],
        batch["chosen_attention_mask"],
    )
    ref_rejected_logps = get_log_probs(
        ref_rejected_logits,
        batch["rejected_input_ids"],
        batch["rejected_attention_mask"],
    )

    # Calculate the DPO loss.
    loss = dpo_loss(
        policy_chosen_logps,
        policy_rejected_logps,
        ref_chosen_logps,
        ref_rejected_logps,
        beta,
    )
    return jnp.mean(loss)

Now we run the finetuning.

In [13]:
start_prompt = template.format(
    instruction="What is the future for human?",
    input="",
    output="",
)
start_tokens = tokenizer.encode(start_prompt)[:seqlen]
print(f"***Generated text before DPO:")
generated_text = model.generate_text(seqlen // 5, start_tokens)

# Define the training step
@nnx.jit
def train_step(model, optimizer, ref_model, batch, rngs):
    # Calculate the loss and gradients with respect to the model's parameters.
    loss, grads = nnx.value_and_grad(calculate_loss, argnums=0)(
        model, ref_model, batch, rngs
    )
    # Update the model's parameters.
    optimizer.update(grads)
    return loss

optimizer = nnx.Optimizer(
    model, 
    optax.chain(
        optax.clip_by_global_norm(1.0),  # Add gradient clipping
        optax.adamw(learning_rate=init_learning_rate, weight_decay=weight_decay)
    )
)

data_gen = data_loader(ds, batch_size)
rngs = nnx.Rngs(0)

# Train the model
for step in range(max_steps):
    batch = next(data_gen)
    batch = jax.device_put(batch, NamedSharding(mesh, P("batch")))
    loss = train_step(model, optimizer, ref_model, batch, rngs=rngs)
    if step % 50 == 0:
        print(f"Step {step}, Loss: {loss}")        
        wandb.log(data={'Loss': loss}, step=step)

# Generate text after DPO
print(f"***Generated text after DPO:")
generated_text = model.generate_text(seqlen // 5, start_tokens)


***Generated text before DPO:
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
What is the future for human?

### Input:


### Response:
As human beings progress to their next evolutionary stages on a much larger and more advanced level and become even faster-Evolutron. In the next evolutionary stage they will be able to survive for billions of years and will become a species with a special form. Their genetic material is made up of cells that have unique features such as wings, hair, and a single organelle within the wings of a human. The cells will eventually have evolved over time, and their genetic structure and function has been modified drastically as human evolution takes a far more complex journey.Step 0, Loss: 5.59375
Step 50, Loss: 2.59375
Step 100, Loss: 2.09375
Step 150, Loss: 1.10156
Step 200, Loss: 2.0625
Step 250, Loss: 1.96094
Step 300, Loss: 2.07812
Step 350, Loss: 2.07812
Step 400, Loss: 1.9375