From 9c286f3166ea8fe3c4cc3db46caf80d3dbaefac6 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 08:29:01 +0000 Subject: [PATCH 1/8] remove n=1 in fallback generate logic when not all fields have been generated --- dsp/primitives/predict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index ffd4f8caa3..509757f3d7 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -112,7 +112,6 @@ def do_generate( new_kwargs = { **kwargs, max_tokens_key: max_tokens, - "n": 1, "temperature": 0.0, } From a198363ddf46e06d2febff73411837ffc1baa34d Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 13:21:20 +0000 Subject: [PATCH 2/8] extend the generation for all candidate generations until all have required fields --- dsp/primitives/predict.py | 91 +++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 509757f3d7..9378d4732d 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -61,6 +61,45 @@ def _generate(template: Template, **kwargs) -> Callable: generator = dsp.settings.lm + def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example): + """If the required fields are not present in the completion, extend the generation.""" + # remove content of last field to avoid half-completed content + for field_idx, key in enumerate(field_names): + if key not in completion: + break + if completion[key] is None or completion[key] == "": + break + completion[field_names[field_idx]] = "" + + # Recurse with greedy decoding and a shorter length. + max_tokens = (kwargs.get("max_tokens") or + kwargs.get("max_output_tokens") or + dsp.settings.lm.kwargs.get("max_tokens") or + dsp.settings.lm.kwargs.get('max_output_tokens')) + + + if max_tokens is None: + raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.") + max_tokens = min(max(75, max_tokens // 2), max_tokens) + keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys()) + max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens" + new_kwargs = { + **kwargs, + max_tokens_key: max_tokens, + "n": 1, + "temperature": 0.0, + } + + assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens" + _, finished_completion = generate(template, **new_kwargs)( + completion, + stage=stage, + max_depth=max_depth - 1, + original_example=original_example, + ) + return finished_completion.data[0] + + def do_generate( example: Example, stage: str, max_depth: int = 2, original_example=None, ): @@ -77,53 +116,19 @@ def do_generate( completions: list[dict[str, Any]] = generator(prompt, **kwargs) completions: list[Example] = [template.extract(example, p) for p in completions] - # Find the completions that are most complete. + # Find the completions that are unfinished. field_names: list[str] = [field.input_variable for field in template.fields] - last_field_idx = 0 - for field_idx, key in enumerate(field_names): - completions_ = [ - c for c in completions if key in c.keys() and c[key] is not None - ] - - # Filter out completions that are missing fields that are present in at least one completion. - if len(completions_): - completions = completions_ - last_field_idx = field_idx + 1 - - # If none of the completions is completed (i.e., none has the final field set). - if last_field_idx < len(field_names): - # Pick the first completion that has gone farthest. - completion = completions[0] - completion[field_names[last_field_idx]] = "" - - # Recurse with greedy decoding and a shorter length. - max_tokens = (kwargs.get("max_tokens") or - kwargs.get("max_output_tokens") or - dsp.settings.lm.kwargs.get("max_tokens") or - dsp.settings.lm.kwargs.get('max_output_tokens')) - - - if max_tokens is None: - raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.") - max_tokens = min(max(75, max_tokens // 2), max_tokens) - keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys()) - max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens" - new_kwargs = { - **kwargs, - max_tokens_key: max_tokens, - "temperature": 0.0, - } - - assert max_depth > 0 - return generate(template, **new_kwargs)( - completion, - stage=stage, - max_depth=max_depth - 1, - original_example=original_example, + finished_completions = [] + for completion in completions: + if all(key in completion for key in field_names): + finished_completions.append(completion) + continue + finished_completions.append( + extend_generation(completion, field_names, stage, max_depth, original_example) ) - completions = Completions(completions, template=template) + completions = Completions(finished_completions, template=template) example = example.copy(completions=completions) if len(completions) == 1: From 70e630a2c2ab68d2dcb86ba65fba7fb15db1c8e3 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 13:34:30 +0000 Subject: [PATCH 3/8] remove break clauses --- dsp/primitives/predict.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 9378d4732d..b94d63e403 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -63,13 +63,9 @@ def _generate(template: Template, **kwargs) -> Callable: def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example): """If the required fields are not present in the completion, extend the generation.""" + # remove content of last field to avoid half-completed content - for field_idx, key in enumerate(field_names): - if key not in completion: - break - if completion[key] is None or completion[key] == "": - break - completion[field_names[field_idx]] = "" + completion[get_last_field(completion, field_names)] = "" # Recurse with greedy decoding and a shorter length. max_tokens = (kwargs.get("max_tokens") or @@ -165,6 +161,12 @@ def do_generate( return do_generate +def get_last_field(completion, field_names): + for key in field_names: + if key not in completion: + return key + if completion[key] is None or completion[key] == "": + return key def generate_sc( example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs, From 5b16f10fe971bd8c4dad240ae3553e9cc83cff7c Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 13:37:42 +0000 Subject: [PATCH 4/8] add types --- dsp/primitives/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index b94d63e403..f8db9c47e0 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -161,7 +161,7 @@ def do_generate( return do_generate -def get_last_field(completion, field_names): +def get_last_field(completion: Example, field_names: list[str]): for key in field_names: if key not in completion: return key From 991d713d22b9bc2a6035a7a5747bf908dc0208ce Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 20:51:27 +0000 Subject: [PATCH 5/8] pick up completion from first missing field --- dsp/primitives/predict.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index f8db9c47e0..029865571d 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -63,9 +63,9 @@ def _generate(template: Template, **kwargs) -> Callable: def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example): """If the required fields are not present in the completion, extend the generation.""" - # remove content of last field to avoid half-completed content - completion[get_last_field(completion, field_names)] = "" + for field_name in get_all_fields_following_missing_field(completion, field_names): + completion[field_name] = "" # Recurse with greedy decoding and a shorter length. max_tokens = (kwargs.get("max_tokens") or @@ -117,7 +117,11 @@ def do_generate( finished_completions = [] for completion in completions: - if all(key in completion for key in field_names): + all_keys_valid = all( + (completion.get(key, None) is not None) and + (completion.get(key, None) !="") for key in field_names + ) + if all_keys_valid: finished_completions.append(completion) continue finished_completions.append( @@ -161,12 +165,15 @@ def do_generate( return do_generate -def get_last_field(completion: Example, field_names: list[str]): - for key in field_names: - if key not in completion: - return key - if completion[key] is None or completion[key] == "": - return key +def get_all_fields_following_missing_field(completion: Example, field_names: list[str]) -> list[str]: + """Returns every field following the first missing field""" + for i, field_name in enumerate(field_names): + if field_name not in completion: + return field_names[i:] + if completion[field_name] == "": + return field_names[i:] + return [] + def generate_sc( example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs, From 67bb1a00425d30e4596e5e668b53bd05202a34a1 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 21:18:48 +0000 Subject: [PATCH 6/8] pick up completion from first missing field - and pop empty fields --- dsp/primitives/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 029865571d..62b889a98a 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -65,7 +65,7 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma """If the required fields are not present in the completion, extend the generation.""" # remove content of last field to avoid half-completed content for field_name in get_all_fields_following_missing_field(completion, field_names): - completion[field_name] = "" + completion.pop(field_name, None) # Recurse with greedy decoding and a shorter length. max_tokens = (kwargs.get("max_tokens") or From cffbfffc7a8310dd3784497239c431e10f441d69 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Sun, 28 Apr 2024 21:33:44 +0000 Subject: [PATCH 7/8] pick up completion from first missing field - and pop empty fields --- dsp/primitives/predict.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 62b889a98a..565394194e 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -63,6 +63,7 @@ def _generate(template: Template, **kwargs) -> Callable: def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example): """If the required fields are not present in the completion, extend the generation.""" + assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens" # remove content of last field to avoid half-completed content for field_name in get_all_fields_following_missing_field(completion, field_names): completion.pop(field_name, None) @@ -86,7 +87,6 @@ def extend_generation(completion: Example, field_names: list[str], stage:str, ma "temperature": 0.0, } - assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens" _, finished_completion = generate(template, **new_kwargs)( completion, stage=stage, @@ -117,11 +117,7 @@ def do_generate( finished_completions = [] for completion in completions: - all_keys_valid = all( - (completion.get(key, None) is not None) and - (completion.get(key, None) !="") for key in field_names - ) - if all_keys_valid: + if all((completion.get(key, "") != "") for key in field_names): finished_completions.append(completion) continue finished_completions.append( From b2a9258744a03124b671e32c417bad1c413630b2 Mon Sep 17 00:00:00 2001 From: JONEMI19 Date: Mon, 6 May 2024 07:34:40 +0000 Subject: [PATCH 8/8] linting --- dsp/primitives/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/primitives/predict.py b/dsp/primitives/predict.py index 565394194e..3ee615574b 100644 --- a/dsp/primitives/predict.py +++ b/dsp/primitives/predict.py @@ -121,7 +121,7 @@ def do_generate( finished_completions.append(completion) continue finished_completions.append( - extend_generation(completion, field_names, stage, max_depth, original_example) + extend_generation(completion, field_names, stage, max_depth, original_example), ) completions = Completions(finished_completions, template=template)