In [1]:
!union create login --auth device-flow --host demo.hosted.unionai.cloud

Login successful into demo.hosted.unionai.cloud


In [2]:
from typing import Optional, Annotated

from pydantic import BaseModel
from pathlib import Path
import requests
from urllib.parse import urljoin

import flytekit as fl
from flytekit.core.artifact import Artifact
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from union.actor import ActorEnvironment

TOGETHER_API_KEY = "samhita-together-api-key"

actor = ActorEnvironment(
    name="contextual-rag",
    replica_count=50,
    ttl_seconds=120,
    container_image="ghcr.io/unionai-oss/contextual-rag:0.0.1",  # TODO: Map tasks + actors doesn't seem to work with the latest version of union; replace with imagespec when the fix is in.
    secret_requests=[fl.Secret(key=TOGETHER_API_KEY)],
)

BM25Index = Artifact(name="bm25s-index")
ContextualChunksJSON = Artifact(name="contextual-chunks-json")


class Document(BaseModel):
    idx: int
    title: str
    url: str
    content: Optional[str] = None
    chunks: Optional[list[str]] = None
    prompts: Optional[list[str]] = None
    contextual_chunks: Optional[list[str]] = None
    tokens: Optional[list[list[int]]] = None

In [3]:
@actor.task
def parse_main_page(base_url: str, articles_url: str) -> list[Document]:
    from bs4 import BeautifulSoup

    assert base_url.endswith("/"), f"Base URL must end with a slash: {base_url}"
    response = requests.get(urljoin(base_url, articles_url))
    soup = BeautifulSoup(response.text, "html.parser")

    td_cells = soup.select("table > tr > td > table > tr > td")
    documents = []

    idx = 0
    for td in td_cells:
        img = td.find("img")
        if img and int(img.get("width", 0)) <= 15 and int(img.get("height", 0)) <= 15:
            a_tag = td.find("font").find("a") if td.find("font") else None
            if a_tag:
                documents.append(
                    Document(idx=idx, title=a_tag.text, url=urljoin(base_url, a_tag["href"]))
                )
                idx += 1

    return documents

In [4]:
@actor.task
def scrape_pg_essays(document: Document) -> Document:
    from bs4 import BeautifulSoup

    response = requests.get(document.url)
    response.raise_for_status()
    soup = BeautifulSoup(response.text, "html.parser")
    content = soup.find("font")

    text = None
    if content:
        text = " ".join(content.get_text().split())
    document.content = text
    return document

In [5]:
@actor.task(cache=True, cache_version="0.2")
def create_chunks(document: Document, chunk_size: int, overlap: int) -> Document:
    if document.content:
        content_chunks = [
            document.content[i : i + chunk_size]
            for i in range(0, len(document.content), chunk_size - overlap)
        ]
        document.chunks = content_chunks
    return document

In [6]:
@actor.task(cache=True, cache_version="0.4")
def generate_context(document: Document, model: str) -> Document:
    from together import Together

    CONTEXTUAL_RAG_PROMPT = """
Given the document below, we want to explain what the chunk captures in the document.

{WHOLE_DOCUMENT}

Here is the chunk we want to explain:

{CHUNK_CONTENT}

Answer ONLY with a succinct explanation of the meaning of the chunk in the context of the whole document above.
"""

    client = Together(api_key=fl.current_context().secrets.get(key=TOGETHER_API_KEY))

    contextual_chunks = [
        f"{response.choices[0].message.content} {chunk}"
        for chunk in (document.chunks or [])
        for response in [
            client.chat.completions.create(
                model=model,
                messages=[
                    {
                        "role": "user",
                        "content": CONTEXTUAL_RAG_PROMPT.format(
                            WHOLE_DOCUMENT=document.content,
                            CHUNK_CONTENT=chunk,
                        ),
                    }
                ],
                temperature=1,
            )
        ]
    ]

    # Assign the contextual chunks back to the document
    document.contextual_chunks = contextual_chunks if contextual_chunks else None
    return document

