# DSPy Playground: Companion Notebook

This notebook mirrors the narrative in dspy_blog.md so you can execute every experiment while reading the story.

**How to use it**

1. Run the setup cells below (installation + language model configuration).
2. If you have an OpenAI key, set os.environ[\"OPENAI_API_KEY\"] = \"sk-...\" before re-running the configuration cell. Without a key, the notebook defaults to a playful simulated model so every block still prints output.
3. Keep the blog open side-by-side and tinker with the code to explore variations.


In [1]:
# Install DSPy and supporting libraries
!pip install -q dspy-ai fastapi uvicorn transformers accelerate datasets


[notice] A new release of pip is available: 25.0.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import dspy
import os
os.environ["OPENAI_API_KEY"] = "api-key"
lm = dspy.LM(provider="openai", model="gpt-3.5-turbo", max_tokens=1000)
response = lm("Explain what Retrieval-Augmented Generation (RAG) is.")
print(response)

['Retrieval-Augmented Generation (RAG) is a natural language processing model that combines elements of both retrieval-based and generation-based approaches to improve the quality of text generation. In RAG, a retrieval component is used to search for relevant information from a large database of text, which is then used to augment the generation process. This allows the model to incorporate factual information and context from the retrieved text into the generated output, resulting in more accurate and coherent text generation. RAG has been shown to outperform traditional generation models in tasks such as question answering and text summarization by leveraging the benefits of both retrieval and generation techniques.']


In [3]:
lm = dspy.LM(provider="openai", model="gpt-3.5-turbo", max_tokens=512)
response = lm("In two sentences, explain how retrieval-augmented generation strengthens ESG risk analysis for institutional investors.")
print(response)


['Retrieval-augmented generation allows institutional investors to access a vast amount of data and information related to ESG risks, enabling them to make more informed investment decisions. By combining retrieval of relevant data with generation of insights and analysis, investors can better understand the potential impact of ESG risks on their portfolios and take proactive measures to mitigate them.']


In [4]:
import os
import re
import json
import types
import asyncio
import logging
import pickle
import hashlib
import textwrap
from datetime import datetime
from typing import List, Optional
from functools import lru_cache
from concurrent.futures import ThreadPoolExecutor

import pandas as pd
import dspy
#from dspy.models import openai
#from dspy import openai
from dspy.teleprompt import BootstrapFewShot, BootstrapFewShotWithRandomSearch, LabeledFewShot, BootstrapFinetune
from dspy.evaluate import Evaluate

TRACE_LOG: List[str] = []


class EchoLM(dspy.BaseLM):
    """Fallback LM that fabricates structured outputs so every cell prints."""

    def __init__(self):
        super().__init__(model="echo-playground", temperature=0.0, max_tokens=512, cache=False)

    def forward(self, prompt, messages=None, **kwargs):
        system_content = ""
        if messages:
            for message in messages:
                if isinstance(message, dict) and message.get("role") == "system":
                    system_content = message.get("content", "")
        output_fields = []
        if "Your output fields are:" in system_content:
            after = system_content.split("Your output fields are:")[1]
            before = after.split("All interactions")[0]
            # Use simpler string splitting instead of regex
            output_fields = [line.strip() for line in before.split('\n') if line.strip()]
        if not output_fields:
            output_fields = ["answer"]

        if messages:
            last_message = messages[-1]
            if isinstance(last_message, dict):
                user_content = last_message.get("content", "")
            elif hasattr(last_message, "content"):
                user_content = last_message.content
            else:
                user_content = str(last_message)
        else:
            user_content = prompt or ""

        snippet = (user_content or "").replace("\n", " ").strip()
        if len(snippet) > 140:
            snippet = snippet[:137] + "..."

        lines = []
        for field in output_fields:
            label = field.replace("_", " ")
            lines.append(f"[[ ## {field} ## ]]")
            lines.append(f"Simulated {label} for: {snippet or '<empty>'}")
            lines.append("")
        lines.append("[[ ## completed ## ]]")
        content = "\n".join(lines)

        message = types.SimpleNamespace(content=content, role="assistant")
        choice = types.SimpleNamespace(message=message, finish_reason="stop")
        return types.SimpleNamespace(
            choices=[choice],
            usage={
                "prompt_tokens": len(snippet.split()),
                "completion_tokens": len(content.split()),
                "total_tokens": len(snippet.split()) + len(content.split()),
            },
            model=self.model,
        )


def configure_language_model() -> bool:
    """Try to wire up a real LM, then fall back to EchoLM."""

    openai_key = os.getenv("OPENAI_API_KEY")
    model_name = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")

    if openai_key:
        try:
            lm = dspy.LM(provider="openai", model=model_name, temperature=0.7, max_tokens=800)
            dspy.settings.configure(lm=lm, trace=TRACE_LOG)
            print(f"✅ Connected to OpenAI model: {model_name}")
            return True
        except Exception as exc:
            print(f"⚠️ OpenAI setup failed: {exc}")

    fallback_lm = EchoLM()
    dspy.settings.configure(lm=fallback_lm, trace=TRACE_LOG)
    print("🔄 Using the simulated EchoLM fallback. Set OPENAI_API_KEY to talk to a real model.")
    return False


HAVE_REAL_LM = configure_language_model()
print(f"LM ready: {HAVE_REAL_LM}")

✅ Connected to OpenAI model: gpt-3.5-turbo
LM ready: True


In [5]:
# The Old Way: Fragile prompt engineering
def extract_info_traditional(text):
    prompt = f"""Extract the following information from the text:
    - Mission or initiative name
    - Key commitments (list them)
    - Sentiment (positive/negative/neutral)

    Text: {text}

    Format your response as JSON.
    Make sure to include all fields.
    Be concise but comprehensive.
    Double-check your JSON formatting.
    """

    response = call_llm(prompt)
    # Hope the LLM follows our format...
    # Parse JSON (might fail!)
    # Handle edge cases manually
    return parse_somehow(response)


def call_llm(prompt: str) -> str:
    """Pretend to call an LLM so the classic example prints output."""

    preview = (prompt[:180] + "...") if len(prompt) > 180 else prompt
    print("Prompt preview:\n" + textwrap.indent(preview, "  "))
    payload = {
        "Mission or initiative name": "TRUTHS climate mission",
        "Key commitments": [
            "Calibrate satellite climate observations with traceable radiometry",
            "Sustain a partnership between ESA, the UKSA, and the Canadian Space Agency",
            "Launch on Vega-C in 2030 to provide authoritative greenhouse gas baselines"
        ],
        "Sentiment": "positive"
    }
    return json.dumps(payload)


def parse_somehow(response: str) -> dict:
    print(f"Raw response: {response}")
    return {"parsed": json.loads(response)}


sample_text = """In March 2024, the European Space Agency formally approved the TRUTHS climate mission, an observatory designed to calibrate satellite measurements of Earth's radiative balance. The programme is led by the UK Space Agency in partnership with the Canadian Space Agency and national metrology laboratories. Its goals include delivering traceable radiometric data for greenhouse gas inventories, validating commercial climate analytics, and accelerating climate model tuning. Launch is targeted for 2030 aboard a Vega-C rocket from French Guiana."""

sample_traditional = extract_info_traditional(sample_text)
print("Parsed output ->", sample_traditional)


Prompt preview:
  Extract the following information from the text:
      - Mission or initiative name
      - Key commitments (list them)
      - Sentiment (positive/negative/neutral)

      Text: In March ...
Raw response: {"Mission or initiative name": "TRUTHS climate mission", "Key commitments": ["Calibrate satellite climate observations with traceable radiometry", "Sustain a partnership between ESA, the UKSA, and the Canadian Space Agency", "Launch on Vega-C in 2030 to provide authoritative greenhouse gas baselines"], "Sentiment": "positive"}
Parsed output -> {'parsed': {'Mission or initiative name': 'TRUTHS climate mission', 'Key commitments': ['Calibrate satellite climate observations with traceable radiometry', 'Sustain a partnership between ESA, the UKSA, and the Canadian Space Agency', 'Launch on Vega-C in 2030 to provide authoritative greenhouse gas baselines'], 'Sentiment': 'positive'}}


In [6]:
# Signatures: Defining What, Not How
class BasicQA(dspy.Signature):
    """Answer questions with short factual answers."""

    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")


class DocumentQA(dspy.Signature):
    """Answer question based on given context."""

    context = dspy.InputField(desc="relevant background information")
    question = dspy.InputField(desc="question to be answered")
    answer = dspy.OutputField(desc="detailed answer based on context")

print("✅ BasicQA and DocumentQA signatures registered.")


✅ BasicQA and DocumentQA signatures registered.


In [7]:
# Modules: Composable LLM Programs
class SimpleQAModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_answer = dspy.Predict(BasicQA)

    def forward(self, question):
        prediction = self.generate_answer(question=question)
        return prediction.answer


qa = SimpleQAModule()
result = qa(question="Which country hosts the ITER fusion reactor construction site?")
print(f"Answer: {result}")


Answer: France


In [8]:
# Building a Real Application: Research Paper Analyzer
class ExtractMainClaim(dspy.Signature):
    """Extract the main claim or thesis from an academic abstract."""

    abstract = dspy.InputField(desc="academic paper abstract")
    main_claim = dspy.OutputField(desc="the primary claim or contribution in one sentence")


class IdentifyMethods(dspy.Signature):
    """Identify research methods used in the paper."""

    abstract = dspy.InputField()
    methods = dspy.OutputField(desc="list of research methods, separated by semicolons")


class AssessNovelty(dspy.Signature):
    """Assess the novelty of the research contribution."""

    abstract = dspy.InputField()
    main_claim = dspy.InputField()
    novelty_score = dspy.OutputField(desc="integer from 1 to 10")
    novelty_explanation = dspy.OutputField(desc="brief explanation of score")


class GenerateSummary(dspy.Signature):
    """Generate a structured summary of the research."""

    abstract = dspy.InputField()
    main_claim = dspy.InputField()
    methods = dspy.InputField()
    novelty_score = dspy.InputField()
    summary = dspy.OutputField(desc="3-4 sentence summary for a general audience")


class ResearchPaperAnalyzer(dspy.Module):
    def __init__(self):
        super().__init__()
        # Initialize our predictors
        self.extract_claim = dspy.Predict(ExtractMainClaim)
        self.identify_methods = dspy.Predict(IdentifyMethods)
        self.assess_novelty = dspy.Predict(AssessNovelty)
        self.generate_summary = dspy.Predict(GenerateSummary)

    def forward(self, abstract):
        # Extract main claim
        claim = self.extract_claim(abstract=abstract).main_claim

        # Identify methods
        methods = self.identify_methods(abstract=abstract).methods

        # Assess novelty
        novelty_assessment = self.assess_novelty(
            abstract=abstract,
            main_claim=claim
        )

        # Generate final summary
        summary = self.generate_summary(
            abstract=abstract,
            main_claim=claim,
            methods=methods,
            novelty_score=novelty_assessment.novelty_score
        ).summary

        return dspy.Prediction(
            main_claim=claim,
            methods=methods,
            novelty_score=novelty_assessment.novelty_score,
            novelty_explanation=novelty_assessment.novelty_explanation,
            summary=summary
        )


analyzer = ResearchPaperAnalyzer()
test_abstract = """
We present GridSketch, an orchestration framework for hybrid microgrids combining battery storage, electrolyzers, and adaptive demand forecasting. GridSketch learns hierarchical control policies that coordinate solar, wind, and hydrogen subsystems using graph neural optimizers trained on dispatch logs from three island grids. Field trials in Barbados, Martinique, and Guadeloupe cut diesel peaker usage by 38% while maintaining reserve margins during cyclone-driven outages. Modular APIs let operators plug in market signals and carbon pricing objectives without rewriting controllers.
"""

result = analyzer(abstract=test_abstract)
print(f"Main Claim: {result.main_claim}\n")
print(f"Methods: {result.methods}\n")
print(f"Novelty Score: {result.novelty_score}/10\n")
print(f"Explanation: {result.novelty_explanation}\n")
print(f"Summary: {result.summary}")


Main Claim: GridSketch is an orchestration framework for hybrid microgrids that reduces diesel peaker usage by 38% in island grids while maintaining reserve margins during cyclone-driven outages.

Methods: machine learning; hierarchical control; field trials; dispatch logs analysis; modular API integration

Novelty Score: 7/10

Explanation: The research contribution of GridSketch lies in its development as an orchestration framework for hybrid microgrids that effectively reduces diesel peaker usage by 38% in island grids while ensuring reserve margins during cyclone-driven outages. This novel approach combines battery storage, electrolyzers, and adaptive demand forecasting, along with hierarchical control policies and graph neural optimizers trained on dispatch logs. The modular APIs further enhance its adaptability by allowing operators to incorporate market signals and carbon pricing objectives without the need to rewrite controllers.

Summary: GridSketch is a novel orchestration fra

In [9]:
# The Power of Optimization: Bootstrap Few-Shot Learning
def create_training_examples():
    """Create some training data for optimization"""
    examples = []

    examples.append(dspy.Example(
        abstract="""We introduce GridSketch, an orchestration layer for hybrid microgrids that fuses battery scheduling, electrolyzer control, and load forecasting through graph neural optimizers trained on island utility data.""",
        main_claim="GridSketch coordinates hybrid microgrids using hierarchical graph control.",
        novelty_score="8"
    ).with_inputs('abstract'))

    examples.append(dspy.Example(
        abstract="""HarborCast couples AIS shipping telemetry with ensemble weather forecasts to generate 48-hour energy demand scenarios for ports, guiding crane electrification and hydrogen storage commitments.""",
        main_claim="HarborCast fuses shipping traffic and weather ensembles to forecast port energy demand.",
        novelty_score="7"
    ).with_inputs('abstract'))

    examples.append(dspy.Example(
        abstract="""CryoSense deploys fiber Bragg grating arrays across fusion cryostats and learns thermal fingerprints that anticipate disruptive heat loads minutes before quench events.""",
        main_claim="CryoSense maps fusion cryostat heat loads with fiber Bragg gratings to prevent quenches.",
        novelty_score="9"
    ).with_inputs('abstract'))

    return examples


def validate_analysis(example, pred, trace=None):
    """Check if the analysis is reasonable"""
    claim_valid = len(pred.main_claim) > 10 and len(pred.main_claim) < 200

    try:
        score = int(pred.novelty_score)
        score_valid = 1 <= score <= 10
    except Exception:
        score_valid = False

    summary_valid = len(pred.summary) > 50 and len(pred.summary) < 500

    return claim_valid and score_valid and summary_valid


trainset = create_training_examples()
optimizer = BootstrapFewShotWithRandomSearch(
    metric=validate_analysis,
    max_bootstrapped_demos=2,
    max_labeled_demos=2,
    num_candidate_programs=5,
    num_threads=1
)

compiled_analyzer = optimizer.compile(
    ResearchPaperAnalyzer(),
    trainset=trainset
)

print("Optimization complete! Let's test the improved analyzer:")
optimized = compiled_analyzer(abstract=test_abstract)
print(f"\nOptimized Results:")
print(f"Main Claim: {optimized.main_claim}")
print(f"Novelty Score: {optimized.novelty_score}/10")
print(f"Summary: {optimized.summary}")


Going to sample between 1 and 2 traces per predictor.
Will attempt to bootstrap 5 candidate sets.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00, 11.94it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 13.98it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 20.68it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 20.39it/s]

