Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ __pycache__/
*.py[cod]
*$py.class

# Vim
*.swp

# Jupyter Notebook
.ipynb_checkpoints
# notebooks/
Expand Down Expand Up @@ -42,4 +45,4 @@ finetuning_ckpts/
.idea
assertion.log
*.log
*.db
*.db
2 changes: 1 addition & 1 deletion dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:

def format_handler(x):
assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}"
return " ".join(x.split())

formatted_value = format_handler(example[field.input_variable])
Expand Down
1 change: 1 addition & 0 deletions dspy/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .functional import cot, predictor, FunctionalModule, TypedPredictor
327 changes: 327 additions & 0 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
import inspect, os, openai, dspy, typing, pydantic
from typing import Annotated
import typing
from dsp.templates import passages2text
import json


MAX_RETRIES = 3


def predictor(func):
signature = _func_to_signature(func)
return TypedPredictor(signature, chain_of_thought=False, simple_output=True)


def cot(func):
signature = _func_to_signature(func)
return TypedPredictor(signature, chain_of_thought=True, simple_output=True)


class FunctionalModule(dspy.Module):
def __init__(self):
super().__init__()
for name in dir(self):
attr = getattr(self, name)
if isinstance(attr, dspy.Module):
self.__dict__[name] = attr.copy()


class TypedPredictor(dspy.Module):
def __init__(self, signature, chain_of_thought=False, simple_output=False):
super().__init__()
self.signature = signature
self.predictor = dspy.Predict(signature)
self.chain_of_thought = chain_of_thought
self.simple_output = simple_output

def copy(self):
return TypedPredictor(self.signature, self.chain_of_thought, self.simple_output)

def _prepare_signature(self):
"""Add formats and parsers to the signature fields, based on the type
annotations of the fields."""
signature = self.signature
for name, field in self.signature.fields.items():
is_output = field.json_schema_extra["__dspy_field_type"] == "output"
type_ = field.annotation
if is_output:
if type_ in (str, int, float, bool):
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
+ (f". Respond with a single {type_.__name__} value"),
format=lambda x: x if isinstance(x, str) else str(x),
parser=type_,
)
else:
# Anything else we wrap in a pydantic object
unwrap = lambda x: x
if not inspect.isclass(type_) or not issubclass(
type_, pydantic.BaseModel
):
type_ = pydantic.create_model(
"Output", value=(type_, ...), __base__=pydantic.BaseModel
)
unwrap = lambda x: x.value
signature = signature.with_updated_fields(
name,
desc=field.json_schema_extra.get("desc", "")
+ (
f". Respond with a single JSON object using the schema "
+ json.dumps(type_.model_json_schema())
),
format=lambda x: x if isinstance(x, str) else x.model_dump_json(),
parser=lambda x: unwrap(
type_.model_validate_json(_unwrap_json(x))
),
)
else: # If input field
format = lambda x: x if isinstance(x, str) else str(x)
if type_ in (list[str], tuple[str]):
format = passages2text
elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
format = lambda x: x if isinstance(x, str) else x.model_dump_json()
signature = signature.with_updated_fields(name, format=format)

if self.chain_of_thought:
output_keys = ", ".join(signature.output_fields.keys())
signature = signature.prepend(
"reasoning",
dspy.OutputField(
prefix="Reasoning: Let's think step by step in order to",
desc="${produce the " + output_keys + "}. We ...",
),
)
return signature

def forward(self, **kwargs):
modified_kwargs = kwargs.copy()
signature = self._prepare_signature()
for try_i in range(MAX_RETRIES):
result = self.predictor(**modified_kwargs, new_signature=signature)
errors = {}
parsed_results = {}
# Parse the outputs
for name, field in signature.output_fields.items():
try:
value = getattr(result, name)
parser = field.json_schema_extra.get("parser", lambda x: x)
parsed_results[name] = parser(value)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = e
if errors:
# Add new fields for each error
for name, error in errors.items():
modified_kwargs[f"error_{name}_{try_i}"] = str(error)
signature = signature.append(
f"error_{name}_{try_i}",
dspy.InputField(
prefix=f"Past Error "
+ (f"({name}):" if try_i == 0 else f"({name}, {try_i+1}):"),
desc="An error to avoid in the future",
),
)
else:
# If there are no errors, we return the parsed results
for name, value in parsed_results.items():
setattr(result, name, value)
if self.simple_output:
*_, last_output = signature.output_fields.keys()
return result[last_output]
return result
raise ValueError("Too many retries")


def _func_to_signature(func):
"""Make a dspy.Signature based on a function definition."""
sig = inspect.signature(func)
annotations = typing.get_type_hints(func)
output_key = func.__name__
instructions = func.__doc__
fields = {}

# Input fields
for param in sig.parameters.values():
if param.name == "self":
continue
# We default to str as the type of the input
annotation = annotations.get(param.name, str)
kwargs = {}
if typing.get_origin(annotation) is Annotated:
annotation, kwargs["desc"] = typing.get_args(annotation)
fields[param.name] = (annotation, dspy.InputField(**kwargs))

# Output field
kwargs = {}
annotation = annotations.get("return", str)
if typing.get_origin(annotation) is Annotated:
annotation, kwargs["desc"] = typing.get_args(annotation)
fields[output_key] = (annotation, dspy.OutputField(**kwargs))

