# Retrieval Agent (Minimal)

LLM-controlled query refinement + deterministic hybrid retrieval (Qdrant) + lexical table scoring.


In [1]:
%load_ext autoreload
%autoreload 2

import sys
from typing import Any, Dict, List, Optional

sys.path.insert(0, "../src")

from pydantic import BaseModel, Field, field_validator
from langchain_ollama import ChatOllama
from qdrant_client import QdrantClient

from retrieval.pipeline import FinanceRAGPipeline, PipelineConfig, RetrievalConfig
from retrieval.evaluator import score_and_select_tables


In [2]:
TABLES_DIR = "../data/chunked"
COLLECTION_NAME = "sec_docs_hybrid"
DEFAULT_DOC_TYPES = ["table", "table_row"]

client = QdrantClient(host="localhost", port=6333)
config = PipelineConfig(
    retrieval=RetrievalConfig(collection_name=COLLECTION_NAME),
)
pipeline = FinanceRAGPipeline(client, config)


Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [3]:
class RetrievalRefinement(BaseModel):
    queries: List[str] = Field(
        ...,
        description="List of 1-4 retrieval queries (short accounting-term style strings).",
    )

    @field_validator("queries")
    @classmethod
    def _validate_queries(cls, v):
        v = [str(x).strip() for x in v if str(x).strip()]
        if not v:
            raise ValueError("queries must be non-empty")
        return v[:4]


retrieval_llm = ChatOllama(model="qwen3:4b-instruct", temperature=0)
structured_refiner = retrieval_llm.with_structured_output(RetrievalRefinement)


def _summarize_scored_tables(scored_tables: List[Dict[str, Any]], limit: int = 3) -> str:
    if not scored_tables:
        return "No tables scored."
    lines = []
    for t in scored_tables[:limit]:
        name = t.get("table_name")
        score = t.get("total_score")
        headers = t.get("row_headers") or []
        headers_preview = ", ".join([str(h) for h in headers[:20]])
        lines.append(f"table_name={name} total_score={score} row_headers=[{headers_preview}]")
    return "\n".join(lines)


In [4]:
def retrieval_agent_search(
    query: str,
    *,
    ticker: str,
    fiscal_year: int,
    form_type: str = "10-K",
    doc_types: Optional[List[str]] = None,
    min_total_score: int = 15,
    max_attempts: int = 2,
) -> Dict[str, Any]:
    doc_types = doc_types or DEFAULT_DOC_TYPES

    queries: List[str] = [query]
    best_scored: List[Dict[str, Any]] = []
    best_rerank_query = ""

    for attempt in range(max_attempts):
        rerank_query, _fused, reranked = pipeline.run_hybrid_search_pipeline(
            queries,
            ticker=ticker,
            fiscal_year=fiscal_year,
            form_type=form_type,
            doc_types=doc_types,
        )
        
        scored = score_and_select_tables(
            reranked,
            queries,
            str(fiscal_year),
            tables_dir=TABLES_DIR,
        )
        
        # if scored:
        #     best_scored = scored
        #     best_rerank_query = rerank_query
        #     if scored[0].get("total_score", -1) >= min_total_score:
        #         break

        if attempt == max_attempts - 1:
            break

        summary = _summarize_scored_tables(scored)
        prompt = f"""
        You are a retrieval specialist for SEC financial tables.
        
        User query: {query!r}
        Ticker: {ticker} | Fiscal year: {fiscal_year} | Form: {form_type}
        
        Current retrieval queries: {queries}
        Current top table signals (may be weak):
        {summary}
        
        Task: Propose 1-4 improved retrieval queries focused on exact accounting terms/row labels.
        Rules:
        - Keep queries short and specific (e.g., "Term debt", "Commercial paper", "Total liabilities").
        - Do not include the company name or the year unless necessary.
        """
        refined = structured_refiner.invoke(prompt)
        queries = refined.queries

    max_score = best_scored[0]["total_score"] if best_scored else None
    return {
        "queries_used": queries,
        "rerank_query": best_rerank_query,
        "top_tables": best_scored[:3],
        "max_total_score": max_score,
        "metadata_used": {"ticker": ticker, "fiscal_year": fiscal_year, "form_type": form_type},
    }


