<a href="https://colab.research.google.com/github/sheldonkemper/bank_of_england/blob/main/notebooks/modelling/sk_gen_ai_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/sheldonkemper/bank_of_england/blob/main/notebooks/modelling/sk_gen_ai_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
===================================================
Author: Sheldon Kemper
Role: Data Engineering Lead, Bank of England Employer Project (Quant Collective)
LinkedIn: https://www.linkedin.com/in/sheldon-kemper
Date: 2025-02-04
Version: 1.2

Description:
    This notebook contains a class-based implementation of a Retrieval Augmented Generation (RAG) engine
    designed to analyze bank quarterly earnings call transcripts (in PDF format) stored on Google Drive.
    The code performs the following tasks:

    1. Configures an LLM pipeline using a Flan-T5-based model for text summarization.
    2. Sets up sentence-transformer based embeddings for document vectorization.
    3. Loads and splits PDF documents from one or more specified directories.
    4. Chunks the documents and builds a vector index using Chroma, persisting the index to Google Drive.
    5. Optionally loads an existing persisted vector index to avoid re-indexing, via the 'rebuild_index' parameter.
    6. Retrieves context relevant to user queries from the vector index with token truncation to enforce input limits.
    7. Maintains conversation memory for interactive sessions.
    8. Supports both interactive and programmatic prompt-based querying.
    9. Includes a 'test_mode' option for quick testing with a single PDF.
===================================================
"""



In [10]:
# install langchain-community
!pip install -q langchain-community pypdf tiktoken chromadb sentence-transformers datasets rouge-score huggingface_hub torch transformers > /dev/null 2>&1


In [3]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from google.colab import drive
# For evaluation metrics (ROUGE)
from rouge_score import rouge_scorer
from huggingface_hub import hf_hub_download
import warnings

In [4]:
warnings.filterwarnings("ignore", category=DeprecationWarning)
# Mount Google Drive to the root location with force_remount
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [5]:
import os
from google.colab import userdata
userdata.get('HF')
os.environ["HF"] = userdata.get('HF') # Replace with your actual token

# A class-based implementation of an LLM Retrieval Augmented Generation (RAG) engine

In [6]:


# class BankEarningsChatbot:
#     """
#     A class-based implementation of an LLM Retrieval Augmented Generation (RAG) engine
#     designed to analyze bank quarterly earnings call transcripts. It loads PDF documents
#     from one or more specified folders, builds a Chroma vector index, and sets up an interactive
#     conversational chain for prompt-based queries.

#     Parameters:
#         pdf_folders (list or str): A list of folder paths containing PDF files, or a single folder path as a string.
#         persist_directory (str): Directory path to persist the vector index and model outputs.
#         max_length (int): Maximum output length for the T5 model.
#         test_mode (bool): If True, only loads one PDF (from the first folder) for quick testing.
#         rebuild_index (bool): If True, reprocess all PDFs and rebuild the index even if a persisted index exists.
#     """
#     def __init__(self, pdf_folders,
#                  persist_directory="/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs",
#                  max_length=256, test_mode=False, rebuild_index=False):
#         # Allow a single folder (string) or a list of folders.
#         if isinstance(pdf_folders, str):
#             self.pdf_folders = [pdf_folders]
#         else:
#             self.pdf_folders = pdf_folders

#         self.persist_directory = persist_directory
#         self.test_mode = test_mode

#         # Set up the LLM pipeline using the Flan-T5 model.
#         self._setup_llm(max_length)

#         # Configure embeddings.
#         self._setup_embeddings()

#         # Check if we need to rebuild the vector index.
#         if rebuild_index or (not os.path.exists(self.persist_directory)) or (not os.listdir(self.persist_directory)):
#             # Load documents and build the vector index.
#             self._load_documents()
#             self._build_vector_index()
#         else:
#             print("Loading existing vector index from persistence directory.")
#             self.db = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)
#             # Note: If you need to update the in-memory index from the persisted data, this method should suffice.

#         # Configure retriever from the persisted vector database.
#         self._setup_retriever()

#         # Initialize conversation memory.
#         self.memory = ConversationBufferWindowMemory(k=10, memory_key="chat_history", return_messages=True)

#         # Set up the conversational retrieval chain.
#         self.qa_chain = ConversationalRetrievalChain.from_llm(
#             llm=self.llm,
#             retriever=self.retriever,
#             memory=self.memory,
#             verbose=False
#         )

#     def _setup_llm(self, max_length):
#         """
#         Configures the language model pipeline using a T5 model.
#         """
#         self.model_name = "google/flan-t5-large"
#         self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
#         self.model = AutoModelForSeq2SeqLM.from_pretrained(
#             self.model_name,
#             device_map="auto",
#             torch_dtype=torch.float16
#         )
#         # Using the text2text-generation pipeline for T5.
#         self.pipe = pipeline(
#             "text2text-generation",
#             model=self.model,
#             tokenizer=self.tokenizer,
#             max_length=max_length,
#             temperature=0.5,
#             top_p=0.8,
#             do_sample=True
#         )
#         self.llm = HuggingFacePipeline(pipeline=self.pipe)

#     def _setup_embeddings(self):
#         """
#         Initializes the sentence-transformer based embeddings.
#         """
#         self.embedding_model = "sentence-transformers/all-mpnet-base-v2"
#         self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

#     def _load_documents(self):
#         """
#         Loads and splits PDF documents from the specified folders.
#         In test_mode, only the first PDF (from the first folder) is loaded.
#         """
#         self.documents = []
#         for folder in self.pdf_folders:
#             file_names = [f for f in os.listdir(folder) if f.endswith(".pdf")]
#             if not file_names:
#                 continue
#             # If in test_mode, load only the first PDF file from this folder.
#             if self.test_mode:
#                 file_names = file_names[:1]
#             for pdf_file in file_names:
#                 pdf_path = os.path.join(folder, pdf_file)
#                 try:
#                     loader = PyPDFLoader(pdf_path, extract_images=False)
#                     self.documents.extend(loader.load_and_split())
#                     print(f"Loaded: {pdf_file} from {folder}")
#                 except Exception as e:
#                     print(f"Error loading {pdf_file} from {folder}: {e}")
#             # In test mode, break after processing the first folder.
#             if self.test_mode:
#                 break

#     def _build_vector_index(self):
#         """
#         Chunks documents and builds a Chroma vector index.
#         """
#         text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
#         chunks = text_splitter.split_documents(self.documents)
#         # Remove duplicate chunks.
#         chunks = self.remove_duplicate_chunks(chunks)
#         self.chunks = chunks
#         self.db = Chroma.from_documents(chunks, embedding=self.embeddings, persist_directory=self.persist_directory)
#         self.db.persist()

#     def _setup_retriever(self):
#         """
#         Initializes the retriever from the persisted vector database.
#         """
#         self.vectordb = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)
#         self.retriever = self.vectordb.as_retriever(search_kwargs={"k": 5})

#     @staticmethod
#     def remove_duplicate_chunks(chunks):
#         """
#         Eliminates duplicate document chunks based on their content.
#         """
#         seen = set()
#         unique_chunks = []
#         for chunk in chunks:
#             chunk_text = chunk.page_content.strip()
#             if chunk_text not in seen:
#                 seen.add(chunk_text)
#                 unique_chunks.append(chunk)
#         return unique_chunks

