diff --git a/src/django_ai_core/llm/base.py b/src/django_ai_core/llm/base.py index bb723aa..11d78f4 100644 --- a/src/django_ai_core/llm/base.py +++ b/src/django_ai_core/llm/base.py @@ -21,12 +21,10 @@ def service_id(self) -> str: return f"{self.__class__.__name__}:{self.client.PROVIDER_NAME}:{self.model}" def completion(self, messages, **kwargs): - return self.client.completion(model_id=self.model, messages=messages, **kwargs) + return self.client.completion(model=self.model, messages=messages, **kwargs) def responses(self, input_data, **kwargs): - return self.client.responses( - model_id=self.model, input_data=input_data, **kwargs - ) + return self.client.responses(model=self.model, input_data=input_data, **kwargs) def embedding(self, inputs, **kwargs): - return self.client._embedding(model_id=self.model, inputs=inputs, **kwargs) + return self.client._embedding(model=self.model, inputs=inputs, **kwargs) diff --git a/tests/testapp/indexes.py b/tests/testapp/indexes.py index 5e71897..43057ac 100644 --- a/tests/testapp/indexes.py +++ b/tests/testapp/indexes.py @@ -15,13 +15,13 @@ class MockAnyLLM(AnyLLM): - def completion(self, *, model_id, messages): + def completion(self, *, model, messages): return "completion" - def responses(self, *, model_id, input_data): + def responses(self, *, model, input_data): return "responses" - def _embedding(self, *, model_id, inputs): + def _embedding(self, *, model, inputs): return [0, 1, 2] diff --git a/tests/unit/llm/test_llm_service.py b/tests/unit/llm/test_llm_service.py index 84f0fa4..49819c4 100644 --- a/tests/unit/llm/test_llm_service.py +++ b/tests/unit/llm/test_llm_service.py @@ -25,9 +25,8 @@ def test_llm_service_completion_wraps_anyllm(mock_any_llm): ] service = LLMService(client=mock_any_llm, model="mock-model") service.completion(messages) - print(mock_any_llm.completion.call_args_list) mock_any_llm.completion.assert_called_once_with( - model_id="mock-model", messages=messages + model="mock-model", messages=messages ) @@ -36,7 +35,7 @@ def test_llm_service_responses_wraps_anyllm(mock_any_llm): service = LLMService(client=mock_any_llm, model="mock-model") service.responses(prompt) mock_any_llm.responses.assert_called_once_with( - model_id="mock-model", input_data=prompt + model="mock-model", input_data=prompt ) @@ -44,6 +43,4 @@ def test_llm_service_embedding_wraps_anyllm(mock_any_llm): prompt = "What is the airspeed velocity of an unladen swallow?" service = LLMService(client=mock_any_llm, model="mock-model") service.embedding(prompt) - mock_any_llm._embedding.assert_called_once_with( - model_id="mock-model", inputs=prompt - ) + mock_any_llm._embedding.assert_called_once_with(model="mock-model", inputs=prompt)