From 231b120abaa755fc0ce33b1dd0cb59476f2f4e3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20D=C3=B6rr?= Date: Thu, 9 May 2024 13:05:45 +0200 Subject: [PATCH 1/2] Add support for logprob output --- dsp/modules/gpt3.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index 8bead74147..ee34908733 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -185,7 +185,12 @@ def __call__( if only_completed and len(completed_choices): choices = completed_choices - completions = [self._get_choice_text(c) for c in choices] + # if logprobs: + if kwargs.get("logprobs", False): + completions = [{'text': self._get_choice_text(c), 'logprobs': c["logprobs"]} for c in choices] + else: + completions = [self._get_choice_text(c) for c in choices] + if return_sorted and kwargs.get("n", 1) > 1: scored_completions = [] @@ -200,10 +205,12 @@ def __call__( tokens, logprobs = tokens[:index], logprobs[:index] avglog = sum(logprobs) / len(logprobs) - scored_completions.append((avglog, self._get_choice_text(c))) - + scored_completions.append((avglog, self._get_choice_text(c), logprobs)) scored_completions = sorted(scored_completions, reverse=True) - completions = [c for _, c in scored_completions] + if logprobs: + completions = [{'text': c, 'logprobs': lp} for _, c, lp in scored_completions] + else: + completions = [c for _, c in scored_completions] return completions From 4b48281ff439964a9c6bc842fd11c9d741d8669a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20D=C3=B6rr?= Date: Thu, 9 May 2024 13:11:32 +0200 Subject: [PATCH 2/2] Remove unnecessary comment in GPT3 class. --- dsp/modules/gpt3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dsp/modules/gpt3.py b/dsp/modules/gpt3.py index ee34908733..54c0af0123 100644 --- a/dsp/modules/gpt3.py +++ b/dsp/modules/gpt3.py @@ -185,7 +185,6 @@ def __call__( if only_completed and len(completed_choices): choices = completed_choices - # if logprobs: if kwargs.get("logprobs", False): completions = [{'text': self._get_choice_text(c), 'logprobs': c["logprobs"]} for c in choices] else: