Skip to content

Commit

Permalink
Merge pull request #802 from phospho-app/dev
Browse files Browse the repository at this point in the history
feat: Separate human evals, adds the phospho CLI
  • Loading branch information
fred3105 committed Aug 2, 2024
2 parents f67c706 + fbe6a1a commit d0b799d
Show file tree
Hide file tree
Showing 65 changed files with 2,376 additions and 1,437 deletions.
6 changes: 3 additions & 3 deletions backend/app/api/platform/endpoints/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def post_confirm_event(
Confirm an event that was detected.
"""

org_id = await verify_if_propelauth_user_can_access_project(user, project_id)
await verify_if_propelauth_user_can_access_project(user, project_id)
event = await confirm_event(project_id=project_id, event_id=event_id)
return event

Expand All @@ -120,7 +120,7 @@ async def post_change_label_event(
Change the label of an event.
"""
logger.debug(f"Changing label of event {event_id} to {request.new_label}")
org_id = await verify_if_propelauth_user_can_access_project(user, project_id)
await verify_if_propelauth_user_can_access_project(user, project_id)
event = await change_label_event(
project_id=project_id, event_id=event_id, new_label=request.new_label
)
Expand All @@ -142,7 +142,7 @@ async def post_change_value_event(
Change the value of a range event.
"""
logger.debug(f"Changing value of a range event {event_id} to {request.new_value}")
org_id = await verify_if_propelauth_user_can_access_project(user, project_id)
await verify_if_propelauth_user_can_access_project(user, project_id)
event = await change_value_event(
project_id=project_id, event_id=event_id, new_value=request.new_value
)
Expand Down
3 changes: 1 addition & 2 deletions backend/app/api/platform/endpoints/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from propelauth_py.user import User
import argilla as rg

from app.api.platform.models.integrations import (
DatasetCreationRequest,
Expand Down Expand Up @@ -134,7 +133,7 @@ async def post_pull_dataset(
raise HTTPException(
status_code=400, detail="The dataset name does not exist for this project."
)
argilla_dataset = await pull_dataset_from_argilla(request)
await pull_dataset_from_argilla(request)

return {"status": "ok"}

Expand Down
1 change: 0 additions & 1 deletion backend/app/api/platform/endpoints/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
from typing import Dict, List
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger

from phospho.models import ProjectDataFilters
from propelauth_fastapi import User
Expand Down
1 change: 0 additions & 1 deletion backend/app/api/platform/endpoints/onboarding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import APIRouter, Depends
from loguru import logger
from propelauth_fastapi import User

from app.api.platform.models import (
Expand Down
9 changes: 1 addition & 8 deletions backend/app/api/platform/endpoints/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,13 @@ async def post_search_sessions(
@router.post(
"/projects/{project_id}/tasks",
response_model=Tasks,
description="Get all the tasks of a project",
description="Fetch all the tasks of a project",
)
async def post_tasks(
project_id: str,
query: Optional[QuerySessionsTasksRequest] = None,
user: User = Depends(propelauth.require_user),
):
"""
Get all the tasks of a project.
Args:
project_id: The id of the project
limit: The maximum number of tasks to return
"""
project = await get_project_by_id(project_id)
propelauth.require_org_member(user, project.org_id)

Expand Down
32 changes: 31 additions & 1 deletion backend/app/api/platform/endpoints/sessions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from propelauth_fastapi import User

from app.api.platform.models import (
Session,
SessionUpdateRequest,
Tasks,
SessionHumanEvalRequest,
)
from app.security import verify_if_propelauth_user_can_access_project
from app.security.authentification import propelauth
Expand All @@ -14,6 +15,7 @@
format_session_transcript,
get_session_by_id,
event_suggestion,
human_eval_session,
)
from app.api.platform.models import AddEventRequest, RemoveEventRequest
from app.services.mongo.sessions import add_event_to_session, remove_event_from_session
Expand Down Expand Up @@ -140,3 +142,31 @@ async def post_remove_event_from_session(
event_name=remove_event.event_name,
)
return updated_session


@router.post(
"/sessions/{session_id}/human-eval",
response_model=Session,
description="Update the human eval of a session and the flag",
)
async def post_human_eval_session(
session_id: str,
sessionHumanEvalRequest: SessionHumanEvalRequest,
user: User = Depends(propelauth.require_user),
) -> Session:
"""
Update the human eval of a session and the session_flag with "success" or "failure"
Also signs the origin of the flag with owner
"""
if sessionHumanEvalRequest.human_eval not in ["success", "failure"]:
raise HTTPException(
status_code=400,
detail="The human eval must be either 'success' or 'failure'",
)
session = await get_session_by_id(session_id)
await verify_if_propelauth_user_can_access_project(user, session.project_id)
updated_task = await human_eval_session(
session_model=session,
human_eval=sessionHumanEvalRequest.human_eval,
)
return updated_task
31 changes: 31 additions & 0 deletions backend/app/api/platform/endpoints/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TaskUpdateRequest,
AddEventRequest,
RemoveEventRequest,
TaskHumanEvalRequest,
)
from app.security import verify_if_propelauth_user_can_access_project
from app.security.authentification import propelauth
Expand All @@ -16,7 +17,9 @@
update_task,
add_event_to_task,
remove_event_from_task,
human_eval_task,
)
from loguru import logger

router = APIRouter(tags=["Tasks"])

Expand Down Expand Up @@ -61,6 +64,34 @@ async def post_flag_task(
return updated_task


@router.post(
"/tasks/{task_id}/human-eval",
response_model=Task,
description="Update the human eval of a task and the flag",
)
async def post_human_eval_task(
task_id: str,
taskHumanEvalRequest: TaskHumanEvalRequest,
user: User = Depends(propelauth.require_user),
) -> Task:
"""
Update the human eval of a task and the flag with "success" or "failure"
Also signs the origin of the flag with owner
"""
if taskHumanEvalRequest.human_eval not in ["success", "failure"]:
raise HTTPException(
status_code=400,
detail="The human eval must be either 'success' or 'failure'",
)
task = await get_task_by_id(task_id)
await verify_if_propelauth_user_can_access_project(user, task.project_id)
updated_task = await human_eval_task(
task_model=task,
human_eval=taskHumanEvalRequest.human_eval,
)
return updated_task


@router.post(
"/tasks/{task_id}",
response_model=Task,
Expand Down
2 changes: 2 additions & 0 deletions backend/app/api/platform/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Tests,
UserMetadata,
Users,
TaskHumanEvalRequest,
)

from .abtests import ABTest, ABTests
Expand All @@ -46,3 +47,4 @@
)
from .recipes import RunRecipeRequest
from .tasks import AddEventRequest, RemoveEventRequest
from .sessions import SessionHumanEvalRequest
2 changes: 1 addition & 1 deletion backend/app/api/platform/models/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import Optional, Union
from typing import Optional


class EventBackfillRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion backend/app/api/platform/models/integrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from phospho.models import ProjectDataFilters
from typing import Literal, Optional, List
from typing import Literal, Optional


class DatasetSamplingParameters(BaseModel):
Expand Down
7 changes: 7 additions & 0 deletions backend/app/api/platform/models/sessions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel
from app.db.models import EventDefinition
from typing import Literal, Optional


class AddEventRequest(BaseModel):
Expand All @@ -8,3 +9,9 @@ class AddEventRequest(BaseModel):

class RemoveEventRequest(BaseModel):
event_name: str


class SessionHumanEvalRequest(BaseModel):
human_eval: Optional[Literal["success", "failure"]] = None
project_id: Optional[str] = None
source: Optional[str] = "owner"
68 changes: 33 additions & 35 deletions backend/app/api/v2/endpoints/projects.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
from typing import Optional
from loguru import logger

from fastapi import APIRouter, Depends, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, Depends
from loguru import logger

from app.api.v2.models import (
ComputeJobsRequest,
Sessions,
Tasks,
FlattenedTasks,
FlattenedTasksRequest,
QuerySessionsTasksRequest,
Sessions,
Tasks,
)

from app.api.platform.models.explore import ProjectDataFilters

from app.security import authenticate_org_key, verify_propelauth_org_owns_project_id
from app.services.mongo.projects import (
get_all_sessions,
backcompute_recipes,
)
from app.services.mongo.tasks import get_all_tasks

from app.services.mongo.explore import (
fetch_flattened_tasks,
update_from_flattened_tasks,
)
from app.services.mongo.projects import (
backcompute_recipes,
get_all_sessions,
)
from app.services.mongo.tasks import get_all_tasks

router = APIRouter(tags=["Projects"])


@router.get(
@router.post(
"/projects/{project_id}/sessions",
response_model=Sessions,
description="Get all the sessions of a project",
description="Fetch all the sessions of a project",
)
async def get_sessions(
project_id: str,
Expand All @@ -43,32 +40,37 @@ async def get_sessions(
return Sessions(sessions=sessions)


@router.get(
@router.post(
"/projects/{project_id}/tasks",
response_model=Tasks,
description="Get all the tasks of a project",
description="Fetch all the tasks of a project",
)
async def get_tasks(
async def post_tasks(
project_id: str,
limit: int = 1000,
filters: Optional[ProjectDataFilters] = None,
query: Optional[QuerySessionsTasksRequest] = None,
org: dict = Depends(authenticate_org_key),
) -> Tasks:
"""
Get all the tasks of a project. If filters is specified, the tasks will be filtered according to the filter.
Fetch all the tasks of a project.
Args:
project_id: The id of the project
limit: The maximum number of tasks to return
filters: This model is used to filter tasks in the get_tasks endpoint. The filters are applied as AND filters.
The filters are combined as AND conditions on the different fields.
"""
await verify_propelauth_org_owns_project_id(org, project_id)
if filters is None:
filters = ProjectDataFilters()
if isinstance(filters.event_name, str):
filters.event_name = [filters.event_name]

tasks = await get_all_tasks(project_id=project_id, limit=limit, filters=filters)
if query is None:
query = QuerySessionsTasksRequest()
if query.filters.user_id is not None:
if query.filters.metadata is None:
query.filters.metadata = {}
query.filters.metadata["user_id"] = query.filters.user_id

tasks = await get_all_tasks(
project_id=project_id,
limit=None,
validate_metadata=True,
filters=query.filters,
sorting=query.sorting,
pagination=query.pagination,
)
return Tasks(tasks=tasks)


Expand All @@ -84,10 +86,6 @@ async def get_flattened_tasks(
) -> FlattenedTasks:
"""
Get all the tasks of a project in a flattened format.
Args:
project_id: The id of the project
limit: The maximum number of tasks to return
"""
await verify_propelauth_org_owns_project_id(org, project_id)

Expand Down
Loading

0 comments on commit d0b799d

Please sign in to comment.