In [1]:
import torch
from transformers import BitsAndBytesConfig # Corrected import
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [2]:
model_name:str = "vilm/vinallama-7b-chat-GGUF"

In [3]:
# xây dựng hàm load LLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import BitsAndBytesConfig
from langchain_huggingface import HuggingFacePipeline

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [4]:
def get_hf_llm(model_name=model_name, max_new_token=1024, **kwargs):
  model = AutoModelForCausalLM.from_pretrained(
      model_name,
      quantization_config=nf4_config,
      low_cpu_mem_usage=True
  )
  tokenizer = AutoTokenizer.from_pretrained(model_name)

  model_pipeline = pipeline(
      "text-generation",
      model=model,
      tokenizer=tokenizer,
      max_new_tokens=max_new_token,
      pad_token_id=tokenizer.eos_token_id,
      device_map="auto"
  )

  llm = HuggingFacePipeline(
      pipeline=model_pipeline,
      model_kwargs=kwargs
  )

  return llm

In [5]:
# offline_rag.py
import re

from langchain_classic import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

class Str_OutputParser(StrOutputParser):
    def __init__(self) -> None:
       super().__init__()

    def parse(self, text: str) -> str:
        return self.extract_answer(text)

    def extract_answer(self,
                       text_response: str,
                       pattern: str = r"Answer:\s*(.*)"
                       ) -> str:

        match = re.search(pattern, text_response, re.DOTALL)
        if match:
            answer_text = match.group(1).strip()
            return answer_text
        else:
            return text_response

class Offline_RAG:
    def __init__(self, llm) -> None:
        self.llm = llm
        self.prompt = hub.pull("rlm/rag-prompt")
        self.str_parser = Str_OutputParser()

    def get_chain(self, retriever):
        input_data = {
            "context": retriever | self.format_docs,
            "question": RunnablePassthrough()
        }

        rag_chain = (
            input_data
            | self.prompt
            | self.llm
            | self.str_parser
        )

        return rag_chain
    def format_docs(self, docs):
        return "\n\n".join(doc.page_content for doc in docs)

In [None]:
# vectorstore.py
from typing import Union
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS


class VectorDB:
  def __init__(self,
               documents = None,
               vector_db: Union[Chroma, FAISS] = Chroma,
               embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2"),
               ) -> None:
    self.vector_db = vector_db
    self.embedding = embedding
    self.db = self._build_db(documents)

  def _build_db(self, documents):
    return self.vector_db.from_documents(
                                        documents = documents,
                                        embedding = self.embedding
                                      )

  def get_retriever(self,
                    search_type: str = "similarity",
                    search_kwargs: dict = {"k": 10}):
    retriever = self.db.as_retriever(
                                    search_type = search_type,
                                    search_kwargs = search_kwargs
                                  )
    return retriever

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

In [None]:
# file_loader
from typing import Union, List, Literal
import glob
from tqdm import tqdm
import multiprocessing
from langchain_community.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

def remove_non_utf8_characters(text):
    return ''.join(char for char in text if ord(char) < 128)

def load_pdf(pdf_file):
  docs = PyPDFLoader(pdf_file, extract_images=True).load()
  for doc in docs:
    doc.page_content = remove_non_utf8_characters(doc.page_content)
  return docs

def get_num_cpu():
  return multiprocessing.cpu_count()

class BaseLoader:
  def __init__(self) -> None:
     self.num_processes = get_num_cpu()

  def __call__(self, files: List[str], **kwargs):
    pass

class PDFLoader(BaseLoader):
  def __init__(self) -> None:
    super().__init__()

  def __call__(self, pdf_file: List[str], **kwargs):
    num_processes = min(self.num_processes, kwargs["workers"] )
    with multiprocessing.Pool(processes=num_processes) as pool:
      doc_loaded = []
      total_files = len(pdf_file)
      with tqdm(total=total_files, desc="Loading PDFs", unit="file") as pbar:
        for result in pool. imap_unordered(load_pdf, pdf_file):
          doc_loaded.extend(result)
          pbar.update(1)
    return doc_loaded

class TextSplitter:
  def __init__(self,
               separators: List[str] = ["\n\n", "\n", " ", ""],
               chunk_size: int = 300,
               chunk_overlap: int = 0
               ) -> None:
    self.splitter = RecursiveCharacterTextSplitter(
                                        separators = separators,
                                        chunk_size = chunk_size,
                                        chunk_overlap = chunk_overlap
                                      )
  def __call_(self, documents):
    return self.splitter.split_documents(documents)

class Loader:
  def __init__(self,
            file_type: str = Literal["pdf"],
            split_kwargs: dict = {
            "chunk_size": 300,
            "chunk_overlap": 0}
            ) -> None:
    assert file_type in ["pdf"], "file_type must be pdf"
    self.file_type = file_type
    if file_type == "pdf":
      self.doc_loader = PDFLoader( )
    else:
      raise ValueError("file_type must be pdf")

    self.doc_spltter = TextSplitter( ** split_kwargs)

  def load(self, pdf_files: Union[str, List[str]], workers: int = 1):
    if isinstance(pdf_files, str):
      pdf_files = [pdf_files]
    doc_loaded = self.doc_loader(pdf_files, workers=workers)
    doc_split = self.doc_spltter(doc_loaded)
    return doc_split

  def load_dir(self, dir_path: str, workers: int = 1):
    if self.file_type == "pdf":
      files = glob.glob(f"{dir_path}/*. pdf")
      assert len(files) > 0, f"No {self.file_type} files found in {dir_path}"
    else:
      raise ValueError("file_type must be pdf")
    return self.load(files, workers=workers)

In [None]:
# build rag source code: main.py
from pydantic import BaseModel, Field

class InputQA(BaseModel):
  question: str = Field(..., title = "Question to ask the model")
class OutputQA(BaseModel):
  answer: str = Field(..., title = "Answer from the model")

def build_rag_chain(llm, data_dir, data_type):
  doc_loaded = Loader(file_type=data_type).load_dir(data_dir, workers=2)
  retriever = VectorDB(documents = doc_loaded).get_retriever()
  rag_chain = Offline_RAG(llm).get_chain(retriever)

  return rag_chain

In [None]:
# chạy thử
llm = get_hf_llm(temperature=0.9)
genai_docs = "/content/data_source/"

genai_chain = build_rag_chain(llm, genai_docs, "pdf")

In [None]:
while True:
  input = InputQA(question=input("Enter your question: "))
  answer = genai_chain.invoke(input.question)
  print(answer)