In [1]:
# A default setup cell.
# It imports environment variables, define 'devtools.debug" as a buildins, set PYTHONPATH, and code auto-reload
# Copy it in other Notebooks

%load_ext autoreload
%autoreload 2
%reset -f

from devtools import debug  # noqa: F401  # noqa: F811
from dotenv import load_dotenv
from rich import print  # noqa: F401

assert load_dotenv(verbose=True)

In [2]:
from src.utils.config_mngr import global_config

list_demos = (
    global_config(reload=True).merge_with("config/schemas/document_extractor.yaml").get_list("Document_extractor_demo")
)


test_schema = next((item for item in list_demos if item.get("schema_name") == "Rainbow File"))

In [None]:
from src.demos.ekg.struct_rag_doc_processing import StructuredRagConfig, StructuredRagDocProcessor
from src.demos.ekg.struct_rag_tool_factory import StructuredRagToolFactory

KV_STORE = "file"


vector_store_factory = StructuredRagConfig.get_vector_store_factory()
struct_rag_conf = StructuredRagConfig(
    model_definition=test_schema,
    vector_store_factory=vector_store_factory,
    llm_id=None,
    kvstore_id=KV_STORE,
)
rainbow_rag_processor = StructuredRagDocProcessor(rag_conf=struct_rag_conf)
rainbow_tool_factory = StructuredRagToolFactory(rag_conf=struct_rag_conf)

In [None]:
import os
from pathlib import Path

doc_id = "03.RESM-SOL-9000559500_CNES_TMA_VENUS_VIP_PEPS_THEIA_MUSCATE-v0.2"

file1 = Path(os.getenv("ONEDRIVE", "")) / "prj/atos-kg/rainbow-json/" / (doc_id + "_extracted.json")
assert file1.exists()
doc_text = file1.read_text()

rainbow_report = rainbow_rag_processor.analyze_document(
    document_id=doc_id,
    markdown=doc_text,
)

print("Structured result:", rainbow_report)

assert rainbow_report

In [None]:
chunks = rainbow_rag_processor.chunck(rainbow_report)
# debug(chunks)

In [None]:
from langchain_core.utils.function_calling import convert_to_openai_tool

dyn_tool = rainbow_tool_factory.create_vector_search_lc_tool()
debug(convert_to_openai_tool(dyn_tool))

In [None]:
r = dyn_tool.invoke({"query": "CNES", "selected_sections": ["team"], "entity_keys": []})
print(r)

In [None]:
rainbow_rag_processor.kv_to_vector_store()

In [None]:
# 2. Index the document
rainbow_rag_processor.store_chunks(chunks)
print("Document stored.")

In [None]:
hits = rainbow_tool_factory.query_vectorstore("e-mail address", k=2)
print("Vector hits:", hits)

In [None]:
# 3. Query the vector store
hits = rainbow_tool_factory.query_vectorstore("revenue", k=2, filter={"field_name": {"$eq": "financials"}})
print("Vector hits:", hits)

In [None]:
# vector_store_factory.delete_collection()

In [None]:
from smolagents import CodeAgent, LiteLLMModel, Tool

from src.ai_core.llm_factory import LlmFactory

MODEL_ID = None
llm_factory = LlmFactory(llm_id=MODEL_ID, llm_params={"temperature": 0.7})
llm = LiteLLMModel(model_id=llm_factory.get_litellm_model_name(), **llm_factory.llm_params)

dyn_tool = rainbow_tool_factory.create_vector_search_lc_tool()
sa_tool = Tool.from_langchain(dyn_tool)

agent = CodeAgent(tools=[sa_tool], model=llm)

# agent.run("What are the offerings in space sector ?")

In [None]:
print(dyn_tool.description)

In [None]:
agent.run("What is the bif manager for opportunity 9000559500 ?")

In [None]:
debug(
    dyn_tool.invoke(
        {
            "query": "BIF manager for opportunity 9000559500",
            "selected_sections": ["team"],
            "entity_keys": ["9000559500"],
        }
    )
)

In [None]:
def from_langchain(langchain_tool):
    class LangChainToolWrapper(Tool):
        skip_forward_signature_validation = True

        def __init__(self, _langchain_tool):
            self.name = _langchain_tool.name.lower()
            self.description = _langchain_tool.description
            self.inputs = _langchain_tool.args.copy()
            debug(_langchain_tool.args)
            for input_content in self.inputs.values():
                if "title" in input_content:
                    input_content.pop("title")
                # input_content["description"] = ""
            self.output_type = "string"
            self.langchain_tool = _langchain_tool
            self.is_initialized = True

        def forward(self, *args, **kwargs):
            tool_input = kwargs.copy()
            for index, argument in enumerate(args):
                if index < len(self.inputs):
                    input_key = next(iter(self.inputs))
                    tool_input[input_key] = argument
            return self.langchain_tool.run(tool_input)

    return LangChainToolWrapper(langchain_tool)


sa_tool_1 = from_langchain(dyn_tool)
sa_tool_1.inputs