<a href="https://colab.research.google.com/github/s4piru/GPT2-JAX/blob/main/gpt2_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GPT-2 (Decoder-only Transformer)

Some transformers have both encoder and ecoder such as machine translation model. However, GPT-2 is an autoregressive language model that uses only decoder. Since GPT-2 focuses on generating the next token probabilistically based on the given context, it does not need encoder and only uses decoder.

## Self-Attention

Self-attention is the core mechanism of Transformers. In self-attention, each word in a sequence calculates its relevance to other words and integrates information accordingly.

* Specifically, for each input word vector $\mathbf{x}$, three vectors are generated: Query, Key, and Value. These are computed using learnable weight matrices:  
  $\mathbf{q} = \mathbf{x} \mathbf{W}^Q$,  
  $\mathbf{k} = \mathbf{x} \mathbf{W}^K$,  
  $\mathbf{v} = \mathbf{x} \mathbf{W}^V$.

* Calculate attention scores: The dot product of a Query vector $\mathbf{q}$ and a Key vector $\mathbf{k}$ determines how much attention should be paid to each word. A softmax function normalizes these scores.

* Weight/Aggregate Values: The obtained attention weights are applied to the Value vectors, and their weighted sum creates a new word representation. This process updates a words representation to reflect related words information.

If the word "run" assigns high attention weight to "dog," it indicates the information that "the one running is the dog."

## Causal Mask

In decoder's self-attention, causal mask is essential to prevent information leakage from the future. Specifically, at time step $t$, tokens must not see words from $t$ onward. This is done by adding a mask value of $-\infty$ to invalid positions in the attention score matrix. This ensures that, after applying softmax, future words receive zero attention, allowing Transformer decoder to generate text sequentially from left to right.

## Positional Encoding

Self-attention learns relationships between words regardless of order. However, without explicitly encoding position, it cannot distinguish word order and require explicit position encoding.

* Fixed positional encoding: Uses sin and cos waves at different frequencies to encode position. Close positions have similar wave patterns, while distant positions show greater phase shifts.

* Learnable positional encoding: A later approach allows the model to learn position encodings itself.

* Relative positional encoding: Some models, like BERT extensions, use relative distance-based encoding.

## Multi-Head Attention

Instead of a single attention mechanism, Transformers use multiple parallel attention heads. This is called multi-head attention. Each head learns different weight matrices $\mathbf{W}_i^Q, \mathbf{W}_i^K, \mathbf{W}_i^V$ and captures different aspects of word relationships.

GPT-2 Small has 12 attention heads, each with different weight matrices. Each head projects the input 768 dimensions hidden vector into a 64-dimensional space. These projections are then combined to capture diverse relationships.

## Feed Forward Network

In each Transformer layer, after multi-head self-attention, a Feed Forward Networkis applied. The FFN captures nonlinear relationships.

$\mathbf{FFN}(\mathbf{h}) = \mathbf{GELU}(\mathbf{h}W_1 + \mathbf{b}_1) W_2 + \mathbf{b}_2$

FFN applies transformations to enhance word-level features. First, it temporarily expands the hidden dimension (GPT-2 is from 768 to 3072), applies a non-linear GELU activation, and then reduces it back to the original size.

## LayerNorm

Standard Transformers apply Layer Normalization after each sublayer. However, GPT-2 uses a Pre-LN architecture, where LayerNorm is applied before sublayers. LayerNorm normalizes vector elements to zero mean and unit variance per sample, stabilizing training by preventing gradient vanishing or divergence.

## Residual Connections

In deep neural networks like Transformers, deeper layers improve representation power. However, deeper networks also face gradient vanishing issues, making it harder for early layers to learn. Adding residual connections helps preserve gradients.

If $\mathbf{X}$ is the input to a layer, and $\mathbf{Sublayer}(\mathbf{X})$ represents its processing, the residual connection transforms it as

$\mathbf{Y} = \mathbf{X} + \mathbf{Sublayer}(\mathbf{X})\$

This ensures that at least the input gradient is maintained, preventing excessively small gradients and enabling stable training.

## Text Generation Methods