2025/09/21 06:06:34 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



New best score: 100.0 for seed -3
Scores so far: [100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00,  6.58it/s]


Average Metric: 2.00 / 2 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 12.98it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 12.98it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 12.87it/s]

2025/09/21 06:06:35 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  7.08it/s]


 67%|██████▋   | 2/3 [00:00<00:00,  5.26it/s]


 67%|██████▋   | 2/3 [00:00<00:00,  5.44it/s]




Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00,  4.85it/s]


Average Metric: 2.00 / 2 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00,  9.65it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00,  9.65it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00,  6.40it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00,  6.83it/s]

2025/09/21 06:06:35 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  7.94it/s]


 67%|██████▋   | 2/3 [00:00<00:00,  7.72it/s]


 67%|██████▋   | 2/3 [00:00<00:00,  7.72it/s]




Bootstrapped 2 full traces after 2 examples for up to 1 rounds, amounting to 2 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00,  9.28it/s]


Average Metric: 2.00 / 2 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 18.39it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 18.39it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 18.44it/s]

2025/09/21 06:06:36 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  9.19it/s]


 33%|███▎      | 1/3 [00:00<00:00,  9.11it/s]




Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00,  9.85it/s]


Average Metric: 2.00 / 2 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 19.51it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 19.51it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 19.95it/s]

