Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 32 additions & 41 deletions dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []

# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
# This creates the "Output:" prefix at the end of the prompt.
if not is_demo:
has_value = [
field.input_variable in example
Expand All @@ -80,40 +82,40 @@ def query(self, example: Example, is_demo: bool = False) -> str:
for field in self.fields
]

for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break
# If there are no inputs, set the first field to ""
if not any(has_value):
example[self.fields[0].input_variable] = ""
# Otherwise find the first field without a value.
else:
for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break

for field in self.fields:
if (
field.input_variable in example
and example[field.input_variable] is not None
):
if field.input_variable in example and example[field.input_variable] is not None:
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:

def format_handler(x):
assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}"
return " ".join(x.split())

formatted_value = format_handler(example[field.input_variable])
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator

result.append(
f"{field.name}{separator}{formatted_value}",
)

if self._has_augmented_guidelines() and (example.get('augmented', False)):
if self._has_augmented_guidelines() and (example.get("augmented", False)):
return "\n\n".join([r for r in result if r])
return "\n".join([r for r in result if r])

def guidelines(self, show_guidelines=True) -> str:
"""Returns the task guidelines as described in the lm prompt"""
if (not show_guidelines) or (
hasattr(dsp.settings, "show_guidelines")
and not dsp.settings.show_guidelines
):
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
return ""

result = "Follow the following format.\n\n"
Expand All @@ -128,11 +130,13 @@ def guidelines(self, show_guidelines=True) -> str:

def _has_augmented_guidelines(self):
return len(self.fields) > 3 or any(
("\n" in field.separator) or ('\n' in field.description) for field in self.fields
("\n" in field.separator) or ("\n" in field.description) for field in self.fields
)

def extract(
self, example: Union[Example, dict[str, Any]], raw_pred: str,
self,
example: Union[Example, dict[str, Any]],
raw_pred: str,
) -> Example:
"""Extracts the answer from the LM raw prediction using the template structure

Expand All @@ -149,10 +153,7 @@ def extract(

idx = 0
while idx < len(self.fields):
if (
self.fields[idx].input_variable not in example
or example[self.fields[idx].input_variable] is None
):
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
break
idx += 1

Expand All @@ -166,16 +167,16 @@ def extract(

if offset >= 0:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip()

idx += 1
else:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -187,7 +188,7 @@ def extract(
assert idx == len(self.fields) - 1, (idx, len(self.fields))

if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -198,7 +199,7 @@ def extract(
def __call__(self, example, show_guidelines=True) -> str:
example = dsp.Example(example)

if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only:
if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
return self.query(example)

# The training data should not contain the output variable
Expand All @@ -209,29 +210,20 @@ def __call__(self, example, show_guidelines=True) -> str:
self.query(demo, is_demo=True)
for demo in example.demos
if (
(not demo.get('augmented', False))
(not demo.get("augmented", False))
and ( # validate that the training example has the same primitive input var as the template
self.fields[-1].input_variable in demo
and demo[self.fields[-1].input_variable] is not None
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
)
)
]

ademos = [
self.query(demo, is_demo=True)
for demo in example.demos
if demo.get('augmented', False)
]
ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]

# Move the rdemos to ademos if rdemo has all the fields filled in
rdemos_ = []
new_ademos = []
for rdemo in rdemos:
if all(
(field.name in rdemo)
for field in self.fields
if field.input_variable in example
):
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
import dspy

if dspy.settings.release >= 20230928:
Expand All @@ -244,7 +236,6 @@ def __call__(self, example, show_guidelines=True) -> str:
ademos = new_ademos + ademos
rdemos = rdemos_


long_query = self._has_augmented_guidelines()

if long_query:
Expand All @@ -253,10 +244,10 @@ def __call__(self, example, show_guidelines=True) -> str:
query = self.query(example)

# if it has more lines than fields
if len(query.split('\n')) > len(self.fields):
if len(query.split("\n")) > len(self.fields):
long_query = True

if not example.get('augmented', False):
if not example.get("augmented", False):
example["augmented"] = True
query = self.query(example)

Expand Down
5 changes: 3 additions & 2 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# from .evaluation import *
# FIXME:
import dsp
from dsp.modules.hf_client import ChatModuleClient, HFClientSGLang, HFClientVLLM, HFServerTGI

Expand All @@ -8,6 +6,9 @@
from .retrieve import *
from .signatures import *

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip

settings = dsp.settings

AzureOpenAI = dsp.AzureOpenAI
Expand Down
Loading