From aecc5e6e2f104ec21a504baba245ba40bc4c0c2b Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 8 Nov 2023 12:19:37 +0000 Subject: [PATCH 01/24] rx.State as base state --- reflex/app.py | 28 ++++----- reflex/compiler/compiler.py | 9 +-- reflex/compiler/utils.py | 8 +-- reflex/event.py | 4 +- reflex/middleware/hydrate_middleware.py | 4 +- reflex/middleware/middleware.py | 6 +- reflex/state.py | 64 ++++++++++++--------- reflex/testing.py | 8 ++- reflex/vars.py | 14 ++--- reflex/vars.pyi | 5 +- tests/__init__.py | 7 +++ tests/components/base/test_script.py | 4 +- tests/components/datadisplay/conftest.py | 8 +-- tests/components/datadisplay/test_table.py | 4 +- tests/components/forms/test_debounce.py | 2 +- tests/components/layout/test_cond.py | 2 +- tests/components/layout/test_foreach.py | 4 +- tests/components/test_component.py | 6 +- tests/conftest.py | 2 +- tests/middleware/test_hydrate_middleware.py | 10 ++-- tests/states/__init__.py | 4 +- tests/states/mutation.py | 6 +- tests/states/upload.py | 8 +-- tests/test_app.py | 20 +++---- tests/test_event.py | 4 +- tests/test_state.py | 42 +++++++------- tests/test_var.py | 17 +++--- tests/utils/test_utils.py | 4 +- 28 files changed, 164 insertions(+), 140 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index c66a3dfd897..42d9389e2a5 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -57,6 +57,7 @@ verify_route_validity, ) from reflex.state import ( + BaseState, RouterData, State, StateManager, @@ -98,7 +99,7 @@ class App(Base): socket_app: Optional[ASGIApp] = None # The state class to use for the app. - state: Optional[Type[State]] = None + state: Optional[Type[BaseState]] = None # Class to manage many client states. _state_manager: Optional[StateManager] = None @@ -149,25 +150,24 @@ def __init__(self, *args, **kwargs): "`connect_error_component` is deprecated, use `overlay_component` instead" ) super().__init__(*args, **kwargs) - state_subclasses = State.__subclasses__() - inferred_state = state_subclasses[-1] if state_subclasses else None + state_subclasses = BaseState.__subclasses__() is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ - # Special case to allow test cases have multiple subclasses of rx.State. + # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env: - # Only one State class is allowed. + # Only one Base State class is allowed. if len(state_subclasses) > 1: raise ValueError( - "rx.State has been subclassed multiple times. Only one subclass is allowed" + "rx.BaseState cannot be subclassed multiple times. use rx.State instead" ) # verify that provided state is valid - if self.state and inferred_state and self.state is not inferred_state: + if self.state and self.state is not State: console.warn( f"Using substate ({self.state.__name__}) as root state in `rx.App` is currently not supported." - f" Defaulting to root state: ({inferred_state.__name__})" + f" Defaulting to root state: ({State.__name__})" ) - self.state = inferred_state + self.state = State # Get the config config = get_config() @@ -265,7 +265,7 @@ def state_manager(self) -> StateManager: raise ValueError("The state manager has not been initialized.") return self._state_manager - async def preprocess(self, state: State, event: Event) -> StateUpdate | None: + async def preprocess(self, state: BaseState, event: Event) -> StateUpdate | None: """Preprocess the event. This is where middleware can modify the event before it is processed. @@ -290,7 +290,7 @@ async def preprocess(self, state: State, event: Event) -> StateUpdate | None: return out # type: ignore async def postprocess( - self, state: State, event: Event, update: StateUpdate + self, state: BaseState, event: Event, update: StateUpdate ) -> StateUpdate: """Postprocess the event. @@ -762,7 +762,7 @@ def submit_work(fn, *args, **kwargs): future.result() @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state out of band. Args: @@ -790,7 +790,9 @@ async def modify_state(self, token: str) -> AsyncIterator[State]: sid=state.router.session.session_id, ) - def _process_background(self, state: State, event: Event) -> asyncio.Task | None: + def _process_background( + self, state: BaseState, event: Event + ) -> asyncio.Task | None: """Process an event in the background and emit updates as they arrive. Args: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 1e5215872ba..e9eea12e373 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -16,6 +16,7 @@ ) from reflex.config import get_config from reflex.state import State +from reflex.state import BaseState from reflex.utils.imports import ImportVar @@ -63,7 +64,7 @@ def _compile_theme(theme: dict) -> str: return templates.THEME.render(theme=theme) -def _compile_contexts(state: Optional[Type[State]]) -> str: +def _compile_contexts(state: Optional[Type[BaseState]]) -> str: """Compile the initial state and contexts. Args: @@ -87,7 +88,7 @@ def _compile_contexts(state: Optional[Type[State]]) -> str: def _compile_page( component: Component, - state: Type[State], + state: Type[BaseState], ) -> str: """Compile the component given the app state. @@ -337,7 +338,7 @@ def compile_theme(style: ComponentStyle) -> tuple[str, str]: return output_path, code -def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: +def compile_contexts(state: Optional[Type[BaseState]]) -> tuple[str, str]: """Compile the initial state / context. Args: @@ -353,7 +354,7 @@ def compile_contexts(state: Optional[Type[State]]) -> tuple[str, str]: def compile_page( - path: str, component: Component, state: Type[State] + path: str, component: Component, state: Type[BaseState] ) -> tuple[str, str]: """Compile a single page. diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 3a866051ab9..0a4df7a4ee9 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -21,7 +21,7 @@ Title, ) from reflex.components.component import Component, ComponentStyle, CustomComponent -from reflex.state import Cookie, LocalStorage, State +from reflex.state import BaseState, Cookie, LocalStorage from reflex.style import Style from reflex.utils import console, format, imports, path_ops @@ -128,7 +128,7 @@ def get_import_dict(lib: str, default: str = "", rest: list[str] | None = None) } -def compile_state(state: Type[State]) -> dict: +def compile_state(state: Type[BaseState]) -> dict: """Compile the state of the app. Args: @@ -170,7 +170,7 @@ def _compile_client_storage_field( def _compile_client_storage_recursive( - state: Type[State], + state: Type[BaseState], ) -> tuple[dict[str, dict], dict[str, dict[str, str]]]: """Compile the client-side storage for the given state recursively. @@ -208,7 +208,7 @@ def _compile_client_storage_recursive( return cookies, local_storage -def compile_client_storage(state: Type[State]) -> dict[str, dict]: +def compile_client_storage(state: Type[BaseState]) -> dict[str, dict]: """Compile the client-side storage for the given state. Args: diff --git a/reflex/event.py b/reflex/event.py index cbaf65b0ce1..f0ad0957b0d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -22,7 +22,7 @@ from reflex.vars import BaseVar, Var if TYPE_CHECKING: - from reflex.state import State + from reflex.state import BaseState class Event(Base): @@ -64,7 +64,7 @@ def background(fn): def _no_chain_background_task( - state_cls: Type["State"], name: str, fn: Callable + state_cls: Type["BaseState"], name: str, fn: Callable ) -> Callable: """Protect against directly chaining a background task from another event handler. diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 38d5fb14f57..6108a90c49b 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -6,7 +6,7 @@ from reflex import constants from reflex.event import Event, fix_events, get_hydrate_event from reflex.middleware.middleware import Middleware -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate from reflex.utils import format if TYPE_CHECKING: @@ -17,7 +17,7 @@ class HydrateMiddleware(Middleware): """Middleware to handle initial app hydration.""" async def preprocess( - self, app: App, state: State, event: Event + self, app: App, state: BaseState, event: Event ) -> Optional[StateUpdate]: """Preprocess the event. diff --git a/reflex/middleware/middleware.py b/reflex/middleware/middleware.py index 726d8162175..f522ff86162 100644 --- a/reflex/middleware/middleware.py +++ b/reflex/middleware/middleware.py @@ -6,7 +6,7 @@ from reflex.base import Base from reflex.event import Event -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate if TYPE_CHECKING: from reflex.app import App @@ -16,7 +16,7 @@ class Middleware(Base, ABC): """Middleware to preprocess and postprocess requests.""" async def preprocess( - self, app: App, state: State, event: Event + self, app: App, state: BaseState, event: Event ) -> Optional[StateUpdate]: """Preprocess the event. @@ -31,7 +31,7 @@ async def preprocess( return None async def postprocess( - self, app: App, state: State, event: Event, update: StateUpdate + self, app: App, state: BaseState, event: Event, update: StateUpdate ) -> StateUpdate: """Postprocess the event. diff --git a/reflex/state.py b/reflex/state.py index 881a6b47c7b..389cdf0253d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -81,7 +81,7 @@ def __init__(self, router_data: Optional[dict] = None): class PageData(Base): """An object containing page data.""" - host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) + host: str = "" # repeated with self.headers.origin (remove or keep the duplicate?) path: str = "" raw_path: str = "" full_path: str = "" @@ -152,7 +152,7 @@ def __init__(self, router_data: Optional[dict] = None): } -class State(Base, ABC, extra=pydantic.Extra.allow): +class BaseState(Base, ABC, extra=pydantic.Extra.allow): """The state of the app.""" # A map from the var name to the var. @@ -189,10 +189,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): _always_dirty_substates: ClassVar[Set[str]] = set() # The parent state. - parent_state: Optional[State] = None + parent_state: Optional[BaseState] = None # The substates of the state. - substates: Dict[str, State] = {} + substates: Dict[str, BaseState] = {} # The set of dirty vars. dirty_vars: Set[str] = set() @@ -212,7 +212,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # The hydrated bool. is_hydrated: bool = False - def __init__(self, *args, parent_state: State | None = None, **kwargs): + def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): """Initialize the state. Args: @@ -241,7 +241,7 @@ def __init__(self, *args, parent_state: State | None = None, **kwargs): # Create a fresh copy of the backend variables for this instance self._backend_vars = copy.deepcopy(self.backend_vars) - def _init_event_handlers(self, state: State | None = None): + def _init_event_handlers(self, state: BaseState | None = None): """Initialize event handlers. Allow event handlers to be called directly on the instance. This is @@ -437,7 +437,7 @@ def get_skip_vars(cls) -> set[str]: @classmethod @functools.lru_cache() - def get_parent_state(cls) -> Type[State] | None: + def get_parent_state(cls) -> Type[BaseState] | None: """Get the parent state. Returns: @@ -446,14 +446,14 @@ def get_parent_state(cls) -> Type[State] | None: parent_states = [ base for base in cls.__bases__ - if types._issubclass(base, State) and base is not State + if types._issubclass(base, BaseState) and base is not BaseState ] assert len(parent_states) < 2, "Only one parent state is allowed." return parent_states[0] if len(parent_states) == 1 else None # type: ignore @classmethod @functools.lru_cache() - def get_substates(cls) -> set[Type[State]]: + def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. Returns: @@ -487,7 +487,7 @@ def get_full_name(cls) -> str: @classmethod @functools.lru_cache() - def get_class_substate(cls, path: Sequence[str]) -> Type[State]: + def get_class_substate(cls, path: Sequence[str]) -> Type[BaseState]: """Get the class substate. Args: @@ -643,7 +643,7 @@ def _get_base_functions() -> dict[str, FunctionType]: """ return { func[0]: func[1] - for func in inspect.getmembers(State, predicate=inspect.isfunction) + for func in inspect.getmembers(BaseState, predicate=inspect.isfunction) if not func[0].startswith("__") } @@ -909,7 +909,7 @@ def _reset_client_storage(self): for substate in self.substates.values(): substate._reset_client_storage() - def get_substate(self, path: Sequence[str]) -> State | None: + def get_substate(self, path: Sequence[str]) -> BaseState | None: """Get the substate. Args: @@ -933,7 +933,7 @@ def get_substate(self, path: Sequence[str]) -> State | None: def _get_event_handler( self, event: Event - ) -> tuple[State | StateProxy, EventHandler]: + ) -> tuple[BaseState | StateProxy, EventHandler]: """Get the event handler for the given event. Args: @@ -1050,7 +1050,7 @@ def _as_state_update( ) async def _process_event( - self, handler: EventHandler, state: State | StateProxy, payload: Dict + self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict ) -> AsyncIterator[StateUpdate]: """Process event. @@ -1263,7 +1263,7 @@ def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]: d.update(substate_d) return d - async def __aenter__(self) -> State: + async def __aenter__(self) -> BaseState: """Enter the async context manager protocol. This should not be used for the State class, but exists for @@ -1288,6 +1288,12 @@ async def __aexit__(self, *exc_info: Any) -> None: pass +class State(BaseState): + """The app Base State.""" + + is_hydrated: bool = False + + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. @@ -1455,10 +1461,10 @@ class StateManager(Base, ABC): """A class to manage many client states.""" # The state class to use. - state: Type[State] + state: Type[BaseState] @classmethod - def create(cls, state: Type[State]): + def create(cls, state: Type[BaseState]): """Create a new state manager. Args: @@ -1473,7 +1479,7 @@ def create(cls, state: Type[State]): return StateManagerMemory(state=state) @abstractmethod - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1485,7 +1491,7 @@ async def get_state(self, token: str) -> State: pass @abstractmethod - async def set_state(self, token: str, state: State): + async def set_state(self, token: str, state: BaseState): """Set the state for a token. Args: @@ -1496,7 +1502,7 @@ async def set_state(self, token: str, state: State): @abstractmethod @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1512,7 +1518,7 @@ class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" # The mapping of client ids to states. - states: Dict[str, State] = {} + states: Dict[str, BaseState] = {} # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock = asyncio.Lock() @@ -1527,7 +1533,7 @@ class Config: "_states_locks": {"exclude": True}, } - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1540,7 +1546,7 @@ async def get_state(self, token: str) -> State: self.states[token] = self.state() return self.states[token] - async def set_state(self, token: str, state: State): + async def set_state(self, token: str, state: BaseState): """Set the state for a token. Args: @@ -1550,7 +1556,7 @@ async def set_state(self, token: str, state: State): pass @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1598,7 +1604,7 @@ class StateManagerRedis(StateManager): b"evicted", } - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state for a token. Args: @@ -1613,7 +1619,9 @@ async def get_state(self, token: str) -> State: return await self.get_state(token) return cloudpickle.loads(redis_state) - async def set_state(self, token: str, state: State, lock_id: bytes | None = None): + async def set_state( + self, token: str, state: BaseState, lock_id: bytes | None = None + ): """Set the state for a token. Args: @@ -1637,7 +1645,7 @@ async def set_state(self, token: str, state: State, lock_id: bytes | None = None await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[State]: + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. Args: @@ -1879,7 +1887,7 @@ class MutableProxy(wrapt.ObjectProxy): __mutable_types__ = (list, dict, set, Base) - def __init__(self, wrapped: Any, state: State, field_name: str): + def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. Args: diff --git a/reflex/testing.py b/reflex/testing.py index a6bb3d400d2..b2ba8bd49a8 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -38,7 +38,7 @@ import reflex.utils.exec import reflex.utils.prerequisites import reflex.utils.processes -from reflex.state import State, StateManagerMemory, StateManagerRedis +from reflex.state import BaseState, State, StateManagerMemory, StateManagerRedis try: from selenium import webdriver # pyright: ignore [reportMissingImports] @@ -434,7 +434,7 @@ def frontend(self, driver_clz: Optional[Type["WebDriver"]] = None) -> "WebDriver self._frontends.append(driver) return driver - async def get_state(self, token: str) -> State: + async def get_state(self, token: str) -> BaseState: """Get the state associated with the given token. Args: @@ -561,7 +561,9 @@ def poll_for_value( ) return element.get_attribute("value") - def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, reflex.State]: + def poll_for_clients( + self, timeout: TimeoutType = None + ) -> dict[str, reflex.BaseState]: """Poll app state_manager for any connected clients. Args: diff --git a/reflex/vars.py b/reflex/vars.py index caec26f87a2..6bfd6c57dd6 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -41,7 +41,7 @@ from reflex.utils.imports import ImportDict, ImportVar if TYPE_CHECKING: - from reflex.state import State + from reflex.state import BaseState # Set of unique variable names. USED_VARIABLES = set() @@ -1472,7 +1472,7 @@ def _var_full_name(self) -> str: ) ) - def _var_set_state(self, state: Type[State] | str) -> Any: + def _var_set_state(self, state: Type[BaseState] | str) -> Any: """Set the state of the var. Args: @@ -1604,14 +1604,14 @@ def get_setter_name(self, include_state: bool = True) -> str: return setter return ".".join((self._var_data.state, setter)) - def get_setter(self) -> Callable[[State, Any], None]: + def get_setter(self) -> Callable[[BaseState, Any], None]: """Get the var's setter function. Returns: A function that that creates a setter for the var. """ - def setter(state: State, value: Any): + def setter(state: BaseState, value: Any): """Get the setter for the var. Args: @@ -1643,9 +1643,9 @@ class ComputedVar(Var, property): def __init__( self, - fget: Callable[[State], Any], - fset: Callable[[State, Any], None] | None = None, - fdel: Callable[[State], Any] | None = None, + fget: Callable[[BaseState], Any], + fset: Callable[[BaseState, Any], None] | None = None, + fdel: Callable[[BaseState], Any] | None = None, doc: str | None = None, **kwargs, ): diff --git a/reflex/vars.pyi b/reflex/vars.pyi index c4c96af90f8..6ec1bd98798 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -5,6 +5,7 @@ from _typeshed import Incomplete from reflex import constants as constants from reflex.base import Base as Base from reflex.state import State as State +from reflex.state import BaseState as BaseState from reflex.utils import console as console, format as format, types as types from reflex.utils.imports import ImportVar from types import FunctionType @@ -110,7 +111,7 @@ class Var: def as_ref(self) -> Var: ... @property def _var_full_name(self) -> str: ... - def _var_set_state(self, state: Type[State] | str) -> Any: ... + def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... @dataclass(eq=False) class BaseVar(Var): @@ -123,7 +124,7 @@ class BaseVar(Var): def __hash__(self) -> int: ... def get_default_value(self) -> Any: ... def get_setter_name(self, include_state: bool = ...) -> str: ... - def get_setter(self) -> Callable[[State, Any], None]: ... + def get_setter(self) -> Callable[[BaseState, Any], None]: ... @dataclass(init=False) class ComputedVar(Var): diff --git a/tests/__init__.py b/tests/__init__.py index b318bba6307..41bc9f99857 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,8 @@ """Root directory for tests.""" +import os + +from reflex import constants + +os.environ[constants.PYTEST_CURRENT_TEST] = "true" +print("Gegining testing-------------\n") +print(constants.PYTEST_CURRENT_TEST in os.environ) diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py index cc16ab7185a..7ccdc06346b 100644 --- a/tests/components/base/test_script.py +++ b/tests/components/base/test_script.py @@ -2,7 +2,7 @@ import pytest from reflex.components.base.script import Script -from reflex.state import State +from reflex.state import BaseState def test_script_inline(): @@ -31,7 +31,7 @@ def test_script_neither(): Script.create() -class EvState(State): +class EvState(BaseState): """State for testing event handlers.""" def on_ready(self): diff --git a/tests/components/datadisplay/conftest.py b/tests/components/datadisplay/conftest.py index c0e61c437d9..18eafb17c24 100644 --- a/tests/components/datadisplay/conftest.py +++ b/tests/components/datadisplay/conftest.py @@ -18,7 +18,7 @@ def data_table_state(request): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(rx.BaseState): data = request.param["data"] columns = ["column1", "column2"] @@ -33,7 +33,7 @@ def data_table_state2(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(rx.BaseState): _data = pd.DataFrame() @rx.var @@ -51,7 +51,7 @@ def data_table_state3(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(rx.BaseState): _data: List = [] _columns: List = ["col1", "col2"] @@ -74,7 +74,7 @@ def data_table_state4(): The data table state class. """ - class DataTableState(rx.State): + class DataTableState(rx.BaseState): _data: List = [] _columns: List = ["col1", "col2"] diff --git a/tests/components/datadisplay/test_table.py b/tests/components/datadisplay/test_table.py index 94e39f6e5f4..212f1fb59f7 100644 --- a/tests/components/datadisplay/test_table.py +++ b/tests/components/datadisplay/test_table.py @@ -4,12 +4,12 @@ import pytest from reflex.components.datadisplay.table import Tbody, Tfoot, Thead -from reflex.state import State +from reflex.state import BaseState PYTHON_GT_V38 = sys.version_info.major >= 3 and sys.version_info.minor > 8 -class TableState(State): +class TableState(BaseState): """Test State class.""" rows_List_List_str: List[List[str]] = [["random", "row"]] diff --git a/tests/components/forms/test_debounce.py b/tests/components/forms/test_debounce.py index 1965fabf586..91a2d938e53 100644 --- a/tests/components/forms/test_debounce.py +++ b/tests/components/forms/test_debounce.py @@ -24,7 +24,7 @@ def test_render_many_child(): _ = rx.debounce_input("foo", "bar").render() -class S(rx.State): +class S(rx.BaseState): """Example state for debounce tests.""" value: str = "" diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index 00cf4de7d53..27f204e4cbf 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -20,7 +20,7 @@ @pytest.fixture def cond_state(request): - class CondState(rx.State): + class CondState(rx.BaseState): value: request.param["value_type"] = request.param["value"] # noqa return CondState diff --git a/tests/components/layout/test_foreach.py b/tests/components/layout/test_foreach.py index aacdf363826..71ae36b23a2 100644 --- a/tests/components/layout/test_foreach.py +++ b/tests/components/layout/test_foreach.py @@ -4,10 +4,10 @@ from reflex.components import box, foreach, text from reflex.components.layout import Foreach -from reflex.state import State +from reflex.state import BaseState -class ForEachState(State): +class ForEachState(BaseState): """A state for testing the ForEach component.""" colors_list: List[str] = ["red", "yellow"] diff --git a/tests/components/test_component.py b/tests/components/test_component.py index d90efd658d4..a576f04842b 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -15,6 +15,8 @@ from reflex.constants import EventTriggers from reflex.event import EventChain, EventHandler from reflex.state import State +from reflex.event import EventHandler +from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports from reflex.utils.imports import ImportVar @@ -23,7 +25,7 @@ @pytest.fixture def test_state(): - class TestState(State): + class TestState(BaseState): num: int def do_something(self): @@ -400,7 +402,7 @@ def test_get_event_triggers(component1, component2): ) -class C1State(State): +class C1State(BaseState): """State for testing C1 component.""" def mock_handler(self, _e, _bravo, _charlie): diff --git a/tests/conftest.py b/tests/conftest.py index d2dd301f468..040b1221918 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,7 +235,7 @@ def duplicate_substate(): The test state. """ - class TestState(rx.State): + class TestState(rx.BaseState): pass class ChildTestState(TestState): # type: ignore # noqa diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 7767dcf8b2d..357aa4401d0 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -5,10 +5,10 @@ from reflex.app import App from reflex.constants import CompileVars from reflex.middleware.hydrate_middleware import HydrateMiddleware -from reflex.state import State, StateUpdate +from reflex.state import BaseState, StateUpdate -def exp_is_hydrated(state: State) -> Dict[str, Any]: +def exp_is_hydrated(state: BaseState) -> Dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. Args: @@ -20,7 +20,7 @@ def exp_is_hydrated(state: State) -> Dict[str, Any]: return {state.get_name(): {CompileVars.IS_HYDRATED: True}} -class TestState(State): +class TestState(BaseState): """A test state with no return in handler.""" __test__ = False @@ -32,7 +32,7 @@ def test_handler(self): self.num += 1 -class TestState2(State): +class TestState2(BaseState): """A test state with return in handler.""" __test__ = False @@ -54,7 +54,7 @@ def change_name(self): self.name = "random" -class TestState3(State): +class TestState3(BaseState): """A test state with async handler.""" __test__ = False diff --git a/tests/states/__init__.py b/tests/states/__init__.py index 2c007172a67..d8baf53e4e1 100644 --- a/tests/states/__init__.py +++ b/tests/states/__init__.py @@ -1,4 +1,4 @@ -"""Common rx.State subclasses for use in tests.""" +"""Common rx.BaseState subclasses for use in tests.""" import reflex as rx from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState @@ -12,7 +12,7 @@ ) -class GenState(rx.State): +class GenState(rx.BaseState): """A state with event handlers that generate multiple updates.""" value: int diff --git a/tests/states/mutation.py b/tests/states/mutation.py index b3d98301ffc..f6d050295a1 100644 --- a/tests/states/mutation.py +++ b/tests/states/mutation.py @@ -5,7 +5,7 @@ import reflex as rx -class DictMutationTestState(rx.State): +class DictMutationTestState(rx.BaseState): """A state for testing ReflexDict mutation.""" # plain dict @@ -62,7 +62,7 @@ def add_friend_age(self): self.friend_in_nested_dict["friend"]["age"] = 30 -class ListMutationTestState(rx.State): +class ListMutationTestState(rx.BaseState): """A state for testing ReflexList mutation.""" # plain list @@ -144,7 +144,7 @@ class CustomVar(rx.Base): custom: OtherBase = OtherBase() -class MutableTestState(rx.State): +class MutableTestState(rx.BaseState): """A test state.""" array: List[Union[str, List, Dict[str, str]]] = [ diff --git a/tests/states/upload.py b/tests/states/upload.py index ec2585dd161..2f5ba950dcb 100644 --- a/tests/states/upload.py +++ b/tests/states/upload.py @@ -5,7 +5,7 @@ import reflex as rx -class UploadState(rx.State): +class UploadState(rx.BaseState): """The base state for uploading a file.""" async def handle_upload1(self, files: List[rx.UploadFile]): @@ -17,7 +17,7 @@ async def handle_upload1(self, files: List[rx.UploadFile]): pass -class BaseState(rx.State): +class BaseState(rx.BaseState): """The test base state.""" pass @@ -37,7 +37,7 @@ async def handle_upload(self, files: List[rx.UploadFile]): pass -class FileUploadState(rx.State): +class FileUploadState(rx.BaseState): """The base state for uploading a file.""" img_list: List[str] @@ -79,7 +79,7 @@ async def bg_upload(self, files: List[rx.UploadFile]): pass -class FileStateBase1(rx.State): +class FileStateBase1(rx.BaseState): """The base state for a child FileUploadState.""" pass diff --git a/tests/test_app.py b/tests/test_app.py index d0f53bec7f9..787eeedb691 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -28,7 +28,7 @@ from reflex.event import Event, get_hydrate_event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import RouterData, State, StateManagerRedis, StateUpdate +from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar @@ -43,7 +43,7 @@ ) -class EmptyState(State): +class EmptyState(BaseState): """An empty state.""" pass @@ -77,14 +77,14 @@ def about(): return about -class ATestState(State): +class ATestState(BaseState): """A simple state for testing.""" var: int @pytest.fixture() -def test_state() -> Type[State]: +def test_state() -> Type[BaseState]: """A default state. Returns: @@ -94,14 +94,14 @@ def test_state() -> Type[State]: @pytest.fixture() -def redundant_test_state() -> Type[State]: +def redundant_test_state() -> Type[BaseState]: """A default state. Returns: A default state. """ - class RedundantTestState(State): + class RedundantTestState(BaseState): var: int return RedundantTestState @@ -198,12 +198,12 @@ def test_default_app(app: App): def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): - """Test that an error is thrown when multiple classes subclass rx.State. + """Test that an error is thrown when multiple classes subclass rx.BaseState. Args: monkeypatch: Pytest monkeypatch object. - test_state: A test state subclassing rx.State. - redundant_test_state: Another test state subclassing rx.State. + test_state: A test state subclassing rx.BaseState. + redundant_test_state: Another test state subclassing rx.BaseState. """ monkeypatch.delenv(constants.PYTEST_CURRENT_TEST) with pytest.raises(ValueError): @@ -851,7 +851,7 @@ async def test_upload_file_background(state, tmp_path, token): await app.state_manager.redis.close() -class DynamicState(State): +class DynamicState(BaseState): """State class for testing dynamic route var. This is defined at module level because event handlers cannot be addressed diff --git a/tests/test_event.py b/tests/test_event.py index 284012b1311..2326f092026 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -4,7 +4,7 @@ from reflex import event from reflex.event import Event, EventHandler, EventSpec, fix_events -from reflex.state import State +from reflex.state import BaseState from reflex.utils import format from reflex.vars import Var @@ -303,7 +303,7 @@ def test_event_actions(): def test_event_actions_on_state(): - class EventActionState(State): + class EventActionState(BaseState): def handler(self): pass diff --git a/tests/test_state.py b/tests/test_state.py index 591026d2cc1..8879af67bfe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -19,11 +19,11 @@ from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler from reflex.state import ( + BaseState, ImmutableStateError, LockExpiredError, MutableProxy, RouterData, - State, StateManager, StateManagerMemory, StateManagerRedis, @@ -75,7 +75,7 @@ class Object(Base): prop2: str = "hello" -class TestState(State): +class TestState(BaseState): """A test state.""" # Set this class as not test one @@ -148,7 +148,7 @@ def do_nothing(self): pass -class DateTimeState(State): +class DateTimeState(BaseState): """A State with some datetime fields.""" d: datetime.date = datetime.date.fromisoformat("1989-11-09") @@ -837,7 +837,7 @@ def test_get_query_params(test_state): def test_add_var(): - class DynamicState(State): + class DynamicState(BaseState): pass ds1 = DynamicState() @@ -870,7 +870,7 @@ def test_add_var_default_handlers(test_state): assert isinstance(test_state.event_handlers["set_rand_int"], EventHandler) -class InterdependentState(State): +class InterdependentState(BaseState): """A state with 3 vars and 3 computed vars. x: a variable that no computed var depends on @@ -915,7 +915,7 @@ def v1x2x2(self) -> int: @pytest.fixture -def interdependent_state() -> State: +def interdependent_state() -> BaseState: """A state with varying dependency between vars. Returns: @@ -988,7 +988,7 @@ def test_per_state_backend_var(interdependent_state): def test_child_state(): """Test that the child state computed vars can reference parent state vars.""" - class MainState(State): + class MainState(BaseState): v: int = 2 class ChildState(MainState): @@ -1006,7 +1006,7 @@ def rendered_var(self): def test_conditional_computed_vars(): """Test that computed vars can have conditionals.""" - class MainState(State): + class MainState(BaseState): flag: bool = False t1: str = "a" t2: str = "b" @@ -1051,7 +1051,7 @@ def test_event_handlers_convert_to_fns(test_state, child_state): def test_event_handlers_call_other_handlers(): """Test that event handlers can call other event handlers.""" - class MainState(State): + class MainState(BaseState): v: int = 0 def set_v(self, v: int): @@ -1077,7 +1077,7 @@ def test_computed_var_cached(): """Test that a ComputedVar doesn't recalculate when accessed.""" comp_v_calls = 0 - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 @rx.cached_var @@ -1102,7 +1102,7 @@ def comp_v(self) -> int: def test_computed_var_cached_depends_on_non_cached(): """Test that a cached_var is recalculated if it depends on non-cached ComputedVar.""" - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 @rx.var @@ -1144,7 +1144,7 @@ def test_computed_var_depends_on_parent_non_cached(): """Child state cached_var that depends on parent state un cached var is always recalculated.""" counter = 0 - class ParentState(State): + class ParentState(BaseState): @rx.var def no_cache_v(self) -> int: nonlocal counter @@ -1195,7 +1195,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): """ counter = 0 - class HandlerState(State): + class HandlerState(BaseState): x: int = 42 def handler(self): @@ -1226,7 +1226,7 @@ def cached_x_side_effect(self) -> int: def test_computed_var_dependencies(): """Test that a ComputedVar correctly tracks its dependencies.""" - class ComputedState(State): + class ComputedState(BaseState): v: int = 0 w: int = 0 x: int = 0 @@ -1293,7 +1293,7 @@ def comp_z(self) -> List[bool]: def test_backend_method(): """A method with leading underscore should be callable from event handler.""" - class BackendMethodState(State): + class BackendMethodState(BaseState): def _be_method(self): return True @@ -1369,7 +1369,7 @@ def test_error_on_state_method_shadow(): """Test that an error is thrown when an event handler shadows a state method.""" with pytest.raises(NameError) as err: - class InvalidTest(rx.State): + class InvalidTest(rx.BaseState): def reset(self): pass @@ -1382,7 +1382,7 @@ def reset(self): def test_state_with_invalid_yield(): """Test that an error is thrown when a state yields an invalid value.""" - class StateWithInvalidYield(rx.State): + class StateWithInvalidYield(rx.BaseState): """A state that yields an invalid value.""" def invalid_handler(self): @@ -1666,7 +1666,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert mcall.kwargs["to"] == grandchild_state.get_sid() -class BackgroundTaskState(State): +class BackgroundTaskState(BaseState): """A state with a background task.""" order: List[str] = [] @@ -2206,7 +2206,7 @@ class Foo(Base): def test_json_dumps_with_mutables(): """Test that json.dumps works with Base vars inside mutable types.""" - class MutableContainsBase(State): + class MutableContainsBase(BaseState): items: List[Foo] = [Foo()] dict_val = MutableContainsBase().dict() @@ -2225,7 +2225,7 @@ def test_reset_with_mutables(): default = [[0, 0], [0, 1], [1, 1]] copied_default = copy.deepcopy(default) - class MutableResetState(State): + class MutableResetState(BaseState): items: List[List[int]] = default instance = MutableResetState() @@ -2273,7 +2273,7 @@ class Custom3(Base): def test_state_union_optional(): """Test that state can be defined with Union and Optional vars.""" - class UnionState(State): + class UnionState(BaseState): int_float: Union[int, float] = 0 opt_int: Optional[int] c3: Optional[Custom3] diff --git a/tests/test_var.py b/tests/test_var.py index c57c7eb235e..787915a9662 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -7,6 +7,7 @@ from reflex.base import Base from reflex.state import State +from reflex.state import BaseState from reflex.vars import ( BaseVar, ComputedVar, @@ -24,10 +25,10 @@ ] -class BaseState(State): - """A Test State.""" - - val: str = "key" +# class BaseState(State): +# """A Test State.""" +# +# val: str = "key" @pytest.fixture @@ -41,7 +42,7 @@ class TestObj(Base): @pytest.fixture def ParentState(TestObj): - class ParentState(State): + class ParentState(BaseState): foo: int bar: int @@ -74,7 +75,7 @@ def var_without_annotation(self): @pytest.fixture def StateWithAnyVar(TestObj): - class StateWithAnyVar(State): + class StateWithAnyVar(BaseState): @ComputedVar def var_without_annotation(self) -> typing.Any: return TestObj @@ -84,7 +85,7 @@ def var_without_annotation(self) -> typing.Any: @pytest.fixture def StateWithCorrectVarAnnotation(): - class StateWithCorrectVarAnnotation(State): + class StateWithCorrectVarAnnotation(BaseState): @ComputedVar def var_with_annotation(self) -> str: return "Correct annotation" @@ -94,7 +95,7 @@ def var_with_annotation(self) -> str: @pytest.fixture def StateWithWrongVarAnnotation(TestObj): - class StateWithWrongVarAnnotation(State): + class StateWithWrongVarAnnotation(BaseState): @ComputedVar def var_with_annotation(self) -> str: return TestObj diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index b64f5f1889f..82085dd742c 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -10,7 +10,7 @@ from reflex import constants from reflex.base import Base from reflex.event import EventHandler -from reflex.state import State +from reflex.state import BaseState from reflex.utils import ( build, prerequisites, @@ -43,7 +43,7 @@ def get_above_max_version(): VMAXPLUS1 = version.parse(get_above_max_version()) -class ExampleTestState(State): +class ExampleTestState(BaseState): """Test state class.""" def test_event_handler(self): From 18b6447fbae11a6ce37aab60aba15f04ddf00d0a Mon Sep 17 00:00:00 2001 From: Elijah Date: Fri, 10 Nov 2023 13:56:20 +0000 Subject: [PATCH 02/24] fix app harness tests --- integration/test_background_task.py | 2 +- integration/test_call_script.py | 2 +- integration/test_client_storage.py | 44 +++++++++++++-------------- integration/test_connection_banner.py | 2 +- integration/test_dynamic_routes.py | 2 +- integration/test_event_actions.py | 2 +- integration/test_event_chain.py | 14 ++++----- integration/test_form_submit.py | 4 +-- integration/test_input.py | 12 ++++---- integration/test_login_flow.py | 4 +-- integration/test_radix_themes.py | 2 +- integration/test_server_side_event.py | 2 +- integration/test_table.py | 2 +- integration/test_upload.py | 6 ++-- integration/test_var_operations.py | 2 +- 15 files changed, 51 insertions(+), 51 deletions(-) diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 493dc518016..bc70ff01fba 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -93,7 +93,7 @@ def index() -> rx.Component: rx.button("Reset", on_click=State.reset_counter, id="reset"), ) - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_call_script.py b/integration/test_call_script.py index 8b24ce19b7c..95bdf37c998 100644 --- a/integration/test_call_script.py +++ b/integration/test_call_script.py @@ -135,7 +135,7 @@ def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() - app = rx.App(state=CallScriptState) + app = rx.App(state=rx.State) with open("assets/external.js", "w") as f: f.write(external_scripts) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 90b98f16dcf..5efea8411de 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -97,7 +97,7 @@ def index(): rx.box(ClientSideSubSubState.l1s, id="l1s"), ) - app = rx.App(state=ClientSideState) + app = rx.App(state=rx.State) app.add_page(index) app.add_page(index, route="/foo") app.compile() @@ -285,28 +285,28 @@ async def test_client_side_state( set_sub_sub_state_button.click() exp_cookies = { - "client_side_state.client_side_sub_state.c1": { + "state.client_side_state.client_side_sub_state.c1": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c1", + "name": "state.client_side_state.client_side_sub_state.c1", "path": "/", "sameSite": "Lax", "secure": False, "value": "c1%20value", }, - "client_side_state.client_side_sub_state.c2": { + "state.client_side_state.client_side_sub_state.c2": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c2", + "name": "state.client_side_state.client_side_sub_state.c2", "path": "/", "sameSite": "Lax", "secure": False, "value": "c2%20value", }, - "client_side_state.client_side_sub_state.c4": { + "state.client_side_state.client_side_sub_state.c4": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c4", + "name": "state.client_side_state.client_side_sub_state.c4", "path": "/", "sameSite": "Strict", "secure": False, @@ -321,19 +321,19 @@ async def test_client_side_state( "secure": False, "value": "c6%20value", }, - "client_side_state.client_side_sub_state.c7": { + "state.client_side_state.client_side_sub_state.c7": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c7", + "name": "state.client_side_state.client_side_sub_state.c7", "path": "/", "sameSite": "Lax", "secure": False, "value": "c7%20value", }, - "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { + "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s": { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", + "name": "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.c1s", "path": "/", "sameSite": "Lax", "secure": False, @@ -354,40 +354,40 @@ async def test_client_side_state( input_value_input.send_keys("c3 value") set_sub_state_button.click() AppHarness._poll_for( - lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) ) - c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"] + c3_cookie = cookie_info_map(driver)["state.client_side_state.client_side_sub_state.c3"] assert c3_cookie.pop("expiry") is not None assert c3_cookie == { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c3", + "name": "state.client_side_state.client_side_sub_state.c3", "path": "/", "sameSite": "Lax", "secure": False, "value": "c3%20value", } time.sleep(2) # wait for c3 to expire - assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) + assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) local_storage_items = local_storage.items() local_storage_items.pop("chakra-ui-color-mode", None) assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l1") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l1") == "l1 value" ) assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l2") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l2") == "l2 value" ) assert local_storage_items.pop("l3") == "l3 value" assert ( - local_storage_items.pop("client_side_state.client_side_sub_state.l4") + local_storage_items.pop("state.client_side_state.client_side_sub_state.l4") == "l4 value" ) assert ( local_storage_items.pop( - "client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" + "state.client_side_state.client_side_sub_state.client_side_sub_sub_state.l1s" ) == "l1s value" ) @@ -482,12 +482,12 @@ async def test_client_side_state( # make sure c5 cookie shows up on the `/foo` route AppHarness._poll_for( - lambda: "client_side_state.client_side_sub_state.c5" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c5" in cookie_info_map(driver) ) - assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == { + assert cookie_info_map(driver)["state.client_side_state.client_side_sub_state.c5"] == { "domain": "localhost", "httpOnly": False, - "name": "client_side_state.client_side_sub_state.c5", + "name": "state.client_side_state.client_side_sub_state.c5", "path": "/foo/", "sameSite": "Lax", "secure": False, diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py index c078df1f70a..0468e6b535f 100644 --- a/integration/test_connection_banner.py +++ b/integration/test_connection_banner.py @@ -19,7 +19,7 @@ class State(rx.State): def index(): return rx.text("Hello World") - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 5ab11b69359..ff0afdebabc 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -56,7 +56,7 @@ def index(): def redirect_page(): return rx.fragment(rx.text("redirecting...")) - app = rx.App(state=DynamicState) + app = rx.App(state=rx.State ) app.add_page(index) app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 8f5e0788b44..9bb0ab1893e 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -130,7 +130,7 @@ def index(): on_click=EventActionState.on_click("outer"), # type: ignore ) - app = rx.App(state=EventActionState) + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 5fbf7cc14b2..b89df7c1b09 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -122,7 +122,7 @@ def click_yield_interim_value(self): time.sleep(0.5) self.interim_value = "final" - app = rx.App(state=State) + app = rx.App(state=rx.State) token_input = rx.input( value=State.router.session.client_token, is_read_only=True, id="token" @@ -401,12 +401,12 @@ async def test_event_chain_click( btn.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( + return len((await event_chain.get_state(token)).substates["state"].event_order) == len( exp_event_order ) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).event_order + event_order = (await event_chain.get_state(token)).substates["state"].event_order assert event_order == exp_event_order @@ -453,12 +453,12 @@ async def test_event_chain_on_load( token = assert_token(event_chain, driver) async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( + return len((await event_chain.get_state(token)).substates["state"].event_order) == len( exp_event_order ) await AppHarness._poll_for_async(_has_all_events) - backend_state = await event_chain.get_state(token) + backend_state = (await event_chain.get_state(token)).substates["state"] assert backend_state.event_order == exp_event_order assert backend_state.is_hydrated is True @@ -529,12 +529,12 @@ async def test_event_chain_on_mount( unmount_button.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).event_order) == len( + return len((await event_chain.get_state(token)).substates["state"].event_order) == len( exp_event_order ) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).event_order + event_order = (await event_chain.get_state(token)).substates["state"].event_order assert event_order == exp_event_order diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 4a9f6c2d10a..c5b9f795dd3 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -22,7 +22,7 @@ class FormState(rx.State): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App(state=FormState) + app = rx.App(state=rx.State) @app.add_page def index(): @@ -210,7 +210,7 @@ async def test_submit(driver, form_submit: AppHarness): submit_input.click() async def get_form_data(): - return (await form_submit.get_state(token)).form_data + return (await form_submit.get_state(token)).substates["form_state"].form_data # wait for the form data to arrive at the backend form_data = await AppHarness._poll_for_async(get_form_data) diff --git a/integration/test_input.py b/integration/test_input.py index d24f7681582..3c75e2ff169 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -16,7 +16,7 @@ def FullyControlledInput(): class State(rx.State): text: str = "initial" - app = rx.App(state=State) + app = rx.App(state=rx.State) @app.add_page def index(): @@ -85,13 +85,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("foo") time.sleep(0.5) assert debounce_input.get_attribute("value") == "ifoonitial" - assert (await fully_controlled_input.get_state(token)).text == "ifoonitial" + assert (await fully_controlled_input.get_state(token)).substates["state"].text == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" # clear the input on the backend async with fully_controlled_input.modify_state(token) as state: - state.text = "" - assert (await fully_controlled_input.get_state(token)).text == "" + state.substates["state"].text = "" + assert (await fully_controlled_input.get_state(token)).substates["state"].text=="" assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -105,7 +105,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): assert debounce_input.get_attribute("value") == "getting testing done" assert ( await fully_controlled_input.get_state(token) - ).text == "getting testing done" + ).substates["state"].text == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" # type into the on_change input @@ -113,7 +113,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): time.sleep(0.5) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert (await fully_controlled_input.get_state(token)).text == "overwrite the state" + assert (await fully_controlled_input.get_state(token)).substates["state"].text == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" clear_button.click() diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py index 68e0a864d95..f5363574378 100644 --- a/integration/test_login_flow.py +++ b/integration/test_login_flow.py @@ -42,7 +42,7 @@ def login(): rx.button("Do it", on_click=State.login, id="doit"), ) - app = rx.App(state=State) + app = rx.App(state=rx.State) app.add_page(index) app.add_page(login) app.compile() @@ -137,6 +137,6 @@ def check_auth_token_header(): logout_button = driver.find_element(By.ID, "logout") logout_button.click() - assert login_sample._poll_for(lambda: local_storage["state.auth_token"] == "") + assert login_sample._poll_for(lambda: local_storage["state.state.auth_token"] == "") with pytest.raises(NoSuchElementException): driver.find_element(By.ID, "auth-token") diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py index 8f07786c59e..d4731d20e1a 100644 --- a/integration/test_radix_themes.py +++ b/integration/test_radix_themes.py @@ -81,7 +81,7 @@ def index() -> rx.Component: ) app = rx.App( - state=State, + state=rx.State, theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), ) app.add_page(index) diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py index b24f47368a8..31a38ba36c8 100644 --- a/integration/test_server_side_event.py +++ b/integration/test_server_side_event.py @@ -33,7 +33,7 @@ def set_value_return(self): def set_value_return_c(self): return rx.set_value("c", "") - app = rx.App(state=SSState) + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_table.py b/integration/test_table.py index e947514a063..00e6a8b22d1 100644 --- a/integration/test_table.py +++ b/integration/test_table.py @@ -26,7 +26,7 @@ class TableState(rx.State): caption: str = "random caption" - app = rx.App(state=TableState) + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_upload.py b/integration/test_upload.py index 648c68be550..04e8f3205a8 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -113,7 +113,7 @@ def index(): ), ) - app = rx.App(state=UploadState) + app = rx.App(state=rx.State) app.add_page(index) app.compile() @@ -192,7 +192,7 @@ async def test_upload_file( # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token))._file_data + return (await upload_file.get_state(token)).substates["upload_state"]._file_data file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) @@ -251,7 +251,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token))._file_data + return (await upload_file.get_state(token)).substates["upload_state"]._file_data file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index f327b662c68..374344a87ae 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -30,7 +30,7 @@ class VarOperationState(rx.State): dict2: dict = {3: 4} html_str: str = "
hello
" - app = rx.App(state=VarOperationState) + app = rx.App(state=rx.State) @app.add_page def index(): From b27db358ea1b9c68b25207c35767e8b1abbbd1e0 Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 16:09:08 +0000 Subject: [PATCH 03/24] fix app harness(except dynamic routes) --- integration/test_background_task.py | 2 +- integration/test_call_script.py | 2 +- integration/test_client_storage.py | 20 ++++++++++++++------ integration/test_connection_banner.py | 2 +- integration/test_event_actions.py | 10 +++++++--- integration/test_event_chain.py | 20 ++++++++++---------- integration/test_form_submit.py | 4 ++-- integration/test_input.py | 18 +++++++++++------- integration/test_login_flow.py | 2 +- integration/test_radix_themes.py | 1 - integration/test_server_side_event.py | 2 +- integration/test_table.py | 2 +- integration/test_upload.py | 2 +- integration/test_var_operations.py | 2 +- reflex/app.py | 4 ++++ reflex/app.pyi | 3 ++- reflex/state.py | 14 ++++++++++++-- reflex/testing.py | 9 ++++++--- tests/components/datadisplay/conftest.py | 9 +++++---- tests/components/forms/test_debounce.py | 3 ++- tests/components/layout/test_cond.py | 3 ++- tests/conftest.py | 4 ++-- tests/states/__init__.py | 3 ++- tests/states/mutation.py | 7 ++++--- tests/states/upload.py | 9 +++++---- tests/test_state.py | 4 ++-- 26 files changed, 100 insertions(+), 61 deletions(-) diff --git a/integration/test_background_task.py b/integration/test_background_task.py index bc70ff01fba..a1d2d3ba744 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -93,7 +93,7 @@ def index() -> rx.Component: rx.button("Reset", on_click=State.reset_counter, id="reset"), ) - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.compile() diff --git a/integration/test_call_script.py b/integration/test_call_script.py index 95bdf37c998..6f965f04ee1 100644 --- a/integration/test_call_script.py +++ b/integration/test_call_script.py @@ -135,7 +135,7 @@ def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() - app = rx.App(state=rx.State) + app = rx.App() with open("assets/external.js", "w") as f: f.write(external_scripts) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 5efea8411de..9bceafcc20a 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -97,7 +97,7 @@ def index(): rx.box(ClientSideSubSubState.l1s, id="l1s"), ) - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.add_page(index, route="/foo") app.compile() @@ -354,9 +354,12 @@ async def test_client_side_state( input_value_input.send_keys("c3 value") set_sub_state_button.click() AppHarness._poll_for( - lambda: "state.client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c3" + in cookie_info_map(driver) ) - c3_cookie = cookie_info_map(driver)["state.client_side_state.client_side_sub_state.c3"] + c3_cookie = cookie_info_map(driver)[ + "state.client_side_state.client_side_sub_state.c3" + ] assert c3_cookie.pop("expiry") is not None assert c3_cookie == { "domain": "localhost", @@ -368,7 +371,9 @@ async def test_client_side_state( "value": "c3%20value", } time.sleep(2) # wait for c3 to expire - assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) + assert "state.client_side_state.client_side_sub_state.c3" not in cookie_info_map( + driver + ) local_storage_items = local_storage.items() local_storage_items.pop("chakra-ui-color-mode", None) @@ -482,9 +487,12 @@ async def test_client_side_state( # make sure c5 cookie shows up on the `/foo` route AppHarness._poll_for( - lambda: "state.client_side_state.client_side_sub_state.c5" in cookie_info_map(driver) + lambda: "state.client_side_state.client_side_sub_state.c5" + in cookie_info_map(driver) ) - assert cookie_info_map(driver)["state.client_side_state.client_side_sub_state.c5"] == { + assert cookie_info_map(driver)[ + "state.client_side_state.client_side_sub_state.c5" + ] == { "domain": "localhost", "httpOnly": False, "name": "state.client_side_state.client_side_sub_state.c5", diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py index 0468e6b535f..dd4f5de1976 100644 --- a/integration/test_connection_banner.py +++ b/integration/test_connection_banner.py @@ -19,7 +19,7 @@ class State(rx.State): def index(): return rx.text("Hello World") - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.compile() diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 9bb0ab1893e..28853a7c38e 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -130,7 +130,7 @@ def index(): on_click=EventActionState.on_click("outer"), # type: ignore ) - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.compile() @@ -211,10 +211,14 @@ async def _backend_state(): return await event_action.get_state(token) async def _check(): - return (await _backend_state()).order == exp_order + return (await _backend_state()).substates[ + "event_action_state" + ].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).order == exp_order + assert (await _backend_state()).substates[ + "event_action_state" + ].order == exp_order return _poll_for_order diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index b89df7c1b09..75054f3b1b1 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -122,7 +122,7 @@ def click_yield_interim_value(self): time.sleep(0.5) self.interim_value = "final" - app = rx.App(state=rx.State) + app = rx.App() token_input = rx.input( value=State.router.session.client_token, is_read_only=True, id="token" @@ -401,9 +401,9 @@ async def test_event_chain_click( btn.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).substates["state"].event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) event_order = (await event_chain.get_state(token)).substates["state"].event_order @@ -453,9 +453,9 @@ async def test_event_chain_on_load( token = assert_token(event_chain, driver) async def _has_all_events(): - return len((await event_chain.get_state(token)).substates["state"].event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) backend_state = (await event_chain.get_state(token)).substates["state"] @@ -529,9 +529,9 @@ async def test_event_chain_on_mount( unmount_button.click() async def _has_all_events(): - return len((await event_chain.get_state(token)).substates["state"].event_order) == len( - exp_event_order - ) + return len( + (await event_chain.get_state(token)).substates["state"].event_order + ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) event_order = (await event_chain.get_state(token)).substates["state"].event_order diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index c5b9f795dd3..1bf614fcc42 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -22,7 +22,7 @@ class FormState(rx.State): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App(state=rx.State) + app = rx.App() @app.add_page def index(): @@ -75,7 +75,7 @@ class FormState(rx.State): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App(state=FormState) + app = rx.App() @app.add_page def index(): diff --git a/integration/test_input.py b/integration/test_input.py index 3c75e2ff169..7c6fe264c5e 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -16,7 +16,7 @@ def FullyControlledInput(): class State(rx.State): text: str = "initial" - app = rx.App(state=rx.State) + app = rx.App() @app.add_page def index(): @@ -85,13 +85,15 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("foo") time.sleep(0.5) assert debounce_input.get_attribute("value") == "ifoonitial" - assert (await fully_controlled_input.get_state(token)).substates["state"].text == "ifoonitial" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" # clear the input on the backend async with fully_controlled_input.modify_state(token) as state: state.substates["state"].text = "" - assert (await fully_controlled_input.get_state(token)).substates["state"].text=="" + assert (await fully_controlled_input.get_state(token)).substates["state"].text == "" assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -103,9 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("getting testing done") time.sleep(0.5) assert debounce_input.get_attribute("value") == "getting testing done" - assert ( - await fully_controlled_input.get_state(token) - ).substates["state"].text == "getting testing done" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" # type into the on_change input @@ -113,7 +115,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): time.sleep(0.5) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert (await fully_controlled_input.get_state(token)).substates["state"].text == "overwrite the state" + assert (await fully_controlled_input.get_state(token)).substates[ + "state" + ].text == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" clear_button.click() diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py index f5363574378..499fbab4b75 100644 --- a/integration/test_login_flow.py +++ b/integration/test_login_flow.py @@ -42,7 +42,7 @@ def login(): rx.button("Do it", on_click=State.login, id="doit"), ) - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.add_page(login) app.compile() diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py index d4731d20e1a..4283560231e 100644 --- a/integration/test_radix_themes.py +++ b/integration/test_radix_themes.py @@ -81,7 +81,6 @@ def index() -> rx.Component: ) app = rx.App( - state=rx.State, theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), ) app.add_page(index) diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py index 31a38ba36c8..57e8ea9e277 100644 --- a/integration/test_server_side_event.py +++ b/integration/test_server_side_event.py @@ -33,7 +33,7 @@ def set_value_return(self): def set_value_return_c(self): return rx.set_value("c", "") - app = rx.App(state=rx.State) + app = rx.App() @app.add_page def index(): diff --git a/integration/test_table.py b/integration/test_table.py index 00e6a8b22d1..0bd7500863d 100644 --- a/integration/test_table.py +++ b/integration/test_table.py @@ -26,7 +26,7 @@ class TableState(rx.State): caption: str = "random caption" - app = rx.App(state=rx.State) + app = rx.App() @app.add_page def index(): diff --git a/integration/test_upload.py b/integration/test_upload.py index 04e8f3205a8..c906271e85a 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -113,7 +113,7 @@ def index(): ), ) - app = rx.App(state=rx.State) + app = rx.App() app.add_page(index) app.compile() diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 374344a87ae..693547c6aca 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -30,7 +30,7 @@ class VarOperationState(rx.State): dict2: dict = {3: 4} html_str: str = "
hello
" - app = rx.App(state=rx.State) + app = rx.App() @app.add_page def index(): diff --git a/reflex/app.py b/reflex/app.py index 42d9389e2a5..bff1549fa96 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -153,6 +153,10 @@ def __init__(self, *args, **kwargs): state_subclasses = BaseState.__subclasses__() is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ + # For tests, make rx.State the state if not provided(Typically useful for app harness tests) + if is_testing_env and not self.state: + self.state = State + # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env: # Only one Base State class is allowed. diff --git a/reflex/app.pyi b/reflex/app.pyi index f7e63727ae0..667ebf52a11 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -33,6 +33,7 @@ from reflex.route import ( ) from reflex.state import ( State as State, + BaseState as BaseState, StateManager as StateManager, StateUpdate as StateUpdate, ) @@ -69,7 +70,7 @@ class App(Base): api: FastAPI sio: Optional[AsyncServer] socket_app: Optional[ASGIApp] - state: Type[State] + state: Type[BaseState] state_manager: StateManager style: ComponentStyle middleware: List[Middleware] diff --git a/reflex/state.py b/reflex/state.py index 389cdf0253d..b7c1427c605 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -18,6 +18,7 @@ AsyncIterator, Callable, ClassVar, + DefaultDict, Dict, List, Optional, @@ -176,6 +177,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} + # A mapping of classes and corresponding subclassses. + class_subclasses: ClassVar[ + DefaultDict[Type[BaseState], Set[Type[BaseState]]] + ] = defaultdict(set) + # Mapping of var name to set of computed variables that depend on it _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} @@ -291,6 +297,10 @@ def __init_subclass__(cls, **kwargs): if parent_state is not None: cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars + # fix up parent class_substates + cls.class_subclasses[parent_state].add(cls) if not any( + [c.__name__ == cls.__name__ for c in cls.class_subclasses[parent_state]] + ) else None cls.new_backend_vars = { name: value @@ -452,14 +462,14 @@ def get_parent_state(cls) -> Type[BaseState] | None: return parent_states[0] if len(parent_states) == 1 else None # type: ignore @classmethod - @functools.lru_cache() def get_substates(cls) -> set[Type[BaseState]]: """Get the substates of the state. Returns: The substates of the state. """ - return set(cls.__subclasses__()) + # return set(cls.__subclasses__()) + return cls.class_subclasses[cls] @classmethod @functools.lru_cache() diff --git a/reflex/testing.py b/reflex/testing.py index b2ba8bd49a8..aa3d387336f 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -162,6 +162,9 @@ def _initialize_app(self): with chdir(self.app_path): # ensure config and app are reloaded when testing different app reflex.config.get_config(reload=True) + # reset rx.State subclasses + State.class_subclasses.clear() + # self.app_module.app. self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_instance = self.app_module.app if isinstance(self.app_instance.state_manager, StateManagerRedis): @@ -169,6 +172,8 @@ def _initialize_app(self): self.state_manager = StateManagerRedis.create(self.app_instance.state) else: self.state_manager = self.app_instance.state_manager + breakpoint() + print() def _get_backend_shutdown_handler(self): if self.backend is None: @@ -561,9 +566,7 @@ def poll_for_value( ) return element.get_attribute("value") - def poll_for_clients( - self, timeout: TimeoutType = None - ) -> dict[str, reflex.BaseState]: + def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: """Poll app state_manager for any connected clients. Args: diff --git a/tests/components/datadisplay/conftest.py b/tests/components/datadisplay/conftest.py index 18eafb17c24..93796ed23a8 100644 --- a/tests/components/datadisplay/conftest.py +++ b/tests/components/datadisplay/conftest.py @@ -5,6 +5,7 @@ import pytest import reflex as rx +from reflex.state import BaseState @pytest.fixture @@ -18,7 +19,7 @@ def data_table_state(request): The data table state class. """ - class DataTableState(rx.BaseState): + class DataTableState(BaseState): data = request.param["data"] columns = ["column1", "column2"] @@ -33,7 +34,7 @@ def data_table_state2(): The data table state class. """ - class DataTableState(rx.BaseState): + class DataTableState(BaseState): _data = pd.DataFrame() @rx.var @@ -51,7 +52,7 @@ def data_table_state3(): The data table state class. """ - class DataTableState(rx.BaseState): + class DataTableState(BaseState): _data: List = [] _columns: List = ["col1", "col2"] @@ -74,7 +75,7 @@ def data_table_state4(): The data table state class. """ - class DataTableState(rx.BaseState): + class DataTableState(BaseState): _data: List = [] _columns: List = ["col1", "col2"] diff --git a/tests/components/forms/test_debounce.py b/tests/components/forms/test_debounce.py index 91a2d938e53..97cfa86484e 100644 --- a/tests/components/forms/test_debounce.py +++ b/tests/components/forms/test_debounce.py @@ -3,6 +3,7 @@ import pytest import reflex as rx +from reflex.state import BaseState from reflex.vars import BaseVar @@ -24,7 +25,7 @@ def test_render_many_child(): _ = rx.debounce_input("foo", "bar").render() -class S(rx.BaseState): +class S(BaseState): """Example state for debounce tests.""" value: str = "" diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index 27f204e4cbf..08c6883c448 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -15,12 +15,13 @@ tablet_only, ) from reflex.components.typography.text import Text +from reflex.state import BaseState from reflex.vars import Var @pytest.fixture def cond_state(request): - class CondState(rx.BaseState): + class CondState(BaseState): value: request.param["value_type"] = request.param["value"] # noqa return CondState diff --git a/tests/conftest.py b/tests/conftest.py index 040b1221918..a49b17a3df3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,9 +8,9 @@ import pytest -import reflex as rx from reflex.app import App from reflex.event import EventSpec +from reflex.state import BaseState from .states import ( DictMutationTestState, @@ -235,7 +235,7 @@ def duplicate_substate(): The test state. """ - class TestState(rx.BaseState): + class TestState(BaseState): pass class ChildTestState(TestState): # type: ignore # noqa diff --git a/tests/states/__init__.py b/tests/states/__init__.py index d8baf53e4e1..b825bbffdcd 100644 --- a/tests/states/__init__.py +++ b/tests/states/__init__.py @@ -1,5 +1,6 @@ """Common rx.BaseState subclasses for use in tests.""" import reflex as rx +from reflex.state import BaseState from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState from .upload import ( @@ -12,7 +13,7 @@ ) -class GenState(rx.BaseState): +class GenState(BaseState): """A state with event handlers that generate multiple updates.""" value: int diff --git a/tests/states/mutation.py b/tests/states/mutation.py index f6d050295a1..5825b6d12bd 100644 --- a/tests/states/mutation.py +++ b/tests/states/mutation.py @@ -3,9 +3,10 @@ from typing import Dict, List, Set, Union import reflex as rx +from reflex.state import BaseState -class DictMutationTestState(rx.BaseState): +class DictMutationTestState(BaseState): """A state for testing ReflexDict mutation.""" # plain dict @@ -62,7 +63,7 @@ def add_friend_age(self): self.friend_in_nested_dict["friend"]["age"] = 30 -class ListMutationTestState(rx.BaseState): +class ListMutationTestState(BaseState): """A state for testing ReflexList mutation.""" # plain list @@ -144,7 +145,7 @@ class CustomVar(rx.Base): custom: OtherBase = OtherBase() -class MutableTestState(rx.BaseState): +class MutableTestState(BaseState): """A test state.""" array: List[Union[str, List, Dict[str, str]]] = [ diff --git a/tests/states/upload.py b/tests/states/upload.py index 2f5ba950dcb..fa0e086d519 100644 --- a/tests/states/upload.py +++ b/tests/states/upload.py @@ -3,9 +3,10 @@ from typing import ClassVar, List import reflex as rx +from reflex.state import BaseState -class UploadState(rx.BaseState): +class UploadState(BaseState): """The base state for uploading a file.""" async def handle_upload1(self, files: List[rx.UploadFile]): @@ -17,7 +18,7 @@ async def handle_upload1(self, files: List[rx.UploadFile]): pass -class BaseState(rx.BaseState): +class BaseState(BaseState): """The test base state.""" pass @@ -37,7 +38,7 @@ async def handle_upload(self, files: List[rx.UploadFile]): pass -class FileUploadState(rx.BaseState): +class FileUploadState(BaseState): """The base state for uploading a file.""" img_list: List[str] @@ -79,7 +80,7 @@ async def bg_upload(self, files: List[rx.UploadFile]): pass -class FileStateBase1(rx.BaseState): +class FileStateBase1(BaseState): """The base state for a child FileUploadState.""" pass diff --git a/tests/test_state.py b/tests/test_state.py index 8879af67bfe..4d9264f159a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1369,7 +1369,7 @@ def test_error_on_state_method_shadow(): """Test that an error is thrown when an event handler shadows a state method.""" with pytest.raises(NameError) as err: - class InvalidTest(rx.BaseState): + class InvalidTest(BaseState): def reset(self): pass @@ -1382,7 +1382,7 @@ def reset(self): def test_state_with_invalid_yield(): """Test that an error is thrown when a state yields an invalid value.""" - class StateWithInvalidYield(rx.BaseState): + class StateWithInvalidYield(BaseState): """A state that yields an invalid value.""" def invalid_handler(self): From 9205a16088bcb4bcceaef3fc29a5fccddf4842e6 Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 16:11:18 +0000 Subject: [PATCH 04/24] remove debug lines --- integration/test_dynamic_routes.py | 259 ++++++++++++++--------------- 1 file changed, 128 insertions(+), 131 deletions(-) diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index ff0afdebabc..21c519969d3 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -1,13 +1,10 @@ """Integration tests for dynamic route page behavior.""" from typing import Callable, Coroutine, Generator, Type -from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By -from reflex.testing import AppHarness, AppHarnessProd, WebDriver - -from .utils import poll_for_navigation +from reflex.testing import AppHarness, WebDriver def DynamicRoute(): @@ -56,7 +53,7 @@ def index(): def redirect_page(): return rx.fragment(rx.text("redirecting...")) - app = rx.App(state=rx.State ) + app = rx.App() app.add_page(index) app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore @@ -151,129 +148,129 @@ async def _check(): return _poll_for_order -@pytest.mark.asyncio -async def test_on_load_navigate( - dynamic_route: AppHarness, - driver: WebDriver, - token: str, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], -): - """Click links to navigate between dynamic pages with on_load event. - - Args: - dynamic_route: harness for DynamicRoute app. - driver: WebDriver instance. - token: The token visible in the driver browser. - poll_for_order: function that polls for the order list to match the expected order. - """ - assert dynamic_route.app_instance is not None - is_prod = isinstance(dynamic_route, AppHarnessProd) - link = driver.find_element(By.ID, "link_page_next") - assert link - - exp_order = [f"/page/[page_id]-{ix}" for ix in range(10)] - # click the link a few times - for ix in range(10): - # wait for navigation, then assert on url - with poll_for_navigation(driver): - link.click() - assert urlsplit(driver.current_url).path == f"/page/{ix}/" - - link = driver.find_element(By.ID, "link_page_next") - page_id_input = driver.find_element(By.ID, "page_id") - - assert link - assert page_id_input - - assert dynamic_route.poll_for_value(page_id_input) == str(ix) - await poll_for_order(exp_order) - - # manually load the next page to trigger client side routing in prod mode - if is_prod: - exp_order += ["/404-no page id"] - exp_order += ["/page/[page_id]-10"] - with poll_for_navigation(driver): - driver.get(f"{dynamic_route.frontend_url}/page/10/") - await poll_for_order(exp_order) - - # make sure internal nav still hydrates after redirect - exp_order += ["/page/[page_id]-11"] - link = driver.find_element(By.ID, "link_page_next") - with poll_for_navigation(driver): - link.click() - await poll_for_order(exp_order) - - # load same page with a query param and make sure it passes through - if is_prod: - exp_order += ["/404-no page id"] - exp_order += ["/page/[page_id]-11"] - with poll_for_navigation(driver): - driver.get(f"{driver.current_url}?foo=bar") - await poll_for_order(exp_order) - assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar" - - # hit a 404 and ensure we still hydrate - exp_order += ["/404-no page id"] - with poll_for_navigation(driver): - driver.get(f"{dynamic_route.frontend_url}/missing") - await poll_for_order(exp_order) - - # browser nav should still trigger hydration - if is_prod: - exp_order += ["/404-no page id"] - exp_order += ["/page/[page_id]-11"] - with poll_for_navigation(driver): - driver.back() - await poll_for_order(exp_order) - - # next/link to a 404 and ensure we still hydrate - exp_order += ["/404-no page id"] - link = driver.find_element(By.ID, "link_missing") - with poll_for_navigation(driver): - link.click() - await poll_for_order(exp_order) - - # hit a page that redirects back to dynamic page - if is_prod: - exp_order += ["/404-no page id"] - exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] - with poll_for_navigation(driver): - driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") - await poll_for_order(exp_order) - # should have redirected back to page 0 - assert urlsplit(driver.current_url).path == "/page/0/" - - -@pytest.mark.asyncio -async def test_on_load_navigate_non_dynamic( - dynamic_route: AppHarness, - driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], -): - """Click links to navigate between static pages with on_load event. - - Args: - dynamic_route: harness for DynamicRoute app. - driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. - """ - assert dynamic_route.app_instance is not None - link = driver.find_element(By.ID, "link_page_x") - assert link - - with poll_for_navigation(driver): - link.click() - assert urlsplit(driver.current_url).path == "/static/x/" - await poll_for_order(["/static/x-no page id"]) - - # go back to the index and navigate back to the static route - link = driver.find_element(By.ID, "link_index") - with poll_for_navigation(driver): - link.click() - assert urlsplit(driver.current_url).path == "/" - - link = driver.find_element(By.ID, "link_page_x") - with poll_for_navigation(driver): - link.click() - assert urlsplit(driver.current_url).path == "/static/x/" - await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) +# @pytest.mark.asyncio +# async def test_on_load_navigate( +# dynamic_route: AppHarness, +# driver: WebDriver, +# token: str, +# poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], +# ): +# """Click links to navigate between dynamic pages with on_load event. +# +# Args: +# dynamic_route: harness for DynamicRoute app. +# driver: WebDriver instance. +# token: The token visible in the driver browser. +# poll_for_order: function that polls for the order list to match the expected order. +# """ +# assert dynamic_route.app_instance is not None +# is_prod = isinstance(dynamic_route, AppHarnessProd) +# link = driver.find_element(By.ID, "link_page_next") +# assert link +# +# exp_order = [f"/page/[page_id]-{ix}" for ix in range(10)] +# # click the link a few times +# for ix in range(10): +# # wait for navigation, then assert on url +# with poll_for_navigation(driver): +# link.click() +# assert urlsplit(driver.current_url).path == f"/page/{ix}/" +# +# link = driver.find_element(By.ID, "link_page_next") +# page_id_input = driver.find_element(By.ID, "page_id") +# +# assert link +# assert page_id_input +# +# assert dynamic_route.poll_for_value(page_id_input) == str(ix) +# await poll_for_order(exp_order) +# +# # manually load the next page to trigger client side routing in prod mode +# if is_prod: +# exp_order += ["/404-no page id"] +# exp_order += ["/page/[page_id]-10"] +# with poll_for_navigation(driver): +# driver.get(f"{dynamic_route.frontend_url}/page/10/") +# await poll_for_order(exp_order) +# +# # make sure internal nav still hydrates after redirect +# exp_order += ["/page/[page_id]-11"] +# link = driver.find_element(By.ID, "link_page_next") +# with poll_for_navigation(driver): +# link.click() +# await poll_for_order(exp_order) +# +# # load same page with a query param and make sure it passes through +# if is_prod: +# exp_order += ["/404-no page id"] +# exp_order += ["/page/[page_id]-11"] +# with poll_for_navigation(driver): +# driver.get(f"{driver.current_url}?foo=bar") +# await poll_for_order(exp_order) +# assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar" +# +# # hit a 404 and ensure we still hydrate +# exp_order += ["/404-no page id"] +# with poll_for_navigation(driver): +# driver.get(f"{dynamic_route.frontend_url}/missing") +# await poll_for_order(exp_order) +# +# # browser nav should still trigger hydration +# if is_prod: +# exp_order += ["/404-no page id"] +# exp_order += ["/page/[page_id]-11"] +# with poll_for_navigation(driver): +# driver.back() +# await poll_for_order(exp_order) +# +# # next/link to a 404 and ensure we still hydrate +# exp_order += ["/404-no page id"] +# link = driver.find_element(By.ID, "link_missing") +# with poll_for_navigation(driver): +# link.click() +# await poll_for_order(exp_order) +# +# # hit a page that redirects back to dynamic page +# if is_prod: +# exp_order += ["/404-no page id"] +# exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] +# with poll_for_navigation(driver): +# driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") +# await poll_for_order(exp_order) +# # should have redirected back to page 0 +# assert urlsplit(driver.current_url).path == "/page/0/" +# +# +# @pytest.mark.asyncio +# async def test_on_load_navigate_non_dynamic( +# dynamic_route: AppHarness, +# driver: WebDriver, +# poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], +# ): +# """Click links to navigate between static pages with on_load event. +# +# Args: +# dynamic_route: harness for DynamicRoute app. +# driver: WebDriver instance. +# poll_for_order: function that polls for the order list to match the expected order. +# """ +# assert dynamic_route.app_instance is not None +# link = driver.find_element(By.ID, "link_page_x") +# assert link +# +# with poll_for_navigation(driver): +# link.click() +# assert urlsplit(driver.current_url).path == "/static/x/" +# await poll_for_order(["/static/x-no page id"]) +# +# # go back to the index and navigate back to the static route +# link = driver.find_element(By.ID, "link_index") +# with poll_for_navigation(driver): +# link.click() +# assert urlsplit(driver.current_url).path == "/" +# +# link = driver.find_element(By.ID, "link_page_x") +# with poll_for_navigation(driver): +# link.click() +# assert urlsplit(driver.current_url).path == "/static/x/" +# await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) From 7ae9b4401094ca4ec1ba4551bb43e7cbc2b9482d Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 16:12:31 +0000 Subject: [PATCH 05/24] remove debug lines --- reflex/testing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/reflex/testing.py b/reflex/testing.py index aa3d387336f..2cc20e6a4d9 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -172,8 +172,6 @@ def _initialize_app(self): self.state_manager = StateManagerRedis.create(self.app_instance.state) else: self.state_manager = self.app_instance.state_manager - breakpoint() - print() def _get_backend_shutdown_handler(self): if self.backend is None: From 2f119805c55c57a3c624ef522f534bcee946b678 Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 17:06:56 +0000 Subject: [PATCH 06/24] fix for reflex web --- reflex/state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index b7c1427c605..4c843d5fe6d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -215,8 +215,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The router data for the current page router: RouterData = RouterData() - # The hydrated bool. - is_hydrated: bool = False def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): """Initialize the state. @@ -1301,6 +1299,7 @@ async def __aexit__(self, *exc_info: Any) -> None: class State(BaseState): """The app Base State.""" + # The hydrated bool. is_hydrated: bool = False From 68a5df8a3bc25522c0709515f1e09d1e90724bac Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 19:34:54 +0000 Subject: [PATCH 07/24] uncomment dynamic route state --- integration/test_dynamic_routes.py | 252 ++++++++++++++--------------- 1 file changed, 126 insertions(+), 126 deletions(-) diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 21c519969d3..55bae2018da 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -148,129 +148,129 @@ async def _check(): return _poll_for_order -# @pytest.mark.asyncio -# async def test_on_load_navigate( -# dynamic_route: AppHarness, -# driver: WebDriver, -# token: str, -# poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], -# ): -# """Click links to navigate between dynamic pages with on_load event. -# -# Args: -# dynamic_route: harness for DynamicRoute app. -# driver: WebDriver instance. -# token: The token visible in the driver browser. -# poll_for_order: function that polls for the order list to match the expected order. -# """ -# assert dynamic_route.app_instance is not None -# is_prod = isinstance(dynamic_route, AppHarnessProd) -# link = driver.find_element(By.ID, "link_page_next") -# assert link -# -# exp_order = [f"/page/[page_id]-{ix}" for ix in range(10)] -# # click the link a few times -# for ix in range(10): -# # wait for navigation, then assert on url -# with poll_for_navigation(driver): -# link.click() -# assert urlsplit(driver.current_url).path == f"/page/{ix}/" -# -# link = driver.find_element(By.ID, "link_page_next") -# page_id_input = driver.find_element(By.ID, "page_id") -# -# assert link -# assert page_id_input -# -# assert dynamic_route.poll_for_value(page_id_input) == str(ix) -# await poll_for_order(exp_order) -# -# # manually load the next page to trigger client side routing in prod mode -# if is_prod: -# exp_order += ["/404-no page id"] -# exp_order += ["/page/[page_id]-10"] -# with poll_for_navigation(driver): -# driver.get(f"{dynamic_route.frontend_url}/page/10/") -# await poll_for_order(exp_order) -# -# # make sure internal nav still hydrates after redirect -# exp_order += ["/page/[page_id]-11"] -# link = driver.find_element(By.ID, "link_page_next") -# with poll_for_navigation(driver): -# link.click() -# await poll_for_order(exp_order) -# -# # load same page with a query param and make sure it passes through -# if is_prod: -# exp_order += ["/404-no page id"] -# exp_order += ["/page/[page_id]-11"] -# with poll_for_navigation(driver): -# driver.get(f"{driver.current_url}?foo=bar") -# await poll_for_order(exp_order) -# assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar" -# -# # hit a 404 and ensure we still hydrate -# exp_order += ["/404-no page id"] -# with poll_for_navigation(driver): -# driver.get(f"{dynamic_route.frontend_url}/missing") -# await poll_for_order(exp_order) -# -# # browser nav should still trigger hydration -# if is_prod: -# exp_order += ["/404-no page id"] -# exp_order += ["/page/[page_id]-11"] -# with poll_for_navigation(driver): -# driver.back() -# await poll_for_order(exp_order) -# -# # next/link to a 404 and ensure we still hydrate -# exp_order += ["/404-no page id"] -# link = driver.find_element(By.ID, "link_missing") -# with poll_for_navigation(driver): -# link.click() -# await poll_for_order(exp_order) -# -# # hit a page that redirects back to dynamic page -# if is_prod: -# exp_order += ["/404-no page id"] -# exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] -# with poll_for_navigation(driver): -# driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") -# await poll_for_order(exp_order) -# # should have redirected back to page 0 -# assert urlsplit(driver.current_url).path == "/page/0/" -# -# -# @pytest.mark.asyncio -# async def test_on_load_navigate_non_dynamic( -# dynamic_route: AppHarness, -# driver: WebDriver, -# poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], -# ): -# """Click links to navigate between static pages with on_load event. -# -# Args: -# dynamic_route: harness for DynamicRoute app. -# driver: WebDriver instance. -# poll_for_order: function that polls for the order list to match the expected order. -# """ -# assert dynamic_route.app_instance is not None -# link = driver.find_element(By.ID, "link_page_x") -# assert link -# -# with poll_for_navigation(driver): -# link.click() -# assert urlsplit(driver.current_url).path == "/static/x/" -# await poll_for_order(["/static/x-no page id"]) -# -# # go back to the index and navigate back to the static route -# link = driver.find_element(By.ID, "link_index") -# with poll_for_navigation(driver): -# link.click() -# assert urlsplit(driver.current_url).path == "/" -# -# link = driver.find_element(By.ID, "link_page_x") -# with poll_for_navigation(driver): -# link.click() -# assert urlsplit(driver.current_url).path == "/static/x/" -# await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) +@pytest.mark.asyncio +async def test_on_load_navigate( + dynamic_route: AppHarness, + driver: WebDriver, + token: str, + poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], +): + """Click links to navigate between dynamic pages with on_load event. + + Args: + dynamic_route: harness for DynamicRoute app. + driver: WebDriver instance. + token: The token visible in the driver browser. + poll_for_order: function that polls for the order list to match the expected order. + """ + assert dynamic_route.app_instance is not None + is_prod = isinstance(dynamic_route, AppHarnessProd) + link = driver.find_element(By.ID, "link_page_next") + assert link + + exp_order = [f"/page/[page_id]-{ix}" for ix in range(10)] + # click the link a few times + for ix in range(10): + # wait for navigation, then assert on url + with poll_for_navigation(driver): + link.click() + assert urlsplit(driver.current_url).path == f"/page/{ix}/" + + link = driver.find_element(By.ID, "link_page_next") + page_id_input = driver.find_element(By.ID, "page_id") + + assert link + assert page_id_input + + assert dynamic_route.poll_for_value(page_id_input) == str(ix) + await poll_for_order(exp_order) + + # manually load the next page to trigger client side routing in prod mode + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page_id]-10"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/page/10/") + await poll_for_order(exp_order) + + # make sure internal nav still hydrates after redirect + exp_order += ["/page/[page_id]-11"] + link = driver.find_element(By.ID, "link_page_next") + with poll_for_navigation(driver): + link.click() + await poll_for_order(exp_order) + + # load same page with a query param and make sure it passes through + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page_id]-11"] + with poll_for_navigation(driver): + driver.get(f"{driver.current_url}?foo=bar") + await poll_for_order(exp_order) + assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar" + + # hit a 404 and ensure we still hydrate + exp_order += ["/404-no page id"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/missing") + await poll_for_order(exp_order) + + # browser nav should still trigger hydration + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["/page/[page_id]-11"] + with poll_for_navigation(driver): + driver.back() + await poll_for_order(exp_order) + + # next/link to a 404 and ensure we still hydrate + exp_order += ["/404-no page id"] + link = driver.find_element(By.ID, "link_missing") + with poll_for_navigation(driver): + link.click() + await poll_for_order(exp_order) + + # hit a page that redirects back to dynamic page + if is_prod: + exp_order += ["/404-no page id"] + exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] + with poll_for_navigation(driver): + driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") + await poll_for_order(exp_order) + # should have redirected back to page 0 + assert urlsplit(driver.current_url).path == "/page/0/" + + +@pytest.mark.asyncio +async def test_on_load_navigate_non_dynamic( + dynamic_route: AppHarness, + driver: WebDriver, + poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], +): + """Click links to navigate between static pages with on_load event. + + Args: + dynamic_route: harness for DynamicRoute app. + driver: WebDriver instance. + poll_for_order: function that polls for the order list to match the expected order. + """ + assert dynamic_route.app_instance is not None + link = driver.find_element(By.ID, "link_page_x") + assert link + + with poll_for_navigation(driver): + link.click() + assert urlsplit(driver.current_url).path == "/static/x/" + await poll_for_order(["/static/x-no page id"]) + + # go back to the index and navigate back to the static route + link = driver.find_element(By.ID, "link_index") + with poll_for_navigation(driver): + link.click() + assert urlsplit(driver.current_url).path == "/" + + link = driver.find_element(By.ID, "link_page_x") + with poll_for_navigation(driver): + link.click() + assert urlsplit(driver.current_url).path == "/static/x/" + await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) From 7bc3c712a2498ce0fb11ac8ffa15468854182ba9 Mon Sep 17 00:00:00 2001 From: Elijah Date: Wed, 15 Nov 2023 19:56:11 +0000 Subject: [PATCH 08/24] remove debug lines --- tests/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 41bc9f99857..b4d8570e5a3 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,3 @@ import os from reflex import constants - -os.environ[constants.PYTEST_CURRENT_TEST] = "true" -print("Gegining testing-------------\n") -print(constants.PYTEST_CURRENT_TEST in os.environ) From ba6cfd39c32cf159ee001eef4e14fdfe733c4aa2 Mon Sep 17 00:00:00 2001 From: Elijah Date: Tue, 21 Nov 2023 06:29:23 +0000 Subject: [PATCH 09/24] monkeypatch validate_field_name in pydantic --- integration/test_dynamic_routes.py | 11 ++++++++--- reflex/base.py | 31 +++++++++++++++++++++++++++++- reflex/constants/__init__.py | 2 ++ reflex/constants/base.py | 1 + reflex/state.py | 1 - reflex/utils/prerequisites.py | 2 ++ 6 files changed, 43 insertions(+), 5 deletions(-) diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 55bae2018da..84673bf869b 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -1,10 +1,13 @@ """Integration tests for dynamic route page behavior.""" from typing import Callable, Coroutine, Generator, Type +from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By -from reflex.testing import AppHarness, WebDriver +from reflex.testing import AppHarness, AppHarnessProd, WebDriver + +from .utils import poll_for_navigation def DynamicRoute(): @@ -140,10 +143,12 @@ async def _backend_state(): return await dynamic_route.get_state(token) async def _check(): - return (await _backend_state()).order == exp_order + return (await _backend_state()).substates[ + "dynamic_state" + ].order == exp_order await AppHarness._poll_for_async(_check) - assert (await _backend_state()).order == exp_order + assert (await _backend_state()).substates["dynamic_state"].order == exp_order return _poll_for_order diff --git a/reflex/base.py b/reflex/base.py index 62bc70855ce..e1d8dd298f7 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -1,11 +1,40 @@ """Define the base Reflex class.""" from __future__ import annotations -from typing import Any +import os +from typing import Any, List, Type import pydantic +from pydantic import BaseModel from pydantic.fields import ModelField +from reflex import constants + + +def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: + """Ensure that the field's name does not shadow an existing attribute of the model. + + Args: + bases: List of base models to check for shadowed attrs. + field_name: name of attribute + + Raises: + NameError: If state var field shadows another in its parent state + """ + reload = os.getenv(constants.RELOAD_CONFIG) == "True" + for base in bases: + try: + if not reload and getattr(base, field_name, None): + pass + except TypeError as te: + raise NameError( + f'State var "{field_name}" in {base} has been shadowed by a substate var; ' + f'use a different field name instead".' + ) from te + + +pydantic.main.validate_field_name = validate_field_name # type: ignore + class Base(pydantic.BaseModel): """The base class subclassed by all Reflex classes. diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index ca408c6a657..6a4a341edd5 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -6,6 +6,7 @@ LOCAL_STORAGE, POLLING_MAX_HTTP_BUFFER_SIZE, PYTEST_CURRENT_TEST, + RELOAD_CONFIG, SKIP_COMPILE_ENV_VAR, ColorMode, Dirs, @@ -87,6 +88,7 @@ PYTEST_CURRENT_TEST, PRODUCTION_BACKEND_URL, Reflex, + RELOAD_CONFIG, RouteArgType, RouteRegex, RouteVar, diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 0b28e18cb26..1781aa0ce13 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -173,3 +173,4 @@ class Ping(SimpleNamespace): # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" +RELOAD_CONFIG = "RELOAD_CONFIG" diff --git a/reflex/state.py b/reflex/state.py index 4c843d5fe6d..4877a615f99 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -215,7 +215,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The router data for the current page router: RouterData = RouterData() - def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): """Initialize the state. diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 17195f2242e..f0e53b2d20b 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -124,11 +124,13 @@ def get_app(reload: bool = False) -> ModuleType: Returns: The app based on the default config. """ + os.environ[constants.RELOAD_CONFIG] = str(reload) config = get_config() module = ".".join([config.app_name, config.app_name]) sys.path.insert(0, os.getcwd()) app = __import__(module, fromlist=(constants.CompileVars.APP,)) if reload: + importlib.reload(app) return app From 646439175e14794657d2e7fbbe3197f1d0b8a705 Mon Sep 17 00:00:00 2001 From: Elijah Date: Tue, 21 Nov 2023 17:00:57 +0000 Subject: [PATCH 10/24] partial unit test fix --- tests/middleware/test_hydrate_middleware.py | 3 ++- tests/test_app.py | 8 +++++++- tests/test_state.py | 7 +------ tests/utils/test_format.py | 2 -- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 357aa4401d0..956ed1fd40f 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -6,7 +6,7 @@ from reflex.constants import CompileVars from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.state import BaseState, StateUpdate - +from reflex import constants def exp_is_hydrated(state: BaseState) -> Dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. @@ -97,6 +97,7 @@ async def test_preprocess( event_fixture: The event fixture(an Event). expected: Expected delta. """ + test_state.add_var(constants.CompileVars.IS_HYDRATED,type_=bool, default_value=False) app = App(state=test_state, load_events={"index": [test_state.test_handler]}) state = test_state() diff --git a/tests/test_app.py b/tests/test_app.py index 787eeedb691..01dace37bd7 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -749,7 +749,7 @@ async def test_upload_file(tmp_path, state, delta, token: str): request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, - "reflex-event-handler": f"{state_name}.multi_handle_upload", + "reflex-event-handler": f"state.{state_name}.multi_handle_upload", } file1 = UploadFile( @@ -894,6 +894,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page, windows_platform: bool, token: str, + mocker ): """Create app with dynamic route var, and simulate navigation. @@ -905,6 +906,11 @@ async def test_dynamic_route_var_route_change_completed_on_load( windows_platform: Whether the system is windows. token: a Token. """ + class_subclasses = { + State: {DynamicState} + } + mocker.patch("reflex.state.State.class_subclasses", class_subclasses) + DynamicState.add_var(constants.CompileVars.IS_HYDRATED,type_=bool, default_value=False) arg_name = "dynamic" route = f"/test/[{arg_name}]" if windows_platform: diff --git a/tests/test_state.py b/tests/test_state.py index 4d9264f159a..421551f28c8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -253,7 +253,6 @@ def test_class_vars(test_state): """ cls = type(test_state) assert set(cls.vars.keys()) == { - CompileVars.IS_HYDRATED, # added by hydrate_middleware to all State "router", "num1", "num2", @@ -641,7 +640,6 @@ def test_reset(test_state, child_state): "obj", "upper", "complex", - "is_hydrated", "fig", "key", "sum", @@ -1165,21 +1163,18 @@ def dep_v(self) -> int: dict1 = ps.dict() assert dict1[ps.get_full_name()] == { "no_cache_v": 1, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict1[cs.get_full_name()] == {"dep_v": 2} dict2 = ps.dict() assert dict2[ps.get_full_name()] == { "no_cache_v": 3, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict2[cs.get_full_name()] == {"dep_v": 4} dict3 = ps.dict() assert dict3[ps.get_full_name()] == { "no_cache_v": 5, - CompileVars.IS_HYDRATED: False, "router": formatted_router, } assert dict3[cs.get_full_name()] == {"dep_v": 6} @@ -2216,7 +2211,7 @@ class MutableContainsBase(BaseState): f_formatted_router = str(formatted_router).replace("'", '"') assert ( val - == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}' + == f'{{"{MutableContainsBase.get_full_name()}": {{"items": {f_items}, "router": {f_formatted_router}}}}}' ) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 257f375142e..e528536a14b 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -528,7 +528,6 @@ def test_format_query_params(input, output): }, "dt": "1989-11-09 18:53:00+01:00", "fig": [], - "is_hydrated": False, "key": "", "map_key": "a", "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, @@ -553,7 +552,6 @@ def test_format_query_params(input, output): DateTimeState.get_full_name(): { "d": "1989-11-09", "dt": "1989-11-09 18:53:00+01:00", - "is_hydrated": False, "t": "18:53:00+01:00", "td": "11 days, 0:11:00", "router": formatted_router, From 90519fb5e7fa7f4fb829ea9755eac4ce1bfee7b4 Mon Sep 17 00:00:00 2001 From: Elijah Date: Thu, 23 Nov 2023 23:51:34 +0000 Subject: [PATCH 11/24] fix unit tests --- integration/test_background_task.py | 2 +- integration/test_call_script.py | 2 +- integration/test_client_storage.py | 2 +- integration/test_connection_banner.py | 2 +- integration/test_dynamic_routes.py | 2 +- integration/test_event_actions.py | 2 +- integration/test_event_chain.py | 2 +- integration/test_form_submit.py | 4 +- integration/test_input.py | 2 +- integration/test_login_flow.py | 2 +- integration/test_radix_themes.py | 2 +- integration/test_server_side_event.py | 2 +- integration/test_table.py | 2 +- integration/test_upload.py | 2 +- integration/test_var_operations.py | 2 +- reflex/app.py | 4 - tests/states/__init__.py | 1 + tests/states/upload.py | 6 +- tests/test_app.py | 104 ++++++++++++++++++++++++-- 19 files changed, 118 insertions(+), 29 deletions(-) diff --git a/integration/test_background_task.py b/integration/test_background_task.py index a1d2d3ba744..bc70ff01fba 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -93,7 +93,7 @@ def index() -> rx.Component: rx.button("Reset", on_click=State.reset_counter, id="reset"), ) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_call_script.py b/integration/test_call_script.py index 6f965f04ee1..95bdf37c998 100644 --- a/integration/test_call_script.py +++ b/integration/test_call_script.py @@ -135,7 +135,7 @@ def reset_(self): yield rx.call_script("inline_counter = 0; external_counter = 0") self.reset() - app = rx.App() + app = rx.App(state=rx.State) with open("assets/external.js", "w") as f: f.write(external_scripts) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 9bceafcc20a..8c7db167d8f 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -97,7 +97,7 @@ def index(): rx.box(ClientSideSubSubState.l1s, id="l1s"), ) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.add_page(index, route="/foo") app.compile() diff --git a/integration/test_connection_banner.py b/integration/test_connection_banner.py index dd4f5de1976..0468e6b535f 100644 --- a/integration/test_connection_banner.py +++ b/integration/test_connection_banner.py @@ -19,7 +19,7 @@ class State(rx.State): def index(): return rx.text("Hello World") - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 84673bf869b..782e625ee8a 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -56,7 +56,7 @@ def index(): def redirect_page(): return rx.fragment(rx.text("redirecting...")) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.add_page(index, route="/page/[page_id]", on_load=DynamicState.on_load) # type: ignore app.add_page(index, route="/static/x", on_load=DynamicState.on_load) # type: ignore diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 28853a7c38e..da444a5cfe7 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -130,7 +130,7 @@ def index(): on_click=EventActionState.on_click("outer"), # type: ignore ) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 75054f3b1b1..7d003635e2c 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -122,7 +122,7 @@ def click_yield_interim_value(self): time.sleep(0.5) self.interim_value = "final" - app = rx.App() + app = rx.App(state=rx.State) token_input = rx.input( value=State.router.session.client_token, is_read_only=True, id="token" diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 1bf614fcc42..cc36b5f252c 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -22,7 +22,7 @@ class FormState(rx.State): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): @@ -75,7 +75,7 @@ class FormState(rx.State): def form_submit(self, form_data: dict): self.form_data = form_data - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_input.py b/integration/test_input.py index 7c6fe264c5e..4a517985018 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -16,7 +16,7 @@ def FullyControlledInput(): class State(rx.State): text: str = "initial" - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_login_flow.py b/integration/test_login_flow.py index 499fbab4b75..f5363574378 100644 --- a/integration/test_login_flow.py +++ b/integration/test_login_flow.py @@ -42,7 +42,7 @@ def login(): rx.button("Do it", on_click=State.login, id="doit"), ) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.add_page(login) app.compile() diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py index 4283560231e..1f29aad20bd 100644 --- a/integration/test_radix_themes.py +++ b/integration/test_radix_themes.py @@ -81,7 +81,7 @@ def index() -> rx.Component: ) app = rx.App( - theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), + state=rx.State, theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), ) app.add_page(index) app.compile() diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py index 57e8ea9e277..31a38ba36c8 100644 --- a/integration/test_server_side_event.py +++ b/integration/test_server_side_event.py @@ -33,7 +33,7 @@ def set_value_return(self): def set_value_return_c(self): return rx.set_value("c", "") - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_table.py b/integration/test_table.py index 0bd7500863d..00e6a8b22d1 100644 --- a/integration/test_table.py +++ b/integration/test_table.py @@ -26,7 +26,7 @@ class TableState(rx.State): caption: str = "random caption" - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/integration/test_upload.py b/integration/test_upload.py index c906271e85a..04e8f3205a8 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -113,7 +113,7 @@ def index(): ), ) - app = rx.App() + app = rx.App(state=rx.State) app.add_page(index) app.compile() diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 693547c6aca..374344a87ae 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -30,7 +30,7 @@ class VarOperationState(rx.State): dict2: dict = {3: 4} html_str: str = "
hello
" - app = rx.App() + app = rx.App(state=rx.State) @app.add_page def index(): diff --git a/reflex/app.py b/reflex/app.py index bff1549fa96..42d9389e2a5 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -153,10 +153,6 @@ def __init__(self, *args, **kwargs): state_subclasses = BaseState.__subclasses__() is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ - # For tests, make rx.State the state if not provided(Typically useful for app harness tests) - if is_testing_env and not self.state: - self.state = State - # Special case to allow test cases have multiple subclasses of rx.BaseState. if not is_testing_env: # Only one Base State class is allowed. diff --git a/tests/states/__init__.py b/tests/states/__init__.py index b825bbffdcd..11e891ab4ea 100644 --- a/tests/states/__init__.py +++ b/tests/states/__init__.py @@ -6,6 +6,7 @@ from .upload import ( ChildFileUploadState, FileStateBase1, + FileStateBase2, FileUploadState, GrandChildFileUploadState, SubUploadState, diff --git a/tests/states/upload.py b/tests/states/upload.py index fa0e086d519..8abe11c2435 100644 --- a/tests/states/upload.py +++ b/tests/states/upload.py @@ -3,7 +3,7 @@ from typing import ClassVar, List import reflex as rx -from reflex.state import BaseState +from reflex.state import BaseState, State class UploadState(BaseState): @@ -38,7 +38,7 @@ async def handle_upload(self, files: List[rx.UploadFile]): pass -class FileUploadState(BaseState): +class FileUploadState(State): """The base state for uploading a file.""" img_list: List[str] @@ -80,7 +80,7 @@ async def bg_upload(self, files: List[rx.UploadFile]): pass -class FileStateBase1(BaseState): +class FileStateBase1(State): """The base state for a child FileUploadState.""" pass diff --git a/tests/test_app.py b/tests/test_app.py index 01dace37bd7..233de1f15cd 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -37,6 +37,7 @@ from .states import ( ChildFileUploadState, FileStateBase1, + FileStateBase2, FileUploadState, GenState, GrandChildFileUploadState, @@ -699,18 +700,100 @@ async def test_dict_mutation_detection__plain_list( assert result.delta == expected_delta +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# ("state", "delta"), +# [ +# ( +# FileUploadState, +# {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, +# ), +# ( +# ChildFileUploadState, +# { +# "file_state_base1.child_file_upload_state": { +# "img_list": ["image1.jpg", "image2.jpg"] +# } +# }, +# ), +# ( +# GrandChildFileUploadState, +# { +# "file_state_base1.file_state_base2.grand_child_file_upload_state": { +# "img_list": ["image1.jpg", "image2.jpg"] +# } +# }, +# ), +# ], +# ) +# async def test_upload_file(tmp_path, state, delta, token: str): +# """Test that file upload works correctly. +# +# Args: +# tmp_path: Temporary path. +# state: The state class. +# delta: Expected delta +# token: a Token. +# """ +# state._tmp_path = tmp_path +# # The App state must be the "root" of the state tree +# app = App(state=state if state is FileUploadState else FileStateBase1) +# app.event_namespace.emit = AsyncMock() # type: ignore +# current_state = await app.state_manager.get_state(token) +# data = b"This is binary data" +# +# # Create a binary IO object and write data to it +# bio = io.BytesIO() +# bio.write(data) +# +# state_name = state.get_full_name().partition(".")[2] or state.get_name() +# request_mock = unittest.mock.Mock() +# request_mock.headers = { +# "reflex-client-token": token, +# "reflex-event-handler": f"{state_name}.multi_handle_upload", +# } +# +# file1 = UploadFile( +# filename=f"image1.jpg", +# file=bio, +# ) +# file2 = UploadFile( +# filename=f"image2.jpg", +# file=bio, +# ) +# upload_fn = upload(app) +# streaming_response = await upload_fn(request_mock, [file1, file2]) +# async for state_update in streaming_response.body_iterator: +# assert ( +# state_update +# == StateUpdate(delta=delta, events=[], final=True).json() + "\n" +# ) +# +# current_state = await app.state_manager.get_state(token) +# state_dict = current_state.dict() +# for substate in state.get_full_name().split(".")[1:]: +# state_dict = state_dict[substate] +# assert state_dict["img_list"] == [ +# "image1.jpg", +# "image2.jpg", +# ] +# +# if isinstance(app.state_manager, StateManagerRedis): +# await app.state_manager.redis.close() + + @pytest.mark.asyncio @pytest.mark.parametrize( ("state", "delta"), [ ( FileUploadState, - {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, + {"state.file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, ), ( ChildFileUploadState, { - "file_state_base1.child_file_upload_state": { + "state.file_state_base1.child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, @@ -718,14 +801,14 @@ async def test_dict_mutation_detection__plain_list( ( GrandChildFileUploadState, { - "file_state_base1.file_state_base2.grand_child_file_upload_state": { + "state.file_state_base1.file_state_base2.grand_child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ], ) -async def test_upload_file(tmp_path, state, delta, token: str): +async def test_upload_file(tmp_path, state, delta, token: str, mocker): """Test that file upload works correctly. Args: @@ -734,9 +817,19 @@ async def test_upload_file(tmp_path, state, delta, token: str): delta: Expected delta token: a Token. """ + class_subclasses = { + State: {state if state is FileUploadState else FileStateBase1}, + FileUploadState : set(), + FileStateBase1: {ChildFileUploadState, FileStateBase2}, + FileStateBase2: {GrandChildFileUploadState}, + GrandChildFileUploadState: set(), + ChildFileUploadState: set(), + } + + mocker.patch("reflex.state.State.class_subclasses", class_subclasses) state._tmp_path = tmp_path # The App state must be the "root" of the state tree - app = App(state=state if state is FileUploadState else FileStateBase1) + app = App(state=State) app.event_namespace.emit = AsyncMock() # type: ignore current_state = await app.state_manager.get_state(token) data = b"This is binary data" @@ -778,7 +871,6 @@ async def test_upload_file(tmp_path, state, delta, token: str): if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.redis.close() - @pytest.mark.asyncio @pytest.mark.parametrize( "state", From b282f6b4df26c4bf020c5e04798747849b088e6d Mon Sep 17 00:00:00 2001 From: Elijah Date: Thu, 23 Nov 2023 23:55:49 +0000 Subject: [PATCH 12/24] fix upload integration test --- integration/test_upload.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration/test_upload.py b/integration/test_upload.py index 04e8f3205a8..13fb28d36a2 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -205,8 +205,8 @@ async def get_file_data(): state = await upload_file.get_state(token) if secondary: # only the secondary form tracks progress and chain events - assert state.event_order.count("upload_progress") == 1 - assert state.event_order.count("chain_event") == 1 + assert state.substates["upload_state"].event_order.count("upload_progress") == 1 + assert state.substates["upload_state"].event_order.count("chain_event") == 1 @pytest.mark.asyncio @@ -349,7 +349,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # look up the backend state and assert on progress state = await upload_file.get_state(token) - assert state.progress_dicts - assert exp_name not in state._file_data + assert state.substates["upload_state"].progress_dicts + assert exp_name not in state.substates["upload_state"]._file_data target_file.unlink() From 4324f22348dbc22f36ce42db98ccbcd7cb7aa9dc Mon Sep 17 00:00:00 2001 From: Elijah Date: Fri, 24 Nov 2023 00:20:01 +0000 Subject: [PATCH 13/24] fix all unit tests --- integration/test_radix_themes.py | 3 ++- reflex/state.py | 14 +++++++++++--- tests/middleware/test_hydrate_middleware.py | 7 +++++-- tests/test_app.py | 18 +++++++++--------- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/integration/test_radix_themes.py b/integration/test_radix_themes.py index 1f29aad20bd..d4731d20e1a 100644 --- a/integration/test_radix_themes.py +++ b/integration/test_radix_themes.py @@ -81,7 +81,8 @@ def index() -> rx.Component: ) app = rx.App( - state=rx.State, theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), + state=rx.State, + theme=rdxt.theme(rdxt.theme_panel(), accent_color="grass"), ) app.add_page(index) app.compile() diff --git a/reflex/state.py b/reflex/state.py index 4877a615f99..70cce27f126 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,6 +7,7 @@ import functools import inspect import json +import os import traceback import urllib.parse import uuid @@ -285,6 +286,7 @@ def __init_subclass__(cls, **kwargs): Args: **kwargs: The kwargs to pass to the pydantic init_subclass method. """ + reload = os.getenv(constants.RELOAD_CONFIG) == "True" super().__init_subclass__(**kwargs) # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() @@ -295,9 +297,15 @@ def __init_subclass__(cls, **kwargs): cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars # fix up parent class_substates - cls.class_subclasses[parent_state].add(cls) if not any( - [c.__name__ == cls.__name__ for c in cls.class_subclasses[parent_state]] - ) else None + if not reload: + cls.class_subclasses[parent_state].add(cls) + else: + cls.class_subclasses[parent_state].add(cls) if not any( + [ + c.__name__ == cls.__name__ + for c in cls.class_subclasses[parent_state] + ] + ) else None cls.new_backend_vars = { name: value diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 956ed1fd40f..2f21557f0a5 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -2,11 +2,12 @@ import pytest +from reflex import constants from reflex.app import App from reflex.constants import CompileVars from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.state import BaseState, StateUpdate -from reflex import constants + def exp_is_hydrated(state: BaseState) -> Dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. @@ -97,7 +98,9 @@ async def test_preprocess( event_fixture: The event fixture(an Event). expected: Expected delta. """ - test_state.add_var(constants.CompileVars.IS_HYDRATED,type_=bool, default_value=False) + test_state.add_var( + constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False + ) app = App(state=test_state, load_events={"index": [test_state.test_handler]}) state = test_state() diff --git a/tests/test_app.py b/tests/test_app.py index 233de1f15cd..f626a5eeace 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -816,10 +816,11 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): state: The state class. delta: Expected delta token: a Token. + mocker: pytest mocker object. """ class_subclasses = { State: {state if state is FileUploadState else FileStateBase1}, - FileUploadState : set(), + FileUploadState: set(), FileStateBase1: {ChildFileUploadState, FileStateBase2}, FileStateBase2: {GrandChildFileUploadState}, GrandChildFileUploadState: set(), @@ -871,6 +872,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.redis.close() + @pytest.mark.asyncio @pytest.mark.parametrize( "state", @@ -983,10 +985,7 @@ def comp_dynamic(self) -> str: @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( - index_page, - windows_platform: bool, - token: str, - mocker + index_page, windows_platform: bool, token: str, mocker ): """Create app with dynamic route var, and simulate navigation. @@ -997,12 +996,13 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: The index page. windows_platform: Whether the system is windows. token: a Token. + mocker: pytest mocker object. """ - class_subclasses = { - State: {DynamicState} - } + class_subclasses = {State: {DynamicState}} mocker.patch("reflex.state.State.class_subclasses", class_subclasses) - DynamicState.add_var(constants.CompileVars.IS_HYDRATED,type_=bool, default_value=False) + DynamicState.add_var( + constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False + ) arg_name = "dynamic" route = f"/test/[{arg_name}]" if windows_platform: From e1983ea1e4657141ee3559246f86e99f8bd5bf0f Mon Sep 17 00:00:00 2001 From: Elijah Date: Fri, 24 Nov 2023 00:36:30 +0000 Subject: [PATCH 14/24] fix merge conflicts --- reflex/compiler/compiler.py | 1 - tests/components/test_component.py | 2 -- tests/test_var.py | 1 - 3 files changed, 4 deletions(-) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index e9eea12e373..1740af50acd 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -15,7 +15,6 @@ StatefulComponent, ) from reflex.config import get_config -from reflex.state import State from reflex.state import BaseState from reflex.utils.imports import ImportVar diff --git a/tests/components/test_component.py b/tests/components/test_component.py index a576f04842b..ac5b23130d6 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -14,8 +14,6 @@ from reflex.components.layout.box import Box from reflex.constants import EventTriggers from reflex.event import EventChain, EventHandler -from reflex.state import State -from reflex.event import EventHandler from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports diff --git a/tests/test_var.py b/tests/test_var.py index 787915a9662..901b40423f4 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -6,7 +6,6 @@ from pandas import DataFrame from reflex.base import Base -from reflex.state import State from reflex.state import BaseState from reflex.vars import ( BaseVar, From 79569ae4caed59e8bb3232162ad1ff0415aeba69 Mon Sep 17 00:00:00 2001 From: Elijah Date: Mon, 27 Nov 2023 15:18:01 +0000 Subject: [PATCH 15/24] fix redis app harness test --- integration/test_client_storage.py | 14 ++++++++++++-- reflex/state.py | 21 ++++++++++----------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 8c7db167d8f..47756d9a0b3 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -245,41 +245,51 @@ async def test_client_side_state( state_var_input.send_keys("c1") input_value_input.send_keys("c1 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c2") input_value_input.send_keys("c2 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c4") input_value_input.send_keys("c4 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c5") input_value_input.send_keys("c5 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c6") input_value_input.send_keys("c6 throwaway value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c6") input_value_input.send_keys("c6 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("c7") input_value_input.send_keys("c7 value") set_sub_state_button.click() - + time.sleep(0.5) state_var_input.send_keys("l1") input_value_input.send_keys("l1 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("l2") input_value_input.send_keys("l2 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("l3") input_value_input.send_keys("l3 value") set_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("l4") input_value_input.send_keys("l4 value") set_sub_state_button.click() - + time.sleep(0.5) state_var_input.send_keys("c1s") input_value_input.send_keys("c1s value") set_sub_sub_state_button.click() + time.sleep(0.5) state_var_input.send_keys("l1s") input_value_input.send_keys("l1s value") set_sub_sub_state_button.click() diff --git a/reflex/state.py b/reflex/state.py index 70cce27f126..b1f3c977390 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,7 +7,6 @@ import functools import inspect import json -import os import traceback import urllib.parse import uuid @@ -286,7 +285,6 @@ def __init_subclass__(cls, **kwargs): Args: **kwargs: The kwargs to pass to the pydantic init_subclass method. """ - reload = os.getenv(constants.RELOAD_CONFIG) == "True" super().__init_subclass__(**kwargs) # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() @@ -297,15 +295,16 @@ def __init_subclass__(cls, **kwargs): cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars # fix up parent class_substates - if not reload: - cls.class_subclasses[parent_state].add(cls) - else: - cls.class_subclasses[parent_state].add(cls) if not any( - [ - c.__name__ == cls.__name__ - for c in cls.class_subclasses[parent_state] - ] - ) else None + + cls.class_subclasses[parent_state] = set( + [ + x + for x in cls.class_subclasses[parent_state] + if x.__name__ != cls.__name__ + ] + ) + + cls.class_subclasses[parent_state].add(cls) cls.new_backend_vars = { name: value From e7a26abaea78fb359346654a1cfee4b5349dd278 Mon Sep 17 00:00:00 2001 From: Elijah Date: Mon, 27 Nov 2023 15:56:23 +0000 Subject: [PATCH 16/24] fix unit tests --- reflex/state.py | 29 ++++++++++++++++++++--------- tests/conftest.py | 21 --------------------- tests/test_state.py | 15 +++++++++++++-- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index b1f3c977390..33c2a442699 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,6 +7,7 @@ import functools import inspect import json +import os import traceback import urllib.parse import uuid @@ -223,21 +224,13 @@ def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): parent_state: The parent state. **kwargs: The kwargs to pass to the Pydantic init method. - Raises: - ValueError: If a substate class shadows another. """ kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) # Setup the substates. for substate in self.get_substates(): - substate_name = substate.get_name() - if substate_name in self.substates: - raise ValueError( - f"The substate class '{substate_name}' has been defined multiple times. Shadowing " - f"substate classes is not allowed." - ) - self.substates[substate_name] = substate(parent_state=self) + self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. self._init_event_handlers() @@ -284,7 +277,11 @@ def __init_subclass__(cls, **kwargs): Args: **kwargs: The kwargs to pass to the pydantic init_subclass method. + + Raises: + ValueError: If a substate class shadows another. """ + is_testing_env = constants.PYTEST_CURRENT_TEST in os.environ super().__init_subclass__(**kwargs) # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() @@ -295,6 +292,20 @@ def __init_subclass__(cls, **kwargs): cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars # fix up parent class_substates + if ( + any( + [ + x + for x in cls.class_subclasses[parent_state] + if x.__name__ == cls.__name__ + ] + ) + and not is_testing_env + ): + raise ValueError( + f"The substate class '{cls.__name__}' has been defined multiple times. Shadowing " + f"substate classes is not allowed." + ) cls.class_subclasses[parent_state] = set( [ diff --git a/tests/conftest.py b/tests/conftest.py index a49b17a3df3..e5dddc47015 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,6 @@ from reflex.app import App from reflex.event import EventSpec -from reflex.state import BaseState from .states import ( DictMutationTestState, @@ -225,23 +224,3 @@ def token() -> str: A fresh/unique token string. """ return str(uuid.uuid4()) - - -@pytest.fixture -def duplicate_substate(): - """Create a Test state that has duplicate child substates. - - Returns: - The test state. - """ - - class TestState(BaseState): - pass - - class ChildTestState(TestState): # type: ignore # noqa - pass - - class ChildTestState(TestState): # type: ignore # noqa - pass - - return TestState diff --git a/tests/test_state.py b/tests/test_state.py index 421551f28c8..b4ca0c87a49 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2187,9 +2187,20 @@ def test_mutable_copy_vars(mutable_state, copy_func): assert not isinstance(var_copy, MutableProxy) -def test_duplicate_substate_class(duplicate_substate): +def test_duplicate_substate_class(mocker): + mocker.patch("reflex.state.os.environ", {}) with pytest.raises(ValueError): - duplicate_substate() + + class TestState(BaseState): + pass + + class ChildTestState(TestState): # type: ignore # noqa + pass + + class ChildTestState(TestState): # type: ignore # noqa + pass + + return TestState class Foo(Base): From 1aee58aa034607d825c8a4d00b1e48dd60a89dce Mon Sep 17 00:00:00 2001 From: Elijah Date: Mon, 27 Nov 2023 16:15:54 +0000 Subject: [PATCH 17/24] add comments --- reflex/base.py | 2 ++ reflex/state.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/reflex/base.py b/reflex/base.py index e1d8dd298f7..9efea60e0ef 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -33,6 +33,8 @@ def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None ) from te +# monkeypatch pydantic validate_field_name method to skip validating +# shadowed state vars when reloading app via utils.prerequisites.get_app(reload=True) pydantic.main.validate_field_name = validate_field_name # type: ignore diff --git a/reflex/state.py b/reflex/state.py index 33c2a442699..351fa222329 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -291,13 +291,13 @@ def __init_subclass__(cls, **kwargs): if parent_state is not None: cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars - # fix up parent class_substates + if ( any( [ - x - for x in cls.class_subclasses[parent_state] - if x.__name__ == cls.__name__ + c + for c in cls.class_subclasses[parent_state] + if c.__name__ == cls.__name__ ] ) and not is_testing_env @@ -306,15 +306,16 @@ def __init_subclass__(cls, **kwargs): f"The substate class '{cls.__name__}' has been defined multiple times. Shadowing " f"substate classes is not allowed." ) - + # clear all existing subclasses when app is reloaded via + # utils.prerequisites.get_app(reload=True) cls.class_subclasses[parent_state] = set( [ - x - for x in cls.class_subclasses[parent_state] - if x.__name__ != cls.__name__ + c + for c in cls.class_subclasses[parent_state] + if c.__name__ != cls.__name__ ] ) - + # fix up parent class_substates cls.class_subclasses[parent_state].add(cls) cls.new_backend_vars = { @@ -483,7 +484,6 @@ def get_substates(cls) -> set[Type[BaseState]]: Returns: The substates of the state. """ - # return set(cls.__subclasses__()) return cls.class_subclasses[cls] @classmethod From b6e465af88977ddcda7bd96d8af18f7979a9ddbe Mon Sep 17 00:00:00 2001 From: Elijah Date: Mon, 27 Nov 2023 16:42:43 +0000 Subject: [PATCH 18/24] remove debug lines --- tests/test_app.py | 82 ----------------------------------------------- 1 file changed, 82 deletions(-) diff --git a/tests/test_app.py b/tests/test_app.py index f626a5eeace..136e20b87ce 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -700,88 +700,6 @@ async def test_dict_mutation_detection__plain_list( assert result.delta == expected_delta -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# ("state", "delta"), -# [ -# ( -# FileUploadState, -# {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, -# ), -# ( -# ChildFileUploadState, -# { -# "file_state_base1.child_file_upload_state": { -# "img_list": ["image1.jpg", "image2.jpg"] -# } -# }, -# ), -# ( -# GrandChildFileUploadState, -# { -# "file_state_base1.file_state_base2.grand_child_file_upload_state": { -# "img_list": ["image1.jpg", "image2.jpg"] -# } -# }, -# ), -# ], -# ) -# async def test_upload_file(tmp_path, state, delta, token: str): -# """Test that file upload works correctly. -# -# Args: -# tmp_path: Temporary path. -# state: The state class. -# delta: Expected delta -# token: a Token. -# """ -# state._tmp_path = tmp_path -# # The App state must be the "root" of the state tree -# app = App(state=state if state is FileUploadState else FileStateBase1) -# app.event_namespace.emit = AsyncMock() # type: ignore -# current_state = await app.state_manager.get_state(token) -# data = b"This is binary data" -# -# # Create a binary IO object and write data to it -# bio = io.BytesIO() -# bio.write(data) -# -# state_name = state.get_full_name().partition(".")[2] or state.get_name() -# request_mock = unittest.mock.Mock() -# request_mock.headers = { -# "reflex-client-token": token, -# "reflex-event-handler": f"{state_name}.multi_handle_upload", -# } -# -# file1 = UploadFile( -# filename=f"image1.jpg", -# file=bio, -# ) -# file2 = UploadFile( -# filename=f"image2.jpg", -# file=bio, -# ) -# upload_fn = upload(app) -# streaming_response = await upload_fn(request_mock, [file1, file2]) -# async for state_update in streaming_response.body_iterator: -# assert ( -# state_update -# == StateUpdate(delta=delta, events=[], final=True).json() + "\n" -# ) -# -# current_state = await app.state_manager.get_state(token) -# state_dict = current_state.dict() -# for substate in state.get_full_name().split(".")[1:]: -# state_dict = state_dict[substate] -# assert state_dict["img_list"] == [ -# "image1.jpg", -# "image2.jpg", -# ] -# -# if isinstance(app.state_manager, StateManagerRedis): -# await app.state_manager.redis.close() - - @pytest.mark.asyncio @pytest.mark.parametrize( ("state", "delta"), From ca82705fd6c8b84a56b686b0fc3fab8a3846df3a Mon Sep 17 00:00:00 2001 From: Elijah Date: Mon, 27 Nov 2023 16:45:19 +0000 Subject: [PATCH 19/24] remove dead code --- tests/test_var.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_var.py b/tests/test_var.py index 901b40423f4..8e7358bf59d 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -24,12 +24,6 @@ ] -# class BaseState(State): -# """A Test State.""" -# -# val: str = "key" - - @pytest.fixture def TestObj(): class TestObj(Base): From 535c48b0bb7ce3b2125af619869a0014b183b038 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Tue, 28 Nov 2023 10:55:20 +0000 Subject: [PATCH 20/24] Update reflex/constants/base.py Co-authored-by: Masen Furer --- reflex/constants/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 1781aa0ce13..82cb39da8b8 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -173,4 +173,4 @@ class Ping(SimpleNamespace): # Testing variables. # Testing os env set by pytest when running a test case. PYTEST_CURRENT_TEST = "PYTEST_CURRENT_TEST" -RELOAD_CONFIG = "RELOAD_CONFIG" +RELOAD_CONFIG = "__REFLEX_RELOAD_CONFIG" From 7fcc0f2fea692fcb1aa959b512bd3a10c008c0a0 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Tue, 28 Nov 2023 10:55:36 +0000 Subject: [PATCH 21/24] Update reflex/state.py Co-authored-by: Masen Furer --- reflex/state.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 351fa222329..4d3858f07ef 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -293,13 +293,7 @@ def __init_subclass__(cls, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars if ( - any( - [ - c - for c in cls.class_subclasses[parent_state] - if c.__name__ == cls.__name__ - ] - ) + cls.__name__ in set(c.__name__ for c in cls.class_subclasses[parent_state]) and not is_testing_env ): raise ValueError( From 0ba0c8cb38c9fe51914fea806b6e80357170e205 Mon Sep 17 00:00:00 2001 From: Elijah Date: Tue, 28 Nov 2023 11:11:49 +0000 Subject: [PATCH 22/24] fix precommit --- integration/test_client_storage.py | 12 ------------ reflex/state.py | 3 ++- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 47756d9a0b3..a9a311d8c16 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -245,51 +245,39 @@ async def test_client_side_state( state_var_input.send_keys("c1") input_value_input.send_keys("c1 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c2") input_value_input.send_keys("c2 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c4") input_value_input.send_keys("c4 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c5") input_value_input.send_keys("c5 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c6") input_value_input.send_keys("c6 throwaway value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c6") input_value_input.send_keys("c6 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c7") input_value_input.send_keys("c7 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("l1") input_value_input.send_keys("l1 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("l2") input_value_input.send_keys("l2 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("l3") input_value_input.send_keys("l3 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("l4") input_value_input.send_keys("l4 value") set_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("c1s") input_value_input.send_keys("c1s value") set_sub_sub_state_button.click() - time.sleep(0.5) state_var_input.send_keys("l1s") input_value_input.send_keys("l1s value") set_sub_sub_state_button.click() diff --git a/reflex/state.py b/reflex/state.py index 4d3858f07ef..c993be8bd00 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -293,7 +293,8 @@ def __init_subclass__(cls, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars if ( - cls.__name__ in set(c.__name__ for c in cls.class_subclasses[parent_state]) + cls.__name__ + in set(c.__name__ for c in cls.class_subclasses[parent_state]) and not is_testing_env ): raise ValueError( From c36a6dfedaf31180c72c5ae4679bfdc06fc9a3cb Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Nov 2023 10:00:10 -0800 Subject: [PATCH 23/24] Only track direct subclasses of each BaseState class --- reflex/state.py | 28 +++++++++------------------- tests/test_app.py | 18 +++++------------- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index c993be8bd00..39104c14577 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -19,7 +19,6 @@ AsyncIterator, Callable, ClassVar, - DefaultDict, Dict, List, Optional, @@ -178,10 +177,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The event handlers. event_handlers: ClassVar[Dict[str, EventHandler]] = {} - # A mapping of classes and corresponding subclassses. - class_subclasses: ClassVar[ - DefaultDict[Type[BaseState], Set[Type[BaseState]]] - ] = defaultdict(set) + # A set of subclassses of this class. + class_subclasses: ClassVar[Set[Type[BaseState]]] = set() # Mapping of var name to set of computed variables that depend on it _computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {} @@ -286,6 +283,9 @@ def __init_subclass__(cls, **kwargs): # Event handlers should not shadow builtin state methods. cls._check_overridden_methods() + # Reset subclass tracking for this class. + cls.class_subclasses = set() + # Get the parent vars. parent_state = cls.get_parent_state() if parent_state is not None: @@ -293,25 +293,15 @@ def __init_subclass__(cls, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars if ( - cls.__name__ - in set(c.__name__ for c in cls.class_subclasses[parent_state]) + cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses) and not is_testing_env ): raise ValueError( f"The substate class '{cls.__name__}' has been defined multiple times. Shadowing " f"substate classes is not allowed." ) - # clear all existing subclasses when app is reloaded via - # utils.prerequisites.get_app(reload=True) - cls.class_subclasses[parent_state] = set( - [ - c - for c in cls.class_subclasses[parent_state] - if c.__name__ != cls.__name__ - ] - ) - # fix up parent class_substates - cls.class_subclasses[parent_state].add(cls) + # Track this new subclass in the parent state's subclasses set. + parent_state.class_subclasses.add(cls) cls.new_backend_vars = { name: value @@ -479,7 +469,7 @@ def get_substates(cls) -> set[Type[BaseState]]: Returns: The substates of the state. """ - return cls.class_subclasses[cls] + return cls.class_subclasses @classmethod @functools.lru_cache() diff --git a/tests/test_app.py b/tests/test_app.py index 136e20b87ce..8b35edbfb5c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -37,7 +37,6 @@ from .states import ( ChildFileUploadState, FileStateBase1, - FileStateBase2, FileUploadState, GenState, GrandChildFileUploadState, @@ -736,16 +735,10 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): token: a Token. mocker: pytest mocker object. """ - class_subclasses = { - State: {state if state is FileUploadState else FileStateBase1}, - FileUploadState: set(), - FileStateBase1: {ChildFileUploadState, FileStateBase2}, - FileStateBase2: {GrandChildFileUploadState}, - GrandChildFileUploadState: set(), - ChildFileUploadState: set(), - } - - mocker.patch("reflex.state.State.class_subclasses", class_subclasses) + mocker.patch( + "reflex.state.State.class_subclasses", + {state if state is FileUploadState else FileStateBase1}, + ) state._tmp_path = tmp_path # The App state must be the "root" of the state tree app = App(state=State) @@ -916,8 +909,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( token: a Token. mocker: pytest mocker object. """ - class_subclasses = {State: {DynamicState}} - mocker.patch("reflex.state.State.class_subclasses", class_subclasses) + mocker.patch("reflex.state.State.class_subclasses", {DynamicState}) DynamicState.add_var( constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False ) From 68d83527e70480f9c565f673c8192588e91a5aec Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 28 Nov 2023 22:16:37 -0800 Subject: [PATCH 24/24] state: remove duplicate subclass names when testing When running AppHarness or unit tests via pytest, delete duplicate state names (and basically redefine the state tree with the new class). In a previous patch i removed this bit of code (https://github.com/reflex-dev/reflex/pull/2227/files#diff-704efe440192781e072231490ac1d12defd1aba5f6debafbc7907eebcc2abcc7L304-L314). But it turns out to be necessary. So I restructured the duplicate name check a bit to indicate that this subclass resetting should only occur in testing modes, and expand a bit on why we're doing it that way. --- reflex/state.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 39104c14577..f6b10e84964 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -292,14 +292,23 @@ def __init_subclass__(cls, **kwargs): cls.inherited_vars = parent_state.vars cls.inherited_backend_vars = parent_state.backend_vars - if ( - cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses) - and not is_testing_env - ): - raise ValueError( - f"The substate class '{cls.__name__}' has been defined multiple times. Shadowing " - f"substate classes is not allowed." - ) + # Check if another substate class with the same name has already been defined. + if cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses): + if is_testing_env: + # Clear existing subclass with same name when app is reloaded via + # utils.prerequisites.get_app(reload=True) + parent_state.class_subclasses = set( + c + for c in parent_state.class_subclasses + if c.__name__ != cls.__name__ + ) + else: + # During normal operation, subclasses cannot have the same name, even if they are + # defined in different modules. + raise ValueError( + f"The substate class '{cls.__name__}' has been defined multiple times. " + "Shadowing substate classes is not allowed." + ) # Track this new subclass in the parent state's subclasses set. parent_state.class_subclasses.add(cls)