2025/09/21 06:06:36 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0, 100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  9.97it/s]


 33%|███▎      | 1/3 [00:00<00:00,  9.87it/s]




Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00, 11.77it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 14.56it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 21.84it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 20.82it/s]

2025/09/21 06:06:36 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0, 100.0, 100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  8.01it/s]


 33%|███▎      | 1/3 [00:00<00:00,  7.95it/s]




Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00, 10.87it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 14.10it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 20.86it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 20.71it/s]

2025/09/21 06:06:37 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]
Best score so far: 100.0



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  7.73it/s]


 33%|███▎      | 1/3 [00:00<00:00,  7.67it/s]




Bootstrapped 1 full traces after 1 examples for up to 1 rounds, amounting to 1 attempts.



  0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 1.00 / 1 (100.0%):   0%|          | 0/3 [00:00<?, ?it/s]


Average Metric: 2.00 / 2 (100.0%):  33%|███▎      | 1/3 [00:00<00:00,  7.51it/s]


Average Metric: 2.00 / 2 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 14.80it/s]


Average Metric: 3.00 / 3 (100.0%):  67%|██████▋   | 2/3 [00:00<00:00, 14.80it/s]


Average Metric: 3.00 / 3 (100.0%): 100%|██████████| 3/3 [00:00<00:00, 16.46it/s]

2025/09/21 06:06:37 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)



Scores so far: [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]
Best score so far: 100.0
8 candidate programs found.
Optimization complete! Let's test the improved analyzer:



Optimized Results:
Main Claim: GridSketch is an orchestration framework for hybrid microgrids that reduces diesel peaker usage by 38% in island grids while maintaining reserve margins during cyclone-driven outages.
Novelty Score: 7/10
Summary: GridSketch is a novel orchestration framework for hybrid microgrids that effectively reduces diesel peaker usage by 38% in island grids, while ensuring reserve margins are maintained during cyclone-driven outages. The framework utilizes machine learning techniques, hierarchical control policies, and field trials in Barbados, Martinique, and Guadeloupe to achieve this significant reduction. Additionally, its modular APIs allow for easy integration of market signals and carbon pricing objectives without the need for rewriting controllers.


In [10]:
# Advanced Pattern: Chain of Thought Reasoning
class ChainOfThoughtQA(dspy.Module):
    def __init__(self):
        super().__init__()
        # Use Chain of Thought for complex reasoning
        self.generate_answer = dspy.ChainOfThought("question -> reasoning, answer")

    def forward(self, question):
        result = self.generate_answer(question=question)
        return result


cot_qa = ChainOfThoughtQA()
complex_question = """
A climate observatory stores 120 GB of raw satellite data per day. Adding a hyperspectral sensor increases the volume by 35%, and automated filtering removes 12% of the combined data. How many gigabytes are archived each day after these changes?
"""

cot_result = cot_qa(question=complex_question)
print(f"Reasoning: {cot_result.reasoning}\n")
print(f"Answer: {cot_result.answer}")


Reasoning: To find the total volume of data archived each day after the changes, we need to follow these steps:
1. Calculate the increase in volume after adding the hyperspectral sensor (35% of 120 GB).
2. Determine the total volume after adding the hyperspectral sensor.
3. Calculate the data removed after filtering (12% of the total volume).
4. Subtract the data removed from the total volume to find the final amount of data archived each day.

Answer: Let's calculate it step by step:
1. Increase after adding hyperspectral sensor: 35% of 120 GB = 0.35 * 120 = 42 GB
2. Total volume after adding sensor: 120 GB + 42 GB = 162 GB
3. Data removed after filtering: 12% of 162 GB = 0.12 * 162 = 19.44 GB
4. Final data archived each day: 162 GB - 19.44 GB = 142.56 GB