return dspy.Signature(fields, instructions)


def _unwrap_json(output):
output = output.strip()
if output.startswith("```"):
if not output.startswith("```json"):
raise ValueError("json output should start with ```json")
if not output.endswith("```"):
raise ValueError("json output should end with ```")
output = output[7:-3].strip()
if not output.startswith("{") or not output.endswith("}"):
raise ValueError("json output should start and end with { and }")
return output


################################################################################
# Example usage
################################################################################


def main():
class Answer(pydantic.BaseModel):
value: float
certainty: float
comments: list[str] = pydantic.Field(
description="At least two comments about the answer"
)

class QA(dspy.Module):
@predictor
def hard_question(self, topic: str) -> str:
"""Think of a hard factual question about a topic. It should be answerable with a number."""

@cot
def answer(self, question: Annotated[str, "Question to answer"]) -> Answer:
pass

def forward(self, **kwargs):
question = self.hard_question(**kwargs)
return (question, self.answer(question=question))

openai.api_key = os.getenv("OPENAI_API_KEY")
lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
# lm = dspy.OpenAI(model="gpt-4", max_tokens=4000)
# lm = dspy.OpenAI(model="gpt-4-preview-1106", max_tokens=4000)
with dspy.context(lm=lm):
qa = QA()
question, answer = qa(topic="Physics")
# lm.inspect_history(n=5)

print("Question:", question)
print("Answer:", answer)


################################################################################
# HotpotQA example with SimpleBaleen
################################################################################


def validate_context_and_answer_and_hops(example, pred, trace=None):
if not dspy.evaluate.answer_exact_match(example, pred):
return False
if not dspy.evaluate.answer_passage_match(example, pred):
return False

hops = [example.question] + [
outputs.query for *_, outputs in trace if "query" in outputs
]

if max([len(h) for h in hops]) > 100:
return False
if any(
dspy.evaluate.answer_exact_match_str(hops[idx], hops[:idx], frac=0.8)
for idx in range(2, len(hops))
):
return False

return True


def gold_passages_retrieved(example, pred, trace=None):
gold_titles = set(map(dspy.evaluate.normalize_text, example["gold_titles"]))
found_titles = set(
map(dspy.evaluate.normalize_text, [c.split(" | ")[0] for c in pred.context])
)

return gold_titles.issubset(found_titles)


def hotpot():
from dsp.utils import deduplicate
import dspy.evaluate
from dspy.datasets import HotPotQA
from dspy.evaluate.evaluate import Evaluate
from dspy.teleprompt.bootstrap import BootstrapFewShot

print("Load the dataset.")
dataset = HotPotQA(
train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0
)
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
print("Done")

class SimplifiedBaleen(FunctionalModule):
def __init__(self, passages_per_hop=3, max_hops=1):
super().__init__()
self.retrieve = dspy.Retrieve(k=passages_per_hop)
self.max_hops = max_hops

@cot
def generate_query(self, context: list[str], question) -> str:
"""Write a simple search query that will help answer a complex question."""
pass

@cot
def generate_answer(self, context: list[str], question) -> str:
"""Answer questions with short factoid answers."""
pass

def forward(self, question):
context = []

for hop in range(self.max_hops):
query = self.generate_query(context=context, question=question)
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

answer = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=answer)

openai.api_key = os.getenv("OPENAI_API_KEY")
rm = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")
lm = dspy.OpenAI(model="gpt-3.5-turbo", max_tokens=4000)
dspy.settings.configure(lm=lm, rm=rm, trace=[])

evaluate_on_hotpotqa = Evaluate(
devset=devset, num_threads=10, display_progress=True, display_table=5
)

# uncompiled (i.e., zero-shot) program
uncompiled_baleen = SimplifiedBaleen()
print(
"Uncompiled Baleen retrieval score:",
evaluate_on_hotpotqa(uncompiled_baleen, metric=gold_passages_retrieved),
)

# compiled (i.e., few-shot) program
compiled_baleen = BootstrapFewShot(
metric=validate_context_and_answer_and_hops
).compile(
SimplifiedBaleen(),
teacher=SimplifiedBaleen(passages_per_hop=2),
trainset=trainset,
)
print(
"Compiled Baleen retrieval score:",
evaluate_on_hotpotqa(compiled_baleen, metric=gold_passages_retrieved),
)
# lm.inspect_history(n=5)


if __name__ == "__main__":
# main()
hotpot()
1 change: 1 addition & 0 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .aggregation import majority
from .program_of_thought import ProgramOfThought
from .retry import Retry
from .knn import KNN
10 changes: 5 additions & 5 deletions dspy/predict/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None)
except:
signature = None

try:
field = field if field else signature.fields[-1].output_variable
except:
field = field if field else list(completions[0].keys())[-1]
if not field:
if signature:
field = signature.output_fields[-1]
else:
field = list(completions[0].keys())[-1]

# Normalize
normalize = normalize if normalize else lambda x: x
Expand All @@ -51,5 +52,4 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None)
# if input_type == Prediction:
return Prediction.from_completions([completion], signature=signature)

return Completions([completion], signature=signature)

Loading