## Load Mistral and HOTPOT Dataset 🍲

In [None]:
!pip install faiss-cpu==1.7.4 mistralai==0.0.12
!pip install -q datasets

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from mistralai.client import MistralClient, ChatMessage
from IPython.display import clear_output
from datasets import load_dataset
from getpass import getpass
from tqdm import tqdm
import numpy as np
import requests
import faiss
import torch
import time
import os

client = MistralClient(api_key="") # TODO: Paste API key here
dataset = load_dataset("hotpot_qa", "distractor")
clear_output()

## Helper Functions 🤝

In [None]:
model = 'open-mistral-7b'

def get_text_embedding(input):
    embeddings_batch_response = client.embeddings(model="mistral-embed",input=input)
    return embeddings_batch_response.data[0].embedding

def chunk_context(context):
    titles = context["title"]
    sentences = context["sentences"]
    chunks = ["# {}\n{}".format(titles[i], "".join(sentences[i])) for i in range(len(titles))]
    return chunks

def make_vector_database(chunks):
    # Get text embeddings
    text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])

    # Create vector database
    d = text_embeddings.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(text_embeddings)
    return index

def retrieve_relevant_chunks(chunks, index, question, k=2):
    question_embeddings = np.array([get_text_embedding(str(question))])
    distances, indices = index.search(question_embeddings, k)
    retrieved_chunk = [chunks[i] for i in indices.tolist()[0]]
    return retrieved_chunk

def run_mistral(user_message):
    messages = [ChatMessage(role="user", content=user_message)]
    chat_response = client.chat(model=model,messages=messages)
    return (chat_response.choices[0].message.content)


In [None]:
junk_context = """
Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. Ut wisi enim ad minim veniam, quis nostrud exerci tation ullamcorper suscipit lobortis nisl ut aliquip ex ea commodo consequat. Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu feugiat nulla facilisis at vero eros et accumsan.
"""

irrelevant_chunks = [junk_context] * 500

In [None]:
irrelevant_chunks

