In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

# Testing the generation code using huggingface models

In [3]:
import flax.nnx as nn
import numpy as np
import jax.numpy as jnp
import jax

import tiktoken
enc = tiktoken.get_encoding('gpt2')

# q = "Hello, I'm a language model,"
q = "Capital of India is"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
tokens = jnp.repeat(tokens, 5, axis=0)

from transformers import GPT2Tokenizer, FlaxGPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = FlaxGPT2Model.from_pretrained('gpt2')
model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')

step_key = jax.random.key(0)

while tokens.shape[1] < 30: # max_length=30
    # forward the model to get the logits
    logits = model(tokens).logits # (B, T, vocab_size) 
    # take the logits at the last position
    logits = logits[:, -1, :] # (B, vocab_size)
    # get the probabilities
    probs = nn.softmax(logits, axis=-1)
    # do top-k sampling of 50 (huggingface pipeline default)
    # topk_probs here becomes (5, 50), topk_indices is (5, 50)
    # top_logits, top_tokens = jax.lax.top_k(logits, min(50, logits.shape[-1]))
    top_logits, top_tokens = jax.lax.top_k(logits, min(50, logits.shape[-1]))
    step_key, subkey = jax.random.split(step_key)
    token_idx = jax.random.categorical(step_key, top_logits, axis=-1)
    next_token = jnp.take_along_axis(top_tokens, token_idx[:, None], axis=-1).squeeze(-1)
    tokens = jnp.concatenate((tokens, jnp.vstack(next_token)), axis=1)
    # print(f"Updated value of tokens.shape[1]: {tokens.shape[1]}")

# print the generated text

for i in range(5):
    x = tokens[i, :30].tolist()
    decoded = enc.decode(x)
    print(">", decoded)

> Capital of India is also developing a program called AOID."AOID is a unique approach to digital rights management. It will become available to
> Capital of India is not a sovereign or democratic government. A sovereign or democratically elected government is not a sovereign state or democratic corporation.

India has
> Capital of India is facing rising cost pressures and rising borrowing costs, and an ever-shrinking middle class. It seems clear that Narendra Modi is one
> Capital of India is also the most expensive financial institution we have seen in our history. Even if BIC had been in the financial sector in 1989,
> Capital of India is currently in transition to create a more diversified economy. India's growth has been very sluggish since the middle of 2007 but its share


In [4]:
import torch
from torch.nn import functional as F
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2") # 124M
model.eval()
torch.manual_seed(42)

import tiktoken
enc = tiktoken.get_encoding('gpt2')

# q = "Hello, I'm a language model,"
q = "Capital of India is"
tokens = enc.encode(q)
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(5, 1) # (5, 8)
x = tokens

# generate!
while x.size(1) < 30: # max_length=30
    # forward the model to get the logits
    with torch.no_grad():
        logits = model(x)[0] # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        x = torch.cat((x, xcol), dim=1)

# print the generated text
for i in range(5):
    tokens = x[i, :30].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

<torch._C.Generator at 0x7ffb107b18d0>

> Capital of India is very interested in providing its own private development project. It took nearly a year or longer before he announced the acquisition of TPG.
> Capital of India is under siege, and its financial losses have exceeded Rs1.8 billion. In response, the US Treasury has asked all major states
> Capital of India is India's largest investment bank. In fact, it has nearly all of India's major cities. In 2012-13, a total
> Capital of India is on the verge of defaulting on its debt, which it believes is too high.<|endoftext|>Gardiner Express has confirmed the news
> Capital of India is going through great decline."

Mr. Modi will likely do what is in his best interests for the country, as he will


# GPT2 itself is not that capable by the looks? But generation process is OK

# Comparing results from HF Pytorch model weights and HF Flax model weights - Flax was failing initially(gibberish)

In [5]:
import jax.numpy as jnp

import tiktoken
enc = tiktoken.get_encoding('gpt2')

# q = "Hello, I'm a language model,"
q = "Capital of India is"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)

