## Local execution


In [1]:
from typing import Optional, Annotated

from pydantic import BaseModel
from pathlib import Path
import requests
from urllib.parse import urljoin

import flytekit as fl
from flytekit.core.artifact import Artifact
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from union.actor import ActorEnvironment

TOGETHER_API_KEY = "samhita-together-api-key"

actor = ActorEnvironment(
    name="contextual-rag",
    replica_count=50,
    ttl_seconds=120,
    container_image=fl.ImageSpec(
        name="contextual-rag",
        registry="ghcr.io/unionai-oss",
        packages=[
            "together==1.3.10",
            "beautifulsoup4==4.12.3",
            "bm25s==0.2.5",
            "pydantic>2",
            "chromadb==0.5.23",
            "union>=0.1.117",
        ],
    ),
    secret_requests=[fl.Secret(key=TOGETHER_API_KEY)],
)


class Document(BaseModel):
    idx: int
    title: str
    url: str
    content: Optional[str] = None
    chunks: Optional[list[str]] = None
    prompts: Optional[list[str]] = None
    contextual_chunks: Optional[list[str]] = None
    tokens: Optional[list[list[int]]] = None

In [2]:
@actor.task
def parse_main_page(
    base_url: str, articles_url: str, local: bool = False
) -> list[Document]:
    from bs4 import BeautifulSoup

    assert base_url.endswith("/"), f"Base URL must end with a slash: {base_url}"
    response = requests.get(urljoin(base_url, articles_url))
    soup = BeautifulSoup(response.text, "html.parser")

    td_cells = soup.select("table > tr > td > table > tr > td")
    documents = []

    idx = 0
    for td in td_cells:
        img = td.find("img")
        if img and int(img.get("width", 0)) <= 15 and int(img.get("height", 0)) <= 15:
            a_tag = td.find("font").find("a") if td.find("font") else None
            if a_tag:
                documents.append(
                    Document(
                        idx=idx, title=a_tag.text, url=urljoin(base_url, a_tag["href"])
                    )
                )
                idx += 1

    if local:
        return documents[:2]

    return documents

In [3]:
@actor.task
def scrape_pg_essays(document: Document) -> Document:
    from bs4 import BeautifulSoup

    response = requests.get(document.url)
    response.raise_for_status()
    soup = BeautifulSoup(response.text, "html.parser")
    content = soup.find("font")

    text = None
    if content:
        text = " ".join(content.get_text().split())
    document.content = text
    return document

In [4]:
@actor.task(cache=True, cache_version="0.2")
def create_chunks(document: Document, chunk_size: int, overlap: int) -> Document:
    if document.content:
        content_chunks = [
            document.content[i : i + chunk_size]
            for i in range(0, len(document.content), chunk_size - overlap)
        ]
        document.chunks = content_chunks
    return document

In [5]:
@actor.task(cache=True, cache_version="0.4")
def generate_context(document: Document, model: str) -> Document:
    from together import Together

    CONTEXTUAL_RAG_PROMPT = """
Given the document below, we want to explain what the chunk captures in the document.

{WHOLE_DOCUMENT}

Here is the chunk we want to explain:

{CHUNK_CONTENT}

Answer ONLY with a succinct explanation of the meaning of the chunk in the context of the whole document above.
"""

    client = Together(api_key=fl.current_context().secrets.get(key=TOGETHER_API_KEY))

    contextual_chunks = [
        f"{response.choices[0].message.content} {chunk}"
        for chunk in (document.chunks or [])
        for response in [
            client.chat.completions.create(
                model=model,
                messages=[
                    {
                        "role": "user",
                        "content": CONTEXTUAL_RAG_PROMPT.format(
                            WHOLE_DOCUMENT=document.content,
                            CHUNK_CONTENT=chunk,
                        ),
                    }
                ],
                temperature=1,
            )
        ]
    ]

    # Assign the contextual chunks back to the document
    document.contextual_chunks = contextual_chunks if contextual_chunks else None
    return document

In [6]:
from chromadb import Documents, EmbeddingFunction, Embeddings
from together import Together


