In [None]:
!pip install -U xformers --index-url https://download.pytorch.org/whl/cu121
!pip install langchain optimum qdrant-client wikipedia FastAPI uvicorn pyngrok
!pip install --upgrade pydantic
!pip install vllm

In [2]:
#GENERATE_MODEL_NAME="phatjk/vietcuna-7b-v3-AWQ"
GENERATE_MODEL_NAME="vilm/vietcuna-3b-v2"
EMBEDDINGS_MODEL_NAME="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
QDRANT_URL = "https://d3966086-8c65-4b03-895a-6926e1f83994.us-east4-0.gcp.cloud.qdrant.io"
QDRANT_COLLECTION_NAME = "Luat_vectordb"
NGROK_STATIC_DOMAIN = "briefly-knowing-treefrog.ngrok-free.app"
NGROK_TOKEN=          "2pHsZScewzWnFPxgNOvwnCtfA9R_2J42SPU3YQJhacrYbj4hM"
HUGGINGFACE_API_KEY = "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"
QDRANT_API_KEY =      "vkZ3snjz8mkKNj0weWgZxCvnz83ANbesUvYhz7HitC2X-rw_-d4hEg"

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.bettertransformer import BetterTransformer
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_rerank = AutoModelForSequenceClassification.from_pretrained('amberoad/bert-multilingual-passage-reranking-msmarco').to(device)
#model_rerank = BetterTransformer.transform(model_rerank)
tokenizer_rerank = AutoTokenizer.from_pretrained('amberoad/bert-multilingual-passage-reranking-msmarco')

In [None]:
!pip install -U langchain-community

In [None]:
# from langchain.schema.document import Document
# from langchain_core.vectorstores import VectorStoreRetriever
# from langchain.retrievers import WikipediaRetriever
# from typing import List
# class RerankRetriever(VectorStoreRetriever):
#     vectorstore: VectorStoreRetriever
#     def get_relevant_documents(self, query: str) -> List[Document]:
#         docs = self.vectorstore.get_relevant_documents(query=query)
#         candidates = [doc.page_content for doc in docs]
#         queries = [query]*len(candidates)
#         features = tokenizer_rerank(queries, candidates,  padding=True, truncation=True, return_tensors="pt").to(device)
#         with torch.no_grad():
#             scores = model_rerank(**features).logits
#             values, indices = torch.sum(scores, dim=1).sort()
#             # relevant_docs = docs[indices[0]]
#         return [docs[indices[0]],docs[indices[1]]]
# class RerankWikiRetriever(VectorStoreRetriever):
#     vectorstore: WikipediaRetriever
#     def get_relevant_documents(self, query: str) -> List[Document]:
#         docs = self.vectorstore.get_relevant_documents(query=query)
#         candidates = [doc.page_content for doc in docs]
#         queries = [query]*len(candidates)
#         features = tokenizer_rerank(queries, candidates,  padding=True, truncation=True, return_tensors="pt").to(device)
#         with torch.no_grad():
#             scores = model_rerank(**features).logits
#             values, indices = torch.sum(scores, dim=1).sort()
#             # relevant_docs = docs[indices[0]]
#         return [docs[indices[0]],docs[indices[1]]]

In [32]:
from langchain.schema.document import Document
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.retrievers import WikipediaRetriever
from typing import List

