diff --git a/dsp/modules/azure_openai.py b/dsp/modules/azure_openai.py index 35e8977740..c80f820883 100644 --- a/dsp/modules/azure_openai.py +++ b/dsp/modules/azure_openai.py @@ -56,11 +56,14 @@ def __init__( model: str = "gpt-3.5-turbo-instruct", api_key: Optional[str] = None, model_type: Literal["chat", "text"] = "chat", + system_prompt: Optional[str] = None, **kwargs, ): super().__init__(model) self.provider = "openai" + self.system_prompt = system_prompt + # Define Client if OPENAI_LEGACY: # Assert that all variables are available @@ -132,7 +135,11 @@ def basic_request(self, prompt: str, **kwargs): kwargs = {**self.kwargs, **kwargs} if self.model_type == "chat": # caching mechanism requires hashable kwargs - kwargs["messages"] = [{"role": "user", "content": prompt}] + messages = [{"role": "user", "content": prompt}] + if self.system_prompt: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + + kwargs["messages"] = messages kwargs = {"stringify_request": json.dumps(kwargs)} response = chat_request(self.client, **kwargs)