From ba4c09e8a7f0d13fd85ff013602c60208eac578b Mon Sep 17 00:00:00 2001 From: Misha Smirnov Date: Wed, 17 Apr 2024 19:37:51 +0000 Subject: [PATCH] fix(dspy/modules/aws_models): properly copy kwargs so temporary changes don't propagate to base model --- dsp/modules/aws_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsp/modules/aws_models.py b/dsp/modules/aws_models.py index 94b07cad43..34cfff088e 100644 --- a/dsp/modules/aws_models.py +++ b/dsp/modules/aws_models.py @@ -148,7 +148,7 @@ def _format_prompt(self, raw_prompt: str) -> str: return " [INST] Human: " + raw_prompt + " [/INST] Assistant: " def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs + base_args: dict[str, Any] = self.kwargs.copy() for k, v in kwargs.items(): base_args[k] = v @@ -211,7 +211,7 @@ def __init__( self.kwargs[k] = v def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs + base_args: dict[str, Any] = self.kwargs.copy() for k, v in kwargs.items(): base_args[k] = v @@ -275,7 +275,7 @@ def __init__( self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens") def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]: - base_args: dict[str, Any] = self.kwargs + base_args: dict[str, Any] = self.kwargs.copy() for k, v in kwargs.items(): base_args[k] = v