<a href="https://colab.research.google.com/github/sheldonkemper/bank_of_england/blob/main/notebooks/modelling/sk_gen_ai_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain

class BankEarningsChatbot:
    """
    A class-based implementation of an LLM Retrieval Augmented Generation (RAG) engine
    designed to analyze bank quarterly earnings call transcripts. It loads PDF documents,
    builds a Chroma vector index, and sets up an interactive conversational chain for prompt-based queries.
    """

    def __init__(self, pdf_folder, persist_directory="test_index", max_new_tokens=300):
        self.pdf_folder = pdf_folder
        self.persist_directory = persist_directory

        # Set up the LLM pipeline using the SmolLM model.
        self._setup_llm(max_new_tokens)

        # Configure embeddings.
        self._setup_embeddings()

        # Load documents from PDF files.
        self._load_documents()

        # Build the vector index.
        self._build_vector_index()

        # Configure retriever from the persisted vector database.
        self._setup_retriever()

        # Initialize conversation memory.
        self.memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", return_messages=True)

        # Set up the conversational retrieval chain.
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.retriever,
            memory=self.memory,
            verbose=False
        )

    def _setup_llm(self, max_new_tokens):
        """
        Configures the language model pipeline.
        """
        self.model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=max_new_tokens,
            temperature=0.5,
            top_p=0.8,
            repetition_penalty=1.3,
            do_sample=True
        )
        self.llm = HuggingFacePipeline(pipeline=self.pipe)

    def _setup_embeddings(self):
        """
        Initializes the sentence-transformer based embeddings.
        """
        self.embedding_model = "sentence-transformers/all-mpnet-base-v2"
        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

    def _load_documents(self):
        """
        Loads and splits PDF documents from the specified folder.
        """
        self.documents = []
        file_names = [f for f in os.listdir(self.pdf_folder) if f.endswith(".pdf")]
        for pdf_file in file_names:
            pdf_path = os.path.join(self.pdf_folder, pdf_file)
            try:
                loader = PyPDFLoader(pdf_path, extract_images=False)
                self.documents.extend(loader.load_and_split())
                print(f"Loaded: {pdf_file}")
            except Exception as e:
                print(f"Error loading {pdf_file}: {e}")

    def _build_vector_index(self):
        """
        Chunks documents and builds a Chroma vector index.
        """
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        chunks = text_splitter.split_documents(self.documents)
        # Remove duplicate chunks.
        chunks = self.remove_duplicate_chunks(chunks)
        self.chunks = chunks
        self.db = Chroma.from_documents(chunks, embedding=self.embeddings, persist_directory=self.persist_directory)
        self.db.persist()

    def _setup_retriever(self):
        """
        Initializes the retriever from the persisted vector database.
        """
        self.vectordb = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)
        self.retriever = self.vectordb.as_retriever(search_kwargs={"k": 3})

    @staticmethod
    def remove_duplicate_chunks(chunks):
        """
        Eliminates duplicate document chunks based on their content.
        """
        seen = set()
        unique_chunks = []
        for chunk in chunks:
            chunk_text = chunk.page_content.strip()
            if chunk_text not in seen:
                seen.add(chunk_text)
                unique_chunks.append(chunk)
        return unique_chunks

    def truncate_context(self, context_list, max_tokens=1000):
        """
        Truncates the retrieved context to avoid overloading the model's input.
        """
        truncated_docs = []
        current_tokens = 0
        for doc in context_list:
            doc_tokens = len(self.tokenizer.encode(doc.page_content))
            if current_tokens + doc_tokens <= max_tokens:
                truncated_docs.append(doc)
                current_tokens += doc_tokens
            else:
                break
        return truncated_docs

    @staticmethod
    def clean_user_input(user_input):
        """
        Cleans and standardizes user input.
        """
        return user_input.strip().replace("\n", " ").replace("\t", " ")

    def reset_memory_if_needed(self):
        """
        Clears conversation history if the number of exchanges exceeds a threshold.
        """
        if len(self.memory.chat_memory.messages) > 6:
            print("\nMemory Full: Resetting Conversation History...\n")
            self.memory.clear()

    def format_response(self, question, response):
        """
        Formats the output to clearly present both the question and the answer.
        """
        response_text = response.strip()
        unwanted_phrases = [
            "Use the following pieces of context",
            "If you don't know the answer, just say that you don't know",
            "Don't try to make up an answer."
        ]
        for phrase in unwanted_phrases:
            if phrase in response_text:
                response_text = response_text.split(phrase)[-1].strip()
        return f"Question: {question}\nHelpful Answer: {response_text}"

    def trim_final_input(self, question, context, max_tokens=1024):
        """
        Truncates the final input to meet the token limit, preserving document metadata.
        """
        system_message = (
            "You are analyzing a bank's quarterly earnings call transcript.\n"
            "Extract and summarize key financial insights, avoiding unnecessary details.\n"
            "If the answer isn't found, respond with 'I don't know.'\n"
            "Provide sources for your answers at the end."
        )
        input_text = f"{system_message}\n\nContext:\n{context}\n\nQuestion: {question}"
        tokens = self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens)
        return self.tokenizer.decode(tokens)

    def answer_question(self, question):
        """
        Processes the user query: retrieves context, prepares the prompt,
        and returns a formatted answer. If no relevant documents are retrieved,
        a fallback message is returned.
        """
        question = self.clean_user_input(question)
        self.reset_memory_if_needed()

        # Retrieve and process context.
        context = self.retriever.get_relevant_documents(question)
        context = self.remove_duplicate_chunks(context)
        context = self.truncate_context(context, max_tokens=800)

        # Fallback: if no relevant context is found.
        if not context:
            return f"Question: {question}\nHelpful Answer: I don't have information regarding that query."

        print("\nRetrieved Context:")
        for doc in context:
            source = doc.metadata.get('source', 'Unknown Source')
            page = doc.metadata.get('page', 'Unknown Page')
            print(f"- Source: {source}, Page: {page}")

        formatted_input = self.trim_final_input(question, context, max_tokens=1024)
        response = self.qa_chain({"question": formatted_input})
        return self.format_response(question, response['answer'])

    def run_chatbot(self):
        """
        Initiates an interactive loop for prompt-based queries.
        """
        print("\n💬 Bank Earnings Chatbot (Type 'exit' to stop)")
        while True:
            user_input = input("\nYou: ")
            if user_input.lower() == "exit":
                print("\nExiting Chatbot. Have a great day!")
                break
            answer = self.answer_question(user_input)
            print("\n" + answer)


In [None]:
# Example usage of the BankEarningsChatbot class

# Define your PDF folder path (ensure this path contains your earnings transcripts in PDF format).
pdf_folder = "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan"

# Instantiate the chatbot object.
chatbot = BankEarningsChatbot(pdf_folder)

# Option 1: Launch an interactive session.
chatbot.run_chatbot()

# Option 2: Process a list of prompts programmatically.
prompts = [
    "What were the key insights from the latest earnings call?",
    "How did revenue change compared to the previous quarter?",
    "What risk factors were identified in the transcript?",
    "What is the overall sentiment of the earnings call?"  # This might trigger the fallback if off-topic.
]

for prompt in prompts:
    response = chatbot.answer_question(prompt)
    print("Question:", prompt)
    print("Response:", response)
    print("-" * 60)
