In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import os
import logging

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import SentenceTransformerEmbeddings
from elasticsearch import Elasticsearch

from redbox.models import Settings
from redbox.models.settings import ElasticLocalSettings
from redbox.storage import ElasticsearchStorageHandler
from redbox.transform import bedrock_tokeniser

from dotenv import find_dotenv, load_dotenv

ROOT = Path().resolve().parent

_ = load_dotenv(find_dotenv(ROOT / ".env"))

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

env = Settings(
    _env_file=(ROOT / ".env"),
    minio_host="localhost",
    object_store="minio",
    elastic=ElasticLocalSettings(host="localhost"),
)

embedding_model = SentenceTransformerEmbeddings(model_name=env.embedding_model, cache_folder="../models/")

es = Elasticsearch(
    hosts=[
        {
            "host": "localhost",
            "port": env.elastic.port,
            "scheme": env.elastic.scheme,
        }
    ],
    basic_auth=(env.elastic.user, env.elastic.password),
)

# See core_api.dependecies for details on this hack
os.environ["AZURE_API_VERSION"] = env.openai_api_version

llm = ChatLiteLLM(
    model=env.azure_openai_model,
    streaming=True,
    azure_key=env.azure_openai_api_key,
    api_base=env.azure_openai_endpoint,
    max_tokens=1_024,
)

storage_handler = ElasticsearchStorageHandler(es_client=es, root_index=env.elastic_root_index)

tokeniser = bedrock_tokeniser

# Summarisation scratch

In [None]:
from core_api.retriever import ParameterisedElasticsearchRetriever

retriever = ParameterisedElasticsearchRetriever(
    es_client=es,
    index_name=f"{env.elastic_root_index}-chunk",
    embedding_model=embedding_model,
    params={
        "size": 1,
        "num_candidates": 100,
        "match_boost": 1,
        "knn_boost": 2,
        "similarity_threshold": 0.7,
    },
)

retriever.invoke(
    input={
        "question": "",
        "file_uuids": [
            # "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
            "1a9d18a7-9499-47b6-abcc-4e82370028ee"  # MAMBA paper,
            # "450a972c-356a-4fdb-b080-3af4fa9b0b74", #backendnotes
        ],
        "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1",
    }
)

In [None]:
from typing import Any
from elasticsearch.helpers import scan
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_elasticsearch import ElasticsearchRetriever

# from core_api.retriever import get_all_chunks_query
from core_api.retriever.base import ESQuery


