From 21568655b06ed546cb775f02276ad303a65c8f92 Mon Sep 17 00:00:00 2001 From: yasinfakhar Date: Mon, 10 Mar 2025 17:46:23 +0330 Subject: [PATCH] fix[api]: access check added --- .../routers/playground_conversation_router.py | 11 +++++++-- .../services/conversation_service.py | 24 +++++++++++++++++++ api/src/project/services/project_service.py | 20 ++++++++++++++-- .../workspace/services/workspace_service.py | 6 +++++ 4 files changed, 57 insertions(+), 4 deletions(-) diff --git a/api/src/conversation/routers/playground_conversation_router.py b/api/src/conversation/routers/playground_conversation_router.py index a730315..fc248b3 100644 --- a/api/src/conversation/routers/playground_conversation_router.py +++ b/api/src/conversation/routers/playground_conversation_router.py @@ -1,6 +1,6 @@ from uuid import UUID from enum import Enum -from typing import List +from typing import List, Dict from pydantic import BaseModel from sqlalchemy.orm import Session from fastapi.responses import StreamingResponse @@ -59,6 +59,7 @@ class ConversationOutput(BaseModel): id: UUID name: str project_id: UUID + chat_history: List[Dict] class Config: from_attributes = True @@ -255,7 +256,7 @@ def delete_conversation( "/{project_id}/{conversation_id}/chat", summary="Send Message", description="Sends a message in the conversation chat for the authenticated user.", - response_description="The message response data along with intermediate steps metadata." + response_description="The message response data along with chat history in metadata." ) async def send_message( input: MessageInput, @@ -277,9 +278,13 @@ async def send_message( - **service**: Conversation service handling business logic. - **user_id**: The authenticated user's ID. + - **metadata**: The chat history metadata. + Returns: StreamingResponse: If `stream` is true. dict: A JSON response containing the complete message data if `stream` is false. + + metadata: The chat history metadata. """ try: @@ -298,6 +303,7 @@ async def send_message( input.message, user_id, conversation_id, + project_id, input.message_type, input.stream), media_type="text/plain" @@ -309,6 +315,7 @@ async def send_message( input.message, user_id, conversation_id, + project_id, input.message_type, input.stream ): diff --git a/api/src/conversation/services/conversation_service.py b/api/src/conversation/services/conversation_service.py index bfa5ddd..b29c2d0 100644 --- a/api/src/conversation/services/conversation_service.py +++ b/api/src/conversation/services/conversation_service.py @@ -12,6 +12,7 @@ from src.agent.services.memory_service import MemoryService from src.project.services.project_service import ProjectService from src.agent.services.agent_service import RouterAgentService +from src.project.repositories.project_repository import ProjectRepository from src.query_usage.services.query_usage_service import QueryUsageService from src.conversation.repositories.conversation_repository import ConversationRepository @@ -23,6 +24,7 @@ class ConversationService: def __init__(self): self.repository = ConversationRepository() + self.project_repository = ProjectRepository() self.memory_service = MemoryService() self.project_service = ProjectService() self.query_usage_service = QueryUsageService() @@ -43,6 +45,12 @@ def create_conversation( Returns: Conversation: The created conversation instance. """ + if name.strip() == "": + raise Exception("Name is required") + + if not self.project_repository.get_by_id(db_session, project_id, user_id): + raise Exception("Project not exists or not owned by user") + conversation = Conversation( name=name, project_id=project_id, user_id=user_id) return self.repository.create(db_session, conversation) @@ -106,6 +114,12 @@ def update_conversation( Raises: Exception: If the conversation is not found. """ + if data["name"].strip() == "": + raise Exception("Name is required") + + if not self.project_repository.get_by_id(db_session, data["project_id"], user_id): + raise Exception("Project not exists or not owned by user") + return self.repository.update(db_session, conversation_id, data, user_id) def delete_conversation( @@ -143,6 +157,7 @@ async def send_message( message: str, user_id: UUID, conversation_id: UUID, + project_id: UUID, message_type: str, stream: bool ): @@ -154,6 +169,7 @@ async def send_message( message (str): The message to be processed user_id (UUID): ID of the user sending the message conversation_id (UUID): ID of the conversation the message belongs to + project_id (UUID): ID of the project associated with the conversation message_type (str): Type of the message stream (bool): If True, streams response tokens. If False, returns complete response @@ -168,6 +184,14 @@ async def send_message( Raises: Exception: If associated project is not found """ + conversation = self.repository.get_by_id( + db_session, conversation_id, user_id) + + if not conversation: + raise Exception("Conversation not found or is not owned by user") + + if conversation.project_id != project_id: + raise Exception("Project not found or is not owned by user") config = Config.get_config() diff --git a/api/src/project/services/project_service.py b/api/src/project/services/project_service.py index a354ec3..f778510 100644 --- a/api/src/project/services/project_service.py +++ b/api/src/project/services/project_service.py @@ -1,10 +1,10 @@ from uuid import UUID -from typing import Set from typing import List from sqlalchemy.orm import Session from src.db.models import Project from src.project.repositories.project_repository import ProjectRepository +from src.workspace.repositories.workspace_repository import WorkspaceRepository class ProjectService: @@ -14,6 +14,7 @@ class ProjectService: def __init__(self): self.repository = ProjectRepository() + self.workspace_repository = WorkspaceRepository() def create_project( self, db_session: Session, name: str, user_id: UUID, workspace_id: UUID @@ -30,6 +31,14 @@ def create_project( Returns: Project: The created project instance. """ + if name.strip() == "": + raise Exception("Name is required") + + workspace = self.workspace_repository.get_by_id( + db_session, workspace_id, user_id) + if not workspace: + raise Exception("Workspace not found or not owned by user") + project = Project(name=name, user_id=user_id, workspace_id=workspace_id) return self.repository.create(db_session, project) @@ -84,6 +93,13 @@ def update_project( Returns: Project: The updated project instance. """ + if data["name"].strip() == "": + raise Exception("Name is required") + + workspace = self.workspace_repository.get_by_id( + db_session, data["workspace_id"], user_id) + if not workspace: + raise Exception("Workspace not found or not owned by user") return self.repository.update(db_session, project_id, data, user_id) def delete_project( @@ -100,4 +116,4 @@ def delete_project( Returns: None """ - self.repository.delete(db_session, project_id, user_id) \ No newline at end of file + self.repository.delete(db_session, project_id, user_id) diff --git a/api/src/workspace/services/workspace_service.py b/api/src/workspace/services/workspace_service.py index 9ec184c..f6b33b3 100644 --- a/api/src/workspace/services/workspace_service.py +++ b/api/src/workspace/services/workspace_service.py @@ -28,6 +28,9 @@ def create_workspace( Returns: Workspace: The created Workspace instance. """ + if name.strip() == "": + raise Exception("Name is required") + workspace = Workspace(name=name, user_id=user_id) return self.repository.create(db_session, workspace) @@ -97,6 +100,9 @@ def update_workspace( Returns: Workspace: The updated Workspace instance. """ + if data["name"].strip() == "": + raise Exception("Name is required") + return self.repository.update(db_session, workspace_id, data, user_id) def delete_workspace(