Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable sensitive log mode #415

Merged
merged 8 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions charts/model-engine/values_circleci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ config:
s3_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository"
dd_trace_enabled: false
istio_enabled: true
sensitive_log_mode: false
tgi_repository: "text-generation-inference"
vllm_repository: "vllm"
lightllm_repository: "lightllm"
Expand Down
1 change: 1 addition & 0 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ config:
# dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster
dd_trace_enabled: false
istio_enabled: true
sensitive_log_mode: false

# Asynchronous endpoints configs (coming soon)
sqs_profile: default
Expand Down
15 changes: 9 additions & 6 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_external_interfaces_read_only,
verify_authentication,
)
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.llms import (
CancelFineTuneResponse,
CompletionStreamV1Request,
Expand Down Expand Up @@ -307,9 +308,10 @@ async def create_completion_sync_task(
"""
Runs a sync prompt completion on an LLM.
"""
logger.info(
f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}"
)
if not hmi_config.sensitive_log_mode:
logger.info(
f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}"
)
try:
use_case = CompletionSyncV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
Expand Down Expand Up @@ -369,9 +371,10 @@ async def create_completion_stream_task(
"""
Runs a stream prompt completion on an LLM.
"""
logger.info(
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
)
if not hmi_config.sensitive_log_mode: # pragma: no cover
logger.info(
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
)
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class HostedModelInferenceServiceConfig:
user_inference_pytorch_repository: str
user_inference_tensorflow_repository: str
docker_image_layer_cache_repository: str
sensitive_log_mode: bool

@classmethod
def from_yaml(cls, yaml_path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ async def create_vllm_bundle(
else:
raise InvalidRequestException(f"Quantization {quantize} is not supported by vLLM.")

if hmi_config.sensitive_log_mode:
subcommands[-1] = subcommands[-1] + " --disable-log-requests"

command = [
"/bin/bash",
"-c",
Expand Down
1 change: 1 addition & 0 deletions model-engine/service_configs/service_config_circleci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ s3_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune

dd_trace_enabled: false
istio_enabled: true
sensitive_log_mode: false
tgi_repository: "text-generation-inference"
vllm_repository: "vllm"
lightllm_repository: "lightllm"
Expand Down
33 changes: 22 additions & 11 deletions model-engine/tests/unit/api/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import re
from typing import Any, Dict, Tuple
from unittest import mock

import pytest
from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response
Expand Down Expand Up @@ -156,6 +156,8 @@ def test_completion_sync_endpoint_not_found_returns_404(
assert response_1.status_code == 404


# When enabling this test, other tests fail with "RunTumeError got Future <Future pending> attached to a different loop"
# https://github.com/encode/starlette/issues/1315#issuecomment-980784457
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reproduce, run

GIT_TAG=$(git rev-parse HEAD) WORKSPACE=..  pytest -v -s tests/unit/api/test_tasks.py::test_create_streaming_task_success tests/unit/api/test_llms.py::test_completion_stream_success

When running these tests individually, we don't have an issue. Only when running both does this error appear. I put id(asyncio.get_running_loop()) throughout the create_completion_stream_task method and everything lined up and didn't error, so there must be something up with the TestClient library itself. Time-boxing this for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the small fixes in encode/starlette#1315 (comment), none of them seemed to work. I think we will either need to downgrade fastapi as folks suggested.

@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
def test_completion_stream_success(
llm_model_endpoint_streaming: ModelEndpoint,
Expand All @@ -175,19 +177,28 @@ def test_completion_stream_success(
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)
with mock.patch(
"model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens",
return_value=5,
):
response_1 = client.post(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)
assert response_1.status_code == 200
count = 0
for message in response_1:
assert re.fullmatch(
'data: {"request_id"}: ".*", "output": null}\r\n\r\n',
message.decode("utf-8"),
)
decoded_message = message.decode("utf-8")
assert decoded_message.startswith("data: "), "SSE does not start with 'data: '"

# strip 'data: ' prefix from Server-sent events format
json_str = decoded_message[len("data: ") :]
parsed_data = json.loads(json_str.strip())
assert parsed_data["request_id"] is not None
assert parsed_data["output"] is None
assert parsed_data["error"] is None
count += 1
assert count == 1

Expand Down