In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def create_model(
    lm_model_name = "gemma-3-4b-it", 
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
):
    tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
    model = AutoModelForCausalLM.from_pretrained(
        lm_model_name,
        device_map="auto",
    ).to(device).eval()

    return model, tokenizer

In [3]:
def lm_template(
    text: str, 
    system_prompt: str = "You are a helpful assistant.", 
):
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": text}]
        }
    ]

In [4]:
@torch.inference_mode()
def generate(
    prompt, 
    tokenizer, 
    model, 
    max_new_tokens: int = 256, 
    temperature: float = 1,
):
    inputs = tokenizer.apply_chat_template(
        prompt, 
        add_generation_prompt=True, 
        tokenize=True,
        return_dict=True, 
        return_tensors="pt",
    )
    inputs = {
        k: (
            v.to(model.device, dtype=model.dtype)
            if v.dtype.is_floating_point else v.to(model.device)
        )
        for k, v in inputs.items()
    }

    input_len = inputs["input_ids"].shape[-1]

    max_len = int(model.config.text_config.max_position_embeddings)
    if input_len > max_len:
        raise ValueError(
            f"Input length {input_len} exceeds maximum allowed length of {max_len} tokens."
        )

    generation = model.generate(
        **inputs, 
        max_new_tokens=max_new_tokens, 
        do_sample=True, 
        temperature=temperature, 
    )
    generation = generation[0][input_len:]

    response = tokenizer.decode(
        generation, 
        skip_special_tokens=True
    )

    return response

In [5]:
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
import json
from typing import List

In [6]:
def create_embedding_model(
    emb_model_name: str = "all-MiniLM-L6-v2", 
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
):
    embedding_model = SentenceTransformer(
        emb_model_name
    ).to(device).eval()

    return embedding_model

In [14]:
def query_template(
    query, 
):
    return f"""
{query.question}
"""

def prompt_template(
    context, 
    query, 
):
    return f"""
References:
{context}
Question:
{query.question}
Do not use markdown syntax to answer and put the answer after "Answer:"
"""

@dataclass
class DocChunk:
    idx: int
    content: str
    embedding: np.ndarray = None
    score = None

@dataclass
class lookupQuery:
    question: str

In [8]:
def preprocess_pages2chunks(
    pages_list, 
):
    chunked_list = []
    for page in pages_list:

        chunked_list.append({
            'text': page['text'].strip(),
        })

    idx = 0
    all_chunks = []
    for chunked in chunked_list:
        doc = DocChunk(
            idx=idx, 
            content=chunked['text'], 
        )
        all_chunks += [doc]
        idx += 1

    return all_chunks

def load_doc_from_path(
    documents_path: str, 
    embedding_model, 
):
    pages_list = []
    with open(documents_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:
                data = json.loads(line)
                pages_list.append(data)

    chunk_list = preprocess_pages2chunks(
        pages_list=pages_list
    )
    contents = [doc.content for doc in chunk_list]

    embeddings = embedding_model.encode(contents)
    index = faiss.IndexFlatIP(embeddings.shape[1]) # Create FAISS index
    faiss.normalize_L2(embeddings)
    index.add(embeddings.astype(np.float32))
    
    # Save documents
    for doc, embedding in zip(chunk_list, embeddings):
        doc.embedding = embedding

    return chunk_list, index

def retrieve_relevant_docs(
    search_query: str, 
    embedding_model, 
    index, 
    chunk_list, 
    top_k: int = 3
) -> List[DocChunk]:
    
    relevant_docs = []
    query_embedding = embedding_model.encode([search_query])
    faiss.normalize_L2(query_embedding)
    scores, indices = index.search(query_embedding.astype(np.float32), top_k)
    
    for score, idx in zip(scores[0], indices[0]):
        if idx < len(chunk_list):
            doc = chunk_list[idx]
            relevant_docs.append(doc)

    seen = set()
    unique_data = []
    for doc in relevant_docs:
        if doc.idx not in seen:
            seen.add(doc.idx)
            unique_data.append(doc)

    return unique_data

In [15]:
def rag_ask(
    user_query: str, 
    model,
    tokenizer,
    embedding_model, 
    index, 
    chunk_list, 
    top_k: int = 3, 
    max_new_tokens: int = 256,
    temperature: float = 1,
) -> str:
    query = lookupQuery(
        question=user_query, 
    )
    search_query = query_template(
        query=query, 
    )
    relevant_docs = retrieve_relevant_docs(
        search_query=search_query, 
        embedding_model=embedding_model, 
        index=index, 
        chunk_list=chunk_list, 
        top_k=top_k, 
    )

    context = ""
    for i, doc in enumerate(relevant_docs):
        context += f"References {i+1}:{doc.content}\n"
    
    text = prompt_template(
        context=context, 
        query=query, 
    )
    prompt = lm_template(
        text=text
    )
    response = generate(
        prompt=prompt,
        tokenizer=tokenizer,
        model=model,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )

    return {
        'response': response, 
        'prompt': prompt, 
        'relevant_docs': relevant_docs, 
    }

In [10]:
model, tokenizer = create_model(
    lm_model_name="gemma-3-4b-it"
)
embedding_model = create_embedding_model(
    emb_model_name="all-MiniLM-L6-v2"
)

documents_path = './dataset/ncku_wikipedia_2510080406.jsonl'

chunk_list, index = load_doc_from_path(
    documents_path=documents_path, 
    embedding_model=embedding_model, 
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.19s/it]


In [12]:
test_cases = [
    "Hello! How are you today?", 
    "When was National Cheng Kung University established?",
    "How many colleges does National Cheng Kung University have?", 
    "Where is National Cheng Kung University located?",
]

In [16]:
print("Start Demo!\n")

for i, query in enumerate(test_cases, 1):
    print(f"Test Case ({i}) {'=' * 50}")
    print(f"user input: {query}")

    response = rag_ask(
        user_query=query, 
        model=model,
        tokenizer=tokenizer,
        embedding_model=embedding_model, 
        index=index, 
        chunk_list=chunk_list, 
    )['response']
    print(f"Model response: {response}\n{'=' * 64}\n")

print("\nDemo Completed!")

Start Demo!

user input: Hello! How are you today?
Model response: Answer:I’m doing well, thank you for asking! How about you?

user input: When was National Cheng Kung University established?
Model response: Answer:National Cheng Kung University was established in January 1931.

user input: How many colleges does National Cheng Kung University have?
Model response: National Cheng Kung University has nine colleges.
Answer: Nine

user input: Where is National Cheng Kung University located?
Model response: National Cheng Kung University is located in the East District of Tainan, Taiwan.
Answer: National Cheng Kung University is located in the East District of Tainan, Taiwan.


Demo Completed!
