# Getting Started with RAG in DSPy

This notebook will show you how to use DSPy to compile a RAG program! DSPy compilation is a fairly new tool for LLM developers, so let's start with an overview of the concept. By `compiling`, we mean finding the prompts that elicit the behavior we want from LLMs when connected in some kind of pipeline.

For example, RAG is a very common LLM pipeline. In it's simplest form, RAG consists of 2 steps, (1) Retrieve and (2) Answer a Question. Part (2), Answering a Question, has an associated prompt, for example, people generally use:

```
--

Please answer the question based on the following context.

context  {context}

question {question}

--
```

This prompt may be a good initial point for an LLM to understand the task. However, it is not the *optimal* prompt. DSPy optimizes the prompt for you by jointly (1) tweaking the instructions, such as rewriting an initial prompt like: 

```
Please answer the question based on the following context.
```

to 

```
Assess the context and answer the given questions that are predominantly about software usage, process optimization, and troubleshooting. Focus on providing accurate information related to tech or software-related queries.
```

Further, DSPy (2) finds examples of desired input-outputs in the prompt to further improve performance, also known as `In-Context Learning`. In this example, we will begin with the simple prompt: `Please answer the question based on the following context.` and end up with:

```

```

In order to leverage black-box optimization techniques like random search, bayesian optimization, or evolutionary algorithms, we need a metric. Coming up with metrics to describe desired system behavior has been a longstanding challenge in Machine Learning research. Excitingly, LLMs have made amazing progress. For example, we can evaluate a RAG answer by prompting an LLM with, `Is the assessed text grounded in the context? Say no if it includes significant facts not in the context`. We then optimize the RAG program to increase the metric LLM's assessment of answer quality.

This example contains 4 parts:

- 0: DSPy Settings and Installation
- 1: DSPy Datasets with `dspy.Example`
- 2: LLM Metrics in DSPy
- 3: LLM Programming with `dspy.Module`
- 4: Optimization with `BootstrapFewShot`, `BootstrapFewShotRandomSearch`, and `BayesianSignatureOptimizer`.


We are using 2 datasets for this example. Firstly, we have an index of the Weaviate Blog Posts. We will use the Weaviate Blog Posts as the retrieved context to help with our second dataset, the Weaviate FAQs. The Weaviate FAQs consists of 44 question-answer pairs of frequently asked Weaviate questions such as: `Do I need to know about Docker (Compose) to use Weaviate?`

We isolate 10 examples to use as our test set and optimize our program with the remaining 34.

Our uncompiled RAG program achieves a score of 270 on the held-out test set.

Our RAG program compiled with the `BayesianSignatureOptimizer` achieves a score of 340! A ~30% improvement!

# 0: DSPy Settings and Installations

In [None]:
# !pip install dspy-ai==2.1.9 weaviate-client==3.26.2 > /dev/null

In [None]:
import openai

In [None]:
# Connect to Weaviate Retriever and configure LLM
import dspy
from dspy.retrieve.weaviate_rm import WeaviateRM
import weaviate
import openai


llm = dspy.OpenAI(model="gpt-3.5-turbo")

# ollamaLLM = dspy.OpenAI(api_base="http://localhost:11434/v1/", api_key="ollama", model="mistral-7b-instruct-v0.2-q6_K", stop='\n\n', model_type='chat')
# Thanks Knox! https://twitter.com/JR_Knox1977/status/1756018720818794916/photo/1

weaviate_client = weaviate.Client("http://localhost:8080")
retriever_model = WeaviateRM("WeaviateBlogChunk", weaviate_client=weaviate_client)
# Assumes the Weaviate collection has a text key `content`
dspy.settings.configure(lm=llm, rm=retriever_model)

In [None]:
print(dspy.settings.lm("Write a 3 line poem about neural networks."))
context_example = dspy.OpenAI(model="gpt-4")

with dspy.context(llm=context_example):
    print(context_example("Write a 3 line poem about neural networks."))

# 1. DSPy Datasets with `dspy.Example`

Our retrieval engine is filled with chunks from Weaviate Blog posts.

Please see weaviate/recipes/integrations/dspy/Weaviate-Import.ipynb for a full tutorial.

# Import FAQs from a markdown file

In [None]:
# Load FAQs
import re

f = open("faq.md")
markdown_content = f.read()

def parse_questions(markdown_content):
    # Regular expression pattern for finding questions
    question_pattern = r'#### Q: (.+?)\n'

    # Finding all questions
    questions = re.findall(question_pattern, markdown_content, re.DOTALL)

    return questions

