In [2]:
%load_ext autoreload
%autoreload 2

import dspy
from dsp.utils import deduplicate
from dspy.datasets import HotPotQA
from dspy.predict.retry import Retry
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch
from dspy.evaluate.evaluate import Evaluate

from dspy.primitives.assertions import assert_transform_module, backtrack_handler

In [3]:
import os
import openai
openai.api_key = os.getenv('OPENAI_API_KEY')

In [4]:
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)
turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=500)
dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)

In [5]:
dataset = HotPotQA(train_seed=1, train_size=300, eval_seed=2023, dev_size=300, test_size=0)
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]

In [6]:
# Suggestion helper functions and Teleprompter metric

def validate_query_distinction_local(previous_queries, query):
    """check if query is distinct from previous queries"""
    if previous_queries == []:
        return True
    if dspy.evaluate.answer_exact_match_str(query, previous_queries, frac=0.8):
        return False
    return True


def validate_context_and_answer_and_hops(example, pred, trace=None):
    if not dspy.evaluate.answer_exact_match(example, pred):
        return False

    if not dspy.evaluate.answer_passage_match(example, pred):
        return False

    return True


In [7]:
# Extrinsic metrics

def gold_passages_retrieved(example, pred, trace=None):
    gold_titles = set(map(dspy.evaluate.normalize_text, example['gold_titles']))
    found_titles = set(map(dspy.evaluate.normalize_text, [c.split(' | ')[0] for c in pred.context]))

    return gold_titles.issubset(found_titles)

