Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Mar 24, 2024
1 parent 7ef9ff5 commit 4b85361
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 24 deletions.
2 changes: 0 additions & 2 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,6 @@ def _set_completed_at(self) -> None:
def _send_webhook(self, event: schema.WebhookEvent) -> None:
if self._webhook_sender is not None:
self._webhook_sender(self.response, event)
# dict_response = jsonable_encoder(self.response.dict(exclude_unset=True))
# self._webhook_sender(dict_response, event)

def _upload_files(self, output: Any) -> Any:
if self._file_uploader is None:
Expand Down
21 changes: 12 additions & 9 deletions python/tests/server/test_response_throttler.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
import time

from cog.schema import Status
from cog.schema import PredictionResponse, Status
from cog.server.response_throttler import ResponseThrottler

processing = PredictionResponse(input={}, status=Status.PROCESSING)
succeeded = PredictionResponse(input={}, status=Status.SUCCEEDED)


def test_zero_interval():
throttler = ResponseThrottler(response_interval=0)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_terminal_status():
throttler = ResponseThrottler(response_interval=10)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert throttler.should_send_response({"status": Status.SUCCEEDED})
assert throttler.should_send_response(succeeded)


def test_nonzero_internal():
throttler = ResponseThrottler(response_interval=0.2)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
throttler.update_last_sent_response_time()
assert not throttler.should_send_response({"status": Status.PROCESSING})
assert not throttler.should_send_response(processing)
throttler.update_last_sent_response_time()

time.sleep(0.3)

assert throttler.should_send_response({"status": Status.PROCESSING})
assert throttler.should_send_response(processing)
4 changes: 2 additions & 2 deletions python/tests/server/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def test_prediction_event_handler_webhook_sender(match):
h.succeeded()

s.assert_called_once_with(
match(
{
PredictionResponse(
**{
"input": {"hello": "there"},
"output": ["elephant", "duck"],
"logs": "running a prediction\nstill running\n",
Expand Down
26 changes: 15 additions & 11 deletions python/tests/server/test_webhook.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import requests
import responses
from cog.schema import WebhookEvent
from cog.schema import WebhookEvent, PredictionResponse
from cog.server.webhook import webhook_caller, webhook_caller_filtered
from responses import registries

processing = PredictionResponse(input={}, output=status=Status.PROCESSING)
payload = {"status": "processing", "logs": "giraffe", "input": {}}
processing = PredictionResponse(**payload)


@responses.activate
def test_webhook_caller_basic():
c = webhook_caller("https://example.com/webhook/123")

responses.post(
"https://example.com/webhook/123",
json={"status": "processing", "animal": "giraffe"},
json=payload,
status=200,
)

c({"status": "processing", "animal": "giraffe"})
c(processing)


@responses.activate
Expand All @@ -24,11 +28,11 @@ def test_webhook_caller_non_terminal_does_not_retry():

responses.post(
"https://example.com/webhook/123",
json={"status": "processing", "animal": "giraffe"},
json=payload,
status=429,
)

c({"status": "processing", "animal": "giraffe"})
c(processing)


@responses.activate(registry=registries.OrderedRegistry)
Expand Down Expand Up @@ -63,11 +67,11 @@ def test_webhook_includes_user_agent():

responses.post(
"https://example.com/webhook/123",
json={"status": "processing", "animal": "giraffe"},
json=payload,
status=200,
)

c({"status": "processing", "animal": "giraffe"})
c(processing)

assert len(responses.calls) == 1
user_agent = responses.calls[0].request.headers["user-agent"]
Expand All @@ -81,19 +85,19 @@ def test_webhook_caller_filtered_basic():

responses.post(
"https://example.com/webhook/123",
json={"status": "processing", "animal": "giraffe"},
json=payload,
status=200,
)

c({"status": "processing", "animal": "giraffe"}, WebhookEvent.LOGS)
c(processing, WebhookEvent.LOGS)


@responses.activate
def test_webhook_caller_filtered_omits_filtered_events():
events = {WebhookEvent.COMPLETED}
c = webhook_caller_filtered("https://example.com/webhook/123", events)

c({"status": "processing", "animal": "giraffe"}, WebhookEvent.LOGS)
c(processing, WebhookEvent.LOGS)


@responses.activate
Expand All @@ -110,4 +114,4 @@ def test_webhook_caller_connection_errors():

c = webhook_caller("https://example.com/webhook/123")
# this should not raise an error
c({"status": "processing", "animal": "giraffe"})
c(processing)

0 comments on commit 4b85361

Please sign in to comment.