In [None]:
model=/data/snowflake-arctic-embed-l-v2.0
volume=/Users/vinhnguyen/Projects/ext-chatbot/model_hub # share a volume with the Docker container to avoid downloading weights every run

docker run --gpus all -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-embeddings-inference:cpu-1.8 --model-id $model

In [None]:
"""
Complete AI Query Processing Flow
Consolidated version of the chatbot's query processing pipeline
"""

import os
import re
import json
import requests
from typing import List, Dict, Optional, Tuple, Union
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import Runnable
import torch
from transformers import AutoModel, AutoTokenizer
from langchain_core.embeddings import Embeddings
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient

# ============================================================================
# Configuration
# ============================================================================

# LLM Configuration
LLM_BASE_URL = os.getenv("LLM_BASE_URL", "http://10.164.84.47:8814/v1")
LLM_MODEL = os.getenv("LLM_MODEL", "LLM-ver01")
API_KEY = os.getenv("API_KEY", "xxx")

# Embedding Configuration
EMBED_BASE_URL = os.getenv("EMBED_BASE_URL", "http://10.164.84.47:7893")
EMBED_MODEL = os.getenv("EMBED_MODEL", "embed-ver01")

# Vector Store Configuration
QDRANT_ENDPOINT = os.getenv("QDRANT_ENDPOINT", "http://10.164.84.47:6933")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "ailab2024")
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", None)
QDRANT_HTTPS = os.getenv("QDRANT_HTTPS", False)
QDRANT_PORT = os.getenv("QDRANT_PORT", 6933)

# Reranker Configuration
RERANKER_BASE_URL = os.getenv("RERANKER_BASE_URL", "http://10.164.84.47:7893/rerank")
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "reranker-ver01")

# Thresholds
RETRIEVAL_SCORE_THRESHOLD = float(os.getenv("RETRIEVAL_SCORE_THRESHOLD", "0.4"))
RERANKER_SCORE_THRESHOLD = float(os.getenv("RERANKER_SCORE_THRESHOLD", "0.25"))
RELATED_SCORE_THRESHOLD = float(os.getenv("RELATED_SCORE_THRESHOLD", "3.5"))

# Business Configuration
BOT_NAME = os.getenv("bot_name", "Yuta")
BUSINESS_NAME = os.getenv("business_name", "Sushi Hokkaido Sachi")
INTRO_DOC = os.getenv("intro_doc", "Nhà hàng Nhật hàng đầu tại Việt Nam...")
WEB_LINK = os.getenv("web_link", "https://sushihokkaidosachi.com.vn/")
HOTLINE = os.getenv("hotline", "")

# Defaults
DEFAULT_SUGGESTION_QUESTIONS = [
    "Bạn là ai?",
    "Bạn cung cấp sản phẩm gì?",
    "Làm sao để liên hệ?",
]

# ============================================================================
# Helper Classes and Functions
# ============================================================================

class ReasoningExtractor(Runnable):
    """Extracts reasoning content from LLM responses"""
    def __init__(
        self,
        start_token: str = "`<think>`",
        end_token: str = "`</think>`",
        return_full_output_only: bool = False
    ):
        self.start_token = start_token
        self.end_token = end_token
        self.return_full_output_only = return_full_output_only
        self.reasoning_regex = re.compile(
            rf"{re.escape(start_token)}(.*?){re.escape(end_token)}",
            re.DOTALL
        )

    def invoke(self, input: Union[str, dict], config: Optional[dict] = None) -> Union[str, dict]:
        if not isinstance(input, AIMessage):
            raise TypeError("ReasoningExtractor only supports AIMessage as input")
        
        model_output = input.content or f"{self.start_token}{self.end_token}"
        reasoning, content = self.extract_reasoning(model_output)

        if self.return_full_output_only:
            return content or ""

        return {
            "reasoning": reasoning,
            "content": content
        }

    def extract_reasoning(self, model_output: str) -> Tuple[Optional[str], Optional[str]]:
        if self.end_token not in model_output:
            return None, model_output 
        if self.start_token not in model_output:
            model_output = f"{self.start_token}{model_output}"

        matches = self.reasoning_regex.findall(model_output)
        if not matches:
            return None, model_output

        reasoning_content = matches[0]
        end_index = len(f"{self.start_token}{reasoning_content}{self.end_token}")
        final_output = model_output[end_index:]

        if len(final_output.strip()) == 0:
            return reasoning_content, None
        return reasoning_content, final_output