# Parsing the markdown content to get only questions
questions = parse_questions(markdown_content)

# Displaying the first few extracted questions
questions[:5]  # Displaying only the first few for brevity

In [None]:
len(questions)

# Wrap each FAQ into an `Example` object

The dspy `Example` object optionally lets you attach metadata, or additional labels, to input/output pairs.

For example, you may want to jointly supervise the answer as well as the context the retrieval system produced to feed into the answer generator.

In [None]:
# Load into dspy datasets
import dspy

# ToDo, add random splitting -- maybe wrap this entire thing in a cross-validation loop
trainset = questions[:20] # 20 examples for training
devset = questions[20:30] # 10 examples for development
testset = questions[30:] # 14 examples for testing

trainset = [dspy.Example(question=question).with_inputs("question") for question in trainset]
devset = [dspy.Example(question=question).with_inputs("question") for question in devset]
testset = [dspy.Example(question=question).with_inputs("question") for question in testset]

In [None]:
devset[0]

# 2. LLM Metrics

Define a Metric for Performance.

In [None]:
# This is a WIP, the next step is to optimize this metric as itself a DSPy module (pretty meta)

# Reference - https://github.com/stanfordnlp/dspy/blob/main/examples/tweets/tweet_metric.py

metricLM = dspy.OpenAI(model='gpt-4', max_tokens=1000, model_type='chat')

# Signature for LLM assessments.

class Assess(dspy.Signature):
    """Assess the quality of an answer to a question."""
    
    context = dspy.InputField(desc="The context for answering the question.")
    assessed_question = dspy.InputField(desc="The evaluation criterion.")
    assessed_answer = dspy.InputField(desc="The answer to the question.")
    assessment_answer = dspy.OutputField(desc="A rating between 1 and 5. Only output the rating and nothing else.")

def llm_metric(gold, pred, trace=None):
    predicted_answer = pred.answer
    question = gold.question
    
    print(f"Test Question: {question}")
    print(f"Predicted Answer: {predicted_answer}")
    
    detail = "Is the assessed answer detailed?"
    faithful = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."
    overall = f"Please rate how well this answer answers the question, `{question}` based on the context.\n `{predicted_answer}`"
    
    with dspy.context(lm=metricLM):
        context = dspy.Retrieve(k=5)(question).passages
        detail = dspy.ChainOfThought(Assess)(context="N/A", assessed_question=detail, assessed_answer=predicted_answer)
        faithful = dspy.ChainOfThought(Assess)(context=context, assessed_question=faithful, assessed_answer=predicted_answer)
        overall = dspy.ChainOfThought(Assess)(context=context, assessed_question=overall, assessed_answer=predicted_answer)
    
    print(f"Faithful: {faithful.assessment_answer}")
    print(f"Detail: {detail.assessment_answer}")
    print(f"Overall: {overall.assessment_answer}")
    
    
    total = float(detail.assessment_answer) + float(faithful.assessment_answer)*2 + float(overall.assessment_answer)
    
    return total / 5.0

## Inspect the metric

In [None]:
test_example = dspy.Example(question="What do cross encoders do?")
test_pred = dspy.Example(answer="They re-rank documents.")

type(llm_metric(test_example, test_pred))

In [None]:
test_example = dspy.Example(question="What do cross encoders do?")
test_pred = dspy.Example(answer="They index data.")

type(llm_metric(test_example, test_pred))

In [None]:
metricLM.inspect_history(n=3)

# 3. The DSPy Programming Model

This block of first code will initilaize the `GenerateAnswer` signature.

Then we will compose a `dspy.Module` consisting of:
- Retrieve
- GenerateAnswer

The DSPy programming model is one of the most powerful aspects of DSPy, we get:
- An intuitive interface to compose prompts into programs.
- A clean way to organize prompts into Signatures.
- Structured output parsing with `dspy.OutputField`
- Built-in prompt extensions such as `ChainOfThought`, `ReAct`, and more!

In [None]:
class GenerateAnswer(dspy.Signature):
    """Answer questions based on the context."""
    
    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField()

In [None]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(answer=prediction.answer)

# A little more info on built-in dspy modules

The DSPy programming model gives you a lot of cool features out of the box. Observe how different modules implement signatures with additional prompting techniques like `ChainOfThought` and `ReAct`. `Predict` is the base class to observe what a standrd prompt looks like without the module extensions.

### dspy.Predict