class RerankRetriever(VectorStoreRetriever):
    vectorstore: VectorStoreRetriever

    @property
    def embeddings(self):
        # Kiểm tra xem vectorstore có thuộc tính embeddings không
        if hasattr(self.vectorstore, 'embeddings'):
            return self.vectorstore.embeddings
        else:
            raise AttributeError("VectorStoreRetriever does not have 'embeddings' attribute.")

    def _get_relevant_documents(self, query: str) -> List[Document]:  # Thay đổi ở đây
        docs = self.vectorstore.get_relevant_documents(query=query)
        candidates = [doc.page_content for doc in docs]
        queries = [query] * len(candidates)
        features = tokenizer_rerank(queries, candidates, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            scores = model_rerank(**features).logits
            values, indices = torch.sum(scores, dim=1).sort()
        return [docs[indices[0]], docs[indices[1]]]

class RerankWikiRetriever(VectorStoreRetriever):
    vectorstore: WikipediaRetriever

    def _get_relevant_documents(self, query: str) -> List[Document]:  # Thay đổi ở đây
        docs = self.vectorstore.get_relevant_documents(query=query)
        candidates = [doc.page_content for doc in docs]
        queries = [query] * len(candidates)
        features = tokenizer_rerank(queries, candidates, padding=True, truncation=True, return_tensors="pt").to(device)
        with torch.no_grad():
            scores = model_rerank(**features).logits
            values, indices = torch.sum(scores, dim=1).sort()
        return [docs[indices[0]], docs[indices[1]]]

In [None]:
!pip install accelerate bitsandbytes

In [7]:
from langchain.retrievers import WikipediaRetriever
from langchain.vectorstores import Qdrant
from langchain.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.chains import RetrievalQA,MultiRetrievalQAChain
from langchain.llms import VLLM
from langchain.llms import HuggingFaceHub

class LLMServe:
    def __init__(self) -> None:
      self.embeddings = self.load_embeddings()
      self.current_source = "wiki"
      self.retriever = self.load_retriever(retriever_name = self.current_source,embeddings=self.embeddings)
      self.pipe = self.load_model_pipeline(max_new_tokens=300)
      self.prompt = self.load_prompt_template()
      self.rag_pipeline = self.load_rag_pipeline(llm=self.pipe,
                                            retriever=self.retriever,
                                            prompt=self.prompt)
    def load_embeddings(self):
      embeddings = HuggingFaceInferenceAPIEmbeddings(
          model_name=EMBEDDINGS_MODEL_NAME,
          api_key = HUGGINGFACE_API_KEY,
          #model_kwargs = {'device': "auto"}
      )
      return embeddings

    def load_retriever(self,retriever_name,embeddings):
      retriever=None
      if retriever_name == "wiki":
        retriever = RerankWikiRetriever(vectorstore = WikipediaRetriever(lang="vi",
                                       doc_content_chars_max=800,top_k_results=15))
      else:
        client = QdrantClient(
            url=QDRANT_URL,api_key=QDRANT_API_KEY, prefer_grpc=False
        )
        db = Qdrant(client=client,
                    embeddings=embeddings,
                    collection_name=QDRANT_COLLECTION_NAME)

        retriever = RerankRetriever(vectorstore = db.as_retriever(search_kwargs={"k":15}))

      return retriever

    def load_model_pipeline(self,max_new_tokens=100):
      llm = VLLM(
          model=GENERATE_MODEL_NAME,
          trust_remote_code=True,  # mandatory for hf models
          max_new_tokens=max_new_tokens,
            # temperature=1.0,
            # top_k=50,
            # top_p=0.9,
          top_k=10,
          top_p=0.95,
          temperature=0.4,
          dtype="half",
          #vllm_kwargs={"quantization": "awq"}
      )
      return llm

    def load_prompt_template(self):

      query_template = "Bạn là một chatbot thông minh trả lời câu hỏi dựa trên ngữ cảnh (context).\n\n### Context:{context} \n\n### Human: {question}\n\n### Assistant:"
      prompt = PromptTemplate(template=query_template,
                        input_variables= ["context","question"])
      return prompt

    def load_rag_pipeline(self,llm,retriever,prompt):
      rag_pipeline = RetrievalQA.from_chain_type(
      llm=llm, chain_type='stuff',
      retriever=retriever,
      chain_type_kwargs={
      "prompt": prompt
      },
      return_source_documents=True)
      return rag_pipeline

    def rag(self,source):
      if source == self.current_source:
        return self.rag_pipeline
      else:
        self.retriever = self.load_retriever(retriever_name=source,embeddings=self.embeddings)
        self.rag_pipeline = self.load_rag_pipeline(llm=self.pipe,
                                      retriever=self.retriever,
                                      prompt=self.prompt)
        self.current_source = source
        return self.rag_pipeline

In [None]:
!pip install triton

In [None]:
app = LLMServe()

In [10]:
# from typing import Union
# from fastapi.middleware.cors import CORSMiddleware
# from fastapi.responses import JSONResponse
# from fastapi.encoders import jsonable_encoder
# from fastapi import FastAPI
# origins = ["*"]
# app_api = FastAPI()
# app_api.add_middleware(
#     CORSMiddleware,
#     allow_origins=origins,
#     allow_credentials=True,
#     allow_methods=["*"],
#     allow_headers=["*"],
# )

# @app_api.get("/")
# def read_root():
#     return "API RAG"

# @app_api.get("/rag/{source}")
# async def read_item(source: str, q: str | None = None):
#     if q:
#         data = app.rag(source=source)(q)
#         sources = []
#         for docs in data["source_documents"]:
#             sources.append(docs.to_json()["kwargs"])
#         res = {
#             "result" : data["result"],
#             "source_documents":sources
#         }
#         return JSONResponse(content=jsonable_encoder(res))
#     return None


In [17]:
from typing import Union, Optional
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi import FastAPI, HTTPException

# Định nghĩa các nguồn dữ liệu hợp lệ
VALID_SOURCES = ["nttu", "wiki"]

origins = ["*"]
app_api = FastAPI()
app_api.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # Cho phép tất cả các nguồn
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app_api.get("/")
def read_root():
    return {"message": "API RAG is running"}

@app_api.get("/rag/{source}")
async def read_item(source: str, q: Optional[str] = None):
    # Kiểm tra source có hợp lệ không
    if source not in VALID_SOURCES:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid source. Must be one of: {VALID_SOURCES}"
        )

    # Kiểm tra q có giá trị không
    if not q:
        raise HTTPException(
            status_code=400,
            detail="Query parameter 'q' is required"
        )

    try:
        # Thực hiện truy vấn RAG
        # data = app.rag(source=source)(q)

