### Imports and Setup

In [14]:
from llama_index.core.schema import NodeWithScore
from llama_index.core.workflow import Event


class RetrieverEvent(Event):
    """Result of running retrieval"""

    nodes: list[NodeWithScore]


class RerankEvent(Event):
    """Result of running reranking on retrieved nodes"""

    nodes: list[NodeWithScore]


In [15]:
from llama_index.core import (
    Document,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.response_synthesizers import CompactAndRefine
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step

RETRIEVER_TOP_N = 5
RERANKER_TOP_N = 3

class SpeechRAGWorkflow(Workflow):

    def __init__(self, llm, embed_model, reranker_model, index_path=None):
        super().__init__()
        self.index_id = "vector_index_for_speech"
        self.llm = llm
        self.embed_model = embed_model
        self.reranker_model = reranker_model
        self.index_path = index_path
        if self.index_path is None:
            self.index = None
        else:
            storage_context = StorageContext.from_defaults(persist_dir=self.index_path)
            self.index = load_index_from_storage(
                storage_context, index_id=self.index_id
            )

        self.url_mapping = {
            "fy2024_budget_statement.pdf": "https://www.mof.gov.sg/singaporebudget/budget-2024/budget-statement",
            "fy2024_budget_debate_round_up_speech.pdf": "https://www.mof.gov.sg/singaporebudget/budget-2024/budget-debate-round-up-speech",
        }

    @step
    async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None:
        """Entry point to ingest a document, triggered by a StartEvent with `dir_name`."""
        dir_name = ev.get("dir_name")
        save_dir = ev.get("save_dir")
        if not dir_name or not save_dir:
            return None

        # Make index
        documents = SimpleDirectoryReader(dir_name).load_data()
        for document in documents:
            document.metadata["url"] = self.url_mapping.get(
                document.metadata["file_name"]
            )
        for document in documents:
            print(document.metadata)

        self.index = VectorStoreIndex.from_documents(
            documents=documents,
            embed_model=self.embed_model,
        )

        # Save index to disk
        self.index.set_index_id(self.index_id)
        self.index.storage_context.persist(save_dir)

        # Set attributes and return results
        self.index_path = save_dir
        return StopEvent(result=(self.index, self.index_path))

    @step
    async def retrieve(self, ctx: Context, ev: StartEvent) -> RetrieverEvent | None:
        "Entry point for RAG, triggered by a StartEvent with `query`."
        query = ev.get("query")

        if not query:
            return None

        print(f"Query the database with: {query}")

        # store the query in the global context
        await ctx.set("query", query)

        if self.index is None:
            print("Index is empty, load some documents before querying!")
            return None

        retriever = self.index.as_retriever(
            similarity_top_k=RETRIEVER_TOP_N,
            embed_model=self.embed_model,
        )
        nodes = await retriever.aretrieve(query)
        print(f"Retrieved {len(nodes)} nodes.")
        return RetrieverEvent(nodes=nodes)

    @step
    async def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
        # Rerank the nodes
        ranker = self.reranker_model
        print(await ctx.get("query", default=None), flush=True)
        new_nodes = ranker.postprocess_nodes(
            ev.nodes, query_str=await ctx.get("query", default=None)
        )
        print(f"Reranked nodes to {len(new_nodes)}")

        # TODO: Remove later once debugging is done
        # response = None
        print(ev.nodes)
        print("---")

        for node in ev.nodes:
            print(node.metadata)
            print("---------------")
        print()

        print(new_nodes)
        print("---")
        print(new_nodes[0].text)
        print()
        
        return RerankEvent(nodes=new_nodes)

    @step
    async def synthesize(self, ctx: Context, ev: RerankEvent) -> StopEvent:
        """Return a streaming response using reranked nodes."""
        summarizer = CompactAndRefine(llm=self.llm, streaming=True, verbose=True)
        query = await ctx.get("query", default=None)

        response = await summarizer.asynthesize(query, nodes=ev.nodes)
        return StopEvent(result=response)


def get_default_workflow():
    from llama_index.core.postprocessor.llm_rerank import LLMRerank
    from llama_index.embeddings.openai import OpenAIEmbedding
    from llama_index.llms.openai import OpenAI

    llm = OpenAI(model_name="gpt-4o-mini", temperature=0)
    embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")
    reranker_model = LLMRerank(
        choice_batch_size=5, top_n=3, llm=OpenAI(model="gpt-4o-mini", temperature=0)
    )

    workflow = SpeechRAGWorkflow(
        llm=llm,
        embed_model=embed_model,
        reranker_model=reranker_model,
        index_path="data/index_storage_for_speech",
    )


def get_default_workflow():
    from llama_index.core.postprocessor.llm_rerank import LLMRerank
    from llama_index.embeddings.openai import OpenAIEmbedding
    from llama_index.llms.openai import OpenAI

    llm = OpenAI(model_name="gpt-4o-mini", temperature=0)
    embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")
    reranker_model = LLMRerank(
        choice_batch_size=5,
        top_n=RERANKER_TOP_N,
        llm=OpenAI(model="gpt-4o-mini", temperature=0),
    )

    workflow = SpeechRAGWorkflow(
        llm=llm,
        embed_model=embed_model,
        reranker_model=reranker_model,
        index_path="../data/index_storage_for_speech",
    )
    return workflow


In [16]:
workflow = get_default_workflow()

In [17]:
index = workflow.index
retriever = index.as_retriever(similarity_top_k=RETRIEVER_TOP_N)
print(len(retriever._node_ids))

142


### Run Queries

In [18]:
# Run a query
result = await workflow.run(query="What are the key reasons for high inflation over the last two years?")
async for chunk in result.async_response_gen():
    print(chunk, end="", flush=True)

Query the database with: What are the key reasons for high inflation over the last two years?
Retrieved 5 nodes.
What are the key reasons for high inflation over the last two years?
Reranked nodes to 3
[NodeWithScore(node=TextNode(id_='4e5f2999-f56a-42f7-b5bb-1747bd667d43', embedding=None, metadata={'page_label': '3', 'file_name': 'fy2024_budget_debate_round_up_speech.pdf', 'file_path': 'C:\\Users\\weilu\\Desktop\\AI Codes\\budget-rag\\data\\budget_statement_and_speech\\fy2024_budget_debate_round_up_speech.pdf', 'file_type': 'application/pdf', 'file_size': 483783, 'creation_date': '2024-11-20', 'last_modified_date': '2024-11-20', 'url': 'https://www.mof.gov.sg/singaporebudget/budget-2024/budget-debate-round-up-speech'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationsh

In [25]:
result = await workflow.run(query="Which members of parliaments spoke about inflation and cost pressures? State the section number where this info can be found.")
async for chunk in result.async_response_gen():
    print(chunk, end="", flush=True)

Query the database with: Which members of parliaments spoke about inflation and cost pressures? State the section number where this info can be found.
Retrieved 5 nodes.
Which members of parliaments spoke about inflation and cost pressures? State the section number where this info can be found.
Reranked nodes to 3
[NodeWithScore(node=TextNode(id_='4e5f2999-f56a-42f7-b5bb-1747bd667d43', embedding=None, metadata={'page_label': '3', 'file_name': 'fy2024_budget_debate_round_up_speech.pdf', 'file_path': 'C:\\Users\\weilu\\Desktop\\AI Codes\\budget-rag\\data\\budget_statement_and_speech\\fy2024_budget_debate_round_up_speech.pdf', 'file_type': 'application/pdf', 'file_size': 483783, 'creation_date': '2024-11-20', 'last_modified_date': '2024-11-20', 'url': 'https://www.mof.gov.sg/singaporebudget/budget-2024/budget-debate-round-up-speech'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_key

In [26]:
result = await workflow.run(query="Which members of parliaments spoke about housing and transport costs? State the section number where this info can be found.")
async for chunk in result.async_response_gen():
    print(chunk, end="", flush=True)

Query the database with: Which members of parliaments spoke about housing and transport costs? State the section number where this info can be found.
Retrieved 5 nodes.
Which members of parliaments spoke about housing and transport costs? State the section number where this info can be found.
Reranked nodes to 3
[NodeWithScore(node=TextNode(id_='2a866236-f144-47c9-8f60-2dc1bbcb9160', embedding=None, metadata={'page_label': '6', 'file_name': 'fy2024_budget_debate_round_up_speech.pdf', 'file_path': 'C:\\Users\\weilu\\Desktop\\AI Codes\\budget-rag\\data\\budget_statement_and_speech\\fy2024_budget_debate_round_up_speech.pdf', 'file_type': 'application/pdf', 'file_size': 483783, 'creation_date': '2024-11-20', 'last_modified_date': '2024-11-20', 'url': 'https://www.mof.gov.sg/singaporebudget/budget-2024/budget-debate-round-up-speech'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=