In [8]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
config = GPT2Config()
# Apparently GPT2 ties the last linear layer to the initial word embeddings
# so the final layer is wte.embedding.T (768, 50257)
# config.tie_word_embeddings = False
hf_model = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
outputs = hf_model(**inputs)

# retrieve logts for next token
next_token_logits = outputs.logits[:, -1]

In [102]:
next_token_logits.shape

(1, 50257)

In [1]:
def get_keys(params):
    out = []
    def print_keys_flat(params, prefix=""):
        for p in params:
            if isinstance(params[p], dict):
                print_keys_flat(params[p], f"{prefix}.{p}")
            else:
                out.append((f"{prefix}.{p}", params[p].shape))
    print_keys_flat(params)
    return out
            
def print_keys(params, offset=""):
    for p in params:
        if not isinstance(params[p], dict):
            print(offset+p, params[p].shape)
        else:
            print(offset+p)
            print_keys(params[p], offset + "  ")

In [95]:
from transformers import pipeline, set_seed

generator = pipeline('text-generation', model='gpt2')
set_seed(42)
generator("Hello, I'm a language model,", max_length=30, num_return_sequences=5)

Device set to use mps:0
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Hello, I'm a language model, and my project will get better with time, but I think there are a lot more things that can help you"},
 {'generated_text': "Hello, I'm a language model, not a language model, so if I don't have a problem, I can fix it by creating new words"},
 {'generated_text': "Hello, I'm a language model, and I'm trying to learn some stuff. I'll try to do some basic programming and just learn better ways"},
 {'generated_text': "Hello, I'm a language model, but I don't believe in grammar. This will work for every language model. You can define it very quickly"},
 {'generated_text': 'Hello, I\'m a language model, a model of how things should be, and then we look at different things as well." I\'d like to'}]

In [2]:
import gpt2 as nn
from importlib import reload
reload(nn)
import jax.numpy as jnp
import jax

model = nn.GPT2(nn.Config())
params = model.from_pretrained('gpt2')

# model_blank = nn.GPT2(nn.Config())
key = jax.random.PRNGKey(42069)
# dummy = jnp.ones((1,1), dtype=int)
# params = model.init(key,dummy)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from functools import partial
block_size = nn.Config().block_size

def _gen_step(carry, rng):
    window, = carry  # window: (B, block_size)
    logits = model.apply(params, window)
    next_token = jax.random.categorical(rng, logits[:, -1, :])  # (B,)
    new_window = jnp.concatenate([window[:, 1:], next_token[:,None]], axis=1) #add next_token[:,None] here instead
    return (new_window,), next_token #return next_token without adding extra dimension


# @partial(jax.jit, static_argnums=(3,))
def generate_batch(params, init_idx, key, max_new_tokens: int):
    """
    params:        your model params pytree
    init_idx:      int32 array of shape (B, T0) with T0 <= block_size
    key:           a PRNGKey
    max_new_tokens: number of new tokens to sample
    returns: full_seq (B, T0+max_new_tokens), all_new_tokens (max_new_tokens, B)
    """
    B, T0 = init_idx.shape
    assert T0 <= block_size, f"Context length must be ≤ block_size ({block_size}), got {T0}"

    # left-pad init_idx up to block_size so our carry window is fixed-size
    pad_len     = block_size - T0
    init_window = jnp.pad(init_idx, ((0,0), (pad_len,0)), constant_values=0)  # (B, block_size)

    # split RNG into one key per token
    keys = jax.random.split(key, max_new_tokens)

    # run the scan
    (final_window,), new_tokens = jax.lax.scan(
        _gen_step,
        (init_window,),  # initial carry
        keys             # scan over these RNGs
    )
    # new_tokens: (max_new_tokens, B)

    # rebuild the full generated sequence
    #   - take the tail of the init_window to recover the original context
    #   - concatenate with the newly sampled tokens
    context = init_window[:, pad_len:]               # (B, T0)
    #gen_seq = jnp.transpose(new_tokens, (1,0))       # (B, max_new_tokens)
    gen_seq = new_tokens.reshape(new_tokens.shape[1], new_tokens.shape[0]) # Reshape new_tokens to (B, max_new_tokens)
    #gen_seq = new_tokens.squeeze()
    full_seq = jnp.concatenate([context, gen_seq], axis=1)  # (B, T0 + max_new_tokens)

    return full_seq, new_tokens

In [7]:
#inputs = jnp.array(inputs['input_ids'])
import tiktoken
import numpy as np
enc = tiktoken.get_encoding('gpt2')
key, gen_key = jax.random.split(key)
start = "Hello, I'm a language model,"
inputs = jnp.array(enc.encode(start)).reshape(1,-1)
out = model.generate(key, params, inputs, 100)
# out, tokens = generate_batch(params, inputs, gen_key, 2)
out

Array([[15496,    11,   314,  1101,   257,  3303,  2746,    11,   290,
          406, 12582,  8053,   318,   257,  3303,  2746,    13,   770,
          318,   262,  3061,   286,   428,  1492,    13,  3914,   338,
          766,   352,    13, 14365, 43506,   279,  6619,   546,   362,
           13, 14365, 10888, 36883, 27992,   513,    13, 14365, 32448,
         2420, 32096,   604,    13, 13610,   617,  3621,   513,    35,
         5563,   642,    13, 10934,  1223,   422,   262,  2323,   510,
          198, 29800,    12, 16129,  1486,   198, 20570, 31026,   198,
        27871,   320,  5612,   286,  3797,  6570, 23154,   290,  2296,
          315,  1799,   287,  7386, 34175,   198,   464, 16585, 15806,
         1781,    13,   198, 41730,  3498, 32144,   198, 41730,  3498,
          481,  6486,   287,   262,  4569,   314,    14,    46,  1490]],      dtype=int32)

In [8]:
print(enc.decode(out[0]))

Hello, I'm a language model, and LazyList is a language model. This is the goal of this book. Let's see 1. Learn ruby p talked about 2. Learn Ruby scripting tutorials 3. Learn Scheme text parsing 4. Create some nice 3D objects 5. Build something from the ground up
Multi-language design
Getting Started
Optimisation of cataphysics and immutability in domain constructs
The chemistry classroom course.
Learning Lab Steps
Learning Lab will lie in the traditional I/O vis
