In [1]:
import time

import jax
import jax.numpy as jnp
import numpy as np
from flax import jax_utils
from flax.training.common_utils import shard

from transformers import FlaxGemmaForCausalLM, GemmaTokenizerFast

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "google/gemma-7b-it"
max_input_tokens = 1024
max_new_tokens = 256

In [3]:
# Load pre-trained model
model, params = FlaxGemmaForCausalLM.from_pretrained(model_id, revision="flax", _do_init=False, dtype=jnp.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained(model_id)

tcmalloc: large alloc 4894842880 bytes == 0x11df08000 @  0x7fb8842d3680 0x7fb8842f4824 0x5d93d1 0x634ae1 0x5a23da 0x4c8bce 0x63afe8 0x4db8d3 0x547447 0x4e1a5e 0x54c8a9 0x54552a 0x4e1bd0 0x5483b6 0x54552a 0x684327 0x5e1514 0x5a27d0 0x547265 0x4d71f8 0x548c6b 0x4d71f8 0x548c6b 0x4d71f8 0x4daf8a 0x547447 0x5d5846 0x547265 0x5d5846 0x547447 0x54552a
tcmalloc: large alloc 1572864000 bytes == 0x369b44000 @  0x7fb8842d3680 0x7fb8842f4824 0x7fb8728c1994 0x7fb8728c212f 0x7fb8729208f5 0x7fb8729c4329 0x7fb8729c4a77 0x7fb8729c4bcc 0x6af68d 0x7fb872909854 0x5d553a 0x5d6066 0x54ca58 0x54552a 0x5d5a23 0x54c8a9 0x5d5846 0x547265 0x5d5846 0x547265 0x5d5846 0x547265 0x5d5846 0x547265 0x5d5846 0x547265 0x5d5846 0x547265 0x4e1a5e 0x54c8a9 0x54552a
tcmalloc: large alloc 4982947840 bytes == 0x3c7744000 @  0x7fb8842d3680 0x7fb8842f4824 0x5d93d1 0x634ae1 0x5a23da 0x4c8bce 0x63afe8 0x4db8d3 0x547447 0x4e1a5e 0x54c8a9 0x54552a 0x4e1bd0 0x5483b6 0x54552a 0x684327 0x5e1514 0x5a27d0 0x547265 0x4d71f8 0x548c6b 0x4d

In [4]:
input_text = 4 * ["The capital of France is"]
input_ids = tokenizer(input_text, return_tensors="np", padding="max_length", max_length=max_input_tokens).input_ids

params = jax_utils.replicate(params)
input_ids = shard(input_ids)

In [5]:
def generate(input_ids, params, max_new_tokens):
    generated_ids = model.generate(input_ids, params=params, max_new_tokens=max_new_tokens, do_sample=True)
    return generated_ids.sequences

p_generate = jax.pmap(generate, "input_ids", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,))
_ = p_generate(input_ids, params, max_new_tokens)



In [56]:
def compute_tok_per_s(input_ids, generated_ids, runtime):
    total_inputs = np.prod(input_ids.shape)
    total_outputs = np.prod(generated_ids.shape)
    tokens_generated = total_outputs - total_inputs
    tokens_per_s = tokens_generated / runtime
    return tokens_per_s

def chat_function(message, chat_history):
    conversation = [{"role": "assistant", "content": "You are a helpful assistant. Reply to the user in full paragraphs and be polite."}]
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="np")
    # in-case our inputs exceed the maximum length, we might need to cut them
    if input_ids.shape[1] > max_input_tokens:
        input_ids = input_ids[:, -max_input_tokens:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {max_input_tokens} tokens.")

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="np", padding="max_length", max_length=max_input_tokens)
    input_ids = np.vstack(4 * [input_ids])
    input_ids = shard(input_ids)

    start = time.time()
    pred_ids = p_generate(input_ids, params, max_new_tokens)
    runtime = time.time() - start
    
    pred_ids = jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))
    pred_text = tokenizer.decode(np.array(pred_ids[0])[input_ids.shape[-1]:], skip_special_tokens=True)

    tok_per_s = compute_tok_per_s(input_ids, pred_ids, runtime)
    gr.Info(f"Tok/s: {round(tok_per_s, 2)}")
    return pred_text

In [57]:
import gradio as gr

chat_interface = gr.ChatInterface(chat_function)
chat_interface.queue().launch(share=True)

Running on local URL:  http://127.0.0.1:7869


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Running on public URL: https://cde34eac7b5cc9a554.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