In [11]:
# Building a Multi-Hop Question Answering System
class MultiHopQA(dspy.Module):
    def __init__(self, passages_per_hop=3):
        super().__init__()
        self.passages_per_hop = passages_per_hop

        self.generate_query = dspy.ChainOfThought("context, question -> reasoning, query")
        self.retrieve = dspy.Retrieve(k=passages_per_hop)
        self.generate_answer = dspy.ChainOfThought("context, question -> reasoning, answer")

    def forward(self, question):
        context = []
        for hop in range(2):
            query_result = self.generate_query(
                context="\n".join(context) if context else "No context yet",
                question=question
            )
            passages = self.retrieve(query_result.query).passages
            context.extend(passages[:2])

        final_answer = self.generate_answer(
            context="\n".join(context),
            question=question
        )

        return dspy.Prediction(
            answer=final_answer.answer,
            reasoning=final_answer.reasoning,
            supporting_passages=context
        )

print("✅ MultiHopQA module skeleton ready. Configure a retriever before calling it.")


✅ MultiHopQA module skeleton ready. Configure a retriever before calling it.


In [12]:
# Practical Example: Building a Code Documentation Generator
class ExtractFunctionInfo(dspy.Signature):
    """Extract key information from Python function code."""

    code = dspy.InputField(desc="Python function code")
    function_name = dspy.OutputField()
    parameters = dspy.OutputField(desc="list of parameters with types if available")
    return_type = dspy.OutputField(desc="return type if specified, otherwise 'inferred'")


class GenerateDocstring(dspy.Signature):
    """Generate a comprehensive docstring for a Python function."""

    code = dspy.InputField()
    function_name = dspy.InputField()
    parameters = dspy.InputField()
    return_type = dspy.InputField()
    docstring = dspy.OutputField(desc="Google-style docstring without the triple quotes")


class GenerateUsageExample(dspy.Signature):
    """Generate a usage example for the function."""

    function_name = dspy.InputField()
    parameters = dspy.InputField()
    docstring = dspy.InputField()
    example_code = dspy.OutputField(desc="simple example showing function usage")


class CodeDocumentationGenerator(dspy.Module):
    def __init__(self):
        super().__init__()
        self.extract_info = dspy.Predict(ExtractFunctionInfo)
        self.generate_docstring = dspy.Predict(GenerateDocstring)
        self.generate_example = dspy.Predict(GenerateUsageExample)

    def forward(self, code):
        info = self.extract_info(code=code)

        docstring = self.generate_docstring(
            code=code,
            function_name=info.function_name,
            parameters=info.parameters,
            return_type=info.return_type
        ).docstring

        example = self.generate_example(
            function_name=info.function_name,
            parameters=info.parameters,
            docstring=docstring
        ).example_code

        documentation = (
            '"""\n'
            + docstring
            + '\n\nExample:\n    >>> '
            + example
            + '\n"""'
        )

        return dspy.Prediction(
            function_name=info.function_name,
            documentation=documentation,
            docstring=docstring,
            example=example
        )


doc_generator = CodeDocumentationGenerator()
test_code = '''
import collections

def aggregate_power_windows(readings: list[dict], window_minutes: int = 15) -> dict[str, float]:
    """Collapse high-frequency sensor readings into rolling demand windows."""
    buckets: dict[str, list[float]] = collections.defaultdict(list)
    for entry in readings:
        buckets[entry["meter_id"]].append(entry["kilowatts"])
    return {
        meter: sum(values) / len(values)
        for meter, values in buckets.items()
        if values
    }
'''

documentation_result = doc_generator(code=test_code)
print(f"Function: {documentation_result.function_name}\n")
print("Generated Documentation:")
print(documentation_result.documentation)


Function: aggregate_power_windows

Generated Documentation:
"""
"""
Collapse high-frequency sensor readings into rolling demand windows.

Args:
    readings (list[dict]): A list of dictionaries containing sensor readings.
    window_minutes (int, optional): The size of the rolling demand windows in minutes. Defaults to 15.

Returns:
    dict[str, float]: A dictionary where keys are meter IDs and values are the average kilowatt readings within the window.
"""

Example:
    >>> # Example of using aggregate_power_windows function
readings = [
    {"meter_id": "A", "kilowatt": 10.5},
    {"meter_id": "B", "kilowatt": 20.3},
    {"meter_id": "A", "kilowatt": 9.2},
    {"meter_id": "B", "kilowatt": 21.1},
    {"meter_id": "A", "kilowatt": 11.7}
]

window_minutes = 10
result = aggregate_power_windows(readings, window_minutes)
print(result)
# Output: {'A': 9.85, 'B': 20.7}
"""


In [13]:
# Evaluation and Testing
def create_test_set():
    """Create test examples for our documentation generator"""
    test_examples = []

    test_examples.append(dspy.Example(
        code='''def compute_capacity_factor(energy_mwh: float, rated_power_mw: float) -> float:\n    return energy_mwh / (rated_power_mw * 8760)''',
        function_name="compute_capacity_factor"
    ).with_inputs('code'))

    test_examples.append(dspy.Example(
        code='''def rolling_mean(series: list[float], window: int) -> list[float]:\n    return [sum(series[i:i+window]) / window for i in range(len(series) - window + 1)]''',
        function_name="rolling_mean"
    ).with_inputs('code'))

    test_examples.append(dspy.Example(
        code='''def fuse_weather_signals(temperature: float, wind_speed: float, humidity: float) -> float:\n    return 0.4 * temperature + 0.35 * wind_speed + 0.25 * humidity''',
        function_name="fuse_weather_signals"
    ).with_inputs('code'))

    return test_examples


def documentation_metric(example, pred, trace=None):
    """Evaluate documentation quality"""
    name_correct = pred.function_name == example.function_name
    doc_exists = len(pred.documentation) > 50
    example_exists = len(pred.example) > 10
    return name_correct and doc_exists and example_exists


evaluator = Evaluate(
    devset=create_test_set(),
    num_threads=1,
    display_progress=False,
    display_table=5
)

raw_score = evaluator(doc_generator, metric=documentation_metric)
score_value = getattr(raw_score, "score", raw_score)
score_fraction = score_value / 100 if score_value > 1 else score_value
print(f"\nEvaluation Score: {score_fraction:.2%}")


2025/09/21 06:06:38 INFO dspy.evaluate.evaluate: Average Metric: 3 / 3 (100.0%)


