In [None]:
import glob
import os
from typing import List

import chromadb
import langchain
from langchain.callbacks import get_openai_callback
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import CacheBackedEmbeddings, OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain.prompts.chat import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.storage import LocalFileStore, RedisStore
from langchain.vectorstores import Chroma
import redis
import spacy

In [None]:
REDIS_HOST = "redis"
CHROMA_HOST = "chroma"
CHROMA_PERSIST_DIRECTORY = "/chroma"

EMBEDDING_MODEL = "text-embedding-ada-002"
KNOWLEDGE_BASE_DIR = "./knowledge_base"

RETRIEVER_COLLECTION_SETTINGS = {
    "info": [{"name": "bm25", "k": 1}, {"name": "semantic", "k": 3}],
    "links": [{"name": "semantic", "k": 1}]
}

In [None]:
class CachedEmbeddings(CacheBackedEmbeddings):
    def embed_query(self, text: str) -> List[float]:
        return self.embed_documents([text])[0]


chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
chroma_client._settings.is_persistent = True
chroma_client._settings.persist_directory=CHROMA_PERSIST_DIRECTORY

redis_client = redis.Redis(host=REDIS_HOST, port=6379, db=0)
redis_ada_emb_store = RedisStore(client=redis_client, namespace=EMBEDDING_MODEL)

cached_embedder = CachedEmbeddings.from_bytes_store(OpenAIEmbeddings(model=EMBEDDING_MODEL), redis_ada_emb_store, namespace=EMBEDDING_MODEL)

In [None]:
def load_documents(dirpath):
    documents = []
    for filepath in glob.glob(os.path.join(dirpath, '*txt')):
        documents.extend(TextLoader(filepath, encoding="utf-8").load())
    return documents

def load_texts(dirpath):
    documents = load_documents(dirpath)
    return [doc.page_content for doc in documents]


def create_knowledge_vectordb(dirpath, embedder):
    collection_names = []
    for folder in glob.glob(os.path.join(dirpath, "*")):
        documents = load_documents(folder)
        db = Chroma.from_documents(documents=documents, embedding=embedder, collection_name=os.path.basename(folder),
                                   client=chroma_client, persist_directory=CHROMA_PERSIST_DIRECTORY)
        collection_names.append(db._collection.name)
    return collection_names


# collection_names = create_knowledge_vectordb(KNOWLEDGE_BASE_DIR, cached_embedder)
# print(collection_names)

In [None]:
spacy_nlp = spacy.load("uk_core_news_sm")

retrievers = []
for collection_name, collection_config in RETRIEVER_COLLECTION_SETTINGS.items():
    collection_retrievers = []

    for retriever_info in collection_config:
        if retriever_info["name"] == "bm25":
            collection_texts = load_texts(os.path.join(KNOWLEDGE_BASE_DIR, collection_name))
            bm25 = BM25Retriever.from_texts(collection_texts, preprocess_func=lambda x: [token.lemma_ for token in spacy_nlp(x)], **retriever_info)
            collection_retrievers.append(bm25)
        elif retriever_info["name"] == "semantic":
            collection_db = Chroma(embedding_function=cached_embedder, collection_name=collection_name,
                                   client=chroma_client, persist_directory=CHROMA_PERSIST_DIRECTORY)
            semantic_retriever = collection_db.as_retriever(search_type="similarity", search_kwargs=retriever_info)
            collection_retrievers.append(semantic_retriever)

    if len(collection_retrievers) > 1:
        retrievers.append(EnsembleRetriever(retrievers=collection_retrievers))
    else:
        retrievers.append(collection_retrievers[0])

context_retriever = MergerRetriever(retrievers=retrievers) if len(retrievers) > 1 else retrievers[0]

In [None]:
context_retriever.get_relevant_documents("Де найближче відділення?")

In [None]:
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo", verbose=True)

rqa_prompt_template = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template(
        template=("You are an AI assistant who answers customer questions about the services and processes "
                  "of the postal company Nova Poshta. Use the following pieces of context to answer the question. "
                  "Answer only in Ukrainian, regardless of the question language.\n\nCONTEXT:\n{context}\n\n"
                  "USER QUESTION: {question}\n\n"
                  "If the question is not related to the context, tell to contact the support. If the answer is not "
                  "contained in the context, tell to contact support. Don't make up the answer. If the question is not "
                  "related to the postal services or it doesn't make sense, tell that you can't answer it.\n\n"
                  "ANSWER IN UKRAINIAN:'")
    )]
)

rqa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=context_retriever, return_source_documents=True,
                                        chain_type_kwargs={"prompt": rqa_prompt_template})

In [None]:
# langchain.debug = True

with get_openai_callback() as cb:
    result = rqa_chain("Де найближче відділення?")
    print(cb)

In [None]:
print(result['result'])