def convert_to_messages(history: List[Dict]) -> List[BaseMessage]:
    """Convert OpenAI format messages to LangChain messages"""
    messages = []
    for msg in history:
        role = msg.get("role", "")
        content = msg.get("content", "")
        
        if role == "user":
            messages.append(HumanMessage(content=content))
        elif role == "assistant":
            messages.append(AIMessage(content=content))
        elif role == "system":
            messages.append(SystemMessage(content=content))
    
    return messages


def is_valid_url(url: str) -> bool:
    """Check if a string is a valid URL"""
    url_regex = re.compile(
        r'^(https?|ftp)://'
        r'(www\.)?'
        r'(?:(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,})'
        r'(?:\/[^\s]*)?$'
    )
    return bool(url_regex.match(url))


def extract_reference_from_response(response: str) -> Tuple[str, Dict]:
    """Extract reference URLs from response text"""
    pattern = r"\[([^\]]*)\]\((https?://[^\)]+)\)"
    try:
        matches = re.findall(pattern, response)
        matches = list(set(matches))

        def replace_with_title(match):
            title = match.group(1)
            return title
        
        final_response = re.sub(pattern, replace_with_title, response)

        reference_jsons = {
            "reference_url": [
                {
                    "title": match[0] if match[0].lower() not in ["đây", "tại đây"] else "Link",
                    "payload": {
                        "url": match[1] + "?utm_source=chatbot_AILab"
                    },
                    "type": "oa.open.url",
                }
                for match in matches
            ]
        }

        return final_response, reference_jsons
    except Exception as e:
        print(f"Error extracting reference: {e}")
        return response, {"reference_url": []}


def filter_url_in_response(response: str) -> str:
    """Filter and clean URLs in response"""
    pattern = r"\[([^\]]*)\]\((https?://[^\)]+)\)"
    try:
        def replace_with_title(match):
            title = match.group(1)
            return title
        final_response = re.sub(pattern, replace_with_title, response)
        return final_response
    except Exception as e:
        print(f"Error filtering URL: {e}")
        return response


def check_prefix_tokens_with_questions(
    response_string: str,
    answer_prefix_string: str = "Final Answer:",
    question_prefix_string: str = "Bonus questions:",
) -> Tuple[bool, List[str], Dict, Dict]:
    """Extract answer and questions from response"""
    answer_prefix = answer_prefix_string.lower()
    question_prefix = question_prefix_string.lower()
    
    response_lower = response_string.lower()
    
    answer_pos = response_lower.find(answer_prefix)
    question_pos = response_lower.find(question_prefix)
    
    check_answer = answer_pos != -1
    check_question = question_pos != -1
    
    if check_answer:
        answer_start = answer_pos + len(answer_prefix)
        answer_text = response_string[answer_start:].strip()
        
        if check_question and question_pos < answer_pos:
            # Questions come before answer
            question_text = response_string[question_pos + len(question_prefix):answer_pos].strip()
            try:
                # Try to extract JSON
                json_start = question_text.find('{')
                json_end = question_text.rfind('}') + 1
                if json_start != -1 and json_end > json_start:
                    question_json = json.loads(question_text[json_start:json_end])
                else:
                    question_json = {"question_list": DEFAULT_SUGGESTION_QUESTIONS}
            except:
                question_json = {"question_list": DEFAULT_SUGGESTION_QUESTIONS}
        else:
            question_json = {"question_list": DEFAULT_SUGGESTION_QUESTIONS}
        
        return check_answer, [answer_text], question_json, {"reference_url": []}
    
    return False, [], {"question_list": DEFAULT_SUGGESTION_QUESTIONS}, {"reference_url": []}


