From 3f0c914743d669c19bf8d04b89ab1ad6ec572d55 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Wed, 5 Mar 2025 16:42:36 +0000 Subject: [PATCH 01/23] endpoint to get full workspace config + free --- src/codegate/api/v1.py | 54 +++++++-- src/codegate/db/connection.py | 9 +- src/codegate/workspaces/crud.py | 4 +- tests/api/test_v1_workspaces.py | 200 ++++++++++++++++++++++++++++++++ 4 files changed, 250 insertions(+), 17 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index bba6ab8eb..7206ff332 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -3,7 +3,7 @@ import requests import structlog -from fastapi import APIRouter, Depends, HTTPException, Response +from fastapi import APIRouter, Depends, HTTPException, Query, Response from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError @@ -12,7 +12,7 @@ from codegate import __version__ from codegate.api import v1_models, v1_processing from codegate.db.connection import AlreadyExistsError, DbReader -from codegate.db.models import AlertSeverity, WorkspaceWithModel +from codegate.db.models import AlertSeverity from codegate.providers import crud as provendcrud from codegate.workspaces import crud @@ -209,13 +209,32 @@ async def delete_provider_endpoint( @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) -async def list_workspaces() -> v1_models.ListWorkspacesResponse: - """List all workspaces.""" - wslist = await wscrud.get_workspaces() +async def list_workspaces( + provider_id: Optional[UUID] = Query(None), +) -> v1_models.ListWorkspacesResponse: + """ + List all workspaces. - resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) + Args: + provider_id (Optional[UUID]): Filter workspaces by provider ID. If provided, + will return workspaces where models from the specified provider (e.g., OpenAI, + Anthropic) have been used in workspace muxing rules. Note that you must + refer to a provider by ID, not by name. - return resp + Returns: + ListWorkspacesResponse: A response object containing the list of workspaces. + """ + try: + if provider_id: + wslist = await wscrud.workspaces_by_provider(provider_id) + resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist) + return resp + else: + wslist = await wscrud.get_workspaces() + resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) + return resp + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @v1.get("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name) @@ -584,17 +603,28 @@ async def set_workspace_muxes( @v1.get( - "/workspaces/{provider_id}", + "/workspaces/{workspace_name}", tags=["Workspaces"], generate_unique_id_function=uniq_name, ) -async def list_workspaces_by_provider( - provider_id: UUID, -) -> List[WorkspaceWithModel]: +async def get_workspace_by_name( + workspace_name: str, +) -> v1_models.FullWorkspace: """List workspaces by provider ID.""" try: - return await wscrud.workspaces_by_provider(provider_id) + ws = await wscrud.get_workspace_by_name(workspace_name) + muxes = await wscrud.get_muxes(workspace_name) + + return v1_models.FullWorkspace( + name=ws.name, + config=v1_models.WorkspaceConfig( + custom_instructions=ws.custom_instructions or "", + muxing_rules=muxes, + ), + ) + except crud.WorkspaceDoesNotExistError: + raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 420f27e8a..5375febd7 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -35,7 +35,6 @@ ProviderModel, Session, WorkspaceRow, - WorkspaceWithModel, WorkspaceWithSessionInfo, ) from codegate.db.token_usage import TokenUsageParser @@ -820,11 +819,13 @@ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: ) return workspaces[0] if workspaces else None - async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWithModel]: + async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceRow]: sql = text( """ SELECT - w.id, w.name, m.provider_model_name + w.id, + w.name, + w.custom_instructions FROM workspaces w JOIN muxes m ON w.id = m.workspace_id WHERE m.provider_endpoint_id = :provider_id @@ -833,7 +834,7 @@ async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWi ) conditions = {"provider_id": provider_id} workspaces = await self._exec_select_conditions_to_pydantic( - WorkspaceWithModel, sql, conditions, should_raise=True + WorkspaceRow, sql, conditions, should_raise=True ) return workspaces diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index fbaf5b994..508dd03d1 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -281,7 +281,9 @@ async def get_workspace_by_name(self, workspace_name: str) -> db_models.Workspac raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.") return workspace - async def workspaces_by_provider(self, provider_id: uuid) -> List[db_models.WorkspaceWithModel]: + async def workspaces_by_provider( + self, provider_id: uuid + ) -> List[db_models.WorkspaceWithSessionInfo]: """Get the workspaces by provider.""" workspaces = await self._db_reader.get_workspaces_by_provider(str(provider_id)) diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index 8bfcbfaf3..a961fa7b3 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -70,6 +70,184 @@ def mock_pipeline_factory(): return mock_factory +@pytest.mark.asyncio +async def test_get_workspaces( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test getting all workspaces.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create a provider for muxing rules + provider_payload = { + "name": "test-provider", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xzy", + } + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + provider = response.json() + + # Create first workspace + name_1 = str(uuid()) + workspace_1 = { + "name": name_1, + "config": { + "custom_instructions": "Respond in haiku format", + "muxing_rules": [ + { + "provider_id": provider["id"], + "model": "foo-bar-001", + "matcher": "*.py", + "matcher_type": "filename_match", + } + ], + }, + } + response = await ac.post("/api/v1/workspaces", json=workspace_1) + assert response.status_code == 201 + + # Create second workspace + name_2 = str(uuid()) + workspace_2 = { + "name": name_2, + "config": { + "custom_instructions": "Respond in prose", + "muxing_rules": [ + { + "provider_id": provider["id"], + "model": "foo-bar-002", + "matcher": "*.js", + "matcher_type": "filename_match", + } + ], + }, + } + response = await ac.post("/api/v1/workspaces", json=workspace_2) + assert response.status_code == 201 + + response = await ac.get("/api/v1/workspaces") + assert response.status_code == 200 + workspaces = response.json()["workspaces"] + + # Verify response structure + assert isinstance(workspaces, list) + assert len(workspaces) >= 2 + + workspace_names = [w["name"] for w in workspaces] + assert name_1 in workspace_names + assert name_2 in workspace_names + assert len([n for n in workspace_names if n in [name_1, name_2]]) == 2 + + +@pytest.mark.asyncio +async def test_get_workspaces_filter_by_provider( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["foo-bar-001", "foo-bar-002"], + ), + ): + """Test filtering workspaces by provider ID.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create first provider + provider_payload_1 = { + "name": "provider-1", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + provider_1 = response.json() + + # Create second provider + provider_payload_2 = { + "name": "provider-2", + "description": "", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-456-xyz", + } + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + provider_2 = response.json() + + # Create workspace using provider 1 + workspace_1 = { + "name": str(uuid()), + "config": { + "custom_instructions": "Instructions 1", + "muxing_rules": [ + { + "provider_id": provider_1["id"], + "model": "foo-bar-001", + "matcher": "*.py", + "matcher_type": "filename_match", + } + ], + }, + } + response = await ac.post("/api/v1/workspaces", json=workspace_1) + assert response.status_code == 201 + + # Create workspace using provider 2 + workspace_2 = { + "name": str(uuid()), + "config": { + "custom_instructions": "Instructions 2", + "muxing_rules": [ + { + "provider_id": provider_2["id"], + "model": "foo-bar-002", + "matcher": "*.js", + "matcher_type": "filename_match", + } + ], + }, + } + response = await ac.post("/api/v1/workspaces", json=workspace_2) + assert response.status_code == 201 + + # Test filtering by provider 1 + response = await ac.get(f"/api/v1/workspaces?provider_id={provider_1['id']}") + assert response.status_code == 200 + workspaces = response.json()["workspaces"] + assert len(workspaces) == 1 + assert workspaces[0]["name"] == workspace_1["name"] + + # Test filtering by provider 2 + response = await ac.get(f"/api/v1/workspaces?provider_id={provider_2['id']}") + assert response.status_code == 200 + workspaces = response.json()["workspaces"] + assert len(workspaces) == 1 + assert workspaces[0]["name"] == workspace_2["name"] + + @pytest.mark.asyncio async def test_create_update_workspace_happy_path( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud @@ -146,6 +324,10 @@ async def test_create_update_workspace_happy_path( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 201 + + # Verify created workspace + response = await ac.get(f"/api/v1/workspaces/{name_1}") + assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_1 @@ -184,6 +366,10 @@ async def test_create_update_workspace_happy_path( response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) assert response.status_code == 201 + + # Verify updated workspace + response = await ac.get(f"/api/v1/workspaces/{name_2}") + assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_2 @@ -222,8 +408,15 @@ async def test_create_update_workspace_name_only( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 201 response_body = response.json() + assert response_body["name"] == name_1 + # Verify created workspace + response = await ac.get(f"/api/v1/workspaces/{name_1}") + assert response.status_code == 200 + response_body = response.json() assert response_body["name"] == name_1 + assert response_body["config"]["custom_instructions"] == "" + assert response_body["config"]["muxing_rules"] == [] name_2: str = str(uuid()) @@ -234,8 +427,15 @@ async def test_create_update_workspace_name_only( response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) assert response.status_code == 201 response_body = response.json() + assert response_body["name"] == name_2 + # Verify updated workspace + response = await ac.get(f"/api/v1/workspaces/{name_2}") + assert response.status_code == 200 + response_body = response.json() assert response_body["name"] == name_2 + assert response_body["config"]["custom_instructions"] == "" + assert response_body["config"]["muxing_rules"] == [] @pytest.mark.asyncio From f2e2a2c85057a02121be5a1d68fd4a4ce5a8b04c Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 6 Mar 2025 13:08:59 +0000 Subject: [PATCH 02/23] add `provider_endpoint_type` to muxes table --- ...992_add_provider_endpoint_type_to_muxes.py | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py new file mode 100644 index 000000000..f510dd181 --- /dev/null +++ b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py @@ -0,0 +1,97 @@ +"""add provider_endpoint_type to muxes + +Revision ID: 769f09b6d992 +Revises: 3ec2b4ab569c +Create Date: 2025-03-06 11:30:11.647216+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "769f09b6d992" +down_revision: Union[str, None] = "3ec2b4ab569c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + # Add the new column + op.execute( + """ + ALTER TABLE muxes + ADD COLUMN provider_endpoint_type TEXT; + """ + ) + + # Update the new column with data from provider_endpoints + op.execute( + """ + UPDATE muxes + SET provider_endpoint_type = ( + SELECT provider_type + FROM provider_endpoints + WHERE provider_endpoints.id = muxes.provider_endpoint_id + ); + """ + ) + + # Make the column NOT NULL after populating it + # SQLite is funny about altering columns, so we actually need to clone & + # swap the table + op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") + op.execute("DROP TABLE muxes;") + op.execute(""" + CREATE TABLE muxes ( + id TEXT PRIMARY KEY, + provider_endpoint_id TEXT NOT NULL, + provider_model_name TEXT NOT NULL, + workspace_id TEXT NOT NULL, + matcher_type TEXT NOT NULL, + matcher_blob TEXT NOT NULL, + priority INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + provider_endpoint_type TEXT NOT NULL, + FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) + );""") + op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") + op.execute("DROP TABLE muxes_new;") + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + try: + # Check if the column exists + op.execute( + """ + SELECT provider_endpoint_type + FROM muxes + LIMIT 1; + """ + ) + + # Drop the column only if it exists + op.execute( + """ + ALTER TABLE muxes + DROP COLUMN provider_endpoint_type; + """ + ) + except Exception: + # If there's an error (column doesn't exist), rollback and continue + op.execute("ROLLBACK;") + return + + # Finish transaction + op.execute("COMMIT;") From 1bce00b954a81e1c570dfaac40fad48c7218fce0 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 6 Mar 2025 14:52:06 +0000 Subject: [PATCH 03/23] add `provider_endpoint_name` to muxes table --- ...da6_add_provider_endpoint_name_to_muxes.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py diff --git a/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py b/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py new file mode 100644 index 000000000..397bb924f --- /dev/null +++ b/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py @@ -0,0 +1,98 @@ +"""add provider_endpoint_name to muxes + +Revision ID: 4b81c45b5da6 +Revises: 769f09b6d992 +Create Date: 2025-03-06 13:24:41.123857+00:00 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4b81c45b5da6" +down_revision: Union[str, None] = "769f09b6d992" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + # Add the new column + op.execute( + """ + ALTER TABLE muxes + ADD COLUMN provider_endpoint_name TEXT; + """ + ) + + # Update the new column with data from provider_endpoints + op.execute( + """ + UPDATE muxes + SET provider_endpoint_name = ( + SELECT name + FROM provider_endpoints + WHERE provider_endpoints.id = muxes.provider_endpoint_id + ); + """ + ) + + # Make the column NOT NULL after populating it + # SQLite is funny about altering columns, so we actually need to clone & + # swap the table + op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") + op.execute("DROP TABLE muxes;") + op.execute(""" + CREATE TABLE muxes ( + id TEXT PRIMARY KEY, + provider_endpoint_id TEXT NOT NULL, + provider_model_name TEXT NOT NULL, + workspace_id TEXT NOT NULL, + matcher_type TEXT NOT NULL, + matcher_blob TEXT NOT NULL, + priority INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + provider_endpoint_type TEXT NOT NULL, + provider_endpoint_name TEXT NOT NULL, + FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) + );""") + op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") + op.execute("DROP TABLE muxes_new;") + + # Finish transaction + op.execute("COMMIT;") + + +def downgrade() -> None: + # Begin transaction + op.execute("BEGIN TRANSACTION;") + + try: + # Check if the column exists + op.execute( + """ + SELECT provider_endpoint_name + FROM muxes + LIMIT 1; + """ + ) + + # Drop the column only if it exists + op.execute( + """ + ALTER TABLE muxes + DROP COLUMN provider_endpoint_name; + """ + ) + except Exception: + # If there's an error (column doesn't exist), rollback and continue + op.execute("ROLLBACK;") + return + + # Finish transaction + op.execute("COMMIT;") From a96dbb3d565d82c1bb93fbb24c362f1fd41f0f5c Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 6 Mar 2025 22:51:47 +0000 Subject: [PATCH 04/23] allow mux CRUD without knowledge of provider IDs --- src/codegate/api/v1.py | 53 +++- src/codegate/api/v1_models.py | 12 +- src/codegate/db/connection.py | 98 +++++-- src/codegate/db/models.py | 8 + src/codegate/muxing/models.py | 45 +++- src/codegate/providers/crud/crud.py | 17 +- src/codegate/workspaces/crud.py | 59 +++-- tests/api/test_v1_providers.py | 384 ++++++++++++++++++++++++++++ tests/muxing/test_rulematcher.py | 11 +- 9 files changed, 620 insertions(+), 67 deletions(-) create mode 100644 tests/api/test_v1_providers.py diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ce470def7..2a0617076 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -36,6 +36,25 @@ def uniq_name(route: APIRoute): return f"v1_{route.name}" +async def _add_provider_id_to_mux_rule( + mux_rule: mux_models.MuxRule, +) -> mux_models.MuxRuleWithProviderId: + """ + Convert a `MuxRule` to `MuxRuleWithProviderId` by looking up the provider ID. + Extracts provider name and type from the MuxRule and queries the database to get the ID. + """ + provider = await dbreader.try_get_provider_endpoint_by_name_and_type( + mux_rule.provider_name, + mux_rule.provider_type, + ) + if provider is None: + raise crud.WorkspaceCrudError( + f'Provider "{mux_rule.provider_name}" of type "{mux_rule.provider_type}" not found' # noqa: E501 + ) + + return mux_models.MuxRuleWithProviderId(**mux_rule.model_dump(), provider_id=provider.id) + + class FilterByNameParams(BaseModel): name: Optional[str] = None @@ -94,14 +113,14 @@ async def list_models_by_provider( @v1.get( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def get_provider_endpoint( - provider_id: UUID, + provider_name: str, ) -> v1_models.ProviderEndpoint: - """Get a provider endpoint by ID.""" + """Get a provider endpoint by name.""" try: - provend = await pcrud.get_endpoint_by_id(provider_id) + provend = await pcrud.get_endpoint_by_name(provider_name) except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -278,7 +297,11 @@ async def create_workspace( """Create a new workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules = request.config.muxing_rules if request.config else None + muxing_rules: List[mux_models.MuxRuleWithProviderId] = [] + if request.config and request.config.muxing_rules: + for rule in request.config.muxing_rules: + mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) + muxing_rules.append(mux_rule_with_provider) workspace_row, mux_rules = await wscrud.add_workspace( request.name, custom_instructions, muxing_rules @@ -320,7 +343,11 @@ async def update_workspace( """Update a workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules = request.config.muxing_rules if request.config else None + muxing_rules: List[mux_models.MuxRuleWithProviderId] = [] + if request.config and request.config.muxing_rules: + for rule in request.config.muxing_rules: + mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) + muxing_rules.append(mux_rule_with_provider) workspace_row, mux_rules = await wscrud.update_workspace( workspace_name, @@ -581,7 +608,7 @@ async def get_workspace_muxes( logger.exception("Error while getting workspace") raise HTTPException(status_code=500, detail="Internal server error") - return muxes + return [mux_models.MuxRule.from_mux_rule_with_provider_id(mux) for mux in muxes] @v1.put( @@ -596,7 +623,12 @@ async def set_workspace_muxes( ): """Set the mux rules of a workspace.""" try: - await wscrud.set_muxes(workspace_name, request) + mux_rules = [] + for rule in request: + mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) + mux_rules.append(mux_rule_with_provider) + + await wscrud.set_muxes(workspace_name, mux_rules) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceCrudError as e: @@ -619,7 +651,10 @@ async def get_workspace_by_name( """List workspaces by provider ID.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) - muxes = await wscrud.get_muxes(workspace_name) + muxes = [ + mux_models.MuxRule.from_mux_rule_with_provider_id(mux) + for mux in await wscrud.get_muxes(workspace_name) + ] return v1_models.FullWorkspace( name=ws.name, diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index dff26489e..bc7eaecec 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -256,19 +256,24 @@ class ProviderEndpoint(pydantic.BaseModel): id: Optional[str] = "" name: str description: str = "" - provider_type: db_models.ProviderType + provider_type: str endpoint: str = "" # Some providers have defaults we can leverage auth_type: ProviderAuthType = ProviderAuthType.none @staticmethod def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint": + auth_type = ( + ProviderAuthType.none + if not db_model.auth_type + else ProviderAuthType(db_model.auth_type) + ) return ProviderEndpoint( id=db_model.id, name=db_model.name, description=db_model.description, provider_type=db_model.provider_type, endpoint=db_model.endpoint, - auth_type=db_model.auth_type, + auth_type=auth_type, ) def to_db_model(self) -> db_models.ProviderEndpoint: @@ -309,8 +314,9 @@ class ModelByProvider(pydantic.BaseModel): Note that these are auto-discovered by the provider. """ - name: str provider_id: str + name: str + provider_type: db_models.ProviderType provider_name: str def __str__(self): diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 02158d3f7..97302ee08 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -33,6 +33,7 @@ ProviderAuthMaterial, ProviderEndpoint, ProviderModel, + ProviderModelIntermediate, Session, WorkspaceRow, WorkspaceWithSessionInfo, @@ -494,7 +495,9 @@ async def push_provider_auth_material(self, auth_material: ProviderAuthMaterial) _ = await self._execute_update_pydantic_model(auth_material, sql, should_raise=True) return - async def add_provider_model(self, model: ProviderModel) -> ProviderModel: + async def add_provider_model( + self, model: ProviderModelIntermediate + ) -> ProviderModelIntermediate: sql = text( """ INSERT INTO provider_models (provider_endpoint_id, name) @@ -533,11 +536,13 @@ async def add_mux(self, mux: MuxRule) -> MuxRule: """ INSERT INTO muxes ( id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, - matcher_blob, priority, created_at, updated_at + matcher_blob, priority, created_at, updated_at, + provider_endpoint_type, provider_endpoint_name ) VALUES ( :id, :provider_endpoint_id, :provider_model_name, :workspace_id, - :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP + :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, + CURRENT_TIMESTAMP, :provider_endpoint_type, :provider_endpoint_name ) RETURNING * """ @@ -709,10 +714,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: @@ -922,11 +927,64 @@ async def get_provider_endpoint_by_name(self, provider_name: str) -> Optional[Pr ) return provider[0] if provider else None - async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + async def try_get_provider_endpoint_by_name_and_type( + self, provider_name: str, provider_type: str + ) -> Optional[ProviderEndpoint]: + """ + Best effort attempt to find a provider endpoint matching name and type. + + With shareable workspaces, a user may share a workspace with mux rules + that refer to a provider name & type. + + Another user may want to consume those rules, but may not have the exact + same provider names configured. + + This makes the shareable workspace feature a little more robust. + """ + # First try exact match on both name and type + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at + FROM provider_endpoints + WHERE name = :name AND provider_type = :provider_type + LIMIT 1 + """ + ) + conditions = {"name": provider_name, "provider_type": provider_type} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + if provider: + logger.debug( + f'Found provider "{provider[0].name}" by name "{provider_name}" and type "{provider_type}"' # noqa: E501 + ) + return provider[0] + + # If no exact match, try matching just provider_type sql = text( """ SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at FROM provider_endpoints + WHERE provider_type = :provider_type + LIMIT 1 + """ + ) + conditions = {"provider_type": provider_type} + provider = await self._exec_select_conditions_to_pydantic( + ProviderEndpoint, sql, conditions, should_raise=True + ) + if provider: + logger.debug( + f'Found provider "{provider[0].name}" by type {provider_type}. Name "{provider_name}" did not match any providers.' # noqa: E501 + ) + return provider[0] + return None + + async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[ProviderEndpoint]: + sql = text( + """ + SELECT id, name, description, provider_type, endpoint, auth_type + FROM provider_endpoints WHERE id = :id """ ) @@ -965,10 +1023,11 @@ async def get_provider_endpoints(self) -> List[ProviderEndpoint]: async def get_provider_models_by_provider_id(self, provider_id: str) -> List[ProviderModel]: sql = text( """ - SELECT provider_endpoint_id, name - FROM provider_models - WHERE provider_endpoint_id = :provider_endpoint_id - """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name, pe.provider_type as provider_endpoint_type + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + WHERE pm.provider_endpoint_id = :provider_endpoint_id + """ # noqa: E501 ) conditions = {"provider_endpoint_id": provider_id} models = await self._exec_select_conditions_to_pydantic( @@ -981,10 +1040,11 @@ async def get_provider_model_by_provider_id_and_name( ) -> Optional[ProviderModel]: sql = text( """ - SELECT provider_endpoint_id, name - FROM provider_models - WHERE provider_endpoint_id = :provider_endpoint_id AND name = :name - """ + SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name, pe.provider_type as provider_endpoint_type + FROM provider_models pm + INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id + WHERE pm.provider_endpoint_id = :provider_endpoint_id AND pm.name = :name + """ # noqa: E501 ) conditions = {"provider_endpoint_id": provider_id, "name": model_name} models = await self._exec_select_conditions_to_pydantic( @@ -995,7 +1055,8 @@ async def get_provider_model_by_provider_id_and_name( async def get_all_provider_models(self) -> List[ProviderModel]: sql = text( """ - SELECT pm.provider_endpoint_id, pm.name, pe.name as provider_endpoint_name + SELECT pm.provider_endpoint_id, pm.name, pe.name as + provider_endpoint_name, pe.provider_type as provider_endpoint_type FROM provider_models pm INNER JOIN provider_endpoints pe ON pm.provider_endpoint_id = pe.id """ @@ -1007,7 +1068,8 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]: sql = text( """ SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, - matcher_blob, priority, created_at, updated_at + matcher_blob, priority, created_at, updated_at, + provider_endpoint_type, provider_endpoint_name FROM muxes WHERE workspace_id = :workspace_id ORDER BY priority ASC diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 6f146b34b..b0bc2282c 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -225,8 +225,14 @@ class ProviderAuthMaterial(BaseModel): auth_blob: str +class ProviderModelIntermediate(BaseModel): + provider_endpoint_id: str + name: str + + class ProviderModel(BaseModel): provider_endpoint_id: str + provider_endpoint_type: str provider_endpoint_name: Optional[str] = None name: str @@ -234,6 +240,8 @@ class ProviderModel(BaseModel): class MuxRule(BaseModel): id: str provider_endpoint_id: str + provider_endpoint_type: ProviderType + provider_endpoint_name: str provider_model_name: str workspace_id: str matcher_type: str diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 5637c5b8c..a2aefe943 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -5,6 +5,7 @@ from codegate.clients.clients import ClientType from codegate.db.models import MuxRule as DBMuxRule +from codegate.db.models import ProviderType class MuxMatcherType(str, Enum): @@ -39,9 +40,8 @@ class MuxRule(pydantic.BaseModel): Represents a mux rule for a provider. """ - # Used for exportable workspaces - provider_name: Optional[str] = None - provider_id: str + provider_name: str + provider_type: ProviderType model: str # The type of matcher to use matcher_type: MuxMatcherType @@ -54,13 +54,46 @@ def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: """ Convert a DBMuxRule to a MuxRule. """ - return MuxRule( - provider_id=db_mux_rule.id, + return cls( + provider_name=db_mux_rule.provider_endpoint_name, + provider_type=db_mux_rule.provider_endpoint_type, model=db_mux_rule.provider_model_name, - matcher_type=db_mux_rule.matcher_type, + matcher_type=MuxMatcherType(db_mux_rule.matcher_type), matcher=db_mux_rule.matcher_blob, ) + @classmethod + def from_mux_rule_with_provider_id(cls, rule: "MuxRuleWithProviderId") -> Self: + """ + Convert a MuxRuleWithProviderId to a MuxRule. + """ + return cls( + provider_name=rule.provider_name, + provider_type=rule.provider_type, + model=rule.model, + matcher_type=rule.matcher_type, + matcher=rule.matcher, + ) + + +class MuxRuleWithProviderId(MuxRule): + """ + Represents a mux rule for a provider with provider ID. + Used internally for referring to a mux rule. + """ + + provider_id: str + + @classmethod + def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + """ + Convert a DBMuxRule to a MuxRuleWithProviderId. + """ + return cls( + **MuxRule.from_db_mux_rule(db_mux_rule).model_dump(), + provider_id=db_mux_rule.provider_endpoint_id, + ) + class ThingToMatchMux(pydantic.BaseModel): """ diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 8bba52b87..fed6b0058 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -13,6 +13,7 @@ from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.workspaces import crud as workspace_crud +from src.codegate.db.models import ProviderType logger = structlog.get_logger("codegate") @@ -114,9 +115,9 @@ async def add_endpoint( for model in models: await self._db_writer.add_provider_model( - dbmodels.ProviderModel( - provider_endpoint_id=dbendpoint.id, + dbmodels.ProviderModelIntermediate( name=model, + provider_endpoint_id=dbendpoint.id, ) ) return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) @@ -236,9 +237,9 @@ async def _update_models_for_provider( # Add the models that are in the provider but not in the DB for model in models_set - models_in_db_set: await self._db_writer.add_provider_model( - dbmodels.ProviderModel( - provider_endpoint_id=dbendpoint.id, + dbmodels.ProviderModelIntermediate( name=model, + provider_endpoint_id=dbendpoint.id, ) ) @@ -274,8 +275,8 @@ async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelB outmodels.append( apimodelsv1.ModelByProvider( name=dbmodel.name, - provider_id=dbmodel.provider_endpoint_id, provider_name=dbendpoint.name, + provider_type=dbendpoint.provider_type, ) ) @@ -290,9 +291,10 @@ async def get_all_models(self) -> List[apimodelsv1.ModelByProvider]: ename = dbmodel.provider_endpoint_name if dbmodel.provider_endpoint_name else "" outmodels.append( apimodelsv1.ModelByProvider( - name=dbmodel.name, provider_id=dbmodel.provider_endpoint_id, + name=dbmodel.name, provider_name=ename, + provider_type=dbmodel.provider_endpoint_type, ) ) @@ -383,6 +385,8 @@ async def try_initialize_provider_endpoints( dbmodels.ProviderModel( provider_endpoint_id=provend.id, name=model, + provider_endpoint_type=provend.provider_type, + provider_endpoint_name=provend.name, ) ) ) @@ -393,7 +397,6 @@ async def try_initialize_provider_endpoints( async def try_update_to_provider( provcrud: ProviderCrud, prov: BaseProvider, dbprovend: dbmodels.ProviderEndpoint ): - authm = await provcrud._db_reader.get_auth_material_by_provider_id(str(dbprovend.id)) try: diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 508dd03d1..a4cd803bd 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -2,11 +2,15 @@ from typing import List, Optional, Tuple from uuid import uuid4 as uuid +import structlog + from codegate.db import models as db_models from codegate.db.connection import AlreadyExistsError, DbReader, DbRecorder, DbTransaction from codegate.muxing import models as mux_models from codegate.muxing import rulematcher +logger = structlog.get_logger("codegate") + class WorkspaceCrudError(Exception): pass @@ -43,7 +47,7 @@ async def add_workspace( self, new_workspace_name: str, custom_instructions: Optional[str] = None, - muxing_rules: Optional[List[mux_models.MuxRule]] = None, + muxing_rules: Optional[List[mux_models.MuxRuleWithProviderId]] = None, ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Add a workspace @@ -51,7 +55,7 @@ async def add_workspace( Args: new_workspace_name (str): The name of the workspace system_prompt (Optional[str]): The system prompt for the workspace - muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace + muxing_rules (Optional[List[mux_models.MuxRuleWithProviderId]]): The muxing rules for the workspace """ if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") @@ -92,7 +96,7 @@ async def update_workspace( old_workspace_name: str, new_workspace_name: str, custom_instructions: Optional[str] = None, - muxing_rules: Optional[List[mux_models.MuxRule]] = None, + muxing_rules: Optional[List[mux_models.MuxRuleWithProviderId]] = None, ) -> Tuple[db_models.WorkspaceRow, List[db_models.MuxRule]]: """ Update a workspace @@ -101,8 +105,8 @@ async def update_workspace( old_workspace_name (str): The old name of the workspace new_workspace_name (str): The new name of the workspace system_prompt (Optional[str]): The system prompt for the workspace - muxing_rules (Optional[List[mux_models.MuxRule]]): The muxing rules for the workspace - """ + muxing_rules (Optional[List[mux_models.MuxRuleWithProviderId]]): The muxing rules for the workspace + """ # noqa: E501 if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if old_workspace_name == "": @@ -111,8 +115,6 @@ async def update_workspace( raise WorkspaceCrudError("Cannot rename default workspace.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: raise WorkspaceCrudError(f"Workspace name {new_workspace_name} is reserved.") - if old_workspace_name == new_workspace_name: - raise WorkspaceCrudError("Old and new workspace names are the same.") async with DbTransaction() as transaction: try: @@ -122,11 +124,12 @@ async def update_workspace( f"Workspace {old_workspace_name} does not exist." ) - existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) - if existing_ws: - raise WorkspaceNameAlreadyInUseError( - f"Workspace name {new_workspace_name} is already in use." - ) + if old_workspace_name != new_workspace_name: + existing_ws = await self._db_reader.get_workspace_by_name(new_workspace_name) + if existing_ws: + raise WorkspaceNameAlreadyInUseError( + f"Workspace name {new_workspace_name} is already in use." + ) new_ws = db_models.WorkspaceRow( id=ws.id, name=new_workspace_name, custom_instructions=ws.custom_instructions @@ -143,7 +146,7 @@ async def update_workspace( await transaction.commit() return workspace_renamed, mux_rules - except (WorkspaceNameAlreadyInUseError, WorkspaceDoesNotExistError) as e: + except WorkspaceDoesNotExistError as e: raise e except Exception as e: raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}") @@ -290,7 +293,7 @@ async def workspaces_by_provider( return workspaces - async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: + async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRuleWithProviderId]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: @@ -302,7 +305,9 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: # These are already sorted by priority for dbmux in dbmuxes: muxes.append( - mux_models.MuxRule( + mux_models.MuxRuleWithProviderId( + provider_name=dbmux.provider_endpoint_name, + provider_type=dbmux.provider_endpoint_type, provider_id=dbmux.provider_endpoint_id, model=dbmux.provider_model_name, matcher_type=dbmux.matcher_type, @@ -313,7 +318,7 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRule]: return muxes async def set_muxes( - self, workspace_name: str, muxes: List[mux_models.MuxRule] + self, workspace_name: str, muxes: List[mux_models.MuxRuleWithProviderId] ) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) @@ -326,7 +331,9 @@ async def set_muxes( # Add the new muxes priority = 0 - muxes_with_routes: List[Tuple[mux_models.MuxRule, rulematcher.ModelRoute]] = [] + muxes_with_routes: List[ + Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute] + ] = [] # Verify all models are valid for mux in muxes: @@ -340,6 +347,8 @@ async def set_muxes( new_mux = db_models.MuxRule( id=str(uuid()), provider_endpoint_id=mux.provider_id, + provider_endpoint_type=mux.provider_type, + provider_endpoint_name=mux.provider_name, provider_model_name=mux.model, workspace_id=workspace.id, matcher_type=mux.matcher_type, @@ -359,7 +368,9 @@ async def set_muxes( return dbmuxes - async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.ModelRoute: + async def get_routing_for_mux( + self, mux: mux_models.MuxRuleWithProviderId + ) -> rulematcher.ModelRoute: """Get the routing for a mux Note that this particular mux object is the API model, not the database model. @@ -367,7 +378,7 @@ async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.Mode """ dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id) if not dbprov: - raise WorkspaceCrudError(f"Provider {mux.provider_id} does not exist") + raise WorkspaceCrudError(f'Provider "{mux.provider_name}" does not exist') dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( mux.provider_id, @@ -375,11 +386,13 @@ async def get_routing_for_mux(self, mux: mux_models.MuxRule) -> rulematcher.Mode ) if not dbm: raise WorkspaceCrudError( - f"Model {mux.model} does not exist for provider {mux.provider_id}" + f'Model "{mux.model}" does not exist for provider "{mux.provider_name}"' ) dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_id) if not dbauth: - raise WorkspaceCrudError(f"Auth material for provider {mux.provider_id} does not exist") + raise WorkspaceCrudError( + f'Auth material for provider "{mux.provider_name}" does not exist' + ) return rulematcher.ModelRoute( endpoint=dbprov, @@ -395,7 +408,7 @@ async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.Mo """ dbprov = await self._db_reader.get_provider_endpoint_by_id(mux.provider_endpoint_id) if not dbprov: - raise WorkspaceCrudError(f"Provider {mux.provider_endpoint_id} does not exist") + raise WorkspaceCrudError(f'Provider "{mux.provider_endpoint_name}" does not exist') dbm = await self._db_reader.get_provider_model_by_provider_id_and_name( mux.provider_endpoint_id, @@ -409,7 +422,7 @@ async def get_routing_for_db_mux(self, mux: db_models.MuxRule) -> rulematcher.Mo dbauth = await self._db_reader.get_auth_material_by_provider_id(mux.provider_endpoint_id) if not dbauth: raise WorkspaceCrudError( - f"Auth material for provider {mux.provider_endpoint_id} does not exist" + f'Auth material for provider "{mux.provider_endpoint_name}" does not exist' ) return rulematcher.ModelRoute( diff --git a/tests/api/test_v1_providers.py b/tests/api/test_v1_providers.py new file mode 100644 index 000000000..a4bceec09 --- /dev/null +++ b/tests/api/test_v1_providers.py @@ -0,0 +1,384 @@ +from pathlib import Path +from unittest.mock import MagicMock, patch +from uuid import uuid4 as uuid + +import httpx +import pytest +import structlog +from httpx import AsyncClient + +from codegate.db import connection +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.crud.crud import ProviderCrud +from codegate.server import init_app +from codegate.workspaces.crud import WorkspaceCrud + +logger = structlog.get_logger("codegate") + +# TODO: Abstract the mock DB setup + + +@pytest.fixture +def db_path(): + """Creates a temporary database file path.""" + current_test_dir = Path(__file__).parent + db_filepath = current_test_dir / f"codegate_test_{uuid()}.db" + db_fullpath = db_filepath.absolute() + connection.init_db_sync(str(db_fullpath)) + yield db_fullpath + if db_fullpath.is_file(): + db_fullpath.unlink() + + +@pytest.fixture() +def db_recorder(db_path) -> connection.DbRecorder: + """Creates a DbRecorder instance with test database.""" + return connection.DbRecorder(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def db_reader(db_path) -> connection.DbReader: + """Creates a DbReader instance with test database.""" + return connection.DbReader(sqlite_path=db_path, _no_singleton=True) + + +@pytest.fixture() +def mock_workspace_crud(db_recorder, db_reader) -> WorkspaceCrud: + """Creates a WorkspaceCrud instance with test database.""" + ws_crud = WorkspaceCrud() + ws_crud._db_reader = db_reader + ws_crud._db_recorder = db_recorder + return ws_crud + + +@pytest.fixture() +def mock_provider_crud(db_recorder, db_reader, mock_workspace_crud) -> ProviderCrud: + """Creates a ProviderCrud instance with test database.""" + p_crud = ProviderCrud() + p_crud._db_reader = db_reader + p_crud._db_writer = db_recorder + p_crud._ws_crud = mock_workspace_crud + return p_crud + + +@pytest.fixture +def mock_pipeline_factory(): + """Create a mock pipeline factory.""" + mock_factory = MagicMock(spec=PipelineFactory) + mock_factory.create_input_pipeline.return_value = MagicMock() + mock_factory.create_fim_pipeline.return_value = MagicMock() + mock_factory.create_output_pipeline.return_value = MagicMock() + mock_factory.create_fim_output_pipeline.return_value = MagicMock() + return mock_factory + + +@pytest.mark.asyncio +async def test_providers_crud( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating multiple providers and listing them.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create first provider (OpenAI) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + provider1_response = response.json() + assert provider1_response["name"] == provider_payload_1["name"] + assert provider1_response["description"] == provider_payload_1["description"] + assert provider1_response["auth_type"] == provider_payload_1["auth_type"] + assert provider1_response["provider_type"] == provider_payload_1["provider_type"] + assert provider1_response["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider1_response.get("id", ""), str) and provider1_response["id"] + + # Create second provider (OpenRouter) + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + provider2_response = response.json() + assert provider2_response["name"] == provider_payload_2["name"] + assert provider2_response["description"] == provider_payload_2["description"] + assert provider2_response["auth_type"] == provider_payload_2["auth_type"] + assert provider2_response["provider_type"] == provider_payload_2["provider_type"] + assert provider2_response["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider2_response.get("id", ""), str) and provider2_response["id"] + + # List all providers + response = await ac.get("/api/v1/provider-endpoints") + assert response.status_code == 200 + providers = response.json() + + # Verify both providers exist in the list + assert isinstance(providers, list) + assert len(providers) == 2 + + # Verify fields for first provider + provider1 = next(p for p in providers if p["name"] == "openai-provider") + assert provider1["description"] == provider_payload_1["description"] + assert provider1["auth_type"] == provider_payload_1["auth_type"] + assert provider1["provider_type"] == provider_payload_1["provider_type"] + assert provider1["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider1.get("id", ""), str) and provider1["id"] + + # Verify fields for second provider + provider2 = next(p for p in providers if p["name"] == "openrouter-provider") + assert provider2["description"] == provider_payload_2["description"] + assert provider2["auth_type"] == provider_payload_2["auth_type"] + assert provider2["provider_type"] == provider_payload_2["provider_type"] + assert provider2["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider2.get("id", ""), str) and provider2["id"] + + # Get OpenAI provider by name + response = await ac.get("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == provider_payload_1["name"] + assert provider["description"] == provider_payload_1["description"] + assert provider["auth_type"] == provider_payload_1["auth_type"] + assert provider["provider_type"] == provider_payload_1["provider_type"] + assert provider["endpoint"] == provider_payload_1["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Get OpenRouter provider by name + response = await ac.get("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == provider_payload_2["name"] + assert provider["description"] == provider_payload_2["description"] + assert provider["auth_type"] == provider_payload_2["auth_type"] + assert provider["provider_type"] == provider_payload_2["provider_type"] + assert provider["endpoint"] == provider_payload_2["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Test getting non-existent provider + response = await ac.get("/api/v1/provider-endpoints/non-existent") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + +@pytest.mark.asyncio +async def test_list_providers_by_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating multiple providers and listing them by name.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create first provider (OpenAI) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + # Create second provider (OpenRouter) + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Test querying providers by name + response = await ac.get("/api/v1/provider-endpoints?name=openai-provider") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 1 + assert providers[0]["name"] == "openai-provider" + assert isinstance(providers[0]["id"], str) and providers[0]["id"] + + response = await ac.get("/api/v1/provider-endpoints?name=openrouter-provider") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 1 + assert providers[0]["name"] == "openrouter-provider" + assert isinstance(providers[0]["id"], str) and providers[0]["id"] + + +@pytest.mark.asyncio +async def test_list_all_provider_models( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing all models from all providers.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create OpenAI provider + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + # Create OpenRouter provider + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Get all models + response = await ac.get("/api/v1/provider-endpoints/models") + assert response.status_code == 200 + models = response.json() + + # Verify response structure and content + assert isinstance(models, list) + assert len(models) == 4 + + # Verify models list structure + assert all(isinstance(model, dict) for model in models) + assert all("name" in model for model in models) + assert all("provider_type" in model for model in models) + assert all("provider_name" in model for model in models) + + # Verify OpenAI provider models + openai_models = [m for m in models if m["provider_name"] == "openai-provider"] + assert len(openai_models) == 2 + assert all(m["provider_type"] == "openai" for m in openai_models) + + # Verify OpenRouter provider models + openrouter_models = [m for m in models if m["provider_name"] == "openrouter-provider"] + assert len(openrouter_models) == 2 + assert all(m["provider_type"] == "openrouter" for m in openrouter_models) + + +@pytest.mark.asyncio +async def test_list_models_by_provider( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing models for a specific provider.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create OpenAI provider + provider_payload = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + provider = response.json() + provider_id = provider["id"] + + # Get models for the provider + response = await ac.get(f"/api/v1/provider-endpoints/{provider_id}/models") + assert response.status_code == 200 + models = response.json() + + # Verify response structure and content + assert isinstance(models, list) + assert len(models) == 2 + assert all(isinstance(model, dict) for model in models) + assert all("name" in model for model in models) + assert all("provider_type" in model for model in models) + assert all("provider_name" in model for model in models) + assert all(model["provider_type"] == "openai" for model in models) + assert all(model["provider_name"] == "openai-provider" for model in models) + + # Test with non-existent provider ID + fake_uuid = str(uuid()) + response = await ac.get(f"/api/v1/provider-endpoints/{fake_uuid}/models") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider not found" diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 7e551525c..6feec7cb4 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -8,7 +8,10 @@ mocked_route_openai = rulematcher.ModelRoute( db_models.ProviderModel( - provider_endpoint_id="1", provider_endpoint_name="fake-openai", name="fake-gpt" + provider_endpoint_id="1", + provider_endpoint_name="fake-openai", + provider_endpoint_type=db_models.ProviderType.openai, + name="fake-gpt", ), db_models.ProviderEndpoint( id="1", @@ -70,6 +73,8 @@ def test_file_matcher( model="fake-gpt", matcher_type="filename_match", matcher=matcher, + provider_name="fake-openai", + provider_type=db_models.ProviderType.openai, ) muxing_rule_matcher = rulematcher.FileMuxingRuleMatcher(mocked_route_openai, mux_rule) # We mock the _extract_request_filenames method to return a list of filenames @@ -120,6 +125,8 @@ def test_request_file_matcher( model="fake-gpt", matcher_type=matcher_type, matcher=matcher, + provider_name="fake-openai", + provider_type=db_models.ProviderType.openai, ) muxing_rule_matcher = rulematcher.RequestTypeAndFileMuxingRuleMatcher( mocked_route_openai, mux_rule @@ -167,6 +174,8 @@ def test_muxing_matcher_factory(matcher_type, expected_class): matcher_type=matcher_type, matcher_blob="fake-matcher", priority=1, + provider_endpoint_name="fake-openai", + provider_endpoint_type=db_models.ProviderType.openai, ) if expected_class: assert isinstance( From 5719cac1a3287985a060d42f85318417faf4d6cf Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 6 Mar 2025 23:14:50 +0000 Subject: [PATCH 05/23] tests & tidy ups --- ...992_add_provider_endpoint_type_to_muxes.py | 6 +- ...da6_add_provider_endpoint_name_to_muxes.py | 6 +- src/codegate/api/v1.py | 4 +- src/codegate/api/v1_models.py | 1 - src/codegate/db/connection.py | 8 +- src/codegate/providers/crud/crud.py | 2 - src/codegate/workspaces/crud.py | 10 +- tests/api/test_v1_workspaces.py | 349 ++++++------------ 8 files changed, 143 insertions(+), 243 deletions(-) diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py index f510dd181..2ce7928de 100644 --- a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py +++ b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py @@ -46,7 +46,8 @@ def upgrade() -> None: # swap the table op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") op.execute("DROP TABLE muxes;") - op.execute(""" + op.execute( + """ CREATE TABLE muxes ( id TEXT PRIMARY KEY, provider_endpoint_id TEXT NOT NULL, @@ -59,7 +60,8 @@ def upgrade() -> None: updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, provider_endpoint_type TEXT NOT NULL, FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) - );""") + );""" + ) op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") op.execute("DROP TABLE muxes_new;") diff --git a/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py b/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py index 397bb924f..15cf26837 100644 --- a/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py +++ b/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py @@ -46,7 +46,8 @@ def upgrade() -> None: # swap the table op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") op.execute("DROP TABLE muxes;") - op.execute(""" + op.execute( + """ CREATE TABLE muxes ( id TEXT PRIMARY KEY, provider_endpoint_id TEXT NOT NULL, @@ -60,7 +61,8 @@ def upgrade() -> None: provider_endpoint_type TEXT NOT NULL, provider_endpoint_name TEXT NOT NULL, FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) - );""") + );""" + ) op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") op.execute("DROP TABLE muxes_new;") diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 2a0617076..dd69daf2f 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -109,6 +109,7 @@ async def list_models_by_provider( except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider not found") except Exception as e: + logger.debug(f"Error listing models by provider, {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -334,7 +335,7 @@ async def create_workspace( "/workspaces/{workspace_name}", tags=["Workspaces"], generate_unique_id_function=uniq_name, - status_code=201, + status_code=200, ) async def update_workspace( workspace_name: str, @@ -368,6 +369,7 @@ async def update_workspace( ), ) except crud.WorkspaceCrudError as e: + logger.debug(f"Could not update workspace: {e}") raise HTTPException(status_code=400, detail=str(e)) except Exception: raise HTTPException(status_code=500, detail="Internal server error") diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index bc7eaecec..d79fb98c2 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -314,7 +314,6 @@ class ModelByProvider(pydantic.BaseModel): Note that these are auto-discovered by the provider. """ - provider_id: str name: str provider_type: db_models.ProviderType provider_name: str diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 97302ee08..54d610650 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -714,10 +714,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index fed6b0058..1e6142ee5 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -13,7 +13,6 @@ from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.workspaces import crud as workspace_crud -from src.codegate.db.models import ProviderType logger = structlog.get_logger("codegate") @@ -291,7 +290,6 @@ async def get_all_models(self) -> List[apimodelsv1.ModelByProvider]: ename = dbmodel.provider_endpoint_name if dbmodel.provider_endpoint_name else "" outmodels.append( apimodelsv1.ModelByProvider( - provider_id=dbmodel.provider_endpoint_id, name=dbmodel.name, provider_name=ename, provider_type=dbmodel.provider_endpoint_type, diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index a4cd803bd..22859fe19 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -56,7 +56,7 @@ async def add_workspace( new_workspace_name (str): The name of the workspace system_prompt (Optional[str]): The system prompt for the workspace muxing_rules (Optional[List[mux_models.MuxRuleWithProviderId]]): The muxing rules for the workspace - """ + """ # noqa: E501 if new_workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if new_workspace_name in RESERVED_WORKSPACE_KEYWORDS: @@ -146,7 +146,7 @@ async def update_workspace( await transaction.commit() return workspace_renamed, mux_rules - except WorkspaceDoesNotExistError as e: + except (WorkspaceDoesNotExistError, WorkspaceNameAlreadyInUseError) as e: raise e except Exception as e: raise WorkspaceCrudError(f"Error updating workspace {old_workspace_name}: {str(e)}") @@ -331,9 +331,9 @@ async def set_muxes( # Add the new muxes priority = 0 - muxes_with_routes: List[ - Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute] - ] = [] + muxes_with_routes: List[Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute]] = ( + [] + ) # Verify all models are valid for mux in muxes: diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index a961fa7b3..e7f410795 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -71,244 +71,69 @@ def mock_pipeline_factory(): @pytest.mark.asyncio -async def test_get_workspaces( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +async def test_workspace_crud( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( + patch("codegate.api.v1.dbreader", db_reader), patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], ), - ): - """Test getting all workspaces.""" - app = init_app(mock_pipeline_factory) - - async with AsyncClient( - transport=httpx.ASGITransport(app=app), base_url="http://test" - ) as ac: - # Create a provider for muxing rules - provider_payload = { - "name": "test-provider", - "description": "", - "auth_type": "none", - "provider_type": "openai", - "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", - } - response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) - assert response.status_code == 201 - provider = response.json() - - # Create first workspace - name_1 = str(uuid()) - workspace_1 = { - "name": name_1, - "config": { - "custom_instructions": "Respond in haiku format", - "muxing_rules": [ - { - "provider_id": provider["id"], - "model": "foo-bar-001", - "matcher": "*.py", - "matcher_type": "filename_match", - } - ], - }, - } - response = await ac.post("/api/v1/workspaces", json=workspace_1) - assert response.status_code == 201 - - # Create second workspace - name_2 = str(uuid()) - workspace_2 = { - "name": name_2, - "config": { - "custom_instructions": "Respond in prose", - "muxing_rules": [ - { - "provider_id": provider["id"], - "model": "foo-bar-002", - "matcher": "*.js", - "matcher_type": "filename_match", - } - ], - }, - } - response = await ac.post("/api/v1/workspaces", json=workspace_2) - assert response.status_code == 201 - - response = await ac.get("/api/v1/workspaces") - assert response.status_code == 200 - workspaces = response.json()["workspaces"] - - # Verify response structure - assert isinstance(workspaces, list) - assert len(workspaces) >= 2 - - workspace_names = [w["name"] for w in workspaces] - assert name_1 in workspace_names - assert name_2 in workspace_names - assert len([n for n in workspace_names if n in [name_1, name_2]]) == 2 - - -@pytest.mark.asyncio -async def test_get_workspaces_filter_by_provider( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud -) -> None: - with ( - patch("codegate.api.v1.wscrud", mock_workspace_crud), - patch("codegate.api.v1.pcrud", mock_provider_crud), patch( - "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): - """Test filtering workspaces by provider ID.""" - app = init_app(mock_pipeline_factory) - - async with AsyncClient( - transport=httpx.ASGITransport(app=app), base_url="http://test" - ) as ac: - # Create first provider - provider_payload_1 = { - "name": "provider-1", - "description": "", - "auth_type": "none", - "provider_type": "openai", - "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xyz", - } - response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) - assert response.status_code == 201 - provider_1 = response.json() - - # Create second provider - provider_payload_2 = { - "name": "provider-2", - "description": "", - "auth_type": "none", - "provider_type": "openai", - "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-456-xyz", - } - response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) - assert response.status_code == 201 - provider_2 = response.json() - - # Create workspace using provider 1 - workspace_1 = { - "name": str(uuid()), - "config": { - "custom_instructions": "Instructions 1", - "muxing_rules": [ - { - "provider_id": provider_1["id"], - "model": "foo-bar-001", - "matcher": "*.py", - "matcher_type": "filename_match", - } - ], - }, - } - response = await ac.post("/api/v1/workspaces", json=workspace_1) - assert response.status_code == 201 - - # Create workspace using provider 2 - workspace_2 = { - "name": str(uuid()), - "config": { - "custom_instructions": "Instructions 2", - "muxing_rules": [ - { - "provider_id": provider_2["id"], - "model": "foo-bar-002", - "matcher": "*.js", - "matcher_type": "filename_match", - } - ], - }, - } - response = await ac.post("/api/v1/workspaces", json=workspace_2) - assert response.status_code == 201 - - # Test filtering by provider 1 - response = await ac.get(f"/api/v1/workspaces?provider_id={provider_1['id']}") - assert response.status_code == 200 - workspaces = response.json()["workspaces"] - assert len(workspaces) == 1 - assert workspaces[0]["name"] == workspace_1["name"] - - # Test filtering by provider 2 - response = await ac.get(f"/api/v1/workspaces?provider_id={provider_2['id']}") - assert response.status_code == 200 - workspaces = response.json()["workspaces"] - assert len(workspaces) == 1 - assert workspaces[0]["name"] == workspace_2["name"] - - -@pytest.mark.asyncio -async def test_create_update_workspace_happy_path( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud -) -> None: - with ( - patch("codegate.api.v1.wscrud", mock_workspace_crud), - patch("codegate.api.v1.pcrud", mock_provider_crud), - patch( - "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], - ), - ): - """Test creating & updating a workspace (happy path).""" + """Test creating, updating and reading a workspace.""" app = init_app(mock_pipeline_factory) provider_payload_1 = { - "name": "foo", - "description": "", + "name": "openai-provider", + "description": "OpenAI provider description", "auth_type": "none", "provider_type": "openai", "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "api_key": "sk-proj-foo-bar-123-xyz", } provider_payload_2 = { - "name": "bar", - "description": "", + "name": "openrouter-provider", + "description": "OpenRouter provider description", "auth_type": "none", - "provider_type": "openai", - "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", } async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: - # Create the first provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) assert response.status_code == 201 - provider_1 = response.json() - # Create the second provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) assert response.status_code == 201 - provider_2 = response.json() + + # Create workspace name_1: str = str(uuid()) custom_instructions_1: str = "Respond to every request in iambic pentameter" muxing_rules_1 = [ { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_1["id"], - "model": "foo-bar-001", + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", "matcher": "*.ts", "matcher_type": "filename_match", }, { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_2["id"], - "model": "foo-bar-002", + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", "matcher_type": "catch_all", "matcher": "", }, @@ -333,6 +158,8 @@ async def test_create_update_workspace_happy_path( assert response_body["name"] == name_1 assert response_body["config"]["custom_instructions"] == custom_instructions_1 for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules_1[i]["provider_name"] + assert rule["provider_type"] == muxing_rules_1[i]["provider_type"] assert rule["model"] == muxing_rules_1[i]["model"] assert rule["matcher"] == muxing_rules_1[i]["matcher"] assert rule["matcher_type"] == muxing_rules_1[i]["matcher_type"] @@ -341,16 +168,16 @@ async def test_create_update_workspace_happy_path( custom_instructions_2: str = "Respond to every request in cockney rhyming slang" muxing_rules_2 = [ { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_2["id"], - "model": "foo-bar-002", + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", "matcher": "*.ts", "matcher_type": "filename_match", }, { - "provider_name": None, # optional & not implemented yet - "provider_id": provider_1["id"], - "model": "foo-bar-001", + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "deepseek/deepseek-r1", "matcher_type": "catch_all", "matcher": "", }, @@ -365,7 +192,7 @@ async def test_create_update_workspace_happy_path( } response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) - assert response.status_code == 201 + assert response.status_code == 200 # Verify updated workspace response = await ac.get(f"/api/v1/workspaces/{name_2}") @@ -375,34 +202,88 @@ async def test_create_update_workspace_happy_path( assert response_body["name"] == name_2 assert response_body["config"]["custom_instructions"] == custom_instructions_2 for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules_2[i]["provider_name"] + assert rule["provider_type"] == muxing_rules_2[i]["provider_type"] assert rule["model"] == muxing_rules_2[i]["model"] assert rule["matcher"] == muxing_rules_2[i]["matcher"] assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"] @pytest.mark.asyncio -async def test_create_update_workspace_name_only( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +async def test_rename_workspace( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( + patch("codegate.api.v1.dbreader", db_reader), patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): - """Test creating & updating a workspace (happy path).""" + """Test renaming a workspace.""" app = init_app(mock_pipeline_factory) + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + name_1: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + muxing_rules = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] payload_create = { "name": name_1, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": muxing_rules, + }, } response = await ac.post("/api/v1/workspaces", json=payload_create) @@ -415,8 +296,6 @@ async def test_create_update_workspace_name_only( assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_1 - assert response_body["config"]["custom_instructions"] == "" - assert response_body["config"]["muxing_rules"] == [] name_2: str = str(uuid()) @@ -425,17 +304,24 @@ async def test_create_update_workspace_name_only( } response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update) - assert response.status_code == 201 + assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_2 + # other fields shouldn't have been touched + assert response_body["config"]["custom_instructions"] == custom_instructions + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules[i]["provider_name"] + assert rule["provider_type"] == muxing_rules[i]["provider_type"] + assert rule["model"] == muxing_rules[i]["model"] + assert rule["matcher"] == muxing_rules[i]["matcher"] + assert rule["matcher_type"] == muxing_rules[i]["matcher_type"] + # Verify updated workspace response = await ac.get(f"/api/v1/workspaces/{name_2}") assert response.status_code == 200 response_body = response.json() assert response_body["name"] == name_2 - assert response_body["config"]["custom_instructions"] == "" - assert response_body["config"]["muxing_rules"] == [] @pytest.mark.asyncio @@ -447,7 +333,11 @@ async def test_create_workspace_name_already_in_use( patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test creating a workspace when the name is already in use.""" @@ -482,7 +372,11 @@ async def test_rename_workspace_name_already_in_use( patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test renaming a workspace when the new name is already in use.""" @@ -522,14 +416,19 @@ async def test_rename_workspace_name_already_in_use( @pytest.mark.asyncio async def test_create_workspace_with_nonexistent_model_in_muxing_rule( - mock_pipeline_factory, mock_workspace_crud, mock_provider_crud + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( + patch("codegate.api.v1.dbreader", db_reader), patch("codegate.api.v1.wscrud", mock_workspace_crud), patch("codegate.api.v1.pcrud", mock_provider_crud), patch( "codegate.providers.openai.provider.OpenAIProvider.models", - return_value=["foo-bar-001", "foo-bar-002"], + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], ), ): """Test creating a workspace with a muxing rule that uses a nonexistent model.""" @@ -537,28 +436,26 @@ async def test_create_workspace_with_nonexistent_model_in_muxing_rule( app = init_app(mock_pipeline_factory) provider_payload = { - "name": "foo", - "description": "", + "name": "openai-provider", + "description": "OpenAI provider description", "auth_type": "none", "provider_type": "openai", "endpoint": "https://api.openai.com", - "api_key": "sk-proj-foo-bar-123-xzy", + "api_key": "sk-proj-foo-bar-123-xyz", } async with AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test" ) as ac: - # Create the first provider response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) assert response.status_code == 201 - provider = response.json() name: str = str(uuid()) custom_instructions: str = "Respond to every request in iambic pentameter" muxing_rules = [ { - "provider_name": None, - "provider_id": provider["id"], + "provider_name": "openai-provider", + "provider_type": "openai", "model": "nonexistent-model", "matcher": "*.ts", "matcher_type": "filename_match", @@ -575,4 +472,4 @@ async def test_create_workspace_with_nonexistent_model_in_muxing_rule( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 400 - assert "Model nonexistent-model does not exist" in response.json()["detail"] + assert "does not exist" in response.json()["detail"] From 844421ac3c95051cf615b190e39751f5069e600a Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 6 Mar 2025 23:24:34 +0000 Subject: [PATCH 06/23] fix type nit --- src/codegate/api/v1_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index d79fb98c2..7976d38c5 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -256,7 +256,7 @@ class ProviderEndpoint(pydantic.BaseModel): id: Optional[str] = "" name: str description: str = "" - provider_type: str + provider_type: db_models.ProviderType endpoint: str = "" # Some providers have defaults we can leverage auth_type: ProviderAuthType = ProviderAuthType.none From 80a9471a649f47b1f397cf21121e6f4a6a3ef050 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 00:34:46 +0000 Subject: [PATCH 07/23] bug fixes and tests --- api/openapi.json | 263 +++++++++++++++----------------- src/codegate/api/v1.py | 23 ++- src/codegate/db/connection.py | 12 +- src/codegate/workspaces/crud.py | 20 ++- tests/api/test_v1_providers.py | 29 ++++ tests/api/test_v1_workspaces.py | 189 ++++++++++++++++++++++- 6 files changed, 379 insertions(+), 157 deletions(-) diff --git a/api/openapi.json b/api/openapi.json index 4ea57d0e1..ad8823642 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -197,24 +197,23 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}": { + "/api/v1/provider-endpoints/{provider_name}": { "get": { "tags": [ "CodeGate API", "Providers" ], "summary": "Get Provider Endpoint", - "description": "Get a provider endpoint by ID.", + "description": "Get a provider endpoint by name.", "operationId": "v1_get_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -241,44 +240,31 @@ } } }, - "put": { + "delete": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Update Provider Endpoint", - "description": "Update a provider endpoint by ID.", - "operationId": "v1_update_provider_endpoint", + "summary": "Delete Provider Endpoint", + "description": "Delete a provider endpoint by id.", + "operationId": "v1_delete_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ProviderEndpoint" - } - } - } - }, "responses": { "200": { "description": "Successful Response", "content": { "application/json": { - "schema": { - "$ref": "#/components/schemas/ProviderEndpoint" - } + "schema": {} } } }, @@ -293,15 +279,17 @@ } } } - }, - "delete": { + } + }, + "/api/v1/provider-endpoints/{provider_id}/auth-material": { + "put": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Delete Provider Endpoint", - "description": "Delete a provider endpoint by id.", - "operationId": "v1_delete_provider_endpoint", + "summary": "Configure Auth Material", + "description": "Configure auth material for a provider.", + "operationId": "v1_configure_auth_material", "parameters": [ { "name": "provider_id", @@ -314,14 +302,19 @@ } } ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": {} + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ConfigureAuthMaterial" } } + } + }, + "responses": { + "204": { + "description": "Successful Response" }, "422": { "description": "Validation Error", @@ -336,15 +329,15 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}/auth-material": { + "/api/v1/provider-endpoints/{provider_id}": { "put": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Configure Auth Material", - "description": "Configure auth material for a provider.", - "operationId": "v1_configure_auth_material", + "summary": "Update Provider Endpoint", + "description": "Update a provider endpoint by ID.", + "operationId": "v1_update_provider_endpoint", "parameters": [ { "name": "provider_id", @@ -362,14 +355,21 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ConfigureAuthMaterial" + "$ref": "#/components/schemas/ProviderEndpoint" } } } }, "responses": { - "204": { - "description": "Successful Response" + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderEndpoint" + } + } + } }, "422": { "description": "Validation Error", @@ -391,8 +391,27 @@ "Workspaces" ], "summary": "List Workspaces", - "description": "List all workspaces.", + "description": "List all workspaces.\n\nArgs:\n provider_id (Optional[UUID]): Filter workspaces by provider ID. If provided,\n will return workspaces where models from the specified provider (e.g., OpenAI,\n Anthropic) have been used in workspace muxing rules. Note that you must\n refer to a provider by ID, not by name.\n\nReturns:\n ListWorkspacesResponse: A response object containing the list of workspaces.", "operationId": "v1_list_workspaces", + "parameters": [ + { + "name": "provider_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string", + "format": "uuid" + }, + { + "type": "null" + } + ], + "title": "Provider Id" + } + } + ], "responses": { "200": { "description": "Successful Response", @@ -403,6 +422,16 @@ } } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } } } }, @@ -415,14 +444,14 @@ "description": "Create a new workspace.", "operationId": "v1_create_workspace", "requestBody": { + "required": true, "content": { "application/json": { "schema": { "$ref": "#/components/schemas/FullWorkspace-Input" } } - }, - "required": true + } }, "responses": { "201": { @@ -552,7 +581,7 @@ } }, "responses": { - "201": { + "200": { "description": "Successful Response", "content": { "application/json": { @@ -613,6 +642,48 @@ } } } + }, + "get": { + "tags": [ + "CodeGate API", + "Workspaces" + ], + "summary": "Get Workspace By Name", + "description": "List workspaces by provider ID.", + "operationId": "v1_get_workspace_by_name", + "parameters": [ + { + "name": "workspace_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workspace Name" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/FullWorkspace-Output" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } } }, "/api/v1/workspaces/archive": { @@ -1085,55 +1156,6 @@ } } }, - "/api/v1/workspaces/{provider_id}": { - "get": { - "tags": [ - "CodeGate API", - "Workspaces" - ], - "summary": "List Workspaces By Provider", - "description": "List workspaces by provider ID.", - "operationId": "v1_list_workspaces_by_provider", - "parameters": [ - { - "name": "provider_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "format": "uuid", - "title": "Provider Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/WorkspaceWithModel" - }, - "title": "Response V1 List Workspaces By Provider" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, "/api/v1/alerts_notification": { "get": { "tags": [ @@ -1950,9 +1972,8 @@ "type": "string", "title": "Name" }, - "provider_id": { - "type": "string", - "title": "Provider Id" + "provider_type": { + "$ref": "#/components/schemas/ProviderType" }, "provider_name": { "type": "string", @@ -1962,7 +1983,7 @@ "type": "object", "required": [ "name", - "provider_id", + "provider_type", "provider_name" ], "title": "ModelByProvider", @@ -1982,19 +2003,11 @@ "MuxRule": { "properties": { "provider_name": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], + "type": "string", "title": "Provider Name" }, - "provider_id": { - "type": "string", - "title": "Provider Id" + "provider_type": { + "$ref": "#/components/schemas/ProviderType" }, "model": { "type": "string", @@ -2017,7 +2030,8 @@ }, "type": "object", "required": [ - "provider_id", + "provider_name", + "provider_type", "model", "matcher_type" ], @@ -2348,31 +2362,6 @@ "muxing_rules" ], "title": "WorkspaceConfig" - }, - "WorkspaceWithModel": { - "properties": { - "id": { - "type": "string", - "title": "Id" - }, - "name": { - "type": "string", - "pattern": "^[a-zA-Z0-9_-]+$", - "title": "Name" - }, - "provider_model_name": { - "type": "string", - "title": "Provider Model Name" - } - }, - "type": "object", - "required": [ - "id", - "name", - "provider_model_name" - ], - "title": "WorkspaceWithModel", - "description": "Returns a workspace ID with model name" } } } diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index dd69daf2f..43d22c16d 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -219,14 +219,17 @@ async def update_provider_endpoint( @v1.delete( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def delete_provider_endpoint( - provider_id: UUID, + provider_name: str, ): - """Delete a provider endpoint by id.""" + """Delete a provider endpoint by name.""" try: - await pcrud.delete_endpoint(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + if provider is None: + raise provendcrud.ProviderNotFoundError + await pcrud.delete_endpoint(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: @@ -236,7 +239,7 @@ async def delete_provider_endpoint( @v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name) async def list_workspaces( - provider_id: Optional[UUID] = Query(None), + provider_name: Optional[str] = Query(None), ) -> v1_models.ListWorkspacesResponse: """ List all workspaces. @@ -251,8 +254,9 @@ async def list_workspaces( ListWorkspacesResponse: A response object containing the list of workspaces. """ try: - if provider_id: - wslist = await wscrud.workspaces_by_provider(provider_id) + if provider_name: + provider = await pcrud.get_endpoint_by_name(provider_name) + wslist = await wscrud.workspaces_by_provider(provider.id) resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist) return resp else: @@ -394,9 +398,12 @@ async def delete_workspace(workspace_name: str): _ = await wscrud.soft_delete_workspace(workspace_name) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") + except crud.DeleteMuxesFromRegistryError: + raise HTTPException(status_code=500, detail="Internal server error") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception: + except crud.DeleteMuxesFromRegistryError as e: + logger.debug(f"Error deleting workspace {e}") raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 54d610650..24a5392a5 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -714,10 +714,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: @@ -859,7 +859,7 @@ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceRow]: sql = text( """ - SELECT + SELECT DISTINCT w.id, w.name, w.custom_instructions @@ -928,7 +928,7 @@ async def get_provider_endpoint_by_name(self, provider_name: str) -> Optional[Pr return provider[0] if provider else None async def try_get_provider_endpoint_by_name_and_type( - self, provider_name: str, provider_type: str + self, provider_name: str, provider_type: Optional[str] ) -> Optional[ProviderEndpoint]: """ Best effort attempt to find a provider endpoint matching name and type. diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 22859fe19..26c557246 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -32,6 +32,10 @@ class WorkspaceMuxRuleDoesNotExistError(WorkspaceCrudError): pass +class DeleteMuxesFromRegistryError(WorkspaceCrudError): + pass + + DEFAULT_WORKSPACE_NAME = "default" # These are reserved keywords that cannot be used for workspaces @@ -237,6 +241,7 @@ async def soft_delete_workspace(self, workspace_name: str): """ Soft delete a workspace """ + if workspace_name == "": raise WorkspaceCrudError("Workspace name cannot be empty.") if workspace_name == DEFAULT_WORKSPACE_NAME: @@ -257,8 +262,13 @@ async def soft_delete_workspace(self, workspace_name: str): raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") # Remove the muxes from the registry - mux_registry = await rulematcher.get_muxing_rules_registry() - await mux_registry.delete_ws_rules(workspace_name) + try: + mux_registry = await rulematcher.get_muxing_rules_registry() + await mux_registry.delete_ws_rules(workspace_name) + except Exception: + raise DeleteMuxesFromRegistryError( + f"Error deleting mux rules for workspace {workspace_name}" + ) return async def hard_delete_workspace(self, workspace_name: str): @@ -331,9 +341,9 @@ async def set_muxes( # Add the new muxes priority = 0 - muxes_with_routes: List[Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute]] = ( - [] - ) + muxes_with_routes: List[ + Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute] + ] = [] # Verify all models are valid for mux in muxes: diff --git a/tests/api/test_v1_providers.py b/tests/api/test_v1_providers.py index a4bceec09..f45ac2226 100644 --- a/tests/api/test_v1_providers.py +++ b/tests/api/test_v1_providers.py @@ -186,6 +186,35 @@ async def test_providers_crud( assert response.status_code == 404 assert response.json()["detail"] == "Provider endpoint not found" + # Test deleting providers + response = await ac.delete("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 204 + + # Verify provider was deleted by trying to get it + response = await ac.get("/api/v1/provider-endpoints/openai-provider") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Delete second provider + response = await ac.delete("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 204 + + # Verify second provider was deleted + response = await ac.get("/api/v1/provider-endpoints/openrouter-provider") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Test deleting non-existent provider + response = await ac.delete("/api/v1/provider-endpoints/non-existent") + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + # Verify providers list is empty + response = await ac.get("/api/v1/provider-endpoints") + assert response.status_code == 200 + providers = response.json() + assert len(providers) == 0 + @pytest.mark.asyncio async def test_list_providers_by_name( diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index e7f410795..105ff2d38 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 as uuid import httpx @@ -8,6 +8,7 @@ from httpx import AsyncClient from codegate.db import connection +from codegate.muxing.rulematcher import MuxingRulesinWorkspaces from codegate.pipeline.factory import PipelineFactory from codegate.providers.crud.crud import ProviderCrud from codegate.server import init_app @@ -70,6 +71,13 @@ def mock_pipeline_factory(): return mock_factory +@pytest.fixture +def mock_muxing_rules_registry(): + """Creates a mock for the muxing rules registry.""" + mock_registry = AsyncMock(spec=MuxingRulesinWorkspaces) + return mock_registry + + @pytest.mark.asyncio async def test_workspace_crud( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader @@ -473,3 +481,182 @@ async def test_create_workspace_with_nonexistent_model_in_muxing_rule( response = await ac.post("/api/v1/workspaces", json=payload_create) assert response.status_code == 400 assert "does not exist" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_list_workspaces_by_provider_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test listing workspaces filtered by provider name.""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create providers + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + + name_1: str = str(uuid()) + custom_instructions_1: str = "Respond to every request in iambic pentameter" + muxing_rules_1 = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create_1 = { + "name": name_1, + "config": { + "custom_instructions": custom_instructions_1, + "muxing_rules": muxing_rules_1, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create_1) + assert response.status_code == 201 + + name_2: str = str(uuid()) + custom_instructions_2: str = "Respond to every request in cockney rhyming slang" + muxing_rules_2 = [ + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "deepseek/deepseek-r1", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + payload_create_2 = { + "name": name_2, + "config": { + "custom_instructions": custom_instructions_2, + "muxing_rules": muxing_rules_2, + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create_2) + assert response.status_code == 201 + + # List workspaces filtered by openai provider + response = await ac.get("/api/v1/workspaces?provider_name=openai-provider") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 1 + assert response_body["workspaces"][0]["name"] == name_1 + + # List workspaces filtered by openrouter provider + response = await ac.get("/api/v1/workspaces?provider_name=openrouter-provider") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 1 + assert response_body["workspaces"][0]["name"] == name_2 + + # List workspaces unfiltered + response = await ac.get("/api/v1/workspaces") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 3 # 2 created in test + default + + +@pytest.mark.asyncio +async def test_delete_workspace( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, mock_muxing_rules_registry +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.muxing.rulematcher.get_muxing_rules_registry", + return_value=mock_muxing_rules_registry, + ), + ): + """Test deleting a workspace.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + payload_create = { + "name": name, + } + + # Create workspace + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Verify workspace exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["name"] == name + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{name}") + assert response.status_code == 204 + + # Verify workspace no longer exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 404 + + # Try to delete non-existent workspace + response = await ac.delete("/api/v1/workspaces/nonexistent") + assert response.status_code == 404 + assert response.json()["detail"] == "Workspace does not exist" From fc40772d01d2c23808502e23cddf43f98ccdf16f Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 08:33:27 +0000 Subject: [PATCH 08/23] update any remaining endpoints referring to providers by name --- src/codegate/api/v1.py | 39 ++++++---- tests/api/test_v1_providers.py | 130 ++++++++++++++++++++++++++++++++- 2 files changed, 150 insertions(+), 19 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 43d22c16d..5dbbdbab6 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -1,5 +1,4 @@ from typing import List, Optional -from uuid import UUID import requests import structlog @@ -80,7 +79,7 @@ async def list_provider_endpoints( return [provend] -# This needs to be above /provider-endpoints/{provider_id} to avoid conflict +# This needs to be above /provider-endpoints/{provider_name} to avoid conflict @v1.get( "/provider-endpoints/models", tags=["Providers"], @@ -95,17 +94,20 @@ async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider] @v1.get( - "/provider-endpoints/{provider_id}/models", + "/provider-endpoints/{provider_name}/models", tags=["Providers"], generate_unique_id_function=uniq_name, ) async def list_models_by_provider( - provider_id: UUID, + provider_name: str, ) -> List[v1_models.ModelByProvider]: """List models by provider.""" try: - return await pcrud.models_by_provider(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + if provider is None: + raise provendcrud.ProviderNotFoundError + return await pcrud.models_by_provider(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider not found") except Exception as e: @@ -167,18 +169,21 @@ async def add_provider_endpoint( @v1.put( - "/provider-endpoints/{provider_id}/auth-material", + "/provider-endpoints/{provider_name}/auth-material", tags=["Providers"], generate_unique_id_function=uniq_name, status_code=204, ) async def configure_auth_material( - provider_id: UUID, + provider_name: str, request: v1_models.ConfigureAuthMaterial, ): """Configure auth material for a provider.""" try: - await pcrud.configure_auth_material(provider_id, request) + provider = await pcrud.get_endpoint_by_name(provider_name) + if provider is None: + raise provendcrud.ProviderNotFoundError + await pcrud.configure_auth_material(provider.id, request) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except provendcrud.ProviderModelsNotFoundError: @@ -192,15 +197,18 @@ async def configure_auth_material( @v1.put( - "/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name + "/provider-endpoints/{provider_name}", tags=["Providers"], generate_unique_id_function=uniq_name ) async def update_provider_endpoint( - provider_id: UUID, + provider_name: str, request: v1_models.ProviderEndpoint, ) -> v1_models.ProviderEndpoint: - """Update a provider endpoint by ID.""" + """Update a provider endpoint by name.""" try: - request.id = str(provider_id) + provider = await pcrud.get_endpoint_by_name(provider_name) + if provider is None: + raise provendcrud.ProviderNotFoundError + request.id = str(provider.id) provend = await pcrud.update_endpoint(request) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") @@ -245,10 +253,9 @@ async def list_workspaces( List all workspaces. Args: - provider_id (Optional[UUID]): Filter workspaces by provider ID. If provided, + provider_name (Optional[str]): Filter workspaces by provider name. If provided, will return workspaces where models from the specified provider (e.g., OpenAI, - Anthropic) have been used in workspace muxing rules. Note that you must - refer to a provider by ID, not by name. + Anthropic) have been used in workspace muxing rules. Returns: ListWorkspacesResponse: A response object containing the list of workspaces. @@ -256,6 +263,8 @@ async def list_workspaces( try: if provider_name: provider = await pcrud.get_endpoint_by_name(provider_name) + if provider is None: + raise provendcrud.ProviderNotFoundError wslist = await wscrud.workspaces_by_provider(provider.id) resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist) return resp diff --git a/tests/api/test_v1_providers.py b/tests/api/test_v1_providers.py index f45ac2226..fc0ef6ace 100644 --- a/tests/api/test_v1_providers.py +++ b/tests/api/test_v1_providers.py @@ -216,6 +216,81 @@ async def test_providers_crud( assert len(providers) == 0 +@pytest.mark.asyncio +async def test_update_provider_endpoint( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + ): + """Test updating a provider endpoint.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create initial provider + provider_payload = { + "name": "test-provider", + "description": "Initial description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.initial.com", + "api_key": "initial-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + initial_provider = response.json() + + # Update the provider + updated_payload = { + "name": "test-provider-updated", + "description": "Updated description", + "auth_type": "api_key", + "provider_type": "openai", + "endpoint": "https://api.updated.com", + "api_key": "updated-key", + } + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider", json=updated_payload + ) + assert response.status_code == 200 + updated_provider = response.json() + + # Verify fields were updated + assert updated_provider["name"] == updated_payload["name"] + assert updated_provider["description"] == updated_payload["description"] + assert updated_provider["auth_type"] == updated_payload["auth_type"] + assert updated_provider["provider_type"] == updated_payload["provider_type"] + assert updated_provider["endpoint"] == updated_payload["endpoint"] + assert updated_provider["id"] == initial_provider["id"] + + # Get OpenRouter provider by name + response = await ac.get("/api/v1/provider-endpoints/test-provider-updated") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == updated_payload["name"] + assert provider["description"] == updated_payload["description"] + assert provider["auth_type"] == updated_payload["auth_type"] + assert provider["provider_type"] == updated_payload["provider_type"] + assert provider["endpoint"] == updated_payload["endpoint"] + assert isinstance(provider["id"], str) and provider["id"] + + # Test updating non-existent provider + response = await ac.put( + "/api/v1/provider-endpoints/fake-provider", json=updated_payload + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" + + @pytest.mark.asyncio async def test_list_providers_by_name( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud @@ -389,10 +464,10 @@ async def test_list_models_by_provider( response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) assert response.status_code == 201 provider = response.json() - provider_id = provider["id"] + provider_name = provider["name"] # Get models for the provider - response = await ac.get(f"/api/v1/provider-endpoints/{provider_id}/models") + response = await ac.get(f"/api/v1/provider-endpoints/{provider_name}/models") assert response.status_code == 200 models = response.json() @@ -407,7 +482,54 @@ async def test_list_models_by_provider( assert all(model["provider_name"] == "openai-provider" for model in models) # Test with non-existent provider ID - fake_uuid = str(uuid()) - response = await ac.get(f"/api/v1/provider-endpoints/{fake_uuid}/models") + fake_name = "foo-bar" + response = await ac.get(f"/api/v1/provider-endpoints/{fake_name}/models") assert response.status_code == 404 assert response.json()["detail"] == "Provider not found" + + +@pytest.mark.asyncio +async def test_configure_auth_material( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + ): + """Test configuring auth material for a provider.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create provider + provider_payload = { + "name": "test-provider", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + + # Configure auth material + auth_material = {"api_key": "sk-proj-foo-bar-123-xyz", "auth_type": "api_key"} + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider/auth-material", json=auth_material + ) + assert response.status_code == 204 + + # Test with non-existent provider + response = await ac.put( + "/api/v1/provider-endpoints/fake-provider/auth-material", json=auth_material + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Provider endpoint not found" From 2184910ce0977fac2b37b427b0b30768281121c4 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 09:03:40 +0000 Subject: [PATCH 09/23] fix alembic head conflict --- .../versions/2025_03_07_0902-1ee1be2156f7_.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 migrations/versions/2025_03_07_0902-1ee1be2156f7_.py diff --git a/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py b/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py new file mode 100644 index 000000000..fc273fff7 --- /dev/null +++ b/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py @@ -0,0 +1,23 @@ +"""empty message + +Revision ID: 1ee1be2156f7 +Revises: e4c05d7591a8, 4b81c45b5da6 +Create Date: 2025-03-07 09:02:54.636452+00:00 + +""" + +from typing import Sequence, Union + +# revision identifiers, used by Alembic. +revision: str = "1ee1be2156f7" +down_revision: Union[str, None] = ("e4c05d7591a8", "4b81c45b5da6") +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass From c46ddcba02c89da19c305d7b330219714c601a6f Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 10:32:23 +0000 Subject: [PATCH 10/23] bug fixes & testing --- api/openapi.json | 98 ++++++------ ...126-e4c05d7591a8_add_installation_table.py | 2 - src/codegate/api/v1.py | 14 +- src/codegate/cli.py | 2 +- src/codegate/db/connection.py | 35 ++++- src/codegate/workspaces/crud.py | 6 +- tests/api/test_v1_workspaces.py | 143 ++++++++++++++++++ 7 files changed, 237 insertions(+), 63 deletions(-) diff --git a/api/openapi.json b/api/openapi.json index ad8823642..ba98677d4 100644 --- a/api/openapi.json +++ b/api/openapi.json @@ -148,7 +148,7 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}/models": { + "/api/v1/provider-endpoints/{provider_name}/models": { "get": { "tags": [ "CodeGate API", @@ -159,13 +159,12 @@ "operationId": "v1_list_models_by_provider", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -240,14 +239,14 @@ } } }, - "delete": { + "put": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Delete Provider Endpoint", - "description": "Delete a provider endpoint by id.", - "operationId": "v1_delete_provider_endpoint", + "summary": "Update Provider Endpoint", + "description": "Update a provider endpoint by name.", + "operationId": "v1_update_provider_endpoint", "parameters": [ { "name": "provider_name", @@ -259,12 +258,24 @@ } } ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ProviderEndpoint" + } + } + } + }, "responses": { "200": { "description": "Successful Response", "content": { "application/json": { - "schema": {} + "schema": { + "$ref": "#/components/schemas/ProviderEndpoint" + } } } }, @@ -279,42 +290,34 @@ } } } - } - }, - "/api/v1/provider-endpoints/{provider_id}/auth-material": { - "put": { + }, + "delete": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Configure Auth Material", - "description": "Configure auth material for a provider.", - "operationId": "v1_configure_auth_material", + "summary": "Delete Provider Endpoint", + "description": "Delete a provider endpoint by name.", + "operationId": "v1_delete_provider_endpoint", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ConfigureAuthMaterial" + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} } } - } - }, - "responses": { - "204": { - "description": "Successful Response" }, "422": { "description": "Validation Error", @@ -329,24 +332,23 @@ } } }, - "/api/v1/provider-endpoints/{provider_id}": { + "/api/v1/provider-endpoints/{provider_name}/auth-material": { "put": { "tags": [ "CodeGate API", "Providers" ], - "summary": "Update Provider Endpoint", - "description": "Update a provider endpoint by ID.", - "operationId": "v1_update_provider_endpoint", + "summary": "Configure Auth Material", + "description": "Configure auth material for a provider.", + "operationId": "v1_configure_auth_material", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "path", "required": true, "schema": { "type": "string", - "format": "uuid", - "title": "Provider Id" + "title": "Provider Name" } } ], @@ -355,21 +357,14 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ProviderEndpoint" + "$ref": "#/components/schemas/ConfigureAuthMaterial" } } } }, "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ProviderEndpoint" - } - } - } + "204": { + "description": "Successful Response" }, "422": { "description": "Validation Error", @@ -391,24 +386,23 @@ "Workspaces" ], "summary": "List Workspaces", - "description": "List all workspaces.\n\nArgs:\n provider_id (Optional[UUID]): Filter workspaces by provider ID. If provided,\n will return workspaces where models from the specified provider (e.g., OpenAI,\n Anthropic) have been used in workspace muxing rules. Note that you must\n refer to a provider by ID, not by name.\n\nReturns:\n ListWorkspacesResponse: A response object containing the list of workspaces.", + "description": "List all workspaces.\n\nArgs:\n provider_name (Optional[str]): Filter workspaces by provider name. If provided,\n will return workspaces where models from the specified provider (e.g., OpenAI,\n Anthropic) have been used in workspace muxing rules.\n\nReturns:\n ListWorkspacesResponse: A response object containing the list of workspaces.", "operationId": "v1_list_workspaces", "parameters": [ { - "name": "provider_id", + "name": "provider_name", "in": "query", "required": false, "schema": { "anyOf": [ { - "type": "string", - "format": "uuid" + "type": "string" }, { "type": "null" } ], - "title": "Provider Id" + "title": "Provider Name" } } ], diff --git a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py index 775e3967b..9e2b6c130 100644 --- a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py +++ b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py @@ -9,8 +9,6 @@ from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "e4c05d7591a8" diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index cc7c447a5..195672e96 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -51,7 +51,14 @@ async def _add_provider_id_to_mux_rule( f'Provider "{mux_rule.provider_name}" of type "{mux_rule.provider_type}" not found' # noqa: E501 ) - return mux_models.MuxRuleWithProviderId(**mux_rule.model_dump(), provider_id=provider.id) + return mux_models.MuxRuleWithProviderId( + matcher=mux_rule.matcher, + matcher_type=mux_rule.matcher_type, + model=mux_rule.model, + provider_type=provider.provider_type, + provider_id=provider.id, + provider_name=provider.name, + ) class FilterByNameParams(BaseModel): @@ -272,6 +279,8 @@ async def list_workspaces( wslist = await wscrud.get_workspaces() resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist) return resp + except provendcrud.ProviderNotFoundError: + return v1_models.ListWorkspacesResponse(workspaces=[]) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -332,7 +341,8 @@ async def create_workspace( ) except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception: + except Exception as e: + logger.debug(f"Error creating workspace: {e}") raise HTTPException(status_code=500, detail="Internal server error") return v1_models.FullWorkspace( diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 1ae3f9c22..5c08821c5 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -16,8 +16,8 @@ from codegate.config import Config, ConfigurationError from codegate.db.connection import ( init_db_sync, - init_session_if_not_exists, init_instance, + init_session_if_not_exists, ) from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.sensitive_data.manager import SensitiveDataManager diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 1507ed083..004a2c231 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -414,6 +414,16 @@ async def soft_delete_workspace(self, workspace: WorkspaceRow) -> Optional[Works return deleted_workspace async def hard_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]: + # First delete associated muxes + sql_delete_muxes = text( + """ + DELETE FROM muxes + WHERE workspace_id = :id + """ + ) + await self._execute_with_no_return(sql_delete_muxes, {"id": workspace.id}) + + # Then delete the workspace sql = text( """ DELETE FROM workspaces @@ -472,7 +482,26 @@ async def delete_provider_endpoint( self, provider: ProviderEndpoint, ) -> Optional[ProviderEndpoint]: - sql = text( + # Delete from provider_models + sql_delete_provider_models = text( + """ + DELETE FROM provider_models + WHERE provider_endpoint_id = :id + """ + ) + await self._execute_with_no_return(sql_delete_provider_models, {"id": provider.id}) + + # Delete from muxes + sql_delete_muxes = text( + """ + DELETE FROM muxes + WHERE provider_endpoint_id = :id + """ + ) + await self._execute_with_no_return(sql_delete_muxes, {"id": provider.id}) + + # Delete from provider_endpoints + sql_delete_provider_endpoints = text( """ DELETE FROM provider_endpoints WHERE id = :id @@ -480,7 +509,7 @@ async def delete_provider_endpoint( """ ) deleted_provider = await self._execute_update_pydantic_model( - provider, sql, should_raise=True + provider, sql_delete_provider_endpoints, should_raise=True ) return deleted_provider @@ -621,7 +650,7 @@ async def init_instance(self) -> None: await self._execute_with_no_return(sql, instance.model_dump()) except IntegrityError as e: logger.debug(f"Exception type: {type(e)}") - raise AlreadyExistsError(f"Instance already initialized.") + raise AlreadyExistsError("Instance already initialized.") class DbReader(DbCodeGate): diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 26c557246..eef2ce475 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -341,9 +341,9 @@ async def set_muxes( # Add the new muxes priority = 0 - muxes_with_routes: List[ - Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute] - ] = [] + muxes_with_routes: List[Tuple[mux_models.MuxRuleWithProviderId, rulematcher.ModelRoute]] = ( + [] + ) # Verify all models are valid for mux in muxes: diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index 105ff2d38..73565ebc9 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -217,6 +217,143 @@ async def test_workspace_crud( assert rule["matcher_type"] == muxing_rules_2[i]["matcher_type"] +@pytest.mark.asyncio +async def test_create_workspace_with_mux_different_provider_name( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + ): + """Test creating a workspace with mux rules, then recreating it after renaming the provider.""" + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + # Create initial provider + provider_payload = { + "name": "test-provider-1", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload) + assert response.status_code == 201 + + # Create workspace with mux rules + workspace_name = str(uuid()) + muxing_rules = [ + { + "provider_name": "test-provider-1", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "test-provider-1", + "provider_type": "openai", + "model": "gpt-3.5-turbo", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + workspace_payload = { + "name": workspace_name, + "config": { + "custom_instructions": "Test instructions", + "muxing_rules": muxing_rules, + }, + } + + response = await ac.post("/api/v1/workspaces", json=workspace_payload) + assert response.status_code == 201 + + # Get workspace config as JSON blob + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + workspace_blob = response.json() + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 204 + response = await ac.delete(f"/api/v1/workspaces/archive/{workspace_name}") + assert response.status_code == 204 + + # Verify workspace is deleted + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 404 + + # Update provider name + rename_provider_payload = { + "name": "test-provider-2", + "description": "Test provider", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.test.com", + "api_key": "test-key", + } + + response = await ac.put( + "/api/v1/provider-endpoints/test-provider-1", json=rename_provider_payload + ) + assert response.status_code == 200 + + # Verify old provider name no longer exists + response = await ac.get("/api/v1/provider-endpoints/test-provider-1") + assert response.status_code == 404 + + # Verify provider exists under new name + response = await ac.get("/api/v1/provider-endpoints/test-provider-2") + assert response.status_code == 200 + provider = response.json() + assert provider["name"] == "test-provider-2" + assert provider["description"] == "Test provider" + assert provider["auth_type"] == "none" + assert provider["provider_type"] == "openai" + assert provider["endpoint"] == "https://api.test.com" + + # re-upload the workspace that we have previously downloaded + + response = await ac.post("/api/v1/workspaces", json=workspace_blob) + assert response.status_code == 201 + + # Verify new workspace config + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + new_workspace = response.json() + + assert new_workspace["name"] == workspace_name + assert ( + new_workspace["config"]["custom_instructions"] + == workspace_blob["config"]["custom_instructions"] + ) + + # Verify muxing rules are correct with updated provider name + for i, rule in enumerate(new_workspace["config"]["muxing_rules"]): + assert rule["provider_name"] == "test-provider-2" + assert ( + rule["provider_type"] + == workspace_blob["config"]["muxing_rules"][i]["provider_type"] + ) + assert rule["model"] == workspace_blob["config"]["muxing_rules"][i]["model"] + assert rule["matcher"] == workspace_blob["config"]["muxing_rules"][i]["matcher"] + assert ( + rule["matcher_type"] + == workspace_blob["config"]["muxing_rules"][i]["matcher_type"] + ) + + @pytest.mark.asyncio async def test_rename_workspace( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader @@ -608,6 +745,12 @@ async def test_list_workspaces_by_provider_name( assert len(response_body["workspaces"]) == 1 assert response_body["workspaces"][0]["name"] == name_2 + # List workspaces filtered by non-existent provider + response = await ac.get("/api/v1/workspaces?provider_name=foo-bar-123") + assert response.status_code == 200 + response_body = response.json() + assert len(response_body["workspaces"]) == 0 + # List workspaces unfiltered response = await ac.get("/api/v1/workspaces") assert response.status_code == 200 From 3919db8478df290c2af3ac02eb8e4cf45344fd40 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 10:33:34 +0000 Subject: [PATCH 11/23] lint fix --- src/codegate/db/connection.py | 8 ++++---- tests/api/test_v1_workspaces.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 004a2c231..1df2ab67f 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -766,10 +766,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index 73565ebc9..afafe7ff9 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -230,7 +230,10 @@ async def test_create_workspace_with_mux_different_provider_name( return_value=["gpt-4", "gpt-3.5-turbo"], ), ): - """Test creating a workspace with mux rules, then recreating it after renaming the provider.""" + """ + Test creating a workspace with mux rules, then recreating it after + renaming the provider. + """ app = init_app(mock_pipeline_factory) async with AsyncClient( From 68431531181966b67fe24c6cbaaef304a04049f8 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 11:03:04 +0000 Subject: [PATCH 12/23] fix integration tests --- tests/integration/anthropic/testcases.yaml | 2 ++ tests/integration/llamacpp/testcases.yaml | 2 ++ tests/integration/ollama/testcases.yaml | 2 ++ tests/integration/openai/testcases.yaml | 2 ++ tests/integration/openrouter/testcases.yaml | 2 ++ tests/integration/vllm/testcases.yaml | 2 ++ 6 files changed, 12 insertions(+) diff --git a/tests/integration/anthropic/testcases.yaml b/tests/integration/anthropic/testcases.yaml index 03f8f6667..1b50ea79d 100644 --- a/tests/integration/anthropic/testcases.yaml +++ b/tests/integration/anthropic/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: claude-3-5-haiku-20241022 + provider_name: anthropic_muxing + provider_type: anthropic matcher_type: catch_all matcher: "" diff --git a/tests/integration/llamacpp/testcases.yaml b/tests/integration/llamacpp/testcases.yaml index 69ec72df6..f7422991d 100644 --- a/tests/integration/llamacpp/testcases.yaml +++ b/tests/integration/llamacpp/testcases.yaml @@ -23,6 +23,8 @@ muxing: Content-Type: application/json rules: - model: qwen2.5-coder-0.5b-instruct-q5_k_m + provider_name: llamacpp_muxing + provider_type: llamacpp matcher_type: catch_all matcher: "" diff --git a/tests/integration/ollama/testcases.yaml b/tests/integration/ollama/testcases.yaml index 56a13b571..691fe4faf 100644 --- a/tests/integration/ollama/testcases.yaml +++ b/tests/integration/ollama/testcases.yaml @@ -24,6 +24,8 @@ muxing: rules: - model: qwen2.5-coder:1.5b matcher_type: catch_all + provider_name: ollama_muxing + provider_type: ollama matcher: "" testcases: diff --git a/tests/integration/openai/testcases.yaml b/tests/integration/openai/testcases.yaml index 452dcce6f..fb3730798 100644 --- a/tests/integration/openai/testcases.yaml +++ b/tests/integration/openai/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: gpt-4o-mini + provider_name: openai_muxing + provider_type: openai matcher_type: catch_all matcher: "" diff --git a/tests/integration/openrouter/testcases.yaml b/tests/integration/openrouter/testcases.yaml index d64e0266a..818acd6a5 100644 --- a/tests/integration/openrouter/testcases.yaml +++ b/tests/integration/openrouter/testcases.yaml @@ -24,6 +24,8 @@ muxing: Content-Type: application/json rules: - model: anthropic/claude-3.5-haiku + provider_name: openrouter_muxing + provider_type: openrouter matcher_type: catch_all matcher: "" diff --git a/tests/integration/vllm/testcases.yaml b/tests/integration/vllm/testcases.yaml index 52df95984..adba751e2 100644 --- a/tests/integration/vllm/testcases.yaml +++ b/tests/integration/vllm/testcases.yaml @@ -23,6 +23,8 @@ muxing: Content-Type: application/json rules: - model: Qwen/Qwen2.5-Coder-0.5B-Instruct + provider_name: vllm_muxing + provider_type: vllm matcher_type: catch_all matcher: "" From b5f193ad976857286f4ee930d556f607ab07fabb Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 14:38:17 +0000 Subject: [PATCH 13/23] fix bug where provider name not updated in muxes table after rename --- src/codegate/db/connection.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 1df2ab67f..ff3405291 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -476,6 +476,26 @@ async def update_provider_endpoint(self, provider: ProviderEndpoint) -> Provider updated_provider = await self._execute_update_pydantic_model( provider, sql, should_raise=True ) + + # Update dependent tables + update_muxes_sql = text( + """ + UPDATE muxes + SET provider_endpoint_name = :name, provider_endpoint_type = :provider_type + WHERE provider_endpoint_id = :id + """ + ) + await self._execute_with_no_return(update_muxes_sql, provider.model_dump()) + + update_provider_models_sql = text( + """ + UPDATE provider_models + SET provider_endpoint_name = :name, provider_endpoint_type = :provider_type + WHERE provider_endpoint_id = :id + """ + ) + await self._execute_with_no_return(update_provider_models_sql, provider.model_dump()) + return updated_provider async def delete_provider_endpoint( @@ -766,10 +786,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[IntermediatePromptWithOutputUsageAlerts] = ( - await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True - ) + rows: List[ + IntermediatePromptWithOutputUsageAlerts + ] = await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: From 5f405f0e172a27895d8866bec29ffccad28626d2 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 14:53:18 +0000 Subject: [PATCH 14/23] address logger feedback --- src/codegate/api/v1.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 195672e96..cdb9e1246 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -118,7 +118,7 @@ async def list_models_by_provider( except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider not found") except Exception as e: - logger.debug(f"Error listing models by provider, {e}") + logger.exception("Error while listing models by provider") raise HTTPException(status_code=500, detail=str(e)) @@ -342,7 +342,7 @@ async def create_workspace( except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - logger.debug(f"Error creating workspace: {e}") + logger.exception("Error while creating workspace") raise HTTPException(status_code=500, detail="Internal server error") return v1_models.FullWorkspace( @@ -392,7 +392,7 @@ async def update_workspace( ), ) except crud.WorkspaceCrudError as e: - logger.debug(f"Could not update workspace: {e}") + logger.exception("Error while updating workspace") raise HTTPException(status_code=400, detail=str(e)) except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -421,8 +421,8 @@ async def delete_workspace(workspace_name: str): raise HTTPException(status_code=500, detail="Internal server error") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) - except crud.DeleteMuxesFromRegistryError as e: - logger.debug(f"Error deleting workspace {e}") + except crud.DeleteMuxesFromRegistryError: + logger.exception("Error deleting muxes while deleting workspace") raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) From d533e7f7bf63587f06c5eac00fb45e424d216c15 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 14:53:37 +0000 Subject: [PATCH 15/23] move `raise ProviderNotFoundError` into crud method --- src/codegate/api/v1.py | 22 +++++----------------- src/codegate/providers/crud/crud.py | 2 +- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index cdb9e1246..7925c3ec9 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -78,13 +78,12 @@ async def list_provider_endpoints( try: provend = await pcrud.get_endpoint_by_name(filter_query.name) + return [provend] + except pcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") - if provend is None: - raise HTTPException(status_code=404, detail="Provider endpoint not found") - return [provend] - # This needs to be above /provider-endpoints/{provider_name} to avoid conflict @v1.get( @@ -112,8 +111,6 @@ async def list_models_by_provider( try: provider = await pcrud.get_endpoint_by_name(provider_name) - if provider is None: - raise provendcrud.ProviderNotFoundError return await pcrud.models_by_provider(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider not found") @@ -131,11 +128,10 @@ async def get_provider_endpoint( """Get a provider endpoint by name.""" try: provend = await pcrud.get_endpoint_by_name(provider_name) + except pcrud.ProviderNotFoundError: + raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") - - if provend is None: - raise HTTPException(status_code=404, detail="Provider endpoint not found") return provend @@ -188,8 +184,6 @@ async def configure_auth_material( """Configure auth material for a provider.""" try: provider = await pcrud.get_endpoint_by_name(provider_name) - if provider is None: - raise provendcrud.ProviderNotFoundError await pcrud.configure_auth_material(provider.id, request) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") @@ -213,8 +207,6 @@ async def update_provider_endpoint( """Update a provider endpoint by name.""" try: provider = await pcrud.get_endpoint_by_name(provider_name) - if provider is None: - raise provendcrud.ProviderNotFoundError request.id = str(provider.id) provend = await pcrud.update_endpoint(request) except provendcrud.ProviderNotFoundError: @@ -242,8 +234,6 @@ async def delete_provider_endpoint( """Delete a provider endpoint by name.""" try: provider = await pcrud.get_endpoint_by_name(provider_name) - if provider is None: - raise provendcrud.ProviderNotFoundError await pcrud.delete_endpoint(provider.id) except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") @@ -270,8 +260,6 @@ async def list_workspaces( try: if provider_name: provider = await pcrud.get_endpoint_by_name(provider_name) - if provider is None: - raise provendcrud.ProviderNotFoundError wslist = await wscrud.workspaces_by_provider(provider.id) resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist) return resp diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 1e6142ee5..5f0ee22ab 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -67,7 +67,7 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider dbendpoint = await self._db_reader.get_provider_endpoint_by_name(name) if dbendpoint is None: - return None + raise ProviderNotFoundError(f'Provider "{name}" not found') return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) From 5f894bd04cace2386e8311a313946085b354a7a5 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 15:53:04 +0000 Subject: [PATCH 16/23] clean up converting API MuxRule to internal representation --- src/codegate/api/v1.py | 62 ++++++++++------------------- src/codegate/db/connection.py | 17 ++------ src/codegate/providers/crud/crud.py | 38 ++++++++++++++++++ 3 files changed, 62 insertions(+), 55 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index 7925c3ec9..b08ce40a0 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -35,32 +35,6 @@ def uniq_name(route: APIRoute): return f"v1_{route.name}" -async def _add_provider_id_to_mux_rule( - mux_rule: mux_models.MuxRule, -) -> mux_models.MuxRuleWithProviderId: - """ - Convert a `MuxRule` to `MuxRuleWithProviderId` by looking up the provider ID. - Extracts provider name and type from the MuxRule and queries the database to get the ID. - """ - provider = await dbreader.try_get_provider_endpoint_by_name_and_type( - mux_rule.provider_name, - mux_rule.provider_type, - ) - if provider is None: - raise crud.WorkspaceCrudError( - f'Provider "{mux_rule.provider_name}" of type "{mux_rule.provider_type}" not found' # noqa: E501 - ) - - return mux_models.MuxRuleWithProviderId( - matcher=mux_rule.matcher, - matcher_type=mux_rule.matcher_type, - model=mux_rule.model, - provider_type=provider.provider_type, - provider_id=provider.id, - provider_name=provider.name, - ) - - class FilterByNameParams(BaseModel): name: Optional[str] = None @@ -79,7 +53,7 @@ async def list_provider_endpoints( try: provend = await pcrud.get_endpoint_by_name(filter_query.name) return [provend] - except pcrud.ProviderNotFoundError: + except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -128,7 +102,7 @@ async def get_provider_endpoint( """Get a provider endpoint by name.""" try: provend = await pcrud.get_endpoint_by_name(provider_name) - except pcrud.ProviderNotFoundError: + except provendcrud.ProviderNotFoundError: raise HTTPException(status_code=404, detail="Provider endpoint not found") except Exception: raise HTTPException(status_code=500, detail="Internal server error") @@ -220,6 +194,7 @@ async def update_provider_endpoint( detail=str(e), ) except Exception as e: + logger.exception("Error while updating provider endpoint") raise HTTPException(status_code=500, detail=str(e)) return provend @@ -308,14 +283,12 @@ async def create_workspace( """Create a new workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules: List[mux_models.MuxRuleWithProviderId] = [] + mux_rules = [] if request.config and request.config.muxing_rules: - for rule in request.config.muxing_rules: - mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) - muxing_rules.append(mux_rule_with_provider) + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) workspace_row, mux_rules = await wscrud.add_workspace( - request.name, custom_instructions, muxing_rules + request.name, custom_instructions, mux_rules ) except crud.WorkspaceNameAlreadyInUseError: raise HTTPException(status_code=409, detail="Workspace name already in use") @@ -327,9 +300,13 @@ async def create_workspace( "Please use only alphanumeric characters, hyphens, or underscores." ), ) + except provendcrud.ProviderNotFoundError as e: + logger.exception("Error matching a provider for a mux rule while creating a workspace") + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceCrudError as e: + logger.exception("Error while create a workspace") raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: + except Exception: logger.exception("Error while creating workspace") raise HTTPException(status_code=500, detail="Internal server error") @@ -355,18 +332,18 @@ async def update_workspace( """Update a workspace.""" try: custom_instructions = request.config.custom_instructions if request.config else None - muxing_rules: List[mux_models.MuxRuleWithProviderId] = [] + mux_rules = [] if request.config and request.config.muxing_rules: - for rule in request.config.muxing_rules: - mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) - muxing_rules.append(mux_rule_with_provider) + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) workspace_row, mux_rules = await wscrud.update_workspace( workspace_name, request.name, custom_instructions, - muxing_rules, + mux_rules, ) + except provendcrud.ProviderNotFoundError as e: + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceNameAlreadyInUseError: @@ -640,11 +617,12 @@ async def set_workspace_muxes( """Set the mux rules of a workspace.""" try: mux_rules = [] - for rule in request: - mux_rule_with_provider = await _add_provider_id_to_mux_rule(rule) - mux_rules.append(mux_rule_with_provider) + if request.config and request.config.muxing_rules: + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) await wscrud.set_muxes(workspace_name, mux_rules) + except provendcrud.ProviderNotFoundError as e: + raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceCrudError as e: diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index ff3405291..d3ce3470d 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -487,15 +487,6 @@ async def update_provider_endpoint(self, provider: ProviderEndpoint) -> Provider ) await self._execute_with_no_return(update_muxes_sql, provider.model_dump()) - update_provider_models_sql = text( - """ - UPDATE provider_models - SET provider_endpoint_name = :name, provider_endpoint_type = :provider_type - WHERE provider_endpoint_id = :id - """ - ) - await self._execute_with_no_return(update_provider_models_sql, provider.model_dump()) - return updated_provider async def delete_provider_endpoint( @@ -786,10 +777,10 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id( # If trigger category is None we want to get all alerts trigger_category = trigger_category if trigger_category else "%" conditions = {"workspace_id": workspace_id, "trigger_category": trigger_category} - rows: List[ - IntermediatePromptWithOutputUsageAlerts - ] = await self._exec_select_conditions_to_pydantic( - IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + rows: List[IntermediatePromptWithOutputUsageAlerts] = ( + await self._exec_select_conditions_to_pydantic( + IntermediatePromptWithOutputUsageAlerts, sql, conditions, should_raise=True + ) ) prompts_dict: Dict[str, GetPromptWithOutputsRow] = {} for row in rows: diff --git a/src/codegate/providers/crud/crud.py b/src/codegate/providers/crud/crud.py index 5f0ee22ab..56ba63089 100644 --- a/src/codegate/providers/crud/crud.py +++ b/src/codegate/providers/crud/crud.py @@ -10,6 +10,7 @@ from codegate.config import Config from codegate.db import models as dbmodels from codegate.db.connection import DbReader, DbRecorder +from codegate.muxing import models as mux_models from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.workspaces import crud as workspace_crud @@ -71,6 +72,43 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + async def _try_get_endpoint_by_name_and_type( + self, name: str, type: Optional[str] + ) -> Optional[apimodelsv1.ProviderEndpoint]: + """ + Try to get an endpoint by name & type, + falling back to a "best effort" match by type. + """ + + dbendpoint = await self._db_reader.try_get_provider_endpoint_by_name_and_type(name, type) + if dbendpoint is None: + raise ProviderNotFoundError(f'Provider "{name}" not found') + + return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) + + async def add_provider_id_to_mux_rule( + self, rule: mux_models.MuxRule + ) -> mux_models.MuxRuleWithProviderId: + endpoint = await self._try_get_endpoint_by_name_and_type( + rule.provider_name, rule.provider_type + ) + return mux_models.MuxRuleWithProviderId( + model=rule.model, + matcher=rule.matcher, + matcher_type=rule.matcher_type, + provider_name=endpoint.name, + provider_type=endpoint.provider_type, + provider_id=endpoint.id, + ) + + async def add_provider_ids_to_mux_rule_list( + self, rules: List[mux_models.MuxRule] + ) -> List[mux_models.MuxRuleWithProviderId]: + rules_with_ids = [] + for rule in rules: + rules_with_ids.append(await self.add_provider_id_to_mux_rule(rule)) + return rules_with_ids + async def add_endpoint( self, endpoint: apimodelsv1.AddProviderEndpointRequest ) -> apimodelsv1.ProviderEndpoint: From ba0a2f624c92e900bd9fe5fbbc48e10ee8cc0833 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 7 Mar 2025 16:03:32 +0000 Subject: [PATCH 17/23] flatten migrations --- ..._add_provider_endpoint_fields_to_muxes.py} | 52 ++++++--- ...da6_add_provider_endpoint_name_to_muxes.py | 100 ------------------ .../versions/2025_03_07_0902-1ee1be2156f7_.py | 23 ---- 3 files changed, 35 insertions(+), 140 deletions(-) rename migrations/versions/{2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py => 2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py} (60%) delete mode 100644 migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py delete mode 100644 migrations/versions/2025_03_07_0902-1ee1be2156f7_.py diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py similarity index 60% rename from migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py rename to migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py index 2ce7928de..347bd11a0 100644 --- a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_type_to_muxes.py +++ b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py @@ -1,7 +1,7 @@ -"""add provider_endpoint_type to muxes +"""add provider endpoint fields to muxes Revision ID: 769f09b6d992 -Revises: 3ec2b4ab569c +Revises: e4c05d7591a8 Create Date: 2025-03-06 11:30:11.647216+00:00 """ @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision: str = "769f09b6d992" -down_revision: Union[str, None] = "3ec2b4ab569c" +down_revision: Union[str, None] = "e4c05d7591a8" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,29 +21,40 @@ def upgrade() -> None: # Begin transaction op.execute("BEGIN TRANSACTION;") - # Add the new column + # Add the new columns op.execute( """ ALTER TABLE muxes ADD COLUMN provider_endpoint_type TEXT; """ ) + op.execute( + """ + ALTER TABLE muxes + ADD COLUMN provider_endpoint_name TEXT; + """ + ) - # Update the new column with data from provider_endpoints + # Update both new columns with data from provider_endpoints op.execute( """ UPDATE muxes - SET provider_endpoint_type = ( - SELECT provider_type - FROM provider_endpoints - WHERE provider_endpoints.id = muxes.provider_endpoint_id - ); + SET + provider_endpoint_type = ( + SELECT provider_type + FROM provider_endpoints + WHERE provider_endpoints.id = muxes.provider_endpoint_id + ), + provider_endpoint_name = ( + SELECT name + FROM provider_endpoints + WHERE provider_endpoints.id = muxes.provider_endpoint_id + ); """ ) - # Make the column NOT NULL after populating it - # SQLite is funny about altering columns, so we actually need to clone & - # swap the table + # Make the columns NOT NULL after populating them + # SQLite requires table recreation for this op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") op.execute("DROP TABLE muxes;") op.execute( @@ -59,6 +70,7 @@ def upgrade() -> None: created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, provider_endpoint_type TEXT NOT NULL, + provider_endpoint_name TEXT NOT NULL, FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) );""" ) @@ -74,24 +86,30 @@ def downgrade() -> None: op.execute("BEGIN TRANSACTION;") try: - # Check if the column exists + # Check if the columns exist op.execute( """ - SELECT provider_endpoint_type + SELECT provider_endpoint_type, provider_endpoint_name FROM muxes LIMIT 1; """ ) - # Drop the column only if it exists + # Drop both columns if they exist op.execute( """ ALTER TABLE muxes DROP COLUMN provider_endpoint_type; """ ) + op.execute( + """ + ALTER TABLE muxes + DROP COLUMN provider_endpoint_name; + """ + ) except Exception: - # If there's an error (column doesn't exist), rollback and continue + # If there's an error (columns don't exist), rollback and continue op.execute("ROLLBACK;") return diff --git a/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py b/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py deleted file mode 100644 index 15cf26837..000000000 --- a/migrations/versions/2025_03_06_1324-4b81c45b5da6_add_provider_endpoint_name_to_muxes.py +++ /dev/null @@ -1,100 +0,0 @@ -"""add provider_endpoint_name to muxes - -Revision ID: 4b81c45b5da6 -Revises: 769f09b6d992 -Create Date: 2025-03-06 13:24:41.123857+00:00 - -""" - -from typing import Sequence, Union - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "4b81c45b5da6" -down_revision: Union[str, None] = "769f09b6d992" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Begin transaction - op.execute("BEGIN TRANSACTION;") - - # Add the new column - op.execute( - """ - ALTER TABLE muxes - ADD COLUMN provider_endpoint_name TEXT; - """ - ) - - # Update the new column with data from provider_endpoints - op.execute( - """ - UPDATE muxes - SET provider_endpoint_name = ( - SELECT name - FROM provider_endpoints - WHERE provider_endpoints.id = muxes.provider_endpoint_id - ); - """ - ) - - # Make the column NOT NULL after populating it - # SQLite is funny about altering columns, so we actually need to clone & - # swap the table - op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") - op.execute("DROP TABLE muxes;") - op.execute( - """ - CREATE TABLE muxes ( - id TEXT PRIMARY KEY, - provider_endpoint_id TEXT NOT NULL, - provider_model_name TEXT NOT NULL, - workspace_id TEXT NOT NULL, - matcher_type TEXT NOT NULL, - matcher_blob TEXT NOT NULL, - priority INTEGER NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - provider_endpoint_type TEXT NOT NULL, - provider_endpoint_name TEXT NOT NULL, - FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) - );""" - ) - op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") - op.execute("DROP TABLE muxes_new;") - - # Finish transaction - op.execute("COMMIT;") - - -def downgrade() -> None: - # Begin transaction - op.execute("BEGIN TRANSACTION;") - - try: - # Check if the column exists - op.execute( - """ - SELECT provider_endpoint_name - FROM muxes - LIMIT 1; - """ - ) - - # Drop the column only if it exists - op.execute( - """ - ALTER TABLE muxes - DROP COLUMN provider_endpoint_name; - """ - ) - except Exception: - # If there's an error (column doesn't exist), rollback and continue - op.execute("ROLLBACK;") - return - - # Finish transaction - op.execute("COMMIT;") diff --git a/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py b/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py deleted file mode 100644 index fc273fff7..000000000 --- a/migrations/versions/2025_03_07_0902-1ee1be2156f7_.py +++ /dev/null @@ -1,23 +0,0 @@ -"""empty message - -Revision ID: 1ee1be2156f7 -Revises: e4c05d7591a8, 4b81c45b5da6 -Create Date: 2025-03-07 09:02:54.636452+00:00 - -""" - -from typing import Sequence, Union - -# revision identifiers, used by Alembic. -revision: str = "1ee1be2156f7" -down_revision: Union[str, None] = ("e4c05d7591a8", "4b81c45b5da6") -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - pass - - -def downgrade() -> None: - pass From 5234e53a3a802c52b960e4e1647dc6822433f835 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Mon, 10 Mar 2025 08:44:36 +0000 Subject: [PATCH 18/23] fix 500 error when deleting workspace w. no mux rules --- src/codegate/api/v1.py | 10 +- src/codegate/workspaces/crud.py | 4 +- tests/api/test_v1_workspaces.py | 175 +++++++++++++++++++++++++++++++- 3 files changed, 181 insertions(+), 8 deletions(-) diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index b08ce40a0..8d42a5a95 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -382,13 +382,14 @@ async def delete_workspace(workspace_name: str): _ = await wscrud.soft_delete_workspace(workspace_name) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") - except crud.DeleteMuxesFromRegistryError: - raise HTTPException(status_code=500, detail="Internal server error") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) except crud.DeleteMuxesFromRegistryError: logger.exception("Error deleting muxes while deleting workspace") raise HTTPException(status_code=500, detail="Internal server error") + except Exception: + logger.exception("Error while deleting workspace") + raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) @@ -616,10 +617,7 @@ async def set_workspace_muxes( ): """Set the mux rules of a workspace.""" try: - mux_rules = [] - if request.config and request.config.muxing_rules: - mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) - + mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request) await wscrud.set_muxes(workspace_name, mux_rules) except provendcrud.ProviderNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index eef2ce475..4d0fb2a96 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -264,7 +264,9 @@ async def soft_delete_workspace(self, workspace_name: str): # Remove the muxes from the registry try: mux_registry = await rulematcher.get_muxing_rules_registry() - await mux_registry.delete_ws_rules(workspace_name) + rules = await mux_registry.get_ws_rules(workspace_name) + if rules: + await mux_registry.delete_ws_rules(workspace_name) except Exception: raise DeleteMuxesFromRegistryError( f"Error deleting mux rules for workspace {workspace_name}" diff --git a/tests/api/test_v1_workspaces.py b/tests/api/test_v1_workspaces.py index afafe7ff9..24db9f238 100644 --- a/tests/api/test_v1_workspaces.py +++ b/tests/api/test_v1_workspaces.py @@ -79,7 +79,180 @@ def mock_muxing_rules_registry(): @pytest.mark.asyncio -async def test_workspace_crud( +async def test_workspace_crud_name_only( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + ): + """Test creating and deleting a workspace by name only.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + # Create workspace + payload_create = {"name": name} + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Verify workspace exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["name"] == name + + # Delete workspace + response = await ac.delete(f"/api/v1/workspaces/{name}") + assert response.status_code == 204 + + # Verify workspace no longer exists + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_muxes_crud( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader +) -> None: + with ( + patch("codegate.api.v1.dbreader", db_reader), + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + patch( + "codegate.providers.openai.provider.OpenAIProvider.models", + return_value=["gpt-4", "gpt-3.5-turbo"], + ), + patch( + "codegate.providers.openrouter.provider.OpenRouterProvider.models", + return_value=["anthropic/claude-2", "deepseek/deepseek-r1"], + ), + ): + """Test creating and validating mux rules on a workspace.""" + + app = init_app(mock_pipeline_factory) + + provider_payload_1 = { + "name": "openai-provider", + "description": "OpenAI provider description", + "auth_type": "none", + "provider_type": "openai", + "endpoint": "https://api.openai.com", + "api_key": "sk-proj-foo-bar-123-xyz", + } + + provider_payload_2 = { + "name": "openrouter-provider", + "description": "OpenRouter provider description", + "auth_type": "none", + "provider_type": "openrouter", + "endpoint": "https://openrouter.ai/api", + "api_key": "sk-or-foo-bar-456-xyz", + } + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1) + assert response.status_code == 201 + + response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2) + assert response.status_code == 201 + + # Create workspace + workspace_name: str = str(uuid()) + custom_instructions: str = "Respond to every request in iambic pentameter" + payload_create = { + "name": workspace_name, + "config": { + "custom_instructions": custom_instructions, + "muxing_rules": [], + }, + } + + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Set mux rules + muxing_rules = [ + { + "provider_name": "openai-provider", + "provider_type": "openai", + "model": "gpt-4", + "matcher": "*.ts", + "matcher_type": "filename_match", + }, + { + "provider_name": "openrouter-provider", + "provider_type": "openrouter", + "model": "anthropic/claude-2", + "matcher_type": "catch_all", + "matcher": "", + }, + ] + + response = await ac.put(f"/api/v1/workspaces/{workspace_name}/muxes", json=muxing_rules) + assert response.status_code == 204 + + # Verify mux rules + response = await ac.get(f"/api/v1/workspaces/{workspace_name}") + assert response.status_code == 200 + response_body = response.json() + for i, rule in enumerate(response_body["config"]["muxing_rules"]): + assert rule["provider_name"] == muxing_rules[i]["provider_name"] + assert rule["provider_type"] == muxing_rules[i]["provider_type"] + assert rule["model"] == muxing_rules[i]["model"] + assert rule["matcher"] == muxing_rules[i]["matcher"] + assert rule["matcher_type"] == muxing_rules[i]["matcher_type"] + + +@pytest.mark.asyncio +async def test_create_workspace_and_add_custom_instructions( + mock_pipeline_factory, mock_workspace_crud, mock_provider_crud +) -> None: + with ( + patch("codegate.api.v1.wscrud", mock_workspace_crud), + patch("codegate.api.v1.pcrud", mock_provider_crud), + ): + """Test creating a workspace, adding custom + instructions, and validating them.""" + + app = init_app(mock_pipeline_factory) + + async with AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as ac: + name: str = str(uuid()) + + # Create workspace + payload_create = {"name": name} + response = await ac.post("/api/v1/workspaces", json=payload_create) + assert response.status_code == 201 + + # Add custom instructions + custom_instructions = "Respond to every request in iambic pentameter" + payload_instructions = {"prompt": custom_instructions} + response = await ac.put( + f"/api/v1/workspaces/{name}/custom-instructions", json=payload_instructions + ) + assert response.status_code == 204 + + # Validate custom instructions by getting the workspace + response = await ac.get(f"/api/v1/workspaces/{name}") + assert response.status_code == 200 + assert response.json()["config"]["custom_instructions"] == custom_instructions + + # Validate custom instructions by getting the custom instructions endpoint + response = await ac.get(f"/api/v1/workspaces/{name}/custom-instructions") + assert response.status_code == 200 + assert response.json()["prompt"] == custom_instructions + + +@pytest.mark.asyncio +async def test_workspace_crud_full_workspace( mock_pipeline_factory, mock_workspace_crud, mock_provider_crud, db_reader ) -> None: with ( From e5e4e77d64f7a2528440327128f05ad8ebd1951a Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 11 Mar 2025 15:46:43 +0000 Subject: [PATCH 19/23] fix possible inconsistent db state when muxes are deleted --- ...6d992_add_provider_endpoint_fields_to_muxes.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py index 347bd11a0..9aed5e0cd 100644 --- a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py +++ b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py @@ -35,7 +35,20 @@ def upgrade() -> None: """ ) - # Update both new columns with data from provider_endpoints + # Delete mux rules where provider_endpoint_id doesn't match a provider in + # the database + # This may seem extreme, but if the provider doesn't exist, the mux rule is + # invalid and will error anyway if we try to use it, so we should prevent this invalid state from ever existing. + # There is work on this branch to ensure that when a provider is deleted, + # the associated mux rules are also deleted. + op.execute( + """ + DELETE FROM muxes + WHERE provider_endpoint_id NOT IN (SELECT id FROM provider_endpoints); + """ + ) + + # Update remaining rows with provider endpoint details op.execute( """ UPDATE muxes From 7ff7a571fdd8610b6929944133d59ed4b8051407 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Tue, 11 Mar 2025 15:51:28 +0000 Subject: [PATCH 20/23] linter --- ...0-769f09b6d992_add_provider_endpoint_fields_to_muxes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py index 9aed5e0cd..2c78f806c 100644 --- a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py +++ b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py @@ -37,13 +37,16 @@ def upgrade() -> None: # Delete mux rules where provider_endpoint_id doesn't match a provider in # the database + # # This may seem extreme, but if the provider doesn't exist, the mux rule is - # invalid and will error anyway if we try to use it, so we should prevent this invalid state from ever existing. + # invalid and will error anyway if we try to use it, so we should prevent + # this invalid state from ever existing. + # # There is work on this branch to ensure that when a provider is deleted, # the associated mux rules are also deleted. op.execute( """ - DELETE FROM muxes + DELETE FROM muxes WHERE provider_endpoint_id NOT IN (SELECT id FROM provider_endpoints); """ ) From c01216c7a490921b2994763b37085a98c05ecdc9 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Wed, 12 Mar 2025 14:05:19 +0000 Subject: [PATCH 21/23] address feedback on DB schema changes --- ...2_add_provider_endpoint_fields_to_muxes.py | 133 ------------------ src/codegate/api/v1.py | 47 +++++-- src/codegate/db/connection.py | 30 ++-- src/codegate/db/models.py | 2 - src/codegate/muxing/models.py | 19 ++- src/codegate/muxing/rulematcher.py | 11 +- src/codegate/workspaces/crud.py | 39 ++--- tests/muxing/test_rulematcher.py | 19 ++- 8 files changed, 91 insertions(+), 209 deletions(-) delete mode 100644 migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py diff --git a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py b/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py deleted file mode 100644 index 2c78f806c..000000000 --- a/migrations/versions/2025_03_06_1130-769f09b6d992_add_provider_endpoint_fields_to_muxes.py +++ /dev/null @@ -1,133 +0,0 @@ -"""add provider endpoint fields to muxes - -Revision ID: 769f09b6d992 -Revises: e4c05d7591a8 -Create Date: 2025-03-06 11:30:11.647216+00:00 - -""" - -from typing import Sequence, Union - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "769f09b6d992" -down_revision: Union[str, None] = "e4c05d7591a8" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # Begin transaction - op.execute("BEGIN TRANSACTION;") - - # Add the new columns - op.execute( - """ - ALTER TABLE muxes - ADD COLUMN provider_endpoint_type TEXT; - """ - ) - op.execute( - """ - ALTER TABLE muxes - ADD COLUMN provider_endpoint_name TEXT; - """ - ) - - # Delete mux rules where provider_endpoint_id doesn't match a provider in - # the database - # - # This may seem extreme, but if the provider doesn't exist, the mux rule is - # invalid and will error anyway if we try to use it, so we should prevent - # this invalid state from ever existing. - # - # There is work on this branch to ensure that when a provider is deleted, - # the associated mux rules are also deleted. - op.execute( - """ - DELETE FROM muxes - WHERE provider_endpoint_id NOT IN (SELECT id FROM provider_endpoints); - """ - ) - - # Update remaining rows with provider endpoint details - op.execute( - """ - UPDATE muxes - SET - provider_endpoint_type = ( - SELECT provider_type - FROM provider_endpoints - WHERE provider_endpoints.id = muxes.provider_endpoint_id - ), - provider_endpoint_name = ( - SELECT name - FROM provider_endpoints - WHERE provider_endpoints.id = muxes.provider_endpoint_id - ); - """ - ) - - # Make the columns NOT NULL after populating them - # SQLite requires table recreation for this - op.execute("CREATE TABLE muxes_new AS SELECT * FROM muxes;") - op.execute("DROP TABLE muxes;") - op.execute( - """ - CREATE TABLE muxes ( - id TEXT PRIMARY KEY, - provider_endpoint_id TEXT NOT NULL, - provider_model_name TEXT NOT NULL, - workspace_id TEXT NOT NULL, - matcher_type TEXT NOT NULL, - matcher_blob TEXT NOT NULL, - priority INTEGER NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - provider_endpoint_type TEXT NOT NULL, - provider_endpoint_name TEXT NOT NULL, - FOREIGN KEY(provider_endpoint_id) REFERENCES provider_endpoints(id) - );""" - ) - op.execute("INSERT INTO muxes SELECT * FROM muxes_new;") - op.execute("DROP TABLE muxes_new;") - - # Finish transaction - op.execute("COMMIT;") - - -def downgrade() -> None: - # Begin transaction - op.execute("BEGIN TRANSACTION;") - - try: - # Check if the columns exist - op.execute( - """ - SELECT provider_endpoint_type, provider_endpoint_name - FROM muxes - LIMIT 1; - """ - ) - - # Drop both columns if they exist - op.execute( - """ - ALTER TABLE muxes - DROP COLUMN provider_endpoint_type; - """ - ) - op.execute( - """ - ALTER TABLE muxes - DROP COLUMN provider_endpoint_name; - """ - ) - except Exception: - # If there's an error (columns don't exist), rollback and continue - op.execute("ROLLBACK;") - return - - # Finish transaction - op.execute("COMMIT;") diff --git a/src/codegate/api/v1.py b/src/codegate/api/v1.py index ed781c5cb..651b2b543 100644 --- a/src/codegate/api/v1.py +++ b/src/codegate/api/v1.py @@ -290,9 +290,16 @@ async def create_workspace( if request.config and request.config.muxing_rules: mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) - workspace_row, mux_rules = await wscrud.add_workspace( + workspace_row, created_mux_rules = await wscrud.add_workspace( request.name, custom_instructions, mux_rules ) + + created_muxes_with_name_type = [ + mux_models.MuxRule.from_db_models( + mux_rule, await pcrud.get_endpoint_by_id(mux_rule.provider_endpoint_id) + ) + for mux_rule in created_mux_rules + ] except crud.WorkspaceNameAlreadyInUseError: raise HTTPException(status_code=409, detail="Workspace name already in use") except ValidationError: @@ -317,7 +324,7 @@ async def create_workspace( name=workspace_row.name, config=v1_models.WorkspaceConfig( custom_instructions=workspace_row.custom_instructions or "", - muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + muxing_rules=created_muxes_with_name_type, ), ) @@ -339,12 +346,20 @@ async def update_workspace( if request.config and request.config.muxing_rules: mux_rules = await pcrud.add_provider_ids_to_mux_rule_list(request.config.muxing_rules) - workspace_row, mux_rules = await wscrud.update_workspace( + workspace_row, updated_muxes = await wscrud.update_workspace( workspace_name, request.name, custom_instructions, mux_rules, ) + + updated_muxes_with_name_type = [ + mux_models.MuxRule.from_db_models( + mux_rule, await pcrud.get_endpoint_by_id(mux_rule.provider_endpoint_id) + ) + for mux_rule in updated_muxes + ] + except provendcrud.ProviderNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) except crud.WorkspaceDoesNotExistError: @@ -369,7 +384,7 @@ async def update_workspace( name=workspace_row.name, config=v1_models.WorkspaceConfig( custom_instructions=workspace_row.custom_instructions or "", - muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules], + muxing_rules=updated_muxes_with_name_type, ), ) @@ -707,7 +722,13 @@ async def get_workspace_muxes( The list is ordered in order of priority. That is, the first rule in the list has the highest priority.""" try: - muxes = await wscrud.get_muxes(workspace_name) + db_muxes = await wscrud.get_muxes(workspace_name) + + muxes = [] + for db_mux in db_muxes: + db_endpoint = await pcrud.get_endpoint_by_id(db_mux.provider_endpoint_id) + mux_rule = mux_models.MuxRule.from_db_models(db_mux, db_endpoint) + muxes.append(mux_rule) except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception: @@ -737,8 +758,8 @@ async def set_workspace_muxes( raise HTTPException(status_code=404, detail="Workspace does not exist") except crud.WorkspaceCrudError as e: raise HTTPException(status_code=400, detail=str(e)) - except Exception: - logger.exception("Error while setting muxes") + except Exception as e: + logger.exception(f"Error while setting muxes {e}") raise HTTPException(status_code=500, detail="Internal server error") return Response(status_code=204) @@ -755,10 +776,13 @@ async def get_workspace_by_name( """List workspaces by provider ID.""" try: ws = await wscrud.get_workspace_by_name(workspace_name) - muxes = [ - mux_models.MuxRule.from_mux_rule_with_provider_id(mux) - for mux in await wscrud.get_muxes(workspace_name) - ] + db_muxes = await wscrud.get_muxes(workspace_name) + + muxes = [] + for db_mux in db_muxes: + db_endpoint = await pcrud.get_endpoint_by_id(db_mux.provider_endpoint_id) + mux_rule = mux_models.MuxRule.from_db_models(db_mux, db_endpoint) + muxes.append(mux_rule) return v1_models.FullWorkspace( name=ws.name, @@ -771,6 +795,7 @@ async def get_workspace_by_name( except crud.WorkspaceDoesNotExistError: raise HTTPException(status_code=404, detail="Workspace does not exist") except Exception as e: + logger.exception(f"Error while getting workspace {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index a55dc3313..f99a5c8c5 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -479,16 +479,6 @@ async def update_provider_endpoint(self, provider: ProviderEndpoint) -> Provider provider, sql, should_raise=True ) - # Update dependent tables - update_muxes_sql = text( - """ - UPDATE muxes - SET provider_endpoint_name = :name, provider_endpoint_type = :provider_type - WHERE provider_endpoint_id = :id - """ - ) - await self._execute_with_no_return(update_muxes_sql, provider.model_dump()) - return updated_provider async def delete_provider_endpoint( @@ -579,14 +569,14 @@ async def add_mux(self, mux: MuxRule) -> MuxRule: sql = text( """ INSERT INTO muxes ( - id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, - matcher_blob, priority, created_at, updated_at, - provider_endpoint_type, provider_endpoint_name + id, provider_endpoint_id, provider_model_name, + workspace_id, matcher_type, matcher_blob, + priority, created_at, updated_at ) VALUES ( - :id, :provider_endpoint_id, :provider_model_name, :workspace_id, - :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, - CURRENT_TIMESTAMP, :provider_endpoint_type, :provider_endpoint_name + :id, :provider_endpoint_id, :provider_model_name, + :workspace_id, :matcher_type, :matcher_blob, + :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP ) RETURNING * """ @@ -1139,7 +1129,6 @@ async def try_get_provider_endpoint_by_name_and_type( SELECT id, name, description, provider_type, endpoint, auth_type, created_at, updated_at FROM provider_endpoints WHERE name = :name AND provider_type = :provider_type - LIMIT 1 """ ) conditions = {"name": provider_name, "provider_type": provider_type} @@ -1259,9 +1248,10 @@ async def get_all_provider_models(self) -> List[ProviderModel]: async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]: sql = text( """ - SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, - matcher_blob, priority, created_at, updated_at, - provider_endpoint_type, provider_endpoint_name + SELECT + id, provider_endpoint_id, provider_model_name, + workspace_id, matcher_type, matcher_blob, + priority, created_at, updated_at FROM muxes WHERE workspace_id = :workspace_id ORDER BY priority ASC diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 75577cc9c..5b3b95e2f 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -268,8 +268,6 @@ class ProviderModel(BaseModel): class MuxRule(BaseModel): id: str provider_endpoint_id: str - provider_endpoint_type: ProviderType - provider_endpoint_name: str provider_model_name: str workspace_id: str matcher_type: str diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index a2aefe943..5e74db2e2 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -5,6 +5,7 @@ from codegate.clients.clients import ClientType from codegate.db.models import MuxRule as DBMuxRule +from codegate.db.models import ProviderEndpoint as DBProviderEndpoint from codegate.db.models import ProviderType @@ -50,13 +51,15 @@ class MuxRule(pydantic.BaseModel): matcher: Optional[str] = None @classmethod - def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + def from_db_models( + cls, db_mux_rule: DBMuxRule, db_provider_endpoint: DBProviderEndpoint + ) -> Self: """ - Convert a DBMuxRule to a MuxRule. + Convert a DBMuxRule and DBProviderEndpoint to a MuxRule. """ return cls( - provider_name=db_mux_rule.provider_endpoint_name, - provider_type=db_mux_rule.provider_endpoint_type, + provider_name=db_provider_endpoint.name, + provider_type=db_provider_endpoint.provider_type, model=db_mux_rule.provider_model_name, matcher_type=MuxMatcherType(db_mux_rule.matcher_type), matcher=db_mux_rule.matcher_blob, @@ -85,12 +88,14 @@ class MuxRuleWithProviderId(MuxRule): provider_id: str @classmethod - def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self: + def from_db_models( + cls, db_mux_rule: DBMuxRule, db_provider_endpoint: DBProviderEndpoint + ) -> Self: """ - Convert a DBMuxRule to a MuxRuleWithProviderId. + Convert a DBMuxRule and DBProviderEndpoint to a MuxRuleWithProviderId. """ return cls( - **MuxRule.from_db_mux_rule(db_mux_rule).model_dump(), + **MuxRule.from_db_models(db_mux_rule, db_provider_endpoint).model_dump(), provider_id=db_mux_rule.provider_endpoint_id, ) diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index d41eb2ce0..7f154df7a 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -74,7 +74,11 @@ class MuxingMatcherFactory: """Factory for creating muxing matchers.""" @staticmethod - def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher: + def create( + db_mux_rule: db_models.MuxRule, + db_provider_endpoint: db_models.ProviderEndpoint, + route: ModelRoute, + ) -> MuxingRuleMatcher: """Create a muxing matcher for the given endpoint and model.""" factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = { @@ -86,7 +90,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch try: # Initialize the MuxingRuleMatcher - mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule) + mux_rule = mux_models.MuxRule.from_db_models(db_mux_rule, db_provider_endpoint) return factory[mux_rule.matcher_type](route, mux_rule) except KeyError: raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}") @@ -193,7 +197,8 @@ async def set_ws_rules(self, workspace_name: str, rules: List[MuxingRuleMatcher] async def delete_ws_rules(self, workspace_name: str) -> None: """Delete the rules for the given workspace.""" async with self._lock: - del self._ws_rules[workspace_name] + if workspace_name in self._ws_rules: + del self._ws_rules[workspace_name] async def set_active_workspace(self, workspace_name: str) -> None: """Set the active workspace.""" diff --git a/src/codegate/workspaces/crud.py b/src/codegate/workspaces/crud.py index 4d0fb2a96..1dba3a871 100644 --- a/src/codegate/workspaces/crud.py +++ b/src/codegate/workspaces/crud.py @@ -262,15 +262,8 @@ async def soft_delete_workspace(self, workspace_name: str): raise WorkspaceCrudError(f"Error deleting workspace {workspace_name}") # Remove the muxes from the registry - try: - mux_registry = await rulematcher.get_muxing_rules_registry() - rules = await mux_registry.get_ws_rules(workspace_name) - if rules: - await mux_registry.delete_ws_rules(workspace_name) - except Exception: - raise DeleteMuxesFromRegistryError( - f"Error deleting mux rules for workspace {workspace_name}" - ) + mux_registry = await rulematcher.get_muxing_rules_registry() + await mux_registry.delete_ws_rules(workspace_name) return async def hard_delete_workspace(self, workspace_name: str): @@ -305,7 +298,7 @@ async def workspaces_by_provider( return workspaces - async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRuleWithProviderId]: + async def get_muxes(self, workspace_name: str) -> List[db_models.MuxRule]: # Verify if workspace exists workspace = await self._db_reader.get_workspace_by_name(workspace_name) if not workspace: @@ -313,21 +306,7 @@ async def get_muxes(self, workspace_name: str) -> List[mux_models.MuxRuleWithPro dbmuxes = await self._db_reader.get_muxes_by_workspace(workspace.id) - muxes = [] - # These are already sorted by priority - for dbmux in dbmuxes: - muxes.append( - mux_models.MuxRuleWithProviderId( - provider_name=dbmux.provider_endpoint_name, - provider_type=dbmux.provider_endpoint_type, - provider_id=dbmux.provider_endpoint_id, - model=dbmux.provider_model_name, - matcher_type=dbmux.matcher_type, - matcher=dbmux.matcher_blob, - ) - ) - - return muxes + return dbmuxes async def set_muxes( self, workspace_name: str, muxes: List[mux_models.MuxRuleWithProviderId] @@ -359,8 +338,6 @@ async def set_muxes( new_mux = db_models.MuxRule( id=str(uuid()), provider_endpoint_id=mux.provider_id, - provider_endpoint_type=mux.provider_type, - provider_endpoint_name=mux.provider_name, provider_model_name=mux.model, workspace_id=workspace.id, matcher_type=mux.matcher_type, @@ -370,7 +347,8 @@ async def set_muxes( dbmux = await self._db_recorder.add_mux(new_mux) dbmuxes.append(dbmux) - matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, route)) + provider = await self._db_reader.get_provider_endpoint_by_id(mux.provider_id) + matchers.append(rulematcher.MuxingMatcherFactory.create(dbmux, provider, route)) priority += 1 @@ -475,7 +453,10 @@ async def repopulate_mux_cache(self) -> None: matchers: List[rulematcher.MuxingRuleMatcher] = [] for mux in muxes: + provider = await self._db_reader.get_provider_endpoint_by_id( + mux.provider_endpoint_id + ) route = await self.get_routing_for_db_mux(mux) - matchers.append(rulematcher.MuxingMatcherFactory.create(mux, route)) + matchers.append(rulematcher.MuxingMatcherFactory.create(mux, provider, route)) await mux_registry.set_ws_rules(ws.name, matchers) diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 6feec7cb4..2edd1f975 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -174,13 +174,24 @@ def test_muxing_matcher_factory(matcher_type, expected_class): matcher_type=matcher_type, matcher_blob="fake-matcher", priority=1, - provider_endpoint_name="fake-openai", - provider_endpoint_type=db_models.ProviderType.openai, + ) + provider_endpoint = db_models.ProviderEndpoint( + id="1", + auth_type="none", + description="", + endpoint="http://localhost:11434", + name="fake-openai", + provider_type="openai", ) if expected_class: assert isinstance( - rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai), expected_class + rulematcher.MuxingMatcherFactory.create( + mux_rule, provider_endpoint, mocked_route_openai + ), + expected_class, ) else: with pytest.raises(ValueError): - rulematcher.MuxingMatcherFactory.create(mux_rule, mocked_route_openai) + rulematcher.MuxingMatcherFactory.create( + mux_rule, provider_endpoint, mocked_route_openai + ) From f215dd29db6cb6abdbaa8e3e19916cd689b03e17 Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Thu, 13 Mar 2025 15:24:33 +0000 Subject: [PATCH 22/23] address unnecessary manual deletions feedback --- src/codegate/db/connection.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index f99a5c8c5..4a19a1998 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -416,16 +416,6 @@ async def soft_delete_workspace(self, workspace: WorkspaceRow) -> Optional[Works return deleted_workspace async def hard_delete_workspace(self, workspace: WorkspaceRow) -> Optional[WorkspaceRow]: - # First delete associated muxes - sql_delete_muxes = text( - """ - DELETE FROM muxes - WHERE workspace_id = :id - """ - ) - await self._execute_with_no_return(sql_delete_muxes, {"id": workspace.id}) - - # Then delete the workspace sql = text( """ DELETE FROM workspaces @@ -485,25 +475,6 @@ async def delete_provider_endpoint( self, provider: ProviderEndpoint, ) -> Optional[ProviderEndpoint]: - # Delete from provider_models - sql_delete_provider_models = text( - """ - DELETE FROM provider_models - WHERE provider_endpoint_id = :id - """ - ) - await self._execute_with_no_return(sql_delete_provider_models, {"id": provider.id}) - - # Delete from muxes - sql_delete_muxes = text( - """ - DELETE FROM muxes - WHERE provider_endpoint_id = :id - """ - ) - await self._execute_with_no_return(sql_delete_muxes, {"id": provider.id}) - - # Delete from provider_endpoints sql_delete_provider_endpoints = text( """ DELETE FROM provider_endpoints From 5e716fc5d80b0f56dafc03cd233751597edda7ff Mon Sep 17 00:00:00 2001 From: alex-mcgovern Date: Fri, 14 Mar 2025 09:02:38 +0000 Subject: [PATCH 23/23] tidy ups --- src/codegate/db/connection.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 35d6806d6..973a4a1b3 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -475,7 +475,7 @@ async def delete_provider_endpoint( self, provider: ProviderEndpoint, ) -> Optional[ProviderEndpoint]: - sql_delete_provider_endpoints = text( + sql = text( """ DELETE FROM provider_endpoints WHERE id = :id @@ -483,7 +483,7 @@ async def delete_provider_endpoint( """ ) deleted_provider = await self._execute_update_pydantic_model( - provider, sql_delete_provider_endpoints, should_raise=True + provider, sql, should_raise=True ) return deleted_provider @@ -540,14 +540,12 @@ async def add_mux(self, mux: MuxRule) -> MuxRule: sql = text( """ INSERT INTO muxes ( - id, provider_endpoint_id, provider_model_name, - workspace_id, matcher_type, matcher_blob, - priority, created_at, updated_at + id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, + matcher_blob, priority, created_at, updated_at ) VALUES ( - :id, :provider_endpoint_id, :provider_model_name, - :workspace_id, :matcher_type, :matcher_blob, - :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP + :id, :provider_endpoint_id, :provider_model_name, :workspace_id, + :matcher_type, :matcher_blob, :priority, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP ) RETURNING * """ @@ -1221,10 +1219,8 @@ async def get_all_provider_models(self) -> List[ProviderModel]: async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]: sql = text( """ - SELECT - id, provider_endpoint_id, provider_model_name, - workspace_id, matcher_type, matcher_blob, - priority, created_at, updated_at + SELECT id, provider_endpoint_id, provider_model_name, workspace_id, matcher_type, + matcher_blob, priority, created_at, updated_at FROM muxes WHERE workspace_id = :workspace_id ORDER BY priority ASC