Unnamed: 0,code,example_function_name,pred_function_name,documentation,docstring,example,documentation_metric
0,"def compute_capacity_factor(energy_mwh: float, rated_power_mw: flo...",compute_capacity_factor,compute_capacity_factor,""""""" Calculates the capacity factor of a power plant based on the e...",Calculates the capacity factor of a power plant based on the energ...,# Example usage of the compute_capacity_factor function\nenergy_mw...,✔️ [True]
1,"def rolling_mean(series: list[float], window: int) -> list[float]:...",rolling_mean,rolling_mean,""""""" Calculates the rolling mean of a series with a specified windo...",Calculates the rolling mean of a series with a specified window si...,"# Example of using rolling_mean function\nseries = [1.2, 2.3, 3.4,...",✔️ [True]
2,"def fuse_weather_signals(temperature: float, wind_speed: float, hu...",fuse_weather_signals,fuse_weather_signals,""""""" Calculates a fused weather signal based on temperature, wind s...","Calculates a fused weather signal based on temperature, wind speed...",# Example usage of fuse_weather_signals function\nfused_signal = f...,✔️ [True]



Evaluation Score: 100.00%


In [14]:
class ValidatedQA(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_answer = dspy.ChainOfThought("question -> answer")

    def forward(self, question):
        pred = self.generate_answer(question=question)

        dspy.Suggest(
            len(pred.answer) > 10,
            "Answer should be at least 10 characters long"
        )

        dspy.Suggest(
            not pred.answer.lower().startswith("i don't know"),
            "Answer should be informative"
        )

        return pred


validated_qa = ValidatedQA()
try:
    answer = validated_qa(question="What is the primary purpose of a synthetic aperture radar constellation in climate monitoring?")
    print(f"Answer: {answer.answer}")
except Exception as exc:
    print(f"Validation failed: {exc}")


Validation failed: module 'dspy' has no attribute 'Suggest'


In [15]:
# Custom Teleprompters for Optimization
class CustomOptimizer:
    def __init__(self, metric, examples):
        self.metric = metric
        self.examples = examples

    def compile(self, module):
        strategies = []

        if len(self.examples) > 0:
            labeled = LabeledFewShot(k=min(3, len(self.examples)))
            strategies.append(labeled.compile(module, trainset=self.examples))

        bootstrap = BootstrapFewShot(
            metric=self.metric,
            max_bootstrapped_demos=4
        )
        strategies.append(bootstrap.compile(module, trainset=self.examples))

        best_score = 0
        best_module = module

        for strategy in strategies:
            score = self.evaluate_module(strategy)
            if score > best_score:
                best_score = score
                best_module = strategy

        return best_module

    def evaluate_module(self, module):
        correct = 0
        for example in self.examples[:3]:
            try:
                pred = module(example.inputs())
                if self.metric(example, pred):
                    correct += 1
            except Exception:
                pass
        denominator = min(3, len(self.examples)) or 1
        return correct / denominator


custom_optimizer = CustomOptimizer(metric=validate_analysis, examples=create_training_examples())
optimized_module = custom_optimizer.compile(ResearchPaperAnalyzer())
print("Custom optimizer selected:", type(optimized_module).__name__)



  0%|          | 0/3 [00:00<?, ?it/s]


 33%|███▎      | 1/3 [00:00<00:00,  9.41it/s]


 67%|██████▋   | 2/3 [00:00<00:00,  9.67it/s]


100%|██████████| 3/3 [00:00<00:00,  9.32it/s]


100%|██████████| 3/3 [00:00<00:00,  9.26it/s]




Bootstrapped 3 full traces after 2 examples for up to 1 rounds, amounting to 3 attempts.


Custom optimizer selected: ResearchPaperAnalyzer


In [16]:
# Parallel Processing for Batch Operations
class BatchProcessor(dspy.Module):
    def __init__(self, base_module, max_workers=4):
        super().__init__()
        self.base_module = base_module
        self.max_workers = max_workers

    def forward(self, items):
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            futures = [executor.submit(self.base_module, item) for item in items]
            results = [future.result() for future in futures]
        return results


qa_module = SimpleQAModule()
batch_processor = BatchProcessor(qa_module)

questions = [
    "Which country operates the Atacama Large Millimeter/submillimeter Array?",
    "Where is the headquarters of the International Renewable Energy Agency located?",
    "Which city hosts Europe's largest battery gigafactory Northvolt Ett?",
    "Which country launched the Himawari-9 weather satellite?"
]

batch_results = batch_processor(questions)
for q, r in zip(questions, batch_results):
    print(f"Q: {q}\nA: {r}\n")


Q: Which country operates the Atacama Large Millimeter/submillimeter Array?
A: Chile

Q: Where is the headquarters of the International Renewable Energy Agency located?
A: Abu Dhabi, United Arab Emirates

Q: Which city hosts Europe's largest battery gigafactory Northvolt Ett?
A: Skellefteå

Q: Which country launched the Himawari-9 weather satellite?
A: Japan



In [17]:
# Real-World Integration: Building an API
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

app = FastAPI()
doc_generator_api = CodeDocumentationGenerator()


class CodeRequest(BaseModel):
    code: str


class DocumentationResponse(BaseModel):
    function_name: str
    documentation: str
    example: str


@app.post("/generate-docs", response_model=DocumentationResponse)
async def generate_documentation(request: CodeRequest):
    try:
        result = doc_generator_api(code=request.code)
        return DocumentationResponse(
            function_name=result.function_name,
            documentation=result.documentation,
            example=result.example
        )
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc))


@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": "DSPy Documentation Generator"}

print("✅ FastAPI app configured. Run uvicorn main:app --reload in a terminal to try it out.")


✅ FastAPI app configured. Run uvicorn main:app --reload in a terminal to try it out.


In [18]:
# Performance Optimization Tips
def save_compiled_module(module, filename):
    """Save a compiled DSPy module to disk"""
    with open(filename, 'wb') as f:
        pickle.dump(module, f)


def load_compiled_module(filename):
    """Load a pre-compiled DSPy module"""
    with open(filename, 'rb') as f:
        return pickle.load(f)


class CachedModule(dspy.Module):
    def __init__(self, base_module):
        super().__init__()
        self.base_module = base_module
        self.cache = {}

    def _hash_input(self, input_data):
        """Create a hash of the input for caching"""
        return hashlib.md5(str(input_data).encode()).hexdigest()

    def forward(self, **kwargs):
        cache_key = self._hash_input(kwargs)
        if cache_key not in self.cache:
            self.cache[cache_key] = self.base_module(**kwargs)
        return self.cache[cache_key]


def process_in_batches(module, items, batch_size=10):
    """Process items in batches to optimize LLM calls"""
    results = []
    for i in range(0, len(items), batch_size):
        batch = items[i:i + batch_size]
        batch_results = [module(item) for item in batch]
        results.extend(batch_results)
    return results


if 'compiled_analyzer' in globals():
    save_compiled_module(compiled_analyzer, 'compiled_analyzer.pkl')
    restored = load_compiled_module('compiled_analyzer.pkl')
    print("📦 Saved and reloaded the compiled analyzer from disk.")
else:
    print("ℹ️ Compile a module first, then call save_compiled_module to persist it.")


📦 Saved and reloaded the compiled analyzer from disk.


