In [None]:
from utils.agent_helper import *
from utils.models import REASONING_LLM, RAG_LLM, EMBEDDING_MODEL
from utils.rag import *
from utils.vector_store import *
from utils.prompts import *

from langchain_community.tools.google_scholar import GoogleScholarQueryRun
from langchain_community.utilities.google_scholar import GoogleScholarAPIWrapper
from langchain_core.prompts import ChatPromptTemplate

import functools
from typing import Annotated
import os

### Google scholar search agent

In [None]:
g_scholar_tool = GoogleScholarQueryRun(api_wrapper=GoogleScholarAPIWrapper())

search_agent = create_agent(
    REASONING_LLM,
    [g_scholar_tool],
    "You are a research assistant who can search for the top 3 most relevant publications in the last one month using the google scholar search engine.",
)
search_node = functools.partial(agent_node, agent=search_agent, name="Search")

### text book agent

In [None]:
textbook_documents = get_markdown_documents('data/text_books/Textbook-of-Diabetes-2024.pdf')

rag_runnables = RAGRunnables(
                            rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT),
                            vector_store = get_vector_store(textbook_documents, EMBEDDING_MODEL, emb_dim=384),
                            llm = RAG_LLM
                        )
textbook_chain = create_rag_chain(rag_runnables.rag_prompt_template, 
                                    rag_runnables.vector_store, 
                                    rag_runnables.llm)

@tool
def retrieve_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the book 'Textbook of Diabetes'."""
  return textbook_chain.invoke({"question" : query})['response']


textbook_agent = create_agent(
    REASONING_LLM,
    [textbook_chain],
    TEXTBOOK_AGENT_PROMPT,
)

textbook_node = functools.partial(agent_node, agent=textbook_agent, name="TextbookInformationRetriever")

### paper agent

In [None]:
folder = 'data/literature'
paths = [os.path.join(folder, file) for file in  os.listdir(folder)]

paper_documents = []
for path in paths:
    document = get_markdown_documents(path)
    paper_documents.extend(document)

rag_runnables = RAGRunnables(
                            rag_prompt_template = ChatPromptTemplate.from_template(RAG_PROMPT),
                            vector_store = get_vector_store(paper_documents, EMBEDDING_MODEL, emb_dim=384),
                            llm = RAG_LLM)
    
paper_chain = create_rag_chain(rag_runnables.rag_prompt_template, 
                                    rag_runnables.vector_store, 
                                    rag_runnables.llm)

@tool
def retrieve_information(
    query: Annotated[str, "query to ask the retrieve information tool"]
    ):
  """Use Retrieval Augmented Generation to retrieve information about the papers provided."""
  return textbook_chain.invoke({"question" : query})['response']


paper_agent = create_agent(
    REASONING_LLM,
    [textbook_chain],
    SAVED_PAPER_AGENT_PROMPT,
)

paper_node = functools.partial(agent_node, agent=paper_agent, name="TextbookInformationRetriever") 