In [8]:
# signatures of dspy modules
class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class GenerateSearchQuery(dspy.Signature):
    """Write a simple search query that will help answer a complex question."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    query = dspy.OutputField()

In [9]:
class SimplifiedBaleen(dspy.Module):
    def __init__(self, passages_per_hop=2, max_hops=2):
        super().__init__()

        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.max_hops = max_hops

    def forward(self, question):
        context = []
        for hop in range(self.max_hops):
            query = self.generate_query[hop](context=context, question=question).query
            passages = self.retrieve(query).passages
            context = deduplicate(context + passages)
        pred = self.generate_answer(context=context, question=question)
        pred = dspy.Prediction(context=context, answer=pred.answer)
        return pred

In [20]:
class SimplifiedBaleenAssertions(dspy.Module):
    def __init__(self, passages_per_hop=2, max_hops=2):
        super().__init__()
        self.generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.max_hops = max_hops

    def forward(self, question):
        context = []
        prev_queries = [question]

        for hop in range(self.max_hops):
            query = self.generate_query[hop](context=context, question=question).query

            dspy.Suggest(
                len(query) <= 100,
                "Query should be short and less than 100 characters",
            )

            dspy.Suggest(
                validate_query_distinction_local(prev_queries, query),
                "Query should be distinct from: "
                + "; ".join(f"{i+1}) {q}" for i, q in enumerate(prev_queries)),
                is_metric=True,
            )

            prev_queries.append(query)
            passages = self.retrieve(query).passages
            context = deduplicate(context + passages)

        pred = self.generate_answer(context=context, question=question)
        pred = dspy.Prediction(context=context, answer=pred.answer)
        return pred


In [11]:
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=25, display_progress=True, display_table=False)

In [12]:
def evaluate(module):
    retrieval_score = evaluate_on_hotpotqa(
        module, metric=gold_passages_retrieved
    )
    
    accuracy_score = evaluate_on_hotpotqa(
        module, metric=dspy.evaluate.answer_exact_match
    )

    print(f"## Retrieval Score: {retrieval_score}")
    print(f"## Accuracy Score: {accuracy_score}")

In [13]:
# # No Compilation + No Assertion
# baleen = SimplifiedBaleen()
# evaluate(baleen)


In [14]:
# # No Compilation + Yes Assertion
# baleen_with_assertions = assert_transform_module(SimplifiedBaleenAssertions().map_named_predictors(Retry), suggest_backtrack_handler) 
# evaluate(baleen_with_assertions)

In [15]:
# Yes Compilation + No Assertion
# baleen = SimplifiedBaleen()
# teleprompter = BootstrapFewShotWithRandomSearch(
#     metric=validate_context_and_answer_and_hops,
#     max_bootstrapped_demos=2,
#     num_candidate_programs=6,
# )

# compiled_baleen = teleprompter.compile(student = baleen, teacher = baleen, trainset = trainset, valset = devset)
# evaluate(compiled_baleen)

In [21]:
# Yes Compilation + Yes Assertion
baleen = SimplifiedBaleen()
teleprompter = BootstrapFewShotWithRandomSearch(
    metric=validate_context_and_answer_and_hops,
    max_bootstrapped_demos=2,
    num_candidate_programs=6,
)
compiled_baleen = teleprompter.compile(
    student=assert_transform_module(
        SimplifiedBaleenAssertions().map_named_predictors(Retry),
        backtrack_handler,
    ),
    teacher=baleen,
    trainset=trainset,
    valset=devset[:100]
)
evaluate(compiled_baleen)

Going to sample between 1 and 2 traces per predictor.
Will attempt to train 6 candidate sets.


Average Metric: 43 / 100  (43.0): 100%|██████████| 100/100 [00:00<00:00, 524.07it/s]
  df = df.applymap(truncate_cell)


Average Metric: 43 / 100  (43.0%)
Score: 43.0 for set: [0, 0, 0]
New best score: 43.0 for seed -3
Scores so far: [43.0]
Best score: 43.0


Average Metric: 42 / 100  (42.0): 100%|██████████| 100/100 [00:00<00:00, 514.39it/s]


Average Metric: 42 / 100  (42.0%)
Score: 42.0 for set: [16, 16, 16]
Scores so far: [43.0, 42.0]
Best score: 43.0


  1%|▏         | 4/300 [00:00<00:00, 1176.85it/s]


Bootstrapped 2 full traces after 5 examples in round 0.
# of suggestion failures during bootstrapping: 4


Average Metric: 45 / 100  (45.0): 100%|██████████| 100/100 [00:00<00:00, 517.62it/s]


Average Metric: 45 / 100  (45.0%)
Score: 44.2 for set: [16, 16, 16]
New best score: 44.2 for seed -1
Scores so far: [43.0, 42.0, 44.2]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.51
Average of max per entry across top 3 scores: 0.54
Average of max per entry across top 5 scores: 0.54
Average of max per entry across top 8 scores: 0.54
Average of max per entry across top 9999 scores: 0.54


  1%|          | 3/300 [00:00<00:00, 1135.03it/s]


Bootstrapped 2 full traces after 4 examples in round 0.
# of suggestion failures during bootstrapping: 4


Average Metric: 44 / 100  (44.0): 100%|██████████| 100/100 [00:00<00:00, 486.05it/s]


Average Metric: 44 / 100  (44.0%)
Score: 43.2 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.56
Average of max per entry across top 8 scores: 0.56
Average of max per entry across top 9999 scores: 0.56


  1%|          | 2/300 [00:00<00:00, 1088.72it/s]


Bootstrapped 1 full traces after 3 examples in round 0.
# of suggestion failures during bootstrapping: 0


Average Metric: 40 / 100  (40.0): 100%|██████████| 100/100 [00:00<00:00, 468.94it/s]


Average Metric: 40 / 100  (40.0%)
Score: 40.0 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2, 40.0]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.58
Average of max per entry across top 8 scores: 0.58
Average of max per entry across top 9999 scores: 0.58


  1%|          | 2/300 [00:00<00:00, 1038.84it/s]


Bootstrapped 1 full traces after 3 examples in round 0.
# of suggestion failures during bootstrapping: 2


Average Metric: 38 / 100  (38.0): 100%|██████████| 100/100 [00:00<00:00, 537.95it/s]


Average Metric: 38 / 100  (38.0%)
Score: 37.6 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2, 40.0, 37.6]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.58
Average of max per entry across top 8 scores: 0.58
Average of max per entry across top 9999 scores: 0.58


  1%|▏         | 4/300 [00:00<00:00, 1060.71it/s]


Bootstrapped 1 full traces after 5 examples in round 0.
# of suggestion failures during bootstrapping: 1


Average Metric: 37 / 100  (37.0): 100%|██████████| 100/100 [00:00<00:00, 428.19it/s]


Average Metric: 37 / 100  (37.0%)
Score: 36.8 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2, 40.0, 37.6, 36.8]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.58
Average of max per entry across top 8 scores: 0.59
Average of max per entry across top 9999 scores: 0.59


  1%|          | 3/300 [00:00<00:00, 1050.50it/s]


Bootstrapped 1 full traces after 4 examples in round 0.
# of suggestion failures during bootstrapping: 1


Average Metric: 43 / 100  (43.0): 100%|██████████| 100/100 [00:00<00:00, 555.81it/s]


Average Metric: 43 / 100  (43.0%)
Score: 42.8 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2, 40.0, 37.6, 36.8, 42.8]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.57
Average of max per entry across top 8 scores: 0.6
Average of max per entry across top 9999 scores: 0.6


  1%|          | 2/300 [00:00<00:00, 616.54it/s]


Bootstrapped 2 full traces after 3 examples in round 0.
# of suggestion failures during bootstrapping: 2


Average Metric: 40 / 100  (40.0): 100%|██████████| 100/100 [00:00<00:00, 460.90it/s]


Average Metric: 40 / 100  (40.0%)
Score: 39.6 for set: [16, 16, 16]
Scores so far: [43.0, 42.0, 44.2, 43.2, 40.0, 37.6, 36.8, 42.8, 39.6]
Best score: 44.2
Average of max per entry across top 1 scores: 0.45
Average of max per entry across top 2 scores: 0.52
Average of max per entry across top 3 scores: 0.55
Average of max per entry across top 5 scores: 0.57
Average of max per entry across top 8 scores: 0.6
Average of max per entry across top 9999 scores: 0.61
9 candidate programs found.


Average Metric: 126 / 300  (42.0): 100%|██████████| 300/300 [00:00<00:00, 595.84it/s]


Average Metric: 126 / 300  (42.0%)


Average Metric: 150 / 300  (50.0): 100%|██████████| 300/300 [00:00<00:00, 618.44it/s]

Average Metric: 150 / 300  (50.0%)
## Retrieval Score: 42.0
## Accuracy Score: 50.0