### 1. Greedy Search
Greedy Search is the simplest way to generate text. It selects the token with the highest probability from the softmax distribution when generating the next word. At each time step $t$, it picks the most probable next word.

**Pros**
* It has the lowest computational cost, making it suitable for real-time inference.
* It is deterministic, allowing high reproducibility.

**Cons**
* It easily gets stuck in local optima, leading to repetitive text generation.

### 2. Temperature Sampling
Temperature Sampling is a method that randomly selects the next word based on probability distribution, adjusting the softmax probabilities with a temperature parameter to control diversity. A lower temperature makes high-probability tokens more likely to be chosen, while a higher temperature flattens the probabilities, allowing for more diverse word selection.

**Pros**
* It allows control over randomness, producing outputs with a balanced level of diversity.

**Cons**
* If the temperature is too high, low-probability words may be selected, increasing the risk of meaningless text.
* If the temperature is too low, the output becomes similar to Greedy Search.

### 3. Top-k Sampling
Top-k Sampling considers only the top $k$ most probable tokens, normalizes their probabilities, and samples from them.

**Pros**
* It allows control over randomness.
* Prevents selecting semantically inappropriate words.

**Cons**
* If $k$ is too small, it behaves similarly to Greedy Search, reducing diversity.
* If $k$ is too large, it includes low-probability tokens, which may reduce coherence.
* Choosing an appropriate $k$ value is difficult.

### 4. Top-p Sampling
Top-p Sampling selects the smallest set of words whose cumulative probability exceeds $p$, and samples from them.

**Pros**
* Unlike Top-k, this method dynamically adjusts the number of candidates based on the probability distribution.

**Cons**
* Choosing an appropriate $p$ value is difficult.

### 5. Beam Search
Beam Search explores multiple candidate sequences in parallel and selects the one with the highest overall probability. Unlike Greedy Search, which selects the best local word at each step, Beam Search expands multiple candidates based on the beam width and considers future words.

**Pros**
* It finds more contextually probable sequences compared to Greedy Search.
* It is deterministic, allowing high reproducibility.

**Cons**
* A larger beam width increases computational cost.
* All beams may converge to similar results.
* Since it maximizes cumulative probability, shorter sentences may be preferred over longer ones.


Let's implement the GPT-2 architecture using JAX/Flax and apply pretrained parameters from Hugging Face.

In [None]:
!nvidia-smi
!pip install --upgrade pip
!pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install flax transformers

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import freeze, unfreeze
from flax.linen.attention import dot_product_attention
from transformers import FlaxGPT2LMHeadModel, GPT2Tokenizer
import numpy as np

In [None]:
# This is a custom implementation of a dense fully-connected layer similar to Flax's nn.Dense.
class DenseT(nn.Module):
    features: int
    use_bias: bool = True
    kernel_init: any = nn.initializers.lecun_normal()

    @nn.compact
    def __call__(self, inputs):
        in_features = inputs.shape[-1]
        kernel = self.param("kernel", self.kernel_init, (in_features, self.features))
        y = jnp.dot(inputs, kernel)
        if self.use_bias:
            bias = self.param("bias", nn.initializers.zeros, (self.features,))
            y = y + bias
        return y