from jax_gpt2 import GPT
model_flax = GPT.from_pretrained_flax('gpt2')  
model_hf = GPT.from_pretrained('gpt2')

model_hf._compare(model_flax, tokens)

loading weights from pretrained gpt: gpt2


Length of prepared JAX modules dict: 76
loading weights from pretrained gpt: gpt2
Length of pytorch state dict: 149
Length of prepared JAX modules dict: 76
Transposing  lm_head
Checking wpe: True
Checking wte: True
Checking post token embedding + position embedding: True
Checking block0 - layernorm 1: True
Checking block0 - self attention: True
Checking block0 - post residual: True
Checking block0 - layernorm 2: True
Checking block0 - MLP: True
Checking block0 - post residual: True
Checking block1 - layernorm 1: True
Checking block1 - self attention: True
Checking block1 - post residual: True
Checking block1 - layernorm 2: True
Checking block1 - MLP: True
Checking block1 - post residual: True
Checking block2 - layernorm 1: True
Checking block2 - self attention: True
Checking block2 - post residual: True
Checking block2 - layernorm 2: True
Checking block2 - MLP: True
Checking block2 - post residual: True
Checking block3 - layernorm 1: True
Checking block3 - self attention: True
Checking

# Had a doubt if the weights loaded from HF Flax model were somehow not getting updated - not entirely sensible

In [6]:
from transformers import FlaxGPT2LMHeadModel
model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')
from flax.core import unfreeze
from flax.traverse_util import flatten_dict
params = unfreeze(model.params['transformer'])
params = flatten_dict(params, sep='.')

from jax_gpt2 import GPT
model_flax = GPT.from_pretrained_flax('gpt2')  
jax_modules_dict = {}
for module_pair in model_flax.iter_modules():
    if type(module_pair[1]).__name__  in ['Block', 'CausalSelfAttention', 'GPT', 'MLP']:
        continue
    module_path = '.'.join([str(x) for x in module_pair[0]])
    module = module_pair[1]
    jax_modules_dict[module_path] = module

for param in params:
    t = param.split('.')[-1]    # Inner key Eg. wpe.embedding
    jax_key = '.'.join(param.split('.')[:-1:])
    if params[param].shape == jax_modules_dict[jax_key].__dict__[t].value.shape:
        print(f"Checking {param}: {jnp.allclose(params[param], jax_modules_dict[jax_key].__dict__[t].value)}")
    elif params[param].T.shape == jax_modules_dict[jax_key].__dict__[t].value.shape:
        print(f"Checking {param}: {jnp.allclose(params[param].T, jax_modules_dict[jax_key].__dict__[t].value)}")
    else:
        print(f"Shape mismatch for {param}")    

loading weights from pretrained gpt: gpt2
Length of prepared JAX modules dict: 76
Checking h.0.attn.c_attn.bias: True
Checking h.0.attn.c_attn.kernel: True
Checking h.0.attn.c_proj.bias: True
Checking h.0.attn.c_proj.kernel: False
Checking h.0.ln_1.bias: True
Checking h.0.ln_1.scale: True
Checking h.0.ln_2.bias: True
Checking h.0.ln_2.scale: True
Checking h.0.mlp.c_fc.bias: True
Checking h.0.mlp.c_fc.kernel: True
Checking h.0.mlp.c_proj.bias: True
Checking h.0.mlp.c_proj.kernel: True
Checking h.1.attn.c_attn.bias: True
Checking h.1.attn.c_attn.kernel: True
Checking h.1.attn.c_proj.bias: True
Checking h.1.attn.c_proj.kernel: False
Checking h.1.ln_1.bias: True
Checking h.1.ln_1.scale: True
Checking h.1.ln_2.bias: True
Checking h.1.ln_2.scale: True
Checking h.1.mlp.c_fc.bias: True
Checking h.1.mlp.c_fc.kernel: True
Checking h.1.mlp.c_proj.bias: True
Checking h.1.mlp.c_proj.kernel: True
Checking h.10.attn.c_attn.bias: True
Checking h.10.attn.c_attn.kernel: True
Checking h.10.attn.c_proj.bi