Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-889] useContext per substate #2149

Merged
merged 36 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
27cd95c
Each substate has its own context
masenf Nov 7, 2023
be6fa14
Fix delta checking in test_app
masenf Nov 7, 2023
139c01a
Use _var_data throughout
masenf Nov 7, 2023
9c81f8a
Account for all imports and hooks where they are used
masenf Nov 7, 2023
e8de0e9
Fixup hydrate middleware and _var_state
masenf Nov 7, 2023
27f527c
Fixup tests
masenf Nov 7, 2023
da6a367
Do not always create new vars in style.py
masenf Nov 7, 2023
c9ebed4
Fixup change from State.dict returning flat structure
masenf Nov 8, 2023
aa574e5
Cond var must be _var_is_local=False
masenf Nov 8, 2023
846fe51
Component: do not mutate event_triggers
masenf Nov 8, 2023
b9960d3
Fixup static issues
masenf Nov 8, 2023
4b63a8b
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 9, 2023
5157ed2
client_side_routing: use `rx.cond` instead of the `Cond` component
masenf Nov 13, 2023
06d8fa2
Fixup color_mode_toggle var
masenf Nov 13, 2023
4afd4e8
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 13, 2023
34c3c2c
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 14, 2023
a29ca2e
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 14, 2023
78ddc79
Add tests for Component._get_vars
masenf Nov 15, 2023
62b9332
Apply CR feedback (from myself)
masenf Nov 15, 2023
454226f
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 15, 2023
a356bcb
ensure 'from reflex.var import ImportVar' keeps working
masenf Nov 15, 2023
9329e18
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 19, 2023
90159b0
cond: Carry VarData before `format_cond`
masenf Nov 19, 2023
1ee5702
upload: uploadFiles is not actually a required import
masenf Nov 19, 2023
1ef7dee
Update pyi files
masenf Nov 19, 2023
006ca66
Move Bare special case into Bare class
masenf Nov 21, 2023
3c54c66
use the | operator instead of union
masenf Nov 21, 2023
475694b
remove weird indentation
masenf Nov 21, 2023
0f36629
cond: remove unnecessary `elif`
masenf Nov 21, 2023
96b6090
remove relative imports
masenf Nov 21, 2023
70cb517
style: improve comments and explainations
masenf Nov 21, 2023
569ff86
component: add comments and give the code some space to breathe
masenf Nov 21, 2023
00d137e
Remove unused NoRenderImportVar
masenf Nov 21, 2023
bf385fb
client_side_routing.pyi: remove relative imports
masenf Nov 21, 2023
2230d10
Merge remote-tracking branch 'origin/main' into masenf/context-per-su…
masenf Nov 21, 2023
1d78788
Performance optimizations
masenf Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions integration/test_var_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class VarOperationState(rx.State):
str_var4: str = "a long string"
dict1: dict = {1: 2}
dict2: dict = {3: 4}
html_str: str = "<div>hello</div>"

app = rx.App(state=VarOperationState)

Expand Down Expand Up @@ -522,6 +523,19 @@ def index():
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
rx.text(VarOperationState.list3.join(""), id="list_join"),
rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
# Index from an op var
rx.text(
VarOperationState.list3[VarOperationState.int_var1 % 3],
id="list_index_mod",
),
rx.html(
VarOperationState.html_str,
id="html_str",
),
rx.highlight(
"second",
query=[VarOperationState.str_var2],
),
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
Expand Down Expand Up @@ -713,7 +727,14 @@ def test_var_operations(driver, var_operations: AppHarness):
("dict_eq_dict", "false"),
("dict_neq_dict", "true"),
("dict_contains", "true"),
# index from an op var
("list_index_mod", "second"),
# html component with var
("html_str", "hello"),
]

for tag, expected in tests:
assert driver.find_element(By.ID, tag).text == expected

# Highlight component with var query (does not plumb ID)
assert driver.find_element(By.TAG_NAME, "mark").text == "second"
12 changes: 7 additions & 5 deletions reflex/.templates/jinja/web/pages/_app.js.jinja2
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% extends "web/pages/base_page.js.jinja2" %}

{% block declaration %}
import { EventLoopProvider } from "/utils/context.js";
import { EventLoopProvider, StateProvider } from "/utils/context.js";
import { ThemeProvider } from 'next-themes'

{% for custom_code in custom_codes %}
Expand All @@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) {
return (
<ThemeProvider defaultTheme="light" storageKey="chakra-ui-color-mode" attribute="class">
<AppWrap>
<EventLoopProvider>
<Component {...pageProps} />
</EventLoopProvider>
<StateProvider>
<EventLoopProvider>
<Component {...pageProps} />
</EventLoopProvider>
</StateProvider>
</AppWrap>
</ThemeProvider>
);
}

{% endblock %}
{% endblock %}
26 changes: 0 additions & 26 deletions reflex/.templates/jinja/web/pages/index.js.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,6 @@

