# Create an Agentic RAG Pipeline

<a target="_blank" href="https://colab.research.google.com/github/unionai-oss/agentic-rag-workshop/blob/main/agentic-rag.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In this workshop, we'll build a simple agentic RAG workflow on Union:

1. First, we'll create a vector store containing biomedical research documents.
2. Then, we'll implement a simple RAG workflow.
3. Finally, we'll extend the RAG workflow to be agentic.

In [None]:
print("Hello World")

## 🛠️ Workshop Setup

Install python libraries by running the code cell below:

In [None]:
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    !git clone https://github.com/unionai-oss/agentic-rag-workshop.git
    %cd union-rag
    %pip install -r requirements.txt

While the libraries are being installed:
- Join the slack channel for support: https://flyte-org.slack.com/archives/C07R0QU6Y2H
- Sign up for a free Union account: https://signup.union.ai/
- Go to the Union dashboard: https://serverless.union.ai/

Login to Union in this notebook session:

In [None]:
!union create login --auth device-flow --serverless

## 🔀 Build a Simple Workflow

A Union task is a containerized function that takes an input and produces an output. A workflow is a collection of tasks that are executed in a specific sequence.

In [None]:
%%writefile simple_wf.py
from flytekit import task, workflow

@task
def hello_world(name: str) -> str:
    return f"Hello, {name}"

@workflow
def main(name: str) -> str:
    return hello_world(name=name)

In [None]:
# Run locally
!union run simple_wf.py main --name "Workshop Attendee"

In [None]:
# Run on Union
!union run --remote simple_wf.py main --name "Workshop Attendee"

In [None]:
# Inspect the main workflow inputs
!union run simple_wf.py main --help

## 🔑 Create OpenAI API Key Secret on Union

First go to https://platform.openai.com/account/api-keys and create an OpenAI API key.

Then, run the following command to make the secret accessible on Union:

In [None]:
!union create secret openai_api_key

In [None]:
!union get secret

If you have issues with the secret, you can delete it by uncommenting the code cell below:

In [None]:
#!union delete secret openai_api_key

## 🗂️ A Simple RAG Workflow

> To run, we must first learn how to walk.

In the first part of this workshop, we'll build a simple RAG workflow.

<img src="./static/rag-workflow.png" alt="Simple RAG Workflow" width="300"/>

### Define an Image Spec

An `ImageSpec` is an easy-to-use interface for specifying a container image.

In [None]:
%%writefile custom_image.py
from flytekit import ImageSpec

image = ImageSpec(
    packages=[
        "beautifulsoup4==4.12.3",
        "chromadb==0.5.3",
        "langchain==0.3.2",
        "langchain-community==0.3.1",
        "langchain-openai==0.2.2",
        "langchain-text-splitters==0.3.0",
        "tiktoken==0.7.0",
        "xmltodict==0.13.0",
    ],
)

### 🗃️ Create a Vector Store

The first step to doing this is to create a vector store of documents. In the code snippet below, we'll create a vector store of [PubMed](https://pubmed.ncbi.nlm.nih.gov/) documents. The documents will depend on a `query` parameter that we pass into the task.

In [None]:
%%writefile vector_store.py
from typing import Annotated, Optional

from flytekit import task, Deck, Secret
from flytekit.deck import MarkdownRenderer
from flytekit.types.directory import FlyteDirectory
from union.artifacts import Artifact, DataCard

from custom_image import image
from utils import get_pubmed_loader, parse_doc, generate_data_card, set_openai_api_key


# Define the vector store artifact
VectorStore = Artifact(name="vector-store")


@task(
    container_image=image,
    cache=True,
    cache_version="0",
    secret_requests=[Secret(key="openai_api_key")],
    enable_deck=True,
    deck_fields=[],
)
def create_vector_store(
    query: str,
    load_max_docs: Optional[int] = None,
    chunk_size: int = 100,
    chunk_overlap: int = 50,
) -> Annotated[FlyteDirectory, VectorStore]:
    """Create a vector store of pubmed documents based on a query."""

    from langchain_community.vectorstores import Chroma
    from langchain_openai import OpenAIEmbeddings
    from langchain_text_splitters import RecursiveCharacterTextSplitter

    set_openai_api_key()

    load_max_docs = load_max_docs or 10

    # load the documents
    loader = get_pubmed_loader(
        query,
        load_max_docs=load_max_docs,
        max_retry=200,
        sleep_time=1.0,
    )
    docs = [parse_doc(doc) for doc in loader.load()]

    # split the documents into chunks
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
    )
    doc_splits = text_splitter.split_documents(docs)

    # create a Chroma vector store
    vector_store = Chroma.from_documents(
        documents=doc_splits,
        collection_name="rag-chroma",
        embedding=OpenAIEmbeddings(),
        persist_directory="./chroma_db",
    )

    # create a data card
    data_card = generate_data_card(docs)
    Deck("Data Card", MarkdownRenderer().to_html(data_card))

    return VectorStore.create_from(
        FlyteDirectory(path=vector_store._persist_directory),
        DataCard(data_card),
    )