def get_all_chunks_query(query: ESQuery) -> dict[str, Any]:
    query_filter = [
        {
            "bool": {
                "should": [
                    {"term": {"creator_user_uuid.keyword": str(query["user_uuid"])}},
                    {"term": {"metadata.creator_user_uuid.keyword": str(query["user_uuid"])}},
                ]
            }
        }
    ]
    if len(query["file_uuids"]) != 0:
        query_filter.append(
            {
                "bool": {
                    "should": [
                        {"terms": {"parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]}},
                        {"terms": {"metadata.parent_file_uuid.keyword": [str(uuid) for uuid in query["file_uuids"]]}},
                    ]
                }
            }
        )
    return {
        "_source": {"excludes": ["embedding"]},
        "query": {"bool": {"must": {"match_all": {}}, "filter": query_filter}},
    }


class AllElasticsearchRetriever(ElasticsearchRetriever):
    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> list[Document]:
        if not self.es_client or not self.document_mapper:
            raise ValueError("faulty configuration")  # should not happen

        body = self.body_func(query)
        results = list(scan(client=self.es_client, index=self.index_name, query=body, source=True))

        results_documents = [
            Document(page_content=hit["_source"]["text"], metadata=hit["_source"]["metadata"]) for hit in results
        ]

        return sorted(results_documents, key=lambda result: result.metadata["index"])


all_chunks_retriever = AllElasticsearchRetriever(
    es_client=es, index_name=f"{env.elastic_root_index}-chunk", body_func=get_all_chunks_query, content_field="text"
)

# docs = all_chunks_retriever.invoke(
#     input={
#         "question": "",
#         "file_uuids": [
#             # "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
#             "1a9d18a7-9499-47b6-abcc-4e82370028ee" # MAMBA paper,
#             # "450a972c-356a-4fdb-b080-3af4fa9b0b74", #backendnotes
#         ],
#         "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1"
#     }
# )

# docs[:3]
# [doc.metadata["index"] for doc in docs]

body = get_all_chunks_query(
    {
        "question": "",
        "file_uuids": [
            # "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
            "1a9d18a7-9499-47b6-abcc-4e82370028ee"  # MAMBA paper,
            # "450a972c-356a-4fdb-b080-3af4fa9b0b74", #backendnotes
        ],
        "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1",
    }
)
results = list(scan(client=es, index=f"{env.elastic_root_index}-chunk", query=body, source=True))

[Document(page_content=hit["_source"]["text"], metadata=hit["_source"]["metadata"]) for hit in results[:3]]

In [None]:
from langchain_core.runnables import (
    Runnable,
    RunnableLambda,
    RunnablePassthrough,
    chain,
)
from langchain_core.runnables.config import RunnableConfig
from langchain.schema import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate

from redbox.models import ChatRoute
from redbox.models.errors import AIError

from core_api.runnables import (
    make_chat_prompt_from_messages_runnable,
    resize_documents,
)


def build_summary_chain(
    llm,
    all_chunks_retriever,
    tokeniser,
    env,
) -> Runnable:
    def make_document_context(input_dict: dict):
        return (
            all_chunks_retriever
            | {
                str(file_uuid): resize_documents(env.ai.summarisation_chunk_max_tokens)
                for file_uuid in input_dict["file_uuids"]
            }
            | RunnableLambda(lambda f: [chunk.page_content for chunk_lists in f.values() for chunk in chunk_lists])
        ).invoke(input_dict)

    # Stuff chain now missing the RunnabeLambda to format the chunks
    stuff_chain = (
        make_chat_prompt_from_messages_runnable(
            system_prompt=env.ai.summarisation_system_prompt,
            question_prompt=env.ai.summarisation_question_prompt,
            input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
            tokeniser=tokeniser,
        )
        | llm
        | {
            "response": StrOutputParser(),
            "route_name": RunnableLambda(lambda _: ChatRoute.stuff_summarise.value),
        }
    )

    @chain
    def map_operation(input_dict):
        system_map_prompt = env.ai.map_system_prompt
        prompt_template = PromptTemplate.from_template(env.ai.chat_map_question_prompt)

        formatted_map_question_prompt = prompt_template.format(question=input_dict["question"])

        map_prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_map_prompt),
                ("human", formatted_map_question_prompt + env.ai.map_document_prompt),
            ]
        )

        documents = input_dict["documents"]

        map_summaries = (map_prompt | llm | StrOutputParser()).batch(
            documents,
            config=RunnableConfig(max_concurrency=env.ai.summarisation_max_concurrency),
        )

        summaries = " ; ".join(map_summaries)
        input_dict["summaries"] = summaries
        return input_dict

    map_reduce_chain = (
        map_operation
        | make_chat_prompt_from_messages_runnable(
            system_prompt=env.ai.reduce_system_prompt,
            question_prompt=env.ai.reduce_question_prompt,
            input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
            tokeniser=tokeniser,
        )
        | llm
        | {
            "response": StrOutputParser(),
            "route_name": RunnableLambda(lambda _: ChatRoute.map_reduce_summarise.value),
        }
    )

    @chain
    def summarisation_route(input_dict):
        if len(input_dict["documents"]) == 1:
            return stuff_chain

        elif len(input_dict["documents"]) > 1:
            return map_reduce_chain

        else:
            message = "No documents to summarise"
            raise AIError(message)

    return RunnablePassthrough.assign(documents=make_document_context) | summarisation_route

In [None]:
stuff_chain = (
    RunnablePassthrough.assign(documents=all_chunks_retriever)
    | make_chat_prompt_from_messages_runnable(
        system_prompt=env.ai.summarisation_system_prompt,
        question_prompt=env.ai.summarisation_question_prompt,
        input_token_budget=env.ai.context_window_size - env.llm_max_tokens,
        tokeniser=tokeniser,
    )
    | llm
)

stuff_chain.invoke(
    input={
        "question": "Summarise this paper",
        "file_uuids": [
            # "36ed2f1a-57a5-489c-a4cb-fbdd25e2b038", # KAN paper
            # "1a9d18a7-9499-47b6-abcc-4e82370028ee", # MAMBA paper
            "450a972c-356a-4fdb-b080-3af4fa9b0b74",  # backend notes
        ],
        "user_uuid": "5c37bf4c-002c-458d-9e68-03042f76a5b1",
        "chat_history": [],
    }
)