Skip to content

Commit

Permalink
enh: Make client's handling of error responses more robust and user-f…
Browse files Browse the repository at this point in the history
…riendly (#418)
  • Loading branch information
jeffreyftang committed Apr 17, 2024
1 parent cc2e0a9 commit 2d33ee9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
11 changes: 9 additions & 2 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,15 @@ def generate(
timeout=self.timeout,
)

# TODO: expose better error messages for 422 and similar errors
payload = resp.json()
try:
payload = resp.json()
except requests.JSONDecodeError as e:
# If the status code is success-like, reset it to 500 since the server is sending an invalid response.
if 200 <= resp.status_code < 400:
resp.status_code = 500

payload = {"message": e.msg}

if resp.status_code != 200:
raise parse_error(resp.status_code, payload)

Expand Down
35 changes: 21 additions & 14 deletions clients/python/lorax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,17 @@ def __init__(self, model_id: str):
super(NotSupportedError, self).__init__(message)


# Unknown error
class UnknownError(Exception):
class UnprocessableEntityError(Exception):
def __init__(self, message: str):
super().__init__(message)


# Unknown error
class UnknownError(Exception):
def __init__(self, message: str, code: int):
super().__init__(f"Error status {code}: {message}")


def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Expand All @@ -75,17 +80,17 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
"""
# Try to parse a LoRAX error
message = payload["error"]
if "error_type" in payload:
error_type = payload["error_type"]
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)
message = payload.get("error", "")

error_type = payload.get("error_type", "")
if error_type == "generation":
return GenerationError(message)
if error_type == "incomplete_generation":
return IncompleteGenerationError(message)
if error_type == "overloaded":
return OverloadedError(message)
if error_type == "validation":
return ValidationError(message)

# Try to parse a APIInference error
if status_code == 400:
Expand All @@ -98,6 +103,8 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
return NotFoundError(message)
if status_code == 429:
return RateLimitExceededError(message)
if status_code == 422:
return UnprocessableEntityError(message)

# Fallback to an unknown error
return UnknownError(message)
return UnknownError(message, status_code)
7 changes: 6 additions & 1 deletion clients/python/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ShardTimeoutError,
NotFoundError,
RateLimitExceededError,
UnknownError,
UnknownError, UnprocessableEntityError,
)


Expand Down Expand Up @@ -59,6 +59,11 @@ def test_rate_limit_exceeded_error():
assert isinstance(parse_error(429, payload), RateLimitExceededError)


def test_unprocessable_entity_error():
payload = {"error": "test"}
assert isinstance(parse_error(422, payload), UnprocessableEntityError)


def test_unknown_error():
payload = {"error": "test"}
assert isinstance(parse_error(500, payload), UnknownError)

0 comments on commit 2d33ee9

Please sign in to comment.