In [1]:
import sys 
sys.path.append('/workspaces/AI_Chatbot')

In [2]:
import logging
import os
import torch

from langchain_community.vectorstores.chroma import Chroma
from langchain_community.chat_models import ChatOllama
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.embeddings.huggingface import HuggingFaceInstructEmbeddings
from langchain.prompts import MessagesPlaceholder
from langchain.prompts import ChatPromptTemplate
from langchain.vectorstores.utils import filter_complex_metadata
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from constants import CHROMA_SETTINGS, PERSIST_DIRECTORY, SOURCE_DIRECTORY
from callback_logger import CallbackLogger
import performance_logger 

from langchain.globals import set_debug
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain.retrievers import ContextualCompressionRetriever
from semantic_chunking_helper import SematicChunkingHelper
from langchain.retrievers.document_compressors import CrossEncoderReranker

15.51GB


In [3]:
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s",
    level=logging.INFO,
)
logger = logging.getLogger('__file__')

source_dir = SOURCE_DIRECTORY

model_name='hkunlp/instructor-xl'
persist_dir=PERSIST_DIRECTORY + "_" + model_name
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
print(persist_dir)

/workspaces/AI_Chatbot/selfcheck/.DB_hkunlp/instructor-xl


In [5]:
embeddings = HuggingFaceInstructEmbeddings(
        model_name=model_name,
        model_kwargs={"device": device},
)

vector_store = Chroma(
    persist_directory=persist_dir,
    embedding_function=embeddings,
    client_settings=CHROMA_SETTINGS,
)


base_retriever = vector_store.as_retriever(search_kwargs={"k": 10})
model = HuggingFaceCrossEncoder(
    model_name="BAAI/bge-reranker-base", model_kwargs={"device": device}
)
reranker = CrossEncoderReranker(model=model, top_n=2)
retriever = ContextualCompressionRetriever(
    base_compressor=reranker, base_retriever=base_retriever
)

  from tqdm.autonotebook import trange
  _torch_pytree._register_pytree_node(
2024-10-01 18:11:09,315 - INFO - SentenceTransformer.py:66 - Load pretrained SentenceTransformer: hkunlp/instructor-xl
  _torch_pytree._register_pytree_node(


load INSTRUCTOR_Transformer


  _torch_pytree._register_pytree_node(
  return torch.load(checkpoint_file, map_location=map_location)


max_seq_length  512


  model.load_state_dict(torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu')))


In [6]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

In [7]:
def create_conversational_chain(model: ChatOllama):
    """Create chat history"""
    contextualize_q_system_prompt = """Given a chat history and the latest user question \
                                        which might reference context in the chat history, formulate a standalone question \
                                        which can be understood without the chat history. Do NOT answer the question, \
                                        just reformulate it if needed and otherwise return it as is.
                                    """
    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("user", "{input}"),
        ]
    )

    history_aware_retriever = create_history_aware_retriever(
        model, retriever, contextualize_q_prompt
    )

    qa_system_prompt = """
                        You are a helpful DEK assistant for question-answering DEK policies. \
                        Do not give me any information outside of PROVIDED CONTEXT. \
                        If you don't know the answer, just say that you don't know. \
                        You have to answer the question in Vietnamese. \
                        {context}
                        """
    qa_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", qa_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("user", "{input}"),
        ]
    )

    chain_from_docs = (
        RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
        | qa_prompt
        | model
        | StrOutputParser()
    )

    question_answer_chain = create_stuff_documents_chain(model, qa_prompt)

    rag_chain = create_retrieval_chain(
        history_aware_retriever, question_answer_chain
    ).assign(answer=chain_from_docs)

    return rag_chain

In [8]:
chain = create_conversational_chain(model)

TypeError: Expected a Runnable, callable or dict.Instead got an unsupported type: <class 'langchain_community.cross_encoders.huggingface.HuggingFaceCrossEncoder'>

In [None]:
def filter_answer_from_response(response: dict):
    """filter_answer_from_response"""
    if response["answer"] or len(response["answer"]) > 0:
        response["answer"] = response["answer"].replace("</s> [INST]", "")
        response["answer"] = response["answer"].replace("</s>", "")
        response["answer"] = response["answer"].replace("<s>", "")
        response["answer"] = response["answer"].replace("[ANSW]", "")
        response["answer"] = response["answer"].replace("[ANS]", "")
        response["answer"] = response["answer"].replace("[/ANSW]", "")
        response["answer"] = response["answer"].replace("[INST]", "")
        response["answer"] = response["answer"].replace("[/INST]", "")

def ask(query: str, chat_history: list):
    """Retrieve answer from LLM"""
    if not chain:
        return "Please, add a document first."

    performance_logger.open(query, model=model_name)
    callback_handler = CallbackLogger(logger=performance_logger)
    config = {"callbacks": [callback_handler]}

    result = chain.invoke(
        {"input": query, "chat_history": chat_history}, config=config
    )
    performance_logger.close(result)

    filter_answer_from_response(result)

    return result