Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-1885] Shard Substates when serializing to Redis #2574

Merged
merged 14 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration/test_client_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def set_sub_sub(var: str, value: str):
assert l1s.text == "l1s value"

# reset the backend state to force refresh from client storage
async with client_side.modify_state(token) as state:
async with client_side.modify_state(f"{token}_state.client_side_state") as state:
state.reset()
driver.refresh()

Expand Down
11 changes: 8 additions & 3 deletions integration/test_dynamic_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def dynamic_route(
"""
with app_harness_env.create(
root=tmp_path_factory.mktemp(f"dynamic_route"),
app_name=f"dynamicroute_{app_harness_env.__name__.lower()}",
app_source=DynamicRoute, # type: ignore
) as harness:
yield harness
Expand Down Expand Up @@ -146,7 +147,7 @@ def poll_for_order(

async def _poll_for_order(exp_order: list[str]):
async def _backend_state():
return await dynamic_route.get_state(token)
return await dynamic_route.get_state(f"{token}_state.dynamic_state")

async def _check():
return (await _backend_state()).substates[
Expand Down Expand Up @@ -194,7 +195,9 @@ async def test_on_load_navigate(
assert link
assert page_id_input

assert dynamic_route.poll_for_value(page_id_input) == str(ix)
assert dynamic_route.poll_for_value(
page_id_input, exp_not_equal=str(ix - 1)
) == str(ix)
assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}/"
await poll_for_order(exp_order)

Expand All @@ -220,7 +223,9 @@ async def test_on_load_navigate(
with poll_for_navigation(driver):
driver.get(f"{driver.current_url}?foo=bar")
await poll_for_order(exp_order)
assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar"
assert (
await dynamic_route.get_state(f"{token}_state.dynamic_state")
).router.page.params["foo"] == "bar"

# hit a 404 and ensure we still hydrate
exp_order += ["/404-no page id"]
Expand Down
2 changes: 1 addition & 1 deletion integration/test_event_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def poll_for_order(

async def _poll_for_order(exp_order: list[str]):
async def _backend_state():
return await event_action.get_state(token)
return await event_action.get_state(f"{token}_state.event_action_state")

async def _check():
return (await _backend_state()).substates[
Expand Down
2 changes: 1 addition & 1 deletion integration/test_event_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
token = event_chain.poll_for_value(token_input)
assert token is not None

return token
return f"{token}_state.state"


@pytest.mark.parametrize(
Expand Down
6 changes: 5 additions & 1 deletion integration/test_form_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ async def test_submit(driver, form_submit: AppHarness):
submit_input.click()

async def get_form_data():
return (await form_submit.get_state(token)).substates["form_state"].form_data
return (
(await form_submit.get_state(f"{token}_state.form_state"))
.substates["form_state"]
.form_data
)

# wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data)
Expand Down
20 changes: 9 additions & 11 deletions integration/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
token = fully_controlled_input.poll_for_value(token_input)
assert token

async def get_state_text():
state = await fully_controlled_input.get_state(f"{token}_state.state")
return state.substates["state"].text

# find the input and wait for it to have the initial state value
debounce_input = driver.find_element(By.ID, "debounce_input_input")
value_input = driver.find_element(By.ID, "value_input")
Expand All @@ -95,16 +99,14 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "ifoonitial"
assert await get_state_text() == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial"

# clear the input on the backend
async with fully_controlled_input.modify_state(token) as state:
async with fully_controlled_input.modify_state(f"{token}_state.state") as state:
state.substates["state"].text = ""
assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
assert await get_state_text() == ""
assert (
fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial"
Expand All @@ -116,9 +118,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "getting testing done"
assert await get_state_text() == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
assert (
fully_controlled_input.poll_for_value(plain_value_input)
Expand All @@ -130,9 +130,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "overwrite the state"
assert await get_state_text() == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
assert (
fully_controlled_input.poll_for_value(plain_value_input)
Expand Down
19 changes: 15 additions & 4 deletions integration/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ async def test_upload_file(
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"

suffix = "_secondary" if secondary else ""

Expand All @@ -191,7 +192,11 @@ async def test_upload_file(

# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
return (
(await upload_file.get_state(substate_token))
.substates["upload_state"]
._file_data
)

file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
Expand All @@ -201,7 +206,7 @@ async def get_file_data():
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == exp_name

state = await upload_file.get_state(token)
state = await upload_file.get_state(substate_token)
if secondary:
# only the secondary form tracks progress and chain events
assert state.substates["upload_state"].event_order.count("upload_progress") == 1
Expand All @@ -223,6 +228,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"

upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
assert upload_box
Expand Down Expand Up @@ -250,7 +256,11 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):

# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
return (
(await upload_file.get_state(substate_token))
.substates["upload_state"]
._file_data
)

file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
Expand Down Expand Up @@ -330,6 +340,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"

upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
upload_button = driver.find_element(By.ID, f"upload_button_secondary")
Expand All @@ -347,7 +358,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
cancel_button.click()

# look up the backend state and assert on progress
state = await upload_file.get_state(token)
state = await upload_file.get_state(substate_token)
assert state.substates["upload_state"].progress_dicts
assert exp_name not in state.substates["upload_state"]._file_data

Expand Down
7 changes: 4 additions & 3 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ async def process(
}
)
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.token) as state:
async with app.state_manager.modify_state(event.substate_token) as state:
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
Expand Down Expand Up @@ -1002,7 +1002,8 @@ async def upload_file(request: Request, files: List[UploadFile]):
)

# Get the state for the session.
state = await app.state_manager.get_state(token)
substate_token = token + "_" + handler.rpartition(".")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can pull this out into a function get_substater_token(token, handler)

state = await app.state_manager.get_state(substate_token)

# get the current session ID
# get the current state(parent state/substate)
Expand Down Expand Up @@ -1049,7 +1050,7 @@ async def _ndjson_updates():
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(token) as state:
async with app.state_manager.modify_state(event.substate_token) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
Expand Down
10 changes: 10 additions & 0 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class Event(Base):
# The event payload.
payload: Dict[str, Any] = {}

@property
def substate_token(self) -> str:
"""Get the substate token for the event.

Returns:
The substate token.
"""
substate = self.name.rpartition(".")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we pull it out, then we can reuse the function here

return f"{self.token}_{substate}"


BACKGROUND_TASK_MARKER = "_reflex_background_task"

Expand Down
Loading
Loading