# ============================================================================
# Core Components
# ============================================================================

class RewriterOutputJsonFormat(BaseModel):
    """Output format for query rewriter"""
    is_related_score: int = Field(description="Score from 0-5 indicating how related the query is to conversation")
    key_words: List[str] = Field(description="Important keywords from conversation and question")
    new_question: str = Field(description="Rewritten question that is more general")
    word_dup: float = Field(description="Ratio of keyword overlap between old and new question (0-1)")
    word_rewrite: float = Field(description="Semantic similarity between old and new question (0-1)")


class Rewriter:
    """Rewrites user queries to be more searchable"""
    
    def __init__(self):
        self.llm = ChatOpenAI(
            base_url=LLM_BASE_URL,
            model=LLM_MODEL,
            api_key=API_KEY,
            temperature=0.0,
            max_tokens=4096,
            extra_body={
                'repetition_penalty': 1.15,
                'chat_template_kwargs': {
                    'enable_thinking': False
                }
            },
        )
        self.reason_parser = ReasoningExtractor(return_full_output_only=True)
        self.parser = JsonOutputParser(pydantic_object=RewriterOutputJsonFormat)
        
        # Simplified prompts (you can load from Langfuse in production)
        system_prompt = """You are a query rewriter. Your task is to rewrite user questions to be more searchable while maintaining the original intent."""
        
        user_prompt_template = """Given the conversation history and the last question, rewrite the question to be more general and searchable.

Conversation:
{conversation}

Last Question: {question}

{format_instructions}

Return a JSON object with the required fields."""
        
        self.rewrite_chat_prompt_template = ChatPromptTemplate.from_messages([
            SystemMessage(content=system_prompt),
            HumanMessagePromptTemplate(
                prompt=PromptTemplate(
                    template=user_prompt_template,
                    input_variables=["conversation", "question"],
                    partial_variables={"format_instructions": self.parser.get_format_instructions("")}
                )
            )
        ])
        
        self.chain = self.rewrite_chat_prompt_template | self.llm | self.reason_parser | self.parser
    
    def rewrite(self, query: str, history: List[BaseMessage]) -> Tuple[float, str]:
        """Rewrite query based on conversation history"""
        reformatted_history = []
        for h in history:
            if isinstance(h, HumanMessage):
                reformatted_history.append(f"user: {h.content}")
            elif isinstance(h, AIMessage):
                reformatted_history.append(f"assistant: {h.content}")
        
        history_context = "\n".join(reformatted_history)
        
        result = self.chain.invoke({
            "conversation": history_context,
            "question": query
        })
        
        is_related_score, new_question = self.calculate_result(result)
        return is_related_score, new_question
    
    def calculate_result(self, result_json: Dict) -> Tuple[float, str]:
        """Calculate final relatedness score"""
        is_related_score = float(result_json.get("is_related_score", 0) / 5)
        word_dup = float(result_json.get("word_dup", 0))
        word_rewrite = float(result_json.get("word_rewrite", 0))
        is_related_score_calculate = ((is_related_score * 2 + word_dup + word_rewrite) / 4) * 5
        return is_related_score_calculate, result_json.get("new_question", "")