# Applying a single linear transformation to the input x in order to generate a concatenated tensor containing the query, key, and value vectors.
# Reshaping the tensor from (B, T, C) to (B, T, n_head, head_dim) to facilitate multi-head attention via the dot_product_attention function.
# After computing the attention, the output is reshaped back to (B, T, C) and processed by a projection layer (c_proj).
# A causal mask, implemented as a lower triangular matrix of shape (T, T).
class GPT2SelfAttention(nn.Module):
    n_embd: int
    n_head: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask=None, deterministic=True):
        B, T, C = x.shape
        head_dim = C // self.n_head

        # GPT-2 parameters derived from PyTorch, the weight for the concatenated query, key, and value transformation is often stored with shape (3*C, C).
        # However, Flax's Dense layer expects parameters of shape (C, 3*C). Therefore, apply transpose during parameter initialization.
        x_3c = DenseT(3 * C, use_bias=True, name="c_attn")(x)
        x_3c = x_3c.reshape(B, T, 3, C)
        q, k, v = jnp.split(x_3c, 3, axis=2)
        q = jnp.squeeze(q, axis=2)
        k = jnp.squeeze(k, axis=2)
        v = jnp.squeeze(v, axis=2)

        # Multi-head splitting: Reshape query, key, and value tensors to (B, T, n_head, head_dim)
        q = q.reshape(B, T, self.n_head, head_dim)
        k = k.reshape(B, T, self.n_head, head_dim)
        v = v.reshape(B, T, self.n_head, head_dim)

        # Causal mask: Expand its dimensions to (B, n_head, T, T) and create a bias tensor.
        # Positions disallowed by the mask are set to a very negative value (-1e10), effectively blocking attention.
        if mask is not None:
            expanded_mask = jnp.broadcast_to(mask, (B, self.n_head, T, T))
            bias = jnp.where(expanded_mask, 0.0, -1e10)
        else:
            bias = None

        # Compute dot-product attention using the queries, keys, values, and bias.
        # The dropout rate is applied only when not in deterministic mode.
        attn_out = dot_product_attention(
            query=q,
            key=k,
            value=v,
            bias=bias,
            dropout_rate=self.dropout_rate if not deterministic else 0.0,
            deterministic=deterministic,
        )
        # Reshape the attention output back to the original shape (B, T, C).
        attn_out = attn_out.reshape(B, T, C)

        # Apply a final linear projection and transpose.
        out = nn.Dense(C, use_bias=True, name="c_proj")(attn_out)
        return out

# Expanding the input dimensionality to 4*C via a fully-connected layer.
# Applying the GELU activation function.
# Using dropout for regularization.
# Projecting the result back to the original dimension (C) with another fully-connected layer.
class FeedForward(nn.Module):
    n_embd: int
    dropout_rate: float = 0.1
    name: str = "ffn"

    @nn.compact
    def __call__(self, x, deterministic=True):
        hidden = DenseT(4 * self.n_embd, use_bias=True, name="fc")(x)
        hidden = nn.gelu(hidden)
        hidden = nn.Dropout(rate=self.dropout_rate)(hidden, deterministic=deterministic)
        out = DenseT(self.n_embd, use_bias=True, name="proj")(hidden)
        return out

# The output is computed as the element-wise sum of the token and positional embeddings.
class GPT2Embed(nn.Module):
    vocab_size: int
    max_length: int
    n_embd: int

    def setup(self):
        self.token_embed = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embd)
        self.pos_embed = self.param(
            "pos_embedding",
            nn.initializers.normal(stddev=0.02),
            (self.max_length, self.n_embd),
        )

    def __call__(self, input_ids):
        B, T = input_ids.shape
        # Compute token embeddings; the resulting tensor: (B, T, n_embd).
        token_emb = self.token_embed(input_ids)
        # Slice the positional embeddings to match the sequence length T: (T, n_embd).
        pos_emb = self.pos_embed[:T, :]
        # Expand dimensions of positional embeddings to (1, T, n_embd) for broadcasting.
        pos_emb = jnp.expand_dims(pos_emb, axis=0)
        # Combine token and positional embeddings.
        emb_out = token_emb + pos_emb
        return emb_out

class TransformerBlock(nn.Module):
    n_embd: int
    n_head: int
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x, mask=None, deterministic=True):
        # Self-Attention Sub-layer
        residual = x  # Save input for the residual (skip) connection.
        x = nn.LayerNorm(epsilon=1e-5)(x)  # Apply LayerNorm before self-attention.
        x = GPT2SelfAttention(
            n_embd=self.n_embd,
            n_head=self.n_head,
            dropout_rate=self.dropout_rate,
            name="attn"
        )(x, mask=mask, deterministic=deterministic)
        x = residual + x  # Add the residual connection.

        # Feed-Forward (MLP) Sub-layer:
        residual = x  # Save current tensor for the residual connection.
        x = nn.LayerNorm(epsilon=1e-5)(x)  # Apply LayerNorm before the MLP.
        x = FeedForward(
            self.n_embd,
            self.dropout_rate,
            name="ffn"
        )(x, deterministic=deterministic)
        x = residual + x  # Add the residual connection.
        return x

