diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py index 8a1c3525..f77d96ea 100644 --- a/integration_tests/rest_api_utils.py +++ b/integration_tests/rest_api_utils.py @@ -168,7 +168,7 @@ def my_model(**keyword_args): CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = { "name": format_name("llama-2-7b-test"), - "model_name": "llama-2-7b", + "model_name": "llama-2-7b-chat", "source": "hugging_face", "inference_framework": "vllm", "inference_framework_image_tag": "latest", @@ -802,7 +802,7 @@ async def create_llm_streaming_task( timeout=LONG_NETWORK_TIMEOUT_SEC, ) as response: assert response.status == 200, (await response.read()).decode() - return await response.json() + return (await response.read()).decode() async def create_sync_tasks( @@ -987,6 +987,27 @@ def ensure_llm_task_response_is_correct( assert re.search(response_text_regex, response["output"]["text"]) +def ensure_llm_task_stream_response_is_correct( + response: str, + required_output_fields: Optional[List[str]], + response_text_regex: Optional[str], +): + # parse response + # data has format "data: \n\ndata: \n\n" + # We want to get a list of dictionaries parsing out the 'data:' field + parsed_response = [ + json.loads(r.split("data: ")[1]) for r in response.split("\n") if "data:" in r.strip() + ] + + # Join the text field of the response + response_text = "".join([r["output"]["text"] for r in parsed_response]) + print("response text: ", response_text) + assert response_text is not None + + if response_text_regex is not None: + assert re.search(response_text_regex, response_text) + + # Wait up to 30 seconds for the tasks to be returned. @retry( stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) diff --git a/integration_tests/test_completions.py b/integration_tests/test_completions.py index e2530963..aac6b213 100644 --- a/integration_tests/test_completions.py +++ b/integration_tests/test_completions.py @@ -13,6 +13,7 @@ delete_llm_model_endpoint, ensure_launch_gateway_healthy, ensure_llm_task_response_is_correct, + ensure_llm_task_stream_response_is_correct, ensure_n_ready_private_llm_endpoints_short, ensure_nonzero_available_llm_workers, ) @@ -86,7 +87,7 @@ def test_completions(capsys): ) ) for response in task_responses: - ensure_llm_task_response_is_correct( + ensure_llm_task_stream_response_is_correct( response, required_output_fields, response_text_regex ) except Exception as e: