In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [2]:
from jax_gpt2 import GPT, GPTConfig
model = GPT.from_pretrained('gpt2')
model

  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained gpt: gpt2
Length of pytorch state dict: 149
Length of prepared JAX modules dict: 76
Transposing  lm_head


GPT(
  config=GPTConfig(block_size=1024, vocab_size=50257, n_layer=12, n_head=12, n_embd=768),
  wte=Embed(
    embedding=Param(
      value=Array(shape=(50257, 768), dtype=float32)
    ),
    num_embeddings=50257,
    features=768,
    dtype=dtype('float32'),
    param_dtype=<class 'jax.numpy.float32'>,
    embedding_init=<function variance_scaling.<locals>.init at 0x7f7b84c03e50>
  ),
  wpe=Embed(
    embedding=Param(
      value=Array(shape=(1024, 768), dtype=float32)
    ),
    num_embeddings=1024,
    features=768,
    dtype=dtype('float32'),
    param_dtype=<class 'jax.numpy.float32'>,
    embedding_init=<function variance_scaling.<locals>.init at 0x7f7b84c03e50>
  ),
  h=[Block(
    ln_1=LayerNorm(
      scale=Param(
        value=Array(shape=(768,), dtype=float32)
      ),
      bias=Param(
        value=Array(shape=(768,), dtype=float32)
      ),
      num_features=768,
      epsilon=1e-05,
      dtype=None,
      param_dtype=<class 'jax.numpy.float32'>,
      use_bias=True,
 

In [4]:
import jax.numpy as jnp

type(jnp.array([[10,11]]))
jnp.array([[10,11]]).shape

model(jnp.array([[10,11]])).shape

jaxlib.xla_extension.ArrayImpl

(1, 2)

(1, 2, 50257)

# Model load success

In [107]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')

q = "Hello, I'm a language model,"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
tokens = jnp.repeat(tokens, 5, axis=0)
tokens.shape

(5, 8)

In [108]:
preds = model(tokens)
preds.shape
preds.shape[1]

(5, 8, 50257)

8

# Impatient generate attempt by converting JAX model outputs to Pytorch

In [1]:
import flax.nnx as nn
import torch
from torch.nn import functional as F
import numpy as np
import jax.numpy as jnp

import tiktoken
enc = tiktoken.get_encoding('gpt2')

q = "Hello, I'm a language model,"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
tokens = jnp.repeat(tokens, 5, axis=0)

from jax_gpt2 import GPT, GPTConfig
model = GPT.from_pretrained('gpt2')

# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("gpt2")
# q = "Hello, I'm a language model,"
# tokens = tokenizer.encode(q)
# tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
# tokens = jnp.repeat(tokens, 5, axis=0)

torch.manual_seed(42)

while tokens.shape[1] < 30: # max_length=30
    # forward the model to get the logits
    logits = model(tokens) # (B, T, vocab_size) 
    logits = torch.from_numpy(np.array(logits))
    # take the logits at the last position
    logits = logits[:, -1, :] # (B, vocab_size)
    # get the probabilities
    probs = F.softmax(logits, dim=-1)
    # do top-k sampling of 50 (huggingface pipeline default)
    # topk_probs here becomes (5, 50), topk_indices is (5, 50)
    topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
    # select a token from the top-k probabilities
    # note: multinomial does not demand the input to sum to 1
    ix = torch.multinomial(topk_probs, 1) # (B, 1)
    # gather the corresponding indices
    xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
    # append to the sequence
    tokens = torch.from_numpy(np.asarray(tokens))
    tokens = torch.cat((tokens, xcol), dim=1)
    tokens = jnp.array(tokens.cpu().numpy())

# print the generated text

tokens = torch.from_numpy(np.array(tokens))
for i in range(5):
    x = tokens[i, :30].tolist()
    decoded = enc.decode(x)
    print(">", decoded)

  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained gpt: gpt2
Length of pytorch state dict: 149
Length of prepared JAX modules dict: 76
Transposing  lm_head


  tokens = torch.from_numpy(np.asarray(tokens))


> Hello, I'm a language model, you want to work the first one important time,
With I'm't only 1-level A 'R
> Hello, I'm a language model, that you can't be asked this a big difference or the latest to the most time, he won't ask
> Hello, I'm a language model, we must be
I's now in the current and an final final, for more than The two. It
> Hello, I'm a language model, (This is still need to try, the same thing like in a good news in that may take them "
> Hello, I'm a language model, they could then have to the very small number of of 'In his new in the most a few days after


# Argmax attempts

In [112]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')

q = "Hello, I'm a language model,"
tokens = enc.encode(q)
tokens = jnp.expand_dims(jnp.array(tokens), axis=0)
tokens = jnp.repeat(tokens, 5, axis=0)

while tokens.shape[1] < 30: # max_length=30
    # forward the model to get the logits
    logits = model(tokens) # (B, T, vocab_size) 
    # take the logits at the last position
    logits = logits[:, -1, :] # (B, vocab_size)
    # get the probabilities
    probs = nn.softmax(logits, axis=-1)
    amax = probs.argmax(axis=-1)
    tokens = jnp.concatenate((tokens, jnp.vstack(amax)), axis=1)

# print the generated text

for i in range(5):
    x = tokens[i, :30].tolist()
    decoded = enc.decode(x)
    print(">", decoded)

> Hello, I'm a language model, the same, the same, the same, the same, the same, the same, the same, the
> Hello, I'm a language model, the same, the same, the same, the same, the same, the same, the same, the
> Hello, I'm a language model, the same, the same, the same, the same, the same, the same, the same, the
> Hello, I'm a language model, the same, the same, the same, the same, the same, the same, the same, the
> Hello, I'm a language model, the same, the same, the same, the same, the same, the same, the same, the


In [113]:
import torch
from torch.nn import functional as F
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2") # 124M
model.eval()
torch.manual_seed(42)
tokens = [15496, 11, 314, 1101, 257, 3303, 2746, 11] # "Hello, I'm a language model,"
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(5, 1) # (5, 8)
x = tokens

# generate!
while x.size(1) < 30: # max_length=30
    # forward the model to get the logits
    with torch.no_grad():
        logits = model(x)[0] # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)
        amax = probs.argmax(dim=-1).reshape(x.size(0),1)
        # append to the sequence
        x = torch.cat((x, amax), dim=1)

# print the generated text
import tiktoken
enc = tiktoken.get_encoding('gpt2')
for i in range(5):
    tokens = x[i, :30].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

<torch._C.Generator at 0x7f426f0da290>

> Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a language model
> Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a language model
> Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a language model
> Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a language model
> Hello, I'm a language model, not a programming language. I'm a language model. I'm a language model. I'm a language model


# Test a Pytorch GPT2, does it really work? It does

In [115]:
import torch
from torch.nn import functional as F
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2") # 124M
model.eval()
torch.manual_seed(42)
tokens = [15496, 11, 314, 1101, 257, 3303, 2746, 11] # "Hello, I'm a language model,"
tokens = torch.tensor(tokens, dtype=torch.long) # (8,)
tokens = tokens.unsqueeze(0).repeat(5, 1) # (5, 8)
x = tokens

# generate!
while x.size(1) < 30: # max_length=30
    # forward the model to get the logits
    with torch.no_grad():
        logits = model(x)[0] # (B, T, vocab_size)
        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)
        # get the probabilities
        probs = F.softmax(logits, dim=-1)        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1) # (B, 1)
        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
        # append to the sequence
        x = torch.cat((x, xcol), dim=1)

# print the generated text
import tiktoken
enc = tiktoken.get_encoding('gpt2')
for i in range(5):
    tokens = x[i, :30].tolist()
    decoded = enc.decode(tokens)
    print(">", decoded)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

<torch._C.Generator at 0x7f426f0da290>

> Hello, I'm a language model, not a programming platform! I just make decisions based on other projects. I try to do that."


> Hello, I'm a language model, a kind of a "first class citizen" of the world and a person that comes from a much more egalitarian
> Hello, I'm a language model, and I'm starting to talk about the notion of the syntax, and I'm also working on an extension that
> Hello, I'm a language model, because I'm writing real-time. I'm writing all languages. And I'm working with languages for me
> Hello, I'm a language model, I don't know where to begin but I know there is a big deal going on with our society. What


# Is it a datatype thing? No by the looks

In [118]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2") 
model.state_dict()["transformer.wte.weight"].dtype

torch.float32

# Tried a pure JAX generate and ran into a forever loop!! Thought it was slowness!!

# Continuing in troubleshoot_gibberish.ipynb