class GPT2LMModel(nn.Module):
    vocab_size: int
    max_length: int
    n_embd: int
    n_head: int
    n_layer: int
    dropout_rate: float = 0.1

    def setup(self):
        # Initialize the embedding module.
        self.embed = GPT2Embed(
            vocab_size=self.vocab_size,
            max_length=self.max_length,
            n_embd=self.n_embd,
            name="embed"
        )
        # Create a list of Transformer blocks.
        self.blocks = [
            TransformerBlock(
                n_embd=self.n_embd,
                n_head=self.n_head,
                dropout_rate=self.dropout_rate,
                name=f"block_{i}"
            )
            for i in range(self.n_layer)
        ]
        # Final LayerNorm applied after the Transformer blocks.
        self.ln_f = nn.LayerNorm(epsilon=1e-5)
        # LM Head: A fully-connected layer that projects the hidden state to the vocabulary size.
        # Weight tying is performed by later setting lm_head.kernel equal to the transpose of token_embed.embedding.
        self.lm_head = nn.Dense(self.vocab_size, use_bias=False, name="lm_head")

    def __call__(self, input_ids, deterministic=True):
        B, T = input_ids.shape
        x = self.embed(input_ids)
        # Create a causal mask: generate a lower triangular matrix (T, T) to prevent attention to future tokens.
        causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))
        # Expand the mask dimensions to (B, 1, T, T) for broadcasting over batches and attention heads.
        causal_mask = causal_mask[None, None, :, :]
        causal_mask = jnp.broadcast_to(causal_mask, (B, 1, T, T))

        # Pass the input through each Transformer block with the causal mask applied.
        for blk in self.blocks:
            x = blk(x, mask=causal_mask, deterministic=deterministic)

        x = self.ln_f(x)  # Apply the final LayerNorm.
        logits = self.lm_head(x)  # Project to vocabulary size.
        return logits

    def forward_intermediate(self, input_ids, deterministic=True):
        """
        Extract intermediate outputs (hidden states) from each layer.
        This function is intended for debugging and comparison purposes.
        """
        B, T = input_ids.shape
        x = self.embed(input_ids)
        hidden_states = [x]  # Store the embedding layer output.
        # Create a causal mask.
        causal_mask = jnp.tril(jnp.ones((T, T), dtype=bool))[None, None, :, :]
        causal_mask = jnp.broadcast_to(causal_mask, (B, 1, T, T))

        # Collect the output of each Transformer block.
        for blk in self.blocks:
            x = blk(x, mask=causal_mask, deterministic=deterministic)
            hidden_states.append(x)
        x = self.ln_f(x)  # Apply the final LayerNorm.
        hidden_states[-1] = x  # Replace the last hidden state with the post-LayerNorm output.
        return hidden_states


