diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index 8bead74147..54c0af0123 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -185,7 +185,11 @@ def __call__( if only_completed and len(completed_choices): choices = completed_choices - completions = [self._get_choice_text(c) for c in choices] + if kwargs.get("logprobs", False): + completions = [{'text': self._get_choice_text(c), 'logprobs': c["logprobs"]} for c in choices] + else: + completions = [self._get_choice_text(c) for c in choices] + if return_sorted and kwargs.get("n", 1) > 1: scored_completions = [] @@ -200,10 +204,12 @@ def __call__( tokens, logprobs = tokens[:index], logprobs[:index] avglog = sum(logprobs) / len(logprobs) - scored_completions.append((avglog, self._get_choice_text(c))) - + scored_completions.append((avglog, self._get_choice_text(c), logprobs)) scored_completions = sorted(scored_completions, reverse=True) - completions = [c for _, c in scored_completions] + if logprobs: + completions = [{'text': c, 'logprobs': lp} for _, c, lp in scored_completions] + else: + completions = [c for _, c in scored_completions] return completions