class TogetherEmbedding(EmbeddingFunction):
    def __init__(self, model_name: str):
        self.model = model_name
        self.client = Together(
            api_key=fl.current_context().secrets.get(key=TOGETHER_API_KEY)
        )

    def __call__(self, input: Documents) -> Embeddings:
        outputs = self.client.embeddings.create(
            input=input,
            model=self.model,
        )
        return [x.embedding for x in outputs.data]

In [7]:
@actor.task(cache=True, cache_version="0.19")
def create_vector_index(
    document: Document, model_api_string: str, local: bool = False
) -> Document:
    import os
    import chromadb

    if not local:
        client = chromadb.HttpClient(
            host=f"http://contextual-rag-chroma-db-app.{os.getenv('FLYTE_INTERNAL_TASK_PROJECT')}-{os.getenv('FLYTE_INTERNAL_TASK_DOMAIN')}.svc.cluster.local",  
        ) # NOTE: Hard-coding the value for now; dynamic endpoint retrieval will be supported soon.
    else:
        client = chromadb.PersistentClient()

    collection = client.get_or_create_collection(
        name="paul-graham-collection",
        metadata={"hnsw:space": "cosine", "hnsw:search_ef": 50},
        embedding_function=TogetherEmbedding(model_name=model_api_string),
    )

    if not document.contextual_chunks:
        return document  # Exit early if there are no contextual chunks

    ids = [
        f"id{document.idx}_{chunk_idx}"
        for chunk_idx, _ in enumerate(document.contextual_chunks)
    ]
    documents = [
        chunk[:512]  # NOTE: Trimming the chunk for the embedding model's context window
        for chunk in document.contextual_chunks
    ]
    metadatas = [{"title": document.title} for _ in document.contextual_chunks]

    # Add to the collection
    collection.upsert(ids=ids, documents=documents, metadatas=metadatas)

    return document

In [8]:
@actor.task(cache=True, cache_version="0.5")
def create_bm25s_index(documents: list[Document]) -> tuple[FlyteDirectory, FlyteFile]:
    import json
    import bm25s

    # Prepare data for JSON
    data = {
        f"id{doc_idx}_{chunk_idx}": contextual_chunk
        for doc_idx, document in enumerate(documents)
        if document.contextual_chunks
        for chunk_idx, contextual_chunk in enumerate(document.contextual_chunks)
    }

    retriever = bm25s.BM25(corpus=list(data.values()))
    retriever.index(bm25s.tokenize(list(data.values())))

    ctx = fl.current_context()
    working_dir = Path(ctx.working_directory)
    bm25s_index_dir = working_dir / "bm25s_index"
    contextual_chunks_json = working_dir / "contextual_chunks.json"

    retriever.save(str(bm25s_index_dir))

    # Write the data to a JSON file
    with open(contextual_chunks_json, "w", encoding="utf-8") as json_file:
        json.dump(data, json_file, indent=4, ensure_ascii=False)

    return FlyteDirectory(path=bm25s_index_dir), FlyteFile(contextual_chunks_json)

In [9]:
import functools
from dataclasses import dataclass

from dotenv import load_dotenv

load_dotenv()  # Ensure the secret (together API key) is present in the .env file

BM25Index = Artifact(name="bm25s-index")
ContextualChunksJSON = Artifact(name="contextual-chunks-json")


@fl.workflow
def build_indices_wf(
    base_url: str = "https://paulgraham.com/",
    articles_url: str = "articles.html",
    model_api_string: str = "BAAI/bge-large-en-v1.5",
    chunk_size: int = 250,
    overlap: int = 30,
    model: str = "meta-llama/Llama-3.2-3B-Instruct-Turbo",
    local: bool = True,
) -> tuple[
    Annotated[FlyteDirectory, BM25Index], Annotated[FlyteFile, ContextualChunksJSON]
]:
    tocs = parse_main_page(base_url=base_url, articles_url=articles_url, local=local)
    scraped_content = fl.map_task(scrape_pg_essays)(document=tocs)
    chunks = fl.map_task(
        functools.partial(create_chunks, chunk_size=chunk_size, overlap=overlap)
    )(document=scraped_content)
    contextual_chunks = fl.map_task(functools.partial(generate_context, model=model))(
        document=chunks
    )
    documents = fl.map_task(
        functools.partial(
            create_vector_index, model_api_string=model_api_string, local=local
        )
    )(document=contextual_chunks)
    bm25s_index, contextual_chunks_json_file = create_bm25s_index(
        documents=contextual_chunks
    )
    return bm25s_index, contextual_chunks_json_file