# Parameter Conversion:
# Transpose c_attn weights from (3*C, C) to (C, 3*C).
# Transpose c_proj weights.
# Transpose c_fc weights from (4*C, C) to (C, 4*C) and c_proj weights from (C, 4*C) to (4*C, C).
def convert_hf_params_to_my_model(hf_params, my_params, n_layer=12):
    mp = unfreeze(my_params)

    # Copy token embeddings with shape (vocab_size, n_embd)
    mp["embed"]["token_embed"]["embedding"] = hf_params["transformer"]["wte"]["embedding"]
    # Copy positional embeddings with shape (max_length, n_embd)
    mp["embed"]["pos_embedding"] = hf_params["transformer"]["wpe"]["embedding"]

    # Copy parameters for each Transformer block
    for i in range(n_layer):
        hf_block = hf_params["transformer"]["h"][str(i)]
        block_key = f"block_{i}"

        # Pre-self-attention LayerNorm (ln_1): Copy bias and scale parameters.
        mp[block_key]["LayerNorm_0"]["bias"] = hf_block["ln_1"]["bias"]
        mp[block_key]["LayerNorm_0"]["scale"] = hf_block["ln_1"]["scale"]

        # Self-Attention:
        # For c_attn, transpose the kernel from (3*C, C) to (C, 3*C) and copy the bias.
        hf_c_attn_kernel = hf_block["attn"]["c_attn"]["kernel"]
        hf_c_attn_bias = hf_block["attn"]["c_attn"]["bias"]
        mp[block_key]["attn"]["c_attn"]["kernel"] = hf_c_attn_kernel.T
        mp[block_key]["attn"]["c_attn"]["bias"] = hf_c_attn_bias

        # For c_proj in self-attention, even though the weight matrix is (C, C),
        # Apply transpose to ensure correct element ordering, and the bias is copied.
        hf_c_proj_kernel = hf_block["attn"]["c_proj"]["kernel"]
        hf_c_proj_bias = hf_block["attn"]["c_proj"]["bias"]
        mp[block_key]["attn"]["c_proj"]["kernel"] = hf_c_proj_kernel.T
        mp[block_key]["attn"]["c_proj"]["bias"] = hf_c_proj_bias

        # Pre-MLP LayerNorm (ln_2): Copy bias and scale parameters.
        mp[block_key]["LayerNorm_1"]["bias"] = hf_block["ln_2"]["bias"]
        mp[block_key]["LayerNorm_1"]["scale"] = hf_block["ln_2"]["scale"]

        # MLP: For the first linear layer (c_fc), transpose the kernel from (4*C, C) to (C, 4*C) and copy the bias.
        c_fc_kernel = hf_block["mlp"]["c_fc"]["kernel"]
        c_fc_bias = hf_block["mlp"]["c_fc"]["bias"]
        mp[block_key]["ffn"]["fc"]["kernel"] = c_fc_kernel.T
        mp[block_key]["ffn"]["fc"]["bias"] = c_fc_bias

        # For the second linear layer (c_proj) in the MLP, transpose the kernel from (C, 4*C) to (4*C, C) and copy the bias.
        c_proj_kernel = hf_block["mlp"]["c_proj"]["kernel"]
        c_proj_bias = hf_block["mlp"]["c_proj"]["bias"]
        mp[block_key]["ffn"]["proj"]["kernel"] = c_proj_kernel.T
        mp[block_key]["ffn"]["proj"]["bias"] = c_proj_bias

    # Final LayerNorm (ln_f): Copy bias and scale parameters.
    mp["ln_f"]["bias"] = hf_params["transformer"]["ln_f"]["bias"]
    mp["ln_f"]["scale"] = hf_params["transformer"]["ln_f"]["scale"]

    # LM Head: Tie the LM head to the token embeddings by transposing the token embedding.
    wte = hf_params["transformer"]["wte"]["embedding"]  # (vocab_size, n_embd)
    mp["lm_head"]["kernel"] = wte.T # (n_embd, vocab_size)

    return freeze(mp)

# Generation Functions
def greedy_generate(bound_model, tokenizer, prompt: str, max_new_tokens: int = 30):
    """
    Greedy generation:
    - At each step, selects the token with the highest probability.
    - Stops generation if the end-of-sequence token is produced.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="np")
    input_ids = jnp.array(input_ids, dtype=jnp.int32)
    for _ in range(max_new_tokens):
        logits = bound_model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token_id = jnp.argmax(next_token_logits, axis=-1)
        input_ids = jnp.concatenate([input_ids, next_token_id[:, None]], axis=1)
        # If the EOS token is generated, exit early.
        if next_token_id[0] == tokenizer.eos_token_id:
            break
    output_ids = np.array(input_ids[0])
    return tokenizer.decode(output_ids)

def sample_generate(bound_model, tokenizer, prompt: str, max_new_tokens: int = 30, temperature: float = 1.0, rng=jax.random.PRNGKey(0)):
    """
    Temperature Sampling:
    - At each step, scales logits by temperature and samples from the probability distribution.
    - Continues until max_new_tokens are generated or the EOS token is produced.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="np")
    input_ids = jnp.array(input_ids, dtype=jnp.int32)
    for _ in range(max_new_tokens):
        logits = bound_model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        next_token_probs = jax.nn.softmax(next_token_logits)
        rng, subkey = jax.random.split(rng)
        next_token_id = jax.random.choice(subkey, next_token_probs.shape[-1], p=next_token_probs[0])
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token_id]], dtype=jnp.int32)], axis=1)
        if next_token_id == tokenizer.eos_token_id:
            break
    output_ids = np.array(input_ids[0])
    return tokenizer.decode(output_ids)

