diff --git a/weave/tests/test_client_feedback.py b/weave/tests/test_client_feedback.py index 3436beca24..a79692e499 100644 --- a/weave/tests/test_client_feedback.py +++ b/weave/tests/test_client_feedback.py @@ -10,7 +10,7 @@ def test_feedback_apis(client): project_id = client._project_id() # Emoji from Jamie - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name:digest", @@ -23,7 +23,7 @@ def test_feedback_apis(client): id_emoji_1 = res.id # Another emoji from Jamie - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name:digest", @@ -36,7 +36,7 @@ def test_feedback_apis(client): id_emoji_2 = res.id # Emoji from Shawn - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjoxOQ==", weave_ref="weave:///entity/project/object/name:digest", @@ -49,7 +49,7 @@ def test_feedback_apis(client): id_emoji_3 = res.id # Note from Jamie - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name:digest", @@ -61,7 +61,7 @@ def test_feedback_apis(client): id_note = res.id # Custom from Jamie - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name:digest", @@ -73,7 +73,7 @@ def test_feedback_apis(client): id_custom_1 = res.id # Custom on another object - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name2:digest", @@ -205,7 +205,7 @@ def test_feedback_create_too_large(client): project_id = client._project_id() value = "a" * 10000 - req = tsi.FeedbackCreateReqForInsert( + req = tsi.FeedbackCreateReq( project_id=project_id, wb_user_id="VXNlcjo0NTI1NDQ=", weave_ref="weave:///entity/project/object/name:digest", diff --git a/weave/tests/test_weave_client.py b/weave/tests/test_weave_client.py index fb151a6a9f..ab5abe6789 100644 --- a/weave/tests/test_weave_client.py +++ b/weave/tests/test_weave_client.py @@ -250,7 +250,7 @@ def test_calls_delete(client): original_calls_delete = client.server.calls_delete def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - post_auth_req = tsi.CallsDeleteReqForInsert( + post_auth_req = tsi.CallsDeleteReq( project_id=req.project_id, call_ids=req.call_ids, wb_user_id="test-user-id", @@ -294,7 +294,7 @@ def test_calls_delete_cascade(client): original_calls_delete = client.server.calls_delete def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - post_auth_req = tsi.CallsDeleteReqForInsert( + post_auth_req = tsi.CallsDeleteReq( project_id=req.project_id, call_ids=req.call_ids, wb_user_id="test-user-id", diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index c394750232..f011bd9999 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -53,6 +53,7 @@ from .orm import Table, Column, ParamBuilder, Row from .trace_server_interface_util import ( + assert_non_null_wb_user_id, extract_refs_from_values, generate_id, str_digest, @@ -223,7 +224,7 @@ class SelectableCHObjSchema(BaseModel): required_obj_select_columns = list(set(all_obj_select_columns) - set([])) -class ClickHouseTraceServer(tsi.TraceServerInterfacePostAuth): +class ClickHouseTraceServer(tsi.TraceServerInterface): def __init__( self, *, @@ -396,7 +397,8 @@ def calls_query_stream( _ch_call_dict_to_call_schema_dict(ch_dict) ) - def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes: + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + assert_non_null_wb_user_id(req) if len(req.call_ids) > MAX_DELETE_CALLS_COUNT: raise RequestTooLarge( f"Cannot delete more than {MAX_DELETE_CALLS_COUNT} calls at once" @@ -448,7 +450,7 @@ def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes: return tsi.CallsDeleteRes() - def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None: + def _ensure_valid_update_field(self, req: tsi.CallUpdateReq) -> None: valid_update_fields = ["display_name"] for field in valid_update_fields: if getattr(req, field, None) is not None: @@ -458,7 +460,8 @@ def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None: f"One of [{', '.join(valid_update_fields)}] is required for call update" ) - def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes: + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + assert_non_null_wb_user_id(req) self._ensure_valid_update_field(req) renamed_insertable = CallUpdateCHInsertable( project_id=req.project_id, @@ -961,9 +964,8 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR raise ValueError("Missing chunks") return tsi.FileContentReadRes(content=b"".join(chunks)) - def feedback_create( - self, req: tsi.FeedbackCreateReqForInsert - ) -> tsi.FeedbackCreateRes: + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + assert_non_null_wb_user_id(req) validate_feedback_create_req(req) # Augment emoji with alias. diff --git a/weave/trace_server/feedback.py b/weave/trace_server/feedback.py index 07e3a16d37..aad7421418 100644 --- a/weave/trace_server/feedback.py +++ b/weave/trace_server/feedback.py @@ -28,7 +28,7 @@ } -def validate_feedback_create_req(req: tsi.FeedbackCreateReqForInsert) -> None: +def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None: payload_schema = FEEDBACK_PAYLOAD_SCHEMAS.get(req.feedback_type) if payload_schema: try: diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index cb9b1923bb..e2f5b00a97 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -14,6 +14,7 @@ import emoji from .trace_server_interface_util import ( + assert_non_null_wb_user_id, extract_refs_from_values, str_digest, bytes_digest, @@ -65,7 +66,7 @@ def get_conn_cursor(db_path: str) -> tuple[sqlite3.Connection, sqlite3.Cursor]: return conn_cursor -class SqliteTraceServer(tsi.TraceServerInterfacePostAuth): +class SqliteTraceServer(tsi.TraceServerInterface): def __init__(self, db_path: str): self.lock = threading.Lock() self.db_path = db_path @@ -462,6 +463,7 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR ) def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: + assert_non_null_wb_user_id(req) # update row with a deleted_at field set to now conn, cursor = get_conn_cursor(self.db_path) with self.lock: @@ -504,7 +506,8 @@ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: return tsi.CallsDeleteRes() - def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes: + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: + assert_non_null_wb_user_id(req) if req.display_name is None: raise ValueError("One of [display_name] is required for call update") @@ -707,9 +710,8 @@ def read_ref(r: refs.ObjectRef) -> Any: return tsi.RefsReadBatchRes(vals=[read_ref(r) for r in parsed_obj_refs]) - def feedback_create( - self, req: tsi.FeedbackCreateReqForInsert - ) -> tsi.FeedbackCreateRes: + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: + assert_non_null_wb_user_id(req) validate_feedback_create_req(req) # Augment emoji with alias. diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 24cc40e2f5..1c9f8d92c3 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -5,6 +5,10 @@ from .interface.query import Query +WB_USER_ID_DESCRIPTION = ( + "Do not set directly. Server will automatically populate this field." +) + class CallSchema(BaseModel): id: str @@ -73,7 +77,7 @@ class StartedCallSchemaForInsert(BaseModel): inputs: typing.Dict[str, typing.Any] # WB Metadata - wb_user_id: typing.Optional[str] = None + wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) wb_run_id: typing.Optional[str] = None @@ -148,9 +152,8 @@ class CallsDeleteReq(BaseModel): project_id: str call_ids: typing.List[str] - -class CallsDeleteReqForInsert(CallsDeleteReq): - wb_user_id: str + # wb_user_id is automatically populated by the server + wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) class CallsDeleteRes(BaseModel): @@ -210,9 +213,8 @@ class CallUpdateReq(BaseModel): # optional update fields display_name: typing.Optional[str] = None - -class CallUpdateReqForInsert(CallUpdateReq): - wb_user_id: str + # wb_user_id is automatically populated by the server + wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) class CallUpdateRes(BaseModel): @@ -343,9 +345,8 @@ class FeedbackCreateReq(BaseModel): ] ) - -class FeedbackCreateReqForInsert(FeedbackCreateReq): - wb_user_id: str + # wb_user_id is automatically populated by the server + wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION) # The response provides the additional fields needed to convert a request @@ -357,7 +358,7 @@ class FeedbackCreateRes(BaseModel): payload: typing.Dict[str, typing.Any] # If not empty, replace payload -class Feedback(FeedbackCreateReqForInsert): +class Feedback(FeedbackCreateReq): id: str created_at: datetime.datetime @@ -501,15 +502,11 @@ def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes: raise NotImplementedError() -class TraceServerInterfacePostAuth(TraceServerInterface): - @abc.abstractmethod - def call_update(self, req: CallUpdateReqForInsert) -> CallUpdateRes: - raise NotImplementedError() - - @abc.abstractmethod - def calls_delete(self, req: CallsDeleteReqForInsert) -> CallsDeleteRes: - raise NotImplementedError() - - @abc.abstractmethod - def feedback_create(self, req: FeedbackCreateReqForInsert) -> FeedbackCreateRes: - raise NotImplementedError() +# These symbols are used in the WB Trace Server and it is not safe +# to remove them, else it will break the server. Once the server +# is updated to use the new symbols, these can be removed. +# +# Remove once https://github.com/wandb/core/pull/22040 lands +CallsDeleteReqForInsert = CallsDeleteReq +CallUpdateReqForInsert = CallUpdateReq +FeedbackCreateReqForInsert = FeedbackCreateReq diff --git a/weave/trace_server/trace_server_interface_util.py b/weave/trace_server/trace_server_interface_util.py index e358b6b793..54f84c2935 100644 --- a/weave/trace_server/trace_server_interface_util.py +++ b/weave/trace_server/trace_server_interface_util.py @@ -81,3 +81,13 @@ def _visit(val: typing.Any) -> typing.Any: _visit(vals) return refs + + +def assert_non_null_wb_user_id(obj: typing.Any) -> None: + if not hasattr(obj, "wb_user_id") or obj.wb_user_id is None: + raise ValueError("wb_user_id cannot be None") + + +def assert_null_wb_user_id(obj: typing.Any) -> None: + if not hasattr(obj, "wb_user_id") or obj.wb_user_id is not None: + raise ValueError("wb_user_id must be be None")