#     def truncate_context(self, context_list, max_tokens=800):
#         """
#         Truncates the retrieved context to avoid overloading the model's input.
#         """
#         truncated_docs = []
#         current_tokens = 0
#         for doc in context_list:
#             doc_tokens = len(self.tokenizer.encode(doc.page_content))
#             if current_tokens + doc_tokens <= max_tokens:
#                 truncated_docs.append(doc)
#                 current_tokens += doc_tokens
#             else:
#                 break
#         return truncated_docs

#     @staticmethod
#     def clean_user_input(user_input):
#         """
#         Cleans and standardizes user input.
#         """
#         return user_input.strip().replace("\n", " ").replace("\t", " ")

#     def reset_memory_if_needed(self):
#         """
#         Clears conversation history if the number of exchanges exceeds a threshold.
#         """
#         if len(self.memory.chat_memory.messages) > 6:
#             print("\nMemory Full: Resetting Conversation History...\n")
#             self.memory.clear()

#     def format_response(self, question, response):
#         """
#         Formats the output to clearly present both the question and the answer.
#         """
#         response_text = response.strip()
#         unwanted_phrases = [
#             "Use the following pieces of context",
#             "If you don't know the answer, just say that you don't know",
#             "Don't try to make up an answer."
#         ]
#         for phrase in unwanted_phrases:
#             if phrase in response_text:
#                 response_text = response_text.split(phrase)[-1].strip()
#         return f"Question: {question}\nHelpful Answer: {response_text}"

#     def trim_final_input(self, question, context, max_tokens=512):
#       """
#       Truncates the final input to meet the token limit, preserving document metadata.
#       """
#       system_message = (
#           "You are analyzing a bank's quarterly earnings call transcript. "
#           "Provide a bullet-point summary of the most important takeaways with specific details: "
#           "list key revenue trends (include any percentage changes if available), major expense drivers, "
#           "and management's outlook for the future. If numerical details are not available, provide qualitative insights."
#       )
#       # Join the context documents into one coherent string.
#       context_str = "\n".join([doc.page_content for doc in context])
#       input_text = f"{system_message}\n\nContext:\n{context_str}\n\nQuestion: {question}"
#       tokens = self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens)
#       return self.tokenizer.decode(tokens)


#     def answer_question(self, question):
#         """
#         Processes the user query: retrieves context, prepares the prompt,
#         and returns a formatted answer. If no relevant documents are retrieved,
#         a fallback message is returned.
#         """
#         question = self.clean_user_input(question)
#         self.reset_memory_if_needed()

#         # Retrieve and process context.
#         context = self.retriever.get_relevant_documents(question)
#         context = self.remove_duplicate_chunks(context)
#         context = self.truncate_context(context, max_tokens=800)

#         # Fallback: if no relevant context is found.
#         if not context:
#             return f"Question: {question}\nHelpful Answer: I don't have information regarding that query."

#         print("\nRetrieved Context:")
#         for doc in context:
#             source = doc.metadata.get('source', 'Unknown Source')
#             page = doc.metadata.get('page', 'Unknown Page')
#             print(f"- Source: {source}, Page: {page}")

#         # Enforce a 512-token limit for the final prompt.
#         formatted_input = self.trim_final_input(question, context, max_tokens=512)
#         response = self.qa_chain({"question": formatted_input})
#         return self.format_response(question, response['answer'])

#     def run_chatbot(self):
#         """
#         Initiates an interactive loop for prompt-based queries.
#         """
#         print("\n💬 Bank Earnings Chatbot (Type 'exit' to stop)")
#         while True:
#             user_input = input("\nYou: ")
#             if user_input.lower() == "exit":
#                 print("\nExiting Chatbot. Have a great day!")
#                 break
#             answer = self.answer_question(user_input)
#             print("\n" + answer)




In [7]:
# # ----------------------------
# # Example usage of the BankEarningsChatbot class with T5, multiple data sources, and test mode enabled
# # ----------------------------

# # Define your PDF folder paths (ensure these paths contain your earnings transcripts in PDF format).
# pdf_folders = [
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan",
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs"
# ]

# # Define the persistence directory for model outputs and the vector index.
# persist_directory = "/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs"

# # Instantiate the chatbot object.
# # Set rebuild_index=True if you want to force re-indexing, otherwise it will load the persisted index if it exists.
# chatbot = BankEarningsChatbot(pdf_folders, persist_directory=persist_directory, test_mode=False, rebuild_index=True)

In [8]:
# import os
# import re
# import torch
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# from datasets import Dataset  # For batch processing
# from langchain.document_loaders import PyPDFLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.embeddings import HuggingFaceEmbeddings
# from langchain.vectorstores import Chroma
# from langchain.memory import ConversationBufferWindowMemory
# from langchain_community.llms import HuggingFacePipeline
# from langchain.chains import ConversationalRetrievalChain
# from langchain.docstore.document import Document

# # For evaluation metrics (ROUGE)
# from rouge_score import rouge_scorer

# class BankEarningsChatbotTwoStage:
#     """
#     A two-stage RAG approach using LED (long-context model):
#       - Stage 1: Summarize each document (or chunk, if needed) and store summaries in a separate vector store.
#       - Stage 2: For each user query, retrieve relevant summaries and perform final Q&A.
#       - Also includes an evaluation method for summarization quality (using ROUGE).
#     """
#     def __init__(
#         self,
#         pdf_folders,
#         persist_directory="/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs",
#         max_length=1024,  # LED supports up to 16k tokens; adjust as needed.
#         test_mode=False,
#         rebuild_index=False,
#         verbose=False,
#         chunk_size=1000,
#         chunk_overlap=100,
#         chunk_threshold=1024  # If a doc's token count <= threshold, do not split it.
#     ):
#         if isinstance(pdf_folders, str):
#             self.pdf_folders = [pdf_folders]
#         else:
#             self.pdf_folders = pdf_folders

#         self.persist_directory = persist_directory
#         self.test_mode = test_mode
#         self.verbose = verbose
#         self.chunk_size = chunk_size
#         self.chunk_overlap = chunk_overlap
#         self.chunk_threshold = chunk_threshold  # New parameter

#         # 1) Set up the LLM pipeline using LED.
#         self._setup_llm(max_length)
#         # 2) Set up embeddings.
#         self._setup_embeddings()

#         # 3) Load raw documents and (optionally) build a raw vector index.
#         if rebuild_index or (not os.path.exists(self.persist_directory)) or (not os.listdir(self.persist_directory)):
#             self._load_documents()
#             self._build_raw_vector_index()
#         else:
#             if self.verbose:
#                 print("Loading existing raw vector index from persistence directory.")
#             self.raw_db = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)

#         # 4) Build the summary vector index (Stage 1).
#         self.summary_persist_dir = os.path.join(self.persist_directory, "summaries")
#         if rebuild_index or (not os.path.exists(self.summary_persist_dir)) or (not os.listdir(self.summary_persist_dir)):
#             if not hasattr(self, "documents"):
#                 self._load_documents()
#             self._build_summary_vector_index()
#         else:
#             if self.verbose:
#                 print("Loading existing summary vector index from 'summaries' directory.")
#             self.summary_db = Chroma(persist_directory=self.summary_persist_dir, embedding_function=self.embeddings)