In [5]:
# Example
result = retrieval_agent_search(
    "What was the total debt as of 2024 year end?",
    ticker="AAPL",
    fiscal_year=2024,
    form_type="10-K",
)
result

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'queries_used': ['Total term debt',
  'Total non-current debt',
  'Debt, less current portion',
  'Total liabilities'],
 'rerank_query': '',
 'top_tables': [],
 'max_total_score': None,
 'metadata_used': {'ticker': 'AAPL', 'fiscal_year': 2024, 'form_type': '10-K'}}

## Retrieval Agent - Enhanced

### Retrieval Tool

In [36]:
# tool
from __future__ import annotations

from typing import Any, Dict, List, Optional
import os

from pydantic import BaseModel, Field, field_validator

# Import your existing functions/modules
from retrieval.pipeline import FinanceRAGPipeline  # expects pipeline.run_hybrid_search_pipeline(...)
from retrieval.evaluator import score_and_select_tables

TABLES_DIR = "../data/chunked"
# os.getenv("TABLES_DIR", "PATH/TO/YOUR/TABLES_DIR")
DEFAULT_DOC_TYPES = ["table"]  # replace with your real defaults
COLLECTION_NAME = "sec_docs_hybrid"

client = QdrantClient(host="localhost", port=6333)
config = PipelineConfig(
retrieval=RetrievalConfig(collection_name=COLLECTION_NAME),
)
pipeline = FinanceRAGPipeline(client, config)

class RetrieveTablesResponse(BaseModel):
    ok: bool = True
    queries_used: List[str]
    rerank_query: str
    top_tables: List[Dict[str, Any]] = Field(default_factory=list)
    max_total_score: Optional[int] = None
    metadata_used: Dict[str, Any] = Field(default_factory=dict)
    error: Optional[str] = None
    trace: Dict[str, Any] = None

class RetrievalQueries(BaseModel):
    queries: List[str] = Field(..., description="1-4 short retrieval queries")

    @field_validator("queries")
    @classmethod
    def _validate_queries(cls, v: List[str]) -> List[str]:
        v = [str(x).strip() for x in v if str(x).strip()]
        if not v:
            raise ValueError("queries must be non-empty")
        return v[:4]

def sec_retrieve_tables(
    *,
    queries: List[str],
    ticker: str,
    fiscal_year: int,
    form_type: str = "10-K",
    doc_types: Optional[List[str]] = None,
    top_k: int = 3,
    min_total_score: int = 0,
) -> RetrieveTablesResponse:
    """
    Deterministic SEC table retrieval:
    hybrid retrieval + rerank + lexical scoring (score_and_select_tables).
    """
    try:
        doc_types = doc_types or DEFAULT_DOC_TYPES
        queries = RetrievalQueries(queries=queries).queries

        t0 = time.time()
        rerank_query, _fused, reranked = pipeline.run_hybrid_search_pipeline(
            queries=queries,
            ticker=ticker,
            fiscal_year=fiscal_year,
            form_type=form_type,
            doc_types=doc_types,
        )
        t1 = time.time()
        scored = score_and_select_tables(
            reranked,
            queries,
            str(fiscal_year),
            tables_dir=TABLES_DIR,
        )

        # apply min score + top_k
        scored = [t for t in scored if (t.get("total_score") or 0) >= min_total_score]
        top_tables = scored[:top_k]
        max_score = (top_tables[0].get("total_score") if top_tables else None)
        t2 = time.time()
        
        return RetrieveTablesResponse(
            ok=True,
            queries_used=queries,
            rerank_query=rerank_query,
            top_tables=top_tables,
            max_total_score=max_score,
            metadata_used={"ticker": ticker, "fiscal_year": fiscal_year, "form_type": form_type},
            trace = {
                "timing_ms": {
                    "hybrid_plus_rerank": int((t1 - t0) * 1000),
                    "lexical_scoring": int((t2 - t1) * 1000),
                    "total": int((t2 - t0) * 1000),
                },
                "counts": {
                    "fused_candidates": len(_fused) if _fused is not None else None,
                    "reranked": len(reranked) if reranked is not None else None,
                    "scored": len(scored),
                },
            }
        )
    except Exception as e:
        return RetrieveTablesResponse(
            ok=False,
            queries_used=queries[:4] if isinstance(queries, list) else [],
            rerank_query="",
            top_tables=[],
            max_total_score=None,
            metadata_used={"ticker": ticker, "fiscal_year": fiscal_year, "form_type": form_type},
            error=str(e),
        )


