In [1]:
import dspy
from dsp.utils import deduplicate
from dspy.datasets import HotPotQA
from dspy.predict.retry import Retry
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
from dspy.evaluate.evaluate import Evaluate

from dspy.primitives.assertions import assert_transform_module, backtrack_handler

In [2]:
import os
import openai
openai.api_key = os.getenv('OPENAI_API_KEY')

In [3]:
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)
turbo = dspy.OpenAI(model='gpt-4o-mini', max_tokens=500)
dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)

In [4]:
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0, keep_details=True)
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]

In [5]:
# Suggestion helper functions and Teleprompter metric

def validate_query_distinction_local(previous_queries, query):
    """check if query is distinct from previous queries"""
    if previous_queries == []:
        return True
    if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8):
        return False
    return True


def validate_context_and_answer_and_hops(example, pred, trace=None):
    if not dspy.evaluate.answer_exact_match(example, pred):
        return False

    if not dspy.evaluate.answer_passage_match(example, pred):
        return False

    return True


In [6]:
# Extrinsic metrics

def gold_passages_retrieved(example, pred, trace=None):
    gold_titles = set(map(dspy.evaluate.normalize_text, example['gold_titles']))
    found_titles = set(map(dspy.evaluate.normalize_text, [c.split(' | ')[0] for c in pred.context]))

    return gold_titles.issubset(found_titles)

In [7]:
# signatures of dspy modules
class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class GenerateSearchQuery(dspy.Signature):
    """Write a simple search query that will help answer a complex question."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    query = dspy.OutputField()

In [8]:
def all_queries_distinct(prev_queries):
    query_distinct = True
    for i, query in enumerate(prev_queries):
        if validate_query_distinction_local(prev_queries[:i], query) == False:
            query_distinct = False
            break
    return query_distinct

In [9]:
class SimplifiedBaleen(dspy.Module):
    def __init__(self, passages_per_hop=2, max_hops=2):
        super().__init__()

        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.max_hops = max_hops

        # for evaluating assertions only
        self.passed_suggestions = 0

    def forward(self, question):
        context = []
        prev_queries = [question]

        for hop in range(self.max_hops):
            query = self.generate_query[hop](context=context, question=question).query
            prev_queries.append(query)
            passages = self.retrieve(query).passages
            context = deduplicate(context + passages)
        
        if all_queries_distinct(prev_queries):
            self.passed_suggestions += 1
        
        pred = self.generate_answer(context=context, question=question)
        pred = dspy.Prediction(context=context, answer=pred.answer)
        return pred

In [10]:
class SimplifiedBaleenAssertions(dspy.Module):
    def __init__(self, passages_per_hop=2, max_hops=2):
        super().__init__()
        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.max_hops = max_hops

        # for evaluating assertions only
        self.passed_suggestions = 0

    def forward(self, question):
        context = []
        prev_queries = [question]

        for hop in range(self.max_hops):
            query = self.generate_query[hop](context=context, question=question).query

            dspy.Suggest(
                len(query) <= 100,
                "Query should be short and less than 100 characters",
                target_module=self.generate_query
            )

            dspy.Suggest(
                validate_query_distinction_local(prev_queries, query),
                "Query should be distinct from: "
                + "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)),
                target_module=self.generate_query
            )

            prev_queries.append(query)
            passages = self.retrieve(query).passages
            context = deduplicate(context + passages)
        
        if all_queries_distinct(prev_queries):
            self.passed_suggestions += 1

        pred = self.generate_answer(context=context, question=question)
        pred = dspy.Prediction(context=context, answer=pred.answer)
        return pred


In [11]:
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=10, display_progress=True, display_table=False)

In [12]:
def evaluate(module):
    module.passed_suggestions = 0

    retrieval_score = evaluate_on_hotpotqa(
        module, metric=gold_passages_retrieved
    )
    
    suggestions_score = module.passed_suggestions / len(devset) * 100

    accuracy_score = evaluate_on_hotpotqa(
        module, metric=dspy.evaluate.answer_exact_match
    )

    print(f"## Suggestions Score: {suggestions_score}")

    print(f"## Retrieval Score: {retrieval_score}")
    print(f"## Accuracy Score: {accuracy_score}")

In [None]:
# No Compilation + No Assertion
baleen = SimplifiedBaleen()
evaluate(baleen)


In [None]:
# No Compilation + Yes Assertion
baleen_with_assertions = assert_transform_module(SimplifiedBaleenAssertions().map_named_predictors(Retry), backtrack_handler) 
evaluate(baleen_with_assertions)

In [None]:
max_bootstrapped_demos = 2

In [None]:
# Yes Compilation + No Assertion
baleen = SimplifiedBaleen()
teleprompter = BootstrapFewShotWithRandomSearch(
    metric=validate_context_and_answer_and_hops,
    max_bootstrapped_demos=max_bootstrapped_demos,
    num_candidate_programs=6,
)

compiled_baleen = teleprompter.compile(student = SimplifiedBaleen(), teacher = SimplifiedBaleen(), trainset = trainset, valset = devset)
evaluate(compiled_baleen)

In [None]:
# Yes Compilation + Yes Assertion
baleen = SimplifiedBaleen()
teleprompter = BootstrapFewShotWithRandomSearch(
    metric=validate_context_and_answer_and_hops,
    max_bootstrapped_demos=max_bootstrapped_demos,
    num_candidate_programs=6,
)
compiled_baleen = teleprompter.compile(
    student=assert_transform_module(
        SimplifiedBaleenAssertions().map_named_predictors(Retry),
        backtrack_handler,
    ),
    teacher=baleen,
    trainset=trainset,
    valset=devset
)
evaluate(compiled_baleen)
