<img src="../../docs/docs/static/img/dspy_logo.png" alt="DSPy7 Image" height="150"/>

## **DSPy Assertions**: Asserting Computational Constraints on Foundation Models

### **TweetGen**: Generating tweets to answer questions

[<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/examples/tweets/tweets_assertions.ipynb)


This notebook highlights an example of [**DSPy Assertions**](https://dspy-docs.vercel.app/docs/building-blocks/assertions), allowing for declaration of computational constraints within DSPy programs. 


This notebook builds upon the foundational concepts of the **DSPy** framework. Prerequisites of following this notebook is having gone through the [DSPy tutorial](../../intro.ipynb), the [**DSPy Assertions documentation**](https://dspy-docs.vercel.app/docs/building-blocks/assertions) and the introductory DSPy Assertions [tutorial on LongFormQA](../longformqa/longformqa_assertions.ipynb).


In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import regex as re

try: # When on google Colab, let's clone the notebook so we download the cache.
    import google.colab  # noqa: F401
    repo_path = 'dspy'
    
    !git -C $repo_path pull origin || git clone https://github.com/stanfordnlp/dspy $repo_path
except:
    repo_path = '.'

if repo_path not in sys.path:
    sys.path.append(repo_path)


import pkg_resources # Install the package if it's not installed
if "dspy-ai" not in {pkg.key for pkg in pkg_resources.working_set}:
    !pip install -U pip
    !pip install dspy-ai==2.4.17
    !pip install openai~=0.28.1
    !pip install -e $repo_path

import dspy
from dspy.predict import Retry
from dspy.datasets import HotPotQA
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
from dsp.utils import deduplicate
from dspy.evaluate.evaluate import Evaluate
from dspy.primitives.assertions import assert_transform_module, backtrack_handler

In [None]:
import openai
openai.api_key = os.getenv('OPENAI_API_KEY')

In [None]:
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)
turbo = dspy.OpenAI(model='gpt-4o-mini', max_tokens=500)
dspy.settings.configure(lm=turbo, trace=[], temperature=0.7)

In [None]:
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0, keep_details=True)
trainset = [x.with_inputs('question', 'answer') for x in dataset.train]
devset = [x.with_inputs('question', 'answer') for x in dataset.dev]

### 3] TweetGen

Let's introduce a new task: TweetGen. We extend the `Multi-Hop QA` program, but now aim to present the answer generation in the form of a tweet. 

The `Tweeter` module captures the iterative multi-hop generation process from `Multi-Hop QA` in query generation, passage retrieval, and context assembly. The `GenerateTweet` layer now utilizes the context alongside the question to generate a tweet that effectively answers the question.

With this program, we aim to generate tweets that adhere to the following guidelines:
1. The tweet has no hashtags. 
2. The tweet includes the correct answer
3. The tweet is within a character limit. 
4. The tweet is engaging
5. The tweet is faithful

In [None]:
class GenerateSearchQuery(dspy.Signature):
    """Write a simple search query that will help answer a complex question."""
    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    query = dspy.OutputField()

class GenerateTweet(dspy.Signature):
    """Generate an engaging tweet that effectively answers a question staying faithful to the context, is less than 280 characters, and has no hashtags."""
    question = dspy.InputField()
    context = dspy.InputField(desc="may contain relevant facts")
    tweet = dspy.OutputField()

