diff --git a/dsp/modules/azure_openai.py b/dsp/modules/azure_openai.py index 6ff774e767..1c7cef2255 100644 --- a/dsp/modules/azure_openai.py +++ b/dsp/modules/azure_openai.py @@ -108,6 +108,10 @@ def __init__( **kwargs, } # TODO: add kwargs above for + self.api_base = api_base + self.api_version = api_version + self.api_key = api_key + self.history: list[dict[str, Any]] = [] def _openai_client(self): @@ -219,6 +223,19 @@ def __call__( completions = [c for _, c in scored_completions] return completions + + def copy(self, **kwargs): + """Returns a copy of the language model with the same parameters.""" + kwargs = {**self.kwargs, **kwargs} + model = kwargs.pop("model") + + return self.__class__( + model=model, + api_key=self.api_key, + api_version=self.api_version, + api_base=self.api_base, + **kwargs, + ) @CacheMemory.cache