Skip to content

Commit

Permalink
chore(weave): Simplify TraceServerInterface definition and fix inheri…
Browse files Browse the repository at this point in the history
…tance (#1763)

* init

* removed other uses

* removed other uses

* init

* added description

* fixed description
  • Loading branch information
tssweeney committed Jun 12, 2024
1 parent 8d8855e commit d4ba75d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 45 deletions.
14 changes: 7 additions & 7 deletions weave/tests/test_client_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_feedback_apis(client):
project_id = client._project_id()

# Emoji from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -23,7 +23,7 @@ def test_feedback_apis(client):
id_emoji_1 = res.id

# Another emoji from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -36,7 +36,7 @@ def test_feedback_apis(client):
id_emoji_2 = res.id

# Emoji from Shawn
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjoxOQ==",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -49,7 +49,7 @@ def test_feedback_apis(client):
id_emoji_3 = res.id

# Note from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -61,7 +61,7 @@ def test_feedback_apis(client):
id_note = res.id

# Custom from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -73,7 +73,7 @@ def test_feedback_apis(client):
id_custom_1 = res.id

# Custom on another object
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name2:digest",
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_feedback_create_too_large(client):
project_id = client._project_id()

value = "a" * 10000
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand Down
4 changes: 2 additions & 2 deletions weave/tests/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_calls_delete(client):
original_calls_delete = client.server.calls_delete

def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
post_auth_req = tsi.CallsDeleteReqForInsert(
post_auth_req = tsi.CallsDeleteReq(
project_id=req.project_id,
call_ids=req.call_ids,
wb_user_id="test-user-id",
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_calls_delete_cascade(client):
original_calls_delete = client.server.calls_delete

def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
post_auth_req = tsi.CallsDeleteReqForInsert(
post_auth_req = tsi.CallsDeleteReq(
project_id=req.project_id,
call_ids=req.call_ids,
wb_user_id="test-user-id",
Expand Down
16 changes: 9 additions & 7 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .orm import Table, Column, ParamBuilder, Row

from .trace_server_interface_util import (
assert_non_null_wb_user_id,
extract_refs_from_values,
generate_id,
str_digest,
Expand Down Expand Up @@ -223,7 +224,7 @@ class SelectableCHObjSchema(BaseModel):
required_obj_select_columns = list(set(all_obj_select_columns) - set([]))


class ClickHouseTraceServer(tsi.TraceServerInterfacePostAuth):
class ClickHouseTraceServer(tsi.TraceServerInterface):
def __init__(
self,
*,
Expand Down Expand Up @@ -396,7 +397,8 @@ def calls_query_stream(
_ch_call_dict_to_call_schema_dict(ch_dict)
)

def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes:
def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
assert_non_null_wb_user_id(req)
if len(req.call_ids) > MAX_DELETE_CALLS_COUNT:
raise RequestTooLarge(
f"Cannot delete more than {MAX_DELETE_CALLS_COUNT} calls at once"
Expand Down Expand Up @@ -448,7 +450,7 @@ def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes:

return tsi.CallsDeleteRes()

def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None:
def _ensure_valid_update_field(self, req: tsi.CallUpdateReq) -> None:
valid_update_fields = ["display_name"]
for field in valid_update_fields:
if getattr(req, field, None) is not None:
Expand All @@ -458,7 +460,8 @@ def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None:
f"One of [{', '.join(valid_update_fields)}] is required for call update"
)

def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes:
def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
assert_non_null_wb_user_id(req)
self._ensure_valid_update_field(req)
renamed_insertable = CallUpdateCHInsertable(
project_id=req.project_id,
Expand Down Expand Up @@ -961,9 +964,8 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR
raise ValueError("Missing chunks")
return tsi.FileContentReadRes(content=b"".join(chunks))

def feedback_create(
self, req: tsi.FeedbackCreateReqForInsert
) -> tsi.FeedbackCreateRes:
def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
assert_non_null_wb_user_id(req)
validate_feedback_create_req(req)

# Augment emoji with alias.
Expand Down
2 changes: 1 addition & 1 deletion weave/trace_server/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def validate_feedback_create_req(req: tsi.FeedbackCreateReqForInsert) -> None:
def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None:
payload_schema = FEEDBACK_PAYLOAD_SCHEMAS.get(req.feedback_type)
if payload_schema:
try:
Expand Down
12 changes: 7 additions & 5 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import emoji

from .trace_server_interface_util import (
assert_non_null_wb_user_id,
extract_refs_from_values,
str_digest,
bytes_digest,
Expand Down Expand Up @@ -65,7 +66,7 @@ def get_conn_cursor(db_path: str) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
return conn_cursor


class SqliteTraceServer(tsi.TraceServerInterfacePostAuth):
class SqliteTraceServer(tsi.TraceServerInterface):
def __init__(self, db_path: str):
self.lock = threading.Lock()
self.db_path = db_path
Expand Down Expand Up @@ -462,6 +463,7 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR
)

def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
assert_non_null_wb_user_id(req)
# update row with a deleted_at field set to now
conn, cursor = get_conn_cursor(self.db_path)
with self.lock:
Expand Down Expand Up @@ -504,7 +506,8 @@ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:

return tsi.CallsDeleteRes()

def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes:
def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
assert_non_null_wb_user_id(req)
if req.display_name is None:
raise ValueError("One of [display_name] is required for call update")

Expand Down Expand Up @@ -707,9 +710,8 @@ def read_ref(r: refs.ObjectRef) -> Any:

return tsi.RefsReadBatchRes(vals=[read_ref(r) for r in parsed_obj_refs])

def feedback_create(
self, req: tsi.FeedbackCreateReqForInsert
) -> tsi.FeedbackCreateRes:
def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
assert_non_null_wb_user_id(req)
validate_feedback_create_req(req)

# Augment emoji with alias.
Expand Down
43 changes: 20 additions & 23 deletions weave/trace_server/trace_server_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from .interface.query import Query

WB_USER_ID_DESCRIPTION = (
"Do not set directly. Server will automatically populate this field."
)


class CallSchema(BaseModel):
id: str
Expand Down Expand Up @@ -73,7 +77,7 @@ class StartedCallSchemaForInsert(BaseModel):
inputs: typing.Dict[str, typing.Any]

# WB Metadata
wb_user_id: typing.Optional[str] = None
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
wb_run_id: typing.Optional[str] = None


Expand Down Expand Up @@ -148,9 +152,8 @@ class CallsDeleteReq(BaseModel):
project_id: str
call_ids: typing.List[str]


class CallsDeleteReqForInsert(CallsDeleteReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


class CallsDeleteRes(BaseModel):
Expand Down Expand Up @@ -210,9 +213,8 @@ class CallUpdateReq(BaseModel):
# optional update fields
display_name: typing.Optional[str] = None


class CallUpdateReqForInsert(CallUpdateReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


class CallUpdateRes(BaseModel):
Expand Down Expand Up @@ -343,9 +345,8 @@ class FeedbackCreateReq(BaseModel):
]
)


class FeedbackCreateReqForInsert(FeedbackCreateReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


# The response provides the additional fields needed to convert a request
Expand All @@ -357,7 +358,7 @@ class FeedbackCreateRes(BaseModel):
payload: typing.Dict[str, typing.Any] # If not empty, replace payload


class Feedback(FeedbackCreateReqForInsert):
class Feedback(FeedbackCreateReq):
id: str
created_at: datetime.datetime

Expand Down Expand Up @@ -501,15 +502,11 @@ def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes:
raise NotImplementedError()


class TraceServerInterfacePostAuth(TraceServerInterface):
@abc.abstractmethod
def call_update(self, req: CallUpdateReqForInsert) -> CallUpdateRes:
raise NotImplementedError()

@abc.abstractmethod
def calls_delete(self, req: CallsDeleteReqForInsert) -> CallsDeleteRes:
raise NotImplementedError()

@abc.abstractmethod
def feedback_create(self, req: FeedbackCreateReqForInsert) -> FeedbackCreateRes:
raise NotImplementedError()
# These symbols are used in the WB Trace Server and it is not safe
# to remove them, else it will break the server. Once the server
# is updated to use the new symbols, these can be removed.
#
# Remove once https://github.com/wandb/core/pull/22040 lands
CallsDeleteReqForInsert = CallsDeleteReq
CallUpdateReqForInsert = CallUpdateReq
FeedbackCreateReqForInsert = FeedbackCreateReq
10 changes: 10 additions & 0 deletions weave/trace_server/trace_server_interface_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ def _visit(val: typing.Any) -> typing.Any:

_visit(vals)
return refs


def assert_non_null_wb_user_id(obj: typing.Any) -> None:
if not hasattr(obj, "wb_user_id") or obj.wb_user_id is None:
raise ValueError("wb_user_id cannot be None")


def assert_null_wb_user_id(obj: typing.Any) -> None:
if not hasattr(obj, "wb_user_id") or obj.wb_user_id is not None:
raise ValueError("wb_user_id must be be None")

0 comments on commit d4ba75d

Please sign in to comment.