# Initial installs

In [None]:
!pip install -i https://pypi.org/simple/ bitsandbytes --upgrade --quiet

In [None]:
!pip install accelerate --upgrade --quiet

In [None]:
!pip install PyMuPDF --quiet

In [None]:
!pip install pyngrok --quiet

# Initial imports

In [None]:
import bitsandbytes
import accelerate

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, BertTokenizer, BertForSequenceClassification

In [None]:
from flask import Flask, request, jsonify, send_file
from pyngrok import ngrok
import requests
import gc
import os

In [None]:
from huggingface_hub import notebook_login
from google.colab import drive, userdata

In [None]:
from llama_index.core.prompts.prompts import SimpleInputPrompt
from llama_index.llms.huggingface import HuggingFaceLLM
from langchain.embeddings import HuggingFaceEmbeddings
from llama_index.embeddings.langchain import LangchainEmbedding

# HuggingFace/Drive interfacing

In [None]:
notebook_login()

In [None]:
drive.mount('/content/drive')

# GPU

In [None]:
# Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Free GPU Memory

In [None]:
def free_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Load in the model

In [None]:
# Load tokenizer and final trained model
model_directory = "/content/drive/MyDrive/saved_models/LLama2-7B-chat-PT1-v2"
auth_token = userdata.get('HF_TOKEN')

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_directory,
    local_files_only = True
)

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load the fine-tuned model
model = AutoModelForCausalLM.from_pretrained(
    model_directory,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    rope_scaling={"type": "dynamic", "factor": 2},
    local_files_only = True,
    quantization_config=quantization_config
)

In [None]:
model.eval()
print("Model loaded successfully.")

# Load initial classifier

In [None]:
classifier_tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-v1.1')
stage1_classifier = BertForSequenceClassification.from_pretrained('/content/drive/MyDrive/classifiers/v1')

In [None]:
stage1_classifier.to(device)
stage1_classifier.eval()
print("Classifer loaded successfully.")

In [None]:
def encode_data(tokenizer, texts, max_len=256):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [None]:
def classify(model, tokenizer, text, label_dict):
    input_ids, attention_masks = encode_data(tokenizer, [text])

    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)

    with torch.no_grad():
        outputs = model(input_ids, token_type_ids=None, attention_mask=attention_masks)

    prediction_idx = torch.argmax(outputs.logits, dim=1).item()
    return label_dict[prediction_idx]

In [None]:
label_dict = {0: 'quantitative analysis', 1: 'general information', 2: 'miscellaneous'}

# Webpage layout interface

In [None]:
STATIC_DIR = os.path.abspath('/content/interface/static')

# Main stream

In [None]:
# Initialize the Flask app and the context history
app = Flask("expert-bot", static_folder = STATIC_DIR)
context_history = []

In [None]:
@app.route("/")
def home():
    html_file_path = '/content/interface/index.html'
    with open(html_file_path, 'r') as file:
        html_content = file.read()

    return html_content

In [None]:
# Define a system prompt to guide the responses of the chatbot
system_prompt = """You are a helpful and informative assistant called "Assistant". Your goal is to provide accurate and relevant information to the user's queries.
Please ensure that your responses are succinct, respectful, and factual. Refrain from emoting.
If you're uncertain about a question, it's better to admit it rather than provide inaccurate information.
Respond to the User's question ONLY. Do not impersonate the User and do not include followup questions in your response unless prompted."""

In [None]:
context_history += [system_prompt]

In [None]:
@app.route("/interact", methods=["POST"])
def interact():
    global context_history
    data = request.get_json()
    user_input = data['query']

    branch = classify(stage1_classifier, classifier_tokenizer, user_input, label_dict)

    if branch == 'general information':
        # Append user input to context as needed
        context_history.append(f"User: {user_input}")

        # Generate the response using the current context, not repeating the user's input
        conversation = "\n".join(context_history)

        prompt = f"{conversation}\n Assistant: "

        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        outputs = model.generate(
            **inputs,
            max_length=350,
            temperature=0.5,
            top_p=0.75
        )

        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        assistant_response = response_text.split('Assistant:')[-1].strip()

        # Append the formatted response to the context
        context_history.append(f"Assistant: {assistant_response}")

        # Maintain a recent context window to avoid stale conversation artifacts
        if len(context_history) > 10:
            context_history = context_history[-10:]  # keep the last 10 exchanges

        # Only display the Assistant's response to the user, not the entire context
        response_to_display = assistant_response

        return jsonify({"answer": response_to_display})

    elif branch == 'quantitative analysis':
        return jsonify({"answer": "quantitative_placeholder"})

    else:
        return jsonify({"answer": "Sorry, I'm not able to help you with that. Please either rephrase the question or ask a different question."})

In [None]:
if __name__ == '__main__':
    public_url = ngrok.connect(7000)

    print(f"Flask app is running at {public_url}")

    # Run the Flask app
    app.run(host='0.0.0.0', port=7000)