{% block export %}
export default function Component() {
{% if state_name %}
const {{state_name}} = useContext(StateContext)
{% endif %}
const {{const.router}} = useRouter()
const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext)
const focusRef = useRef();

// Main event loop.
const [addEvents, connectError] = useContext(EventLoopContext)

// Set focus to the specified element.
useEffect(() => {
if (focusRef.current) {
focusRef.current.focus();
}
})

// Route after the initial page hydration.
useEffect(() => {
const change_complete = () => addEvents(initialEvents())
{{const.router}}.events.on('routeChangeComplete', change_complete)
return () => {
{{const.router}}.events.off('routeChangeComplete', change_complete)
}
}, [{{const.router}}])

{% for hook in hooks %}
{{ hook }}
{% endfor %}
Expand Down
47 changes: 38 additions & 9 deletions reflex/.templates/jinja/web/utils/context.js.jinja2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { createContext, useState } from "react"
import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"
import { createContext, useContext, useMemo, useReducer, useState } from "react"
import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js"

{% if initial_state %}
export const initialState = {{ initial_state|json_dumps }}
Expand All @@ -8,7 +8,12 @@ export const initialState = {}
{% endif %}

export const ColorModeContext = createContext(null);
export const StateContext = createContext(null);
export const DispatchContext = createContext(null);
export const StateContexts = {
{% for state_name in initial_state %}
{{state_name|var_name}}: createContext(null),
{% endfor %}
}
export const EventLoopContext = createContext(null);
{% if client_storage %}
export const clientStorage = {{ client_storage|json_dumps }}
Expand All @@ -27,16 +32,40 @@ export const initialEvents = () => []
export const isDevMode = {{ is_dev_mode|json_dumps }}

export function EventLoopProvider({ children }) {
const [state, addEvents, connectError] = useEventLoop(
initialState,
const dispatch = useContext(DispatchContext)
const [addEvents, connectError] = useEventLoop(
dispatch,
initialEvents,
clientStorage,
)
return (
<EventLoopContext.Provider value={[addEvents, connectError]}>
<StateContext.Provider value={state}>
{children}
</StateContext.Provider>
{children}
</EventLoopContext.Provider>
)
}
}

export function StateProvider({ children }) {
{% for state_name in initial_state %}
const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"])
{% endfor %}
const dispatchers = useMemo(() => {
return {
{% for state_name in initial_state %}
"{{state_name}}": dispatch_{{state_name|var_name}},
{% endfor %}
}
}, [])

return (
{% for state_name in initial_state %}
<StateContexts.{{state_name|var_name}}.Provider value={ {{state_name|var_name}} }>
{% endfor %}
<DispatchContext.Provider value={dispatchers}>
{children}
</DispatchContext.Provider>
{% for state_name in initial_state|reverse %}
</StateContexts.{{state_name|var_name}}.Provider>
{% endfor %}
)
}
59 changes: 21 additions & 38 deletions reflex/.templates/web/utils/state.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import env from "env.json";
import Cookies from "universal-cookie";
import { useEffect, useReducer, useRef, useState } from "react";
import Router, { useRouter } from "next/router";
import { initialEvents } from "utils/context.js"
import { initialEvents, initialState } from "utils/context.js"

// Endpoint URLs.
const EVENTURL = env.EVENT
Expand Down Expand Up @@ -100,37 +100,10 @@ export const getEventURL = () => {
* @param delta The delta to apply.
*/
export const applyDelta = (state, delta) => {
const new_state = { ...state }
for (const substate in delta) {
let s = new_state;
const path = substate.split(".").slice(1);
while (path.length > 0) {
s = s[path.shift()];
}
for (const key in delta[substate]) {
s[key] = delta[substate][key];
}
}
return new_state
return { ...state, ...delta }
};


/**
* Get all local storage items in a key-value object.
* @returns object of items in local storage.
*/
export const getAllLocalStorageItems = () => {
var localStorageItems = {};

for (var i = 0, len = localStorage.length; i < len; i++) {
var key = localStorage.key(i);
localStorageItems[key] = localStorage.getItem(key);
}

return localStorageItems;
}


/**
* Handle frontend event or send the event to the backend via Websocket.
* @param event The event to send.
Expand Down Expand Up @@ -346,7 +319,9 @@ export const connect = async (
// On each received message, queue the updates and events.
socket.current.on("event", message => {
const update = JSON5.parse(message)
dispatch(update.delta)
for (const substate in update.delta) {
dispatch[substate](update.delta[substate])
}
applyClientStorageDelta(client_storage, update.delta)
event_processing = !update.final
if (update.events) {
Expand Down Expand Up @@ -524,23 +499,21 @@ const applyClientStorageDelta = (client_storage, delta) => {

/**
* Establish websocket event loop for a NextJS page.
* @param initial_state The initial app state.
* @param initial_events Function that returns the initial app events.
* @param dispatch The reducer dispatch function to update state.
* @param initial_events The initial app events.
* @param client_storage The client storage object from context.js
*
* @returns [state, addEvents, connectError] -
* state is a reactive dict,
* @returns [addEvents, connectError] -
* addEvents is used to queue an event, and
* connectError is a reactive js error from the websocket connection (or null if connected).
*/
export const useEventLoop = (
initial_state = {},
dispatch,
initial_events = () => [],
client_storage = {},
) => {
const socket = useRef(null)
const router = useRouter()
const [state, dispatch] = useReducer(applyDelta, initial_state)
const [connectError, setConnectError] = useState(null)

// Function to add new events to the event queue.
Expand Down Expand Up @@ -570,7 +543,7 @@ export const useEventLoop = (
return;
}
// only use websockets if state is present
if (Object.keys(state).length > 0) {
if (Object.keys(initialState).length > 0) {
// Initialize the websocket connection.
if (!socket.current) {
connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage)
Expand All @@ -583,7 +556,17 @@ export const useEventLoop = (
})()
}
})
return [state, addEvents, connectError]