#         # 5) Create a retriever for the summary DB.
#         self._setup_summary_retriever()

#         # 6) Create conversation memory.
#         self.memory = ConversationBufferWindowMemory(k=10, memory_key="chat_history", return_messages=True)

#         # 7) Create the final Q&A chain (Stage 2) using the summary retriever.
#         self.qa_chain = ConversationalRetrievalChain.from_llm(
#             llm=self.llm,
#             retriever=self.summary_retriever,
#             memory=self.memory,
#             verbose=False
#         )

#     def _setup_llm(self, max_length):
#         # Use LED which supports a larger context window.
#         self.model_name = "allenai/led-base-16384"
#         self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
#         self.model = AutoModelForSeq2SeqLM.from_pretrained(
#             self.model_name,
#             torch_dtype=torch.float16
#         )
#         if torch.cuda.is_available():
#             self.model.to("cuda")
#         self.pipe = pipeline(
#             "text2text-generation",
#             model=self.model,
#             tokenizer=self.tokenizer,
#             max_length=max_length,
#             temperature=0.1,
#             top_p=0.8,
#             do_sample=True,
#             batch_size=8
#         )
#         self.llm = HuggingFacePipeline(pipeline=self.pipe)

#     def _setup_embeddings(self):
#         self.embedding_model = "sentence-transformers/all-mpnet-base-v2"
#         self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

#     def _load_documents(self):
#         self.documents = []
#         total_files_loaded = 0
#         for folder in self.pdf_folders:
#             bank_name = os.path.basename(folder).lower()
#             file_names = [f for f in os.listdir(folder) if f.endswith(".pdf")]
#             if not file_names:
#                 continue
#             if self.test_mode:
#                 file_names = file_names[:1]
#             for pdf_file in file_names:
#                 pdf_path = os.path.join(folder, pdf_file)
#                 try:
#                     loader = PyPDFLoader(pdf_path, extract_images=False)
#                     docs = loader.load_and_split()
#                     for doc in docs:
#                         doc.metadata["bank"] = bank_name
#                         doc.metadata["source_pdf"] = pdf_file
#                     self.documents.extend(docs)
#                     total_files_loaded += 1
#                     if self.verbose:
#                         print(f"Loaded: {pdf_file} from {folder}")
#                 except Exception as e:
#                     if self.verbose:
#                         print(f"Error loading {pdf_file} from {folder}: {e}")
#             if self.test_mode:
#                 break
#         if self.verbose:
#             print(f"Total PDF files loaded: {total_files_loaded}")

#     def _build_raw_vector_index(self):
#         # Instead of always chunking, check if a document’s token length exceeds our threshold.
#         processed_docs = []
#         for doc in self.documents:
#             tokens = self.tokenizer.encode(doc.page_content)
#             if len(tokens) > self.chunk_threshold:
#                 # Use chunking
#                 splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
#                 chunks = splitter.split_documents([doc])
#                 processed_docs.extend(self.remove_duplicate_chunks(chunks))
#             else:
#                 processed_docs.append(doc)
#         self.raw_db = Chroma.from_documents(
#             processed_docs,
#             embedding=self.embeddings,
#             persist_directory=self.persist_directory
#         )
#         self.raw_db.persist()

#     def _build_summary_vector_index(self):
#         # Stage 1: Chunk (if necessary) the documents.
#         processed_docs = []
#         for doc in self.documents:
#             tokens = self.tokenizer.encode(doc.page_content)
#             if len(tokens) > self.chunk_threshold:
#                 splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
#                 chunks = splitter.split_documents([doc])
#                 processed_docs.extend(self.remove_duplicate_chunks(chunks))
#             else:
#                 processed_docs.append(doc)

#         # Stage 2: Summarize each processed document/chunk.
#         prompts = []
#         for d in processed_docs:
#             prompt = (
#                 "Summarize the following text in bullet points, focusing on key financials, sentiment, "
#                 "and forward-looking statements:\n\n" + d.page_content
#             )
#             prompts.append(prompt)

#         if self.verbose:
#             print(f"Summarizing {len(prompts)} documents/chunks...")

#         # Use a Dataset for batch processing.
#         ds = Dataset.from_dict({"text": prompts})
#         summary_responses = self.llm.generate(ds["text"])

#         summary_docs = []
#         for orig_doc, resp in zip(processed_docs, summary_responses):
#             if isinstance(resp, dict) and 'answer' in resp:
#                 summary_text = resp['answer']
#             else:
#                 summary_text = str(resp)
#             summary_doc = Document(
#                 page_content=summary_text,
#                 metadata={
#                     "source_pdf": orig_doc.metadata.get("source_pdf", "Unknown"),
#                     "bank": orig_doc.metadata.get("bank", "unknown"),
#                     "orig_chunk": orig_doc.page_content[:50]
#                 }
#             )
#             summary_docs.append(summary_doc)

#         self.summary_db = Chroma.from_documents(
#             summary_docs,
#             embedding=self.embeddings,
#             persist_directory=os.path.join(self.persist_directory, "summaries")
#         )
#         self.summary_db.persist()
#         if self.verbose:
#             print(f"Built summary vector index with {len(summary_docs)} summarized docs.")

#     def _setup_summary_retriever(self):
#         self.summary_retriever = self.summary_db.as_retriever(search_kwargs={"k": 5})

#     @staticmethod
#     def remove_duplicate_chunks(chunks):
#         seen = set()
#         unique = []
#         for chunk in chunks:
#             text = chunk.page_content.strip()
#             if text not in seen:
#                 seen.add(text)
#                 unique.append(chunk)
#         return unique

#     def truncate_context(self, context_list, max_tokens=800):
#         truncated = []
#         current = 0
#         for doc in context_list:
#             doc_tokens = len(self.tokenizer.encode(doc.page_content))
#             if current + doc_tokens <= max_tokens:
#                 truncated.append(doc)
#                 current += doc_tokens
#             else:
#                 break
#         return truncated

#     @staticmethod
#     def clean_user_input(user_input):
#         return user_input.strip().replace("\n", " ").replace("\t", " ")

#     def reset_memory_if_needed(self):
#         if len(self.memory.chat_memory.messages) > 6:
#             if self.verbose:
#                 print("\nMemory Full: Resetting Conversation History...\n")
#             self.memory.clear()

#     def format_response(self, question, response):
#         if isinstance(response, dict) and 'answer' in response:
#             resp_text = response['answer'].strip()
#         else:
#             resp_text = str(response).strip()
#         for phrase in [
#             "Use the following pieces of context",
#             "If you don't know the answer, just say that you don't know",
#             "Don't try to make up an answer."
#         ]:
#             if phrase in resp_text:
#                 resp_text = resp_text.split(phrase)[-1].strip()
#         return f"Question: {question}\nHelpful Answer: {resp_text}"

