In [19]:
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate, ChatPromptTemplate,MessagesPlaceholder
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.memory import ChatMessageHistory
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor

import numpy as np

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, AIMessage
from pydantic import BaseModel, Field
from dotenv import load_dotenv

In [None]:
OPENAI_API_KEY = "YOUR_OPENAI"

In [5]:
loader = DirectoryLoader("/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data", glob="*.pdf", loader_cls=PyPDFLoader)
docs = loader.load()


text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 50)
texts = text_splitter.split_documents(docs)

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

vectorstore = FAISS.from_documents(texts, embeddings)

llm_model = ChatOpenAI()


parser = StrOutputParser()

In [20]:
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":20})
compressor = LLMChainExtractor.from_llm(llm_model)

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=retriever  # your existing retriever
)

def format_docs(retrieved_docs):
    metadata = []
    for docs in retrieved_docs:
        source = docs.metadata['source']
        page_label = docs.metadata['page_label']
        metadata.append({"source":source,"page_label": page_label})

    display(metadata)
    
    return "\n\n".join([doc.page_content for doc in retrieved_docs])


def dot_product(a, b):
    return np.dot(a, b)

def rerank_openai(query: str, docs, embeddings_model, top_k=5):
    query_embedding = embeddings_model.embed_query(query)
    
    doc_embeddings = [embeddings_model.embed_query(doc.page_content) for doc in docs]
    scores = [dot_product(query_embedding, doc_emb) for doc_emb in doc_embeddings]
    
    ranked_docs = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
    print(ranked_docs)
    return [doc for _, doc in ranked_docs[:top_k]]

def retrieve_and_rerank(inputs):
    raw_docs = compression_retriever.invoke(inputs["question"])
    reranked = rerank_openai(inputs["question"], raw_docs,embeddings)
    return reranked


reranked_retriever = RunnableLambda(retrieve_and_rerank)

In [18]:
system_prompt = (
    """
    You are a helpful medical assistant for question answering task.use the following pieces of context to answer the answer.If you don't know the answer, just say "I don't know".
    \n\n
    {context}
    """
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("user", "{chat_history}"),
        ("human", "{input}")
    ]
)

In [12]:

# Dictionary to hold memory per session
message_histories = {}

def get_session_history(session_id):
    return message_histories.setdefault(session_id, ChatMessageHistory())

# class InMemoryHistory(BaseChatMessageHistory, BaseModel):
#     """In memory implementation of chat message history."""

#     messages: list[BaseMessage] = Field(default_factory=list)

#     def add_messages(self, messages: list[BaseMessage]) -> None:
#         """Add a list of messages to the store"""
#         self.messages.extend(messages)

#     def clear(self) -> None:
#         self.messages = []

# store = {}
# def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
#     if session_id not in store:
#         store[session_id] = InMemoryHistory()
#     return store[session_id]


In [21]:
# parallel_chain = RunnableParallel({
#    "context": RunnableLambda(lambda x: x["question"]) | retriever | RunnableLambda(format_docs),
#     "input": RunnableLambda(lambda x: x["question"]),
#     "chat_history": RunnableLambda(lambda x: x.get("chat_history", ""))
# })

parallel_chain = RunnableParallel({
   "context": reranked_retriever | RunnableLambda(format_docs),
    "input": RunnableLambda(lambda x: x["question"]),
    "chat_history": RunnableLambda(lambda x: x.get("chat_history", ""))
})

# result = parallel_chain.invoke('what is LLama 2 ?')

main_chain = parallel_chain | prompt | llm_model | parser



# Wrap the chain with memory management
rag_with_memory = RunnableWithMessageHistory(
    main_chain,
    get_session_history,
    input_messages_key="question",
    history_messages_key="chat_history",
)

response = rag_with_memory.invoke(
    {"question": "tell me about aids ?"},
    config={"configurable": {"session_id": "user-001"}}
)

print(response)

# # # --- Run the chain ---
# print(main_chain.invoke("What is ?"))

[(np.float64(0.5918293882849612), Document(metadata={'producer': 'PDFlib+PDI 5.0.0 (SunOS)', 'creator': 'PyPDF', 'creationdate': '2004-12-18T17:00:02-05:00', 'moddate': '2004-12-18T16:15:31-06:00', 'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf', 'total_pages': 637, 'page': 89, 'page_label': '90'}, page_content='AIDS is usually marked by a very low number of CD4+ lymphocytes, followed by a rise in the frequency of opportunistic infections and cancers. Doctors monitor the number and proportion of CD4+ lymphocytes in the patient’s blood in order to assess the progression of the disease and the effectiveness of different medications.')), (np.float64(0.5862299739266871), Document(metadata={'producer': 'PDFlib+PDI 5.0.0 (SunOS)', 'creator': 'PyPDF', 'creationdate': '2004-12-18T17:00:02-05:00', 'moddate': '2004-12-18T16:15:31-06:00', 'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf', 'total_pages': 637, 'page':

[{'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf',
  'page_label': '90'},
 {'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf',
  'page_label': '87'},
 {'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf',
  'page_label': '95'},
 {'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf',
  'page_label': '101'},
 {'source': '/Users/yogeshagrawal/Desktop/Gen AI/25.Medical_Chat_bot/Data/Medical_book.pdf',
  'page_label': '96'}]

AIDS (Acquired Immune Deficiency Syndrome) is an infectious disease caused by the human immunodeficiency virus (HIV). It is the advanced stage of HIV infection, characterized by a severe depletion of CD4+ lymphocytes, leading to opportunistic infections and cancers. Regular monitoring of CD4+ cell count and viral load is essential for managing the disease progression and treatment effectiveness.
