In [None]:
# Directory for RAG documents
# Note: Change this to match your environment
documents_dir = "/mnt/c/Users/your_username/Documents/your_documents_directory/"
documents_pattern = "*.md"

# Model specification
embedding_model_name_or_path = "sbintuitions/sarashina-embedding-v1-1b"
model_name_or_path = "sbintuitions/sarashina2.2-3b-instruct-v0.1"

lang = "ja"  # or "en" for English

In [None]:
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from transformers.generation.streamers import TextStreamer
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.utils.quantization_config import BitsAndBytesConfig

# Load and split documents
loader = DirectoryLoader(
    documents_dir,
    glob=documents_pattern,
    loader_cls=TextLoader,
    loader_kwargs={"encoding": "utf-8"},
)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = splitter.split_documents(documents)

# Prepare embedding model
# https://python.langchain.com/api_reference/huggingface/embeddings/langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings.html
# Note1: Not compatible with quantization by BitsAndBytes
# Note2: With only 8GB GPU, running the embedding model on GPU is tight, so running on CPU
embedding_model = HuggingFaceEmbeddings(
    model_name=embedding_model_name_or_path,
    model_kwargs={"device": "cpu"},
)

# Store in vector DB (FAISS)
db = FAISS.from_documents(docs, embedding_model)

# Prepare LLM
# Note: Without 8bit quantization by BitsAndBytes, it overflows my 8GB GPU :P and becomes extremely slow
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    quantization_config=quant_config,
)

In [None]:
def _rag(retriever, user_prompt):
    """Search for similar documents and create a prompt for RAG"""
    relevant_docs = retriever.invoke(user_prompt)
    context = "\n\n".join([doc.page_content for doc in relevant_docs])

    rag_prompt_en = f"""Answer the question using only the conversation with the user and the information provided below.
    ==========================
    {context}
    ==========================
    """

    rag_prompt_ja = f"""ユーザーとのやり取りと、次の記載にある情報のみを使って質問に答えてください。
    ==========================
    {context}
    ==========================
    """

    if lang == "en":
        return rag_prompt_en
    elif lang == "ja":
        return rag_prompt_ja


def _build_token(tokenizer, prompts, device):
    """Build token for input to model"""

    formatted_text = tokenizer.apply_chat_template(
        prompts,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(
        formatted_text,
        padding=True,
        truncation=True,
        max_length=1024,
        return_tensors="pt",
    ).to(device)

    return inputs


def _extract_response(text: str) -> str:
    """Extract response from AI output."""
    text = text.split("<|assistant|>")[-1]
    text = text.split("</s>")[0]
    text = text.strip()
    return text


def generate(streamer, tokenizer, model, retriever, chat_history, user_prompt) -> str:
    """
    Receive user input and generate AI response using RAG
    :param streamer: Object for streaming output
    :param tokenizer: Tokenizer
    :param model: Model
    :param retriever: Retriever for vector DB
    :param chat_history: Chat history
    :param user_prompt: User input
    """

    rag_prompt = _rag(retriever, user_prompt)

    prompts = [{"role": "system", "content": rag_prompt}] + chat_history

    inputs = _build_token(tokenizer, prompts, model.device)

    # Output from model via streamer, then store the same content in outputs
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pad_token_id=tokenizer.pad_token_id,
        streamer=streamer,
        max_new_tokens=512,
        use_cache=True,
        # Use randomness when choosing words
        do_sample=True,  # If False, the model always picks the most likely next word. Default False
        temperature=1.0,  # Higher temperature = more randomness. Default 1.0
        top_k=50,  # Limits the selection of next words. Default 50
        top_p=1.0,  # top_p=1.0 means no limit; top_p=0.9 would restrict to the top words that make up 90% of the probability. Default 1.0
    )

    output = tokenizer.decode(outputs[0], skip_special_tokens=False)
    response = _extract_response(output)

    return response

In [None]:
# For streaming output from the model
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Create retriever from the vector DB
retriever = db.as_retriever()

# Initialize chat history list
chat_history = []

In [None]:
def chat(user_prompt):
    print(f"USER: {user_prompt}\n")
    chat_history.append({"role": "user", "content": user_prompt})

    print("AI:")
    response = generate(streamer, tokenizer, model, retriever, chat_history, user_prompt)
    chat_history.append({"role": "assistant", "content": response})

In [None]:
# Example usage
chat("ゆうちょ銀行で口座開設しようとしたことあったよね？どうなった？")
# or chat("What happened when you went to San Francisco?")