#     def trim_final_input(self, question, context, max_tokens=512):
#         system_message = (
#             "You are analyzing a bank's quarterly earnings call transcript. "
#             "Provide a bullet-point summary of the most important takeaways with specific details: "
#             "list key revenue trends (include any percentage changes if available), major expense drivers, "
#             "and management's outlook for the future. If numerical details are not available, provide qualitative insights."
#         )
#         # Batch intermediate summarization for each context document.
#         summarized_context = []
#         per_doc_limit = max_tokens // max(1, len(context))
#         batch_prompts = []
#         for doc in context:
#             doc_text = doc.page_content
#             tokens = self.tokenizer.encode(doc_text)
#             if len(tokens) > per_doc_limit:
#                 summary_prompt = f"Summarize the following text in a concise bullet-point format:\n\n{doc_text}"
#                 batch_prompts.append(summary_prompt)
#             else:
#                 summarized_context.append(doc_text)
#         if batch_prompts:
#             ds = Dataset.from_dict({"text": batch_prompts})
#             summary_responses = self.llm.generate(ds["text"])
#             for resp in summary_responses:
#                 if isinstance(resp, dict) and 'answer' in resp:
#                     summarized_context.append(resp['answer'])
#                 else:
#                     summarized_context.append(str(resp))
#         context_str = "\n".join(summarized_context)
#         input_text = f"{system_message}\n\nContext:\n{context_str}\n\nQuestion: {question}"
#         token_length = len(self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens))
#         if self.verbose:
#             print(f"Final input token length: {token_length}")
#         tokens = self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens)
#         return self.tokenizer.decode(tokens)

#     def answer_question(self, question: str) -> str:
#         self.reset_memory_if_needed()
#         response_dict = self.qa_chain({"question": question})
#         final_answer = response_dict["answer"].strip()
#         return f"Question: {question}\nHelpful Answer: {final_answer}"

#     def run_chatbot(self):
#         print("\n💬 Bank Earnings Chatbot - Two Stage (Type 'exit' to stop)")
#         while True:
#             user_input = input("\nYou: ")
#             if user_input.lower() == "exit":
#                 print("\nExiting Chatbot. Have a great day!")
#                 break
#             answer = self.answer_question(user_input)
#             print("\n" + answer)

#     # ---------------- Evaluation Methods ----------------
#     def evaluate_summaries(self, generated_summaries, reference_summaries, verbose=False):
#         scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
#         eval_scores = {}
#         for key, gen_dict in generated_summaries.items():
#             gen_summary = gen_dict.get("summary", "")
#             if key in reference_summaries:
#                 ref_summary = reference_summaries[key]
#                 scores = scorer.score(ref_summary, gen_summary)
#                 eval_scores[key] = scores
#                 if verbose:
#                     print(f"Evaluation for {key}: {scores}")
#             else:
#                 if verbose:
#                     print(f"No reference summary provided for {key}.")
#         return eval_scores

#     def summarize_individual_transcripts(self):
#         prompts = []
#         sources = []
#         for doc in self.documents:
#             source = doc.metadata.get('source_pdf', 'Unknown Source')
#             bank = doc.metadata.get('bank', 'unknown')
#             transcript_text = doc.page_content
#             prompt = f"Please provide a bullet-point sentiment summary for the following transcript:\n\n{transcript_text}"
#             prompts.append(prompt)
#             sources.append((source, bank))
#         if self.verbose:
#             print(f"Batch summarizing {len(prompts)} transcripts...")
#         ds = Dataset.from_dict({"text": prompts})
#         responses = self.llm.generate(ds["text"])
#         summaries = {}
#         for (source, bank), response in zip(sources, responses):
#             if isinstance(response, dict) and 'answer' in response:
#                 answer_text = response['answer']
#             else:
#                 answer_text = str(response)
#             summaries[source] = {"bank": bank, "summary": answer_text}
#             if self.verbose:
#                 print(f"Summarized transcript from {source}")
#         return summaries

#     def group_summaries_by_bank_and_quarter(self, summaries):
#         grouped = {}
#         for source, info in summaries.items():
#             bank = info["bank"]
#             summary = info["summary"]
#             match = re.search(
#                 r'((\d{1}[qQ]|[qQ]\d)|((first|second|third|fourth)[-_ ]?quarter))[-_ ]?(\d{2,4})',
#                 source,
#                 re.IGNORECASE
#             )
#             if match:
#                 if match.group(2):
#                     quarter_raw = match.group(2).lower()
#                     quarter_digit_match = re.search(r'\d', quarter_raw)
#                     quarter_num = quarter_digit_match.group(0) if quarter_digit_match else "1"
#                 elif match.group(4):
#                     word = match.group(4).lower()
#                     mapping = {"first": "1", "second": "2", "third": "3", "fourth": "4"}
#                     quarter_num = mapping.get(word, "1")
#                 else:
#                     quarter_num = "1"
#                 year_str = match.group(5)
#                 year_val = int(year_str)
#                 if year_val < 100:
#                     year_val += 2000
#                 key = f"{year_val}-Q{quarter_num}"
#             else:
#                 key = "Unknown"
#             if bank not in grouped:
#                 grouped[bank] = {}
#             if key not in grouped[bank]:
#                 grouped[bank][key] = []
#             grouped[bank][key].append(summary)
#         return grouped

#     def aggregate_quarterly_summaries_by_bank(self, grouped_quarterly):
#         quarterly_aggregates = {}
#         for bank, quarters in grouped_quarterly.items():
#             quarterly_aggregates[bank] = {}
#             for key, summaries in quarters.items():
#                 combined = "\n".join(summaries)
#                 prompt = (f"Based on the following quarterly sentiment summaries for {bank.upper()} ({key}), "
#                           "provide a concise bullet-point overview of the overall sentiment for that quarter:\n\n" + combined)
#                 response = self.llm(prompt)
#                 if isinstance(response, dict) and 'answer' in response:
#                     answer_text = response['answer']
#                 else:
#                     answer_text = str(response)
#                 quarterly_aggregates[bank][key] = answer_text
#         return quarterly_aggregates

#     def forecast_next_quarter_sentiment(self, bank, historical_quarterly):
#         combined = "\n".join(historical_quarterly)
#         prompt = (
#             f"Based on the following historical quarterly sentiment summaries for {bank.upper()}, "
#             "forecast the overall sentiment for the next quarter. Provide a bullet-point summary of the expected trends, "
#             "including any changes in tone, risk factors, or optimism:\n\n" + combined
#         )
#         response = self.llm(prompt)
#         if isinstance(response, dict) and 'answer' in response:
#             return response['answer']
#         return str(response)

#     @staticmethod
#     def parse_quarter_key(key):
#         try:
#             parts = key.split("-")
#             year = int(parts[0])
#             quarter = int(re.search(r'\d', parts[1]).group(0))
#             return year, quarter
#         except Exception:
#             return (0, 0)

#     def analyze_and_forecast_sentiment_by_bank(self):
#         print("Generating individual transcript summaries...")
#         summaries = self.summarize_individual_transcripts()
#         print("Grouping summaries by bank and quarter...")
#         grouped_quarterly = self.group_summaries_by_bank_and_quarter(summaries)
#         print("Aggregating quarterly summaries...")
#         quarterly_aggregates = self.aggregate_quarterly_summaries_by_bank(grouped_quarterly)

