Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,41 @@ def _generate(template: Template, **kwargs) -> Callable:

generator = dsp.settings.lm

def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example):
"""If the required fields are not present in the completion, extend the generation."""
assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens"
# remove content of last field to avoid half-completed content
for field_name in get_all_fields_following_missing_field(completion, field_names):
completion.pop(field_name, None)

# Recurse with greedy decoding and a shorter length.
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
max_tokens_key: max_tokens,
"n": 1,
"temperature": 0.0,
}

_, finished_completion = generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
)
return finished_completion.data[0]


def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None,
):
Expand All @@ -77,54 +112,19 @@ def do_generate(
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
completions: list[Example] = [template.extract(example, p) for p in completions]

# Find the completions that are most complete.
# Find the completions that are unfinished.
field_names: list[str] = [field.input_variable for field in template.fields]

last_field_idx = 0
for field_idx, key in enumerate(field_names):
completions_ = [
c for c in completions if key in c.keys() and c[key] is not None
]

# Filter out completions that are missing fields that are present in at least one completion.
if len(completions_):
completions = completions_
last_field_idx = field_idx + 1

# If none of the completions is completed (i.e., none has the final field set).
if last_field_idx < len(field_names):
# Pick the first completion that has gone farthest.
completion = completions[0]
completion[field_names[last_field_idx]] = ""

# Recurse with greedy decoding and a shorter length.
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
max_tokens_key: max_tokens,
"n": 1,
"temperature": 0.0,
}

assert max_depth > 0
return generate(template, **new_kwargs)(
completion,
stage=stage,
max_depth=max_depth - 1,
original_example=original_example,
finished_completions = []
for completion in completions:
if all((completion.get(key, "") != "") for key in field_names):
finished_completions.append(completion)
continue
finished_completions.append(
extend_generation(completion, field_names, stage, max_depth, original_example),
)

completions = Completions(completions, template=template)
completions = Completions(finished_completions, template=template)
example = example.copy(completions=completions)

if len(completions) == 1:
Expand Down Expand Up @@ -161,6 +161,15 @@ def do_generate(

return do_generate

def get_all_fields_following_missing_field(completion: Example, field_names: list[str]) -> list[str]:
"""Returns every field following the first missing field"""
for i, field_name in enumerate(field_names):
if field_name not in completion:
return field_names[i:]
if completion[field_name] == "":
return field_names[i:]
return []


def generate_sc(
example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs,
Expand Down