In [1]:
import time

import jax
import jax.numpy as jnp
import numpy as np
from flax import jax_utils
from flax.training.common_utils import shard

from transformers import FlaxMistralForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "./Mistral-7B-Instruct-v0.2"
max_input_tokens = 256
max_new_tokens = 256

In [3]:
# Load pre-trained model
model, params = FlaxMistralForCausalLM.from_pretrained(model_id, _do_init=False, dtype=jnp.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

tcmalloc: large alloc 9942974464 bytes == 0x12aeaa000 @  0x7faf9fcdd680 0x7faf9fcfe824 0x5d93d1 0x634ae1 0x5a23da 0x4c8bce 0x63afe8 0x4db8d3 0x547447 0x4e1a5e 0x54c8a9 0x54552a 0x4e1bd0 0x5483b6 0x54552a 0x684327 0x5e1514 0x5a27d0 0x547265 0x4d71f8 0x548c6b 0x4d71f8 0x548c6b 0x4d71f8 0x4daf8a 0x547447 0x5d5846 0x547265 0x5d5846 0x547447 0x54552a
tcmalloc: large alloc 4540514304 bytes == 0x12aeaa000 @  0x7faf9fcdd680 0x7faf9fcfe824 0x5d93d1 0x634ae1 0x5a23da 0x4c8bce 0x63afe8 0x4db8d3 0x547447 0x4e1a5e 0x54c8a9 0x54552a 0x4e1bd0 0x5483b6 0x54552a 0x684327 0x5e1514 0x5a27d0 0x547265 0x4d71f8 0x548c6b 0x4d71f8 0x548c6b 0x4d71f8 0x4daf8a 0x547447 0x5d5846 0x547265 0x5d5846 0x547447 0x54552a
Some of the weights of FlaxMistralForCausalLM were initialized in bfloat16 precision from the model checkpoint at ./Mistral-7B-Instruct-v0.2:
[('lm_head', 'kernel'), ('model', 'embed_tokens', 'embedding'), ('model', 'layers', '0', 'input_layernorm', 'weight'), ('model', 'layers', '0', 'mlp', 'down_proj'

In [4]:
tokenizer.pad_token = tokenizer.eos_token

In [5]:
input_text = 4 * ["The capital of France is"]
inputs = tokenizer(input_text, return_tensors="np", return_attention_mask=True, padding="max_length", max_length=max_input_tokens)

params = jax_utils.replicate(params)
inputs = shard(inputs.data)

In [6]:
input_text = [{"role": "user", "content": "The capital of France is"}]
input_text = tokenizer.apply_chat_template(input_text, tokenize=False)
inputs = tokenizer(4 * [input_text], return_tensors="np", return_attention_mask=True, padding="max_length", max_length=max_input_tokens)
inputs = shard(inputs.data)

In [7]:
def generate(inputs, params, max_new_tokens):
    generated_ids = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"], params=params, max_new_tokens=max_new_tokens, do_sample=True)
    return generated_ids.sequences

p_generate = jax.pmap(generate, "inputs", in_axes=(0, 0, None,), out_axes=0, static_broadcasted_argnums=(2,))
gen_ids = p_generate(inputs, params, max_new_tokens)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [8]:
def compute_tok_per_s(input_ids, generated_ids, runtime):
    input_ids = np.asarray(input_ids)
    input_ids = input_ids[input_ids != tokenizer.pad_token_id]
    total_inputs = np.prod(input_ids.shape)
    
    generated_ids = np.asarray(generated_ids)
    generated_ids = generated_ids[generated_ids != tokenizer.pad_token_id]
    total_outputs = np.prod(generated_ids.shape)
    
    tokens_generated = total_outputs - total_inputs
    tokens_per_s = tokens_generated / runtime
    return tokens_per_s

def chat_function(message, chat_history):
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="np")
    # in-case our inputs exceed the maximum length, we might need to cut them
    if input_ids.shape[1] > max_input_tokens:
        input_ids = input_ids[:, -max_input_tokens:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {max_input_tokens} tokens.")

    input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
    inputs = tokenizer(4 * [input_text], return_tensors="np", return_attention_mask=True, padding="max_length", max_length=max_input_tokens)
    inputs = shard(inputs.data)
    input_ids = inputs["input_ids"]

    start = time.time()
    pred_ids = p_generate(inputs, params, max_new_tokens)
    runtime = time.time() - start
    
    pred_ids = jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))
    pred_text = tokenizer.decode(np.array(pred_ids[0])[input_ids.shape[-1]:], skip_special_tokens=True)

    tok_per_s = compute_tok_per_s(input_ids, pred_ids, runtime)
    gr.Info(f"Tok/s: {round(tok_per_s, 2)}")
    return pred_text

In [9]:
import gradio as gr

chat_interface = gr.ChatInterface(chat_function)
chat_interface.queue().launch(share=True)

Running on local URL:  http://127.0.0.1:7860


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Running on public URL: https://29f71c69eb6de75b87.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [10]:
def chat_function():
    conversation = []
    while True:
        message = input()
        conversation.append({"role": "user", "content": message})
        #input_ids = tokenizer.apply_chat_template(conversation, return_tensors="np")
        # in-case our inputs exceed the maximum length, we might need to cut them
        #if input_ids.shape[1] > max_input_tokens:
        #    input_ids = input_ids[:, -max_input_tokens:]
        #    gr.Warning(f"Trimmed input from conversation as it was longer than {max_input_tokens} tokens.")
    
        input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
        inputs = tokenizer(4 * [input_text], return_tensors="np", return_attention_mask=True, padding="max_length", max_length=max_input_tokens)
        inputs = shard(inputs.data)
        input_ids = inputs["input_ids"]
    
        start = time.time()
        pred_ids = p_generate(inputs, params, max_new_tokens)
        runtime = time.time() - start
        
        pred_ids = jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1]))
        pred_text = tokenizer.decode(np.array(pred_ids[0])[input_ids.shape[-1]:], skip_special_tokens=True)

        # tok_per_s = compute_tok_per_s(input_ids, pred_ids, runtime)
        # print(f"Tok/s: {round(tok_per_s, 2)}")
        print("Response:", pred_text)
        conversation.append({"role": "assistant", "content": pred_text})

In [11]:
chat_function()

 Hello!


Response: Hello! How can I help you today? If you have any questions or need assistance with a particular topic, feel free to ask! I'll do my best to provide you with accurate and helpful information. Additionally, if you'd just like to chat or share some thoughts, I'd be happy to listen! So, how can I help you today?


 Sounds great


Response: I'm glad you think so! I'm here to help answer any questions you might have, or to provide assistance with any topics you'd like to explore. Additionally, if you'd just like to chat or share some thoughts, I'd be happy to listen! So, how can I help you today? Let me know if you have any specific questions or topics in mind, and I'll do my best to provide you with accurate and helpful information! If you don't have any specific questions or topics in mind, that's perfectly fine too! Feel free to share any thoughts or ideas you might have, or just chat with me about whatever topic you'd like! I'm here to help and engage with you in a positive and productive way! So, how can I help you today? Let me know if you have any specific questions or topics in mind, and I'll do my best to provide you with accurate and helpful information! If you don't have any specific questions or topics in mind, that's perfectly fine too! Feel free to share any thoughts or ideas you might have, or just

 What is the hottest country in the world?


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.

KeyboardInterrupt

