Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Simplify TraceServerInterface definition and fix inheritance #1763

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions weave/tests/test_client_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_feedback_apis(client):
project_id = client._project_id()

# Emoji from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -23,7 +23,7 @@ def test_feedback_apis(client):
id_emoji_1 = res.id

# Another emoji from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -36,7 +36,7 @@ def test_feedback_apis(client):
id_emoji_2 = res.id

# Emoji from Shawn
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjoxOQ==",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -49,7 +49,7 @@ def test_feedback_apis(client):
id_emoji_3 = res.id

# Note from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -61,7 +61,7 @@ def test_feedback_apis(client):
id_note = res.id

# Custom from Jamie
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand All @@ -73,7 +73,7 @@ def test_feedback_apis(client):
id_custom_1 = res.id

# Custom on another object
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name2:digest",
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_feedback_create_too_large(client):
project_id = client._project_id()

value = "a" * 10000
req = tsi.FeedbackCreateReqForInsert(
req = tsi.FeedbackCreateReq(
project_id=project_id,
wb_user_id="VXNlcjo0NTI1NDQ=",
weave_ref="weave:///entity/project/object/name:digest",
Expand Down
4 changes: 2 additions & 2 deletions weave/tests/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_calls_delete(client):
original_calls_delete = client.server.calls_delete

def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
post_auth_req = tsi.CallsDeleteReqForInsert(
post_auth_req = tsi.CallsDeleteReq(
project_id=req.project_id,
call_ids=req.call_ids,
wb_user_id="test-user-id",
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_calls_delete_cascade(client):
original_calls_delete = client.server.calls_delete

def patched_delete(req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
post_auth_req = tsi.CallsDeleteReqForInsert(
post_auth_req = tsi.CallsDeleteReq(
project_id=req.project_id,
call_ids=req.call_ids,
wb_user_id="test-user-id",
Expand Down
16 changes: 9 additions & 7 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .orm import Table, Column, ParamBuilder, Row

from .trace_server_interface_util import (
assert_non_null_wb_user_id,
extract_refs_from_values,
generate_id,
str_digest,
Expand Down Expand Up @@ -223,7 +224,7 @@ class SelectableCHObjSchema(BaseModel):
required_obj_select_columns = list(set(all_obj_select_columns) - set([]))


class ClickHouseTraceServer(tsi.TraceServerInterfacePostAuth):
class ClickHouseTraceServer(tsi.TraceServerInterface):
def __init__(
self,
*,
Expand Down Expand Up @@ -396,7 +397,8 @@ def calls_query_stream(
_ch_call_dict_to_call_schema_dict(ch_dict)
)

def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes:
def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
assert_non_null_wb_user_id(req)
if len(req.call_ids) > MAX_DELETE_CALLS_COUNT:
raise RequestTooLarge(
f"Cannot delete more than {MAX_DELETE_CALLS_COUNT} calls at once"
Expand Down Expand Up @@ -448,7 +450,7 @@ def calls_delete(self, req: tsi.CallsDeleteReqForInsert) -> tsi.CallsDeleteRes:

return tsi.CallsDeleteRes()

def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None:
def _ensure_valid_update_field(self, req: tsi.CallUpdateReq) -> None:
valid_update_fields = ["display_name"]
for field in valid_update_fields:
if getattr(req, field, None) is not None:
Expand All @@ -458,7 +460,8 @@ def _ensure_valid_update_field(self, req: tsi.CallUpdateReqForInsert) -> None:
f"One of [{', '.join(valid_update_fields)}] is required for call update"
)

def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes:
def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
assert_non_null_wb_user_id(req)
self._ensure_valid_update_field(req)
renamed_insertable = CallUpdateCHInsertable(
project_id=req.project_id,
Expand Down Expand Up @@ -961,9 +964,8 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR
raise ValueError("Missing chunks")
return tsi.FileContentReadRes(content=b"".join(chunks))

def feedback_create(
self, req: tsi.FeedbackCreateReqForInsert
) -> tsi.FeedbackCreateRes:
def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
assert_non_null_wb_user_id(req)
validate_feedback_create_req(req)

# Augment emoji with alias.
Expand Down
2 changes: 1 addition & 1 deletion weave/trace_server/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def validate_feedback_create_req(req: tsi.FeedbackCreateReqForInsert) -> None:
def validate_feedback_create_req(req: tsi.FeedbackCreateReq) -> None:
payload_schema = FEEDBACK_PAYLOAD_SCHEMAS.get(req.feedback_type)
if payload_schema:
try:
Expand Down
12 changes: 7 additions & 5 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import emoji

from .trace_server_interface_util import (
assert_non_null_wb_user_id,
extract_refs_from_values,
str_digest,
bytes_digest,
Expand Down Expand Up @@ -65,7 +66,7 @@ def get_conn_cursor(db_path: str) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
return conn_cursor


class SqliteTraceServer(tsi.TraceServerInterfacePostAuth):
class SqliteTraceServer(tsi.TraceServerInterface):
def __init__(self, db_path: str):
self.lock = threading.Lock()
self.db_path = db_path
Expand Down Expand Up @@ -462,6 +463,7 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR
)

def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
assert_non_null_wb_user_id(req)
# update row with a deleted_at field set to now
conn, cursor = get_conn_cursor(self.db_path)
with self.lock:
Expand Down Expand Up @@ -504,7 +506,8 @@ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:

return tsi.CallsDeleteRes()

def call_update(self, req: tsi.CallUpdateReqForInsert) -> tsi.CallUpdateRes:
def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes:
assert_non_null_wb_user_id(req)
if req.display_name is None:
raise ValueError("One of [display_name] is required for call update")

Expand Down Expand Up @@ -707,9 +710,8 @@ def read_ref(r: refs.ObjectRef) -> Any:

return tsi.RefsReadBatchRes(vals=[read_ref(r) for r in parsed_obj_refs])

def feedback_create(
self, req: tsi.FeedbackCreateReqForInsert
) -> tsi.FeedbackCreateRes:
def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes:
assert_non_null_wb_user_id(req)
validate_feedback_create_req(req)

# Augment emoji with alias.
Expand Down
43 changes: 20 additions & 23 deletions weave/trace_server/trace_server_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

from .interface.query import Query

WB_USER_ID_DESCRIPTION = (
"Do not set directly. Server will automatically populate this field."
)


class CallSchema(BaseModel):
id: str
Expand Down Expand Up @@ -73,7 +77,7 @@ class StartedCallSchemaForInsert(BaseModel):
inputs: typing.Dict[str, typing.Any]

# WB Metadata
wb_user_id: typing.Optional[str] = None
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)
wb_run_id: typing.Optional[str] = None


Expand Down Expand Up @@ -148,9 +152,8 @@ class CallsDeleteReq(BaseModel):
project_id: str
call_ids: typing.List[str]


class CallsDeleteReqForInsert(CallsDeleteReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


class CallsDeleteRes(BaseModel):
Expand Down Expand Up @@ -210,9 +213,8 @@ class CallUpdateReq(BaseModel):
# optional update fields
display_name: typing.Optional[str] = None


class CallUpdateReqForInsert(CallUpdateReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


class CallUpdateRes(BaseModel):
Expand Down Expand Up @@ -343,9 +345,8 @@ class FeedbackCreateReq(BaseModel):
]
)


class FeedbackCreateReqForInsert(FeedbackCreateReq):
wb_user_id: str
# wb_user_id is automatically populated by the server
wb_user_id: typing.Optional[str] = Field(None, description=WB_USER_ID_DESCRIPTION)


# The response provides the additional fields needed to convert a request
Expand All @@ -357,7 +358,7 @@ class FeedbackCreateRes(BaseModel):
payload: typing.Dict[str, typing.Any] # If not empty, replace payload


class Feedback(FeedbackCreateReqForInsert):
class Feedback(FeedbackCreateReq):
id: str
created_at: datetime.datetime

Expand Down Expand Up @@ -501,15 +502,11 @@ def feedback_purge(self, req: FeedbackPurgeReq) -> FeedbackPurgeRes:
raise NotImplementedError()


class TraceServerInterfacePostAuth(TraceServerInterface):
@abc.abstractmethod
def call_update(self, req: CallUpdateReqForInsert) -> CallUpdateRes:
raise NotImplementedError()

@abc.abstractmethod
def calls_delete(self, req: CallsDeleteReqForInsert) -> CallsDeleteRes:
raise NotImplementedError()

@abc.abstractmethod
def feedback_create(self, req: FeedbackCreateReqForInsert) -> FeedbackCreateRes:
raise NotImplementedError()
# These symbols are used in the WB Trace Server and it is not safe
# to remove them, else it will break the server. Once the server
# is updated to use the new symbols, these can be removed.
#
# Remove once https://github.com/wandb/core/pull/22040 lands
CallsDeleteReqForInsert = CallsDeleteReq
CallUpdateReqForInsert = CallUpdateReq
FeedbackCreateReqForInsert = FeedbackCreateReq
10 changes: 10 additions & 0 deletions weave/trace_server/trace_server_interface_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ def _visit(val: typing.Any) -> typing.Any:

_visit(vals)
return refs


def assert_non_null_wb_user_id(obj: typing.Any) -> None:
if not hasattr(obj, "wb_user_id") or obj.wb_user_id is None:
raise ValueError("wb_user_id cannot be None")


def assert_null_wb_user_id(obj: typing.Any) -> None:
if not hasattr(obj, "wb_user_id") or obj.wb_user_id is not None:
raise ValueError("wb_user_id must be be None")
Loading