Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_python(session: Session) -> None:
session.install(".[all]")
else:
install_idom_dev(session, extras="all")
pytest_args += ["--reruns", "5"]
pytest_args += ["--reruns", "1"]

session.run("pytest", "tests", *pytest_args)

Expand Down
31 changes: 29 additions & 2 deletions src/idom/client/app/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function defaultWebSocketEndpoint() {
protocol = "ws:";
}

return protocol + "//" + url.join("/") + window.location.search;
return protocol + "//" + url.join("/") + "?" + queryParams.user.toString();
}

export function mountLayoutWithWebSocket(
Expand All @@ -48,7 +48,7 @@ export function mountLayoutWithWebSocket(
});

socket.onopen = (event) => {
console.log(`Connected to ${endpoint}`);
console.log(`Connected.`);
if (mountState.everMounted) {
unmountComponentAtNode(element);
}
Expand All @@ -69,6 +69,10 @@ export function mountLayoutWithWebSocket(
};

socket.onclose = (event) => {
if (!shouldReconnect()) {
console.log(`Connection lost.`);
return;
}
const reconnectTimeout = _nextReconnectTimeout(mountState);
console.log(`Connection lost, reconnecting in ${reconnectTimeout} seconds`);
setTimeout(function () {
Expand All @@ -95,3 +99,26 @@ function _nextReconnectTimeout(mountState) {
}
return timeout;
}

function shouldReconnect() {
return queryParams.reserved.get("noReconnect") === null;
}

const queryParams = (() => {
const reservedParams = new URLSearchParams();
const userParams = new URLSearchParams(window.location.search);

const reservedParamNames = ["noReconnect"];
reservedParamNames.forEach((name) => {
const value = userParams.get(name);
if (value !== null) {
reservedParams.append(name, userParams.get(name));
userParams.delete(name);
}
});

return {
reserved: reservedParams,
user: userParams,
};
})();
7 changes: 7 additions & 0 deletions src/idom/client/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


def web_module_path(package_name: str, must_exist: bool = False) -> Path:
"""Get the :class:`Path` to a web module's source"""
path = _private.web_modules_dir().joinpath(*(package_name + ".js").split("/"))
if must_exist and not path.exists():
raise ValueError(
Expand All @@ -23,13 +24,18 @@ def web_module_path(package_name: str, must_exist: bool = False) -> Path:


def web_module_exports(package_name: str) -> List[str]:
"""Get a list of names this module exports"""
web_module_path(package_name, must_exist=True)
return _private.find_js_module_exports_in_source(
web_module_path(package_name).read_text(encoding="utf-8")
)


def web_module_url(package_name: str) -> str:
"""Get the URL the where the web module should reside

If this URL is relative, then the base URL is determined by the client
"""
web_module_path(package_name, must_exist=True)
return (
IDOM_CLIENT_IMPORT_SOURCE_URL.get()
Expand All @@ -38,6 +44,7 @@ def web_module_url(package_name: str) -> str:


def web_module_exists(package_name: str) -> bool:
"""Whether a web module with a given name exists"""
return web_module_path(package_name).exists()


Expand Down
78 changes: 36 additions & 42 deletions src/idom/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,50 +92,13 @@ def log_records(self) -> List[logging.LogRecord]:
"""A list of captured log records"""
return self._log_handler.records

def assert_logged_exception(
self,
error_type: Type[Exception],
error_pattern: str,
clear_after: bool = True,
) -> None:
"""Assert that a given error type and message were logged"""
try:
re_pattern = re.compile(error_pattern)
for record in self.log_records:
if record.exc_info is not None:
error = record.exc_info[1]
if isinstance(error, error_type) and re_pattern.search(str(error)):
break
else: # pragma: no cover
assert False, f"did not raise {error_type} matching {error_pattern!r}"
finally:
if clear_after:
self.log_records.clear()

def raise_if_logged_exception(
self,
log_level: int = logging.ERROR,
exclude_exc_types: Union[Type[Exception], Tuple[Type[Exception], ...]] = (),
clear_after: bool = True,
) -> None:
"""Raise the first logged exception (if any)
def url(self, path: str = "", query: Optional[Any] = None) -> str:
"""Return a URL string pointing to the host and point of the server

Args:
log_level: The level of log to check
exclude_exc_types: Any exception types to ignore
clear_after: Whether to clear logs after check
path: the path to a resource on the server
query: a dictionary or list of query parameters
"""
try:
for record in self._log_handler.records:
if record.levelno >= log_level and record.exc_info is not None:
error = record.exc_info[1]
if error is not None and not isinstance(error, exclude_exc_types):
raise error
finally:
if clear_after:
self.log_records.clear()

def url(self, path: str = "", query: Optional[Any] = None) -> str:
return urlunparse(
[
"http",
Expand All @@ -147,6 +110,35 @@ def url(self, path: str = "", query: Optional[Any] = None) -> str:
]
)

def list_logged_exceptions(
self,
pattern: str = "",
types: Union[Type[Any], Tuple[Type[Any], ...]] = Exception,
log_level: int = logging.ERROR,
del_log_records: bool = True,
) -> List[BaseException]:
"""Return a list of logged exception matching the given criteria

Args:
log_level: The level of log to check
exclude_exc_types: Any exception types to ignore
del_log_records: Whether to delete the log records for yielded exceptions
"""
found: List[BaseException] = []
compiled_pattern = re.compile(pattern)
for index, record in enumerate(self.log_records):
if record.levelno >= log_level and record.exc_info is not None:
error = record.exc_info[1]
if (
error is not None
and isinstance(error, types)
and compiled_pattern.search(str(error))
):
if del_log_records:
del self.log_records[index - len(found)]
found.append(error)
return found

def __enter__(self: _Self) -> _Self:
self._log_handler = _LogRecordCaptor()
logging.getLogger().addHandler(self._log_handler)
Expand All @@ -161,8 +153,10 @@ def __exit__(
) -> None:
self.server.stop()
logging.getLogger().removeHandler(self._log_handler)
self.raise_if_logged_exception()
del self.mount, self.server
logged_errors = self.list_logged_exceptions(del_log_records=False)
if logged_errors: # pragma: no cover
raise logged_errors[0]
return None


Expand Down
19 changes: 0 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

import inspect
import logging
import os
from typing import Any, Iterator, List

import pyalect.builtins.pytest # noqa
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.logging import LogCaptureFixture
from _pytest.logging import caplog as _caplog # noqa
from selenium.webdriver import Chrome, ChromeOptions
from selenium.webdriver.support.ui import WebDriverWait

Expand Down Expand Up @@ -107,22 +104,6 @@ def driver_is_headless(pytestconfig: Config):
return bool(pytestconfig.option.headless)


@pytest.fixture(autouse=True)
def caplog(_caplog: LogCaptureFixture) -> Iterator[LogCaptureFixture]:
_caplog.set_level(logging.DEBUG)
yield _caplog
# check that there are no ERROR level log messages
for record in _caplog.records:
if record.exc_info:
raise record.exc_info[1]
assert record.levelno < logging.ERROR


class _PropogateHandler(logging.Handler):
def emit(self, record):
logging.getLogger(record.name).handle(record)


@pytest.fixture(scope="session", autouse=True)
def _restore_client(pytestconfig: Config) -> Iterator[None]:
"""Restore the client's state before and after testing
Expand Down
23 changes: 18 additions & 5 deletions tests/test_server/test_common/test_shared_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,25 @@ def Counter(count):

def test_shared_client_state_server_does_not_support_per_client_parameters(
driver_get,
driver_wait,
server_mount_point,
):
driver_get({"per_client_param": 1})
driver_get(
{
"per_client_param": 1,
# we need to stop reconnect attempts to prevent the error from happening
# more than once
"noReconnect": True,
}
)

server_mount_point.assert_logged_exception(
ValueError,
"does not support per-client view parameters",
clear_after=True,
driver_wait.until(
lambda driver: (
len(
server_mount_point.list_logged_exceptions(
"does not support per-client view parameters", ValueError
)
)
== 1
)
)