In [31]:
@actor.task(cache=True, cache_version="0.18")
def create_vector_index(document: Document, model_api_string: str) -> Document:
    import chromadb
    from chromadb import Documents, EmbeddingFunction, Embeddings
    from together import Together

    class TogetherEmbedding(EmbeddingFunction):
        def __init__(self, model_name: str):
            self.model = model_name
            self.client = Together(
                api_key=fl.current_context().secrets.get(key=TOGETHER_API_KEY)
            )

        def __call__(self, input: Documents) -> Embeddings:
            outputs = self.client.embeddings.create(
                input=input,
                model=self.model,
            )
            return [x.embedding for x in outputs.data]

    client = chromadb.HttpClient(host='http://contextual-rag-chroma-db-app.demo-development.svc.cluster.local', port=8080) # TODO: Remove when the fix is in.
    collection = client.get_or_create_collection(
        name="paul-graham-collection",
        metadata={"hnsw:space": "cosine", "hnsw:search_ef": 50},
        embedding_function=TogetherEmbedding(model_name=model_api_string),
    )

    if not document.contextual_chunks:
        return document  # Exit early if there are no contextual chunks

    ids = [
        f"id{document.idx}_{chunk_idx}"
        for chunk_idx, _ in enumerate(document.contextual_chunks)
    ]
    documents = [
        chunk[:512]  # NOTE: Trimming the chunk for the embedding model's context window
        for chunk in document.contextual_chunks
    ]
    metadatas = [
        {"title": document.title}
        for _ in document.contextual_chunks
    ]

    # Add to the collection
    collection.upsert(ids=ids, documents=documents, metadatas=metadatas)

    return document

In [32]:
@actor.task(cache=True, cache_version="0.5")
def create_bm25s_index(documents: list[Document]) -> tuple[FlyteDirectory, FlyteFile]:
    import json
    import bm25s

    # Prepare data for JSON
    data = {
        f"id{doc_idx}_{chunk_idx}": contextual_chunk
        for doc_idx, document in enumerate(documents)
        if document.contextual_chunks
        for chunk_idx, contextual_chunk in enumerate(document.contextual_chunks)
    }

    retriever = bm25s.BM25(corpus=list(data.values()))
    retriever.index(bm25s.tokenize(list(data.values())))

    ctx = fl.current_context()
    working_dir = Path(ctx.working_directory)
    bm25s_index_dir = working_dir / "bm25s_index"
    contextual_chunks_json = working_dir / "contextual_chunks.json"

    retriever.save(str(bm25s_index_dir))

    # Write the data to a JSON file
    with open(contextual_chunks_json, "w", encoding="utf-8") as json_file:
        json.dump(data, json_file, indent=4, ensure_ascii=False)

    return FlyteDirectory(path=bm25s_index_dir), FlyteFile(contextual_chunks_json)

In [33]:
import functools


@fl.workflow
def build_indices_wf(
    base_url: str = "https://paulgraham.com/",
    articles_url: str = "articles.html",
    model_api_string: str = "BAAI/bge-large-en-v1.5",
    chunk_size: int = 250,
    overlap: int = 30,
    model: str = "meta-llama/Llama-3.2-3B-Instruct-Turbo",
) -> tuple[Annotated[FlyteDirectory, BM25Index],  Annotated[FlyteFile, ContextualChunksJSON]]:
    tocs = parse_main_page(base_url=base_url, articles_url=articles_url)
    scraped_content = fl.map_task(scrape_pg_essays)(document=tocs)
    chunks = fl.map_task(
        functools.partial(create_chunks, chunk_size=chunk_size, overlap=overlap)
    )(document=scraped_content)
    contextual_chunks = fl.map_task(functools.partial(generate_context, model=model))(
        document=chunks
    )
    documents = fl.map_task(functools.partial(create_vector_index, model_api_string=model_api_string))(document=contextual_chunks)
    bm25s_index, contextual_chunks_json_file = create_bm25s_index(
        documents=contextual_chunks
    )
    return bm25s_index, contextual_chunks_json_file

In [34]:
from union.remote import UnionRemote
from flytekit.configuration import Config

remote = UnionRemote(
    config=Config.for_endpoint(endpoint="demo.hosted.unionai.cloud"),
    default_project="demo",
    default_domain="development",
)

In [35]:
indices_execution = remote.execute(build_indices_wf, inputs={})
print(indices_execution.execution_url)

https://demo.hosted.unionai.cloud/console/projects/demo/domains/development/executions/ahlg4vzmcs4sshkhs5vb


In [14]:
lp = fl.LaunchPlan.get_or_create(
    build_indices_wf,
    name="vector_db_ingestion",
    schedule=fl.CronSchedule(schedule="0 1 * * *"), # Run every day to update the databases
)

registered_lp = remote.register_launch_plan(entity=lp, version="v1") # Issue: https://github.com/flyteorg/flyte/issues/6062
remote.activate_launchplan(registered_lp.id)