In [1]:
import tkinter as tk
from tkinter import scrolledtext
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
# Specify the path to your model folder
model_path = "Model-4-bit"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

# device_map = {
#     "transformer.h.0": "cpu",  # Load the first layer on the CPU
#     "transformer.h.1": "cuda:0",  # Load the second layer on the GPU
#     "lm_head": "cpu",  # Load the output head on the CPU
#     "default": "cuda:0"  # Load everything else on the GPU
# }

# Load the model in 4-bit precision
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,  # You can use torch.bfloat16 if supported and needed
    load_in_4bit=True,  # This is important for loading the 4-bit model
    device_map="cuda:0",  # Automatically selects the available device (GPU/CPU)
)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
# Alpaca prompt template
alpaca_prompt = """Respond to the given text the way Rick would in the show Rick and Morty using as much context from the input to harbor a response

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token  # Must add EOS_TOKEN

# Function to format the prompts
def formatting_prompts_func(examples):
    inputs = examples["Input"]
    outputs = examples["Output"]
    texts = []
    for input_text, output_text in zip(inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(input_text, output_text) + EOS_TOKEN
        texts.append(text)
    return {"text": texts,}

In [3]:
conversation_history = []

def format_conversation(history):
    # Concatenate the conversation history with prompts and responses
    formatted_history = ""
    for prompt, response in history:
        formatted_history += alpaca_prompt.format(prompt, response)
    return formatted_history

def chat(model, tokenizer, max_history=5):
    def send_message():
        user_input = input_box.get()
        if user_input.lower() == "exit":
            root.quit()

        # Append the user input to the conversation history
        conversation_history.append((user_input, ""))
        
        # Keep only the last `max_history` turns
        conversation_history[:] = conversation_history[-max_history:]

        # Format the conversation history for the model
        formatted_history = format_conversation(conversation_history)
        
        # Tokenize the formatted history
        inputs = tokenizer([formatted_history], return_tensors="pt").to("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # Generate a response
        outputs = model.generate(**inputs, max_new_tokens=128)
        
        # Decode the generated response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the bot's response after the last prompt
        bot_response = response.split(alpaca_prompt.format(user_input, ""))[-1].strip()
        
        # Update the last user input with the model response in the conversation history
        conversation_history[-1] = (user_input, bot_response)
        
        # Display the conversation
        chat_history.config(state=tk.NORMAL)
        chat_history.insert(tk.END, f"You: {user_input}\nBot: {bot_response}\n")
        chat_history.config(state=tk.DISABLED)
        chat_history.yview(tk.END)
        
        # Clear the input box
        input_box.delete(0, tk.END)

    # Setup tkinter GUI
    root = tk.Tk()
    root.title("Chat with LLM")

    chat_history = scrolledtext.ScrolledText(root, wrap=tk.WORD, state=tk.DISABLED)
    chat_history.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)

    input_box = tk.Entry(root, width=100)
    input_box.pack(padx=10, pady=10, fill=tk.X, expand=True)
    input_box.bind("<Return>", lambda event: send_message())

    send_button = tk.Button(root, text="Send", command=send_message)
    send_button.pack(padx=10, pady=10)

    root.mainloop()