From cddab06079498f87817460e555864979249aae7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20D=C3=B6rr?= Date: Sat, 18 May 2024 18:20:21 +0200 Subject: [PATCH 1/2] Add support for logprobs in hf_client --- dsp/modules/hf_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index c4bf83e558..d303cb14f1 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -206,7 +206,7 @@ def _generate(self, prompt, **kwargs): completions = json_response["choices"] response = { "prompt": prompt, - "choices": [{"text": c["text"]} for c in completions], + "choices": [{"text": c["text"], "logprobs": c["logprobs"]} if c["logprobs"] else {"text": c["text"]} for c in completions], } return response From e32cb00301bae8702a3ad19cbe71e1ad9a8ced63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20D=C3=B6rr?= Date: Sat, 18 May 2024 18:21:10 +0200 Subject: [PATCH 2/2] Add support for logprobs in hf.py --- dsp/modules/hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/modules/hf.py b/dsp/modules/hf.py index cefcee9290..db6d7f54eb 100644 --- a/dsp/modules/hf.py +++ b/dsp/modules/hf.py @@ -188,7 +188,7 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): kwargs["do_sample"] = True response = self.request(prompt, **kwargs) - return [c["text"] for c in response["choices"]] + return [{"text": c["text"], "logprobs": c["logprobs"]} if 'logprobs' in c else c["text"] for c in response["choices"]] # @functools.lru_cache(maxsize=None if cache_turn_on else 0)