In [None]:
# Create the vector store
!union run --remote vector_store.py create_vector_store --query "CRISPR therapy" --load_max_docs 10

If you visit the link produced by the `union run` command, you'll see a task execution that creates the vector store.

- Click on the `vector_store.create_vector_store` item in the `Nodes` list view
- The right-hand sidebar has a **Flyte Deck** button that shows a preview of the contents of the vector store.
- The **Outputs** tab shows the vector store artifact that we created, which we'll use in the next step of this workshop.

## 🔀 Create a RAG Workflow

Next, we'll create a simple implementation of a RAG workflow. It will:

- Take a user question as an input
- Retrieve relevant documents from the vector store
- Generate an answer to the user question based on the retrieved documents.

In [None]:
%%writefile simple_rag.py
from typing import Optional

from flytekit import workflow, Deck, Resources, Secret
from flytekit.deck import MarkdownRenderer
from flytekit.types.directory import FlyteDirectory
from union.actor import ActorEnvironment
from union.artifacts import Artifact

from custom_image import image
from utils import set_openai_api_key


DEFAULT_PROMPT_TEMPLATE = """
You are an assistant for question-answering tasks in the biomedical domain.
Use only the following pieces of retrieved context to answer the question. 
If you don't know the answer, just say that you don't know. Make the answer as
detailed as possible. If the answer contains acronyms, make sure to expand on them.

Question: {question}
Context: {context}
Answer:
"""


actor = ActorEnvironment(
    name="simple-rag",
    ttl_seconds=180,
    container_image=image,
    requests=Resources(cpu="2", mem="8Gi"),
    secret_requests=[Secret(key="openai_api_key")],
)

VectorStore = Artifact(name="vector-store")


@actor.task(enable_deck=True, deck_fields=[])
def retrieve(
    question: str,
    vector_store: FlyteDirectory,
) -> str:
    from langchain_community.vectorstores import Chroma
    from langchain_openai import OpenAIEmbeddings

    set_openai_api_key()

    vector_store.download()
    vector_store = Chroma(
        collection_name="rag-chroma",
        persist_directory=vector_store.path,
        embedding_function=OpenAIEmbeddings(),
    )
    retriever = vector_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 8},
    )
    context = "\n\n".join(doc.page_content for doc in retriever.invoke(question))
    Deck("Context", MarkdownRenderer().to_html(context))
    return context


@actor.task(enable_deck=True, deck_fields=[])
def generate(
    question: str,
    context: str,
    prompt_template: Optional[str] = None,
) -> str:
    from langchain_core.output_parsers import StrOutputParser
    from langchain_core.prompts import PromptTemplate
    from langchain_openai import ChatOpenAI

    set_openai_api_key()

    prompt = PromptTemplate.from_template(prompt_template or DEFAULT_PROMPT_TEMPLATE)
    llm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0.9)

    chain = prompt | llm | StrOutputParser()
    answer = chain.invoke({"question": question, "context": context})
    Deck("Answer", MarkdownRenderer().to_html(answer))
    return answer


@workflow
def run(
    question: str,
    vector_store: FlyteDirectory = VectorStore.query(),  # 👈 this uses the vector store artifact by default
    prompt_template: Optional[str] = None,
) -> str:
    context = retrieve(question, vector_store)
    return generate(
        question=question,
        context=context,
        prompt_template=prompt_template,
    )

In [None]:
# Run the simple RAG workflow
!union run --remote simple_rag.py run --question "What are the latest CRISPR therapies?"

The link produced by `union run` will take you to a task execution that runs the simple RAG workflow. Similar to the vector store, we've also created a Flyte Deck that shows the answer to the question we posed.

## 🤖 Making RAG Agentic

At a fundamental level, we can talk about Agents as a higher level abstraction: we can define it as an entity that gets a higher level job done by taking a series of automonous actions. For example, you can have a AI agent that helps you order food from a restaurant, or book a flight for you.

