In [None]:
import accelerate
import gradio as gr
import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from unsloth import FastLanguageModel


# Set the number of threads for Torch
torch.set_num_threads(1)

# Get Hugging Face token from environment variable
HF_TOKEN = '< your hf token >'

# Set device to CUDA if available, otherwise to CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load tokenizer and model from Hugging Face's model hub
MODEL_NAME = "< your model name >"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN).to(device)
FastLanguageModel.for_inference(model)


# Function to count tokens
def count_tokens(text):
    return len(tokenizer.tokenize(text))


In [None]:
# Function to generate model predictions
def predict(message, history):
    alpaca_format = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    ### Instruction:
    {}

    ### Input:
    {}

    ### Response:
    {}"""

    # if no input is presetn make it work
    if '###' not in message:
        message = message + ' ### '

    # Split message into instruction and input
    messages = message.split('###')
    model_inputs = tokenizer(alpaca_format.format(messages[0], messages[1], ""), return_tensors="pt").to(device)

    # Initialize TextIteratorStreamer
    streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)

    # Generate model kwargs
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=2048 - count_tokens(alpaca_format),
        top_p=0.2,
        top_k=20,
        temperature=0.1,
        repetition_penalty=2.0,
        length_penalty=-0.5,
        num_beams=1
    )

    # Start generation in a separate thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Yield partial message
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message


# Setting up the Gradio chat interface
gr.ChatInterface(predict,
                 title="Gemma 2b Instruct Chat",
                 description=None
                 ).launch(share=True)  # Launching the web interface.
