Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions integration_tests/rest_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def my_model(**keyword_args):

CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = {
"name": format_name("llama-2-7b-test"),
"model_name": "llama-2-7b",
"model_name": "llama-2-7b-chat",
"source": "hugging_face",
"inference_framework": "vllm",
"inference_framework_image_tag": "latest",
Expand Down Expand Up @@ -802,7 +802,7 @@ async def create_llm_streaming_task(
timeout=LONG_NETWORK_TIMEOUT_SEC,
) as response:
assert response.status == 200, (await response.read()).decode()
return await response.json()
return (await response.read()).decode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, how did this work before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm pretty sure it didn't



async def create_sync_tasks(
Expand Down Expand Up @@ -987,6 +987,27 @@ def ensure_llm_task_response_is_correct(
assert re.search(response_text_regex, response["output"]["text"])


def ensure_llm_task_stream_response_is_correct(
response: str,
required_output_fields: Optional[List[str]],
response_text_regex: Optional[str],
):
# parse response
# data has format "data: <data>\n\ndata: <data>\n\n"
# We want to get a list of dictionaries parsing out the 'data:' field
parsed_response = [
json.loads(r.split("data: ")[1]) for r in response.split("\n") if "data:" in r.strip()
]

# Join the text field of the response
response_text = "".join([r["output"]["text"] for r in parsed_response])
print("response text: ", response_text)
assert response_text is not None

if response_text_regex is not None:
assert re.search(response_text_regex, response_text)


# Wait up to 30 seconds for the tasks to be returned.
@retry(
stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError)
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
delete_llm_model_endpoint,
ensure_launch_gateway_healthy,
ensure_llm_task_response_is_correct,
ensure_llm_task_stream_response_is_correct,
ensure_n_ready_private_llm_endpoints_short,
ensure_nonzero_available_llm_workers,
)
Expand Down Expand Up @@ -86,7 +87,7 @@ def test_completions(capsys):
)
)
for response in task_responses:
ensure_llm_task_response_is_correct(
ensure_llm_task_stream_response_is_correct(
response, required_output_fields, response_text_regex
)
except Exception as e:
Expand Down