In this workshop, we'll talk about these kinds of systems at a more fundamental level, i.e. in terms of execution graphs and workflows. An "agentic" system is one that uses an AI model (typically a language model) to modify the state of an execution graph, determine the traversal pattern of the graph, or even modify the shape of the graph itself.

An "agentic" system lies on a spectrum of capabilities, which may be a combination of:

1. 🧰 Having access to tools and memory, e.g. vector stores, functions or APIs that an LLM can call.
2. 🦾 Producing actions that determines the traversal of the execution graph.
3. 🤔 Re-processing the application state through "self-reflection" / "reasoning".
4. 🗓️ Breaking down the job into smaller subtasks through planning.

An agentic system can have one or more agents that implement one or more of these capabilities.

In this workshop, we'll build an agentic system that implements 1-3 of these capabilities:

<img src="./static/agentic-rag-workflow.png" alt="Agentic RAG Workflow" width="700"/>

First we'll define the basic data structures that we'll use to implement agentic RAG:

In [None]:
%%writefile agentic_types.py
import json
from dataclasses import dataclass
from enum import Enum


class RetrieverAction(Enum):
    tools = "tools"
    end = "end"


class GraderAction(Enum):
    generate = "generate"
    rewrite = "rewrite"
    end = "end"


@dataclass
class Message:
    """Json-encoded message."""

    data: str

    def to_langchain(self):
        from langchain_core.messages import AIMessage, ToolMessage, HumanMessage

        data = json.loads(self.data)
        message_type = data.get("type", data.get("role"))
        return {"ai": AIMessage, "tool": ToolMessage, "human": HumanMessage}[message_type](**data)

    @classmethod
    def from_langchain(cls, message):
        return cls(data=json.dumps(message.dict()))


@dataclass
class AgentState:
    """A list of messages capturing the state of the RAG execution graph."""

    messages: list[Message]

    def to_langchain(self) -> dict:
        return {"messages": [message.to_langchain() for message in self.messages]}

    def append(self, message):
        self.messages.append(Message.from_langchain(message))

    def __getitem__(self, index):
        message: Message = self.messages[index]
        return message.to_langchain()


Then, we implement the collection of nodes that will form the basis of the agentic RAG workflow.

In [None]:
%%writefile agentic_nodes.py
from flytekit import Deck, Secret
from flytekit.deck import MarkdownRenderer
from flytekit.types.directory import FlyteDirectory
from union.actor import ActorEnvironment

from agentic_types import RetrieverAction, GraderAction, Message, AgentState
from custom_image import image
from utils import get_vector_store_retriever, set_openai_api_key


actor = ActorEnvironment(
    name="agentic-rag",
    ttl_seconds=180,
    container_image=image,
    secret_requests=[Secret(key="openai_api_key")],
)


@actor.task(cache=True, cache_version="0")
def init_state(user_message: str) -> AgentState:
    """Initialize the AgentState with the user's message."""
    from langchain_core.messages import HumanMessage

    return AgentState(messages=[Message.from_langchain(HumanMessage(user_message))])


@actor.task
def retriever_agent(
    state: AgentState,
    vector_store: FlyteDirectory,
) -> tuple[AgentState, RetrieverAction]:
    """Invokes the agent to either end the loop or call the retrieval tool."""

    from langchain_openai import ChatOpenAI
    from langchain_core.prompts import PromptTemplate
    
    set_openai_api_key()

    vector_store.download()
    retriever_tool = get_vector_store_retriever(vector_store.path)

    prompt = PromptTemplate(
        template="""You are an biomedical research assistant that can retrieve
        documents and answer questions based on those documents.

        Here is the user question: {question} \n

        If the question is related to biomedical research, call the relevant
        tool that you have access to. If the question is not related to
        biomedical research, end the loop with a response that the question
        is not relevant.""",
        input_variables=["question"],
    )

    question_message = state[-1]
    assert question_message.type == "human"

    model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo").bind_tools([retriever_tool])
    chain = prompt | model
    response = chain.invoke({"question": question_message.content})

    # Get agent's decision to call the retrieval tool or end the loop
    action = RetrieverAction.end
    if hasattr(response, "tool_calls") and len(response.tool_calls) > 0:
        action = RetrieverAction.tools

    state.append(response)
    return state, action