#         analysis = {}
#         for bank, quarters in quarterly_aggregates.items():
#             analysis[bank] = {}
#             valid_keys = [k for k in quarters.keys() if k != "Unknown"]
#             if not valid_keys:
#                 print(f"No valid quarter keys found for {bank}: {valid_keys}")
#                 continue
#             sorted_keys = sorted(valid_keys, key=lambda k: self.parse_quarter_key(k))
#             most_recent_key = sorted_keys[-1]
#             current_year, current_quarter = self.parse_quarter_key(most_recent_key)

#             years = sorted({self.parse_quarter_key(k)[0] for k in quarters if k != "Unknown"})
#             previous_year = max([y for y in years if y < current_year], default=None)
#             previous_year_summaries = []
#             if previous_year is not None:
#                 for key in quarters:
#                     year, _ = self.parse_quarter_key(key)
#                     if year == previous_year:
#                         previous_year_summaries.extend(quarters[key])
#                 combined_prev = "\n".join(previous_year_summaries)
#                 prompt_prev = (f"Based on the following sentiment summaries for all quarters in {previous_year} for {bank.upper()}, "
#                                "provide a bullet-point summary of the overall sentiment trends for that year:\n\n" + combined_prev)
#                 response_prev = self.llm(prompt_prev)
#                 if isinstance(response_prev, dict) and 'answer' in response_prev:
#                     previous_year_summary = response_prev['answer']
#                 else:
#                     previous_year_summary = str(response_prev)
#             else:
#                 previous_year_summary = "Not available"

#             current_summary = quarters.get(most_recent_key, "Not available")

#             historical = []
#             for key in sorted_keys:
#                 historical.extend(quarters[key])
#             forecast = self.forecast_next_quarter_sentiment(bank, historical) if historical else "Not available"

#             analysis[bank] = {
#                 "previous_year_summary": previous_year_summary,
#                 "current_quarter_summary": current_summary,
#                 "forecast_next_quarter": forecast,
#                 "most_recent_key": most_recent_key
#             }
#         return analysis

#     # ---------------- Example Usage and Evaluation ----------------
#     def evaluate_summaries(self, generated_summaries, reference_summaries, verbose=False):
#         scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
#         eval_scores = {}
#         for key, gen_dict in generated_summaries.items():
#             gen_summary = gen_dict.get("summary", "")
#             if key in reference_summaries:
#                 ref_summary = reference_summaries[key]
#                 scores = scorer.score(ref_summary, gen_summary)
#                 eval_scores[key] = scores
#                 if verbose:
#                     print(f"Evaluation for {key}: {scores}")
#             else:
#                 if verbose:
#                     print(f"No reference summary provided for {key}.")
#         return eval_scores


# # ----------------------------
# # Example usage:
# pdf_folders = [
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan",
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs"
# ]
# persist_directory = "/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs"

# chatbot = BankEarningsChatbotTwoStage(
#     pdf_folders,
#     persist_directory=persist_directory,
#     test_mode=False,
#     rebuild_index=True,
#     verbose=True
# )

# analysis = chatbot.analyze_and_forecast_sentiment_by_bank()

# print("Yearly and Quarterly Sentiment Analysis by Bank:")
# for bank, data in analysis.items():
#     print(f"{bank.upper()} Analysis:")
#     print(f"Most Recent Quarter Key: {data.get('most_recent_key', 'N/A')}")
#     print("Previous Year Sentiment Summary:")
#     print(data["previous_year_summary"])
#     print("Current Quarter Sentiment Summary:")
#     print(data["current_quarter_summary"])
#     print("Forecast for Next Quarter:")
#     print(data["forecast_next_quarter"])
#     print("-" * 60)

# # ----------------------------
# # Evaluation Example:
# # Suppose you have reference summaries in a dictionary:
# reference_summaries = {
#     "1q23-earnings-transcript.pdf": "Reference summary for 1Q23 transcript...",
#     "4q24-earnings-call-remarks.pdf": "Reference summary for 4Q24 transcript..."
#     # Add more as available.
# }

# generated_summaries = chatbot.summarize_individual_transcripts()
# evaluation_results = chatbot.evaluate_summaries(generated_summaries, reference_summaries, verbose=True)
# print("Evaluation Results (ROUGE):")
# for key, scores in evaluation_results.items():
#     print(f"{key}: {scores}")


In [None]:
import os
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datasets import Dataset  # For batch processing
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.docstore.document import Document

# For evaluation metrics (ROUGE)
from rouge_score import rouge_scorer

# --- Configuration Section ---
CHATBOT_CONFIG = {
    "pdf_folders": [  # Example folders, replace with your actual paths
        "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan",
        "/content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs"
    ],
    "persist_directory": "/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs",
    "llm_model_name": "google/bigbird-pegasus-large-arxiv",  # Updated to BigBird-Pegasus
    "embedding_model_name": "sentence-transformers/all-mpnet-base-v2",
    "text_generation_pipeline_task": "text2text-generation",
    "max_length": 1024,
    "temperature": 0.1,
    "top_p": 0.8,
    "batch_size": 8,
    "chunk_size": 1000,
    "chunk_overlap": 100,
    "chunk_threshold": 1024,
    "memory_window_k": 10,
    "retriever_search_k": 5
}


