In [None]:
from dotenv import load_dotenv
_ = load_dotenv("../.env")

In [None]:
import sys
sys.path.append("..")

In [None]:
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic

local_llm = "qwen2.5:7b"
# llm = ChatOllama(model=local_llm, temperature=0)
# llm = ChatOllama(model="qwen2.5:7b", temperature=0)
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", temperature=0)
llm_json_mode = ChatOllama(model=local_llm, temperature=0, format="json")

### VectorDB

In [None]:
from langchain_ollama.embeddings import OllamaEmbeddings
from langchain_chroma import Chroma

In [None]:
embedder = OllamaEmbeddings(model="nomic-embed-text")
db = Chroma(
        persist_directory="../data/chroma_db", 
        embedding_function=embedder
)

### Router

In [None]:
import json
from langchain_core.messages import HumanMessage, SystemMessage

# Prompt
router_instructions = """You are an expert at routing a user question to a vectorstore or tool call.
The vectorstore contains details about datasets from World Resource Institute(WRI).
Use the vectorstore for questions on topics related to searching datasets. 
For specific question on forest fires use the tool call.
Return JSON with single key, route, that is 'vectorstore' or 'glad-tool' depending on the question."""

queries = ["I am interested in biodiversity conservation in Argentina", 
           "I would like to explore helping with forest loss in Amazon",
           "show datasets related to mangrooves",
           "find forest fires in milan for the year 2022",
           "show stats on forest fires over Ihorombe for 2021"
          ]

In [None]:
# tests
for query in queries:
    response = llm_json_mode.invoke(
        [SystemMessage(content=router_instructions)]
        + [
            HumanMessage(
                content=query
            )
        ]
    )
    response = json.loads(response.content)
    print(query, " ---> ", response["route"])

### RAG

In [None]:
rag_prompt = """You are a World Resources Institute (WRI) assistant specializing in dataset recommendations.

Instructions:
1. Use the following context to inform your response:
{context}

2. User Question:
{question}

3. Response Format:
   - Only use information from the provided context
   - For each recommended dataset:
     - Dataset URL
     - Two-line explanation of why this dataset is relevant to the user's problem
"""

In [None]:
retriever = db.as_retriever(k=4)

In [None]:
question = "I am interested in biodiversity conservation in Argentina"
docs = retriever.invoke(question)

In [None]:
def make_context(docs):
    fmt_docs = []
    for doc in docs:
        url = f"https://data-api.globalforestwatch.org/dataset/{doc.metadata['dataset']}"
        content = "URL: " + url + "\n" + doc.page_content
        fmt_docs.append(content)
    return "\n\n".join(fmt_docs)

In [None]:
docs_txt = make_context(docs)

In [None]:
rag_prompt_fmt = rag_prompt.format(context=docs_txt, question=question)

In [None]:
generation = llm.invoke([HumanMessage(content=rag_prompt_fmt)])

In [None]:
print(generation.content)

# Agent

In [None]:
import operator
from typing_extensions import TypedDict
from typing import List, Annotated
from IPython.display import Image, display, Markdown
from langgraph.graph import START, MessagesState, StateGraph, END, add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import AnyMessage, ToolMessage

from src.tools.glad.weekly_alerts_tool import glad_weekly_alerts_tool
from src.tools.location.tool import location_tool

In [None]:
tools = [location_tool, glad_weekly_alerts_tool]
llm_with_tools = llm.bind_tools(tools)

In [None]:
class GraphState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    question: str  # User question
    generation: str  # LLM generation
    answers: int  # Number of answers generated
    loop_step: Annotated[int, operator.add]
    documents: List[str]  # List of retrieved documents

### Nodes

In [None]:
def retrieve(state):
    print("---RETRIEVE---")
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents}

def generate(state):
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    loop_step = state.get("loop_step", 0)

    # RAG generation
    docs_txt = make_context(documents)
    rag_prompt_fmt = rag_prompt.format(context=docs_txt, question=question)
    generation = llm.invoke([HumanMessage(content=rag_prompt_fmt)])
    return {"generation": generation, "loop_step": loop_step + 1}

def assistant(state):
    sys_msg = SystemMessage(content="""You are a helpful assistant tasked with answering the user queries for WRI data API.
        Use the `location-tool` to get iso, adm1 & adm2 of any region or place.
        Use the `glad-weekly-alerts-tool` to get forest fire information for a particular year. Think through the solution step-by-step first and then execute.
        
        For eg: If the query is "Find forest fires in Milan for the year 2024"
        Steps
        1. Use the `location_tool` to get iso, adm1, adm2 for place `Milan` by passing `query=Milan`
        2. Pass iso, adm1, adm2 along with year `2024` as args to `glad-weekly-alerts-tool` to get information about forest fire alerts.
        """)
    if not state["messages"]:
        state["messages"] = [HumanMessage(state["question"])]
    return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}

tool_node = ToolNode(tools)

In [None]:
# tools_by_name = {tool.name: tool for tool in tools}
# def tool_node(state: dict):
#     result = []
#     for tool_call in state["messages"][-1].tool_calls:
#         tool = tools_by_name[tool_call["name"]]
#         print(tool)
#         observation = tool.invoke(tool_call["args"])
#         result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
#     return {"messages": result}

### Edges

In [None]:
def router(state):
    print("---ROUTER---")
    response = llm_json_mode.invoke(
        [SystemMessage(content=router_instructions)]
        + [
            HumanMessage(
                content=state["question"]
            )
        ]
    )
    route = json.loads(response.content)["route"]
    if route == "vectorstore":
        print("---ROUTING-TO-RAG---")
        return "retrieve"
    elif route == "glad-tool":
        print("---ROUTING-TO-TOOLS---")
        return "assistant"

### Graph

In [None]:
from IPython.core.debugger import set_trace

In [None]:
def call_tool(state):
    # set_trace()
    last_msg = state["messages"][-1]
    if not last_msg.tool_calls:
        return "__end__"
    return "tools"

In [None]:
wf = StateGraph(GraphState)

wf.add_node("retrieve", retrieve)
wf.add_node("generate", generate)
wf.add_node("assistant", assistant)
wf.add_node("tools", tool_node)

wf.set_conditional_entry_point(
    router,
    {
        "retrieve": "retrieve",
        "assistant": "assistant"
    }
)
wf.add_edge("retrieve", "generate")
wf.add_edge("generate", END)
wf.add_conditional_edges(
    "assistant",
    tools_condition
)
wf.add_edge("tools", "assistant")
wf.add_edge("assistant", END)

graph = wf.compile()

In [None]:
display(Image(graph.get_graph(xray=False).draw_mermaid_png()))

In [None]:
result = graph.invoke({"question": "show stats on forest fires over Ihorombe for 2021"})

In [None]:
for msg in result["messages"]:
    msg.pretty_print()

In [None]:
result = graph.invoke({"question": "I am interested in tree cover loss over amazon"})

In [None]:
Markdown(result["generation"].content)