class Tweeter(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_tweet = dspy.ChainOfThought(GenerateTweet)

    def forward(self, question, answer):
        context = []
        max_hops=2
        passages_per_hop=3
        generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        retrieve = dspy.Retrieve(k=passages_per_hop)
        for hop in range(max_hops):
            query = generate_query[hop](context=context, question=question).query
            passages = retrieve(query).passages
            context = deduplicate(context + passages)
        generated_tweet = self.generate_tweet(question=question, context=context).tweet
        return dspy.Prediction(generated_tweet=generated_tweet, context=context)
    
tweeter = Tweeter()

### 4] Evaluation - Intrinsic and Extrinsic

#### Intrinsic Metrics: passing internal computational constraints is the goal 

**No Hashtags** - This is a user-personalized constraint to test how well the model can follow a specific, yet simple guideline of not including any hashtags within the generated tweet.

**Correct Answer Inclusion** - This is a general check to ensure the tweet indeed has the correct answer to the question.

**Within Length** - This check follows Twitter platform guidelines of 280 character limits per tweet.

**Engagement** - To verify the engagement quality of the tweet, we define and call another **DSPy** program: ``Predict`` on ``AssessTweet``, relying on the same LM to answer the question: `"Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging."`

**Faithfulness** - To verify the faithfulness of the tweet to its referenced context, we similarly use `AssessTweet` as above but prompt it with the question: `"Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."`


In [None]:
def has_no_hashtags(text):
    return len(re.findall(r"#\w+", text)) == 0

def is_within_length_limit(text, length_limit=280):
    return len(text) <= length_limit

def is_assessment_yes(assessment_answer):
    """Check if the first word of the assessment answer is 'yes'."""
    return assessment_answer.split()[0].lower() == 'yes'

def has_correct_answer(text, answer):
    return answer in text


class AssessTweet(dspy.Signature):
    """Assess the quality of a tweet along the specified dimension."""

    context = dspy.InputField(desc='ignore if N/A')
    assessed_text = dspy.InputField()
    assessment_question = dspy.InputField()
    assessment_answer = dspy.OutputField(desc="Yes or No")

def no_hashtags_metric(gold, pred, trace=None):
    tweet = pred.generated_tweet
    no_hashtags = has_no_hashtags(tweet)
    score = no_hashtags
    return score

def is_correct_metric(gold, pred, trace=None):
    answer, tweet = gold.answer, pred.generated_tweet
    correct = has_correct_answer(tweet, answer)
    score = correct
    return score

def within_length_metric(gold, pred, trace=None):
    tweet = pred.generated_tweet
    within_length_limit = is_within_length_limit(tweet, 280)
    score = within_length_limit
    return score

def engaging_metric(gold, pred, trace=None):
    tweet = pred.generated_tweet
    engaging = "Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging."
    engaging = dspy.Predict(AssessTweet)(context='N/A', assessed_text=tweet, assessment_question=engaging)
    engaging = engaging.assessment_answer.split()[0].lower() == 'yes'
    score = engaging
    return score

def faithful_metric(gold, pred, trace=None):
    context, tweet = pred.context, pred.generated_tweet
    faithful = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."   
    faithful = dspy.Predict(AssessTweet)(context=context, assessed_text=tweet, assessment_question=faithful)
    faithful = faithful.assessment_answer.split()[0].lower() == 'yes'
    score = faithful
    return score

#### Extrinsic Metrics: Assess the overall quality and effectiveness of generated output on downstream task

The extrinsic metric is defined as the overall quality of the generated tweet in following the mentioned constraints, and this is evaluated over a composite metric.

While maintaining the most relevant intrinsic metrics of forming a valid tweet in the correctness and within_length constraints, the overall composite metric returns an averaged score over the 5 intrinsic metrics.

In [None]:
def overall_metric(gold, pred, trace=None):
    answer, context, tweet = gold.answer, pred.context, pred.generated_tweet
    no_hashtags = has_no_hashtags(tweet)
    within_length_limit = is_within_length_limit(tweet, 280)
    correct = has_correct_answer(tweet, answer)
    engaging = "Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging."
    faithful = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."   
    faithful = dspy.Predict(AssessTweet)(context=context, assessed_text=tweet, assessment_question=faithful)
    engaging = dspy.Predict(AssessTweet)(context='N/A', assessed_text=tweet, assessment_question=engaging)
    engaging, faithful = [m.assessment_answer.split()[0].lower() == 'yes' for m in [engaging, faithful]]
    score = (correct + engaging + faithful + no_hashtags + within_length_limit) if correct and within_length_limit else 0
    return score / 5.0

We hence define the evaluation as follows:

In [None]:
metrics = [no_hashtags_metric, is_correct_metric, within_length_metric, engaging_metric, faithful_metric, overall_metric]

for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)
    evaluate(tweeter)

Let's take a look at an example tweet generation:

In [None]:
example = devset[118]
tweet = tweeter(question=example.question, answer = example.answer)
print('Generated Tweet: ', tweet.generated_tweet)
tweet.context

In [None]:
for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset[118:119], num_threads=1, display_progress=True, display_table=5)
    evaluate(tweeter)

In this example, we see that the generated tweet is within the length of 280 characters at 151 characters. It does in fact include the correct answer `Hooke`.

However, it fails to not include hashtags as we see `#knowledge` at the end of the tweet. Additionally, the tweet has been determined to not be engaging, which makes sense from an eye-test as it simply states the answer and nothing more. 

Let's try to fix this and produce tweets using DSPy Assertions. 

### 5] Introducing Assertions: TweeterWithAssertions

To correct these various errors, let's include assertions that simply reiterate our computational constraints within DSPy Assertion semantics. 

In the first **Assertion**, we check for if the generated tweet has any hashtags through regex and if violated, assert: **"Please revise the tweet to remove hashtag phrases following it."**

Similarly, we check for the tweet length and if it is not within 280 characters, we send the feedback message: **"Please ensure the tweet is within {280} characters."**

We check for if the generated tweet has the answer and if not, we assert: **"The tweet does not include the correct answer to the question. Please revise accordingly."**

For the engagement and faithfulness checks, we make use of the setup from above, checking for if the respective assessment is determined as `Yes` or `No`.


