In [12]:

from utils.agent_helper import *
from utils.models import REASONING_LLM, RAG_LLM, EMBEDDING_MODEL
from utils.rag import *
from utils.vector_store import *
from utils.prompts import *


from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper
from langgraph.graph import END, StateGraph
from langchain_core.tools import tool



import functools
import operator
from typing import Annotated, List, TypedDict
import os

In [13]:
from dotenv import load_dotenv

load_dotenv()

True

### Semantic scholar search agent

In [14]:
api_wrapper = SemanticScholarAPIWrapper(top_k_results = 2, load_max_docs = 2)
semantic_query_tool = SemanticScholarQueryRun(api_wrapper=api_wrapper)

query_agent = create_agent(REASONING_LLM, [semantic_query_tool], QUERY_AGENT_PROMPT)

query_node = functools.partial(agent_node, agent=query_agent, name="ScholarQuery")


### research agent

In [4]:
from langchain.schema.output_parser import StrOutputParser
from operator import itemgetter
def create_rag_chain(rag_prompt_template, vector_store, llm):
    retriever = vector_store.as_retriever(search_kwargs={"k": 5})
    rag_chain = ({"context": itemgetter("question") | retriever, "question": itemgetter("question")}
                    | rag_prompt_template | llm | StrOutputParser())
    return rag_chain

In [15]:
textbook_documents = get_markdown_documents('data/text_books/Textbook-of-Diabetes-2024-shortened.pdf', chunk_size=500, chunk_overlap=50)

rag_runnables = RAGRunnables(
                            rag_prompt_template = ChatPromptTemplate.from_template(TEXTBOOK_RAG_PROMPT),
                            vector_store = get_vector_store(textbook_documents, EMBEDDING_MODEL, emb_dim=384, collection_name='textbook_collection'),
                            llm = RAG_LLM
                        )
textbook_chain = create_rag_chain(rag_runnables.rag_prompt_template, 
                                    rag_runnables.vector_store, 
                                    rag_runnables.llm)

@tool
def retrieve_textbook_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the book 'Textbook of Diabetes'."""
  return textbook_chain.invoke({"question" : query})


Processing data/text_books/Textbook-of-Diabetes-2024-shortened.pdf...


In [16]:
folder = 'data/literature'
paths = [os.path.join(folder, file) for file in  os.listdir(folder)]

paper_documents = []
for path in paths:
    document = get_markdown_documents(path, chunk_size=500, chunk_overlap=50)
    paper_documents.extend(document)

rag_runnables = RAGRunnables(
                            rag_prompt_template = ChatPromptTemplate.from_template(PAPER_RAG_PROMPT),
                            vector_store = get_vector_store(paper_documents, EMBEDDING_MODEL, emb_dim=384, collection_name='paper_collection'),
                            llm = RAG_LLM)
    
paper_chain = create_rag_chain(rag_runnables.rag_prompt_template, 
                                    rag_runnables.vector_store, 
                                    rag_runnables.llm)

@tool
def retrieve_paper_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the papers provided."""
  return paper_chain.invoke({"question" : query})


Processing data/literature/s41591-023-02278-8.pdf...
Processing data/literature/PIIS1550413121006318.pdf...


In [62]:
RESEARCH_AGENT_PROMPT = """You are a research assistant who can provide specific information on the documents received. You must only respond with information about the documents related to the request. Make sure every documents are covered."""


In [18]:
from langchain_core.tools import tool

@tool
def retrieve_paper_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the papers provided."""
  return paper_chain.invoke({"question" : query})['response']

@tool
def retrieve_textbook_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the book 'Textbook of Diabetes'."""
  return textbook_chain.invoke({"question" : query})['response']

In [19]:
research_agent = create_agent(
    REASONING_LLM,
    [retrieve_paper_information, retrieve_textbook_information],
    RESEARCH_AGENT_PROMPT,
)

research_node = functools.partial(agent_node, agent=research_agent, name="LocalInformationRetriever")

### Supervisor agent

In [20]:
supervisor_agent = create_team_supervisor(
    REASONING_LLM,
    SUPERVISOR_PROMPT,
    ["ScholarQuery", "LocalInformationRetriever"],
)

  | llm.bind_functions(functions=[function_def], function_call="route")


### graph state

In [21]:
class ResearchTeamState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]
    team_members: List[str]
    next: str

In [22]:
def next_step(state):
    return state['next']

In [23]:
graph = StateGraph(ResearchTeamState)
graph.add_node("Research", research_node)
graph.add_node("Query", query_node)
graph.add_node("Supervisor", supervisor_agent)

graph.add_edge("Query", "Supervisor")
graph.add_edge("Research", "Supervisor")
graph.add_conditional_edges(
    "Supervisor",
    next_step,
    {"ScholarQuery": "Query", "LocalInformationRetriever": "Research", "FINISH": END},
)

In [24]:
graph.set_entry_point("Supervisor")
chain = graph.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(chain.get_graph(xray=True).draw_mermaid_png()))
except:
    pass

In [25]:
def enter_chain(message: str):
    results = {
        "messages": [HumanMessage(content=message)],
    }
    return results

research_chain = enter_chain | chain

In [26]:
for s in research_chain.stream(
    "what caused diabetes?", {"recursion_limit": 100}
):
    if "__end__" not in s:
        print(s)
        print("---")

{'Supervisor': {'next': 'ScholarQuery'}}
---
{'Query': {'messages': [HumanMessage(content='The research paper titled "Social Determinants of Health and Structural Inequities-Root Causes of Diabetes Disparities" discusses how social determinants of health (SDOH) and structural inequities are the root causes of diabetes disparities. The paper emphasizes that historically marginalized groups, such as racial and ethnic minorities and those with lower socioeconomic status, bear a disproportionate burden of diabetes and its associated complications. Factors such as socioeconomic status, neighborhood and physical environment, food environment, health care, and social context are highlighted as key contributors to diabetes outcomes. The paper calls for addressing SDOH at the structural and systems level to achieve health equity and improve outcomes, particularly for marginalized communities.', additional_kwargs={}, response_metadata={}, name='ScholarQuery')]}}
---
{'Supervisor': {'next': 'Loca