Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
AsyncIterator,
Callable,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Expand Down Expand Up @@ -178,10 +177,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}

# A mapping of classes and corresponding subclassses.
class_subclasses: ClassVar[
DefaultDict[Type[BaseState], Set[Type[BaseState]]]
] = defaultdict(set)
# A set of subclassses of this class.
class_subclasses: ClassVar[Set[Type[BaseState]]] = set()

# Mapping of var name to set of computed variables that depend on it
_computed_var_dependencies: ClassVar[Dict[str, Set[str]]] = {}
Expand Down Expand Up @@ -286,32 +283,25 @@ def __init_subclass__(cls, **kwargs):
# Event handlers should not shadow builtin state methods.
cls._check_overridden_methods()

# Reset subclass tracking for this class.
cls.class_subclasses = set()

# Get the parent vars.
parent_state = cls.get_parent_state()
if parent_state is not None:
cls.inherited_vars = parent_state.vars
cls.inherited_backend_vars = parent_state.backend_vars

if (
cls.__name__
in set(c.__name__ for c in cls.class_subclasses[parent_state])
cls.__name__ in set(c.__name__ for c in parent_state.class_subclasses)
and not is_testing_env
):
raise ValueError(
f"The substate class '{cls.__name__}' has been defined multiple times. Shadowing "
f"substate classes is not allowed."
)
# clear all existing subclasses when app is reloaded via
# utils.prerequisites.get_app(reload=True)
cls.class_subclasses[parent_state] = set(
[
c
for c in cls.class_subclasses[parent_state]
if c.__name__ != cls.__name__
]
)
# fix up parent class_substates
cls.class_subclasses[parent_state].add(cls)
# Track this new subclass in the parent state's subclasses set.
parent_state.class_subclasses.add(cls)

cls.new_backend_vars = {
name: value
Expand Down Expand Up @@ -479,7 +469,7 @@ def get_substates(cls) -> set[Type[BaseState]]:
Returns:
The substates of the state.
"""
return cls.class_subclasses[cls]
return cls.class_subclasses

@classmethod
@functools.lru_cache()
Expand Down
18 changes: 5 additions & 13 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .states import (
ChildFileUploadState,
FileStateBase1,
FileStateBase2,
FileUploadState,
GenState,
GrandChildFileUploadState,
Expand Down Expand Up @@ -736,16 +735,10 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
token: a Token.
mocker: pytest mocker object.
"""
class_subclasses = {
State: {state if state is FileUploadState else FileStateBase1},
FileUploadState: set(),
FileStateBase1: {ChildFileUploadState, FileStateBase2},
FileStateBase2: {GrandChildFileUploadState},
GrandChildFileUploadState: set(),
ChildFileUploadState: set(),
}

mocker.patch("reflex.state.State.class_subclasses", class_subclasses)
mocker.patch(
"reflex.state.State.class_subclasses",
{state if state is FileUploadState else FileStateBase1},
)
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=State)
Expand Down Expand Up @@ -916,8 +909,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
token: a Token.
mocker: pytest mocker object.
"""
class_subclasses = {State: {DynamicState}}
mocker.patch("reflex.state.State.class_subclasses", class_subclasses)
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
DynamicState.add_var(
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
)
Expand Down