In [None]:
class TweeterWithAssertions(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_tweet = dspy.ChainOfThought(GenerateTweet)

    def forward(self, question, answer):
        context = []
        max_hops=2
        passages_per_hop=3
        generate_query = [dspy.ChainOfThought(GenerateSearchQuery) for _ in range(max_hops)]
        retrieve = dspy.Retrieve(k=passages_per_hop)
        for hop in range(max_hops):
            query = generate_query[hop](context=context, question=question).query
            passages = retrieve(query).passages
            context = deduplicate(context + passages)
        generated_tweet = self.generate_tweet(question=question, context=context).tweet
        dspy.Suggest(has_no_hashtags(generated_tweet), "Please revise the tweet to remove hashtag phrases following it.", target_module=self.generate_tweet)
        dspy.Suggest(is_within_length_limit(generated_tweet, 280), f"Please ensure the tweet is within {280} characters.", target_module=self.generate_tweet)
        dspy.Suggest(has_correct_answer(generated_tweet, answer), "The tweet does not include the correct answer to the question. Please revise accordingly.", target_module=self.generate_tweet)
        engaging_question = "Does the assessed text make for a self-contained, engaging tweet? Say no if it is not engaging."
        engaging_assessment = dspy.Predict(AssessTweet)(context=context, assessed_text=generated_tweet, assessment_question=engaging_question)
        dspy.Suggest(is_assessment_yes(engaging_assessment.assessment_answer), "The text is not engaging enough. Please revise to make it more captivating.", target_module=self.generate_tweet)
        faithful_question = "Is the assessed text grounded in the context? Say no if it includes significant facts not in the context."
        faithful_assessment = dspy.Predict(AssessTweet)(context='N/A', assessed_text=generated_tweet, assessment_question=faithful_question)
        dspy.Suggest(is_assessment_yes(faithful_assessment.assessment_answer), "The text contains unfaithful elements or significant facts not in the context. Please revise for accuracy.", target_module=self.generate_tweet)
        return dspy.Prediction(generated_tweet=generated_tweet, context=context)

tweeter_with_assertions = assert_transform_module(TweeterWithAssertions().map_named_predictors(Retry), backtrack_handler) 

Let's evaluate the `TweeterWithAssertions` now over the devset.

In [None]:
metrics = [no_hashtags_metric, is_correct_metric, within_length_metric, engaging_metric, faithful_metric, overall_metric]

for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)
    evaluate(tweeter_with_assertions)

Now let's take a look at how our generated tweet has improved with the addition of assertions.

In [None]:
example = devset[118]
tweet = tweeter_with_assertions(question=example.question, answer = example.answer)
print('Generated Tweet: ', tweet.generated_tweet)
tweet.context

In [None]:
for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset[118:119], num_threads=1, display_progress=True, display_table=5)
    evaluate(tweeter_with_assertions)

We see that the tweet has improved significantly, following all of our set constraints! 

It no longer has hashtags, and is both engaging and faithful, while maintaining the inclusion of the correct answer within 280 characters. Exciting!

### 6] Compilation With Assertions

We can leverage **DSPy**'s`BootstrapFewShotWithRandomSearch` optimizer, to automatically generate few-shot demonstrations and conduct a random search over the candidates to output the best compiled program. We evaluate this over the `overall_metric` composite metric. 

We can first evaluate this on `Tweeter` to see how compilation performs without the inclusion of assertions. 

In [None]:
teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)
compiled_tweeter = teleprompter.compile(student = tweeter, teacher = tweeter, trainset=trainset, valset=devset[:25])

for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)
    evaluate(compiled_tweeter)

Now we test the compilation on 2 settings with assertions:

**Compilation with Assertions**: assertion-driven example bootstrapping and counterexample bootstrapping during compilation. Teacher has assertions while the student does not as the student learns from the teacher's assertion-driven bootstrapped examples. 

**Compilation + Inference with Assertions**: assertion-driven optimizations for both the teacher and student to offer enhanced assertion-driven outputs during both compilation and inference.

In [None]:
teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6)
compiled_with_assertions_tweeter = teleprompter.compile(student=tweeter, teacher = tweeter_with_assertions, trainset=trainset, valset=devset[:25])


for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)
    evaluate(compiled_with_assertions_tweeter)

In [None]:
teleprompter = BootstrapFewShotWithRandomSearch(metric = overall_metric, max_bootstrapped_demos=2, num_candidate_programs=6, num_threads=1)
compiled_tweeter_with_assertions = teleprompter.compile(student=tweeter_with_assertions, teacher = tweeter_with_assertions, trainset=trainset, valset=devset[:25])

for metric in metrics:
    evaluate = Evaluate(metric=metric, devset=devset, num_threads=1, display_progress=True, display_table=5)
    evaluate(compiled_tweeter_with_assertions)