In [None]:
%load_ext autoreload
%autoreload 2

# A/B testing scratch

A place to check your tests will work. An unsacred space for development.

In [None]:
from uuid import UUID

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_elasticsearch import ApproxRetrievalStrategy, ElasticsearchStore
from elasticsearch import Elasticsearch


from redbox.models import Settings
from redbox.models.chat import ChatRequest
from redbox.models.settings import ElasticLocalSettings
from redbox.storage import ElasticsearchStorageHandler

env = Settings(_env_file="../.env", minio_host="localhost", elastic=ElasticLocalSettings(host="localhost"))

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

if env.elastic.subscription_level == "basic":
    strategy = ApproxRetrievalStrategy(hybrid=False)
elif env.elastic.subscription_level in ["platinum", "enterprise"]:
    strategy = ApproxRetrievalStrategy(hybrid=True)

vector_store = ElasticsearchStore(
    es_connection=es,
    index_name="redbox-data-chunk",
    embedding=embedding_model,
    strategy=strategy,
    vector_query_field="embedding",
)

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_version=env.openai_api_version,
    api_base=env.azure_openai_endpoint,
    max_tokens=4_096,
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

## K-retrieval

In [None]:
from core_api.build_chains import build_k_retrieval_chain

In [None]:
retriever = vector_store.as_retriever(
    search_type="similarity",
    search_kwargs={
        "k": 5,
        "filter": {
            "terms": {
                "parent_file_uuid.keyword": [
                    "7bcc6d44-6bf3-4c45-b598-8f421531daa2",
                    "a28c04e2-8a1c-41b0-8d29-74ae41aa2e0f",
                ]
            }
        },
        # "filter": {
        #     "parent_file_uuid": "7bcc6d44-6bf3-4c45-b598-8f421531daa2"
        # }
    },
)

retriever.invoke("yest")

In [None]:
from langchain_elasticsearch import ElasticsearchRetriever
from langchain_core.runnables import ConfigurableField
from typing import Any, TypedDict, Callable
from redbox.models import Chunk
from langchain_core.retrievers import BaseRetriever
from functools import partial


class ESQuery(TypedDict):
    question: str
    file_uuids: list[UUID]
    user_uuid: UUID


class ESParams(TypedDict):
    size: int
    num_candidates: int
    match_boost: float
    knn_boost: float
    similarity_threshold: float


def get_es_retriever(env, es) -> BaseRetriever:
    """Creates an Elasticsearch retriever runnable.

    Runnable takes input of a dict keyed to question, file_uuids and user_uuid.

    Runnable returns a list of Chunks.
    """

    def es_query(query: ESQuery, params: ESParams) -> dict[str, Any]:
        vector = embedding_model.embed_query(query["question"])

        query_filter = [{"term": {"creator_user_uuid.keyword": str(query["user_uuid"])}}]

        if len(query["file_uuids"]) != 0:
            query_filter.append({"terms": {"parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]}})

        return {
            "size": params["size"],
            "query": {
                "bool": {
                    "should": [
                        {
                            "match": {
                                "text": {
                                    "query": query["question"],
                                    "boost": params["match_boost"],
                                }
                            }
                        },
                        {
                            "knn": {
                                "field": "embedding",
                                "query_vector": vector,
                                "num_candidates": params["num_candidates"],
                                "filter": query_filter,
                                "boost": params["knn_boost"],
                                "similarity": params["similarity_threshold"],
                            }
                        },
                    ],
                    "filter": query_filter,
                }
            },
        }

    def chunk_mapper(hit: dict[str, Any]) -> Chunk:
        return Chunk(**hit["_source"])

    class ParameterisedElasticsearchRetriever(ElasticsearchRetriever):
        params: ESParams
        body_func: Callable[[str], dict]

        def __init__(self, **kwargs: Any) -> None:
            super().__init__(**kwargs)
            self.body_func = partial(self.body_func, params=self.params)

    default_params = {
        "size": env.ai.rag_k,
        "num_candidates": env.ai.rag_num_candidates,
        "match_boost": 1,
        "knn_boost": 1,
        "similarity_threshold": 0,
    }

    return ParameterisedElasticsearchRetriever(
        es_client=es,
        index_name=f"{env.elastic_root_index}-chunk",
        body_func=es_query,
        document_mapper=chunk_mapper,
        params=default_params,
    ).configurable_fields(
        params=ConfigurableField(
            id="params", name="Retriever parameters", description="A dictionary of parameters to use for the retriever."
        )
    )


chat_request = ChatRequest(
    **{
        "message_history": [
            {"text": "Tell me about energy", "role": "user"},
        ],
        "selected_files": [
            {"uuid": "718dfb9c-3f0c-4942-a0c1-e0458a7a53c6"},
            {"uuid": "a28c04e2-8a1c-41b0-8d29-74ae41aa2e0f"},
        ],
    }
)

retriever = get_es_retriever(env=env, es=es)

response = retriever.with_config(
    configurable={
        "params": {
            "size": 10,
            "num_candidates": 100,
            "match_boost": 1,
            "knn_boost": 2,
            "similarity_threshold": 0.7,
        }
    }
).invoke(
    {
        "question": "BEIS statistical publications Energy Trends",
        "file_uuids": ["718dfb9c-3f0c-4942-a0c1-e0458a7a53c6", "a28c04e2-8a1c-41b0-8d29-74ae41aa2e0f"],
        "user_uuid": "b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5",
    }
)

[(res.parent_file_uuid, res.text) for res in response]

# len(response["sources"]), response["reponse"]

In [None]:
chat_request = ChatRequest(
    **{
        "message_history": [
            {"text": "Tell me about energy", "role": "user"},
        ],
        "selected_files": [
            {"uuid": "718dfb9c-3f0c-4942-a0c1-e0458a7a53c6"},
            {"uuid": "a28c04e2-8a1c-41b0-8d29-74ae41aa2e0f"},
        ],
    }
)

chain, params = await build_k_retrieval_chain(
    chat_request=chat_request,
    user_uuid=UUID("b92ebddb-a77e-4ed7-81b9-a2f7ce814ef5"),
    llm=llm,
    embedding_model=embedding_model,
    storage_handler=storage_handler,
    k=20,
)

response = chain.invoke(params)
len(response["sources"]), response["reponse"]