<a href="https://colab.research.google.com/github/peremartra/FinLLMOpt/blob/FinChat-XS-Instruct/FinChat-XS/Finchat_Gradio_Interface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q gradio

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.3/62.3 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.1/322.1 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.2/11.2 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.0/72.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.3/62.3 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
# Detect device: use CUDA if available, otherwise check for MPS (Apple Silicon), else CPU.
device = "cuda" if torch.cuda.is_available() else "mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"

# Cache for loaded models to avoid reloading on every request.
model_cache = {}

In [4]:
FINCHAT_NAME = "oopere/FinChat-XS"

In [5]:
# System prompts optimized for each model
default_system_prompt = "You are FinChat, a helpful AI assistant with expertise in finance. For general questions, respond naturally and concisely. Only use your financial knowledge when questions specifically relate to markets, investments, or financial concepts."
MODEL_SYSTEM_PROMPTS = {
    FINCHAT_NAME: default_system_prompt,

    "HuggingFaceTB/SmolLM2-360M-Instruct": default_system_prompt,

    "meta-llama/Llama-3.2-1B-Instruct": default_system_prompt
}

# Generation parameters optimized for each model
MODEL_PARAMS = {
    FINCHAT_NAME: {
        "max_new_tokens": 600,
        "temperature": 0.1,
        "repetition_penalty":1.2,
        "top_p": 1,
    },
    "HuggingFaceTB/SmolLM2-360M-Instruct": {
        "max_new_tokens": 600,
        "temperature": 0.1,
        "top_p": 1,
    },
    "meta-llama/Llama-3.2-1B-Instruct": {
        "max_new_tokens": 600,
        "temperature": 0.1,
        "top_p": 0.9,
    }
}

In [6]:
def load_model(model_name):
    """
    Loads and caches the tokenizer and model from Hugging Face.
    """
    if model_name not in model_cache:
        try:
            print(f"Loading {model_name} on {device} ...")
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            # Ensure we have a pad token (using eos_token if not available)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else "auto"
            )
            model.to(device)
            model_cache[model_name] = (tokenizer, model)
            return tokenizer, model
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise e
    return model_cache[model_name]

def load_model_action(model_choice):
    """
    Action triggered by the 'Load Model' button.
    Loads the selected model and returns a status message.
    """
    try:
        load_model(model_choice)
        return (f"Model '{model_choice}' loaded successfully.", [])
    except Exception as e:
        return (f"Error loading model '{model_choice}': {str(e)}", [])

def clean_assistant_response(text):
    """
    Cleans up the assistant response by removing common prefixes.
    More sophisticated than just stripping 'assistant'.
    """
    # List of prefixes to remove
    prefixes = ['assistant:', 'assistant', '<assistant>:', '<assistant>', 'AI:', 'FinChat:']

    # Try each prefix
    for prefix in prefixes:
        if text.lower().startswith(prefix.lower()):
            text = text[len(prefix):].strip()

    return text.strip()


In [7]:
def chat(model_choice, message, history):
    """
    Handles the chat interaction with the selected model.
    """
    # Initialize history as list of (user, assistant) tuples if not provided
    if history is None or len(history) == 0:
        history = []

    # Exit early if no message
    if not message.strip():
        return history, history

    try:
        tokenizer, model = load_model(model_choice)

        #Get system prompt for the selected model (or use default)
        system_prompt = MODEL_SYSTEM_PROMPTS.get(
            model_choice,
            default_system_prompt
        )

        # Get generation parameters for the selected model (or use default)
        gen_params = MODEL_PARAMS.get(
            model_choice,
            {"max_new_tokens": 150, "temperature": 0.2, "top_p": 0.9}
        )

        # Build conversation
        conversation = [] #[{"role": "system", "content": system_prompt}]

        # Add history
        for user_text, bot_text in history:
            conversation.append({"role": "user", "content": user_text})
            conversation.append({"role": "assistant", "content": bot_text})

        # Add new message
        conversation.append({"role": "user", "content": message})

        # Try to use apply_chat_template safely
        try:
            input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
        except Exception as e:
            print(f"Error applying chat template: {e}")
            # Fallback for models without proper chat templates
            input_text = f"{system_prompt}\n\nUser: {message}\nAssistant:"

        # Generate response
        inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)

        # Ensure input isn't too long
        if inputs.shape[1] > tokenizer.model_max_length:
            inputs = inputs[:, -tokenizer.model_max_length:]

        with torch.no_grad():
            output_ids = model.generate(
                inputs,
                **gen_params,
                do_sample=True,
            )

        # Decode only the newly generated tokens
        output_text = tokenizer.decode(output_ids[0][inputs.shape[1]:], skip_special_tokens=True)

        # Clean up the response
        output_text = clean_assistant_response(output_text)

        # Update history
        history.append((message, output_text))
        return history, history

    except Exception as e:
        # Return error message as bot response
        error_message = f"Error: {str(e)}"
        history.append((message, error_message))
        return history, history

# Function to clear chat history
def clear_history():
    return [], []

# Build the Gradio interface using Blocks
with gr.Blocks(css="footer {visibility: hidden}") as demo:
    gr.Markdown("# FinChat Personal Finance Assistant")

    with gr.Row():
        with gr.Column(scale=3):
            model_choice = gr.Dropdown(
                choices=[
                    FINCHAT_NAME,
                    "HuggingFaceTB/SmolLM2-360M-Instruct",
                    "meta-llama/Llama-3.2-1B-Instruct"
                ],
                #value=FINCHAT_NAME,
                label="Select Model"
            )

        with gr.Column(scale=1):
            load_button = gr.Button("Load Model")

        with gr.Column(scale=3):
            load_status = gr.Textbox(label="Model Status", interactive=False)

    chatbot = gr.Chatbot(label="Chat Conversation", height=200)

    with gr.Row():
        message = gr.Textbox(
            label="Your Message",
            placeholder="Ask me about personal finance, budgeting, investments, etc.",
            lines=2
        )

    with gr.Row():
        clear_btn = gr.Button("Clear Conversation")
        send_button = gr.Button("Send", variant="primary")

    # State to store conversation history
    state = gr.State([])

    # Link buttons to functions
    load_button.click(
        fn=load_model_action,
        inputs=model_choice,
        outputs=[load_status, state]
    )

    clear_btn.click(
        fn=clear_history,
        inputs=[],
        outputs=[chatbot, state]
    )

    # Link both the Send button and pressing Enter in the textbox to send a message
    send_button.click(
        fn=chat,
        inputs=[model_choice, message, state],
        outputs=[chatbot, state]
    )

    message.submit(
        fn=chat,
        inputs=[model_choice, message, state],
        outputs=[chatbot, state]
    )

    # Load the default model on startup
    #demo.load(
    #    fn=load_model_action,
    #    inputs=model_choice,
    #    outputs=load_status
    #)




In [8]:
# Launch the interface
demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b6672befeb76c0c751.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


