Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporarily have /st-allowed-message-origins double as a healthcheck #5642

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions frontend/src/lib/WebsocketConnection.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ describe("doInitPings", () => {
Promise.all = originalPromiseAll
})

// NOTE: Temporary test until we're able to rename the /healthz endpoint
it("does not call the /healthz endpoint when pinging server", async () => {
axios.get = jest.fn().mockImplementation(url => {
if (url.endsWith("/healthz")) {
throw Error("kaboom")
}
if (url.endsWith("/st-allowed-message-origins")) {
return MOCK_ALLOWED_ORIGINS_RESPONSE
}
return {}
})

const uriIndex = await doInitPings(
MOCK_PING_DATA.uri,
MOCK_PING_DATA.timeoutMs,
MOCK_PING_DATA.maxTimeoutMs,
MOCK_PING_DATA.retryCallback,
MOCK_PING_DATA.setHostAllowedOrigins,
MOCK_PING_DATA.userCommandLine
)
expect(uriIndex).toEqual(0)
expect(MOCK_PING_DATA.setHostAllowedOrigins).toHaveBeenCalledWith(
MOCK_ALLOWED_ORIGINS_RESPONSE.data.allowedOrigins
)
})

it("returns the uri index and sets allowedOrigins for the first successful ping (0)", async () => {
Promise.all = jest
.fn()
Expand Down
16 changes: 15 additions & 1 deletion frontend/src/lib/WebsocketConnection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,21 @@ export function doInitPings(
// not to do so as it's semantically cleaner to not give the healthcheck
// endpoint additional responsibilities.
Promise.all([
axios.get(healthzUri, { timeout: minimumTimeoutMs }),
// NOTE: We temporarily avoid hitting the healthz endpoint for now
// because certain environments (notably GCP App Engine and Cloud Run)
// reserve the endpoint name, and the new /st-allowed-message-origins
// can be used as a healthcheck at the relatively cheap cost of some
// semantic clarity.
//
// We keep the Promise.all and just return a resolved promise in the
// first element of the array instead of actually pinging the /healthz
// endpoint to avoid having to change the structure of the code for this
// temporary change.
//
// Once we're able to pick up work on https://github.com/streamlit/streamlit/pull/5534
// again, our endpoints can be re-split into their original dedicated
// roles.
Promise.resolve(), // axios.get(healthzUri, { timeout: minimumTimeoutMs }),
axios.get(allowedOriginsUri, { timeout: minimumTimeoutMs }),
])
.then(([_, originsResp]) => {
Expand Down
44 changes: 37 additions & 7 deletions lib/streamlit/web/server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,44 @@ async def get(self):
]


# NOTE: We're temporarily having this endpoint duplicate much of the code that also
# lives in HealthHandler because the /healthz endpoint name is giving us trouble in
# certain environments (in particular, GCP products like App Engine and Cloud Run reserve
# the healthz endpoint).
#
# In the future, we'll be prefixing all of our endpoints, which will allow us to have
# this endpoint and the healthcheck endpoint return to their dedicated roles, but having
# this endpoint double as a healthcheck is fine in the meantime.
class AllowedMessageOriginsHandler(_SpecialRequestHandler):
def get(self) -> None:
# ALLOWED_MESSAGE_ORIGINS must be wrapped in a dictionary because Tornado
# disallows writing lists directly into responses due to potential XSS
# vulnerabilities.
# See https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write
self.write({"allowedOrigins": ALLOWED_MESSAGE_ORIGINS})
self.set_status(200)
def initialize(self, callback):
"""Initialize the handler

Parameters
----------
callback : callable
A function that returns True if the server is healthy

"""
self._callback = callback

async def get(self) -> None:
ok, msg = await self._callback()

if ok:
# ALLOWED_MESSAGE_ORIGINS must be wrapped in a dictionary because Tornado
# disallows writing lists directly into responses due to potential XSS
# vulnerabilities.
# See https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write
self.write({"allowedOrigins": ALLOWED_MESSAGE_ORIGINS})
self.set_status(200)

if config.get_option("server.enableXsrfProtection"):
self.set_cookie("_xsrf", self.xsrf_token)

else:
# 503 = SERVICE_UNAVAILABLE
self.set_status(503)
self.write(msg)


class MessageCacheHandler(tornado.web.RequestHandler):
Expand Down
1 change: 1 addition & 0 deletions lib/streamlit/web/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _create_app(self) -> tornado.web.Application:
(
make_url_path_regex(base, "st-allowed-message-origins"),
AllowedMessageOriginsHandler,
dict(callback=lambda: self._runtime.is_ready_for_browser_connection),
),
(
make_url_path_regex(
Expand Down
37 changes: 34 additions & 3 deletions lib/tests/streamlit/web/server/routes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import tornado.web
import tornado.websocket

from streamlit import config
from streamlit.logger import get_logger
from streamlit.runtime.forward_msg_cache import ForwardMsgCache, populate_hash_if_needed
from streamlit.runtime.runtime_util import serialize_forward_msg
Expand All @@ -34,6 +33,7 @@
StaticFileHandler,
)
from tests.streamlit.message_mocks import create_dataframe_msg
from tests.testutil import patch_config_options

LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -62,15 +62,15 @@ def test_healthz(self):
response = self.fetch("/healthz")
self.assertEqual(503, response.code)

@patch_config_options({"server.enableXsrfProtection": False})
def test_healthz_without_csrf(self):
config._set_option("server.enableXsrfProtection", False, "test")
response = self.fetch("/healthz")
self.assertEqual(200, response.code)
self.assertEqual(b"ok", response.body)
self.assertNotIn("Set-Cookie", response.headers)

@patch_config_options({"server.enableXsrfProtection": True})
def test_healthz_with_csrf(self):
config._set_option("server.enableXsrfProtection", True, "test")
response = self.fetch("/healthz")
self.assertEqual(200, response.code)
self.assertEqual(b"ok", response.body)
Expand Down Expand Up @@ -156,12 +156,20 @@ def test_parse_url_path_404(self):


class AllowedMessageOriginsHandlerTest(tornado.testing.AsyncHTTPTestCase):
def setUp(self):
super(AllowedMessageOriginsHandlerTest, self).setUp()
self._is_healthy = True

async def is_healthy(self):
return self._is_healthy, "ok"

def get_app(self):
return tornado.web.Application(
[
(
r"/st-allowed-message-origins",
AllowedMessageOriginsHandler,
dict(callback=self.is_healthy),
)
]
)
Expand All @@ -172,3 +180,26 @@ def test_allowed_message_origins(self):
self.assertEqual(
{"allowedOrigins": ALLOWED_MESSAGE_ORIGINS}, json.loads(response.body)
)

# NOTE: Temporary tests to verify this endpoint can also act as a healthcheck
# endpoint while we need it to. These tests are more or less copy-paste from the
# HealthHandlerTest class above.
def test_healthcheck_responsibilities(self):
response = self.fetch("/st-allowed-message-origins")
self.assertEqual(200, response.code)

self._is_healthy = False
response = self.fetch("/st-allowed-message-origins")
self.assertEqual(503, response.code)

@patch_config_options({"server.enableXsrfProtection": False})
def test_healthcheck_responsibilities_without_csrf(self):
response = self.fetch("/st-allowed-message-origins")
self.assertEqual(200, response.code)
self.assertNotIn("Set-Cookie", response.headers)

@patch_config_options({"server.enableXsrfProtection": True})
def test_healthcheck_responsibilities_with_csrf(self):
response = self.fetch("/st-allowed-message-origins")
self.assertEqual(200, response.code)
self.assertIn("Set-Cookie", response.headers)