def top_k_generate(bound_model, tokenizer, prompt: str, max_new_tokens: int = 30, k: int = 50, temperature: float = 1.0, rng=jax.random.PRNGKey(0)):
    """
    Top-k generation:
    - At each decoding step, retains only the top k tokens with the highest logits.
    - All other token logits are set to a very low value (-1e10) to exclude them from sampling.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="np")
    input_ids = jnp.array(input_ids, dtype=jnp.int32)
    for _ in range(max_new_tokens):
        logits = bound_model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        kth_value = jnp.sort(next_token_logits, axis=-1)[:, -k]
        filtered_logits = jnp.where(next_token_logits < kth_value[:, None], -1e10, next_token_logits)
        next_token_probs = jax.nn.softmax(filtered_logits)
        rng, subkey = jax.random.split(rng)
        next_token_id = jax.random.choice(subkey, next_token_probs.shape[-1], p=next_token_probs[0])
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token_id]], dtype=jnp.int32)], axis=1)
        if next_token_id == tokenizer.eos_token_id:
            break
    output_ids = np.array(input_ids[0])
    return tokenizer.decode(output_ids)

def top_p_generate(bound_model, tokenizer, prompt: str, max_new_tokens: int = 30,
                   top_p: float = 0.9, temperature: float = 1.0, rng=jax.random.PRNGKey(0)):
    """
    Top-p generation:
    - At each step, sorts token probabilities in descending order and computes their cumulative sum.
    - Retains only the smallest set of tokens whose cumulative probability is within top_p.
    - Tokens with probabilities below the threshold are suppressed.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="np")
    input_ids = jnp.array(input_ids, dtype=jnp.int32)
    for _ in range(max_new_tokens):
        logits = bound_model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        next_token_probs = jax.nn.softmax(next_token_logits)
        sorted_probs = jnp.sort(next_token_probs, axis=-1)[:, ::-1]
        sorted_indices = jnp.argsort(next_token_probs, axis=-1)[:, ::-1]
        cumulative_probs = jnp.cumsum(sorted_probs, axis=-1)
        cutoff = cumulative_probs > top_p
        cutoff_indices = jnp.argmax(cutoff, axis=-1)
        threshold = sorted_probs.at[jnp.arange(sorted_probs.shape[0]), cutoff_indices].get()
        filtered_logits = jnp.where(next_token_probs < threshold[:, None], -1e10, next_token_logits)
        next_token_probs = jax.nn.softmax(filtered_logits)
        p_1d = next_token_probs[0]
        rng, subkey = jax.random.split(rng)
        next_token_id = jax.random.choice(
            subkey,
            p_1d.shape[0],
            p=p_1d
        )
        input_ids = jnp.concatenate(
            [input_ids, jnp.array([[next_token_id]], dtype=jnp.int32)],
            axis=1
        )
        if next_token_id == tokenizer.eos_token_id:
            break
    output_ids = np.array(input_ids[0])
    return tokenizer.decode(output_ids)

def beam_search_generate(bound_model, tokenizer, prompt: str, beam_width: int = 3, max_new_tokens: int = 30):
    """
    Beam Search:
    - At each step, expands each beam by considering the top-k token candidates.
    - Retains only the best beam_width sequences based on cumulative log-probabilities.
    - Note: This implementation does not support length penalty adjustments.
    """
    input_ids = tokenizer.encode(prompt, return_tensors="np")
    input_ids = jnp.array(input_ids, dtype=jnp.int32)
    beams = [(input_ids, 0.0)]
    for _ in range(max_new_tokens):
        new_beams = []
        for seq, score in beams:
            logits = bound_model(seq)
            next_token_logits = logits[:, -1, :]
            log_probs = jax.nn.log_softmax(next_token_logits)
            topk_log_probs, topk_indices = jax.lax.top_k(log_probs, beam_width)
            for i in range(beam_width):
                new_seq = jnp.concatenate([seq, topk_indices[:, i:i+1]], axis=1)
                new_score = score + float(topk_log_probs[0, i])
                new_beams.append((new_seq, new_score))
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]
        if any(b[0][0, -1] == tokenizer.eos_token_id for b in beams):
            break
    best_seq = beams[0][0]
    return tokenizer.decode(np.array(best_seq[0]))


