<a href="https://colab.research.google.com/github/sheldonkemper/bank_of_england/blob/main/notebooks/modelling/sk_gen_ai_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/sheldonkemper/bank_of_england/blob/main/notebooks/modelling/sk_gen_ai_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
===================================================
Author: Sheldon Kemper
Role: Data Engineering Lead, Bank of England Employer Project (Quant Collective)
LinkedIn: https://www.linkedin.com/in/sheldon-kemper
Date: 2025-02-04
Version: 1.1

Description:
    This notebook contains a class-based implementation of a Retrieval Augmented Generation (RAG) engine
    designed to analyze bank quarterly earnings call transcripts (in PDF format) stored on Google Drive.
    The code performs the following tasks:

    1. Configures an LLM pipeline using a Flan-T5-based model for text summarization.
    2. Sets up sentence-transformer based embeddings for document vectorization.
    3. Loads and splits PDF documents from a specified directory.
    4. Chunks the documents and builds a vector index using Chroma.
    5. Retrieves context relevant to user queries from the vector index.
    6. Implements a fallback mechanism for queries unrelated to the provided data.
    7. Maintains conversation memory for interactive sessions.
    8. Supports both interactive and programmatic prompt-based querying.
===================================================
"""



In [2]:
# install langchain-community
!pip install -q langchain-community pypdf tiktoken chromadb sentence-transformers > /dev/null 2>&1

In [3]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from google.colab import drive

In [4]:

# Mount Google Drive to the root location with force_remount
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


# A class-based implementation of an LLM Retrieval Augmented Generation (RAG) engine

In [5]:
class BankEarningsChatbot:
    """
    A class-based implementation of an LLM Retrieval Augmented Generation (RAG) engine
    designed to analyze bank quarterly earnings call transcripts. It loads PDF documents
    from one or more specified folders, builds a Chroma vector index, and sets up an interactive
    conversational chain for prompt-based queries.

    Parameters:
        pdf_folders (list or str): A list of folder paths containing PDF files, or a single folder path as a string.
        persist_directory (str): Directory path to persist the vector index and model outputs.
        max_length (int): Maximum output length for the T5 model.
        test_mode (bool): If True, only loads one PDF (from the first folder) for quick testing.
    """
    def __init__(self, pdf_folders, persist_directory="/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs", max_length=256, test_mode=False):
        # Allow a single folder (string) or a list of folders.
        if isinstance(pdf_folders, str):
            self.pdf_folders = [pdf_folders]
        else:
            self.pdf_folders = pdf_folders

        self.persist_directory = persist_directory
        self.test_mode = test_mode

        # Set up the LLM pipeline using the Flan-T5 model.
        self._setup_llm(max_length)

        # Configure embeddings.
        self._setup_embeddings()

        # Load documents from all specified PDF folders.
        self._load_documents()

        # Build the vector index.
        self._build_vector_index()

        # Configure retriever from the persisted vector database.
        self._setup_retriever()

        # Initialize conversation memory.
        self.memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", return_messages=True)

        # Set up the conversational retrieval chain.
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.retriever,
            memory=self.memory,
            verbose=False
        )

    def _setup_llm(self, max_length):
        """
        Configures the language model pipeline using a T5 model.
        """
        self.model_name = "google/flan-t5-large"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            self.model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        # Using the text2text-generation pipeline for T5
        self.pipe = pipeline(
            "text2text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_length=max_length,
            temperature=0.5,
            top_p=0.8,
            do_sample=True
        )
        self.llm = HuggingFacePipeline(pipeline=self.pipe)

    def _setup_embeddings(self):
        """
        Initializes the sentence-transformer based embeddings.
        """
        self.embedding_model = "sentence-transformers/all-mpnet-base-v2"
        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

    def _load_documents(self):
        """
        Loads and splits PDF documents from the specified folders.
        In test_mode, only the first PDF (from the first folder) is loaded.
        """
        self.documents = []
        for folder in self.pdf_folders:
            file_names = [f for f in os.listdir(folder) if f.endswith(".pdf")]
            if not file_names:
                continue
            # If in test_mode, load only the first PDF file from this folder.
            if self.test_mode:
                file_names = file_names[:1]
            for pdf_file in file_names:
                pdf_path = os.path.join(folder, pdf_file)
                try:
                    loader = PyPDFLoader(pdf_path, extract_images=False)
                    self.documents.extend(loader.load_and_split())
                    print(f"Loaded: {pdf_file} from {folder}")
                except Exception as e:
                    print(f"Error loading {pdf_file} from {folder}: {e}")
            # In test mode, break after processing the first folder.
            if self.test_mode:
                break

    def _build_vector_index(self):
        """
        Chunks documents and builds a Chroma vector index.
        """
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        chunks = text_splitter.split_documents(self.documents)
        # Remove duplicate chunks.
        chunks = self.remove_duplicate_chunks(chunks)
        self.chunks = chunks
        self.db = Chroma.from_documents(chunks, embedding=self.embeddings, persist_directory=self.persist_directory)
        self.db.persist()

    def _setup_retriever(self):
        """
        Initializes the retriever from the persisted vector database.
        """
        self.vectordb = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)
        self.retriever = self.vectordb.as_retriever(search_kwargs={"k": 3})

    @staticmethod
    def remove_duplicate_chunks(chunks):
        """
        Eliminates duplicate document chunks based on their content.
        """
        seen = set()
        unique_chunks = []
        for chunk in chunks:
            chunk_text = chunk.page_content.strip()
            if chunk_text not in seen:
                seen.add(chunk_text)
                unique_chunks.append(chunk)
        return unique_chunks

    def truncate_context(self, context_list, max_tokens=800):
        """
        Truncates the retrieved context to avoid overloading the model's input.
        """
        truncated_docs = []
        current_tokens = 0
        for doc in context_list:
            doc_tokens = len(self.tokenizer.encode(doc.page_content))
            if current_tokens + doc_tokens <= max_tokens:
                truncated_docs.append(doc)
                current_tokens += doc_tokens
            else:
                break
        return truncated_docs

    @staticmethod
    def clean_user_input(user_input):
        """
        Cleans and standardizes user input.
        """
        return user_input.strip().replace("\n", " ").replace("\t", " ")

    def reset_memory_if_needed(self):
        """
        Clears conversation history if the number of exchanges exceeds a threshold.
        """
        if len(self.memory.chat_memory.messages) > 6:
            print("\nMemory Full: Resetting Conversation History...\n")
            self.memory.clear()

    def format_response(self, question, response):
        """
        Formats the output to clearly present both the question and the answer.
        """
        response_text = response.strip()
        unwanted_phrases = [
            "Use the following pieces of context",
            "If you don't know the answer, just say that you don't know",
            "Don't try to make up an answer."
        ]
        for phrase in unwanted_phrases:
            if phrase in response_text:
                response_text = response_text.split(phrase)[-1].strip()
        return f"Question: {question}\nHelpful Answer: {response_text}"

    def trim_final_input(self, question, context, max_tokens=512):
        """
        Truncates the final input to meet the token limit, preserving document metadata.
        """
        system_message = (
            "You are analyzing a bank's quarterly earnings call transcript.\n"
            "Extract and summarize key financial insights, avoiding unnecessary details.\n"
            "If the answer isn't found, respond with 'I don't know.'\n"
            "Provide sources for your answers at the end."
        )
        # Join the context documents into one coherent string.
        context_str = "\n".join([doc.page_content for doc in context])
        input_text = f"{system_message}\n\nContext:\n{context_str}\n\nQuestion: {question}"
        tokens = self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens)
        return self.tokenizer.decode(tokens)

    def answer_question(self, question):
        """
        Processes the user query: retrieves context, prepares the prompt,
        and returns a formatted answer. If no relevant documents are retrieved,
        a fallback message is returned.
        """
        question = self.clean_user_input(question)
        self.reset_memory_if_needed()

        # Retrieve and process context.
        context = self.retriever.get_relevant_documents(question)
        context = self.remove_duplicate_chunks(context)
        context = self.truncate_context(context, max_tokens=800)

        # Fallback: if no relevant context is found.
        if not context:
            return f"Question: {question}\nHelpful Answer: I don't have information regarding that query."

        print("\nRetrieved Context:")
        for doc in context:
            source = doc.metadata.get('source', 'Unknown Source')
            page = doc.metadata.get('page', 'Unknown Page')
            print(f"- Source: {source}, Page: {page}")

        # Enforce a 512-token limit for the final prompt.
        formatted_input = self.trim_final_input(question, context, max_tokens=512)
        response = self.qa_chain({"question": formatted_input})
        return self.format_response(question, response['answer'])

    def run_chatbot(self):
        """
        Initiates an interactive loop for prompt-based queries.
        """
        print("\n💬 Bank Earnings Chatbot (Type 'exit' to stop)")
        while True:
            user_input = input("\nYou: ")
            if user_input.lower() == "exit":
                print("\nExiting Chatbot. Have a great day!")
                break
            answer = self.answer_question(user_input)
            print("\n" + answer)




Based on the transcripts and retrieval‐augmented setup, here are some recommendations for crafting prompts that are likely to yield the most accurate and domain‐specific responses:

- **Be Specific About the Timeframe:**  
  Instead of asking “What were the key insights?” specify the quarter or transcript you’re interested in. For example:  
  - "What were the key financial insights from the Q4 2023 earnings call?"  
  - "Summarize the main drivers of revenue in the Q1 2023 transcript."

- **Target Specific Financial Metrics or Themes:**  
  Focus on particular areas the transcripts cover, such as revenue trends, expense drivers, or capital performance. For example:  
  - "How did revenue change compared to the previous quarter in the Q4 2023 earnings call?"  
  - "What were the primary expense drivers discussed in the Q4 2023 transcript?"

- **Incorporate Domain-Specific Language:**  
  Use terminology that reflects the financial domain to guide the model. For example:  
  - "What risk factors and forward-looking statements were highlighted in the Q3 2023 transcript?"  
  - "Outline the key operational challenges and strategic responses mentioned in the earnings call."

- **Prompt for Summaries and Insights:**  
  Asking for summaries can help the model focus on extracting concise information from large volumes of text. For example:  
  - "Provide a concise summary of the key financial insights from the Q4 2023 earnings transcript, including revenue, expenses, and capital allocation."  
  - "What are the overall sentiments and key management strategies discussed in the transcript?"

By tailoring your queries with specific quarters, financial metrics, and industry language, you guide the retrieval and summarization process more effectively. This structured approach should lead to more precise and contextually relevant responses from your system.

Interactive Chatbot Session:
By calling chatbot.run_chatbot(), you launch an interactive loop. In this mode, the program continuously waits for user input from the command line. As the user types questions, the chatbot processes each one in real time and prints the response. This mode is ideal for a live, conversational experience where the operator manually drives the dialogue.

Programmatic Prompt Processing:
Instead of an interactive loop, you can supply a list of predefined prompts (as shown in the example). The code then iterates over this list, calling chatbot.answer_question(prompt) for each query. It prints both the prompt and the corresponding answer. This approach is useful for batch testing, automated evaluations, or when you want to process a fixed set of queries without manual intervention.

# Instantiate the chatbot object

In [None]:
# Define your PDF folder paths (ensure these paths contain your earnings transcripts in PDF format).
pdf_folders = [
    "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan",
    "/content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs"
]

# Define the persistence directory for model outputs and the vector index.
persist_directory = "/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs"

# Instantiate the chatbot object with test_mode=True to load only a single PDF.
chatbot = BankEarningsChatbot(pdf_folders, persist_directory=persist_directory, test_mode=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

# For debugging, process a list of prompts programmatically

In [None]:

# Process a list of prompts programmatically.
prompts = [
    "What were the key insights from the latest earnings call?",
    "How did revenue change compared to the previous quarter?",
    "What risk factors were identified in the transcript?",
    "What is the overall sentiment of the earnings call?"  # This might trigger the fallback if off-topic.
]

for prompt in prompts:
    response = chatbot.answer_question(prompt)
    print("Question:", prompt)
    print("Response:", response)
    print("-" * 60)

# Launch an interactive  Chatbot session

In [None]:
chatbot.run_chatbot()