class LocalEmbeddings(Embeddings):
    """Local embeddings using Hugging Face transformers"""
    
    def __init__(self, model_name: str = 'Snowflake/snowflake-arctic-embed-l-v2.0', query_prefix: str = 'query: '):
        """
        Initialize local embeddings model
        
        Args:
            model_name: Hugging Face model name or local path
            query_prefix: Prefix to add to queries (not documents)
        """
        self.model_name = model_name
        self.query_prefix = query_prefix
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
        self.model.to(self.device)
        self.model.eval()
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents (without query prefix)"""
        if not texts:
            return []
        
        # Tokenize documents (no prefix)
        tokens = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=8192
        )
        
        # Move tokens to device
        tokens = {k: v.to(self.device) for k, v in tokens.items()}
        
        # Compute embeddings
        with torch.no_grad():
            embeddings = self.model(**tokens)[0][:, 0]  # Take first token embedding
        
        # Normalize embeddings
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        # Convert to list of lists
        return embeddings.cpu().tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """Embed a single query (with query prefix)"""
        # Add query prefix
        query_with_prefix = f"{self.query_prefix}{text}"
        
        # Tokenize query
        tokens = self.tokenizer(
            [query_with_prefix],
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=8192
        )
        
        # Move tokens to device
        tokens = {k: v.to(self.device) for k, v in tokens.items()}
        
        # Compute embedding
        with torch.no_grad():
            embedding = self.model(**tokens)[0][:, 0]  # Take first token embedding
        
        # Normalize embedding
        embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
        
        # Convert to list
        return embedding.cpu().tolist()[0]


class VectorStore:
    """Vector store for document retrieval"""
    
    def __init__(self, model_name: Optional[str] = None):
        """
        Initialize vector store with local embeddings
        
        Args:
            model_name: Hugging Face model name or local path. 
                       Defaults to 'Snowflake/snowflake-arctic-embed-l-v2.0'
                       or local path if model exists in model_hub/
        """
        # Check for local model first, then use HuggingFace model name
        local_model_path = 'model_hub/snowflake-arctic-embed-l-v2.0'
        if model_name is None:
            if os.path.exists(local_model_path):
                model_name = local_model_path
            else:
                model_name = 'Snowflake/snowflake-arctic-embed-l-v2.0'
        
        vector_store_kwargs = {
            "url": QDRANT_ENDPOINT,
            "api_key": QDRANT_API_KEY,
            "https": QDRANT_HTTPS,
            "port": QDRANT_PORT
        }
        
        # Use local embeddings instead of service
        embedding_api = LocalEmbeddings(model_name=model_name)
        self.client = QdrantClient(
            url=vector_store_kwargs["url"],
            api_key=vector_store_kwargs["api_key"],
            https=vector_store_kwargs["https"],
            port=vector_store_kwargs["port"]
        )
        self.vector_store = Qdrant(
            embeddings=embedding_api,
            collection_name=QDRANT_COLLECTION,
            client=self.client
        )
    
    def search_docs_with_score(self, query: str, top_k: int = 30) -> List[Document]:
        """Search documents with similarity scores"""
        query_doc = Document(page_content=query)
        results = self.vector_store.similarity_search_with_score(
            query_doc,
            k=top_k,
            score_threshold=RETRIEVAL_SCORE_THRESHOLD,
        )
        matched_docs = []
        for doc, score in results:
            doc.metadata["score"] = score
            matched_docs.append(doc)
        return matched_docs


def rerank_documents(query: str, documents: List[Document], top_n: int = 8) -> List[Document]:
    """Rerank documents using reranker API"""
    if len(documents) == 0:
        return []
    
    res = requests.post(
        RERANKER_BASE_URL,
        json={
            "model": RERANKER_MODEL,
            "query": query,
            "documents": [d.page_content for d in documents],
        }
    )
    
    res_data = json.loads(res.content)['results']
    sorted_res = sorted(res_data, key=lambda x: x['relevance_score'], reverse=True)[:top_n]
    
    ranked_docs = []
    for item in sorted_res:
        index = item["index"]
        score = item["relevance_score"]
        
        target_doc = documents[index]
        target_doc.metadata["relevance_score"] = score
        if score >= RERANKER_SCORE_THRESHOLD:
            ranked_docs.append(target_doc)
    
    return ranked_docs


class StructuredResponse(BaseModel):
    """Structured output for LLM response"""
    rule: str = Field(description="Whether you are allowed to answer")
    thought: str = Field(description="Think about the user question")
    language: str = Field(description="The language that user uses")
    addressing_way: str = Field(description="Xưng em với Anh/chị")
    action: str = Field(description="What you gonna do to answer")
    bonus_questions: list[str] = Field(description="Follow-up questions")
    observation: str = Field(description="What you see user ask")
    reference_url: str = Field(description="Reference URL")
    final_answer: str = Field(description="The final answer to user")


# ============================================================================
# Main Query Processing Function
# ============================================================================

def knowledge_base_chat(
    user_query: str,
    history: List[Dict],
    force_disclaimer: bool = False
) -> Dict:
    """
    Main function for processing user queries through the complete AI pipeline
    
    Args:
        user_query: User's question
        history: Conversation history in OpenAI format [{"role": "user/assistant", "content": "..."}]
        force_disclaimer: Whether to force disclaimer in response
    
    Returns:
        Dictionary with user_msg, raw_llm_msg, question_list, reference_list
    """
    
    # Initialize components
    rewriter = Rewriter()
    vector_store = VectorStore()
    llm_model = ChatOpenAI(
        base_url=LLM_BASE_URL,
        model=LLM_MODEL,
        api_key=API_KEY,
        temperature=0.0,
        max_tokens=4096,
        extra_body={
            'repetition_penalty': 1.15,
            'chat_template_kwargs': {
                'enable_thinking': True
            }
        },
    )
    parser = JsonOutputParser(pydantic_object=StructuredResponse)
    reason_extractor = ReasoningExtractor()
    
    # Convert history
    history_messages = convert_to_messages(history[-10:])  # Keep last 10 messages
    
    # Step 1: Query Rewriting
    is_related_score, rewrited_query = rewriter.rewrite(user_query, history_messages)
    print(f"Rewriter: {is_related_score >= RELATED_SCORE_THRESHOLD} with score: {is_related_score}")
    is_related = is_related_score >= RELATED_SCORE_THRESHOLD
    search_query = rewrited_query if is_related else user_query
    
    # Step 2: Vector Search
    matched_docs = vector_store.search_docs_with_score(search_query, top_k=30)
    
    # Step 3: Reranking
    reranked_docs = rerank_documents(search_query, matched_docs, top_n=8)
    
    # Step 4: Context Construction
    if len(reranked_docs) != 0:
        context_prompt = "".join(f"{i}: {doc.page_content}\n\n" for i, doc in enumerate(reranked_docs))
    else:
        context_prompt = "No information found."
    
    enable_context = len(reranked_docs) != 0
    response_tone = "formal"
    
    # Step 5: Prompt Selection (simplified - you can load from Langfuse)
    if len(reranked_docs) == 0:
        vib_system_prompt = """You are a helpful assistant named {bot_name} for {business_name}. 
        {intro_doc}
        You should be polite and use "em" to refer to yourself and "Anh/Chị" to refer to the user.
        If you don't have information, politely say so."""
        query_instruction = """Answer the following question based on your knowledge:

Question: {question}

Provide a helpful answer."""
    else:
        vib_system_prompt = """You are a helpful assistant named {bot_name} for {business_name}. 
        {intro_doc}
        You should be polite and use "em" to refer to yourself and "Anh/Chị" to refer to the user.
        Answer questions based on the provided context."""
        query_instruction = """Answer the following question based on the provided context:

Context:
{context}

Question: {question}

Provide a detailed answer based on the context. Include reference URLs if mentioned in the context."""
    
    if force_disclaimer:
        query_instruction += "\n\nInclude a disclaimer about information accuracy."
    
    # Step 6: LLM Generation
    final_chat_prompt_template = ChatPromptTemplate.from_messages([
        SystemMessage(content=vib_system_prompt),
        MessagesPlaceholder(variable_name="history_messages"),
        HumanMessagePromptTemplate.from_template(query_instruction),
    ])
    
    chain = final_chat_prompt_template | llm_model | reason_extractor
    
    retry_attempt_questions = [
        "",
        "Trả lời câu này: ",
        "Cho tôi hỏi: ",
        "Tôi muốn hỏi là: ",
    ]
    
    # Contact info setup
    contact_info = ""
    if hotline and web_link:
        contact_info = f"The official website of {BUSINESS_NAME} is {WEB_LINK}, the Hotline is {HOTLINE}"
    elif web_link:
        contact_info = f"The official website of {BUSINESS_NAME} is {WEB_LINK}"
    elif hotline:
        contact_info = f"The Hotline of {BUSINESS_NAME} is {HOTLINE}"
    
    # Retry loop for LLM generation
    response_final = ""
    question_list = []
    check_final = False
    
    for ii, retry_question in enumerate(retry_attempt_questions):
        chain_input = {
            "question": retry_question + user_query,
            "history_messages": history_messages,
            "bot_name": BOT_NAME,
            "business_name": BUSINESS_NAME,
            "web_link": WEB_LINK,
            "hotline": HOTLINE,
            "intro_doc": INTRO_DOC,
            "contact_info": contact_info,
        }
        if enable_context:
            chain_input["context"] = context_prompt
        
        response = chain.invoke(input=chain_input)
        reasoning_content, final_output = response['reasoning'], response['content']
        
        # Clean up response
        final_output = re.sub(r"\*\*Final Answer\*\*:?", "Final Answer:", final_output)
        final_output = re.sub(r"\*\*FINAL ANSWER\*\*", "Final Answer", final_output)
        final_output = re.sub(r'FINAL ANSWER:?', 'Final Answer:', final_output)
        final_output = re.sub(r'## FINAL ANSWER:?', 'Final Answer:', final_output)
        
        try:
            check_final, response_final_tokens, question_list_json, reference_json = check_prefix_tokens_with_questions(final_output)
            response_final = ''.join(response_final_tokens)
            question_list = question_list_json.get("question_list", [])
            
            if check_final:
                break
        except Exception as e:
            print(f"Error at extract final answer: {e}")
            check_final = False
    
    if not check_final:
        response_final = "Em xin lỗi, em chưa hiểu ý của Anh/Chị. Mong Anh/Chị có thể hỏi lại lần nữa ạ."
        question_list = []
    
    # Step 7: Response Processing
    final_response, reference_json_valid = extract_reference_from_response(response_final)
    final_reference_json_list = []
    unique_links = []
    
    for item in reference_json_valid.get('reference_url', []):
        if item['payload']['url'] not in unique_links and not is_valid_url(item['title']):
            final_reference_json_list.append(item)
            unique_links.append(item['payload']['url'])
    
    # Filter URLs from response
    final_response = filter_url_in_response(final_response)
    
    # Final output
    output_response = {
        "user_msg": final_response.lstrip(),
        "raw_llm_msg": response_final.lstrip(),
        "question_list": (question_list + DEFAULT_SUGGESTION_QUESTIONS)[0:3],
        "reference_list": final_reference_json_list,
    }
    
    return output_response


# ============================================================================
# Example Usage
# ============================================================================

if __name__ == "__main__":
    # Example usage
    result = knowledge_base_chat(
        user_query="Giá món sushi là bao nhiêu?",
        history=[
            {"role": "user", "content": "Xin chào"},
            {"role": "assistant", "content": "Chào Anh/Chị! Em có thể giúp gì cho Anh/Chị?"},
        ],
        force_disclaimer=False
    )
    
    print("=" * 50)
    print("FINAL RESPONSE:")
    print("=" * 50)
    print(f"User Message: {result['user_msg']}")
    print(f"\nSuggested Questions: {result['question_list']}")
    print(f"\nReferences: {result['reference_list']}")