In [19]:
# Debugging DSPy Programs
dspy.settings.configure(lm=dspy.settings.lm, trace=TRACE_LOG)

class DebugModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predict = dspy.Predict("question -> answer")

    def forward(self, question):
        print(f"Input Question: {question}")
        result = self.predict(question=question)
        print("\nNote: Cannot directly print prompt using .get_prompt() in this DSPy version.")
        print(f"\nResult: {result.answer}")
        return result


debug_module = DebugModule()
debug_result = debug_module(question="What is a geospatial digital twin and how is it used in climate resilience planning?")

print("\nTrace events collected:")
if TRACE_LOG:
    for item in TRACE_LOG[-5:]:
        print(f"- {item}")
else:
    print("(Trace log is empty in this simulated run.)")


Input Question: What is a geospatial digital twin and how is it used in climate resilience planning?

Note: Cannot directly print prompt using .get_prompt() in this DSPy version.

Result: A geospatial digital twin is a digital representation of a physical asset, process, or system that includes geospatial data. It is used in climate resilience planning to simulate and analyze potential impacts of climate change on infrastructure, natural resources, and communities. By incorporating geospatial data into the digital twin, planners can better understand vulnerabilities, identify adaptation strategies, and optimize resilience measures to mitigate the effects of climate change.