class BankEarningsChatbotTwoStage:
    """
    A two-stage RAG chatbot for analyzing bank earnings call transcripts using BigBird-Pegasus.

    Stage 1: Summarizes each document (or chunk if needed) and stores summaries in a vector store.
    Stage 2: Retrieves relevant summaries for user queries and performs final Q&A.

    Includes evaluation methods for summarization quality using ROUGE.
    """

    BIGBIRD_PEGASUS_MODEL_NAME = CHATBOT_CONFIG["llm_model_name"]
    EMBEDDING_MODEL_NAME = CHATBOT_CONFIG["embedding_model_name"]
    TEXT_GENERATION_TASK = CHATBOT_CONFIG["text_generation_pipeline_task"]

    def __init__(
        self,
        pdf_folders: list[str],
        persist_directory: str = CHATBOT_CONFIG["persist_directory"],
        max_length: int = CHATBOT_CONFIG["max_length"],
        test_mode: bool = False,
        rebuild_index: bool = False,
        verbose: bool = False,
        chunk_size: int = CHATBOT_CONFIG["chunk_size"],
        chunk_overlap: int = CHATBOT_CONFIG["chunk_overlap"],
        chunk_threshold: int = CHATBOT_CONFIG["chunk_threshold"]
    ) -> None:
        if not isinstance(pdf_folders, list):
            raise TypeError("pdf_folders must be a list of strings.")
        self.pdf_folders = pdf_folders
        self.persist_directory = persist_directory
        self.test_mode = test_mode
        self.verbose = verbose
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.chunk_threshold = chunk_threshold

        self._setup_llm(max_length)
        self._setup_embeddings()

        if rebuild_index or not os.path.exists(self.persist_directory) or not os.listdir(self.persist_directory):
            self._load_documents()
            self._build_raw_vector_index()
        else:
            if self.verbose:
                print("Loading existing raw vector index from persistence directory.")
            self.raw_db = Chroma(persist_directory=self.persist_directory, embedding_function=self.embeddings)

        self.summary_persist_dir = os.path.join(self.persist_directory, "summaries")
        if rebuild_index or not os.path.exists(self.summary_persist_dir) or not os.listdir(self.summary_persist_dir):
            if not hasattr(self, "documents"):
                self._load_documents()
            self._build_summary_vector_index()
        else:
            if self.verbose:
                print("Loading existing summary vector index from 'summaries' directory.")
            self.summary_db = Chroma(persist_directory=self.summary_persist_dir, embedding_function=self.embeddings)

        self._setup_summary_retriever()
        self.memory = ConversationBufferWindowMemory(k=CHATBOT_CONFIG["memory_window_k"], memory_key="chat_history", return_messages=True)
        self._setup_qa_chain()

    def _setup_llm(self, max_length: int) -> None:
        """
        Sets up the Language Model (LLM) pipeline using Hugging Face Transformers (BigBird-Pegasus model).
        """
        self.model_name = self.BIGBIRD_PEGASUS_MODEL_NAME
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=os.environ["HF"])
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            token=os.environ["HF"]
        )
        if torch.cuda.is_available():
            self.model.to("cuda")
        self.pipe = pipeline(
            self.TEXT_GENERATION_TASK,
            model=self.model,
            tokenizer=self.tokenizer,
            max_length=max_length,
            temperature=CHATBOT_CONFIG["temperature"],
            top_p=CHATBOT_CONFIG["top_p"],
            do_sample=True,
            batch_size=CHATBOT_CONFIG["batch_size"]
        )
        self.llm = HuggingFacePipeline(pipeline=self.pipe)

    def _setup_embeddings(self) -> None:
        """
        Sets up the Hugging Face Embeddings.
        """
        self.embedding_model = self.EMBEDDING_MODEL_NAME
        self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model)

    def _load_documents(self) -> None:
        """
        Loads PDF documents from specified folders.
        """
        self.documents = []
        total_files_loaded = 0
        for folder in self.pdf_folders:
            bank_name = os.path.basename(folder).lower()
            file_names = [f for f in os.listdir(folder) if f.endswith(".pdf")]
            if not file_names:
                continue
            if self.test_mode:
                file_names = file_names[:1]
            for pdf_file in file_names:
                pdf_path = os.path.join(folder, pdf_file)
                try:
                    loader = PyPDFLoader(pdf_path, extract_images=False)
                    docs = loader.load_and_split()
                    for doc in docs:
                        doc.metadata["bank"] = bank_name
                        doc.metadata["source_pdf"] = pdf_file
                    self.documents.extend(docs)
                    total_files_loaded += 1
                    if self.verbose:
                        print(f"Loaded: {pdf_file} from {folder}")
                except Exception as e:
                    print(f"Error loading {pdf_file} from {folder}: {e}")
            if self.test_mode:
                break
        if self.verbose:
            print(f"Total PDF files loaded: {total_files_loaded}")

    def _chunk_document_if_needed(self, doc: Document) -> list[Document]:
        tokens = self.tokenizer.encode(doc.page_content)
        if len(tokens) > self.chunk_threshold:
            splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
            chunks = splitter.split_documents([doc])
            return self.remove_duplicate_chunks(chunks)
        return [doc]

    def _build_raw_vector_index(self) -> None:
        processed_docs = []
        for doc in self.documents:
            processed_docs.extend(self._chunk_document_if_needed(doc))
        self.raw_db = Chroma.from_documents(
            processed_docs,
            embedding=self.embeddings,
            persist_directory=self.persist_directory
        )
        self.raw_db.persist()
        if self.verbose:
            print(f"Built raw vector index with {len(processed_docs)} documents/chunks.")

    def _summarize_document_batch(self, prompts: list[str]) -> list[str]:
        ds = Dataset.from_dict({"text": prompts})
        if self.verbose:
            print("Dataset created for summarization prompts:")
            print(ds)
        try:
            summary_responses_list = self.llm.generate(ds["text"])
            summary_responses = [resp[0].page_content if isinstance(resp, list) and resp else str(resp) for resp in summary_responses_list]
            return summary_responses
        except Exception as e:
            print(f"Error during LLM summarization batch: {e}")
            return [""] * len(prompts)

    def _create_summary_documents(self, processed_docs: list[Document], summary_texts: list[str]) -> list[Document]:
        summary_docs = []
        for orig_doc, summary_text in zip(processed_docs, summary_texts):
            summary_doc = Document(
                page_content=summary_text,
                metadata={
                    "source_pdf": orig_doc.metadata.get("source_pdf", "Unknown"),
                    "bank": orig_doc.metadata.get("bank", "unknown"),
                    "orig_chunk": orig_doc.page_content[:50]
                }
            )
            summary_docs.append(summary_doc)
        return summary_docs

    def _build_summary_vector_index(self) -> None:
        processed_docs = []
        for doc in self.documents:
            processed_docs.extend(self._chunk_document_if_needed(doc))

        prompts = [
            "Summarize the following text in bullet points, focusing on key financials, sentiment, and forward-looking statements:\n\n" + d.page_content
            for d in processed_docs
        ]

        if self.verbose:
            print(f"Summarizing {len(prompts)} documents/chunks...")

        summary_texts = self._summarize_document_batch(prompts)
        summary_docs = self._create_summary_documents(processed_docs, summary_texts)

        self.summary_db = Chroma.from_documents(
            summary_docs,
            embedding=self.embeddings,
            persist_directory=self.summary_persist_dir
        )
        self.summary_db.persist()
        if self.verbose:
            print(f"Built summary vector index with {len(summary_docs)} summarized docs.")

    def _setup_summary_retriever(self) -> None:
        self.summary_retriever = self.summary_db.as_retriever(search_kwargs={"k": CHATBOT_CONFIG["retriever_search_k"]})

    def _setup_qa_chain(self) -> None:
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.summary_retriever,
            memory=self.memory,
            verbose=False
        )

    @staticmethod
    def remove_duplicate_chunks(chunks: list[Document]) -> list[Document]:
        seen = set()
        unique = []
        for chunk in chunks:
            text = chunk.page_content.strip()
            if text not in seen:
                seen.add(text)
                unique.append(chunk)
        return unique

    def truncate_context(self, context_list: list[Document], max_tokens: int = 800) -> list[Document]:
        truncated = []
        current = 0
        for doc in context_list:
            doc_tokens = len(self.tokenizer.encode(doc.page_content))
            if current + doc_tokens <= max_tokens:
                truncated.append(doc)
                current += doc_tokens
            else:
                break
        return truncated

    @staticmethod
    def clean_user_input(user_input: str) -> str:
        return user_input.strip().replace("\n", " ").replace("\t", " ")

    def reset_memory_if_needed(self) -> None:
        if len(self.memory.chat_memory.messages) > 6:
            if self.verbose:
                print("\nMemory Full: Resetting Conversation History...\n")
            self.memory.clear()

    def format_response(self, question: str, response: dict) -> str:
        if isinstance(response, dict) and 'answer' in response:
            resp_text = response['answer'].strip()
        else:
            resp_text = str(response).strip()
        for phrase in [
            "Use the following pieces of context",
            "If you don't know the answer, just say that you don't know",
            "Don't try to make up an answer."
        ]:
            if phrase in resp_text:
                resp_text = resp_text.split(phrase)[-1].strip()
        return f"Question: {question}\nHelpful Answer: {resp_text}"

    def _batch_summarize_context_docs(self, context: list[Document], per_doc_limit: int) -> list[str]:
        summarized_context = []
        batch_prompts = []
        for doc in context:
            doc_text = doc.page_content
            tokens = self.tokenizer.encode(doc_text)
            if len(tokens) > per_doc_limit:
                summary_prompt = f"Summarize the following text in a concise bullet-point format:\n\n{doc_text}"
                batch_prompts.append(summary_prompt)
            else:
                summarized_context.append(doc_text)
        if batch_prompts:
            ds = Dataset.from_dict({"text": batch_prompts})
            try:
                summary_responses_list = self.llm.generate(ds["text"])
                batch_summaries = [resp[0].page_content if isinstance(resp, list) and resp else str(resp) for resp in summary_responses_list]
                summarized_context.extend(batch_summaries)
            except Exception as e:
                print(f"Error during batch summarization of context documents: {e}")
                summarized_context.extend([""] * len(batch_prompts))
        return summarized_context

    def trim_final_input(self, question: str, context: list[Document], max_tokens: int = 512) -> str:
        system_message = (
            "You are analyzing a bank's quarterly earnings call transcript. "
            "Provide a bullet-point summary of the most important takeaways with specific details: "
            "list key revenue trends (include any percentage changes if available), major expense drivers, "
            "and management's outlook for the future. If numerical details are not available, provide qualitative insights."
        )
        per_doc_limit = max_tokens // max(1, len(context))
        summarized_context = self._batch_summarize_context_docs(context, per_doc_limit)
        context_str = "\n".join(summarized_context)
        input_text = f"{system_message}\n\nContext:\n{context_str}\n\nQuestion: {question}"
        token_length = len(self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens))
        if self.verbose:
            print(f"Final input token length: {token_length}")
        tokens = self.tokenizer.encode(input_text, truncation=True, max_length=max_tokens)
        return self.tokenizer.decode(tokens)

    def answer_question(self, question: str) -> str:
        self.reset_memory_if_needed()
        response_dict = self.qa_chain({"question": question})
        final_answer = response_dict["answer"].strip()
        return f"Question: {question}\nHelpful Answer: {final_answer}"

    def run_chatbot(self) -> None:
        print("\n Bank Earnings Chatbot - Two Stage (Type 'exit' to stop)")
        while True:
            user_input = input("\nYou: ")
            if user_input.lower() == "exit":
                print("\nExiting Chatbot. Have a great day!")
                break
            answer = self.answer_question(user_input)
            print("\n" + answer)

    # ---------------- Evaluation Methods ----------------
    def evaluate_summaries(self, generated_summaries: dict, reference_summaries: dict, verbose: bool = False) -> dict:
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        eval_scores = {}
        for key, gen_dict in generated_summaries.items():
            gen_summary = gen_dict.get("summary", "")
            if key in reference_summaries:
                ref_summary = reference_summaries[key]
                scores = scorer.score(ref_summary, gen_summary)
                eval_scores[key] = scores
                if verbose:
                    print(f"Evaluation for {key}: {scores}")
            else:
                if verbose:
                    print(f"No reference summary provided for {key}.")
        return eval_scores

    def summarize_individual_transcripts(self) -> dict:
        prompts = []
        sources = []
        for doc in self.documents:
            source = doc.metadata.get('source_pdf', 'Unknown Source')
            bank = doc.metadata.get('bank', 'unknown')
            transcript_text = doc.page_content
            prompt = f"Please provide a bullet-point sentiment summary for the following transcript:\n\n{transcript_text}"
            prompts.append(prompt)
            sources.append((source, bank))
        if self.verbose:
            print(f"Batch summarizing {len(prompts)} transcripts...")
        ds = Dataset.from_dict({"text": prompts})
        if self.verbose:
            print("Dataset created for transcript summarization:")
            print(ds)
        try:
            responses_list = self.llm.generate(ds["text"])
            responses = [resp[0].page_content if isinstance(resp, list) and resp else str(resp) for resp in responses_list]
        except Exception as e:
            print(f"Error during batch summarization of transcripts: {e}")
            responses = [""] * len(prompts)
        summaries = {}
        for (source, bank), response in zip(sources, responses):
            summaries[source] = {"bank": bank, "summary": response}
            if self.verbose:
                print(f"Summarized transcript from {source}")
        return summaries

    @staticmethod
    def group_summaries_by_bank_and_quarter(summaries: dict) -> dict:
        grouped = {}
        for source, info in summaries.items():
            bank = info["bank"]
            summary = info["summary"]
            match = re.search(
                r'((\d{1}[qQ]|[qQ]\d)|((first|second|third|fourth)[-_ ]?quarter))[-_ ]?(\d{2,4})',
                source,
                re.IGNORECASE
            )
            if match:
                if match.group(2):
                    quarter_raw = match.group(2).lower()
                    quarter_digit_match = re.search(r'\d', quarter_raw)
                    quarter_num = quarter_digit_match.group(0) if quarter_digit_match else "1"
                elif match.group(4):
                    word = match.group(4).lower()
                    mapping = {"first": "1", "second": "2", "third": "3", "fourth": "4"}
                    quarter_num = mapping.get(word, "1")
                else:
                    quarter_num = "1"
                year_str = match.group(5)
                year_val = int(year_str)
                if year_val < 100:
                    year_val += 2000
                key = f"{year_val}-Q{quarter_num}"
            else:
                key = "Unknown"
            if bank not in grouped:
                grouped[bank] = {}
            if key not in grouped[bank]:
                grouped[bank][key] = []
            grouped[bank][key].append(summary)
        return grouped

    def aggregate_quarterly_summaries_by_bank(self, grouped_quarterly: dict) -> dict:
        quarterly_aggregates = {}
        for bank, quarters in grouped_quarterly.items():
            quarterly_aggregates[bank] = {}
            for key, summaries in quarters.items():
                combined = "\n".join(summaries)
                prompt = (f"Based on the following quarterly sentiment summaries for {bank.upper()} ({key}), "
                          "provide a concise bullet-point overview of the overall sentiment for that quarter:\n\n" + combined)
                try:
                    ds = Dataset.from_dict({"text": [prompt]})
                    response_list = self.llm.generate(ds["text"])
                    response = response_list[0] if response_list else ""
                    if isinstance(response, list) and response and 'answer' in response[0]:
                        answer_text = response[0]['answer']
                    elif isinstance(response, dict) and 'answer' in response:
                        answer_text = response['answer']
                    else:
                        answer_text = str(response)
                except Exception as e:
                    print(f"Error during quarterly aggregation for {bank} - {key}: {e}")
                    answer_text = "Error in aggregation."
                quarterly_aggregates[bank][key] = answer_text
        return quarterly_aggregates

    def forecast_next_quarter_sentiment(self, bank: str, historical_quarterly: list[str]) -> str:
        combined = "\n".join(historical_quarterly)
        prompt = (
            f"Based on the following historical quarterly sentiment summaries for {bank.upper()}, "
            "forecast the overall sentiment for the next quarter. Provide a bullet-point summary of the expected trends, "
            "including any changes in tone, risk factors, or optimism:\n\n" + combined
        )
        try:
            ds = Dataset.from_dict({"text": [prompt]})
            response_list = self.llm.generate(ds["text"])
            response = response_list[0] if response_list else ""
            if isinstance(response, list) and response and 'answer' in response[0]:
                return response[0]['answer']
            elif isinstance(response, dict) and 'answer' in response:
                return response['answer']
            else:
                return str(response)
        except Exception as e:
            print(f"Error forecasting next quarter sentiment for {bank}: {e}")
            return "Forecast unavailable due to an error."

    @staticmethod
    def parse_quarter_key(key: str) -> tuple[int, int]:
        try:
            parts = key.split("-")
            year = int(parts[0])
            quarter = int(re.search(r'\d', parts[1]).group(0))
            return year, quarter
        except Exception:
            return (0, 0)

    def analyze_and_forecast_sentiment_by_bank(self) -> dict:
        print("Generating individual transcript summaries...")
        summaries = self.summarize_individual_transcripts()
        print("Grouping summaries by bank and quarter...")
        grouped_quarterly = self.group_summaries_by_bank_and_quarter(summaries)
        print("Aggregating quarterly summaries...")
        quarterly_aggregates = self.aggregate_quarterly_summaries_by_bank(grouped_quarterly)

        analysis = {}
        for bank, quarters in quarterly_aggregates.items():
            analysis[bank] = {}
            valid_keys = [k for k in quarters.keys() if k != "Unknown"]
            if not valid_keys:
                print(f"No valid quarter keys found for {bank}: {valid_keys}")
                continue

            sorted_keys = sorted(valid_keys, key=lambda k: self.parse_quarter_key(k))
            most_recent_key = sorted_keys[-1]
            current_year, current_quarter = self.parse_quarter_key(most_recent_key)

            years = sorted({self.parse_quarter_key(k)[0] for k in quarters if k != "Unknown"})
            previous_year = max([y for y in years if y < current_year], default=None)

            previous_year_summaries = []
            if previous_year is not None:
                for key in quarters:
                    year, _ = self.parse_quarter_key(key)
                    if year == previous_year:
                        previous_year_summaries.extend(quarters[key])
                combined_prev = "\n".join(previous_year_summaries)
                prompt_prev = (f"Based on the following sentiment summaries for all quarters in {previous_year} for {bank.upper()}, "
                               "provide a bullet-point summary of the overall sentiment trends for that year:\n\n" + combined_prev)
                try:
                    ds = Dataset.from_dict({"text": [prompt_prev]})
                    response_list_prev = self.llm.generate(ds["text"])
                    response_prev = response_list_prev[0] if response_list_prev else ""
                    if isinstance(response_prev, list) and response_prev and 'answer' in response_prev[0]:
                        previous_year_summary = response_prev[0]['answer']
                    elif isinstance(response_prev, dict) and 'answer' in response_prev:
                        previous_year_summary = response_prev['answer']
                    else:
                        previous_year_summary = str(response_prev)
                except Exception as e:
                    print(f"Error summarizing previous year sentiment for {bank}: {e}, Exception: {e}")
                    previous_year_summary = "Error summarizing previous year."
            else:
                previous_year_summary = "Not available (no previous year data)."

            current_summary = quarters.get(most_recent_key, "Not available")

            historical = []
            for key in sorted_keys:
                historical.extend(quarters[key])
            try:
                forecast = self.forecast_next_quarter_sentiment(bank, historical) if historical else "Not available (no historical data for forecast)."
            except Exception as e:
                print(f"Error forecasting next quarter sentiment for {bank}: {e}")
                forecast = "Forecast unavailable due to an error."

            analysis[bank] = {
                "previous_year_summary": previous_year_summary,
                "current_quarter_summary": current_summary,
                "forecast_next_quarter": forecast,
                "most_recent_key": most_recent_key
            }
        return analysis


