# Initial installs

In [None]:
!pip install -i https://pypi.org/simple/ bitsandbytes --upgrade --quiet

In [None]:
!pip install accelerate --upgrade --quiet

In [None]:
!pip install PyMuPDF --quiet

In [None]:
!pip install pyngrok --quiet

# Imports

In [None]:
import bitsandbytes
import accelerate

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, BertTokenizer, BertForSequenceClassification

In [None]:
from flask import Flask, request, jsonify, send_file
from pyngrok import ngrok
import requests
import gc
import os

In [None]:
from huggingface_hub import notebook_login
from google.colab import drive, userdata

In [None]:
import llama_index
import llama_index.readers
import llama_index.readers.file
from llama_index.readers.file import PyMuPDFReader
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, download_loader
from llama_index.core.prompts.prompts import SimpleInputPrompt
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.langchain import LangchainEmbedding
from langchain.embeddings import HuggingFaceEmbeddings

from pathlib import Path

# HuggingFace/Drive interfacing

In [None]:
notebook_login()

In [None]:
drive.mount('/content/drive')

# GPU

In [None]:
# Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Free GPU Memory

In [None]:
def free_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

# Load in the model

In [None]:
# Load tokenizer and final trained model
model_directory = "/content/drive/MyDrive/saved_models/LLama2-7B-chat-PT1-v2"
auth_token = userdata.get('HF_TOKEN')

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_directory,
    local_files_only = True
)

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load the fine-tuned model
model = AutoModelForCausalLM.from_pretrained(
    model_directory,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    rope_scaling={"type": "dynamic", "factor": 2},
    local_files_only = True,
    quantization_config=quantization_config
)

In [None]:
model.eval()
print("Model loaded successfully.")

# Load initial classifier

In [None]:
classifier_tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-v1.1')
stage1_classifier = BertForSequenceClassification.from_pretrained('/content/drive/MyDrive/classifiers/v1')

In [None]:
stage1_classifier.to(device)
stage1_classifier.eval()
print("Classifer loaded successfully.")

In [None]:
def encode_data(tokenizer, texts, max_len=256):
    input_ids = []
    attention_masks = []

    for text in texts:
        encoded = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_len,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoded['input_ids'])
        attention_masks.append(encoded['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)

    return input_ids, attention_masks

In [None]:
def classify(model, tokenizer, text, label_dict):
    input_ids, attention_masks = encode_data(tokenizer, [text])

    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)

    with torch.no_grad():
        outputs = model(input_ids, token_type_ids=None, attention_mask=attention_masks)

    prediction_idx = torch.argmax(outputs.logits, dim=1).item()
    return label_dict[prediction_idx]

In [None]:
label_dict = {0: 'quantitative analysis', 1: 'general information', 2: 'miscellaneous'}

# Webpage layout interface

In [None]:
STATIC_DIR = os.path.abspath('/content/interface/static')

# Branches

## Knowledge branch

In [None]:
knowledge_system_prompt = """[INST] <<SYS>>
You are an informative assistant called "Assistant". Your goal is to provide accurate and relevant information about Cardiovascular disease and adjacent topics in response to the user's queries.
Please ensure that your responses are informative, helpful, direct, dispassionate, and factual. Respond in plain English, and aim for your response to be at least 3 sentences in length.
If you're uncertain about a question, it's better to admit it rather than provide inaccurate information.
<</SYS>>
"""

query_wrapper_prompt = SimpleInputPrompt("{query_str}\nAssistant: [/INST]")

In [None]:
# Create a HF LLM using the llama index wrapper
knowledge_llm = HuggingFaceLLM(
    context_window = 4096,
    max_new_tokens = 512,
    generate_kwargs = {"temperature": 0.6},
    system_prompt = knowledge_system_prompt,
    query_wrapper_prompt = query_wrapper_prompt,
    model = model,
    tokenizer = tokenizer
)

In [None]:
# Create and dl embeddings instance
embeddings = LangchainEmbedding(
    HuggingFaceEmbeddings(model_name = "all-MiniLM-L6-v2")
)

In [None]:
# Function to load and index multiple PDF documents
def load_and_index_documents(directory_path):
    loader = PyMuPDFReader()
    all_documents = []
    for pdf_file in Path(directory_path).rglob('*.pdf'):
        documents = loader.load(file_path = pdf_file, metadata = True)
        all_documents.extend(documents)

    # Create an index with all documents
    index = VectorStoreIndex.from_documents(
        all_documents,
        embed_model = embeddings
    )
    return index

In [None]:
# Load and index documents from a specified directory
directory_path = '/content/drive/MyDrive/data/'
doc_index = load_and_index_documents(directory_path)

def serve_knowledge(prompt):
    # Setup index query engine using LLM
    query_engine = doc_index.as_query_engine(llm = knowledge_llm)

    response = query_engine.query(prompt)
    return {"answer": response.response.strip()}

## Quantitative branch

# Main stream

In [None]:
# Initialize the Flask app and the context history
app = Flask("expert-bot", static_folder = STATIC_DIR)
context_history = []

In [None]:
@app.route("/")
def home():
    html_file_path = '/content/interface/index.html'
    with open(html_file_path, 'r') as file:
        html_content = file.read()

    return html_content

In [None]:
# Define a system prompt to guide the responses of the chatbot
system_prompt = """You are a helpful and informative assistant called "Assistant". Your goal is to provide accurate and relevant information to the user's queries.
Please ensure that your responses are succinct, respectful, and factual. Refrain from emoting.
If you're uncertain about a question, it's better to admit it rather than provide inaccurate information.
Respond to the User's question ONLY. Do not impersonate the User and do not include followup questions in your response unless prompted."""

In [None]:
context_history += [system_prompt]

In [None]:
@app.route("/interact", methods=["POST"])
def interact():
    global context_history
    data = request.get_json()
    user_input = data['query']

    branch = classify(stage1_classifier, classifier_tokenizer, user_input, label_dict)

    if branch == 'general information':
        # Append user input to context as needed
        context_history.append(f"User: {user_input}")

        # Knowledge-based questions keep a context for the conversation.
        conversation = '\n'.join(context_history)
        response_text = serve_knowledge(conversation)

        # Append the formatted response to the context
        context_history.append(f"Assistant: {response_text['text']}")

        # Maintain a recent context window to avoid stale conversation artifacts
        if len(context_history) > 9:
            # keep the last 9 exchanges (4 User/Assistant pairs and the System prompt)
            context_history = context_history[0] + context_history[-8:]

        return jsonify(response_text)

    elif branch == 'quantitative analysis':
        return jsonify({"answer": "quantitative_placeholder"})

    else:
        return jsonify({"answer": "Sorry, I'm not able to help you with that. Please either rephrase the question or ask a different question."})

In [None]:
if __name__ == '__main__':
    public_url = ngrok.connect(7000)

    print(f"Flask app is running at {public_url}")

    # Run the Flask app
    app.run(host='0.0.0.0', port=7000)