## Load "Vector DB" and run RAG query on pre-trained instruct LLM
Simply appends retrieved chunks to query. No fancy pipeline steps or dedicated Vector DB.

In [1]:
import os
import pickle
import sys
from threading import Thread
import time
from unsloth import FastLanguageModel

import faiss
from sentence_transformers import SentenceTransformer, util
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
    TextStreamer,
    set_seed,
)

set_seed(1234)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


### Load chunk contents and embedding index

In [2]:
# TODO: Database for chunks
index = faiss.read_index("wikipedia-en.index")

In [3]:
index.ntotal

49522046

### Load LLM to run queries

In [4]:
# query_model = "google/gemma-2b-it"
query_model = "/home/stefanwebb/models/llms/Mistral-7B-Instruct-v0.3"

max_seq_length = 2048
dtype = None
load_in_4bit = True
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/home/stefanwebb/models/llms/mistral-7b-instruct-v0.3-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

==((====))==  Unsloth 2024.8: Fast Mistral patching. Transformers = 4.43.4.
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.988 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.5.0.dev20240829+cu124. CUDA = 8.9. CUDA Toolkit = 12.4.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28+d444815.d20240829. FA2 = True]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth




### Construct RAG Query and Run

In [5]:
FastLanguageModel.for_inference(model)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32768, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (

In [6]:
def run_query_streamed(query, query_model):
    system_prompt = "You are a helpful assistant who answers question truthfully to the best of your knowledge. You decline to answer if you do not know the answer."

    chat = [
        {
            "role": "system",
            "content": system_prompt,
        },
        
        {

            "role": "user",
            "content": f"{query}",
        },
    ]

    formatted_prompt = tokenizer.apply_chat_template(
        chat, tokenize=False, add_generation_prompt=True, return_tensors="pt"
    )
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")

    streamer = TextStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    _ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)

In [7]:
query_encoder = SentenceTransformer("/home/stefanwebb/models/llms/multi-qa-MiniLM-L6-cos-v1")

In [8]:
def top_k_chunks(query: str, k=1) -> str:
    """
    Find closest chunk for a given query.
    """
    embeddings = query_encoder.encode([query])
    D, I = index.search(embeddings, k)
    return D, I


def run_rag_query_streamed(query, query_model, k=3):
    # Retrieve most similar chunks
    D, I = top_k_chunks(query, k=k)
    # formatted_chunks = '\n\n'.join(["Document: " + chunks[i] for i in I[0]])
    formatted_chunks = ' '.join([chunks[i] for i in I[0]])
    
    # rag_query = f"Answer the query below and ground your answer in facts contained in the documents below:\n\nQuery: {query}\n\n{formatted_chunks}"

    rag_query = f"{formatted_chunks}\n\nAnswer the following question: {query}"

    # DEBUG
    # print(rag_query)
    # print("\n")

    run_query_streamed(rag_query, query_model)

    # for i, d in zip(I[0], D[0]):
    #     print(d, chunks[i])
    #     print("")

### Debug Examples
Compare answers to questions, with and without context from most similar document chunks.

In [9]:
k=3
query = "Why did Abraham Lincoln grow a beard?"
embeddings = query_encoder.encode([query])
D, I = index.search(embeddings, k)

In [11]:
I[0]

array([40692232, 47828934, 13314770])

In [11]:
query = "Why did Abraham Lincoln grow a beard?"

# print("Mistral-7B-Instruct-v0.3")
# run_query_streamed(query, model)
# print()

# print("Mistral-7B-Instruct-v0.3 + RAG")
# run_rag_query_streamed(query, model)
# print()

Mistral-7B-Instruct-v0.3
Abraham Lincoln grew a beard primarily for practical reasons. In the mid-19th century, beards were common among men, but Lincoln did not have a beard during his first term as a U.S. Representative. After losing the Senate race in 1858, he grew a beard as a way to differentiate himself from his opponents in his presidential campaign in 1860. The beard also helped to hide a condition he had called "tuberculosis laryngitis," which caused his voice to be weak and hoarse. Additionally, the beard may have helped to make him appear more mature and serious, which could have been beneficial in a time of national crisis.

Mistral-7B-Instruct-v0.3 + RAG
Abraham Lincoln grew a beard in 1860, at the suggestion of 11-year-old Grace Bedell, who wrote him a letter advising that a beard would improve his appearance. He was the first of five presidents to do so. This decision was not only influenced by the advice of a young girl but also as a political strategy, as it helped to 

## Scratchpad

In [26]:
inputs = type(tokenizer("Hello", return_tensors="pt"))

In [27]:
inputs

transformers.tokenization_utils_base.BatchEncoding