@dataclass
class RetrievalResults:
    vector_results: list[list[str]]
    bm25s_results: list[list[str]]


@fl.task
def retrieve(
    bm25s_index: FlyteDirectory,
    contextual_chunks_data: FlyteFile,
    model_api_string: str = "BAAI/bge-large-en-v1.5",
    queries: list[str] = [
        "What to do in the face of uncertainty?",
        "Why won't people write?",
    ],
) -> RetrievalResults:
    import json

    import bm25s
    import chromadb
    import numpy as np

    # Initialize ChromaDB client
    client = chromadb.PersistentClient()

    # Get the collection and set up the embedding function
    collection_name = client.list_collections()[0].name
    collection = client.get_collection(
        collection_name,
        embedding_function=TogetherEmbedding(model_name=model_api_string),
    )

    # Perform vector-based retrieval
    vector_idx_result = collection.query(
        query_texts=queries,
        n_results=5,
    )

    # Load BM25S index
    retriever = bm25s.BM25()
    bm25_index = retriever.load(save_dir=bm25s_index.download())

    # Load contextual chunk data
    with open(contextual_chunks_data, "r", encoding="utf-8") as json_file:
        contextual_chunks_data_dict = json.load(json_file)

    # Perform BM25S-based retrieval
    bm25s_idx_result = bm25_index.retrieve(
        query_tokens=bm25s.tokenize(queries),
        k=5,
        corpus=np.array(list(contextual_chunks_data_dict.values())),
    )

    # Return results as a dataclass
    return RetrievalResults(
        vector_results=vector_idx_result["documents"],
        bm25s_results=bm25s_idx_result.documents.tolist(),
    )


if __name__ == "__main__":
    bm25s_index, contextual_chunks_data = build_indices_wf()
    results = retrieve(
        bm25s_index=bm25s_index, contextual_chunks_data=contextual_chunks_data
    )
    print(results)

  from .autonotebook import tqdm as notebook_tqdm


odict_keys(['self', 'input'])
odict_keys(['self', 'input'])


                                                     