#  INFO:     2405:4802:182d:20c0:4480:65c0:d6dd:cac4:0 - "GET /rag/nttu?q=xin%20ch%C3%A0o HTTP/1.1" 500 Internal Server Error
# <ipython-input-7-d0ef9f03339d>:38: LangChainDeprecationWarning: The class `Qdrant` was deprecated in LangChain 0.0.37 and will be removed in 1.0. An updated version of the class exists in the :class:`~langchain-qdrant package and should be used instead. To use it run `pip install -U :class:`~langchain-qdrant` and import as `from :class:`~langchain_qdrant import Qdrant``.
#   db = Qdrant(client=client,
# <ipython-input-15-a883cb622b5a>:42: LangChainDeprecationWarning: The method `Chain.__call__` was deprecated in langchain 0.1.0 and will be removed in 1.0. Use :meth:`~invoke` instead.
#   data = app.rag(source=source)(q)
        data = app.rag(source=source).invoke(q)

        # Xử lý kết quả
        sources = []
        for docs in data["source_documents"]:
            sources.append(docs.to_json()["kwargs"])

        res = {
            "result": data["result"],
            "source_documents": sources
        }

        return JSONResponse(content=jsonable_encoder(res))

    except Exception as e:
        # Xử lý lỗi
        raise HTTPException(
            status_code=500,
            detail=f"An error occurred: {str(e)}"
        )

In [None]:
import nest_asyncio
from pyngrok import ngrok
import uvicorn
ngrok.set_auth_token(NGROK_TOKEN)
ngrok_tunnel = ngrok.connect(8000,domain=NGROK_STATIC_DOMAIN)
print('Public URL:', ngrok_tunnel.public_url)
nest_asyncio.apply()
uvicorn.run(app_api, port=8000)

## fix bug

In [None]:
from qdrant_client import QdrantClient

# Kết nối tới Qdrant
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)

# Lấy thông tin collection
collection_info = client.get_collection("nttu_sotay_vector_db_v1")
print("Collection info:", collection_info)

# Lấy một số points để kiểm tra
points = client.scroll(
    collection_name="nttu_sotay_vector_db_v1",
    limit=2  # Lấy 2 điểm đầu tiên
)
for point in points[0]:
    print("\nPoint ID:", point.id)
    print("Payload:", point.payload)

In [None]:
# # Truy vấn
# query = "Xếp loại học bổng ở trường"

# # Chuyển đổi truy vấn thành vector
# query_vector = embeddings.embed_query(query)

# # Tìm kiếm trong collection
# search_results = client.search(
#     collection_name="nttu_sotay_vector_db_v1",
#     query_vector=query_vector,
#     limit=5  # Số lượng kết quả muốn lấy
# )

# # In kết quả
# for result in search_results:
#     print("\nPoint ID:", result.id)
#     print("Score:", result.score)  # Điểm số tương tự
#     print("Payload:", result.payload)  # Nội dung và metadata