From 27cd95c0a9d7c2957fbde8bd721754e5be1368c8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 09:49:49 -0800 Subject: [PATCH 01/29] Each substate has its own context Var has _var_hooks and _var_imports Instead of relying on the page/component having the appropriate imports statically, allow these to be dynamically calculated for greater flexibility --- integration/test_var_operations.py | 21 ++ .../.templates/jinja/web/pages/_app.js.jinja2 | 12 +- .../jinja/web/pages/index.js.jinja2 | 17 - .../jinja/web/utils/context.js.jinja2 | 47 ++- reflex/.templates/web/utils/state.js | 43 ++- reflex/compiler/compiler.py | 8 - reflex/compiler/templates.py | 3 +- reflex/compiler/utils.py | 4 +- reflex/components/base/bare.py | 6 +- reflex/components/component.py | 98 +++++- reflex/components/datadisplay/datatable.py | 6 +- reflex/components/layout/cond.py | 3 +- reflex/components/layout/html.py | 4 +- reflex/constants/base.py | 2 + reflex/state.py | 14 +- reflex/style.py | 70 +++- reflex/utils/format.py | 15 + reflex/vars.py | 325 ++++++++++++++---- reflex/vars.pyi | 3 +- tests/test_style.py | 3 +- tests/test_var.py | 60 +++- 21 files changed, 593 insertions(+), 171 deletions(-) diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 4514527312..4934acf001 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -28,6 +28,7 @@ class VarOperationState(rx.State): str_var4: str = "a long string" dict1: dict = {1: 2} dict2: dict = {3: 4} + html_str: str = "
hello
" app = rx.App(state=VarOperationState) @@ -522,6 +523,19 @@ def index(): rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"), rx.text(VarOperationState.list3.join(""), id="list_join"), rx.text(VarOperationState.list3.join(","), id="list_join_comma"), + # Index from an op var + rx.text( + VarOperationState.list3[VarOperationState.int_var1 % 3], + id="list_index_mod", + ), + rx.html( + VarOperationState.html_str, + id="html_str", + ), + rx.highlight( + "second", + query=[VarOperationState.str_var2], + ), ) app.compile() @@ -705,7 +719,14 @@ def test_var_operations(driver, var_operations: AppHarness): ("dict_eq_dict", "false"), ("dict_neq_dict", "true"), ("dict_contains", "true"), + # index from an op var + ("list_index_mod", "second"), + # html component with var + ("html_str", "hello"), ] for tag, expected in tests: assert driver.find_element(By.ID, tag).text == expected + + # Highlight component with var query (does not plumb ID) + assert driver.find_element(By.TAG_NAME, "mark").text == "second" diff --git a/reflex/.templates/jinja/web/pages/_app.js.jinja2 b/reflex/.templates/jinja/web/pages/_app.js.jinja2 index 4d3dff89ad..deaf1a02be 100644 --- a/reflex/.templates/jinja/web/pages/_app.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/_app.js.jinja2 @@ -1,7 +1,7 @@ {% extends "web/pages/base_page.js.jinja2" %} {% block declaration %} -import { EventLoopProvider } from "/utils/context.js"; +import { EventLoopProvider, StateProvider } from "/utils/context.js"; import { ThemeProvider } from 'next-themes' {% for custom_code in custom_codes %} @@ -25,12 +25,14 @@ export default function MyApp({ Component, pageProps }) { return ( - - - + + + + + ); } -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index 6f73c70c4a..56323d5a73 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,15 +8,7 @@ {% block export %} export default function Component() { -{% if state_name %} - const {{state_name}} = useContext(StateContext) -{% endif %} - const {{const.router}} = useRouter() - const [ {{const.color_mode}}, {{const.toggle_color_mode}} ] = useContext(ColorModeContext) const focusRef = useRef(); - - // Main event loop. - const [addEvents, connectError] = useContext(EventLoopContext) // Set focus to the specified element. useEffect(() => { @@ -25,15 +17,6 @@ export default function Component() { } }) - // Route after the initial page hydration. - useEffect(() => { - const change_complete = () => addEvents(initialEvents()) - {{const.router}}.events.on('routeChangeComplete', change_complete) - return () => { - {{const.router}}.events.off('routeChangeComplete', change_complete) - } - }, [{{const.router}}]) - {% for hook in hooks %} {{ hook }} {% endfor %} diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index c931b75154..53d7d4e58d 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -1,5 +1,5 @@ -import { createContext, useState } from "react" -import { Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" +import { createContext, useContext, useMemo, useReducer, useState } from "react" +import { applyDelta, Event, hydrateClientStorage, useEventLoop } from "/utils/state.js" {% if initial_state %} export const initialState = {{ initial_state|json_dumps }} @@ -8,7 +8,12 @@ export const initialState = {} {% endif %} export const ColorModeContext = createContext(null); -export const StateContext = createContext(null); +export const DispatchContext = createContext(null); +export const StateContexts = { + {% for state_name in initial_state %} + {{state_name|var_name}}: createContext(null), + {% endfor %} +} export const EventLoopContext = createContext(null); {% if client_storage %} export const clientStorage = {{ client_storage|json_dumps }} @@ -27,16 +32,40 @@ export const initialEvents = () => [] export const isDevMode = {{ is_dev_mode|json_dumps }} export function EventLoopProvider({ children }) { - const [state, addEvents, connectError] = useEventLoop( - initialState, + const dispatch = useContext(DispatchContext) + const [addEvents, connectError] = useEventLoop( + dispatch, initialEvents, clientStorage, ) return ( - - {children} - + {children} ) -} \ No newline at end of file +} + +export function StateProvider({ children }) { + {% for state_name in initial_state %} + const [{{state_name|var_name}}, dispatch_{{state_name|var_name}}] = useReducer(applyDelta, initialState["{{state_name}}"]) + {% endfor %} + const dispatchers = useMemo(() => { + return { + {% for state_name in initial_state %} + "{{state_name}}": dispatch_{{state_name|var_name}}, + {% endfor %} + } + }, []) + + return ( + {% for state_name in initial_state %} + + {% endfor %} + + {children} + + {% for state_name in initial_state|reverse %} + + {% endfor %} + ) +} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 6ffe72ded4..1c7410a2af 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -6,7 +6,7 @@ import env from "env.json"; import Cookies from "universal-cookie"; import { useEffect, useReducer, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; -import { initialEvents } from "utils/context.js" +import { initialEvents, initialState } from "utils/context.js" // Endpoint URLs. const EVENTURL = env.EVENT @@ -95,18 +95,7 @@ export const getEventURL = () => { * @param delta The delta to apply. */ export const applyDelta = (state, delta) => { - const new_state = { ...state } - for (const substate in delta) { - let s = new_state; - const path = substate.split(".").slice(1); - while (path.length > 0) { - s = s[path.shift()]; - } - for (const key in delta[substate]) { - s[key] = delta[substate][key]; - } - } - return new_state + return { ...state, ...delta } }; @@ -333,7 +322,9 @@ export const connect = async ( // On each received message, queue the updates and events. socket.current.on("event", message => { const update = JSON5.parse(message) - dispatch(update.delta) + for (const substate in update.delta) { + dispatch[substate](update.delta[substate]) + } applyClientStorageDelta(client_storage, update.delta) event_processing = !update.final if (update.events) { @@ -475,23 +466,21 @@ const applyClientStorageDelta = (client_storage, delta) => { /** * Establish websocket event loop for a NextJS page. - * @param initial_state The initial app state. - * @param initial_events Function that returns the initial app events. + * @param dispatch The reducer dispatch function to update state. + * @param initial_events The initial app events. * @param client_storage The client storage object from context.js * - * @returns [state, addEvents, connectError] - - * state is a reactive dict, + * @returns [addEvents, connectError] - * addEvents is used to queue an event, and * connectError is a reactive js error from the websocket connection (or null if connected). */ export const useEventLoop = ( - initial_state = {}, + dispatch, initial_events = () => [], client_storage = {}, ) => { const socket = useRef(null) const router = useRouter() - const [state, dispatch] = useReducer(applyDelta, initial_state) const [connectError, setConnectError] = useState(null) // Function to add new events to the event queue. @@ -521,7 +510,7 @@ export const useEventLoop = ( return; } // only use websockets if state is present - if (Object.keys(state).length > 0) { + if (Object.keys(initialState).length > 0) { // Initialize the websocket connection. if (!socket.current) { connect(socket, dispatch, ['websocket', 'polling'], setConnectError, client_storage) @@ -534,7 +523,17 @@ export const useEventLoop = ( })() } }) - return [state, addEvents, connectError] + + // Route after the initial page hydration. + useEffect(() => { + const change_complete = () => addEvents(initialEvents()) + router.events.on('routeChangeComplete', change_complete) + return () => { + router.events.off('routeChangeComplete', change_complete) + } + }, [router]) + + return [addEvents, connectError] } /*** diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 7ee5c6ded1..8e3f87e695 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -25,7 +25,6 @@ "next/router": {ImportVar(tag="useRouter")}, f"/{constants.Dirs.STATE_PATH}": { ImportVar(tag="uploadFiles"), - ImportVar(tag="Event"), ImportVar(tag="isTrue"), ImportVar(tag="spreadArraysOrObjects"), ImportVar(tag="preventDefault"), @@ -33,13 +32,6 @@ ImportVar(tag="getRefValue"), ImportVar(tag="getRefValues"), ImportVar(tag="getAllLocalStorageItems"), - ImportVar(tag="useEventLoop"), - }, - "/utils/context.js": { - ImportVar(tag="EventLoopContext"), - ImportVar(tag="initialEvents"), - ImportVar(tag="StateContext"), - ImportVar(tag="ColorModeContext"), }, "": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, } diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index f2d1272aa5..57bbb44b8c 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -3,7 +3,7 @@ from jinja2 import Environment, FileSystemLoader, Template from reflex import constants -from reflex.utils.format import json_dumps +from reflex.utils.format import format_state_name, json_dumps class ReflexJinjaEnvironment(Environment): @@ -19,6 +19,7 @@ def __init__(self) -> None: ) self.filters["json_dumps"] = json_dumps self.filters["react_setter"] = lambda state: f"set{state.capitalize()}" + self.filters["var_name"] = format_state_name self.loader = FileSystemLoader(constants.Templates.Dirs.JINJA_TEMPLATE) self.globals["const"] = { "socket": constants.CompileVars.SOCKET, diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index dc7c384a4a..b07412812a 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -343,7 +343,9 @@ def get_context_path() -> str: Returns: The path of the context module. """ - return os.path.join(constants.Dirs.WEB_UTILS, "context" + constants.Ext.JS) + return os.path.join( + constants.Dirs.WEB, constants.Dirs.CONTEXTS_PATH + constants.Ext.JS + ) def get_components_path() -> str: diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 190e95e6a0..00ba0a81de 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -24,7 +24,11 @@ def create(cls, contents: Any) -> Component: Returns: The component. """ - return cls(contents=str(contents)) # type: ignore + if isinstance(contents, Var) and (contents._var_imports or contents._var_hooks): + contents = contents.to(str) + else: + contents = str(contents) + return cls(contents=contents) # type: ignore def _render(self) -> Tag: return Tagless(contents=str(self.contents)) diff --git a/reflex/components/component.py b/reflex/components/component.py index 2c9b7e6ccc..7f224cad71 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -5,7 +5,7 @@ import typing from abc import ABC from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Set, Type, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type, Union from reflex.base import Base from reflex.components.tags import Tag @@ -553,6 +553,56 @@ def validate_valid_child(child_name): if self.valid_children: validate_valid_child(name) + @staticmethod + def _get_vars_from_event_triggers( + event_triggers: dict[str, EventChain | Var], + ) -> Iterator[str, list[Var]]: + """Get the Vars associated with each event trigger. + + Args: + event_triggers: The event triggers from the component instance. + + Yields: + tuple of (event_name, event_vars) + """ + for event_trigger, event in event_triggers.items(): + if isinstance(event, Var): + yield event_trigger, [event] + elif isinstance(event, EventChain): + event_args = [] + for spec in event.events: + for args in spec.args: + event_args.extend(args) + yield event_trigger, event_args + + def _get_vars(self) -> Iterator[Var]: + """Walk all Vars used in this component. + + Yields: + Each var referenced by the component (props, styles, event handlers). + """ + from reflex.components.base.bare import Bare + + if isinstance(self, Bare): + if isinstance(self.contents, Var): + yield self.contents + else: + for _, vars in self._get_vars_from_event_triggers(self.event_triggers): + yield from vars + + for prop in self.get_props(): + prop_var = getattr(self, prop) + if isinstance(prop_var, Var): + yield prop_var + + if self.style: + yield BaseVar( + _var_name="style", + _var_type=str, + _var_imports=self.style._var_imports, + _var_hooks=self.style._var_hooks, + ) + def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -644,11 +694,21 @@ def _get_imports(self) -> imports.ImportDict: _imports = {} if self.library is not None and self.tag is not None: _imports[self.library] = {self.import_var} - + event_imports = {} + if self.event_triggers: + event_imports = { + f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, + f"/{Dirs.STATE_PATH}": {ImportVar(tag="Event")}, + "react": {ImportVar(tag="useContext")}, + } + # determine imports from Vars + var_imports = [var._var_imports for var in self._get_vars()] return imports.merge_imports( self._get_props_imports(), self._get_dependencies_imports(), _imports, + event_imports, + *var_imports, ) def get_imports(self) -> imports.ImportDict: @@ -694,6 +754,28 @@ def _get_ref_hook(self) -> str | None: if ref is not None: return f"const {ref} = useRef(null); refs['{ref}'] = {ref};" + def _get_vars_hooks(self) -> set[str]: + """Get the hooks required by vars referenced in this component. + + Returns: + The hooks for the vars. + """ + vars_hooks = set() + for var in self._get_vars(): + vars_hooks.update(var._var_hooks) + return vars_hooks + + def _get_events_hooks(self) -> str[str]: + """Get the hooks required by events referenced in this component. + + Returns: + The hooks for the events. + """ + # TODO: use constants here for better indirection + if self.event_triggers: + return {"const [addEvents, connectError] = useContext(EventLoopContext);"} + return set() + def _get_hooks_internal(self) -> Set[str]: """Get the React hooks for this component managed by the framework. @@ -703,10 +785,14 @@ def _get_hooks_internal(self) -> Set[str]: Returns: Set of internally managed hooks. """ - return set( - hook - for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] - if hook + return ( + set( + hook + for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] + if hook + ) + .union(self._get_vars_hooks()) + .union(self._get_events_hooks()) ) def _get_hooks(self) -> str | None: diff --git a/reflex/components/datadisplay/datatable.py b/reflex/components/datadisplay/datatable.py index 52bd45282a..d9fe49f30e 100644 --- a/reflex/components/datadisplay/datatable.py +++ b/reflex/components/datadisplay/datatable.py @@ -113,13 +113,11 @@ def _render(self) -> Tag: self.columns = BaseVar( _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], - _var_state=self.data._var_state, - ) + )._var_set_state(self.data._var_state) self.data = BaseVar( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], - _var_state=self.data._var_state, - ) + )._var_set_state(self.data._var_state) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index b584b0f0cb..4439902dea 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -117,7 +117,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): raise ValueError("For conditional vars, the second argument must be set.") # Create the conditional var. - return BaseVar( + return cond_var._replace( _var_name=format.format_cond( cond=cond_var._var_full_name, true_value=c1, @@ -125,4 +125,5 @@ def cond(condition: Any, c1: Any, c2: Any = None): is_prop=True, ), _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1), + _var_full_name_needs_state_prefix=False, ) diff --git a/reflex/components/layout/html.py b/reflex/components/layout/html.py index 893e2ecca2..df155ffde0 100644 --- a/reflex/components/layout/html.py +++ b/reflex/components/layout/html.py @@ -1,8 +1,8 @@ """A html component.""" -from typing import Any from reflex.components.layout.box import Box +from reflex.vars import Var class Html(Box): @@ -13,7 +13,7 @@ class Html(Box): """ # The HTML to render. - dangerouslySetInnerHTML: Any + dangerouslySetInnerHTML: Var[dict[str, str]] @classmethod def create(cls, *children, **props): diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 8957edfb77..0b28e18cb2 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -29,6 +29,8 @@ class Dirs(SimpleNamespace): STATE_PATH = "/".join([UTILS, "state"]) # The name of the components file. COMPONENTS_PATH = "/".join([UTILS, "components"]) + # The name of the contexts file. + CONTEXTS_PATH = "/".join([UTILS, "context"]) # The directory where the app pages are compiled to. WEB_PAGES = os.path.join(WEB, "pages") # The directory where the static build is located. diff --git a/reflex/state.py b/reflex/state.py index 9462c901b2..73b3638920 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1206,12 +1206,16 @@ def dict(self, include_computed: bool = True, **kwargs) -> dict[str, Any]: if include_computed else {} ) - substate_vars = { - k: v.dict(include_computed=include_computed, **kwargs) - for k, v in self.substates.items() + variables = {**base_vars, **computed_vars} + d = { + self.get_full_name(): {k: variables[k] for k in sorted(variables)}, } - variables = {**base_vars, **computed_vars, **substate_vars} - return {k: variables[k] for k in sorted(variables)} + for substate_d in [ + v.dict(include_computed=include_computed, **kwargs) + for v in self.substates.values() + ]: + d.update(substate_d) + return d async def __aenter__(self) -> State: """Enter the async context manager protocol. diff --git a/reflex/style.py b/reflex/style.py index bb8320163b..b2f4c67943 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -2,13 +2,31 @@ from __future__ import annotations +from typing import Any + from reflex import constants from reflex.event import EventChain -from reflex.utils import format -from reflex.vars import BaseVar, Var +from reflex.utils import format, imports +from reflex.vars import BaseVar, ImportVar, Var -color_mode = BaseVar(_var_name=constants.ColorMode.NAME, _var_type="str") -toggle_color_mode = BaseVar(_var_name=constants.ColorMode.TOGGLE, _var_type=EventChain) +color_mode_imports = { + f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, +} +color_mode_hooks = { + f"const [ {{{constants.ColorMode.NAME}}}, {{{constants.ColorMode.TOGGLE}}} ] = useContext(ColorModeContext)", +} +color_mode = BaseVar( + _var_name=constants.ColorMode.NAME, + _var_type="str", + _var_imports=color_mode_imports, + _var_hooks=color_mode_hooks, +) +toggle_color_mode = BaseVar( + _var_name=constants.ColorMode.TOGGLE, + _var_type=EventChain, + _var_imports=color_mode_imports, + _var_hooks=color_mode_hooks, +) def convert(style_dict): @@ -20,16 +38,19 @@ def convert(style_dict): Returns: The formatted style dictionary. """ + var_data = Var.create("") out = {} for key, value in style_dict.items(): key = format.to_camel_case(key) if isinstance(value, dict): - out[key] = convert(value) - elif isinstance(value, Var): - out[key] = str(value) + out[key], new_var_data = convert(value) else: - out[key] = value - return out + new_var_data = Var.create(value, _var_is_string=True) + out[key] = str(new_var_data) + var_data = var_data._replace( + add_imports=new_var_data._var_imports, add_hooks=new_var_data._var_hooks + ) + return out, var_data class Style(dict): @@ -41,5 +62,32 @@ def __init__(self, style_dict: dict | None = None): Args: style_dict: The style dictionary. """ - style_dict = style_dict or {} - super().__init__(convert(style_dict)) + style_dict, var_data = convert(style_dict or {}) + self._var_imports = var_data._var_imports + self._var_hooks = var_data._var_hooks + super().__init__(style_dict) + + def update(self, style_dict: dict | None = None): + """Update the style. + + Args: + style_dict: The style dictionary. + """ + converted_dict = type(self)(style_dict) + self._var_imports = imports.merge_imports( + self._var_imports, converted_dict._var_imports + ) + self._var_hooks.update(converted_dict._var_hooks) + super().update(converted_dict) + + def __setitem__(self, key: str, value: Any): + """Set an item in the style. + + Args: + key: The key to set. + value: The value to set. + """ + _var = Var.create(value) + self._var_imports = imports.merge_imports(self._var_imports, _var._var_imports) + self._var_hooks.update(_var._var_hooks) + super().__setitem__(key, value) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 7d18460b5e..10ac78d266 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -520,6 +520,21 @@ def format_state(value: Any) -> Any: raise TypeError(f"No JSON serializer found for var {value} of type {type(value)}.") +def format_state_name(state_name: str) -> str: + """Format a state name, replacing dots with double underscore. + + This allows individual substates to be accessed independently as javascript vars + without using dot notation. + + Args: + state_name: The state name to format. + + Returns: + The formatted state name. + """ + return state_name.replace(".", "__") + + def format_ref(ref: str) -> str: """Format a ref. diff --git a/reflex/vars.py b/reflex/vars.py index e5a7357e75..6083bf067a 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -6,6 +6,7 @@ import dis import json import random +import re import string import sys from types import CodeType, FunctionType @@ -14,6 +15,7 @@ Any, Callable, Dict, + Iterable, List, Literal, Optional, @@ -92,6 +94,122 @@ def get_unique_variable_name() -> str: return get_unique_variable_name() +def _merge_imports( + *import_dicts: dict[str, set[ImportVar]] +) -> dict[str, set[ImportVar]]: + """Merge multiple import dicts. + + This is the same as reflex.utils.imports.merge_imports, except + that causes a circular import, so define this helper to avoid + inline imports all over this module. + + Args: + *import_dicts: The import dicts to merge. + + Returns: + The merged import dict. + """ + from reflex.utils import imports + + return imports.merge_imports(*import_dicts) + + +def _encode_var(value: Var) -> str: + """Encode the state name into a formatted var. + + Args: + value: The value to encode the state name into. + + Returns: + The encoded var. + """ + if value._var_imports or value._var_hooks: + data = { + "_var_imports": { + lib: [iv.dict() for iv in imports] + for lib, imports in value._var_imports.items() + }, + "_var_hooks": list(value._var_hooks), + } + return f"{json.dumps(data)}" + str(value) + return str(value) + + +def _decode_var(value: str) -> tuple[list[str], str]: + """Decode the state name from a formatted var. + + Args: + value: The value to extract the state name from. + + Returns: + The extracted state name and the value without the state name. + """ + _var_data = {} + if isinstance(value, str): + # Extract the state name from a formatted var + while m := re.match(r"(.*)(.*)(.*)", value): + value = m.group(1) + m.group(3) + _var_data = json.loads(m.group(2)) + _var_data["_var_hooks"] = set(_var_data["_var_hooks"]) + _var_data["_var_imports"] = { + lib: set(ImportVar(**iv_json) for iv_json in ivs_json) + for lib, ivs_json in _var_data["_var_imports"].items() + } + return _var_data, value + + +def _merge_var_data(*var_datas: dict[str, Any]) -> dict[str, Any]: + """Merge multiple var data dicts. + + Args: + *var_datas: The var data dicts to merge. + + Returns: + The merged var data dict. + """ + _var_imports = {} + _var_hooks = set() + for _var_data in var_datas: + _var_imports = _merge_imports(_var_imports, _var_data.get("_var_imports", {})) + _var_hooks.update(_var_data.get("_var_hooks", set())) + return {"_var_imports": _var_imports, "_var_hooks": _var_hooks} + + +def _extract_var_data(value: Iterable) -> list[str]: + """Extract the var imports and hooks from an iterable containing a Var. + + Args: + value: The iterable to extract the state name from. + + Returns: + The extracted state name. + """ + _var_data = {} + for sub in value: + if isinstance(sub, Var): + _var_data = _merge_var_data( + _var_data, + { + "_var_imports": sub._var_imports, + "_var_hooks": sub._var_hooks, + }, + ) + elif not isinstance(sub, str): + # Recurse into dict values + if hasattr(sub, "values") and callable(sub.values): + _var_data = _merge_var_data( + _var_data, + _extract_var_data(sub.values()), + ) + # Recurse into iterable values (or dict keys) + with contextlib.suppress(TypeError): + _var_data = _merge_var_data( + _var_data, + _extract_var_data(sub), + ) + return _var_data + + class Var: """An abstract var.""" @@ -110,6 +228,15 @@ class Var: # Whether the var is a string literal. _var_is_string: bool + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool + + # Imports needed to render this var + _var_imports: dict[str, set[ImportVar]] + + # All substates that this var depends on + _var_hooks: set[str] + @classmethod def create( cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False @@ -135,6 +262,12 @@ def create( if isinstance(value, Var): return value + # Try to pull the imports and hooks from contained values. + _var_data = {} + if not isinstance(value, str): + with contextlib.suppress(TypeError): + _var_data = _extract_var_data(value) + # Try to serialize the value. type_ = type(value) name = serializers.serialize(value) @@ -149,6 +282,8 @@ def create( _var_type=type_, _var_is_local=_var_is_local, _var_is_string=_var_is_string, + _var_imports=_var_data.get("_var_imports", {}), + _var_hooks=_var_data.get("_var_hooks", set()), ) @classmethod @@ -185,6 +320,41 @@ def __class_getitem__(cls, type_: str) -> _GenericAlias: """ return _GenericAlias(cls, type_) + def __post_init__(self) -> None: + """Post-initialize the var.""" + # Decode any inline Var markup and apply it to the instance + _var_data, _var_name = _decode_var(self._var_name) + if _var_data: + self._var_name = _var_name + self._var_hooks.update(_var_data.get("_var_hooks", set())) + self._var_imports = _merge_imports( + self._var_imports, + _var_data.get("_var_imports", {}), + ) + + def _replace(self, add_imports=None, add_hooks=None, **kwargs: Any) -> Var: + # Cannot use dataclasses.replace because ComputedVar uses multiple inheritance + # and it's __init__ has a required fget argument + _var_imports = _merge_imports( + kwargs.pop("_var_imports", self._var_imports), + add_imports or {}, + ) + _var_hooks = kwargs.pop("_var_hooks", self._var_hooks).union(add_hooks or set()) + field_values = dict( + _var_name=kwargs.pop("_var_name", self._var_name), + _var_type=kwargs.pop("_var_type", self._var_type), + _var_state=kwargs.pop("_var_state", self._var_state), + _var_is_local=kwargs.pop("_var_is_local", self._var_is_local), + _var_is_string=kwargs.pop("_var_is_string", self._var_is_string), + _var_full_name_needs_state_prefix=kwargs.pop( + "_var_full_name_needs_state_prefix", + self._var_full_name_needs_state_prefix, + ), + _var_imports=_var_imports, + _var_hooks=_var_hooks, + ) + return BaseVar(**field_values) + def _decode(self) -> Any: """Decode Var as a python value. @@ -217,6 +387,10 @@ def equals(self, other: Var) -> bool: and self._var_type == other._var_type and self._var_state == other._var_state and self._var_is_local == other._var_is_local + and self._var_full_name_needs_state_prefix + == other._var_full_name_needs_state_prefix + and self._var_imports == other._var_imports + and self._var_hooks == other._var_hooks ) def to_string(self, json: bool = True) -> Var: @@ -284,9 +458,11 @@ def __format__(self, format_spec: str) -> str: Returns: The formatted var. """ + # Encode the _var_imports and _var_hooks into the formatted output for tracking purposes. + str_self = _encode_var(self) if self._var_is_local: - return str(self) - return f"${str(self)}" + return str_self + return f"${str_self}" def __getitem__(self, i: Any) -> Var: """Index into a var. @@ -319,12 +495,7 @@ def __getitem__(self, i: Any) -> Var: # Convert any vars to local vars. if isinstance(i, Var): - i = BaseVar( - _var_name=i._var_name, - _var_type=i._var_type, - _var_state=i._var_state, - _var_is_local=True, - ) + i = i._replace(_var_is_local=True) # Handle list/tuple/str indexing. if types._issubclass(self._var_type, Union[List, Tuple, str]): @@ -343,11 +514,9 @@ def __getitem__(self, i: Any) -> Var: stop = i.stop or "undefined" # Use the slice function. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}.slice({start}, {stop})", - _var_type=self._var_type, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) # Get the type of the indexed var. @@ -358,11 +527,10 @@ def __getitem__(self, i: Any) -> Var: ) # Use `at` to support negative indices. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}.at({i})", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) # Dictionary / dataframe indexing. @@ -392,11 +560,10 @@ def __getitem__(self, i: Any) -> Var: ) # Use normal indexing here. - return BaseVar( + return self._replace( _var_name=f"{self._var_name}[{i}]", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) def __getattr__(self, name: str) -> Var: @@ -422,11 +589,10 @@ def __getattr__(self, name: str) -> Var: type_ = types.get_attribute_access_type(self._var_type, name) if type_ is not None: - return BaseVar( + return self._replace( _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}", _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, + _var_is_string=False, ) if name in REPLACED_NAMES: @@ -518,10 +684,13 @@ def operation( else f"{self._var_full_name}.{fn}()" ) - return BaseVar( + return self._replace( _var_name=operation_name, _var_type=type_, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, + add_imports=other._var_imports if other is not None else None, + add_hooks=other._var_hooks if other is not None else None, ) @staticmethod @@ -601,10 +770,10 @@ def length(self) -> Var: """ if not types._issubclass(self._var_type, List): raise TypeError(f"Cannot get length of non-list var {self}.") - return BaseVar( - _var_name=f"{self._var_full_name}.length", + return self._replace( + _var_name=f"{self._var_name}.length", _var_type=int, - _var_is_local=self._var_is_local, + _var_is_string=False, ) def __eq__(self, other: Var) -> Var: @@ -754,10 +923,11 @@ def __mul__(self, other: Var, flip=True) -> Var: ]: other_name = other._var_full_name if isinstance(other, Var) else other name = f"Array({other_name}).fill().map(() => {self._var_full_name}).flat()" - return BaseVar( + return self._replace( _var_name=name, _var_type=str, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, ) return self.operation("*", other) @@ -1002,10 +1172,12 @@ def contains(self, other: Any) -> Var: elif not isinstance(other, Var): other = Var.create(other) if types._issubclass(self._var_type, Dict): - return BaseVar( - _var_name=f"{self._var_full_name}.{method}({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.{method}({other._var_full_name})", _var_type=bool, - _var_is_local=self._var_is_local, + _var_is_string=False, + add_imports=other._var_imports, + add_hooks=other._var_hooks, ) else: # str, list, tuple # For strings, the left operand must be a string. @@ -1015,10 +1187,12 @@ def contains(self, other: Any) -> Var: raise TypeError( f"'in ' requires string as left operand, not {other._var_type}" ) - return BaseVar( - _var_name=f"{self._var_full_name}.includes({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.includes({other._var_full_name})", _var_type=bool, - _var_is_local=self._var_is_local, + _var_is_string=False, + add_imports=other._var_imports, + add_hooks=other._var_hooks, ) def reverse(self) -> Var: @@ -1033,10 +1207,10 @@ def reverse(self) -> Var: if not types._issubclass(self._var_type, list): raise TypeError(f"Cannot reverse non-list var {self._var_full_name}.") - return BaseVar( + return self._replace( _var_name=f"[...{self._var_full_name}].reverse()", - _var_type=self._var_type, - _var_is_local=self._var_is_local, + _var_is_string=False, + _var_full_name_needs_state_prefix=False, ) def lower(self) -> Var: @@ -1053,10 +1227,10 @@ def lower(self) -> Var: f"Cannot convert non-string var {self._var_full_name} to lowercase." ) - return BaseVar( - _var_name=f"{self._var_full_name}.toLowerCase()", + return self._replace( + _var_name=f"{self._var_name}.toLowerCase()", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, ) def upper(self) -> Var: @@ -1073,10 +1247,10 @@ def upper(self) -> Var: f"Cannot convert non-string var {self._var_full_name} to uppercase." ) - return BaseVar( - _var_name=f"{self._var_full_name}.toUpperCase()", + return self._replace( + _var_name=f"{self._var_name}.toUpperCase()", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, ) def split(self, other: str | Var[str] = " ") -> Var: @@ -1096,10 +1270,12 @@ def split(self, other: str | Var[str] = " ") -> Var: other = Var.create_safe(json.dumps(other)) if isinstance(other, str) else other - return BaseVar( - _var_name=f"{self._var_full_name}.split({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.split({other._var_full_name})", + _var_is_string=False, _var_type=list[str], - _var_is_local=self._var_is_local, + add_imports=other._var_imports, + add_hooks=other._var_hooks, ) def join(self, other: str | Var[str] | None = None) -> Var: @@ -1124,10 +1300,12 @@ def join(self, other: str | Var[str] | None = None) -> Var: else: other = Var.create_safe(other) - return BaseVar( - _var_name=f"{self._var_full_name}.join({other._var_full_name})", + return self._replace( + _var_name=f"{self._var_name}.join({other._var_full_name})", + _var_is_string=False, _var_type=str, - _var_is_local=self._var_is_local, + add_imports=other._var_imports, + add_hooks=other._var_hooks, ) def foreach(self, fn: Callable) -> Var: @@ -1143,10 +1321,9 @@ def foreach(self, fn: Callable) -> Var: _var_name=get_unique_variable_name(), _var_type=self._var_type, ) - return BaseVar( - _var_name=f"{self._var_full_name}.map(({arg._var_name}, i) => {fn(arg, key='i')})", - _var_type=self._var_type, - _var_is_local=self._var_is_local, + return self._replace( + _var_name=f"{self._var_name}.map(({arg._var_name}, i) => {fn(arg, key='i')})", + _var_is_string=False, ) def to(self, type_: Type) -> Var: @@ -1158,12 +1335,7 @@ def to(self, type_: Type) -> Var: Returns: The converted var. """ - return BaseVar( - _var_name=self._var_name, - _var_type=type_, - _var_state=self._var_state, - _var_is_local=self._var_is_local, - ) + return self._replace(_var_type=type_) @property def _var_full_name(self) -> str: @@ -1172,22 +1344,40 @@ def _var_full_name(self) -> str: Returns: The full name of the var. """ + if not self._var_full_name_needs_state_prefix: + return self._var_name return ( self._var_name if self._var_state == "" - else ".".join([self._var_state, self._var_name]) + else ".".join([format.format_state_name(self._var_state), self._var_name]) ) - def _var_set_state(self, state: Type[State]) -> Any: + def _var_set_state(self, state: Type[State] | str) -> Any: """Set the state of the var. Args: - state: The state to set. + state: The state to set or the full name of the state. Returns: The var with the set state. """ - self._var_state = state.get_full_name() + if isinstance(state, str): + self._var_state = state + else: + self._var_state = state.get_full_name() + self._var_hooks.add( + "const {0} = useContext(StateContexts.{0})".format( + format.format_state_name(self._var_state) + ) + ) + self._var_imports = _merge_imports( + self._var_imports, + { + f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="StateContexts")}, + "react": {ImportVar(tag="useContext")}, + }, + ) + self._var_full_name_needs_state_prefix = True return self @@ -1213,6 +1403,15 @@ class BaseVar(Var): # Whether the var is a string literal. _var_is_string: bool = dataclasses.field(default=False) + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False) + + # Imports needed to render this var + _var_imports: dict[str, set[ImportVar]] = dataclasses.field(default_factory=dict) + + # All substates that this var depends on + _var_hooks: set[str] = dataclasses.field(default_factory=set) + def __hash__(self) -> int: """Define a hash function for a var. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 1040e202bc..750420b577 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -22,6 +22,7 @@ from typing import ( USED_VARIABLES: Incomplete def get_unique_variable_name() -> str: ... +def _decode_var_state(value: str) -> tuple[str, str]: ... class Var: _var_name: str @@ -88,7 +89,7 @@ class Var: def to(self, type_: Type) -> Var: ... @property def _var_full_name(self) -> str: ... - def _var_set_state(self, state: Type[State]) -> Any: ... + def _var_set_state(self, state: Type[State] | str) -> Any: ... @dataclass(eq=False) class BaseVar(Var): diff --git a/tests/test_style.py b/tests/test_style.py index 8b09f9ac09..a8fcf68398 100644 --- a/tests/test_style.py +++ b/tests/test_style.py @@ -22,7 +22,8 @@ def test_convert(style_dict, expected): style_dict: The style to check. expected: The expected formatted style. """ - assert style.convert(style_dict) == expected + converted_dict, _var_data = style.convert(style_dict) + assert converted_dict == expected @pytest.mark.parametrize( diff --git a/tests/test_var.py b/tests/test_var.py index 9efb5fb783..04af19b008 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -17,8 +17,10 @@ test_vars = [ BaseVar(_var_name="prop1", _var_type=int), BaseVar(_var_name="key", _var_type=str), - BaseVar(_var_name="value", _var_type=str, _var_state="state"), - BaseVar(_var_name="local", _var_type=str, _var_state="state", _var_is_local=True), + BaseVar(_var_name="value", _var_type=str)._var_set_state("state"), + BaseVar(_var_name="local", _var_type=str, _var_is_local=True)._var_set_state( + "state" + ), BaseVar(_var_name="local2", _var_type=str, _var_is_local=True), ] @@ -263,7 +265,7 @@ def test_basic_operations(TestObj): assert str(v([1, 2, 3])[v(0)]) == "{[1, 2, 3].at(0)}" assert str(v({"a": 1, "b": 2})["a"]) == '{{"a": 1, "b": 2}["a"]}' assert ( - str(BaseVar(_var_name="foo", _var_state="state", _var_type=TestObj).bar) + str(BaseVar(_var_name="foo", _var_type=TestObj)._var_set_state("state").bar) == "{state.foo.bar}" ) assert str(abs(v(1))) == "{Math.abs(1)}" @@ -274,7 +276,7 @@ def test_basic_operations(TestObj): assert str(v([1, 2, 3]).reverse()) == "{[...[1, 2, 3]].reverse()}" assert str(v(["1", "2", "3"]).reverse()) == '{[...["1", "2", "3"]].reverse()}' assert ( - str(BaseVar(_var_name="foo", _var_state="state", _var_type=list).reverse()) + str(BaseVar(_var_name="foo", _var_type=list)._var_set_state("state").reverse()) == "{[...state.foo].reverse()}" ) assert ( @@ -288,11 +290,14 @@ def test_basic_operations(TestObj): [ (v([1, 2, 3]), "[1, 2, 3]"), (v(["1", "2", "3"]), '["1", "2", "3"]'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=list), "state.foo"), + (BaseVar(_var_name="foo", _var_type=list)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=list), "foo"), (v((1, 2, 3)), "[1, 2, 3]"), (v(("1", "2", "3")), '["1", "2", "3"]'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=tuple), "state.foo"), + ( + BaseVar(_var_name="foo", _var_type=tuple)._var_set_state("state"), + "state.foo", + ), (BaseVar(_var_name="foo", _var_type=tuple), "foo"), ], ) @@ -301,7 +306,7 @@ def test_list_tuple_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.includes("1")}}' assert str(var.contains(v(1))) == f"{{{expected}.includes(1)}}" assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}" assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}" @@ -311,14 +316,14 @@ def test_list_tuple_contains(var, expected): "var, expected", [ (v("123"), json.dumps("123")), - (BaseVar(_var_name="foo", _var_state="state", _var_type=str), "state.foo"), + (BaseVar(_var_name="foo", _var_type=str)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=str), "foo"), ], ) def test_str_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.includes("1")}}' assert str(var.contains(v("1"))) == f'{{{expected}.includes("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert str(var.contains(other_state_var)) == f"{{{expected}.includes(state.other)}}" assert str(var.contains(other_var)) == f"{{{expected}.includes(other)}}" @@ -328,7 +333,7 @@ def test_str_contains(var, expected): "var, expected", [ (v({"a": 1, "b": 2}), '{"a": 1, "b": 2}'), - (BaseVar(_var_name="foo", _var_state="state", _var_type=dict), "state.foo"), + (BaseVar(_var_name="foo", _var_type=dict)._var_set_state("state"), "state.foo"), (BaseVar(_var_name="foo", _var_type=dict), "foo"), ], ) @@ -337,7 +342,7 @@ def test_dict_contains(var, expected): assert str(var.contains("1")) == f'{{{expected}.hasOwnProperty("1")}}' assert str(var.contains(v(1))) == f"{{{expected}.hasOwnProperty(1)}}" assert str(var.contains(v("1"))) == f'{{{expected}.hasOwnProperty("1")}}' - other_state_var = BaseVar(_var_name="other", _var_state="state", _var_type=str) + other_state_var = BaseVar(_var_name="other", _var_type=str)._var_set_state("state") other_var = BaseVar(_var_name="other", _var_type=str) assert ( str(var.contains(other_state_var)) @@ -630,8 +635,8 @@ def test_import_var(import_var, expected): [ (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( - f"testing f-string with {BaseVar(_var_name='myvar', _var_state='state', _var_type=int)}", - "testing f-string with ${state.myvar}", + f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", + "testing f-string with $_var_state=state{state.myvar}", ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", @@ -643,6 +648,35 @@ def test_fstrings(out, expected): assert out == expected +@pytest.mark.parametrize( + ("value", "expect_state"), + [ + ([1], ""), + ({"a": 1}, ""), + ([Var.create_safe(1)._var_set_state("foo")], "foo"), + ({"a": Var.create_safe(1)._var_set_state("foo")}, "foo"), + ], +) +def test_extract_state_from_container(value, expect_state): + """Test that _var_state is extracted from containers containing BaseVar. + + Args: + value: The value to create a var from. + expect_state: The expected state. + """ + assert Var.create_safe(value)._var_state == expect_state + + +def test_fstring_roundtrip(): + """Test that f-string roundtrip carries state.""" + var = BaseVar.create_safe("var")._var_set_state("state") + rt_var = Var.create_safe(f"{var}") + assert var._var_state == rt_var._var_state + assert var._var_full_name_needs_state_prefix + assert not rt_var._var_full_name_needs_state_prefix + assert rt_var._var_name == var._var_full_name + + @pytest.mark.parametrize( "var", [ From be6fa14ed6c90990f2cd38c51d32731473bd5faa Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 10:41:52 -0800 Subject: [PATCH 02/29] Fix delta checking in test_app --- tests/test_app.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_app.py b/tests/test_app.py index 619b2775f4..188f2a8983 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -764,9 +764,7 @@ async def test_upload_file(tmp_path, state, delta, token: str): "event", state_update.json(), to=current_state.router.session.session_id ) current_state = await app.state_manager.get_state(token) - state_dict = current_state.dict() - for substate in state.get_full_name().split(".")[1:]: - state_dict = state_dict[substate] + state_dict = current_state.dict()[state.get_full_name()] assert state_dict["img_list"] == [ "image1.jpg", "image2.jpg", From 139c01ad49d0a4e70452fa641f5a2b6c3b9c9478 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 12:17:22 -0800 Subject: [PATCH 03/29] Use _var_data throughout Ensure that adding additional Var metadata in the future does not require lots of changes everywhere, by placing all carryable metadata into _var_data --- reflex/components/base/bare.py | 2 +- reflex/components/component.py | 12 +- reflex/style.py | 54 ++++--- reflex/vars.py | 254 ++++++++++++++------------------- reflex/vars.pyi | 16 ++- 5 files changed, 157 insertions(+), 181 deletions(-) diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 00ba0a81de..021cd4cf05 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -24,7 +24,7 @@ def create(cls, contents: Any) -> Component: Returns: The component. """ - if isinstance(contents, Var) and (contents._var_imports or contents._var_hooks): + if isinstance(contents, Var) and contents._var_data: contents = contents.to(str) else: contents = str(contents) diff --git a/reflex/components/component.py b/reflex/components/component.py index 7f224cad71..466c1686ee 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -599,8 +599,7 @@ def _get_vars(self) -> Iterator[Var]: yield BaseVar( _var_name="style", _var_type=str, - _var_imports=self.style._var_imports, - _var_hooks=self.style._var_hooks, + _var_data=self.style._var_data, ) def _get_custom_code(self) -> str | None: @@ -702,7 +701,9 @@ def _get_imports(self) -> imports.ImportDict: "react": {ImportVar(tag="useContext")}, } # determine imports from Vars - var_imports = [var._var_imports for var in self._get_vars()] + var_imports = [ + var._var_data.imports for var in self._get_vars() if var._var_data + ] return imports.merge_imports( self._get_props_imports(), self._get_dependencies_imports(), @@ -762,10 +763,11 @@ def _get_vars_hooks(self) -> set[str]: """ vars_hooks = set() for var in self._get_vars(): - vars_hooks.update(var._var_hooks) + if var._var_data: + vars_hooks.update(var._var_data.hooks) return vars_hooks - def _get_events_hooks(self) -> str[str]: + def _get_events_hooks(self) -> set[str]: """Get the hooks required by events referenced in this component. Returns: diff --git a/reflex/style.py b/reflex/style.py index b2f4c67943..a6db990c94 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -6,26 +6,27 @@ from reflex import constants from reflex.event import EventChain -from reflex.utils import format, imports -from reflex.vars import BaseVar, ImportVar, Var - -color_mode_imports = { - f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, -} -color_mode_hooks = { - f"const [ {{{constants.ColorMode.NAME}}}, {{{constants.ColorMode.TOGGLE}}} ] = useContext(ColorModeContext)", -} +from reflex.utils import format +from reflex.vars import BaseVar, ImportVar, Var, VarData + +VarData.update_forward_refs() +color_mode_var_data = VarData( + imports={ + f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, + }, + hooks={ + f"const [ {{{constants.ColorMode.NAME}}}, {{{constants.ColorMode.TOGGLE}}} ] = useContext(ColorModeContext)", + }, +) color_mode = BaseVar( _var_name=constants.ColorMode.NAME, _var_type="str", - _var_imports=color_mode_imports, - _var_hooks=color_mode_hooks, + _var_data=color_mode_var_data, ) toggle_color_mode = BaseVar( _var_name=constants.ColorMode.TOGGLE, _var_type=EventChain, - _var_imports=color_mode_imports, - _var_hooks=color_mode_hooks, + _var_data=color_mode_var_data, ) @@ -38,18 +39,17 @@ def convert(style_dict): Returns: The formatted style dictionary. """ - var_data = Var.create("") + var_data = None out = {} for key, value in style_dict.items(): key = format.to_camel_case(key) if isinstance(value, dict): out[key], new_var_data = convert(value) else: - new_var_data = Var.create(value, _var_is_string=True) - out[key] = str(new_var_data) - var_data = var_data._replace( - add_imports=new_var_data._var_imports, add_hooks=new_var_data._var_hooks - ) + new_var = Var.create(value, _var_is_string=True) + out[key] = str(new_var) + new_var_data = new_var._var_data + var_data = VarData.merge(var_data, new_var_data) return out, var_data @@ -62,22 +62,19 @@ def __init__(self, style_dict: dict | None = None): Args: style_dict: The style dictionary. """ - style_dict, var_data = convert(style_dict or {}) - self._var_imports = var_data._var_imports - self._var_hooks = var_data._var_hooks + style_dict, self._var_data = convert(style_dict or {}) super().__init__(style_dict) - def update(self, style_dict: dict | None = None): + def update(self, style_dict: dict | None, **kwargs): """Update the style. Args: style_dict: The style dictionary. """ + if kwargs: + style_dict = {**style_dict, **kwargs} converted_dict = type(self)(style_dict) - self._var_imports = imports.merge_imports( - self._var_imports, converted_dict._var_imports - ) - self._var_hooks.update(converted_dict._var_hooks) + self._var_data = VarData.merge(self._var_data, converted_dict._var_data) super().update(converted_dict) def __setitem__(self, key: str, value: Any): @@ -88,6 +85,5 @@ def __setitem__(self, key: str, value: Any): value: The value to set. """ _var = Var.create(value) - self._var_imports = imports.merge_imports(self._var_imports, _var._var_imports) - self._var_hooks.update(_var._var_hooks) + self._var_data = VarData.merge(self._var_data, _var._var_data) super().__setitem__(key, value) diff --git a/reflex/vars.py b/reflex/vars.py index 6083bf067a..e7c253c395 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -72,7 +72,7 @@ REPLACED_NAMES = { "full_name": "_var_full_name", "name": "_var_name", - "state": "_var_state", + "state": "_var_data.state", "type_": "_var_type", "is_local": "_var_is_local", "is_string": "_var_is_string", @@ -94,24 +94,63 @@ def get_unique_variable_name() -> str: return get_unique_variable_name() -def _merge_imports( - *import_dicts: dict[str, set[ImportVar]] -) -> dict[str, set[ImportVar]]: - """Merge multiple import dicts. +class VarData(Base): + # The name of the enclosing state. + state: str = "" - This is the same as reflex.utils.imports.merge_imports, except - that causes a circular import, so define this helper to avoid - inline imports all over this module. + # Imports needed to render this var + imports: dict[str, set[ImportVar]] = {} - Args: - *import_dicts: The import dicts to merge. + # Hooks that need to be present in the component to render this var + hooks: set[str] = set() - Returns: - The merged import dict. - """ - from reflex.utils import imports + @classmethod + def merge(cls, *others: VarData | None) -> VarData | None: + """Merge multiple var data objects. + + Args: + *others: The var data objects to merge. - return imports.merge_imports(*import_dicts) + Returns: + The merged var data object. + """ + from reflex.utils.imports import merge_imports + + state = "" + imports = {} + hooks = set() + for var_data in others: + if var_data is None: + continue + state = state or var_data.state + imports = merge_imports(imports, var_data.imports) + hooks.update(var_data.hooks) + return ( + cls( + state=state, + imports=imports, + hooks=hooks, + ) + or None + ) + + def __bool__(self) -> bool: + return bool(self.state or self.imports or self.hooks) + + def dict(self) -> dict: + """Convert the var data to a dictionary. + + Returns: + The var data dictionary. + """ + return { + "state": self.state, + "imports": { + lib: [import_var.dict() for import_var in import_vars] + for lib, import_vars in self.imports.items() + }, + "hooks": list(self.hooks), + } def _encode_var(value: Var) -> str: @@ -123,19 +162,12 @@ def _encode_var(value: Var) -> str: Returns: The encoded var. """ - if value._var_imports or value._var_hooks: - data = { - "_var_imports": { - lib: [iv.dict() for iv in imports] - for lib, imports in value._var_imports.items() - }, - "_var_hooks": list(value._var_hooks), - } - return f"{json.dumps(data)}" + str(value) + if value._var_data: + return f"{value._var_data.json()}" + str(value) return str(value) -def _decode_var(value: str) -> tuple[list[str], str]: +def _decode_var(value: str) -> tuple[VarData, str]: """Decode the state name from a formatted var. Args: @@ -144,70 +176,36 @@ def _decode_var(value: str) -> tuple[list[str], str]: Returns: The extracted state name and the value without the state name. """ - _var_data = {} + var_datas = [] if isinstance(value, str): # Extract the state name from a formatted var while m := re.match(r"(.*)(.*)(.*)", value): value = m.group(1) + m.group(3) - _var_data = json.loads(m.group(2)) - _var_data["_var_hooks"] = set(_var_data["_var_hooks"]) - _var_data["_var_imports"] = { - lib: set(ImportVar(**iv_json) for iv_json in ivs_json) - for lib, ivs_json in _var_data["_var_imports"].items() - } - return _var_data, value - + var_datas.append(VarData.parse_raw(m.group(2))) + return VarData.merge(*var_datas), value -def _merge_var_data(*var_datas: dict[str, Any]) -> dict[str, Any]: - """Merge multiple var data dicts. - Args: - *var_datas: The var data dicts to merge. - - Returns: - The merged var data dict. - """ - _var_imports = {} - _var_hooks = set() - for _var_data in var_datas: - _var_imports = _merge_imports(_var_imports, _var_data.get("_var_imports", {})) - _var_hooks.update(_var_data.get("_var_hooks", set())) - return {"_var_imports": _var_imports, "_var_hooks": _var_hooks} - - -def _extract_var_data(value: Iterable) -> list[str]: +def _extract_var_data(value: Iterable) -> VarData | None: """Extract the var imports and hooks from an iterable containing a Var. Args: - value: The iterable to extract the state name from. + value: The iterable to extract the VarData from Returns: - The extracted state name. + The extracted VarData. """ - _var_data = {} - for sub in value: - if isinstance(sub, Var): - _var_data = _merge_var_data( - _var_data, - { - "_var_imports": sub._var_imports, - "_var_hooks": sub._var_hooks, - }, - ) - elif not isinstance(sub, str): - # Recurse into dict values - if hasattr(sub, "values") and callable(sub.values): - _var_data = _merge_var_data( - _var_data, - _extract_var_data(sub.values()), - ) - # Recurse into iterable values (or dict keys) - with contextlib.suppress(TypeError): - _var_data = _merge_var_data( - _var_data, - _extract_var_data(sub), - ) - return _var_data + var_data = None + with contextlib.suppress(TypeError): + for sub in value: + if isinstance(sub, Var): + var_data = VarData.merge(var_data, sub._var_data) + elif not isinstance(sub, str): + # Recurse into dict values + if hasattr(sub, "values") and callable(sub.values): + var_data = VarData.merge(var_data, _extract_var_data(sub.values())) + # Recurse into iterable values (or dict keys) + var_data = VarData.merge(var_data, _extract_var_data(sub)) + return var_data class Var: @@ -219,9 +217,6 @@ class Var: # The type of the var. _var_type: Type - # The name of the enclosing state. - _var_state: str - # Whether this is a local javascript variable. _var_is_local: bool @@ -231,11 +226,8 @@ class Var: # _var_full_name should be prefixed with _var_state _var_full_name_needs_state_prefix: bool - # Imports needed to render this var - _var_imports: dict[str, set[ImportVar]] - - # All substates that this var depends on - _var_hooks: set[str] + # Extra metadata associated with the Var + _var_data: Optional[VarData] @classmethod def create( @@ -263,10 +255,9 @@ def create( return value # Try to pull the imports and hooks from contained values. - _var_data = {} + _var_data = None if not isinstance(value, str): - with contextlib.suppress(TypeError): - _var_data = _extract_var_data(value) + _var_data = _extract_var_data(value) # Try to serialize the value. type_ = type(value) @@ -282,8 +273,7 @@ def create( _var_type=type_, _var_is_local=_var_is_local, _var_is_string=_var_is_string, - _var_imports=_var_data.get("_var_imports", {}), - _var_hooks=_var_data.get("_var_hooks", set()), + _var_data=_var_data, ) @classmethod @@ -326,32 +316,23 @@ def __post_init__(self) -> None: _var_data, _var_name = _decode_var(self._var_name) if _var_data: self._var_name = _var_name - self._var_hooks.update(_var_data.get("_var_hooks", set())) - self._var_imports = _merge_imports( - self._var_imports, - _var_data.get("_var_imports", {}), - ) + self._var_data = VarData.merge(self._var_data, _var_data) - def _replace(self, add_imports=None, add_hooks=None, **kwargs: Any) -> Var: + def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: # Cannot use dataclasses.replace because ComputedVar uses multiple inheritance # and it's __init__ has a required fget argument - _var_imports = _merge_imports( - kwargs.pop("_var_imports", self._var_imports), - add_imports or {}, - ) - _var_hooks = kwargs.pop("_var_hooks", self._var_hooks).union(add_hooks or set()) field_values = dict( _var_name=kwargs.pop("_var_name", self._var_name), _var_type=kwargs.pop("_var_type", self._var_type), - _var_state=kwargs.pop("_var_state", self._var_state), _var_is_local=kwargs.pop("_var_is_local", self._var_is_local), _var_is_string=kwargs.pop("_var_is_string", self._var_is_string), _var_full_name_needs_state_prefix=kwargs.pop( "_var_full_name_needs_state_prefix", self._var_full_name_needs_state_prefix, ), - _var_imports=_var_imports, - _var_hooks=_var_hooks, + _var_data=VarData.merge( + kwargs.get("_var_data", self._var_data), merge_var_data + ), ) return BaseVar(**field_values) @@ -364,8 +345,6 @@ def _decode(self) -> Any: Returns: The decoded value or the Var name. """ - if self._var_state: - return self._var_full_name if self._var_is_string: return self._var_name try: @@ -385,12 +364,10 @@ def equals(self, other: Var) -> bool: return ( self._var_name == other._var_name and self._var_type == other._var_type - and self._var_state == other._var_state and self._var_is_local == other._var_is_local and self._var_full_name_needs_state_prefix == other._var_full_name_needs_state_prefix - and self._var_imports == other._var_imports - and self._var_hooks == other._var_hooks + and self._var_data == other._var_data ) def to_string(self, json: bool = True) -> Var: @@ -689,8 +666,7 @@ def operation( _var_type=type_, _var_is_string=False, _var_full_name_needs_state_prefix=False, - add_imports=other._var_imports if other is not None else None, - add_hooks=other._var_hooks if other is not None else None, + merge_var_data=other._var_data if other is not None else None, ) @staticmethod @@ -1176,8 +1152,7 @@ def contains(self, other: Any) -> Var: _var_name=f"{self._var_name}.{method}({other._var_full_name})", _var_type=bool, _var_is_string=False, - add_imports=other._var_imports, - add_hooks=other._var_hooks, + merge_var_data=other._var_data, ) else: # str, list, tuple # For strings, the left operand must be a string. @@ -1191,8 +1166,7 @@ def contains(self, other: Any) -> Var: _var_name=f"{self._var_name}.includes({other._var_full_name})", _var_type=bool, _var_is_string=False, - add_imports=other._var_imports, - add_hooks=other._var_hooks, + merge_var_data=other._var_data, ) def reverse(self) -> Var: @@ -1274,8 +1248,7 @@ def split(self, other: str | Var[str] = " ") -> Var: _var_name=f"{self._var_name}.split({other._var_full_name})", _var_is_string=False, _var_type=list[str], - add_imports=other._var_imports, - add_hooks=other._var_hooks, + merge_var_data=other._var_data, ) def join(self, other: str | Var[str] | None = None) -> Var: @@ -1304,8 +1277,7 @@ def join(self, other: str | Var[str] | None = None) -> Var: _var_name=f"{self._var_name}.join({other._var_full_name})", _var_is_string=False, _var_type=str, - add_imports=other._var_imports, - add_hooks=other._var_hooks, + merge_var_data=other._var_data, ) def foreach(self, fn: Callable) -> Var: @@ -1348,8 +1320,10 @@ def _var_full_name(self) -> str: return self._var_name return ( self._var_name - if self._var_state == "" - else ".".join([format.format_state_name(self._var_state), self._var_name]) + if self._var_data.state == "" + else ".".join( + [format.format_state_name(self._var_data.state), self._var_name] + ) ) def _var_set_state(self, state: Type[State] | str) -> Any: @@ -1361,22 +1335,20 @@ def _var_set_state(self, state: Type[State] | str) -> Any: Returns: The var with the set state. """ - if isinstance(state, str): - self._var_state = state - else: - self._var_state = state.get_full_name() - self._var_hooks.add( - "const {0} = useContext(StateContexts.{0})".format( - format.format_state_name(self._var_state) - ) - ) - self._var_imports = _merge_imports( - self._var_imports, - { + state_name = state if isinstance(state, str) else state.get_full_name() + new_var_data = VarData( + state=state_name, + hooks={ + "const {0} = useContext(StateContexts.{0})".format( + format.format_state_name(state_name) + ) + }, + imports={ f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="StateContexts")}, "react": {ImportVar(tag="useContext")}, }, ) + self._var_data = VarData.merge(self._var_data, new_var_data) self._var_full_name_needs_state_prefix = True return self @@ -1394,9 +1366,6 @@ class BaseVar(Var): # The type of the var. _var_type: Type = dataclasses.field(default=Any) - # The name of the enclosing state. - _var_state: str = dataclasses.field(default="") - # Whether this is a local javascript variable. _var_is_local: bool = dataclasses.field(default=False) @@ -1406,11 +1375,8 @@ class BaseVar(Var): # _var_full_name should be prefixed with _var_state _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False) - # Imports needed to render this var - _var_imports: dict[str, set[ImportVar]] = dataclasses.field(default_factory=dict) - - # All substates that this var depends on - _var_hooks: set[str] = dataclasses.field(default_factory=set) + # Extra metadata associated with the Var + _var_data: Optional[VarData] = dataclasses.field(default=None) def __hash__(self) -> int: """Define a hash function for a var. @@ -1473,9 +1439,11 @@ def get_setter_name(self, include_state: bool = True) -> str: The name of the setter function. """ setter = constants.SETTER_PREFIX + self._var_name - if not include_state or self._var_state == "": + if self._var_data is None: + return setter + if not include_state or self._var_data.state == "": return setter - return ".".join((self._var_state, setter)) + return ".".join((self._var_data.state, setter)) def get_setter(self) -> Callable[[State, Any], None]: """Get the var's setter function. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 750420b577..37d5767fbc 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -11,6 +11,7 @@ from typing import ( Any, Callable, Dict, + Iterable, List, Optional, Set, @@ -22,14 +23,23 @@ from typing import ( USED_VARIABLES: Incomplete def get_unique_variable_name() -> str: ... -def _decode_var_state(value: str) -> tuple[str, str]: ... +def _encode_var(value: Var) -> str: ... +def _decode_var(value: str) -> tuple[VarData, str]: ... +def _extract_var_data(value: Iterable) -> VarData | None: ... + +class VarData(Base): + state: str + imports: dict[str, set[ImportVar]] + hooks: set[str] + @classmethod + def merge(cls, *others: VarData | None) -> VarData | None: ... class Var: _var_name: str _var_type: Type - _var_state: str = "" _var_is_local: bool = False _var_is_string: bool = False + _var_data: VarData | None = None @classmethod def create( cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False @@ -95,9 +105,9 @@ class Var: class BaseVar(Var): _var_name: str _var_type: Any - _var_state: str = "" _var_is_local: bool = False _var_is_string: bool = False + _var_data: VarData | None = None def __hash__(self) -> int: ... def get_default_value(self) -> Any: ... def get_setter_name(self, include_state: bool = ...) -> str: ... From 9c81f8ac8457fe95b934236bcbd666b2072d2044 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 14:01:17 -0800 Subject: [PATCH 04/29] Account for all imports and hooks where they are used Remove all* default imports for a page, instead relying on the components and vars that use a particular import to name it explicitly with an ImportVar Remove hardcoded hooks, instead format these in for components that depend on them. Move ImportVar to reflex.utils.imports so it can live next to the ImportDict and avoid circular import with reflex.vars --- .../jinja/web/pages/index.js.jinja2 | 9 --- reflex/.templates/web/utils/state.js | 16 ---- reflex/app.py | 2 +- reflex/compiler/compiler.py | 23 +----- reflex/compiler/utils.py | 3 +- reflex/components/component.py | 59 +++++++++++---- reflex/components/datadisplay/code.py | 4 +- reflex/components/datadisplay/dataeditor.py | 3 +- reflex/components/datadisplay/dataeditor.pyi | 3 +- reflex/components/datadisplay/datatable.py | 4 +- reflex/components/datadisplay/moment.py | 4 +- reflex/components/forms/debounce.py | 14 +++- reflex/components/forms/editor.py | 3 +- reflex/components/forms/form.py | 14 +++- reflex/components/forms/input.py | 4 +- reflex/components/forms/upload.py | 14 +++- reflex/components/layout/cond.py | 22 +++++- reflex/components/libs/chakra.py | 18 ++--- reflex/components/overlay/banner.py | 21 ++++-- reflex/components/radix/themes/base.py | 4 +- reflex/components/typography/markdown.py | 3 +- reflex/constants/__init__.py | 4 + reflex/constants/compiler.py | 25 +++++++ reflex/style.py | 4 +- reflex/utils/imports.py | 51 ++++++++++++- reflex/vars.py | 73 +++++++------------ reflex/vars.pyi | 18 +---- 27 files changed, 254 insertions(+), 168 deletions(-) diff --git a/reflex/.templates/jinja/web/pages/index.js.jinja2 b/reflex/.templates/jinja/web/pages/index.js.jinja2 index 56323d5a73..efb086ef5e 100644 --- a/reflex/.templates/jinja/web/pages/index.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/index.js.jinja2 @@ -8,15 +8,6 @@ {% block export %} export default function Component() { - const focusRef = useRef(); - - // Set focus to the specified element. - useEffect(() => { - if (focusRef.current) { - focusRef.current.focus(); - } - }) - {% for hook in hooks %} {{ hook }} {% endfor %} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 1c7410a2af..6b91c64e1f 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -99,22 +99,6 @@ export const applyDelta = (state, delta) => { }; -/** - * Get all local storage items in a key-value object. - * @returns object of items in local storage. - */ -export const getAllLocalStorageItems = () => { - var localStorageItems = {}; - - for (var i = 0, len = localStorage.length; i < len; i++) { - var key = localStorage.key(i); - localStorageItems[key] = localStorage.getItem(key); - } - - return localStorageItems; -} - - /** * Handle frontend event or send the event to the backend via Websocket. * @param event The event to send. diff --git a/reflex/app.py b/reflex/app.py index ab53672a56..07ff781592 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -59,7 +59,7 @@ StateUpdate, ) from reflex.utils import console, format, prerequisites, types -from reflex.vars import ImportVar +from reflex.utils.imports import ImportVar # Define custom types. ComponentCallable = Callable[[], Component] diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 8e3f87e695..88a5f053f6 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -10,29 +10,10 @@ from reflex.components.component import Component, ComponentStyle, CustomComponent from reflex.config import get_config from reflex.state import State -from reflex.utils import imports -from reflex.vars import ImportVar +from reflex.utils.imports import ImportDict, ImportVar # Imports to be included in every Reflex app. -DEFAULT_IMPORTS: imports.ImportDict = { - "react": { - ImportVar(tag="Fragment"), - ImportVar(tag="useEffect"), - ImportVar(tag="useRef"), - ImportVar(tag="useState"), - ImportVar(tag="useContext"), - }, - "next/router": {ImportVar(tag="useRouter")}, - f"/{constants.Dirs.STATE_PATH}": { - ImportVar(tag="uploadFiles"), - ImportVar(tag="isTrue"), - ImportVar(tag="spreadArraysOrObjects"), - ImportVar(tag="preventDefault"), - ImportVar(tag="refs"), - ImportVar(tag="getRefValue"), - ImportVar(tag="getRefValues"), - ImportVar(tag="getAllLocalStorageItems"), - }, +DEFAULT_IMPORTS: ImportDict = { "": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, } diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index b07412812a..36156559ea 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -24,13 +24,12 @@ from reflex.state import Cookie, LocalStorage, State from reflex.style import Style from reflex.utils import console, format, imports, path_ops -from reflex.vars import ImportVar # To re-export this function. merge_imports = imports.merge_imports -def compile_import_statement(fields: set[ImportVar]) -> tuple[str, set[str]]: +def compile_import_statement(fields: set[imports.ImportVar]) -> tuple[str, set[str]]: """Compile an import statement. Args: diff --git a/reflex/components/component.py b/reflex/components/component.py index 466c1686ee..d0e40b47dc 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -9,7 +9,7 @@ from reflex.base import Base from reflex.components.tags import Tag -from reflex.constants import Dirs, EventTriggers +from reflex.constants import Dirs, EventTriggers, Hooks, Imports from reflex.event import ( EventChain, EventHandler, @@ -20,8 +20,9 @@ ) from reflex.style import Style from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import BaseVar, ImportVar, Var +from reflex.vars import BaseVar, Var class Component(Base, ABC): @@ -684,6 +685,23 @@ def _get_dependencies_imports(self) -> imports.ImportDict: {dep: {ImportVar(tag=None, render=False)} for dep in self.lib_dependencies} ) + def _get_hooks_imports(self) -> imports.ImportDict: + """Get the imports required by certain hooks.""" + _imports = {} + if self._get_ref_hook(): + _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) + _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) + if self._get_mount_lifecycle_hook(): + _imports.setdefault("react", set()).add(ImportVar(tag="useEffect")) + if self._get_special_hooks(): + _imports.setdefault("react", set()).update( + { + ImportVar(tag="useRef"), + ImportVar(tag="useEffect"), + }, + ) + return _imports + def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -693,13 +711,7 @@ def _get_imports(self) -> imports.ImportDict: _imports = {} if self.library is not None and self.tag is not None: _imports[self.library] = {self.import_var} - event_imports = {} - if self.event_triggers: - event_imports = { - f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, - f"/{Dirs.STATE_PATH}": {ImportVar(tag="Event")}, - "react": {ImportVar(tag="useContext")}, - } + event_imports = Imports.EVENTS if self.event_triggers else {} # determine imports from Vars var_imports = [ var._var_data.imports for var in self._get_vars() if var._var_data @@ -707,6 +719,7 @@ def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( self._get_props_imports(), self._get_dependencies_imports(), + self._get_hooks_imports(), _imports, event_imports, *var_imports, @@ -773,9 +786,27 @@ def _get_events_hooks(self) -> set[str]: Returns: The hooks for the events. """ - # TODO: use constants here for better indirection if self.event_triggers: - return {"const [addEvents, connectError] = useContext(EventLoopContext);"} + return {Hooks.EVENTS} + return set() + + def _get_special_hooks(self) -> set[str]: + """Get the hooks required by special actions referenced in this component. + + Returns: + The hooks for special actions. + """ + if self.autofocus: + return { + """ + // Set focus to the specified element. + const focusRef = useRef(null) + useEffect(() => { + if (focusRef.current) { + focusRef.current.focus(); + } + })""", + } return set() def _get_hooks_internal(self) -> Set[str]: @@ -795,6 +826,7 @@ def _get_hooks_internal(self) -> Set[str]: ) .union(self._get_vars_hooks()) .union(self._get_events_hooks()) + .union(self._get_special_hooks()) ) def _get_hooks(self) -> str | None: @@ -1093,10 +1125,11 @@ class NoSSRComponent(Component): def _get_imports(self) -> imports.ImportDict: dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}} - + _imports = super()._get_imports() + _imports[self.library] = {ImportVar(tag=None, render=False)} return imports.merge_imports( dynamic_import, - {self.library: {ImportVar(tag=None, render=False)}}, + _imports, self._get_dependencies_imports(), ) diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index ec3f4d3167..2d2f4596e5 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -10,7 +10,7 @@ from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var # Path to the prism styles. PRISM_STYLES_PATH: str = "/styles/code/prism" @@ -49,7 +49,7 @@ def _get_imports(self) -> imports.ImportDict: if self.theme is not None: merged_imports = imports.merge_imports( merged_imports, - {PRISM_STYLES_PATH: {ImportVar(tag=self.theme._var_name)}}, + {PRISM_STYLES_PATH: {imports.ImportVar(tag=self.theme._var_name)}}, ) return merged_imports diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index ecef03b385..edec813a5a 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -8,8 +8,9 @@ from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import ImportVar, Var, get_unique_variable_name +from reflex.vars import Var, get_unique_variable_name # TODO: Fix the serialization issue for custom types. diff --git a/reflex/components/datadisplay/dataeditor.pyi b/reflex/components/datadisplay/dataeditor.pyi index cddc7ec86d..7f3621f508 100644 --- a/reflex/components/datadisplay/dataeditor.pyi +++ b/reflex/components/datadisplay/dataeditor.pyi @@ -13,8 +13,9 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.components.literals import LiteralRowMarker from reflex.utils import console, format, imports, types +from reflex.utils.imports import ImportVar from reflex.utils.serializers import serializer -from reflex.vars import ImportVar, Var, get_unique_variable_name +from reflex.vars import Var, get_unique_variable_name class GridColumnIcons(Enum): ... diff --git a/reflex/components/datadisplay/datatable.py b/reflex/components/datadisplay/datatable.py index d9fe49f30e..a5a097a6c0 100644 --- a/reflex/components/datadisplay/datatable.py +++ b/reflex/components/datadisplay/datatable.py @@ -8,7 +8,7 @@ from reflex.components.tags import Tag from reflex.utils import imports, types from reflex.utils.serializers import serialize, serializer -from reflex.vars import BaseVar, ComputedVar, ImportVar, Var +from reflex.vars import BaseVar, ComputedVar, Var class Gridjs(Component): @@ -105,7 +105,7 @@ def create(cls, *children, **props): def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), - {"": {ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, + {"": {imports.ImportVar(tag="gridjs/dist/theme/mermaid.css")}}, ) def _render(self) -> Tag: diff --git a/reflex/components/datadisplay/moment.py b/reflex/components/datadisplay/moment.py index 9ae8381bc2..31ffb5ffad 100644 --- a/reflex/components/datadisplay/moment.py +++ b/reflex/components/datadisplay/moment.py @@ -4,7 +4,7 @@ from reflex.components.component import Component, NoSSRComponent from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Moment(NoSSRComponent): @@ -78,7 +78,7 @@ def _get_imports(self) -> imports.ImportDict: if self.tz is not None: merged_imports = imports.merge_imports( merged_imports, - {"moment-timezone": {ImportVar(tag="")}}, + {"moment-timezone": {imports.ImportVar(tag="")}}, ) return merged_imports diff --git a/reflex/components/forms/debounce.py b/reflex/components/forms/debounce.py index da318a1cba..1618d29ba6 100644 --- a/reflex/components/forms/debounce.py +++ b/reflex/components/forms/debounce.py @@ -1,10 +1,11 @@ """Wrapper around react-debounce-input.""" from __future__ import annotations -from typing import Any +from typing import Any, Set from reflex.components import Component from reflex.components.tags import Tag +from reflex.utils import imports from reflex.vars import Var @@ -77,6 +78,17 @@ def _render(self) -> Tag: object.__setattr__(child, "render", lambda: "") return tag + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), *[c._get_imports() for c in self.children] + ) + + def _get_hooks_internal(self) -> Set[str]: + hooks = super()._get_hooks_internal() + for child in self.children: + hooks.update(child._get_hooks_internal()) + return hooks + def props_not_none(c: Component) -> dict[str, Any]: """Get all properties of the component that are not None. diff --git a/reflex/components/forms/editor.py b/reflex/components/forms/editor.py index 3f9bf34ffe..9616f9aa0f 100644 --- a/reflex/components/forms/editor.py +++ b/reflex/components/forms/editor.py @@ -8,7 +8,8 @@ from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var class EditorButtonList(list, enum.Enum): diff --git a/reflex/components/forms/form.py b/reflex/components/forms/form.py index f17642ece7..7c6855617a 100644 --- a/reflex/components/forms/form.py +++ b/reflex/components/forms/form.py @@ -5,8 +5,9 @@ from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent -from reflex.constants import EventTriggers +from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain, EventHandler, EventSpec +from reflex.utils import imports from reflex.vars import Var @@ -66,6 +67,17 @@ def get_event_triggers(self) -> Dict[str, Any]: EventTriggers.ON_SUBMIT: lambda e0: [form_refs], } + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + { + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="getRefValue"), + imports.ImportVar(tag="getRefValues"), + }, + }, + ) + class FormControl(ChakraComponent): """Provide context to form components.""" diff --git a/reflex/components/forms/input.py b/reflex/components/forms/input.py index 75bbac4cd2..9773ecda8a 100644 --- a/reflex/components/forms/input.py +++ b/reflex/components/forms/input.py @@ -11,7 +11,7 @@ ) from reflex.constants import EventTriggers from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Input(ChakraComponent): @@ -58,7 +58,7 @@ class Input(ChakraComponent): def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), - {"/utils/state": {ImportVar(tag="set_val")}}, + {"/utils/state": {imports.ImportVar(tag="set_val")}}, ) def get_event_triggers(self) -> Dict[str, Any]: diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index d1a91adc75..ab73fa483c 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -6,8 +6,9 @@ from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box -from reflex.constants import EventTriggers +from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain +from reflex.utils import imports from reflex.vars import BaseVar, Var files_state: str = "const [files, setFiles] = useState([]);" @@ -114,3 +115,14 @@ def _render(self): def _get_hooks(self) -> str | None: return (super()._get_hooks() or "") + files_state + + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + { + "react": {imports.ImportVar(tag="useState")}, + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="uploadFiles"), + }, + }, + ) diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 4439902dea..75395779af 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -6,8 +6,9 @@ from reflex.components.component import Component from reflex.components.layout.fragment import Fragment from reflex.components.tags import CondTag, Tag -from reflex.utils import format -from reflex.vars import Var +from reflex.constants import Dirs +from reflex.utils import format, imports +from reflex.vars import Var, VarData class Cond(Component): @@ -80,6 +81,12 @@ def render(self) -> Dict: cond_state=f"isTrue({self.cond._var_full_name})", ) + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + getattr(self.cond._var_data, "imports", {}), + ) + def cond(condition: Any, c1: Any, c2: Any = None): """Create a conditional component or Prop. @@ -101,6 +108,15 @@ def cond(condition: Any, c1: Any, c2: Any = None): # Convert the condition to a Var. cond_var = Var.create(condition) assert cond_var is not None, "The condition must be set." + cond_var = cond_var._replace( + merge_var_data=VarData( + imports={ + f"/{Dirs.STATE_PATH}": { + imports.ImportVar(tag="isTrue"), + }, + }, + ), + ) # If the first component is a component, create a Cond component. if isinstance(c1, Component): @@ -109,7 +125,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): ), "Both arguments must be components." return Cond.create(cond_var, c1, c2) - # Otherwise, create a conditionl Var. + # Otherwise, create a conditional Var. # Check that the second argument is valid. if isinstance(c2, Component): raise ValueError("Both arguments must be props.") diff --git a/reflex/components/libs/chakra.py b/reflex/components/libs/chakra.py index 1664d49df6..ff46f3d48d 100644 --- a/reflex/components/libs/chakra.py +++ b/reflex/components/libs/chakra.py @@ -5,7 +5,7 @@ from reflex.components.component import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class ChakraComponent(Component): @@ -57,17 +57,17 @@ def create(cls) -> Component: ) def _get_imports(self) -> imports.ImportDict: - imports = super()._get_imports() - imports.setdefault(self.__fields__["library"].default, set()).add( - ImportVar(tag="extendTheme", is_default=False), + _imports = super()._get_imports() + _imports.setdefault(self.__fields__["library"].default, set()).add( + imports.ImportVar(tag="extendTheme", is_default=False), ) - imports.setdefault("/utils/theme.js", set()).add( - ImportVar(tag="theme", is_default=True), + _imports.setdefault("/utils/theme.js", set()).add( + imports.ImportVar(tag="theme", is_default=True), ) - imports.setdefault(Global.__fields__["library"].default, set()).add( - ImportVar(tag="css", is_default=False), + _imports.setdefault(Global.__fields__["library"].default, set()).add( + imports.ImportVar(tag="css", is_default=False), ) - return imports + return _imports def _get_custom_code(self) -> str | None: return """ diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index 7b9a31eb92..db3f7dcb8e 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -5,22 +5,27 @@ from reflex.components.base.bare import Bare from reflex.components.component import Component -from reflex.components.layout import Box, Cond +from reflex.components.layout import Box, cond from reflex.components.overlay.modal import Modal from reflex.components.typography import Text +from reflex.constants import Hooks, Imports from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var, VarData + +connect_error_var_data = VarData( + imports=Imports.EVENTS, + hooks={Hooks.EVENTS}, +) connection_error: Var = Var.create_safe( value="(connectError !== null) ? connectError.message : ''", _var_is_local=False, _var_is_string=False, -) +)._replace(merge_var_data=connect_error_var_data) has_connection_error: Var = Var.create_safe( value="connectError !== null", _var_is_string=False, -) -has_connection_error._var_type = bool +)._replace(_var_type=bool, merge_var_data=connect_error_var_data) class WebsocketTargetURL(Bare): @@ -28,7 +33,7 @@ class WebsocketTargetURL(Bare): def _get_imports(self) -> imports.ImportDict: return { - "/utils/state.js": {ImportVar(tag="getEventURL")}, + "/utils/state.js": {imports.ImportVar(tag="getEventURL")}, } @classmethod @@ -78,7 +83,7 @@ def create(cls, comp: Optional[Component] = None) -> Component: textAlign="center", ) - return Cond.create(has_connection_error, comp) + return cond(has_connection_error, comp) class ConnectionModal(Component): @@ -96,7 +101,7 @@ def create(cls, comp: Optional[Component] = None) -> Component: """ if not comp: comp = Text.create(*default_connection_error()) - return Cond.create( + return cond( has_connection_error, Modal.create( header="Connection Error", diff --git a/reflex/components/radix/themes/base.py b/reflex/components/radix/themes/base.py index b24b1c0e8f..b6765ff357 100644 --- a/reflex/components/radix/themes/base.py +++ b/reflex/components/radix/themes/base.py @@ -6,7 +6,7 @@ from reflex.components import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralJustify = Literal["start", "center", "end", "between"] @@ -147,7 +147,7 @@ class Theme(RadixThemesComponent): def _get_imports(self) -> imports.ImportDict: return { **super()._get_imports(), - "": {ImportVar(tag="@radix-ui/themes/styles.css", install=False)}, + "": {imports.ImportVar(tag="@radix-ui/themes/styles.css", install=False)}, } diff --git a/reflex/components/typography/markdown.py b/reflex/components/typography/markdown.py index 993d4f48b1..5d494075c1 100644 --- a/reflex/components/typography/markdown.py +++ b/reflex/components/typography/markdown.py @@ -14,7 +14,8 @@ from reflex.components.typography.text import Text from reflex.style import Style from reflex.utils import console, imports, types -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var # Special vars used in the component map. _CHILDREN = Var.create_safe("children", _var_is_local=False) diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index 628f7511b5..ca408c6a65 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -22,6 +22,8 @@ CompileVars, ComponentName, Ext, + Hooks, + Imports, PageNames, ) from .config import ( @@ -68,7 +70,9 @@ Ext, Fnm, GitIgnore, + Hooks, RequirementsTxt, + Imports, IS_WINDOWS, LOCAL_STORAGE, LogLevel, diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 4a9d09d4ca..e309c5d4a9 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -2,6 +2,9 @@ from enum import Enum from types import SimpleNamespace +from reflex.constants import Dirs +from reflex.utils.imports import ImportVar + # The prefix used to create setters for state vars. SETTER_PREFIX = "set_" @@ -47,6 +50,12 @@ class CompileVars(SimpleNamespace): HYDRATE = "hydrate" # The name of the is_hydrated variable. IS_HYDRATED = "is_hydrated" + # The name of the function to add events to the queue. + ADD_EVENTS = "addEvents" + # The name of the var storing any connection error. + CONNECT_ERROR = "connectError" + # The name of the function for converting a dict to an event. + TO_EVENT = "Event" class PageNames(SimpleNamespace): @@ -77,3 +86,19 @@ def zip(self): The lower-case filename with zip extension. """ return self.value.lower() + Ext.ZIP + + +class Imports(SimpleNamespace): + """Common sets of import vars.""" + + EVENTS = { + "react": {ImportVar(tag="useContext")}, + f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, + f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, + } + + +class Hooks(SimpleNamespace): + """Common sets of hook declarations.""" + + EVENTS = f"const [{CompileVars.ADD_EVENTS}, {CompileVars.CONNECT_ERROR}] = useContext(EventLoopContext);" diff --git a/reflex/style.py b/reflex/style.py index a6db990c94..da0bcbd5db 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -7,7 +7,8 @@ from reflex import constants from reflex.event import EventChain from reflex.utils import format -from reflex.vars import BaseVar, ImportVar, Var, VarData +from reflex.utils.imports import ImportVar +from reflex.vars import BaseVar, Var, VarData VarData.update_forward_refs() color_mode_var_data = VarData( @@ -70,6 +71,7 @@ def update(self, style_dict: dict | None, **kwargs): Args: style_dict: The style dictionary. + kwargs: Other key value pairs to apply to the dict update. """ if kwargs: style_dict = {**style_dict, **kwargs} diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index d1e37ae5d8..0a4dd589ba 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,11 +3,9 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, Set +from typing import Dict, Optional, Set -from reflex.vars import ImportVar - -ImportDict = Dict[str, Set[ImportVar]] +from reflex.base import Base def merge_imports(*imports) -> ImportDict: @@ -25,3 +23,48 @@ def merge_imports(*imports) -> ImportDict: for field in fields: all_imports[lib].add(field) return all_imports + + +class ImportVar(Base): + """An import var.""" + + # The name of the import tag. + tag: Optional[str] + + # whether the import is default or named. + is_default: Optional[bool] = False + + # The tag alias. + alias: Optional[str] = None + + # Whether this import need to install the associated lib + install: Optional[bool] = True + + # whether this import should be rendered or not + render: Optional[bool] = True + + @property + def name(self) -> str: + """The name of the import. + + Returns: + The name(tag name with alias) of tag. + """ + return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore + + def __hash__(self) -> int: + """Define a hash function for the import var. + + Returns: + The hash of the var. + """ + return hash((self.tag, self.is_default, self.alias, self.install, self.render)) + + +class NoRenderImportVar(ImportVar): + """A import that doesn't need to be rendered.""" + + render: Optional[bool] = False + + +ImportDict = Dict[str, Set[ImportVar]] diff --git a/reflex/vars.py b/reflex/vars.py index e7c253c395..41635eea20 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -31,7 +31,7 @@ from reflex import constants from reflex.base import Base -from reflex.utils import console, format, serializers, types +from reflex.utils import console, format, imports, serializers, types if TYPE_CHECKING: from reflex.state import State @@ -95,11 +95,13 @@ def get_unique_variable_name() -> str: class VarData(Base): + """Metadata associated with a Var.""" + # The name of the enclosing state. state: str = "" # Imports needed to render this var - imports: dict[str, set[ImportVar]] = {} + imports: dict[str, set[imports.ImportVar]] = {} # Hooks that need to be present in the component to render this var hooks: set[str] = set() @@ -135,6 +137,11 @@ def merge(cls, *others: VarData | None) -> VarData | None: ) def __bool__(self) -> bool: + """Check if the var data is non-empty. + + Returns: + True if any field is set to a non-default value. + """ return bool(self.state or self.imports or self.hooks) def dict(self) -> dict: @@ -435,7 +442,7 @@ def __format__(self, format_spec: str) -> str: Returns: The formatted var. """ - # Encode the _var_imports and _var_hooks into the formatted output for tracking purposes. + # Encode the _var_data into the formatted output for tracking purposes. str_self = _encode_var(self) if self._var_is_local: return str_self @@ -836,7 +843,17 @@ def __add__(self, other: Var, flip=False) -> Var: types.get_base_class(self._var_type) == list and types.get_base_class(other_type) == list ): - return self.operation(",", other, fn="spreadArraysOrObjects", flip=flip) + return self.operation( + ",", other, fn="spreadArraysOrObjects", flip=flip + )._replace( + merge_var_data=VarData( + imports={ + f"/{constants.Dirs.STATE_PATH}": { + imports.ImportVar(tag="spreadArraysOrObjects") + } + }, + ), + ) return self.operation("+", other, flip=flip) def __radd__(self, other: Var) -> Var: @@ -1344,8 +1361,10 @@ def _var_set_state(self, state: Type[State] | str) -> Any: ) }, imports={ - f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="StateContexts")}, - "react": {ImportVar(tag="useContext")}, + f"/{constants.Dirs.CONTEXTS_PATH}": { + imports.ImportVar(tag="StateContexts") + }, + "react": {imports.ImportVar(tag="useContext")}, }, ) self._var_data = VarData.merge(self._var_data, new_var_data) @@ -1655,45 +1674,3 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: cvar = ComputedVar(fget=fget) cvar._cache = True return cvar - - -class ImportVar(Base): - """An import var.""" - - # The name of the import tag. - tag: Optional[str] - - # whether the import is default or named. - is_default: Optional[bool] = False - - # The tag alias. - alias: Optional[str] = None - - # Whether this import need to install the associated lib - install: Optional[bool] = True - - # whether this import should be rendered or not - render: Optional[bool] = True - - @property - def name(self) -> str: - """The name of the import. - - Returns: - The name(tag name with alias) of tag. - """ - return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore - - def __hash__(self) -> int: - """Define a hash function for the import var. - - Returns: - The hash of the var. - """ - return hash((self.tag, self.is_default, self.alias, self.install, self.render)) - - -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - - render: Optional[bool] = False diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 37d5767fbc..fbe9efe13b 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -49,7 +49,8 @@ class Var: cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False ) -> Var: ... @classmethod - def __class_getitem__(cls, type_: str) -> _GenericAlias: ... + def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... + def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: ... def equals(self, other: Var) -> bool: ... def to_string(self) -> Var: ... def __hash__(self) -> int: ... @@ -126,18 +127,3 @@ class ComputedVar(Var): def __init__(self, func) -> None: ... def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... - -class ImportVar(Base): - tag: Optional[str] - is_default: Optional[bool] = False - alias: Optional[str] = None - install: Optional[bool] = True - render: Optional[bool] = True - @property - def name(self) -> str: ... - def __hash__(self) -> int: ... - -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - -def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ... From e8de0e93cc8f129db0a3d41bcf9275e8b3154ba8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 15:51:32 -0800 Subject: [PATCH 05/29] Fixup hydrate middleware and _var_state --- reflex/middleware/hydrate_middleware.py | 2 +- reflex/vars.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 062ff33b3f..26467efd8e 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -51,7 +51,7 @@ async def preprocess( setattr(var_state, var_name, value) # Get the initial state. - delta = format.format_state({state.get_name(): state.dict()}) + delta = format.format_state(state.dict()) # since a full dict was captured, clean any dirtiness state._clean() diff --git a/reflex/vars.py b/reflex/vars.py index 41635eea20..cd46adcf4c 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -212,6 +212,8 @@ def _extract_var_data(value: Iterable) -> VarData | None: var_data = VarData.merge(var_data, _extract_var_data(sub.values())) # Recurse into iterable values (or dict keys) var_data = VarData.merge(var_data, _extract_var_data(sub)) + if hasattr(value, "values") and callable(value.values): + var_data = VarData.merge(var_data, _extract_var_data(value.values())) return var_data @@ -1371,6 +1373,15 @@ def _var_set_state(self, state: Type[State] | str) -> Any: self._var_full_name_needs_state_prefix = True return self + @property + def _var_state(self) -> str: + """Compat method for getting the state. + + Returns: + The state name associated with the var. + """ + return self._var_data.state if self._var_data else "" + @dataclasses.dataclass( eq=False, From 27f527caa16389095897f2437cbbf5b320ddec8b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 15:54:40 -0800 Subject: [PATCH 06/29] Fixup tests --- tests/compiler/test_compiler.py | 2 +- tests/components/test_component.py | 3 +- tests/middleware/test_hydrate_middleware.py | 10 +++--- tests/test_state.py | 39 +++++++++++++-------- tests/test_var.py | 8 ++--- tests/utils/test_format.py | 1 - 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/tests/compiler/test_compiler.py b/tests/compiler/test_compiler.py index 5329b97782..1d1e29df27 100644 --- a/tests/compiler/test_compiler.py +++ b/tests/compiler/test_compiler.py @@ -5,7 +5,7 @@ from reflex.compiler import compiler, utils from reflex.utils import imports -from reflex.vars import ImportVar +from reflex.utils.imports import ImportVar @pytest.mark.parametrize( diff --git a/tests/components/test_component.py b/tests/components/test_component.py index c8f83fc42d..e6dfe78b48 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -11,7 +11,8 @@ from reflex.state import State from reflex.style import Style from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var @pytest.fixture diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 150083bd59..269be789db 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -104,7 +104,7 @@ async def test_preprocess( app=app, event=request.getfixturevalue(event_fixture), state=state ) assert isinstance(update, StateUpdate) - assert update.delta == {state.get_name(): state.dict()} + assert update.delta == state.dict() events = update.events assert len(events) == 2 @@ -133,16 +133,16 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): update = await hydrate_middleware.preprocess(app=app, event=event1, state=state) assert isinstance(update, StateUpdate) - assert update.delta == {"test_state": state.dict()} + assert update.delta == state.dict() assert len(update.events) == 3 # Apply the events. events = update.events update = await state._process(events[0]).__anext__() - assert update.delta == {"test_state": {"num": 1}} + assert update.delta == {"num": 1} update = await state._process(events[1]).__anext__() - assert update.delta == {"test_state": {"num": 2}} + assert update.delta == {"num": 2} update = await state._process(events[2]).__anext__() assert update.delta == exp_is_hydrated(state) @@ -163,7 +163,7 @@ async def test_preprocess_no_events(hydrate_middleware, event1): state=state, ) assert isinstance(update, StateUpdate) - assert update.delta == {"test_state": state.dict()} + assert update.delta == state.dict() assert len(update.events) == 1 assert isinstance(update, StateUpdate) diff --git a/tests/test_state.py b/tests/test_state.py index b8673fb867..7e17214f97 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -324,11 +324,17 @@ def test_dict(test_state): Args: test_state: A state. """ - substates = {"child_state", "child_state2"} - assert set(test_state.dict().keys()) == set(test_state.vars.keys()) | substates - assert ( - set(test_state.dict(include_computed=False).keys()) - == set(test_state.base_vars) | substates + substates = { + "test_state", + "test_state.child_state", + "test_state.child_state.grandchild_state", + "test_state.child_state2", + } + test_state_dict = test_state.dict() + assert set(test_state_dict) == substates + assert set(test_state_dict[test_state.get_name()]) == set(test_state.vars) + assert set(test_state.dict(include_computed=False)[test_state.get_name()]) == set( + test_state.base_vars ) @@ -1081,9 +1087,9 @@ def comp_v(self) -> int: return self.v cs = ComputedState() - assert cs.dict()["v"] == 0 + assert cs.dict()[cs.get_full_name()]["v"] == 0 assert comp_v_calls == 1 - assert cs.dict()["comp_v"] == 0 + assert cs.dict()[cs.get_full_name()]["comp_v"] == 0 assert comp_v_calls == 1 assert cs.comp_v == 0 assert comp_v_calls == 1 @@ -1156,24 +1162,27 @@ def dep_v(self) -> int: assert ps.dirty_vars == set() assert cs.dirty_vars == set() - assert ps.dict() == { - cs.get_name(): {"dep_v": 2}, + dict1 = ps.dict() + assert dict1[ps.get_full_name()] == { "no_cache_v": 1, CompileVars.IS_HYDRATED: False, "router": formatted_router, } - assert ps.dict() == { - cs.get_name(): {"dep_v": 4}, + assert dict1[cs.get_full_name()] == {"dep_v": 2} + dict2 = ps.dict() + assert dict2[ps.get_full_name()] == { "no_cache_v": 3, CompileVars.IS_HYDRATED: False, "router": formatted_router, } - assert ps.dict() == { - cs.get_name(): {"dep_v": 6}, + assert dict2[cs.get_full_name()] == {"dep_v": 4} + dict3 = ps.dict() + assert dict3[ps.get_full_name()] == { "no_cache_v": 5, CompileVars.IS_HYDRATED: False, "router": formatted_router, } + assert dict3[cs.get_full_name()] == {"dep_v": 6} assert counter == 6 @@ -2201,13 +2210,13 @@ class MutableContainsBase(State): items: List[Foo] = [Foo()] dict_val = MutableContainsBase().dict() - assert isinstance(dict_val["items"][0], dict) + assert isinstance(dict_val[MutableContainsBase.get_full_name()]["items"][0], dict) val = json_dumps(dict_val) f_items = '[{"tags": ["123", "456"]}]' f_formatted_router = str(formatted_router).replace("'", '"') assert ( val - == f'{{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}' + == f'{{"{MutableContainsBase.get_full_name()}": {{"is_hydrated": false, "items": {f_items}, "router": {f_formatted_router}}}}}' ) diff --git a/tests/test_var.py b/tests/test_var.py index 04af19b008..e3cb28a1c7 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -7,10 +7,10 @@ from reflex.base import Base from reflex.state import State +from reflex.utils.imports import ImportVar from reflex.vars import ( BaseVar, ComputedVar, - ImportVar, Var, ) @@ -553,10 +553,10 @@ def test_var_unsupported_indexing_dicts(var, index): "fixture,full_name", [ ("ParentState", "parent_state.var_without_annotation"), - ("ChildState", "parent_state.child_state.var_without_annotation"), + ("ChildState", "parent_state__child_state.var_without_annotation"), ( "GrandChildState", - "parent_state.child_state.grand_child_state.var_without_annotation", + "parent_state__child_state__grand_child_state.var_without_annotation", ), ("StateWithAnyVar", "state_with_any_var.var_without_annotation"), ], @@ -636,7 +636,7 @@ def test_import_var(import_var, expected): (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", - "testing f-string with $_var_state=state{state.myvar}", + 'testing f-string with ${"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}{state.myvar}', ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 8e66999626..994c4243e8 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -341,7 +341,6 @@ def test_format_cond(condition: str, true_value: str, false_value: str, expected BaseVar( _var_name="_", _var_type=Any, - _var_state="", _var_is_local=True, _var_is_string=False, ), From da6a3676a49229d56143ccfa56f341fc8589c834 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 15:54:57 -0800 Subject: [PATCH 07/29] Do not always create new vars in style.py --- reflex/style.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/reflex/style.py b/reflex/style.py index da0bcbd5db..6f2ed77544 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -46,10 +46,12 @@ def convert(style_dict): key = format.to_camel_case(key) if isinstance(value, dict): out[key], new_var_data = convert(value) + elif isinstance(value, Var): + new_var_data = value._var_data + out[key] = str(value) else: - new_var = Var.create(value, _var_is_string=True) - out[key] = str(new_var) - new_var_data = new_var._var_data + new_var_data = Var.create(value)._var_data + out[key] = value var_data = VarData.merge(var_data, new_var_data) return out, var_data From c9ebed47a8eb984e16648273016a651f12feed94 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 Nov 2023 16:00:18 -0800 Subject: [PATCH 08/29] Fixup change from State.dict returning flat structure --- tests/middleware/test_hydrate_middleware.py | 4 +- tests/utils/test_format.py | 58 +++++++++++---------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 269be789db..94f6e98cf2 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -139,10 +139,10 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): # Apply the events. events = update.events update = await state._process(events[0]).__anext__() - assert update.delta == {"num": 1} + assert update.delta == {'test_state': {'num': 1}} update = await state._process(events[1]).__anext__() - assert update.delta == {"num": 2} + assert update.delta == {'test_state': {'num': 2}} update = await state._process(events[2]).__anext__() assert update.delta == exp_is_hydrated(state) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 994c4243e8..9b5dc6505e 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -8,7 +8,7 @@ from reflex.style import Style from reflex.utils import format from reflex.vars import BaseVar, Var -from tests.test_state import ChildState, DateTimeState, GrandchildState, TestState +from tests.test_state import ChildState, ChildState2, DateTimeState, GrandchildState, TestState def mock_event(arg): @@ -506,40 +506,44 @@ def test_format_query_params(input, output): ( TestState().dict(), # type: ignore { - "array": [1, 2, 3.14], - "child_state": { + TestState.get_full_name(): { + "array": [1, 2, 3.14], + "complex": { + 1: {"prop1": 42, "prop2": "hello"}, + 2: {"prop1": 42, "prop2": "hello"}, + }, + "dt": "1989-11-09 18:53:00+01:00", + "fig": [], + "is_hydrated": False, + "key": "", + "map_key": "a", + "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, + "num1": 0, + "num2": 3.14, + "obj": {"prop1": 42, "prop2": "hello"}, + "sum": 3.14, + "upper": "", + "router": formatted_router, + }, + ChildState.get_full_name(): { "count": 23, - "grandchild_state": {"value2": ""}, "value": "", }, - "child_state2": {"value": ""}, - "complex": { - 1: {"prop1": 42, "prop2": "hello"}, - 2: {"prop1": 42, "prop2": "hello"}, - }, - "dt": "1989-11-09 18:53:00+01:00", - "fig": [], - "is_hydrated": False, - "key": "", - "map_key": "a", - "mapping": {"a": [1, 2, 3], "b": [4, 5, 6]}, - "num1": 0, - "num2": 3.14, - "obj": {"prop1": 42, "prop2": "hello"}, - "sum": 3.14, - "upper": "", - "router": formatted_router, + ChildState2.get_full_name(): {"value": ""}, + GrandchildState.get_full_name(): {"value2": ""}, }, ), ( DateTimeState().dict(), { - "d": "1989-11-09", - "dt": "1989-11-09 18:53:00+01:00", - "is_hydrated": False, - "t": "18:53:00+01:00", - "td": "11 days, 0:11:00", - "router": formatted_router, + DateTimeState.get_full_name(): { + "d": "1989-11-09", + "dt": "1989-11-09 18:53:00+01:00", + "is_hydrated": False, + "t": "18:53:00+01:00", + "td": "11 days, 0:11:00", + "router": formatted_router, + }, }, ), ], From aa574e557620a990465e5982e1b4252d5fbafc14 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 8 Nov 2023 09:57:41 -0800 Subject: [PATCH 09/29] Cond var must be _var_is_local=False It needs to render with curly braces wrapped around it --- reflex/components/layout/cond.py | 1 + 1 file changed, 1 insertion(+) diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 75395779af..5ceb5248f6 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -141,5 +141,6 @@ def cond(condition: Any, c1: Any, c2: Any = None): is_prop=True, ), _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1), + _var_is_local=False, _var_full_name_needs_state_prefix=False, ) From 846fe51ebc0b70c33eef8f0ba4fd3d5af82b8d68 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 8 Nov 2023 09:58:50 -0800 Subject: [PATCH 10/29] Component: do not mutate event_triggers Instead avoid rendering triggers in Component.render, but otherwise leave them alone so that components may be rendered multiple times without changing the output. --- reflex/components/component.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index d0e40b47dc..531081851e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -503,7 +503,11 @@ def render(self) -> Dict: tag = self._render() rendered_dict = dict( tag.add_props( - **self.event_triggers, + **{ + trigger: handler + for trigger, handler in self.event_triggers.items() + if trigger not in {EventTriggers.ON_MOUNT, EventTriggers.ON_UNMOUNT} + }, key=self.key, id=self.id, class_name=self.class_name, @@ -743,13 +747,13 @@ def _get_mount_lifecycle_hook(self) -> str | None: """ # pop on_mount and on_unmount from event_triggers since these are handled by # hooks, not as actually props in the component - on_mount = self.event_triggers.pop(EventTriggers.ON_MOUNT, None) - on_unmount = self.event_triggers.pop(EventTriggers.ON_UNMOUNT, None) - if on_mount: + on_mount = self.event_triggers.get(EventTriggers.ON_MOUNT, None) + on_unmount = self.event_triggers.get(EventTriggers.ON_UNMOUNT, None) + if on_mount is not None: on_mount = format.format_event_chain(on_mount) - if on_unmount: + if on_unmount is not None: on_unmount = format.format_event_chain(on_unmount) - if on_mount or on_unmount: + if on_mount is not None or on_unmount is not None: return f""" useEffect(() => {{ {on_mount or ""} From b9960d377b792db891adebc4a3c723e57af95926 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 8 Nov 2023 12:59:56 -0800 Subject: [PATCH 11/29] Fixup static issues --- reflex/components/component.py | 17 ++++++++++++----- reflex/components/datadisplay/datatable.py | 6 ++++-- reflex/components/forms/colormodeswitch.py | 4 ++-- reflex/components/layout/cond.py | 19 +++++++++++++++++-- reflex/components/layout/html.py | 4 ++-- reflex/components/overlay/banner.py | 2 +- reflex/style.py | 12 ++++++++---- reflex/vars.py | 19 +++++++++++-------- reflex/vars.pyi | 3 +++ tests/components/layout/test_cond.py | 2 +- tests/middleware/test_hydrate_middleware.py | 4 ++-- tests/utils/test_format.py | 8 +++++++- 12 files changed, 70 insertions(+), 30 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 531081851e..a89ed4ffe9 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -561,7 +561,7 @@ def validate_valid_child(child_name): @staticmethod def _get_vars_from_event_triggers( event_triggers: dict[str, EventChain | Var], - ) -> Iterator[str, list[Var]]: + ) -> Iterator[tuple[str, list[Var]]]: """Get the Vars associated with each event trigger. Args: @@ -592,8 +592,10 @@ def _get_vars(self) -> Iterator[Var]: if isinstance(self.contents, Var): yield self.contents else: - for _, vars in self._get_vars_from_event_triggers(self.event_triggers): - yield from vars + for _, event_vars in self._get_vars_from_event_triggers( + self.event_triggers + ): + yield from event_vars for prop in self.get_props(): prop_var = getattr(self, prop) @@ -690,7 +692,11 @@ def _get_dependencies_imports(self) -> imports.ImportDict: ) def _get_hooks_imports(self) -> imports.ImportDict: - """Get the imports required by certain hooks.""" + """Get the imports required by certain hooks. + + Returns: + The imports required for all selected hooks. + """ _imports = {} if self._get_ref_hook(): _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) @@ -1130,7 +1136,8 @@ class NoSSRComponent(Component): def _get_imports(self) -> imports.ImportDict: dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}} _imports = super()._get_imports() - _imports[self.library] = {ImportVar(tag=None, render=False)} + if self.library is not None: + _imports[self.library] = {ImportVar(tag=None, render=False)} return imports.merge_imports( dynamic_import, _imports, diff --git a/reflex/components/datadisplay/datatable.py b/reflex/components/datadisplay/datatable.py index a5a097a6c0..5b66fa54b9 100644 --- a/reflex/components/datadisplay/datatable.py +++ b/reflex/components/datadisplay/datatable.py @@ -113,11 +113,13 @@ def _render(self) -> Tag: self.columns = BaseVar( _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], - )._var_set_state(self.data._var_state) + _var_full_name_needs_state_prefix=True, + )._replace(merge_var_data=self.data._var_data) self.data = BaseVar( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], - )._var_set_state(self.data._var_state) + _var_full_name_needs_state_prefix=True, + )._replace(merge_var_data=self.data._var_data) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/components/forms/colormodeswitch.py b/reflex/components/forms/colormodeswitch.py index 0aeefcd33a..ca071a2a7a 100644 --- a/reflex/components/forms/colormodeswitch.py +++ b/reflex/components/forms/colormodeswitch.py @@ -22,7 +22,7 @@ from reflex.components.layout.cond import Cond, cond from reflex.components.media.icon import Icon from reflex.style import color_mode, toggle_color_mode -from reflex.vars import BaseVar +from reflex.vars import Var from .button import Button from .switch import Switch @@ -32,7 +32,7 @@ DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon") -def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: +def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: """Create a component or Prop based on color_mode. Args: diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 5ceb5248f6..4ef2320557 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -1,7 +1,7 @@ """Create a list of components from an iterable.""" from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, overload from reflex.components.component import Component from reflex.components.layout.fragment import Fragment @@ -88,6 +88,21 @@ def _get_imports(self) -> imports.ImportDict: ) +@overload +def cond(condition: Any, c1: Component, c2: Any) -> Component: + ... + + +@overload +def cond(condition: Any, c1: Component) -> Component: + ... + + +@overload +def cond(condition: Any, c1: Any, c2: Any) -> Var: + ... + + def cond(condition: Any, c1: Any, c2: Any = None): """Create a conditional component or Prop. @@ -109,7 +124,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): cond_var = Var.create(condition) assert cond_var is not None, "The condition must be set." cond_var = cond_var._replace( - merge_var_data=VarData( + merge_var_data=VarData( # type: ignore imports={ f"/{Dirs.STATE_PATH}": { imports.ImportVar(tag="isTrue"), diff --git a/reflex/components/layout/html.py b/reflex/components/layout/html.py index df155ffde0..3a4ba76ad5 100644 --- a/reflex/components/layout/html.py +++ b/reflex/components/layout/html.py @@ -1,5 +1,5 @@ """A html component.""" - +from typing import Dict from reflex.components.layout.box import Box from reflex.vars import Var @@ -13,7 +13,7 @@ class Html(Box): """ # The HTML to render. - dangerouslySetInnerHTML: Var[dict[str, str]] + dangerouslySetInnerHTML: Var[Dict[str, str]] @classmethod def create(cls, *children, **props): diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index db3f7dcb8e..86c07885bc 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -12,7 +12,7 @@ from reflex.utils import imports from reflex.vars import Var, VarData -connect_error_var_data = VarData( +connect_error_var_data = VarData( # type: ignore imports=Imports.EVENTS, hooks={Hooks.EVENTS}, ) diff --git a/reflex/style.py b/reflex/style.py index 6f2ed77544..fe957f4c8b 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -11,7 +11,7 @@ from reflex.vars import BaseVar, Var, VarData VarData.update_forward_refs() -color_mode_var_data = VarData( +color_mode_var_data = VarData( # type: ignore imports={ f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, }, @@ -44,13 +44,16 @@ def convert(style_dict): out = {} for key, value in style_dict.items(): key = format.to_camel_case(key) + new_var_data = None if isinstance(value, dict): out[key], new_var_data = convert(value) elif isinstance(value, Var): new_var_data = value._var_data out[key] = str(value) else: - new_var_data = Var.create(value)._var_data + new_var = Var.create(value) + if new_var is not None: + new_var_data = new_var._var_data out[key] = value var_data = VarData.merge(var_data, new_var_data) return out, var_data @@ -76,7 +79,7 @@ def update(self, style_dict: dict | None, **kwargs): kwargs: Other key value pairs to apply to the dict update. """ if kwargs: - style_dict = {**style_dict, **kwargs} + style_dict = {**(style_dict or {}), **kwargs} converted_dict = type(self)(style_dict) self._var_data = VarData.merge(self._var_data, converted_dict._var_data) super().update(converted_dict) @@ -89,5 +92,6 @@ def __setitem__(self, key: str, value: Any): value: The value to set. """ _var = Var.create(value) - self._var_data = VarData.merge(self._var_data, _var._var_data) + if _var is not None: + self._var_data = VarData.merge(self._var_data, _var._var_data) super().__setitem__(key, value) diff --git a/reflex/vars.py b/reflex/vars.py index cd46adcf4c..2a5db2917e 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -19,6 +19,7 @@ List, Literal, Optional, + Set, Tuple, Type, Union, @@ -101,10 +102,10 @@ class VarData(Base): state: str = "" # Imports needed to render this var - imports: dict[str, set[imports.ImportVar]] = {} + imports: Dict[str, Set[imports.ImportVar]] = {} # Hooks that need to be present in the component to render this var - hooks: set[str] = set() + hooks: Set[str] = set() @classmethod def merge(cls, *others: VarData | None) -> VarData | None: @@ -174,7 +175,7 @@ def _encode_var(value: Var) -> str: return str(value) -def _decode_var(value: str) -> tuple[VarData, str]: +def _decode_var(value: str) -> tuple[VarData | None, str]: """Decode the state name from a formatted var. Args: @@ -207,13 +208,15 @@ def _extract_var_data(value: Iterable) -> VarData | None: if isinstance(sub, Var): var_data = VarData.merge(var_data, sub._var_data) elif not isinstance(sub, str): - # Recurse into dict values + # Recurse into dict values. if hasattr(sub, "values") and callable(sub.values): var_data = VarData.merge(var_data, _extract_var_data(sub.values())) - # Recurse into iterable values (or dict keys) + # Recurse into iterable values (or dict keys). var_data = VarData.merge(var_data, _extract_var_data(sub)) - if hasattr(value, "values") and callable(value.values): - var_data = VarData.merge(var_data, _extract_var_data(value.values())) + # Recurse when value is a dict itself. + values = getattr(value, "values", None) + if callable(values): + var_data = VarData.merge(var_data, _extract_var_data(values())) return var_data @@ -1339,7 +1342,7 @@ def _var_full_name(self) -> str: return self._var_name return ( self._var_name - if self._var_data.state == "" + if self._var_data is None or self._var_data.state == "" else ".".join( [format.format_state_name(self._var_data.state), self._var_name] ) diff --git a/reflex/vars.pyi b/reflex/vars.pyi index fbe9efe13b..15fb3f19b0 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -6,6 +6,7 @@ from reflex import constants as constants from reflex.base import Base as Base from reflex.state import State as State from reflex.utils import console as console, format as format, types as types +from reflex.utils.imports import ImportVar from types import FunctionType from typing import ( Any, @@ -39,6 +40,7 @@ class Var: _var_type: Type _var_is_local: bool = False _var_is_string: bool = False + _var_full_name_needs_state_prefix: bool = False _var_data: VarData | None = None @classmethod def create( @@ -108,6 +110,7 @@ class BaseVar(Var): _var_type: Any _var_is_local: bool = False _var_is_string: bool = False + _var_full_name_needs_state_prefix: bool = False _var_data: VarData | None = None def __hash__(self) -> int: ... def get_default_value(self) -> Any: ... diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index 3bf373bb4d..00cf4de7d5 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -110,7 +110,7 @@ def test_cond_no_else(): # Props do not support the use of cond without else with pytest.raises(ValueError): - cond(True, "hello") + cond(True, "hello") # type: ignore def test_mobile_only(): diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 94f6e98cf2..7767dcf8b2 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -139,10 +139,10 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): # Apply the events. events = update.events update = await state._process(events[0]).__anext__() - assert update.delta == {'test_state': {'num': 1}} + assert update.delta == {"test_state": {"num": 1}} update = await state._process(events[1]).__anext__() - assert update.delta == {'test_state': {'num': 2}} + assert update.delta == {"test_state": {"num": 2}} update = await state._process(events[2]).__anext__() assert update.delta == exp_is_hydrated(state) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 9b5dc6505e..dc642d0594 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -8,7 +8,13 @@ from reflex.style import Style from reflex.utils import format from reflex.vars import BaseVar, Var -from tests.test_state import ChildState, ChildState2, DateTimeState, GrandchildState, TestState +from tests.test_state import ( + ChildState, + ChildState2, + DateTimeState, + GrandchildState, + TestState, +) def mock_event(arg): From 5157ed239889130e5bf83d73b645b5fab5f77b8e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 13 Nov 2023 11:57:00 -0800 Subject: [PATCH 12/29] client_side_routing: use `rx.cond` instead of the `Cond` component Using `rx.cond` helper ensures that the condition carries the `isTrue` import. --- reflex/components/navigation/client_side_routing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py index 22620e54b2..177c69bb42 100644 --- a/reflex/components/navigation/client_side_routing.py +++ b/reflex/components/navigation/client_side_routing.py @@ -13,7 +13,7 @@ from ...vars import Var from ..component import Component -from ..layout.cond import Cond +from ..layout.cond import cond route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND) @@ -52,10 +52,10 @@ def wait_for_client_redirect(component) -> Component: Returns: The conditionally rendered component. """ - return Cond.create( - cond=route_not_found, - comp1=component, - comp2=ClientSideRouting.create(), + return cond( + condition=route_not_found, + c1=component, + c2=ClientSideRouting.create(), ) From 06d8fa23ba94b6eb8c0a4e02eaf4b9cfe45b4122 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 13 Nov 2023 11:58:20 -0800 Subject: [PATCH 13/29] Fixup color_mode_toggle var Do not surround the formatted values with extra curly braces --- reflex/style.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/style.py b/reflex/style.py index fe957f4c8b..98e65894c3 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -16,7 +16,7 @@ f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, }, hooks={ - f"const [ {{{constants.ColorMode.NAME}}}, {{{constants.ColorMode.TOGGLE}}} ] = useContext(ColorModeContext)", + f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)", }, ) color_mode = BaseVar( From 78ddc79765bbe134e201c35db46df76ed15479ea Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 15 Nov 2023 13:49:54 -0800 Subject: [PATCH 14/29] Add tests for Component._get_vars Ensure that Vars used in different contexts in a Component are correctly identified and returned. --- tests/components/test_component.py | 163 ++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 2 deletions(-) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index e6dfe78b48..9ffe74699d 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -4,15 +4,16 @@ import reflex as rx from reflex.base import Base +from reflex.components.base.bare import Bare from reflex.components.component import Component, CustomComponent, custom_component from reflex.components.layout.box import Box from reflex.constants import EventTriggers -from reflex.event import EventHandler +from reflex.event import EventChain, EventHandler from reflex.state import State from reflex.style import Style from reflex.utils import imports from reflex.utils.imports import ImportVar -from reflex.vars import Var +from reflex.vars import Var, VarData @pytest.fixture @@ -601,3 +602,161 @@ def test_format_component(component, rendered): rendered: The expected rendered component. """ assert str(component) == rendered + + +TEST_VAR = Var.create_safe("test")._replace( + merge_var_data=VarData( + hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test" + ) +) +FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar") +STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False) +EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) +ARG_VAR = Var.create("arg") + + +class EventState(rx.State): + """State for testing event handlers with _get_vars.""" + + v: int = 42 + + def handler(self): + """A handler that does nothing.""" + + def handler2(self, arg): + """A handler that takes an arg. + + Args: + arg: An arg. + """ + + +@pytest.mark.parametrize( + ("component", "exp_vars"), + ( + pytest.param( + Bare.create(TEST_VAR), + [TEST_VAR], + id="direct-bare", + ), + pytest.param( + Bare.create(f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-bare", + ), + pytest.param( + rx.text(as_=TEST_VAR), + [TEST_VAR], + id="direct-prop", + ), + pytest.param( + rx.text(as_=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-prop", + ), + pytest.param( + rx.fragment(id=TEST_VAR), + [TEST_VAR], + id="direct-id", + ), + pytest.param( + rx.fragment(id=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-id", + ), + pytest.param( + rx.fragment(key=TEST_VAR), + [TEST_VAR], + id="direct-key", + ), + pytest.param( + rx.fragment(key=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-key", + ), + pytest.param( + rx.fragment(class_name=TEST_VAR), + [TEST_VAR], + id="direct-class_name", + ), + pytest.param( + rx.fragment(class_name=f"foo{TEST_VAR}bar"), + [FORMATTED_TEST_VAR], + id="fstring-class_name", + ), + pytest.param( + rx.fragment(special_props={TEST_VAR}), + [TEST_VAR], + id="direct-special_props", + ), + pytest.param( + rx.fragment(special_props={Var.create(f"foo{TEST_VAR}bar")}), + [FORMATTED_TEST_VAR], + id="fstring-special_props", + ), + pytest.param( + # custom_attrs cannot accept a Var directly as a value + rx.fragment(custom_attrs={"href": f"{TEST_VAR}"}), + [TEST_VAR], + id="fstring-custom_attrs-nofmt", + ), + pytest.param( + rx.fragment(custom_attrs={"href": f"foo{TEST_VAR}bar"}), + [FORMATTED_TEST_VAR], + id="fstring-custom_attrs", + ), + pytest.param( + rx.fragment(background_color=TEST_VAR), + [STYLE_VAR], + id="direct-background_color", + ), + pytest.param( + rx.fragment(background_color=f"foo{TEST_VAR}bar"), + [STYLE_VAR], + id="fstring-background_color", + ), + pytest.param( + rx.fragment(style={"background_color": TEST_VAR}), + [STYLE_VAR], + id="direct-style-background_color", + ), + pytest.param( + rx.fragment(style={"background_color": f"foo{TEST_VAR}bar"}), + [STYLE_VAR], + id="fstring-style-background_color", + ), + pytest.param( + rx.fragment(on_click=EVENT_CHAIN_VAR), + [EVENT_CHAIN_VAR], + id="direct-event-chain", + ), + pytest.param( + rx.fragment(on_click=EventState.handler), + [], + id="direct-event-handler", + ), + pytest.param( + rx.fragment(on_click=EventState.handler2(TEST_VAR)), # type: ignore + [ARG_VAR, TEST_VAR], + id="direct-event-handler-arg", + ), + pytest.param( + rx.fragment(on_click=EventState.handler2(EventState.v)), # type: ignore + [ARG_VAR, EventState.v], + id="direct-event-handler-arg2", + ), + pytest.param( + rx.fragment(on_click=lambda: EventState.handler2(TEST_VAR)), # type: ignore + [ARG_VAR, TEST_VAR], + id="direct-event-handler-lambda", + ), + ), +) +def test_get_vars(component, exp_vars): + comp_vars = sorted(component._get_vars(), key=lambda v: v._var_name) + assert len(comp_vars) == len(exp_vars) + for comp_var, exp_var in zip( + comp_vars, + sorted(exp_vars, key=lambda v: v._var_name), + ): + assert comp_var.equals(exp_var) From 62b93320964594c2f15ba8ff88bb5efa42503b63 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 15 Nov 2023 13:50:42 -0800 Subject: [PATCH 15/29] Apply CR feedback (from myself) --- reflex/.templates/web/utils/state.js | 2 +- reflex/components/component.py | 17 +++++++++++++++++ reflex/components/layout/cond.py | 11 ++++++----- reflex/vars.py | 11 +++++++++-- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index fe821b131e..aa818bf63a 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -510,7 +510,7 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { - const change_complete = () => addEvents(initialEvents()) + const change_complete = () => addEvents(initial_events()) router.events.on('routeChangeComplete', change_complete) return () => { router.events.off('routeChangeComplete', change_complete) diff --git a/reflex/components/component.py b/reflex/components/component.py index 64ca5218c2..48546d5746 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -615,6 +615,23 @@ def _get_vars(self) -> Iterator[Var]: _var_data=self.style._var_data, ) + yield from self.special_props + + for comp_prop in ( + self.class_name, + self.id, + self.key, + self.autofocus, + *self.custom_attrs.values(), + ): + if isinstance(comp_prop, Var): + yield comp_prop + elif isinstance(comp_prop, str): + # catch f-strings containing Vars + var = Var.create_safe(comp_prop) + if var._var_data is not None: + yield var + def _get_custom_code(self) -> str | None: """Get custom code for the component. diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 4ef2320557..2b883d8161 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -10,6 +10,10 @@ from reflex.utils import format, imports from reflex.vars import Var, VarData +_IS_TRUE_IMPORT = { + f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, +} + class Cond(Component): """Render one of two components based on a condition.""" @@ -85,6 +89,7 @@ def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), getattr(self.cond._var_data, "imports", {}), + _IS_TRUE_IMPORT, ) @@ -125,11 +130,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): assert cond_var is not None, "The condition must be set." cond_var = cond_var._replace( merge_var_data=VarData( # type: ignore - imports={ - f"/{Dirs.STATE_PATH}": { - imports.ImportVar(tag="isTrue"), - }, - }, + imports=_IS_TRUE_IMPORT, ), ) diff --git a/reflex/vars.py b/reflex/vars.py index 75f6f0be52..7020cad827 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -332,8 +332,15 @@ def __post_init__(self) -> None: self._var_data = VarData.merge(self._var_data, _var_data) def _replace(self, merge_var_data=None, **kwargs: Any) -> Var: - # Cannot use dataclasses.replace because ComputedVar uses multiple inheritance - # and it's __init__ has a required fget argument + """Make a copy of this Var with updated fields. + + Args: + merge_var_data: VarData to merge into the existing VarData. + **kwargs: Var fields to update. + + Returns: + A new BaseVar with the updated fields overwriting the corresponding fields in this Var. + """ field_values = dict( _var_name=kwargs.pop("_var_name", self._var_name), _var_type=kwargs.pop("_var_type", self._var_type), From a356bcb86422d328ed3b01be5d9f8a63b640f058 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 15 Nov 2023 13:56:11 -0800 Subject: [PATCH 16/29] ensure 'from reflex.var import ImportVar' keeps working --- reflex/vars.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/reflex/vars.py b/reflex/vars.py index 7020cad827..95a11265be 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -35,6 +35,9 @@ from reflex.base import Base from reflex.utils import console, format, imports, serializers, types +# This module used to export ImportVar itself, so we still import it for export here +from reflex.utils.imports import ImportVar as ImportVar + if TYPE_CHECKING: from reflex.state import State From 90159b0b7ec199ef7d5c2b14c2a7edd15be1caee Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sat, 18 Nov 2023 21:59:38 -0800 Subject: [PATCH 17/29] cond: Carry VarData before `format_cond` `format_cond` doesn't really work with f-string semantics due to a blind `{`, `}` removal which breaks the JSON-encoding of VarData. So instead, we carry the VarData explicitly on the cond_var and change `format_cond` to `str` the Var before formatting to ensure it does not emit a VarData-encoded f-string. --- reflex/components/layout/cond.py | 19 +++++++++++-------- reflex/utils/format.py | 10 ++++------ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 2b883d8161..2abe3219bf 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -8,7 +8,7 @@ from reflex.components.tags import CondTag, Tag from reflex.constants import Dirs from reflex.utils import format, imports -from reflex.vars import Var, VarData +from reflex.vars import BaseVar, Var, VarData _IS_TRUE_IMPORT = { f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, @@ -122,17 +122,15 @@ def cond(condition: Any, c1: Any, c2: Any = None): Raises: ValueError: If the arguments are invalid. """ - # Import here to avoid circular imports. - from reflex.vars import BaseVar, Var + var_datas: list[VarData | None] = [ + VarData( # type: ignore + imports=_IS_TRUE_IMPORT, + ), + ] # Convert the condition to a Var. cond_var = Var.create(condition) assert cond_var is not None, "The condition must be set." - cond_var = cond_var._replace( - merge_var_data=VarData( # type: ignore - imports=_IS_TRUE_IMPORT, - ), - ) # If the first component is a component, create a Cond component. if isinstance(c1, Component): @@ -140,6 +138,8 @@ def cond(condition: Any, c1: Any, c2: Any = None): c2, Component ), "Both arguments must be components." return Cond.create(cond_var, c1, c2) + elif isinstance(c1, Var): + var_datas.append(c1._var_data) # Otherwise, create a conditional Var. # Check that the second argument is valid. @@ -147,6 +147,8 @@ def cond(condition: Any, c1: Any, c2: Any = None): raise ValueError("Both arguments must be props.") if c2 is None: raise ValueError("For conditional vars, the second argument must be set.") + elif isinstance(c2, Var): + var_datas.append(c2._var_data) # Create the conditional var. return cond_var._replace( @@ -159,4 +161,5 @@ def cond(condition: Any, c1: Any, c2: Any = None): _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1), _var_is_local=False, _var_full_name_needs_state_prefix=False, + merge_var_data=VarData.merge(*var_datas), ) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 15a73249fe..089fc9d800 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -230,9 +230,9 @@ def format_route(route: str, format_case=True) -> str: def format_cond( - cond: str, - true_value: str, - false_value: str = '""', + cond: str | Var, + true_value: str | Var, + false_value: str | Var = '""', is_prop=False, ) -> str: """Format a conditional expression. @@ -246,9 +246,6 @@ def format_cond( Returns: The formatted conditional expression. """ - # Import here to avoid circular imports. - from reflex.vars import Var - # Use Python truthiness. cond = f"isTrue({cond})" @@ -264,6 +261,7 @@ def format_cond( _var_is_string=type(false_value) is str, ) prop2._var_is_local = True + prop1, prop2 = str(prop1), str(prop2) # avoid f-string semantics for Var return f"{cond} ? {prop1} : {prop2}".replace("{", "").replace("}", "") # Format component conds. From 1ee57023b1508765b118036f157d0e168e64e113 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sat, 18 Nov 2023 22:14:47 -0800 Subject: [PATCH 18/29] upload: uploadFiles is not actually a required import --- reflex/components/forms/upload.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index d69649f9f4..b7e4370e52 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -192,8 +192,5 @@ def _get_imports(self) -> imports.ImportDict: super()._get_imports(), { "react": {imports.ImportVar(tag="useState")}, - f"/{Dirs.STATE_PATH}": { - imports.ImportVar(tag="uploadFiles"), - }, }, ) From 1ef7dee472663f4a3f46a2f3e26d3b6864939d7a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sat, 18 Nov 2023 22:23:14 -0800 Subject: [PATCH 19/29] Update pyi files --- reflex/components/datadisplay/code.pyi | 3 ++- reflex/components/datadisplay/datatable.pyi | 2 +- reflex/components/datadisplay/moment.pyi | 2 +- reflex/components/forms/colormodeswitch.pyi | 4 ++-- reflex/components/forms/debounce.pyi | 3 ++- reflex/components/forms/editor.pyi | 3 ++- reflex/components/forms/form.pyi | 2 +- reflex/components/forms/input.pyi | 2 +- reflex/components/forms/upload.pyi | 3 ++- reflex/components/layout/html.pyi | 7 +++++-- reflex/components/libs/chakra.pyi | 2 +- reflex/components/navigation/client_side_routing.pyi | 2 +- reflex/components/overlay/banner.py | 2 +- reflex/components/overlay/banner.pyi | 7 ++++--- reflex/components/radix/themes/base.pyi | 2 +- reflex/components/typography/markdown.pyi | 3 ++- 16 files changed, 29 insertions(+), 20 deletions(-) diff --git a/reflex/components/datadisplay/code.pyi b/reflex/components/datadisplay/code.pyi index 79e756ed0a..5a88faeb3c 100644 --- a/reflex/components/datadisplay/code.pyi +++ b/reflex/components/datadisplay/code.pyi @@ -16,7 +16,8 @@ from reflex.components.media import Icon from reflex.event import set_clipboard from reflex.style import Style from reflex.utils import format, imports -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var LiteralCodeBlockTheme = Literal[ "a11y-dark", diff --git a/reflex/components/datadisplay/datatable.pyi b/reflex/components/datadisplay/datatable.pyi index 1a119dcc95..49e3ad7528 100644 --- a/reflex/components/datadisplay/datatable.pyi +++ b/reflex/components/datadisplay/datatable.pyi @@ -12,7 +12,7 @@ from reflex.components.component import Component from reflex.components.tags import Tag from reflex.utils import imports, types from reflex.utils.serializers import serialize, serializer -from reflex.vars import BaseVar, ComputedVar, ImportVar, Var +from reflex.vars import BaseVar, ComputedVar, Var class Gridjs(Component): @overload diff --git a/reflex/components/datadisplay/moment.pyi b/reflex/components/datadisplay/moment.pyi index 0e9fcc4c4c..2c76960372 100644 --- a/reflex/components/datadisplay/moment.pyi +++ b/reflex/components/datadisplay/moment.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import Any, Dict, List from reflex.components.component import Component, NoSSRComponent from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Moment(NoSSRComponent): def get_event_triggers(self) -> Dict[str, Any]: ... diff --git a/reflex/components/forms/colormodeswitch.pyi b/reflex/components/forms/colormodeswitch.pyi index 478bf5c37c..83af8f20ce 100644 --- a/reflex/components/forms/colormodeswitch.pyi +++ b/reflex/components/forms/colormodeswitch.pyi @@ -12,7 +12,7 @@ from reflex.components.component import Component from reflex.components.layout.cond import Cond, cond from reflex.components.media.icon import Icon from reflex.style import color_mode, toggle_color_mode -from reflex.vars import BaseVar +from reflex.vars import Var from .button import Button from .switch import Switch @@ -20,7 +20,7 @@ DEFAULT_COLOR_MODE: str DEFAULT_LIGHT_ICON: Icon DEFAULT_DARK_ICON: Icon -def color_mode_cond(light: Any, dark: Any = None) -> BaseVar | Component: ... +def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: ... class ColorModeIcon(Cond): @overload diff --git a/reflex/components/forms/debounce.pyi b/reflex/components/forms/debounce.pyi index 8c2688f942..975aa8be23 100644 --- a/reflex/components/forms/debounce.pyi +++ b/reflex/components/forms/debounce.pyi @@ -7,9 +7,10 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any +from typing import Any, Set from reflex.components import Component from reflex.components.tags import Tag +from reflex.utils import imports from reflex.vars import Var class DebounceInput(Component): diff --git a/reflex/components/forms/editor.pyi b/reflex/components/forms/editor.pyi index 2b806fdef3..1eaeea0384 100644 --- a/reflex/components/forms/editor.pyi +++ b/reflex/components/forms/editor.pyi @@ -13,7 +13,8 @@ from reflex.base import Base from reflex.components.component import Component, NoSSRComponent from reflex.constants import EventTriggers from reflex.utils.format import to_camel_case -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var class EditorButtonList(list, enum.Enum): BASIC = [["font", "fontSize"], ["fontColor"], ["horizontalRule"], ["link", "image"]] diff --git a/reflex/components/forms/form.pyi b/reflex/components/forms/form.pyi index 827b2273d6..e1a9c325a0 100644 --- a/reflex/components/forms/form.pyi +++ b/reflex/components/forms/form.pyi @@ -12,7 +12,7 @@ from jinja2 import Environment from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent from reflex.components.tags import Tag -from reflex.constants import EventTriggers +from reflex.constants import Dirs, EventTriggers from reflex.event import EventChain from reflex.utils import imports from reflex.utils.format import format_event_chain, to_camel_case diff --git a/reflex/components/forms/input.pyi b/reflex/components/forms/input.pyi index 6ec79fe1f5..d49a6a6dfd 100644 --- a/reflex/components/forms/input.pyi +++ b/reflex/components/forms/input.pyi @@ -17,7 +17,7 @@ from reflex.components.libs.chakra import ( ) from reflex.constants import EventTriggers from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class Input(ChakraComponent): def get_event_triggers(self) -> Dict[str, Any]: ... diff --git a/reflex/components/forms/upload.pyi b/reflex/components/forms/upload.pyi index 87e5d58c75..5486ee90d8 100644 --- a/reflex/components/forms/upload.pyi +++ b/reflex/components/forms/upload.pyi @@ -12,9 +12,10 @@ from reflex import constants from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box +from reflex.constants import Dirs from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script from reflex.utils import imports -from reflex.vars import BaseVar, CallableVar, ImportVar, Var +from reflex.vars import BaseVar, CallableVar, Var, VarData DEFAULT_UPLOAD_ID: str diff --git a/reflex/components/layout/html.pyi b/reflex/components/layout/html.pyi index ec513ca29d..aec46e2f97 100644 --- a/reflex/components/layout/html.pyi +++ b/reflex/components/layout/html.pyi @@ -7,8 +7,9 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style -from typing import Any +from typing import Dict from reflex.components.layout.box import Box +from reflex.vars import Var class Html(Box): @overload @@ -16,7 +17,9 @@ class Html(Box): def create( # type: ignore cls, *children, - dangerouslySetInnerHTML: Optional[Any] = None, + dangerouslySetInnerHTML: Optional[ + Union[Var[Dict[str, str]], Dict[str, str]] + ] = None, element: Optional[Union[Var[str], str]] = None, src: Optional[Union[Var[str], str]] = None, alt: Optional[Union[Var[str], str]] = None, diff --git a/reflex/components/libs/chakra.pyi b/reflex/components/libs/chakra.pyi index 4679c3f8d6..031ec2c48e 100644 --- a/reflex/components/libs/chakra.pyi +++ b/reflex/components/libs/chakra.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import List, Literal from reflex.components.component import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var class ChakraComponent(Component): @overload diff --git a/reflex/components/navigation/client_side_routing.pyi b/reflex/components/navigation/client_side_routing.pyi index 100d0adb21..92f8cd611b 100644 --- a/reflex/components/navigation/client_side_routing.pyi +++ b/reflex/components/navigation/client_side_routing.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from reflex import constants from ...vars import Var from ..component import Component -from ..layout.cond import Cond +from ..layout.cond import cond route_not_found: Var diff --git a/reflex/components/overlay/banner.py b/reflex/components/overlay/banner.py index 86c07885bc..1da9a6749e 100644 --- a/reflex/components/overlay/banner.py +++ b/reflex/components/overlay/banner.py @@ -12,7 +12,7 @@ from reflex.utils import imports from reflex.vars import Var, VarData -connect_error_var_data = VarData( # type: ignore +connect_error_var_data: VarData = VarData( # type: ignore imports=Imports.EVENTS, hooks={Hooks.EVENTS}, ) diff --git a/reflex/components/overlay/banner.pyi b/reflex/components/overlay/banner.pyi index 4e855ae1d7..db3cafd04e 100644 --- a/reflex/components/overlay/banner.pyi +++ b/reflex/components/overlay/banner.pyi @@ -10,15 +10,16 @@ from reflex.style import Style from typing import Optional from reflex.components.base.bare import Bare from reflex.components.component import Component -from reflex.components.layout import Box, Cond +from reflex.components.layout import Box, cond from reflex.components.overlay.modal import Modal from reflex.components.typography import Text +from reflex.constants import Hooks, Imports from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var, VarData +connect_error_var_data: VarData connection_error: Var has_connection_error: Var -has_connection_error._var_type = bool class WebsocketTargetURL(Bare): @overload diff --git a/reflex/components/radix/themes/base.pyi b/reflex/components/radix/themes/base.pyi index eb8b8cb30e..9f840f2a88 100644 --- a/reflex/components/radix/themes/base.pyi +++ b/reflex/components/radix/themes/base.pyi @@ -10,7 +10,7 @@ from reflex.style import Style from typing import Literal from reflex.components import Component from reflex.utils import imports -from reflex.vars import ImportVar, Var +from reflex.vars import Var LiteralAlign = Literal["start", "center", "end", "baseline", "stretch"] LiteralJustify = Literal["start", "center", "end", "between"] diff --git a/reflex/components/typography/markdown.pyi b/reflex/components/typography/markdown.pyi index ee11bf2557..cb9140435a 100644 --- a/reflex/components/typography/markdown.pyi +++ b/reflex/components/typography/markdown.pyi @@ -18,7 +18,8 @@ from reflex.components.typography.heading import Heading from reflex.components.typography.text import Text from reflex.style import Style from reflex.utils import console, imports, types -from reflex.vars import ImportVar, Var +from reflex.utils.imports import ImportVar +from reflex.vars import Var _CHILDREN = Var.create_safe("children", _var_is_local=False) _PROPS = Var.create_safe("...props", _var_is_local=False) From 006ca6691a933e6abe07b857b3e0eae66182b43d Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 16:27:03 -0800 Subject: [PATCH 20/29] Move Bare special case into Bare class --- reflex/components/base/bare.py | 12 +++++++++++- reflex/components/component.py | 20 ++++++-------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 021cd4cf05..ee66ecd4d5 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -1,7 +1,7 @@ """A bare component.""" from __future__ import annotations -from typing import Any +from typing import Any, Iterator from reflex.components.component import Component from reflex.components.tags import Tag @@ -32,3 +32,13 @@ def create(cls, contents: Any) -> Component: def _render(self) -> Tag: return Tagless(contents=str(self.contents)) + + def _get_vars(self) -> Iterator[Var]: + """Walk all Vars used in this component. + + Yields: + The contents if it is a Var, otherwise nothing. + """ + if isinstance(self.contents, Var): + # Fast path for Bare text components. + yield self.contents diff --git a/reflex/components/component.py b/reflex/components/component.py index 243e57e65b..1bacdb0083 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -594,21 +594,13 @@ def _get_vars(self) -> Iterator[Var]: Yields: Each var referenced by the component (props, styles, event handlers). """ - from reflex.components.base.bare import Bare + for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): + yield from event_vars - if isinstance(self, Bare): - if isinstance(self.contents, Var): - yield self.contents - else: - for _, event_vars in self._get_vars_from_event_triggers( - self.event_triggers - ): - yield from event_vars - - for prop in self.get_props(): - prop_var = getattr(self, prop) - if isinstance(prop_var, Var): - yield prop_var + for prop in self.get_props(): + prop_var = getattr(self, prop) + if isinstance(prop_var, Var): + yield prop_var if self.style: yield BaseVar( From 3c54c669f274e658e39fb2cc09e06194ff0f8727 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 16:31:58 -0800 Subject: [PATCH 21/29] use the | operator instead of union --- reflex/components/component.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 1bacdb0083..59dbbb44e6 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -851,9 +851,9 @@ def _get_hooks_internal(self) -> Set[str]: for hook in [self._get_mount_lifecycle_hook(), self._get_ref_hook()] if hook ) - .union(self._get_vars_hooks()) - .union(self._get_events_hooks()) - .union(self._get_special_hooks()) + | self._get_vars_hooks() + | self._get_events_hooks() + | self._get_special_hooks() ) def _get_hooks(self) -> str | None: From 475694b2505714f4e1528921310c5415189464d0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 16:45:07 -0800 Subject: [PATCH 22/29] remove weird indentation --- reflex/components/forms/upload.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index b7e4370e52..89954442cb 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -181,11 +181,8 @@ def _render(self): def _get_hooks(self) -> str | None: return ( - (super()._get_hooks() or "") - + f""" - upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]); - """ - ) + super()._get_hooks() or "" + ) + f"upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);" def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( From 0f36629483397ca5283a8a83e999d1814e612741 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 16:48:23 -0800 Subject: [PATCH 23/29] cond: remove unnecessary `elif` --- reflex/components/layout/cond.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reflex/components/layout/cond.py b/reflex/components/layout/cond.py index 2abe3219bf..1a5cb24998 100644 --- a/reflex/components/layout/cond.py +++ b/reflex/components/layout/cond.py @@ -138,7 +138,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): c2, Component ), "Both arguments must be components." return Cond.create(cond_var, c1, c2) - elif isinstance(c1, Var): + if isinstance(c1, Var): var_datas.append(c1._var_data) # Otherwise, create a conditional Var. @@ -147,7 +147,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): raise ValueError("Both arguments must be props.") if c2 is None: raise ValueError("For conditional vars, the second argument must be set.") - elif isinstance(c2, Var): + if isinstance(c2, Var): var_datas.append(c2._var_data) # Create the conditional var. From 96b6090ae4124c62023908b70629cbf6fe404c8e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 16:50:23 -0800 Subject: [PATCH 24/29] remove relative imports --- reflex/components/navigation/client_side_routing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/reflex/components/navigation/client_side_routing.py b/reflex/components/navigation/client_side_routing.py index 177c69bb42..99f0f6cfd6 100644 --- a/reflex/components/navigation/client_side_routing.py +++ b/reflex/components/navigation/client_side_routing.py @@ -10,10 +10,9 @@ from __future__ import annotations from reflex import constants - -from ...vars import Var -from ..component import Component -from ..layout.cond import cond +from reflex.components.component import Component +from reflex.components.layout.cond import cond +from reflex.vars import Var route_not_found: Var = Var.create_safe(constants.ROUTE_NOT_FOUND) From 70cb5174bdf4d5c994b70de53a7aab0cf93eabc8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 17:02:40 -0800 Subject: [PATCH 25/29] style: improve comments and explainations --- reflex/style.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/reflex/style.py b/reflex/style.py index 98e65894c3..d67029992c 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -10,20 +10,25 @@ from reflex.utils.imports import ImportVar from reflex.vars import BaseVar, Var, VarData -VarData.update_forward_refs() +VarData.update_forward_refs() # Ensure all type definitions are resolved + +# Reference the global ColorModeContext color_mode_var_data = VarData( # type: ignore imports={ f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, + "react": {ImportVar(tag="useContext")}, }, hooks={ f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)", }, ) +# Var resolves to the current color mode for the app ("light" or "dark") color_mode = BaseVar( _var_name=constants.ColorMode.NAME, _var_type="str", _var_data=color_mode_var_data, ) +# Var resolves to a function invocation that toggles the color mode toggle_color_mode = BaseVar( _var_name=constants.ColorMode.TOGGLE, _var_type=EventChain, @@ -40,21 +45,25 @@ def convert(style_dict): Returns: The formatted style dictionary. """ - var_data = None + var_data = None # Track import/hook data from any Vars in the style dict. out = {} for key, value in style_dict.items(): key = format.to_camel_case(key) new_var_data = None if isinstance(value, dict): + # Recursively format nested style dictionaries. out[key], new_var_data = convert(value) elif isinstance(value, Var): + # If the value is a Var, extract the var_data and cast as str. new_var_data = value._var_data out[key] = str(value) else: + # Otherwise, convert to Var to collapse VarData encoded in f-string. new_var = Var.create(value) if new_var is not None: new_var_data = new_var._var_data out[key] = value + # Combine all the collected VarData instances. var_data = VarData.merge(var_data, new_var_data) return out, var_data @@ -81,6 +90,7 @@ def update(self, style_dict: dict | None, **kwargs): if kwargs: style_dict = {**(style_dict or {}), **kwargs} converted_dict = type(self)(style_dict) + # Combine our VarData with that of any Vars in the style_dict that was passed. self._var_data = VarData.merge(self._var_data, converted_dict._var_data) super().update(converted_dict) @@ -91,7 +101,9 @@ def __setitem__(self, key: str, value: Any): key: The key to set. value: The value to set. """ + # Create a Var to collapse VarData encoded in f-string. _var = Var.create(value) if _var is not None: + # Carry the imports/hooks when setting a Var as a value. self._var_data = VarData.merge(self._var_data, _var._var_data) super().__setitem__(key, value) From 569ff864d519fafb4093e0179862965690a053f8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 17:20:01 -0800 Subject: [PATCH 26/29] component: add comments and give the code some space to breathe --- reflex/components/component.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 59dbbb44e6..210c7e16e5 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -594,14 +594,17 @@ def _get_vars(self) -> Iterator[Var]: Yields: Each var referenced by the component (props, styles, event handlers). """ + # Get Vars associated with event trigger arguments. for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): yield from event_vars + # Get Vars associated with component props. for prop in self.get_props(): prop_var = getattr(self, prop) if isinstance(prop_var, Var): yield prop_var + # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded. if self.style: yield BaseVar( _var_name="style", @@ -609,8 +612,10 @@ def _get_vars(self) -> Iterator[Var]: _var_data=self.style._var_data, ) + # Special props are always Var instances. yield from self.special_props + # Get Vars associated with common Component props. for comp_prop in ( self.class_name, self.id, @@ -621,7 +626,7 @@ def _get_vars(self) -> Iterator[Var]: if isinstance(comp_prop, Var): yield comp_prop elif isinstance(comp_prop, str): - # catch f-strings containing Vars + # Collapse VarData encoded in f-strings. var = Var.create_safe(comp_prop) if var._var_data is not None: yield var @@ -715,12 +720,18 @@ def _get_hooks_imports(self) -> imports.ImportDict: The imports required for all selected hooks. """ _imports = {} + if self._get_ref_hook(): + # Handle hooks needed for attaching react refs to DOM nodes. _imports.setdefault("react", set()).add(ImportVar(tag="useRef")) _imports.setdefault(f"/{Dirs.STATE_PATH}", set()).add(ImportVar(tag="refs")) + if self._get_mount_lifecycle_hook(): + # Handle hooks for `on_mount` / `on_unmount`. _imports.setdefault("react", set()).add(ImportVar(tag="useEffect")) + if self._get_special_hooks(): + # Handle additional internal hooks (autofocus, etc). _imports.setdefault("react", set()).update( { ImportVar(tag="useRef"), @@ -736,13 +747,19 @@ def _get_imports(self) -> imports.ImportDict: The imports needed by the component. """ _imports = {} + + # Import this component's tag from the main library. if self.library is not None and self.tag is not None: _imports[self.library] = {self.import_var} + + # Get static imports required for event processing. event_imports = Imports.EVENTS if self.event_triggers else {} - # determine imports from Vars + + # Collect imports from Vars used directly by this component. var_imports = [ var._var_data.imports for var in self._get_vars() if var._var_data ] + return imports.merge_imports( self._get_props_imports(), self._get_dependencies_imports(), @@ -1152,10 +1169,21 @@ class NoSSRComponent(Component): """A dynamic component that is not rendered on the server.""" def _get_imports(self) -> imports.ImportDict: + """Get the imports for the component. + + Returns: + The imports for dynamically importing the component at module load time. + """ + # Next.js dynamic import mechanism. dynamic_import = {"next/dynamic": {ImportVar(tag="dynamic", is_default=True)}} + + # The normal imports for this component. _imports = super()._get_imports() + + # Do NOT import the main library/tag statically. if self.library is not None: _imports[self.library] = {ImportVar(tag=None, render=False)} + return imports.merge_imports( dynamic_import, _imports, From 00d137ef61a485867c016697d3f0ef200832cc92 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 17:46:40 -0800 Subject: [PATCH 27/29] Remove unused NoRenderImportVar --- reflex/utils/imports.py | 6 ------ reflex/vars.py | 6 ------ reflex/vars.pyi | 3 --- 3 files changed, 15 deletions(-) diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 0a4dd589ba..03f006917e 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -61,10 +61,4 @@ def __hash__(self) -> int: return hash((self.tag, self.is_default, self.alias, self.install, self.render)) -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - - render: Optional[bool] = False - - ImportDict = Dict[str, Set[ImportVar]] diff --git a/reflex/vars.py b/reflex/vars.py index 552ffb52a5..f2e1b63598 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1772,12 +1772,6 @@ def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: return cvar -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - - render: Optional[bool] = False - - class CallableVar(BaseVar): """Decorate a Var-returning function to act as both a Var and a function. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index ef4b154a8c..8105aef1c5 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -138,9 +138,6 @@ class ComputedVar(Var): def cached_var(fget: Callable[[Any], Any]) -> ComputedVar: ... -class NoRenderImportVar(ImportVar): - """A import that doesn't need to be rendered.""" - class CallableVar(BaseVar): def __init__(self, fn: Callable[..., BaseVar]): ... def __call__(self, *args, **kwargs) -> BaseVar: ... From bf385fb10e898260e619f1507b1406e46d861acc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 17:50:31 -0800 Subject: [PATCH 28/29] client_side_routing.pyi: remove relative imports --- reflex/components/navigation/client_side_routing.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/reflex/components/navigation/client_side_routing.pyi b/reflex/components/navigation/client_side_routing.pyi index 92f8cd611b..b7801246e5 100644 --- a/reflex/components/navigation/client_side_routing.pyi +++ b/reflex/components/navigation/client_side_routing.pyi @@ -8,9 +8,9 @@ from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style from reflex import constants -from ...vars import Var -from ..component import Component -from ..layout.cond import cond +from reflex.components.component import Component +from reflex.components.layout.cond import cond +from reflex.vars import Var route_not_found: Var From 1d7878834daec70b61fa1153435628fa99dc2bc2 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 20 Nov 2023 22:49:50 -0800 Subject: [PATCH 29/29] Performance optimizations * Do not cast to Style() early in `Component.add_style` * Memoize return value of `Component._get_vars` * Defer `VarData.merge` for most operations -- call it once * Avoid `serializers.serialize` for primitive JSON types --- reflex/components/component.py | 31 +++++++++++++++++++------------ reflex/utils/types.py | 1 + reflex/vars.py | 24 +++++++++++++----------- reflex/vars.pyi | 2 +- 4 files changed, 34 insertions(+), 24 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index c269448fc9..70097cf604 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -493,7 +493,7 @@ def add_style(self, style: ComponentStyle) -> Component: """ if type(self) in style: # Extract the style for this component. - component_style = Style(style[type(self)]) + component_style = style[type(self)] # Only add style props that are not overridden. component_style = { @@ -591,32 +591,38 @@ def _get_vars_from_event_triggers( event_args.extend(args) yield event_trigger, event_args - def _get_vars(self) -> Iterator[Var]: + def _get_vars(self) -> list[Var]: """Walk all Vars used in this component. - Yields: + Returns: Each var referenced by the component (props, styles, event handlers). """ + vars = getattr(self, "__vars", None) + if vars is not None: + return vars + vars = self.__vars = [] # Get Vars associated with event trigger arguments. for _, event_vars in self._get_vars_from_event_triggers(self.event_triggers): - yield from event_vars + vars.extend(event_vars) # Get Vars associated with component props. for prop in self.get_props(): prop_var = getattr(self, prop) if isinstance(prop_var, Var): - yield prop_var + vars.append(prop_var) # Style keeps track of its own VarData instance, so embed in a temp Var that is yielded. if self.style: - yield BaseVar( - _var_name="style", - _var_type=str, - _var_data=self.style._var_data, + vars.append( + BaseVar( + _var_name="style", + _var_type=str, + _var_data=self.style._var_data, + ) ) # Special props are always Var instances. - yield from self.special_props + vars.extend(self.special_props) # Get Vars associated with common Component props. for comp_prop in ( @@ -627,12 +633,13 @@ def _get_vars(self) -> Iterator[Var]: *self.custom_attrs.values(), ): if isinstance(comp_prop, Var): - yield comp_prop + vars.append(comp_prop) elif isinstance(comp_prop, str): # Collapse VarData encoded in f-strings. var = Var.create_safe(comp_prop) if var._var_data is not None: - yield var + vars.append(var) + return vars def _get_custom_code(self) -> str | None: """Get custom code for the component. diff --git a/reflex/utils/types.py b/reflex/utils/types.py index c114e94ac7..8066849549 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -27,6 +27,7 @@ GenericType = Union[Type, _GenericAlias] # Valid state var types. +JSONType = {str, int, float, bool} PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple] StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] diff --git a/reflex/vars.py b/reflex/vars.py index 922ce4cb92..d4d785b7ab 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -192,34 +192,36 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: while m := re.match(r"(.*)(.*)(.*)", value): value = m.group(1) + m.group(3) var_datas.append(VarData.parse_raw(m.group(2))) - return VarData.merge(*var_datas), value + if var_datas: + return VarData.merge(*var_datas), value + return None, value -def _extract_var_data(value: Iterable) -> VarData | None: +def _extract_var_data(value: Iterable) -> list[VarData | None]: """Extract the var imports and hooks from an iterable containing a Var. Args: value: The iterable to extract the VarData from Returns: - The extracted VarData. + The extracted VarDatas. """ - var_data = None + var_datas = [] with contextlib.suppress(TypeError): for sub in value: if isinstance(sub, Var): - var_data = VarData.merge(var_data, sub._var_data) + var_datas.append(sub._var_data) elif not isinstance(sub, str): # Recurse into dict values. if hasattr(sub, "values") and callable(sub.values): - var_data = VarData.merge(var_data, _extract_var_data(sub.values())) + var_datas.extend(_extract_var_data(sub.values())) # Recurse into iterable values (or dict keys). - var_data = VarData.merge(var_data, _extract_var_data(sub)) + var_datas.extend(_extract_var_data(sub)) # Recurse when value is a dict itself. values = getattr(value, "values", None) if callable(values): - var_data = VarData.merge(var_data, _extract_var_data(values())) - return var_data + var_datas.extend(_extract_var_data(values())) + return var_datas class Var: @@ -271,11 +273,11 @@ def create( # Try to pull the imports and hooks from contained values. _var_data = None if not isinstance(value, str): - _var_data = _extract_var_data(value) + _var_data = VarData.merge(*_extract_var_data(value)) # Try to serialize the value. type_ = type(value) - name = serializers.serialize(value) + name = value if type_ in types.JSONType else serializers.serialize(value) if name is None: raise TypeError( f"No JSON serializer found for var {value} of type {type_}." diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 8105aef1c5..9208013db6 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -26,7 +26,7 @@ USED_VARIABLES: Incomplete def get_unique_variable_name() -> str: ... def _encode_var(value: Var) -> str: ... def _decode_var(value: str) -> tuple[VarData, str]: ... -def _extract_var_data(value: Iterable) -> VarData | None: ... +def _extract_var_data(value: Iterable) -> list[VarData | None]: ... class VarData(Base): state: str