if __name__ == "__main__":
    pdf_folders = CHATBOT_CONFIG["pdf_folders"]
    persist_directory = CHATBOT_CONFIG["persist_directory"]

    chatbot = BankEarningsChatbotTwoStage(
        pdf_folders,
        persist_directory=persist_directory,
        test_mode=False,
        rebuild_index=True,
        verbose=True
    )

    analysis = chatbot.analyze_and_forecast_sentiment_by_bank()

    print("\nYearly and Quarterly Sentiment Analysis by Bank:")
    for bank, data in analysis.items():
        print(f"\n{bank.upper()} Analysis:")
        print(f"Most Recent Quarter Key: {data.get('most_recent_key', 'N/A')}")
        print("Previous Year Sentiment Summary:")
        print(data["previous_year_summary"])
        print("Current Quarter Sentiment Summary:")
        print(data["current_quarter_summary"])
        print("Forecast for Next Quarter:")
        print(data["forecast_next_quarter"])
        print("-" * 60)

    # ----------------------------
    # Evaluation Example:
    reference_summaries = {
        "1q23-earnings-transcript.pdf": "Reference summary for 1Q23 transcript...",
        "4q24-earnings-call-remarks.pdf": "Reference summary for 4Q24 transcript..."
    }

    generated_summaries = chatbot.summarize_individual_transcripts()
    evaluation_results = chatbot.evaluate_summaries(generated_summaries, reference_summaries, verbose=True)
    print("\nEvaluation Results (ROUGE):")
    for key, scores in evaluation_results.items():
        print(f"{key}: {scores}")

    # To run the chatbot interactively:
    # chatbot.run_chatbot()


