## Load "Vector DB" and run RAG query on pre-trained instruct LLM

In [31]:
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import pickle
import sys
from threading import Thread
import time

import faiss
from sentence_transformers import SentenceTransformer, util
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
    TextStreamer,
    set_seed,
)

set_seed(1234)

### Load chunk contents and embedding index

In [44]:
with open("data/lincoln_chunks.pkl", "rb") as file:
    chunks = pickle.load(file)

index = faiss.read_index("data/lincoln_chunks.index")

In [45]:
len(chunks), index.ntotal

(153, 153)

### Load LLM to run queries

In [47]:
# query_model = "google/gemma-2b-it"
query_model = "google/gemma-7b-it"

tokenizer = AutoTokenizer.from_pretrained(query_model)
model = AutoModelForCausalLM.from_pretrained(
    query_model, torch_dtype=torch.bfloat16, device_map="mps"
)

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.11G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

### Construct RAG Query and Run

In [48]:
def run_query_streamed(query, query_model):
    system_prompt = "You are a helpful assistant who answers question truthfully to the best of your knowledge. You decline to answer if you do not know the answer."

    chat = [
        # {
        #     "role": "system",
        #     "content": system_prompt,
        # },
        
        {

            "role": "user",
            "content": f"{system_prompt}\n\n{query}",  # Pick up the book
        },
    ]

    formatted_prompt = tokenizer.apply_chat_template(
        chat, tokenize=False, add_generation_prompt=True, return_tensors="pt"
    )
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to("mps")

    streamer = TextStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)

In [49]:
query_encoder = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")

In [50]:
def top_k_chunks(query: str, k=1) -> str:
    """
    Find closest chunk for a given query.
    """
    embeddings = query_encoder.encode([query])
    D, I = index.search(embeddings, k)
    return D, I


def run_rag_query_streamed(query, query_model, k=3):
    # Retrieve most similar chunks
    D, I = top_k_chunks(query, k=k)
    # formatted_chunks = '\n\n'.join(["Document: " + chunks[i] for i in I[0]])
    formatted_chunks = ' '.join([chunks[i] for i in I[0]])
    
    # rag_query = f"Answer the query below and ground your answer in facts contained in the documents below:\n\nQuery: {query}\n\n{formatted_chunks}"

    rag_query = f"{formatted_chunks}\n\nAnswer the following question: {query}"

    print(rag_query)
    print("\n")

    run_query_streamed(rag_query, query_model)

    # for i, d in zip(I[0], D[0]):
    #     print(d, chunks[i])
    #     print("")

### Debug Examples
Compare answers to questions, with and without context from most similar document chunks.

In [51]:
query = "Why did Abraham Lincoln grow a beard?"

print("google/gemma-2b-it")
run_query_streamed(query, query_model)
print()

print("google/gemma-2b-it + RAG")
run_rag_query_streamed(query, query_model)
print()

google/gemma-2b-it
I 