@actor.task
def retrieve(
    state: AgentState,
    vector_store: FlyteDirectory,
) -> AgentState:
    """Retrieves documents from the vector store."""

    from langchain_core.messages import AIMessage, ToolMessage

    set_openai_api_key()

    vector_store.download()
    retriever_tool = get_vector_store_retriever(vector_store.path)

    agent_message = state[-1]
    assert isinstance(agent_message, AIMessage)
    assert len(agent_message.tool_calls) == 1

    # invoke the tool to retrieve documents from the vector store
    tool_call = agent_message.tool_calls[0]
    content = retriever_tool.invoke(tool_call["args"])
    response = ToolMessage(content=content, tool_call_id=tool_call["id"])
    state.append(response)
    return state


@actor.task
def grader_agent(state: AgentState) -> GraderAction:
    """Determines whether the retrieved documents are relevant to the question."""

    from langchain_core.prompts import PromptTemplate
    from langchain_core.pydantic_v1 import BaseModel, Field
    from langchain_openai import ChatOpenAI

    set_openai_api_key()

    # Restrict the LLM's output to be a binary "yes" or "no"
    class Grade(BaseModel):
        """Binary score for relevance check."""

        binary_score: str = Field(description="Relevance score 'yes' or 'no'")

    # LLM with tool and validation
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    llm = model.with_structured_output(Grade)

    # Prompt
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved 
        document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the
        user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the
        document is relevant to the question.""",
        input_variables=["context", "question"],
    )

    # Chain
    chain = prompt | llm

    messages = state.to_langchain()["messages"]

    # get the last "human" and "tool" message, which contains the question and
    # retrieval tool context, respectively
    questions = [m for m in messages if m.type == "human"]
    contexts = [m for m in messages if m.type == "tool"]
    question = questions[-1]
    context = contexts[-1]

    scored_result = chain.invoke({"question": question.content, "context": context.content})
    score = scored_result.binary_score
    return {
        "yes": GraderAction.generate,
        "no": GraderAction.rewrite,
    }[score]


@actor.task
def rewrite(state: AgentState) -> AgentState:
    """Transform the query to produce a better question."""

    from langchain_core.messages import HumanMessage
    from langchain_core.pydantic_v1 import BaseModel, Field
    from langchain_openai import ChatOpenAI

    set_openai_api_key()

    messages = state.to_langchain()["messages"]

    # get the last "human", which contains the user question
    questions = [m for m in messages if m.type == "human"]
    question = questions[-1].content

    class rewritten_question(BaseModel):
        """Binary score for relevance check."""

        question: str = Field(description="Rewritten question")
        reason: str = Field(description="Reasoning for the rewrite")

    rewrite_prompt = f"""
    Look at the input and try to reason about the underlying semantic
    intent / meaning. \n
    Here is the initial question:
    \n ------- \n
    {question} 
    \n ------- \n
    Formulate an improved question and provide your reasoning.
    """

    # define model with structured output for the question rewrite
    model = ChatOpenAI(temperature=0, model="gpt-4-0125-preview", streaming=True)
    rewriter_model = model.with_structured_output(rewritten_question)

    response = rewriter_model.invoke([HumanMessage(content=rewrite_prompt)])
    message = HumanMessage(
        content=response.question,
        response_metadata={"rewrite_reason": response.reason},
    )
    state.append(message)
    return state


@actor.task
def generate(state: AgentState) -> AgentState:
    """Generate an answer based on the state."""

    from langchain_openai import ChatOpenAI
    from langchain_core.messages import AIMessage
    from langchain_core.output_parsers import StrOutputParser
    from langchain_core.prompts import ChatPromptTemplate

    set_openai_api_key()

    messages = state.to_langchain()["messages"]

    # get the last "human" and "tool" message, which contains the question and
    # retrieval tool context, respectively
    questions = [m for m in messages if m.type == "human"]
    contexts = [m for m in messages if m.type == "tool"]
    question = questions[-1]
    context = contexts[-1]

    system_message = """
    You are an assistant for question-answering tasks in the biomedical domain.
    Use the following pieces of retrieved context to answer the question. If you
    don't know the answer, just say that you don't know. Make the answer as
    detailed as possible. If the answer contains acronyms, make sure to expand
    them.

    Question: {question}

    Context: {context}

    Answer:
    """

    prompt = ChatPromptTemplate.from_messages([("human", system_message)])
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
    rag_chain = prompt | llm | StrOutputParser()

    response = rag_chain.invoke({"context": context.content, "question": question.content})
    if isinstance(response, str):
        response = AIMessage(response)

    state.append(response)
    return state


@actor.task(enable_deck=True, deck_fields=[])
def return_answer(state: AgentState) -> str:
    """Finalize the answer to return a string to the user."""

    if len(state.messages) == 1:
        answer = f"I'm sorry, I don't understand: '{state.messages}'"
    else:
        data = state.messages[-1].to_langchain()
        answer = data.content
    Deck("Answer", MarkdownRenderer().to_html(answer))
    return answer

Finally, we put everything together into a single workflow.

We're going to define a `retrieval_router` and `rewrite_or_generate_router` to implement the conditional branching logic that we saw in the workflow diagram earlier. This allows the agent node actions to effect the flow of execution.

Then, we wrap it all into a `run` workflow.

In [None]:
%%writefile agentic_rag.py
from flytekit import dynamic, workflow, Secret
from flytekit.types.directory import FlyteDirectory
from union.actor import ActorEnvironment
from union.artifacts import Artifact

from agentic_types import RetrieverAction, GraderAction, AgentState
from agentic_nodes import init_state, retriever_agent, retrieve, grader_agent, rewrite, generate, return_answer
from custom_image import image


actor = ActorEnvironment(
    name="agentic-rag",
    ttl_seconds=180,
    container_image=image,
    secret_requests=[Secret(key="openai_api_key")],
)

VectorStore = Artifact(name="vector-store")

MAX_REWRITES = 10  # 👈 maximum number of question rewrites


@dynamic
def retrieval_router(
    state: AgentState,
    action: RetrieverAction,
    vector_store: FlyteDirectory,
    n_rewrites: int,
) -> AgentState:
    """
    The first conditional branch in the RAG workflow. This determines whether
    the execution graph should end or call the retrieval tool for grading.
    """

    if action == RetrieverAction.end:
        return state
    elif action == RetrieverAction.tools:
        state = retrieve(state=state, vector_store=vector_store)
        grader_action = grader_agent(state=state)
        return rewrite_or_generate_router(state, grader_action, vector_store, n_rewrites)
    else:
        raise RuntimeError(f"Invalid action '{action}'")


@dynamic
def rewrite_or_generate_router(
    state: AgentState,
    grader_action: GraderAction,
    vector_store: FlyteDirectory,
    n_rewrites: int,
) -> AgentState:
    """
    The second conditional branch in the RAG workflow. This determines whether
    the rewrite the original user's query or generate the final answer.
    """
    if grader_action == GraderAction.generate or n_rewrites >= MAX_REWRITES:
        return generate(state=state)
    elif grader_action == GraderAction.rewrite:
        state = rewrite(state=state)
        state, action = retriever_agent(state=state, vector_store=vector_store)
        n_rewrites += 1
        return retrieval_router(
            state=state,
            action=action,
            vector_store=vector_store,
            n_rewrites=n_rewrites,
        )
    else:
        raise RuntimeError(f"Invalid action '{grader_action}'")



@workflow
def run(
    question: str,
    vector_store: FlyteDirectory = VectorStore.query(),
) -> str:
    """An agentic retrieval augmented generation workflow."""
    state = init_state(user_message=question)
    state, action = retriever_agent(state=state, vector_store=vector_store)
    state = retrieval_router(state, action, vector_store, n_rewrites=0)
    return return_answer(state=state)

In [None]:
# Run the simple RAG workflow
!union run --remote agentic_rag.py run --question "What are the latest CRISPR therapies?"

In [None]:
# Trigger the workflow to end early
!union run --remote agentic_rag.py run --question "What's the weather in Atlanta today"

In [None]:
# Trigger the workflow to rewrite the question
!union run --remote agentic_rag.py run --question "Wat are CRESPR treatment?"

🎉 Congratulations! You've completed the workshop.

You've learned about:
- The basic structure and requirements of a RAG workflow.
- How to create a vector store for contextual retrieval.
- How to implement a simple RAG workflow.
- How to make a RAG workflow agentic by implementing:
  - Decision-making nodes
  - Tools
  - Application state updates.


## Bonus Exercise: Implement an RAG workflow with a web search tool

If you want to test your knowledge of agentic RAG, try implementing a RAG workflow that uses a web search tool. It should take the output of the retrieval tool, reason about the relevance of the retrieved context relative to the question, and produce two actions:

- `websearch`: call the web search tool and use that as the context instead of the retrieved documents from the vector store.
- `passthrough`: pass through the retrieved documents from the vector store as the context.

Here's the workflow at a high level:

<img src="./static/agentic-rag-with-search.png" alt="Agentic RAG Workflow with Web Search" width="600"/>

In [None]:
# Your solution goes here!