In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer

model = AutoModelForCausalLM.from_pretrained("../models/final_model")
tokenizer = AutoTokenizer.from_pretrained("baiges/CatGPT")

# Ensure model is on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to generate a response
# Function to generate a response
def generate_response(input_text, context=None):
    # Prepare the input text with special tokens
    input_text = f"<s> {input_text} </s> <s>"

    if context:
        input_text = context + " " + input_text

    # Tokenize input text
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    # Generate the output
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_length=512,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            do_sample=True,
            top_k=2,
            temperature=0.7,
            repetition_penalty=1.2,
        )
    print('GENERATED')

    # Separate the input tokens from the generated tokens
    generated_ids = output_ids[0][input_ids.size(-1):]

    # Decode the generated tokens to text
    decoded_output = tokenizer.decode(generated_ids, skip_special_tokens=True)

    # Stop generation if the model predicts the end of text token
    if "</s>" in decoded_output:
        decoded_output = decoded_output.split("</s>")[0].strip()

    return decoded_output

# Function to manage the context length
def update_context(context, user_input, response, max_length=750):
    # Add the new interaction to the context
    new_interaction = f"<s> {user_input} </s> <s> {response} </s> <s>"
    context += new_interaction

    # Tokenize the context to manage its length
    context_ids = tokenizer.encode(context)

    # If the context is too long, truncate it
    if len(context_ids) > max_length:
        # Keep the last `max_length` tokens
        context_ids = context_ids[-max_length:]

    # Decode back to string
    context = tokenizer.decode(context_ids, skip_special_tokens=True)

    return context

# Main chatbot loop
def chatbot():
    context = ""
    print("Bot is ready. Type 'exit' to stop.")

    while True:
        user_input = input("User: ")
        if user_input.lower() == "exit":
            break

        # Generate a response from the model
        response = generate_response(user_input, context)

        # Display the response
        print(f"Bot: {response}")

        # Update context and manage its length
        context = update_context(context, user_input, response)

if __name__ == "__main__":
    chatbot()

Bot is ready. Type 'exit' to stop.
GENERATED
Bot:  Bé, gràcies. 
GENERATED
Bot:  No, no ho sóc. 
GENERATED
Bot:  No, no ho sóc. 
GENERATED
Bot:  No ho sé, no puc suportar-ho, però no puc suportar-ho. 
GENERATED
Bot:  No, no puc suportar-ho, però no puc suportar-ho. 


KeyboardInterrupt: Interrupted by user