Device set to use cuda:0
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Loaded: 1q23-earnings-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: 2q23-earnings-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: 4q24-earnings-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: jpm-1q24-earnings-call-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: jpm-2q24-earnings-call-transcript-final.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: jpm-3q23-earnings-call-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: jpm-4q23-earnings-call-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: jpmc-third-quarter-2024-earnings-conference-call-transcript.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan
Loaded: 1q23-earnings-call-remarks.pdf from /content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs

Input ids are automatically padded from 818 to 832 to be a multiple of `config.block_size`: 64


Summarizing 480 documents/chunks...
Dataset created for summarization prompts:
Dataset({
    features: ['text'],
    num_rows: 480
})


Attention type 'block_sparse' is not possible if sequence_length: 659 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In [None]:


# pdf_folders = [
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/jpmorgan",
#     "/content/drive/MyDrive/BOE/bank_of_england/data/raw/ubs"
# ]

# persist_directory = "/content/drive/MyDrive/BOE/bank_of_england/data/model_outputs"

# chatbot = BankEarningsChatbotTwoStage(
#     pdf_folders=pdf_folders,
#     persist_directory=persist_directory,
#     max_length=256,
#     test_mode=False,
#     rebuild_index=True,    # If True, it reprocesses everything
#     verbose=True,
#     chunk_size=1000,
#     chunk_overlap=100
# )

# # Now you have your two-stage chatbot.


# Instantiate the chatbot object

# Launch an interactive  Chatbot session

In [None]:
chatbot.run_chatbot()