In [None]:
from dotenv import load_dotenv

_ = load_dotenv("../.env")

In [None]:
import sys

sys.path.append("..")

In [None]:
import json

from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage
from langchain_ollama import ChatOllama

local_llm = "qwen2.5:7b"
# llm = ChatOllama(model=local_llm, temperature=0)
# llm = ChatOllama(model="qwen2.5:7b", temperature=0)
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", temperature=0)
llm_json_mode = ChatOllama(model=local_llm, temperature=0, format="json")

In [None]:
from langchain_chroma import Chroma
from langchain_ollama.embeddings import OllamaEmbeddings

In [None]:
embedder = OllamaEmbeddings(model="nomic-embed-text")
db = Chroma(persist_directory="../data/chroma_layers", embedding_function=embedder)

In [None]:
rag_prompt = """You are a World Resources Institute (WRI) assistant specializing in dataset recommendations.

Instructions:
1. Use the following context to inform your response:
{context}

2. User Question:
{question}

3. Response Format to be a valid JSON with list of datasets in the following format:
    {{
        "datasets": [
            {{
                "dataset": The slug of the dataset,
                "explanation": A two-line explanation of why this dataset is relevant to the user's problem
            }},
            ...
        ]
    }}
"""

In [None]:
retriever = db.as_retriever(k=4)

In [None]:
question = "I am interested in biodiversity conservation in Argentina"
docs = retriever.invoke(question)

In [None]:
print(docs[0].page_content)

In [None]:
def make_context(docs):
    fmt_docs = []
    for doc in docs:
        dataset = doc.metadata["dataset"]
        content = f"Dataset: {dataset}\n{doc.page_content}"
        fmt_docs.append(content)

    # Join all formatted documents with double newlines
    return "\n\n".join(fmt_docs)

In [None]:
docs_txt = make_context(docs)

In [None]:
rag_prompt_fmt = rag_prompt.format(context=docs_txt, question=question)

In [None]:
print(rag_prompt_fmt)

In [None]:
generation = llm.invoke([HumanMessage(content=rag_prompt_fmt)])

In [None]:
json.loads(generation.content)["datasets"]

In [None]:
import operator
from typing import Annotated, List

from IPython.display import Image, display
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict

In [None]:
class GraphState(TypedDict):
    question: str  # User question
    generation: str  # LLM generation
    loop_step: Annotated[int, operator.add]
    documents: List[str]  # List of retrieved documents

In [None]:
def retrieve(state):
    print("---RETRIEVE---")
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents}


def generate(state):
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    loop_step = state.get("loop_step", 0)

    # RAG generation
    docs_txt = make_context(documents)
    rag_prompt_fmt = rag_prompt.format(context=docs_txt, question=question)
    generation = llm.invoke([HumanMessage(content=rag_prompt_fmt)])
    datasets = json.loads(generation.content)["datasets"]
    for dataset in datasets:
        dataset["uri"] = (
            f"https://data-api.globalforestwatch.org/dataset/{dataset['dataset']}"
        )
        dataset["tilelayer"] = (
            f"https://tiles.globalforestwatch.org/{dataset['dataset']}/latest/dynamic/{{z}}/{{x}}/{{y}}.png"
        )

    return {"generation": datasets, "loop_step": loop_step + 1}

In [None]:
wf = StateGraph(GraphState)

wf.add_node("retrieve", retrieve)
wf.add_node("generate", generate)

wf.add_edge(START, "retrieve")
wf.add_edge("retrieve", "generate")
wf.add_edge("generate", END)

graph = wf.compile()

In [None]:
display(Image(graph.get_graph(xray=False).draw_mermaid_png()))

In [None]:
result = graph.invoke({"question": "I am interested in tree cover loss over amazon"})

In [None]:
result["generation"]

In [None]:
response = json.loads(result["generation"].content)

In [None]:
response