Trace events collected:
- (Predict(BasicQA(question -> answer
    instructions='Answer questions with short factual answers.'
    question = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'})
    answer = Field(annotation=str required=Tr

In [20]:
# Lessons Learned: Start Simple, Then Optimize
simple = dspy.Predict("question -> answer")
complex = dspy.ChainOfThought("context, question -> reasoning, answer")
print("Simple predictor ready:", simple)
print("Chain-of-thought predictor ready:", complex)


Simple predictor ready: Predict(StringSignature(question -> answer
    instructions='Given the fields `question`, produce the fields `answer`.'
    question = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'})
    answer = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'output', 'prefix': 'Answer:', 'desc': '${answer}'})
))
Chain-of-thought predictor ready: predict = Predict(StringSignature(context, question -> reasoning, answer
    instructions='Given the fields `context`, `question`, produce the fields `reasoning`, `answer`.'
    context = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Context:', 'desc': '${context}'})
    question = Field(annotation=str required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Question:', 'desc': '${question}'})
    reasoning = Field(annotation=str required=True json_schema_extra=

In [21]:
# Lessons Learned: Invest in Good Training Data
def create_high_quality_example(input_data, expected_output):
    """Create well-structured training examples"""
    example = dspy.Example(
        **input_data,
        **expected_output
    ).with_inputs(*input_data.keys())

    assert all(v is not None for v in input_data.values())
    assert all(v is not None for v in expected_output.values())

    return example


example = create_high_quality_example(
    {"question": "Which ports are most exposed to extreme heat disruptions by 2030?"},
    {"answer": "Ports in the Arabian Gulf and Southeast Asia face the highest combined heat risk."}
)
print(example.inputs())


Example({'question': 'Which ports are most exposed to extreme heat disruptions by 2030?'}) (input_keys={'question'})


In [22]:
# Lessons Learned: Use Assertions Wisely
class AssertiveModule(dspy.Module):
    def forward(self, input_text):
        result = self.process(input_text)

        dspy.Suggest(len(result) > 10, "Output too short")
        dspy.Assert(result is not None, "Output cannot be None")

        return result

print("✳️ AssertiveModule defined as a template. Plug in your own process() implementation before calling it.")


✳️ AssertiveModule defined as a template. Plug in your own process() implementation before calling it.


In [23]:
# Lessons Learned: Monitor and Log Performance
class MonitoredModule(dspy.Module):
    def __init__(self, base_module):
        super().__init__()
        self.base_module = base_module
        self.metrics = []

    def forward(self, **kwargs):
        start_time = datetime.now()

        try:
            result = self.base_module(**kwargs)
            success = True
            error = None
        except Exception as exc:
            success = False
            error = str(exc)
            result = None

        elapsed = (datetime.now() - start_time).total_seconds()

        self.metrics.append({
            'timestamp': datetime.now(),
            'success': success,
            'elapsed_time': elapsed,
            'error': error
        })

        if not success:
            raise Exception(error)

        return result

    def get_metrics_summary(self):
        if not self.metrics:
            return "No metrics collected"

        success_rate = sum(m['success'] for m in self.metrics) / len(self.metrics)
        avg_time = sum(m['elapsed_time'] for m in self.metrics) / len(self.metrics)

        return f"Success Rate: {success_rate:.2%}, Avg Time: {avg_time:.2f}s"


monitored = MonitoredModule(lambda **kwargs: f"Synthesised payload: {kwargs}")
demo_output = monitored(payload="Monitor turbine cluster")
print("Monitored output:", demo_output)
print("Metrics summary:", monitored.get_metrics_summary())


Monitored output: Synthesised payload: {'payload': 'Monitor turbine cluster'}
Metrics summary: Success Rate: 100.00%, Avg Time: 0.00s


## Fine-Tuning and Distillation

The last step in the workflow is to distil richer reasoning traces into a lighter module. We start by letting the chain-of-thought teacher answer a set of energy-transition questions, then compile a prompt-optimised student. Finally, we show how a weight-level fine-tune would be wired up when a trainable local LM is available.

In [24]:

QUESTION_BANK = [
    {
        "canonical_question": "Which country hosts the ITER fusion reactor construction site?",
        "answer": "France",
        "aliases": ["France", "Southern France", "South of France"],
        "paraphrases": [
            "In which European country is the ITER tokamak being assembled?",
            "Which nation hosts the construction of the ITER fusion reactor?",
            "Where is the ITER fusion megaproject located?"
        ],
        "context": "The ITER experimental fusion reactor is being constructed in southern France near the Cadarache research center."
    },
    {
        "canonical_question": "What is the capital of the country that operates the ALMA observatory?",
        "answer": "Santiago",
        "aliases": ["Santiago", "Santiago de Chile"],
        "paraphrases": [
            "Name the capital city of the nation that operates the ALMA observatory.",
            "What city is the capital of Chile, the country running ALMA?",
            "Which city is the capital of the country that owns the ALMA array?"
        ],
        "context": "ALMA is operated by Chile; Chile's capital city is Santiago."
    },
    {
        "canonical_question": "Which nation leads the world's installed offshore wind capacity in 2024?",
        "answer": "China",
        "aliases": ["China", "People's Republic of China", "PRC"],
        "paraphrases": [
            "Which nation tops global installed offshore wind capacity in 2024?",
            "Who leads the world in offshore wind installations going into 2024?",
            "Identify the country with the most offshore wind megawatts in 2024."
        ],
        "context": "By 2024 China has the world's largest installed offshore wind capacity."
    },
    {
        "canonical_question": "Which city is home to the International Renewable Energy Agency headquarters?",
        "answer": "Abu Dhabi",
        "aliases": ["Abu Dhabi", "Abu Dhabi, UAE", "Abu Dhabi (UAE)"],
        "paraphrases": [
            "Which city hosts the headquarters of the International Renewable Energy Agency?",
            "Where is IRENA's headquarters located?",
            "Name the UAE city that houses the International Renewable Energy Agency."
        ],
        "context": "The International Renewable Energy Agency (IRENA) maintains its headquarters in Abu Dhabi, United Arab Emirates."
    },
    {
        "canonical_question": "Which company operates the Hornsea offshore wind complex?",
        "answer": "?rsted",
        "aliases": ["?rsted", "Orsted", "Orsted A/S"],
        "paraphrases": [
            "Which company runs the Hornsea offshore wind farm complex?",
            "Name the developer responsible for operating the Hornsea wind farms.",
            "Who manages the Hornsea offshore wind projects in the North Sea?"
        ],
        "context": "The Hornsea offshore wind farms are developed and operated by ?rsted."
    },
    {
        "canonical_question": "Which river valley hosts the Three Gorges Dam?",
        "answer": "Yangtze River",
        "aliases": ["Yangtze River", "Yangtze", "Chang Jiang"],
        "paraphrases": [
            "The Three Gorges Dam spans which river valley?",
            "Identify the river valley home to the Three Gorges Dam.",
            "Which river does the Three Gorges Dam harness?"
        ],
        "context": "China's Three Gorges Dam spans the Yangtze River valley."
    },
    {
        "canonical_question": "Which desert houses the Noor Ouarzazate solar complex?",
        "answer": "Sahara Desert",
        "aliases": ["Sahara Desert", "Sahara", "Moroccan Sahara"],
        "paraphrases": [
            "In which desert is the Noor Ouarzazate solar complex located?",
            "Name the desert that hosts Morocco's Noor solar complex.",
            "Which desert houses the Noor Ouarzazate CSP facility?"
        ],
        "context": "Morocco's Noor Ouarzazate solar complex is located on the edge of the Sahara Desert."
    },
    {
        "canonical_question": "Which country launched the Himawari-9 weather satellite?",
        "answer": "Japan",
        "aliases": ["Japan", "Japanese government", "Japan Meteorological Agency"],
        "paraphrases": [
            "Which country launched the Himawari-9 weather satellite?",
            "Identify the nation responsible for placing Himawari-9 into orbit.",
            "Himawari-9 was launched by which country's space agency?"
        ],
        "context": "The Himawari weather satellites are operated and launched by Japan's meteorological agency."
    },
    {
        "canonical_question": "Which U.S. state hosts the National Renewable Energy Laboratory?",
        "answer": "Colorado",
        "aliases": ["Colorado", "State of Colorado", "Colorado State"],
        "paraphrases": [
            "Which U.S. state is home to the National Renewable Energy Laboratory?",
            "Identify the U.S. state where NREL is headquartered.",
            "In which state can you find the National Renewable Energy Laboratory campus?"
        ],
        "context": "The U.S. National Renewable Energy Laboratory (NREL) is headquartered in Golden, Colorado."
    },
    {
        "canonical_question": "Which city hosted the COP28 climate summit in 2023?",
        "answer": "Dubai",
        "aliases": ["Dubai", "Dubai, UAE", "Dubai in the UAE"],
        "paraphrases": [
            "Which city hosted the COP28 climate summit in 2023?",
            "Name the city that welcomed COP28 in 2023.",
            "Where was the 2023 COP28 climate conference held?"
        ],
        "context": "The COP28 climate summit in 2023 was hosted in Dubai, United Arab Emirates."
    },
    {
        "canonical_question": "Which company built the battery gigafactory Northvolt Ett?",
        "answer": "Northvolt",
        "aliases": ["Northvolt", "Northvolt AB", "Northvolt Ett"],
        "paraphrases": [
            "Which company built the Northvolt Ett battery gigafactory?",
            "Name the firm that constructed the Northvolt Ett plant in Sweden.",
            "Who developed the Northvolt Ett battery factory?"
        ],
        "context": "The Northvolt Ett battery gigafactory in Sweden was built by the company Northvolt."
    },
    {
        "canonical_question": "Which sea hosts the Dogger Bank offshore wind project?",
        "answer": "North Sea",
        "aliases": ["North Sea", "The North Sea"],
        "paraphrases": [
            "The Dogger Bank offshore wind farm is located in which sea?",
            "Identify the sea that hosts the Dogger Bank wind project.",
            "Dogger Bank wind arrays operate in what sea?"
        ],
        "context": "The Dogger Bank offshore wind arrays are located in the North Sea."
    }
]

CANONICAL_SYNONYMS = {}
QUESTION_TO_CANONICAL = {}
QUESTION_CONTEXT = {}
teacher_trace = []
distillation_examples = []

def _register_answer(answer, aliases):
    canonical = answer.strip().lower()
    synonyms = CANONICAL_SYNONYMS.setdefault(canonical, set())
    synonyms.add(canonical)
    for alias in aliases:
        synonyms.add(alias.strip().lower())

for entry in QUESTION_BANK:
    _register_answer(entry["answer"], entry.get("aliases", []))
    variations = [entry["canonical_question"]] + entry.get("paraphrases", [])
    context = entry.get("context", "").strip()
    for variant in variations:
        key = variant.strip().lower()
        QUESTION_TO_CANONICAL[key] = entry["answer"]
        if context:
            QUESTION_CONTEXT[key] = context
        pred = cot_qa(question=variant)
        teacher_trace.append({
            "question": variant,
            "teacher_answer": pred.answer,
            "reasoning": pred.reasoning
        })
        distillation_examples.append(
            dspy.Example(question=variant, answer=entry["answer"]).with_inputs('question')
        )

print(f"Generated {len(distillation_examples)} distillation pairs across {len(QUESTION_BANK)} canonical questions.")
print("Sample teacher trace:")
for sample in teacher_trace[:5]:
    print(f"- Q: {sample['question']}")
    print(f"  Teacher A: {sample['teacher_answer']}")
    print(f"  Reasoning: {sample['reasoning']}")


Generated 48 distillation pairs across 12 canonical questions.
Sample teacher trace:
- Q: Which country hosts the ITER fusion reactor construction site?
  Teacher A: France
  Reasoning: The ITER fusion reactor construction site is located in the southern region of France.
- Q: In which European country is the ITER tokamak being assembled?
  Teacher A: France
  Reasoning: The ITER tokamak is being assembled in France. The project is a collaboration between 35 countries, including the European Union, China, India, Japan, Russia, South Korea, and the United States. The construction site is located in Cadarache, a commune in the Provence-Alpes-Côte d'Azur region in southeastern France.
- Q: Which nation hosts the construction of the ITER fusion reactor?
  Teacher A: France
  Reasoning: The ITER fusion reactor is being constructed in France.
- Q: Where is the ITER fusion megaproject located?
  Teacher A: Southern France
  Reasoning: The ITER fusion megaproject is located in southern France.

In [25]:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

INFERENCE_PROMPT = "Question: {question}\nAnswer in one short phrase:"

def build_inference_prompt(question: str, use_context: bool = True) -> str:
    if use_context:
        key = question.strip().lower()
        context = QUESTION_CONTEXT.get(key)
        if context:
            return f"Context: {context}\n" + INFERENCE_PROMPT.format(question=question)
    return INFERENCE_PROMPT.format(question=question)

def canonicalize_answer(text):
    normalized = text.strip().lower()
    for canonical, aliases in CANONICAL_SYNONYMS.items():
        if normalized in aliases:
            return canonical
    return normalized


def compute_accuracy(model, tokenizer, examples, use_context: bool = True):
    correct = 0
    predictions = []

    for example in examples:
        prompt = build_inference_prompt(example.question, use_context=use_context)
        inputs = tokenizer(prompt, return_tensors="pt")
        outputs = model.generate(**inputs, max_new_tokens=48)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        predictions.append((example.question, prediction))

        predicted_canonical = canonicalize_answer(prediction)
        target_canonical = canonicalize_answer(example.answer)
        if predicted_canonical == target_canonical:
            correct += 1

    return correct / len(examples), predictions


base_model_name = "google/flan-t5-small"
baseline_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
baseline_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)

baseline_raw_accuracy, _ = compute_accuracy(
    baseline_model,
    baseline_tokenizer,
    distillation_examples,
    use_context=False
)

baseline_context_accuracy, baseline_predictions = compute_accuracy(
    baseline_model,
    baseline_tokenizer,
    distillation_examples,
    use_context=True
)

print(f"Baseline student accuracy without context: {baseline_raw_accuracy:.2%}")
print(f"Baseline student accuracy with context: {baseline_context_accuracy:.2%}")
for question, prediction in baseline_predictions[:10]:
    print(f"- {question} -> {prediction} (canonical: {canonicalize_answer(prediction)})")


Baseline student accuracy without context: 0.00%
Baseline student accuracy with context: 87.50%
- Which country hosts the ITER fusion reactor construction site? -> France (canonical: france)
- In which European country is the ITER tokamak being assembled? -> France (canonical: france)
- Which nation hosts the construction of the ITER fusion reactor? -> France (canonical: france)
- Where is the ITER fusion megaproject located? -> southern France (canonical: france)
- What is the capital of the country that operates the ALMA observatory? -> Santiago (canonical: santiago)
- Name the capital city of the nation that operates the ALMA observatory. -> Santiago (canonical: santiago)
- What city is the capital of Chile, the country running ALMA? -> Santiago (canonical: santiago)
- Which city is the capital of the country that owns the ALMA array? -> Santiago (canonical: santiago)
- Which nation leads the world's installed offshore wind capacity in 2024? -> China (canonical: china)
- Which natio

In [26]:

from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments
)
import os

DISTILLATION_PROMPT = "Question: {question}\nTeacher reasoning: {reasoning}\nAnswer in one short phrase:"
MODEL_DIR = "ft_flan_t5_simpleqa"


def build_training_prompt(question: str, reasoning: str) -> str:
    key = question.strip().lower()
    context = QUESTION_CONTEXT.get(key)
    base = DISTILLATION_PROMPT.format(question=question, reasoning=reasoning)
    if context:
        return f"Context: {context}\n" + base
    return base

train_inputs = []
train_targets = []
for example, trace in zip(distillation_examples, teacher_trace):
    reasoning = trace["reasoning"].strip()
    prompt = build_training_prompt(trace["question"], reasoning)
    train_inputs.append(prompt)
    train_targets.append(example.answer)

train_pairs = Dataset.from_dict({
    "input_text": train_inputs,
    "target_text": train_targets,
})

model_name = "google/flan-t5-small"
ft_tokenizer = AutoTokenizer.from_pretrained(model_name)
ft_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


def preprocess(batch):
    model_inputs = ft_tokenizer(batch["input_text"], max_length=200, truncation=True, padding="max_length")
    with ft_tokenizer.as_target_tokenizer():
        labels = ft_tokenizer(batch["target_text"], max_length=48, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

training_dataset = train_pairs.map(
    preprocess,
    batched=True,
    remove_columns=train_pairs.column_names
)
collator = DataCollatorForSeq2Seq(ft_tokenizer, model=ft_model)

warmup_steps = max(1, len(train_inputs))

training_args = TrainingArguments(
    output_dir=MODEL_DIR,
    num_train_epochs=30,
    learning_rate=8e-5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    weight_decay=0.02,
    warmup_steps=warmup_steps,
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_strategy="no",
    report_to=[],
    remove_unused_columns=False,
    label_smoothing_factor=0.1,
    seed=13
)

if not os.path.exists(MODEL_DIR):
    trainer = Trainer(
        model=ft_model,
        args=training_args,
        train_dataset=training_dataset,
        tokenizer=ft_tokenizer,
        data_collator=collator,
    )
    trainer.train()
    trainer.save_model(MODEL_DIR)
    ft_tokenizer.save_pretrained(MODEL_DIR)

finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
finetuned_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
raw_accuracy, _ = compute_accuracy(
    finetuned_model,
    finetuned_tokenizer,
    distillation_examples,
    use_context=False
)
context_accuracy, finetuned_predictions = compute_accuracy(
    finetuned_model,
    finetuned_tokenizer,
    distillation_examples,
    use_context=True
)

print(f"Fine-tuned student accuracy without context: {raw_accuracy:.2%}")
print(f"Fine-tuned student accuracy with context: {context_accuracy:.2%}")
for question, prediction in finetuned_predictions[:10]:
    print(f"- {question} -> {prediction} (canonical: {canonicalize_answer(prediction)})")


Map:   0%|          | 0/48 [00:00<?, ? examples/s]



Fine-tuned student accuracy without context: 0.00%
Fine-tuned student accuracy with context: 66.67%
- Which country hosts the ITER fusion reactor construction site? -> France (canonical: france)
- In which European country is the ITER tokamak being assembled? -> France (canonical: france)
- Which nation hosts the construction of the ITER fusion reactor? -> France (canonical: france)
- Where is the ITER fusion megaproject located? -> Cadarache (canonical: cadarache)
- What is the capital of the country that operates the ALMA observatory? -> Chile (canonical: chile)
- Name the capital city of the nation that operates the ALMA observatory. -> Chile (canonical: chile)
- What city is the capital of Chile, the country running ALMA? -> Santiago (canonical: santiago)
- Which city is the capital of the country that owns the ALMA array? -> Chile (canonical: chile)
- Which nation leads the world's installed offshore wind capacity in 2024? -> China (canonical: china)
- Which nation tops global ins