// Route after the initial page hydration.
useEffect(() => {
const change_complete = () => addEvents(initial_events())
router.events.on('routeChangeComplete', change_complete)
return () => {
router.events.off('routeChangeComplete', change_complete)
}
}, [router])

return [addEvents, connectError]
}

/***
Expand Down
2 changes: 1 addition & 1 deletion reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
StateUpdate,
)
from reflex.utils import console, format, prerequisites, types
from reflex.vars import ImportVar
from reflex.utils.imports import ImportVar

# Define custom types.
ComponentCallable = Callable[[], Component]
Expand Down
34 changes: 2 additions & 32 deletions reflex/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,10 @@
from reflex.components.component import Component, ComponentStyle, CustomComponent
from reflex.config import get_config
from reflex.state import State
from reflex.utils import imports
from reflex.vars import ImportVar
from reflex.utils.imports import ImportDict, ImportVar

# Imports to be included in every Reflex app.
DEFAULT_IMPORTS: imports.ImportDict = {
"react": {
ImportVar(tag="Fragment"),
ImportVar(tag="useEffect"),
ImportVar(tag="useRef"),
ImportVar(tag="useState"),
ImportVar(tag="useContext"),
},
"next/router": {ImportVar(tag="useRouter")},
f"/{constants.Dirs.STATE_PATH}": {
ImportVar(tag="uploadFiles"),
ImportVar(tag="Event"),
ImportVar(tag="isTrue"),
ImportVar(tag="spreadArraysOrObjects"),
ImportVar(tag="preventDefault"),
ImportVar(tag="refs"),
ImportVar(tag="getRefValue"),
ImportVar(tag="getRefValues"),
ImportVar(tag="getAllLocalStorageItems"),
ImportVar(tag="useEventLoop"),
},
"/utils/context.js": {
ImportVar(tag="EventLoopContext"),
ImportVar(tag="initialEvents"),
ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"),
},
"/utils/helpers/range.js": {
ImportVar(tag="range", is_default=True),
},
DEFAULT_IMPORTS: ImportDict = {
picklelo marked this conversation as resolved.
Show resolved Hide resolved
"": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)},
}

Expand Down
3 changes: 2 additions & 1 deletion reflex/compiler/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jinja2 import Environment, FileSystemLoader, Template

from reflex import constants
from reflex.utils.format import json_dumps
from reflex.utils.format import format_state_name, json_dumps


class ReflexJinjaEnvironment(Environment):
Expand All @@ -19,6 +19,7 @@ def __init__(self) -> None:
)
self.filters["json_dumps"] = json_dumps
self.filters["react_setter"] = lambda state: f"set{state.capitalize()}"
self.filters["var_name"] = format_state_name
self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE)
self.globals["const"] = {
"socket": constants.CompileVars.SOCKET,
Expand Down
7 changes: 4 additions & 3 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from reflex.state import Cookie, LocalStorage, State
from reflex.style import Style
from reflex.utils import console, format, imports, path_ops
from reflex.vars import ImportVar

# To re-export this function.
merge_imports = imports.merge_imports


def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]:
def compile_import_statement(fields: set[imports.ImportVar]) -> tuple[str, set[str]]:
"""Compile an import statement.

Args:
Expand Down Expand Up @@ -343,7 +342,9 @@ def get_context_path() -> str:
Returns:
The path of the context module.
"""
return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS)
return os.path.join(
constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS
)


def get_components_path() -> str:
Expand Down
6 changes: 5 additions & 1 deletion reflex/components/base/bare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def create(cls, contents: Any) -> Component:
Returns:
The component.
"""
return cls(contents=str(contents)) # type: ignore
if isinstance(contents, Var) and contents._var_data:
contents = contents.to(str)
else:
contents = str(contents)
return cls(contents=contents) # type: ignore

def _render(self) -> Tag:
return Tagless(contents=str(self.contents))