In [34]:
expanded_queries = ['What was Apple’s total debt (short-term plus long-term) at year-end 2024?', 'Total liabilities', 'Total term debt', 'Commercial paper', 'Less: Current portion of term debt', 'Total non-current liabilities']
result = sec_retrieve_tables(
                        queries= expanded_queries,
                        ticker='AAPL',
                        fiscal_year=2024,
                        form_type="10-K",
                            )

In [35]:
result

RetrieveTablesResponse(ok=True, queries_used=['What was Apple’s total debt (short-term plus long-term) at year-end 2024?', 'Total liabilities', 'Total term debt', 'Commercial paper'], rerank_query='What was Apple’s total debt (short-term plus long-term) at year-end 2024? (Total liabilities, Total term debt, Commercial paper)', top_tables=[{'table': ScoredPoint(id='44965772-e87d-5be6-b709-21ff3e346590', version=1, score=-0.8220280408859253, payload={'doc_id': 'AAPL_10-K_2024::table::12', 'content': 'Table summary: The table presents consolidated balance sheet information for Apple Inc. as of September 28, 2024 and September 30, 2023, including assets, liabilities, and shareholders’ equity.\nRows: September 28, 2024: Date for the balance sheet as of September 28, 2024. September 30, 2023: Date for the balance sheet as of September 30, 2023. ASSETS: – Current assets: Section header indicating current assets for both fiscal years. Cash and cash equivalents – Current assets: Current assets 

### Retrieval Agent

In [47]:
import asyncio
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client


@dataclass
class SecRetrievalMCPClient:
    server_command: str = "python"
    server_args: Optional[List[str]] = None

    _session: Optional[ClientSession] = None
    _read = None
    _write = None
    _stdio_cm = None
    _session_cm = None

    async def __aenter__(self):
        if self.server_args is None:
            # If kernel cwd is repo root, use src/tools/server.py.
            # If cwd is notebooks/, use ../src/tools/server.py.
            p1 = Path("src/tools/server.py")
            p2 = Path("../src/tools/server.py")
            server_path = p1 if p1.exists() else p2
            self.server_args = [str(server_path)]

        server_params = StdioServerParameters(
            command=self.server_command,
            args=self.server_args,
        )

        self._stdio_cm = stdio_client(server_params)
        self._read, self._write = await self._stdio_cm.__aenter__()

        self._session_cm = ClientSession(self._read, self._write)
        self._session = await self._session_cm.__aenter__()
        await self._session.initialize()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if self._session_cm:
            await self._session_cm.__aexit__(exc_type, exc, tb)
        if self._stdio_cm:
            await self._stdio_cm.__aexit__(exc_type, exc, tb)

    async def retrieve_tables(
        self,
        *,
        queries: List[str],
        ticker: str,
        fiscal_year: int,
        form_type: str = "10-K",
        doc_types: Optional[List[str]] = None,
        top_k: int = 3,
        min_total_score: int = 0,
        timeout_s: float = 120.0,
    ) -> Dict[str, Any]:
        assert self._session is not None, "Client not initialized. Use 'async with'."

        args = {
            "queries": queries,
            "ticker": ticker,
            "fiscal_year": fiscal_year,
            "form_type": form_type,
            "doc_types": doc_types,
            "top_k": top_k,
            "min_total_score": min_total_score,
        }

        async def _call():
            result = await self._session.call_tool("sec_retrieve_tables", arguments=args)

            # MCP SDK compatibility: some versions expose camelCase fields.
            structured = getattr(result, "structured_content", None)
            if structured is None:
                structured = getattr(result, "structuredContent", None)
            if structured is not None:
                return structured

            is_error = bool(
                getattr(result, "is_error", False) or getattr(result, "isError", False)
            )

            out_text = []
            for block in getattr(result, "content", []) or []:
                if isinstance(block, types.TextContent):
                    out_text.append(block.text)
                    # Some servers return JSON as text; parse first valid dict/list.
                    try:
                        parsed = json.loads(block.text)
                        if isinstance(parsed, (dict, list)):
                            return parsed
                    except Exception:
                        pass

            return {
                "ok": not is_error,
                "unstructured": out_text,
                "args": args,
            }

        try:
            return await asyncio.wait_for(_call(), timeout=timeout_s)
        except asyncio.TimeoutError:
            return {
                "ok": False,
                "error": (
                    f"MCP tool call timed out after {timeout_s:.0f}s. "
                    "First run may need to load embedding/reranker models."
                ),
                "args": args,
            }



In [48]:
from typing import Any, Dict

async def retrieval_agent(state: Dict[str, Any], client) -> Dict[str, Any]:
    """
    Crawl version: call tool once using orchestrator-produced queries+metadata.
    """
    resp = await client.retrieve_tables(
        queries=state["queries"],
        ticker=state["ticker"],
        fiscal_year=state["fiscal_year"],
        form_type=state.get("form_type", "10-K"),
        doc_types=state.get("doc_types"),
        top_k=3,
        min_total_score=0,
    )
    # Attach results back to state (or return a sub-dict if your orchestrator merges outputs)
    return {**state, "retrieval": resp}


In [50]:
import asyncio

async def main():
  state = {
      "queries": ['What was Apple’s total debt (short-term plus long-term) at year-end 2024?', 'Total liabilities', 'Total term debt', 'Commercial paper', 'Less: Current portion of term debt', 'Total non-current liabilities'],
      "ticker": "AAPL",
      "fiscal_year": 2024,
      "form_type": "10-K",
  }

  async with SecRetrievalMCPClient() as client:
      out = await retrieval_agent(state, client)
      print(out["retrieval"])

await main()  # for notebook

{'ok': True, 'queries_used': ['What was Apple’s total debt (short-term plus long-term) at year-end 2024?', 'Total liabilities', 'Total term debt', 'Commercial paper'], 'rerank_query': 'What was Apple’s total debt (short-term plus long-term) at year-end 2024? (Total liabilities, Total term debt, Commercial paper)', 'top_tables': [{'table': {'id': '44965772-e87d-5be6-b709-21ff3e346590', 'version': 1, 'score': -0.8220280408859253, 'payload': {'doc_id': 'AAPL_10-K_2024::table::12', 'content': 'Table summary: The table presents consolidated balance sheet information for Apple Inc. as of September 28, 2024 and September 30, 2023, including assets, liabilities, and shareholders’ equity.\nRows: September 28, 2024: Date for the balance sheet as of September 28, 2024. September 30, 2023: Date for the balance sheet as of September 30, 2023. ASSETS: – Current assets: Section header indicating current assets for both fiscal years. Cash and cash equivalents – Current assets: Current assets consistin

### Calling our tool / agent

In [3]:
import sys
sys.path.insert(0, "../src")  # use "src" if notebook kernel cwd is repo root

from agents.retrieval.mcp_client import SecRetrievalMCPClient
from agents.retrieval.agent import retrieval_agent

state = {
  "queries": ['What was Apple’s total debt (short-term plus long-term) at year-end 2024?', 'Total liabilities', 'Total term debt', 'Commercial paper', 'Less: Current portion of term debt', 'Total non-current liabilities'],
  "ticker": "AAPL",
  "fiscal_year": 2024,
  "form_type": "10-K",          # optional (defaults to 10-K)
  # "doc_types": ["table"],      # optional
}

async def run_retrieval(state):
  async with SecRetrievalMCPClient() as client:
      return await retrieval_agent(state, client)

out = await run_retrieval(state)
out["retrieval"]

{'ok': True,
 'queries_used': ['What was Apple’s total debt (short-term plus long-term) at year-end 2024?',
  'Total liabilities',
  'Total term debt',
  'Commercial paper'],
 'rerank_query': 'What was Apple’s total debt (short-term plus long-term) at year-end 2024? (Total liabilities, Total term debt, Commercial paper)',
 'top_tables': [{'table': {'id': '44965772-e87d-5be6-b709-21ff3e346590',
    'version': 1,
    'score': -0.8220280408859253,
    'payload': {'doc_id': 'AAPL_10-K_2024::table::12',
     'content': 'Table summary: The table presents consolidated balance sheet information for Apple Inc. as of September 28, 2024 and September 30, 2023, including assets, liabilities, and shareholders’ equity.\nRows: September 28, 2024: Date for the balance sheet as of September 28, 2024. September 30, 2023: Date for the balance sheet as of September 30, 2023. ASSETS: – Current assets: Section header indicating current assets for both fiscal years. Cash and cash equivalents – Current assets