In [None]:
dspy.Predict(GenerateAnswer)(question="What are Cross Encoders?")
llm.inspect_history(n=1)

### dspy.ChainOfThought

In [None]:
dspy.ChainOfThought(GenerateAnswer)(question="What are Cross Encoders?")
llm.inspect_history(n=1)

### dspy.ReAct

In [None]:
dspy.ReAct(GenerateAnswer, tools=[dspy.settings.rm])(question="What are cross encoders?")
llm.inspect_history(n=1)

# Initialize DSPy Program

In [None]:
uncompiled_rag = RAG()

# Test uncompiled inference 

In [None]:
print(uncompiled_rag("What are re-rankers in search engines?").answer)

# Check the last call to the LLM

In [None]:
llm.inspect_history(n=1)

# 4. DSPy Optimization

# Evaluate our RAG Program before it is compiled

In [None]:
# Reminder our dataset looks like this:

devset[0]

In [None]:
from dspy.evaluate.evaluate import Evaluate

evaluate = Evaluate(devset=devset, num_threads=1, display_progress=True, display_table=5)

evaluate(RAG(), metric=llm_metric)

# Metric Analysis

The maximum value per rating is (5 + 5*2 + 5) / 5 = 4

4 * 10 test questions = 40

In [None]:
llm.inspect_history(n=1)

In [None]:
metricLM.inspect_history(n=3)

# BootstrapFewShot

In [None]:
from dspy.teleprompt import BootstrapFewShot

teleprompter = BootstrapFewShot(metric=llm_metric, max_labeled_demos=8, max_rounds=3)

# also common to init here, e.g. Rag()
compiled_rag = teleprompter.compile(uncompiled_rag, trainset=trainset)

### Inspect the compiled prompt

In [None]:
compiled_rag("What do cross encoders do?").answer

In [None]:
llm.inspect_history(n=1)

### Evaluate the Compiled RAG Program

In [None]:
evaluate(compiled_rag, metric=llm_metric)

# BootstrapFewShotWithRandomSearch

In [None]:
# Accidentally spent $12 on this with `num_candidate_programs=20`, caution!

In [None]:
from dspy.teleprompt import BootstrapFewShotWithRandomSearch

teleprompter = BootstrapFewShotWithRandomSearch(metric=llm_metric, 
                                                max_bootstrapped_demos=4,
                                                max_labeled_demos=4, 
                                                max_rounds=1,
                                                num_candidate_programs=2,
                                                num_threads=2)

# also common to init here, e.g. Rag()
second_compiled_rag = teleprompter.compile(uncompiled_rag, trainset=trainset)

In [None]:
second_compiled_rag("What do cross encoders do?")

In [None]:
llm.inspect_history(n=1)

In [None]:
evaluate(second_compiled_rag, metric=llm_metric)

# BayesianSignatureOptimizer

In [None]:
from dspy.teleprompt import BayesianSignatureOptimizer

llm_prompter = dspy.OpenAI(model='gpt-4', max_tokens=2000, model_type='chat')

teleprompter = BayesianSignatureOptimizer(task_model=dspy.settings.lm,
                                          metric=llm_metric,
                                          prompt_model=llm_prompter,
                                          n=5,
                                          verbose=False)

kwargs = dict(num_threads=1, display_progress=True, display_table=0)
third_compiled_rag = teleprompter.compile(RAG(), devset=devset,
                                         optuna_trials_num=3,
                                         max_bootstrapped_demos=4,
                                         max_labeled_demos=4,
                                         eval_kwargs=kwargs)

In [None]:
third_compiled_rag("What do cross encoders do?")

# Check this out!!

Below you can see how the BayesianSignatureOptimizer jointly (1) optimizes the task instruction to:

```
Assess the context and answer the given questions that are predominantly about software usage, process optimization, and troubleshooting. Focus on providing accurate information related to tech or software-related queries.
```

As well as sourcing input-output examples for the prompt!

In [None]:
llm.inspect_history(n=1)

In [None]:
evaluate(third_compiled_rag, metric=llm_metric)

# Test Set Eval

In [None]:
# Evaluate Uncompiled
from dspy.evaluate.evaluate import Evaluate

# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
evaluate = Evaluate(devset=testset, num_threads=1, display_progress=True, display_table=5)

In [None]:
evaluate(uncompiled_rag, metric=llm_metric)

In [None]:
evaluate(compiled_rag, metric=llm_metric)

In [None]:
evaluate(second_compiled_rag, metric=llm_metric)

In [None]:
evaluate(third_compiled_rag, metric=llm_metric)