RetrievalResults(vector_results=[['In the face of uncertainty, make choices that give you more options in the future, providing "uncertainty-proof" future flexibility. do in the face of uncertainty is to make choices that are uncertainty-proof. The less sure you are about what to do, the more important it is to choose options that give you more options in the future. I call this "staying upwind." If you\'re unsure w', 'The chunk advises considering options that are uncertain-proof by choosing "upwind" options - essentially, options that will give you more options or flexibility in the future, rather than limiting your options as much. This can be thought of as investing in areas with less่าง overlap or commitment, so you can more easily switch or pivot later if needed. ng upwind." If you\'re unsure whether to major in math or economics, for example, choose math; math is upwind of economics in the sense that it will be easi', 'The speaker is criticizing the conventional advice of attend



## Remote execution


In [10]:
!union create login --auth device-flow --host demo.hosted.unionai.cloud

Login successful into demo.hosted.unionai.cloud


In [11]:
import os

from dotenv import load_dotenv
from flytekit.configuration import Config
from union.app import Endpoint
from union.remote._app_remote import AppRemote

load_dotenv()  # Ensure you add REGISTRY to the .env file.


# Ensure Chroma DB is up and running
chroma_app = Endpoint(
    name="contextual-rag-chroma-db-app",
    container_image=fl.ImageSpec(
        name="contextual-rag-chroma-db",
        registry=os.getenv("REGISTRY"),
        packages=["union-runtime>=0.1.5", "chromadb"],
    ),
    limits=fl.Resources(cpu="3", mem="5Gi"),
    port=8080,
    min_replicas=1,
    max_replicas=1,
    command=["chroma", "run", "--port", "8080"],
)

app_remote = AppRemote(
    config=Config.for_endpoint(endpoint="demo.hosted.unionai.cloud"),
    project="demo",
    domain="development",
)

app_remote.create_or_update(chroma_app)

[34mImage samhitaalla/contextual-rag-chroma-db:O2I7a4vG4wQykMJAZVwR_A found. Skip building.[0m


In [13]:
from union.remote import UnionRemote
from flytekit.configuration import Config

remote = UnionRemote(
    config=Config.for_endpoint(endpoint="demo.hosted.unionai.cloud"),
    default_project="demo",
    default_domain="development",
)

In [None]:
indices_execution = remote.execute(build_indices_wf, inputs={"local": False})
print(indices_execution.execution_url)

[34mImage ghcr.io/unionai-oss/contextual-rag:wFn1_Qnfqo7_7HcL041hyA found. Skip building.[0m
https://demo.hosted.unionai.cloud/console/projects/demo/domains/development/executions/a4wcn6p2tgkqbzf88cvr


In [15]:
lp = fl.LaunchPlan.get_or_create(
    build_indices_wf,
    name="vector_db_ingestion",
    schedule=fl.CronSchedule(
        schedule="0 1 * * *"
    ),  # Run every day to update the databases
)

registered_lp = remote.register_launch_plan(
    entity=lp, version="v1"
)  # Issue: https://github.com/flyteorg/flyte/issues/6062
remote.activate_launchplan(registered_lp.id)

## Deploy apps

In [None]:
from union.app import App, Endpoint, Input

fastapi_app = Endpoint(
    name="contextual-rag-fastapi-app",
    inputs=[
        Input(
            name="bm25s_index",
            value=BM25Index.query(),
            auto_download=True,
            env_name="BM25S_INDEX",
        ),
        Input(
            name="contextual_chunks_json",
            value=ContextualChunksJSON.query(),
            auto_download=True,
            env_name="CONTEXTUAL_CHUNKS_JSON",
        ),
        Input(
            name="chroma_db_endpoint",
            value=chroma_app.query_endpoint(public=False),
            env_name="CHROMA_DB_ENDPOINT",
        ),
    ],
    container_image=fl.ImageSpec(
        name="contextual-rag-fastapi",
        registry=os.getenv("REGISTRY"),
        packages=[
            "together",
            "bm25s",
            "chromadb",
            "fastapi[standard]",
            "union-runtime>=0.1.5",
        ],
    ),
    limits=fl.Resources(cpu="3", mem="10Gi"),
    port=8080,
    include=["./fastapi_app.py"],
    command=["fastapi", "dev", "--port", "8080"],
    min_replicas=1,
    max_replicas=1,
)


gradio_app = App(
    name="contextual-rag-gradio-app",
    inputs=[
        Input(
            name="fastapi_endpoint",
            value=fastapi_app.query_endpoint(public=False),
            env_name="FASTAPI_ENDPOINT",
        )
    ],
    container_image=fl.ImageSpec(
        name="contextual-rag-gradio",
        registry=os.getenv("REGISTRY"),
        packages=["gradio", "union-runtime>=0.1.5"],
    ),
    limits=fl.Resources(cpu="1", mem="1Gi"),
    port=8080,
    include=["./gradio_app.py"],
    command=[
        "python",
        "gradio_app.py",
    ],
    min_replicas=1,
    max_replicas=1,
)

app_remote.create_or_update(fastapi_app)
app_remote.create_or_update(gradio_app)

[34mImage samhitaalla/contextual-rag-fastapi:_xIA2c7PUskwr4MKatBNTQ found. Skip building.[0m


[34mImage samhitaalla/contextual-rag-gradio:OeZ_3noh7WUu8i4JAnEzVQ found. Skip building.[0m


In [17]:
# app_remote.stop(name="contextual-rag-fastapi-app")
# app_remote.stop(name="contextual-rag-gradio-app")