#### Importing relevant libraries

In [None]:
import warnings
# ignore warnings
warnings.filterwarnings("ignore")

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.llms import Ollama
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_chroma import Chroma
from custom_directory_loaders import DocxDirectoryLoader
from sentence_transformers import CrossEncoder
from textwrap import dedent
from dotenv import load_dotenv
from hashlib import sha256
from typing import List
from IPython.display import display, Markdown
import chromadb
import numpy as np
import os
import shutil

load_dotenv()

#### Defining some functions to organise the code

In [None]:
# Function to load documents
def load_documents(file_directory: str, loaders: List) -> List[Document]:
    documents = []
    for loader in loaders:
        documents.extend(loader(file_directory).load())
        
    return documents


# Function to split documents into chunks and shift the raw files to the saved directory
def prepare_documents(documents: List[Document], chunk_size: int = 500, chunk_overlap: int = 200) -> List[Document]:
    # Breaking down documents into chunks
    splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    docs = splitter.split_documents(documents)
    
    # Shifting the files in the 'raw' directory to the 'saved' directory
    if not os.path.exists("./files_to_process/saved"):
        os.makedirs("./files_to_process/saved")
        
    for file in os.listdir("./files_to_process/raw"):
        shutil.move(os.path.join(os.path.abspath("./files_to_process/raw"), file), os.path.join(os.path.abspath("./files_to_process/saved"), file))
    
    return docs


# Function to add documents to a collection in the vector database, and returns the vector database
def add_to_vector_db(docs: List[Document], collection_name: str, 
                     embeddings: HuggingFaceInferenceAPIEmbeddings) -> Chroma:
    # Initialise ChromaDB client
    if not os.path.exists("./chroma_db_index"):
        os.mkdir("./chroma_db_index")
    chroma_client = chromadb.PersistentClient(path="./chroma_db_index")
    
    # Get/Create a collection
    chroma_client.get_or_create_collection(
        name=collection_name,
    )
    
    # Creating ids for docs
    ids = [sha256(doc.page_content.encode('utf-8')).hexdigest() for doc in docs]

    embeddings = HuggingFaceInferenceAPIEmbeddings(
        model="sentence-transformers/all-MiniLM-l6-v2",
        api_key=os.getenv("HF_TOKEN")
        )

    # Adding documents to the collection
    db = Chroma.from_documents(
        client=chroma_client,
        collection_name=collection_name,
        documents=docs,
        embedding=embeddings,
        ids = ids
    )
    
    return db


# Function to initialise the existing vector database should there be no new documents to be added
def get_vector_db(collection_name: str, embeddings: HuggingFaceInferenceAPIEmbeddings) -> Chroma:
    if not os.path.exists("./chroma_db_index"):
        raise FileNotFoundError("No vector database found. Please add documents to the vector database.")
    
    chroma_client = chromadb.PersistentClient(path="./chroma_db_index")
    
    db = Chroma(client=chroma_client, collection_name=collection_name, 
                embedding_function=embeddings,
                )
    
    return db

#### Loading our documents and storing them in a vector database (ChromaDB) 

In [None]:
# Loading our documents from the directory 
loaders = [DocxDirectoryLoader, PyPDFDirectoryLoader]

if not os.path.exists("./files_to_process/raw"):
    os.makedirs("./files_to_process/raw")
documents = load_documents(file_directory="./files_to_process/raw/", loaders=loaders)

# Embeddings for indexing
embeddings = HuggingFaceInferenceAPIEmbeddings(model="sentence-transformers/all-MiniLM-l6-v2", 
                                               api_key=os.getenv("HF_TOKEN"))
# CHROMA Collection Name
collection_name = "test_collection"

if documents:
    docs = prepare_documents(documents)
    print(f"Number of documents: {len(docs)}")
    print("Example of a document: \n", docs[0].page_content)
    db = add_to_vector_db(docs, collection_name=collection_name, embeddings=embeddings)

# if there are no documents, initialise the existing vector database    
else: 
    db = get_vector_db(collection_name=collection_name, embeddings=embeddings)
    
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 10})

In [None]:
response = retriever.invoke("Give me a brief summary of the document.")
response

#### Reranking using a cross encoder

In [None]:
cross_encoder = CrossEncoder(model_name='cross-encoder/ms-marco-MiniLM-L-6-v2')

In [None]:
# Number of documents to use as context in the prompt
num_docs = 3

def reranker(retrieved_documents: List[Document]):
    
    pairs = [[query, doc.page_content] for doc in retrieved_documents]
    scores = cross_encoder.predict(pairs)
    
    # # printing out to see change in order
    # print("New Ordering:")
    # for o in np.argsort(scores)[::-1]:
    #     print(o+1)
    
    # Selecting top n
    top_n =  [retrieved_documents[i] for i, v in enumerate(np.argsort(scores)[::-1]) if v in range(num_docs)]
    
    return top_n

#### Preparing our prompt and LLM

In [None]:
template = dedent(
       """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to aid in answering the question. 
       Keep the answer clear and concise, and support with examples if possible. Return the answer in a markdown format. 
       Question: {question}
       Context: {context}
       Answer:"""
)

prompt = ChatPromptTemplate.from_messages([
    # ("system", template),
    ("human", template)
])

##### Running on LLM API

In [None]:
model = ChatGoogleGenerativeAI(model="gemini-pro", max_output_tokens=2048, temperature=0)

##### Running locally

In [None]:
# Using Ollama with llama3 as the LLM
model = Ollama(model="llama3", temperature=0)

#### Creating our chain

In [None]:
# Helper function to format the retrieved documents in a format the LLM can take
def format_docs(docs: List[Document]) -> str:
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | RunnableLambda(reranker) | format_docs, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
query = """
What is Message Passing Interface?
"""

response = rag_chain.invoke(query)
display(Markdown(response))