['\nLorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. Ut wisi enim ad minim veniam, quis nostrud exerci tation ullamcorper suscipit lobortis nisl ut aliquip ex ea commodo consequat. Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu feugiat nulla facilisis at vero eros et accumsan.\n',
 '\nLorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat. Ut wisi enim ad minim veniam, quis nostrud exerci tation ullamcorper suscipit lobortis nisl ut aliquip ex ea commodo consequat. Duis autem vel eum iriure dolor in hendrerit in vulputate velit esse molestie consequat, vel illum dolore eu feugiat nulla facilisis at vero eros et accumsan.\n',
 '\nLorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore 

## Define RAG Chain  🔗

In [None]:
def no_context_inference(input):
    prompt = "Answer the following question with a one-word response. Question: {}".format(input["question"])
    output = run_mistral(prompt)
    return output

def full_context_inference(input):
    chunks = chunk_context(input["context"])
    context = "".format(chunks)
    prompt = "Answer the following question with a one-word response based only on the provided context. Context: {} Question: {}".format(context,input["question"])
    output = run_mistral(prompt)
    return output

def single_hop_rag(input):
    chunks = chunk_context(input["context"])
    t0 = time.time()
    index = make_vector_database(chunks)
    t1 = time.time()
    index_time = t1-t0
    context = retrieve_relevant_chunks(chunks, index, input["question"])
    prompt = "Answer the following question with a one-word response based only on the provided context. Context: {} Question: {}".format(context,input["question"])
    output = run_mistral(prompt)
    return output, index_time

# Tests scaling by adding irrelavant context to the vector databse
def single_hop_with_irrelevant(index, input):
    context = retrieve_relevant_chunks(irrelevant_chunks + chunks, index, input["question"])
    prompt = "Answer the following question with a one-word response based only on the provided context. Context: {} Question: {}".format(context,input["question"])
    output = run_mistral(prompt)
    return output


### Check that Multi-hop Works on a Single Question ✅
Returns a tuple of (answer, time spent creating index)

In [None]:
sys_prompt = """Answer the following question based only on the provided context by explaining each step of your reasoning in a seperate sentence. If you are able to find the answer, end with `the answer is:` followed by the answer. If not, answer the sub-question.

Example 1:```
Question: What is color fur does Jasper's sister have?
Context: # Jasper - Jasper is a dog who has a sister named Lily. # Lily - Lily has purple fur.
Response: Jasper's sister is Lily. Lily has purple fur. Jasper's sister therefore has purple fur. The answer is: purple.
```

Example 2:```
Question: What city is the director of the film `Two Great Houses` from?
Context: # Two Great Houses - Two Great Houses is a film by directory Andy Wharl that was made in 1975. # Something irrelveant - irrelevant context.
Response: The director of `Two Great Houses` is Andy Wharl.
```

Example 3:```
Question: What is the favorite color of the women who invented `Troopers`?
Context: # Troopers - Troopers is a clothing brand invented by Sierra Wiltons. # Sierra Wiltons - Sierra Wiltons's favorite color is orange.
Response: The women who invented `Troopers` is Sierra Wiltons. Sierra Wilton's favorite color is orange. The answer is: orange.
```
"""
def apply_system_prompt(prompt):
  prefix = "<|im_start|>"
  suffix = "<|im_end|>\n"
  sys_format = prefix + "system\n" + sys_prompt + suffix
  user_format = prefix + "user\n" + prompt + suffix
  assistant_format = prefix + "assistant\n"
  input_text = sys_format + user_format + assistant_format
  return input_text

def multi_hop_rag(input, max_hops=2, verbose=False):
    if(verbose):
        print("Question:",input["question"])
        print("Target:",input["answer"])
        print("\n")

    # Obtain relevant context
    t0 = time.time()
    chunks = chunk_context(input["context"])
    index = make_vector_database(chunks)
    t1 = time.time()
    index_time = t1-t0
    context = set(retrieve_relevant_chunks(chunks, index, input["question"]))
    reasoning = ""

    for hop in range(max_hops):
        if(verbose):
            print("Hop: {}, context len: {}".format(hop,len(context)))
        prompt = f"""
        Context: {context}
        Question: {input["question"]}
        {reasoning}
        """
        prompt = apply_system_prompt(prompt)
        chat_response = client.chat(model=model, messages=[ChatMessage(role="user", content=prompt)])

        # Add the first sentence of reasoning
        reasoning_steps = chat_response.choices[0].message.content
        curr_step = reasoning_steps.split(".")[0]
        reasoning += curr_step

        if(verbose):
            print("PROMPT",prompt)
            print("REASONING_STEPS:", reasoning_steps)
            print("CURR_STEP", curr_step)

        # Add newly obtained context
        for chunk in retrieve_relevant_chunks(chunks, index, curr_step):
          context.add(chunk)

        if "answer is: " in reasoning_steps:
            return reasoning_steps.split("answer is: ")[-1], index_time

    return "", index_time

multi_hop_rag(dataset["validation"][3], verbose=True)

### Test Index Scaling 📈
See how adding junk data to the index influences the retrieval time

In [None]:
# Time spent making the index with junk added
train_dataset = dataset["train"].select(range(250))
example = train_dataset[0]
t0 = time.time()
chunks = chunk_context(example["context"])
big_index = make_vector_database(irrelevant_chunks + chunks)
t1 = time.time()
t1-t0

In [None]:
# Time spent running single-hop with a larger index
t0 = time.time()
output = single_hop_with_irrelevant(big_index, train_dataset[0])
output
t1 = time.time()
t1-t0

In [None]:
# Time spent running single-hop with a small index
t0 = time.time()
output, index_time = single_hop_rag(example)
t1 = time.time()
(t1 - t0) - index_time

## Test RAG 🧪

In [None]:
val_dataset = dataset["validation"].select(range(500)) # Evaluate on subset of validation set

incorrect_indicies = []
check_these = []
num_correct = 0
total_latency = 0
total_index_time = 0
total_router_time = 0
stop_time = 0
import time
with tqdm(total=len(val_dataset)) as pbar:
    t0 = time.time()
    for i,entry in enumerate(val_dataset):
        while True:
            try:
                start = time.time()
                index_time, router_time = 0, 0
                output = no_context_inference(entry)
                # output, index_time = single_hop_rag(entry)
                #output, index_time = multi_hop_rag(entry, max_hops=1)
                #output, index_time, router_time = adaptive_rag(entry)
                if entry["answer"].lower() in output.lower():
                    num_correct += 1
                else:
                  incorrect_indicies.append(i)
                  check_these.append((entry["answer"],output.lower()))
                accuracy = num_correct / (pbar.n + 1)
                pbar.set_postfix({'accuracy': accuracy})
                pbar.update(1)
                break
            except Exception as e:
                print(e)
                time.sleep(5)
                end = time.time()
                stop_time += end-start


        total_index_time += index_time
        total_router_time += router_time
    t1 = time.time()
    total_latency += t1-t0 - stop_time
    print(f"\nAverage Latency w/Index {total_latency/len(val_dataset)}")
    print(f"\nAverage Index Time {total_index_time/len(val_dataset)}")
    print(f"\nAverage Router Time {total_router_time/len(val_dataset)}")
    print(f"\nAverage Latency w/o Index {(total_latency - total_index_time)/len(val_dataset)}")

print(f"\nFinal Accuracy: {num_correct/len(val_dataset)}")

# Adaptive RAG
## Create classifier data for training router

In [None]:
# Create labelin

num_correct = 0
train_dataset = dataset["train"].select(range(1000))
import time
total_latency = 0
total_index_time = 0
with tqdm(total=len(train_dataset)) as pbar:
    t0 = time.time()
    for i,entry in enumerate(train_dataset):
        while True:
          try:
            begin = time.time()
            output = run_mistral(entry["question"])
            if entry["answer"].lower() in output.lower():
              num_correct += 1
              labels.append("mistral")
              break

            output, _ = multi_hop_rag(entry, max_hops=1)
            if entry["answer"].lower() in output.lower():
              num_correct += 1
              labels.append("rag")
              break

            output, _ = multi_hop_rag(entry, max_hops=3)
            if entry["answer"].lower() in output.lower():
              num_correct += 1
              labels.append("3hop-rag")
            else:
              labels.append("no-label")
            break
          except Exception as e:
            print(e)
            time.sleep(5)
            end = time.time()


        accuracy = num_correct / (pbar.n + 1)
        pbar.set_postfix({'accuracy': accuracy})
        pbar.update(1)

print(f"\nFinal Accuracy: {num_correct/len(train_dataset)}")

In [None]:
class QuestionClassifierBERT:
    '''
    Custom Model Finetuned on HotPotQA + Custom data
    '''
    def __init__(self, path="scott-routledge/bert-hotpotqa-classifier-2", classes=["Easy", "Hard"]):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path)
        self.classes = classes


    def predict_difficulty(self, texts):
        encoded_texts = self.tokenizer(
            [text.lower() for text in texts],
            padding='max_length',
            truncation=True,
            max_length=128,
            return_tensors="pt"
            )
        encoded_texts = {k: v.to("cuda:0") for k,v in encoded_texts.items()}
        output = self.model(**encoded_texts)
        prediction = torch.argmax(output.logits, axis=1)
        return self.classes[prediction[0].item()]

In [None]:
router = QuestionClassifierBERT()

In [None]:
router.model.to("cuda:0")

In [None]:
model_counts = {"mistral": 0, "rag": 0}
def adaptive_rag(entry):
  t0 = time.time()
  difficulty = router.predict_difficulty([entry["question"]])
  t1 = time.time()
  router_time = t1-t0
  index_time = 0
  if difficulty == "Easy":
    # print("routing to no context")
    output = no_context_inference(entry)
    # model_counts["mistral"] +=1
  elif difficulty == "Hard":
    # print("routing to RAG")
    output, index_time = single_hop_rag(entry)
    # model_counts["rag"] += 1
  return output, index_time, router_time