# Functions for Comparing Intermediate Layer Outputs
# Run the Hugging Face model and retrieve all hidden states.
def forward_intermediate_hf(hf_model, input_ids):
    outputs = hf_model(input_ids, output_hidden_states=True)
    return outputs.hidden_states

# Run the custom model and retrieve intermediate outputs from each layer.
def forward_intermediate_my(bound_model, input_ids):
    return bound_model.forward_intermediate(input_ids, deterministic=True)

# Compare and print the L2 norm differences between corresponding hidden states from both models.
def compare_intermediate_states(hf_states, my_states):
    n = min(len(hf_states), len(my_states))
    print(f"\n[Compare intermediate states] HF layers: {len(hf_states)}, MY layers: {len(my_states)}")
    for i in range(n):
        diff = hf_states[i] - my_states[i]
        diff_norm = jnp.linalg.norm(diff)
        print(f"Layer {i}: shape={diff.shape}, L2 diff = {float(diff_norm):.4f}")


In [None]:
hf_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
hf_params = hf_model.params

my_model = GPT2LMModel(
    vocab_size=tokenizer.vocab_size,
    max_length=1024,  # Default maximum sequence length for GPT-2.
    n_embd=768,
    n_head=12,
    n_layer=12,
    dropout_rate=0.1
)
rng = jax.random.PRNGKey(0)
dummy_input_ids = jnp.zeros((1, 1), dtype=jnp.int32)
variables = my_model.init(rng, dummy_input_ids)
my_params = variables["params"]

# Convert Hugging Face parameters to the my model format.
converted_params = convert_hf_params_to_my_model(hf_params, my_params, n_layer=12)

# Bind the converted parameters to my model.
bound_model = my_model.bind({"params": converted_params})

prompt_text = "I am a dog"
output_greedy = greedy_generate(bound_model, tokenizer, prompt_text)
output_sample = sample_generate(bound_model, tokenizer, prompt_text, rng=jax.random.PRNGKey(42))
output_top_k = top_k_generate(bound_model, tokenizer, prompt_text, k=40, rng=jax.random.PRNGKey(123))
output_top_p = top_p_generate(bound_model, tokenizer, prompt_text, top_p=0.9, rng=jax.random.PRNGKey(321))
output_beam = beam_search_generate(bound_model, tokenizer, prompt_text, beam_width=5, max_new_tokens=30)

print("===== Prompt =====")
print(prompt_text)
print("===== Greedy Generation =====")
print(output_greedy)
print("===== Temperature Sampling =====")
print(output_sample)
print("===== Top-k Generation =====")
print(output_top_k)
print("===== Top-p Generation =====")
print(output_top_p)
print("===== Beam Search Generation =====")
print(output_beam)

# Compare intermediate layer outputs for validation.
test_input_ids = tokenizer(prompt_text, return_tensors="np")["input_ids"]
hf_states = forward_intermediate_hf(hf_model, test_input_ids)
my_states = forward_intermediate_my(bound_model, test_input_ids)
compare_intermediate_states(hf_states, my_states)

===== Prompt =====
I am a dog
===== Greedy Generation =====
I am a dog lover. I love to play with my dog and I love to play with my dog. I love to play with my dog and I love to play
===== Temperature Sampling =====
I am a dogist and don't like to confront someone for no ludicrous reason. We all have different walks of life and while I believe that some of them do make
===== Top-k Generation =====
I am a dog lover", she says, with a laugh. "So I was really looking forward to watching my dog play, and I had a friend come by the
===== Top-p Generation =====
I am a dog trainer. To the outside world I am rude and not real nice and to the people inside I am am crazy, disrespectful, cold, domineering
===== Beam Search Generation =====
I am a dog lover. I love to play with dogs. I love to play with dogs. I love to play with dogs. I love to play with dogs.

[Compare intermediate states] HF layers: 13, MY layers: 13
Layer 0: shape=(1, 4, 768), L2 diff = 0.0000
Layer 1: shape=(1, 4, 768), L2 dif