In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [7]:
import flax.nnx as nn
import jax.numpy as jnp
import jax


import tiktoken
enc = tiktoken.get_encoding('gpt2')

# q = "Hello, I'm a language model,"
q = "Capital of India is"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
tokens = jnp.repeat(tokens, 5, axis=0)

from jax_gpt2 import GPT, GPTConfig
# model = GPT.from_pretrained_flax('gpt2')
model = GPT.from_pretrained('gpt2')

step_key = jax.random.key(0)

while tokens.shape[1] < 30: # max_length=30
    # forward the model to get the logits
    logits = model(tokens) # (B, T, vocab_size) 
    # take the logits at the last position
    logits = logits[:, -1, :] # (B, vocab_size)
    # get the probabilities
    # probs = nn.softmax(logits, axis=-1)   # This softmax causes poor generations
    # do top-k sampling of 50 (huggingface pipeline default)
    # topk_probs here becomes (5, 50), topk_indices is (5, 50)
    top_logits, top_tokens = jax.lax.top_k(logits, min(50, logits.shape[-1]))
    step_key, subkey = jax.random.split(step_key)
    token_idx = jax.random.categorical(subkey, top_logits, axis=-1)
    next_token = jnp.take_along_axis(top_tokens, token_idx[:, None], axis=-1).squeeze(-1)
    tokens = jnp.concatenate((tokens, jnp.vstack(next_token)), axis=1)
    # print(f"Updated value of tokens.shape[1]: {tokens.shape[1]}")

# print the generated text

for i in range(5):
    x = tokens[i, :30].tolist()
    decoded = enc.decode(x)
    print(">", decoded)

loading weights from pretrained gpt: gpt2


Length of pytorch state dict: 149
Length of prepared JAX modules dict: 76
Total JAX matrices: 149
Transposing  lm_head
> Capital of India is set to launch its first full-service Indian bank service next month.

The bank, started in 2007, will provide loans
> Capital of India is on a collision course with Pakistan to ensure that it will not be forced to defend the rights of its citizens with weapons. It continues
> Capital of India is not a government funded company, and its executives are private. The company has just received money by selling shares in two state companies that
> Capital of India is now expanding ahead of the financial year in February, according to a regulatory report released today. According to the report, India's growth
> Capital of India is an integral step in our country's economic prosperity and has also helped drive India's economic development in several areas, such as investment,


# Issues faced

* Initially loading Huggingface Pytorch weights itself would not work
* Ran into an infinite loop and thought that was a slowness issue due to JAX! The reason was that tokens.shape was not getting updated in the generate loop.
* After much struggle realized that I was returning x instead of the processed y in the Self Attention code!
* And Attention code from https://github.com/cgarciae/nanoGPT-jax and https://github.com/jenkspt/gpt-jax helped get a first working version
* Attention version with dot_product_attention and make_causal_mask from flax.nnx was fixed next by correcting the transpose.
* The version with weights from HF Flax was still not working. Model result comparison code in GPT._compare was helpful finding that the issue was in the Attention block.
* The problem was a missing transpose when loading weights - attn.c_proj 