From aa8f3c77be5c41b9002a1771ff0d58c81e50301f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 7 May 2024 15:49:24 -0700 Subject: [PATCH 1/2] [REF-2787] add_hooks supports Var-wrapped hooks * Fix VarData definition in .pyi file to allow removal of type ignore comments * Var.create and Var.create_safe accept _var_data parameter * Replace instances where a set of imports was being passed to VarData * Update code throughout reduce use of `._replace` to add VarData --- reflex/components/component.py | 57 +++++++++++++++---- reflex/components/core/banner.py | 12 ++-- reflex/components/core/cond.py | 2 +- reflex/components/core/debounce.py | 6 +- reflex/components/core/upload.py | 19 +++---- reflex/components/el/elements/forms.py | 18 ++++-- reflex/components/gridjs/datatable.py | 6 +- .../radix/themes/components/tabs.py | 2 +- reflex/components/sonner/toast.py | 17 +++--- reflex/constants/compiler.py | 6 +- reflex/style.py | 6 +- reflex/vars.py | 18 ++++-- reflex/vars.pyi | 12 ++-- tests/components/test_component.py | 40 ++++++++++++- 14 files changed, 158 insertions(+), 63 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 65b2cfa772..cd2f4644a0 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -241,7 +241,7 @@ def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: """ return {} - def add_hooks(self) -> list[str]: + def add_hooks(self) -> list[str | Var]: """Add hooks inside the component function. Hooks are pieces of literal Javascript code that is inserted inside the @@ -1268,11 +1268,20 @@ def _get_hooks_imports(self) -> imports.ImportDict: }, ) + other_imports = [] user_hooks = self._get_hooks() - if user_hooks is not None and isinstance(user_hooks, Var): - _imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore + if ( + user_hooks is not None + and isinstance(user_hooks, Var) + and user_hooks._var_data is not None + and user_hooks._var_data.imports + ): + other_imports.extend(user_hooks._var_data.imports) + other_imports.extend( + hook_imports for hook_imports in self._get_added_hooks().values() + ) - return _imports + return imports.merge_imports(_imports, *other_imports) def _get_imports(self) -> imports.ImportDict: """Get all the libraries and fields that are used by the component. @@ -1419,6 +1428,36 @@ def _get_hooks_internal(self) -> dict[str, None]: **self._get_special_hooks(), } + def _get_added_hooks(self) -> dict[str, imports.ImportDict]: + """Get the hooks added via `add_hooks` method. + + Returns: + The deduplicated hooks and imports added by the component and parent components. + """ + code = {} + + def extract_var_hooks(hook: Var): + _imports = {} + if hook._var_data is not None: + for sub_hook in hook._var_data.hooks: + code[sub_hook] = {} + if hook._var_data.imports: + _imports = hook._var_data.imports + if str(hook) in code: + code[str(hook)] = imports.merge_imports(code[str(hook)], _imports) + else: + code[str(hook)] = _imports + + # Add the hook code from add_hooks for each parent class (this is reversed to preserve + # the order of the hooks in the final output) + for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): + for hook in clz.add_hooks(self): + if isinstance(hook, Var): + extract_var_hooks(hook) + else: + code[hook] = {} + return code + def _get_hooks(self) -> str | None: """Get the React hooks for this component. @@ -1457,11 +1496,7 @@ def _get_all_hooks(self) -> dict[str, None]: if hooks is not None: code[hooks] = None - # Add the hook code from add_hooks for each parent class (this is reversed to preserve - # the order of the hooks in the final output) - for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): - for hook in clz.add_hooks(self): - code[hook] = None + code.update(self._get_added_hooks()) # Add the hook code for the children. for child in self.children: @@ -2095,8 +2130,8 @@ def _get_memoized_event_triggers( var_deps.extend(cls._get_hook_deps(hook)) memo_var_data = VarData.merge( *[var._var_data for var in event_args], - VarData( # type: ignore - imports={"react": {ImportVar(tag="useCallback")}}, + VarData( + imports={"react": [ImportVar(tag="useCallback")]}, ), ) diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index 0c781fba8b..c2fe3e6886 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -29,23 +29,27 @@ value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''", _var_is_local=False, _var_is_string=False, -)._replace(merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +) connection_errors_count: Var = Var.create_safe( value="connectErrors.length", _var_is_string=False, _var_is_local=False, -)._replace(merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +) has_connection_errors: Var = Var.create_safe( value="connectErrors.length > 0", _var_is_string=False, -)._replace(_var_type=bool, merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +).to(bool) has_too_many_connection_errors: Var = Var.create_safe( value="connectErrors.length >= 2", _var_is_string=False, -)._replace(_var_type=bool, merge_var_data=connect_error_var_data) + _var_data=connect_error_var_data, +).to(bool) class WebsocketTargetURL(Bare): diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 0e3e436725..9ace92b98b 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -13,7 +13,7 @@ from reflex.vars import BaseVar, Var, VarData _IS_TRUE_IMPORT = { - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")}, + f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")], } diff --git a/reflex/components/core/debounce.py b/reflex/components/core/debounce.py index e24a6563db..7b3fd60180 100644 --- a/reflex/components/core/debounce.py +++ b/reflex/components/core/debounce.py @@ -109,13 +109,11 @@ def create(cls, *children: Component, **props: Any) -> Component: "{%s}" % (child.alias or child.tag), _var_is_local=False, _var_is_string=False, - )._replace( - _var_type=Type[Component], - merge_var_data=VarData( # type: ignore + _var_data=VarData( imports=child._get_imports(), hooks=child._get_hooks_internal(), ), - ), + ).to(Type[Component]), ) component = super().create(**props) diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index b3ac37c15d..65c441924c 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -24,12 +24,12 @@ DEFAULT_UPLOAD_ID: str = "default" -upload_files_context_var_data: VarData = VarData( # type: ignore +upload_files_context_var_data: VarData = VarData( imports={ - "react": {imports.ImportVar(tag="useContext")}, - f"/{Dirs.CONTEXTS_PATH}": { + "react": [imports.ImportVar(tag="useContext")], + f"/{Dirs.CONTEXTS_PATH}": [ imports.ImportVar(tag="UploadFilesContext"), - }, + ], }, hooks={ "const [filesById, setFilesById] = useContext(UploadFilesContext);": None, @@ -118,14 +118,13 @@ def get_upload_dir() -> Path: uploaded_files_url_prefix: Var = Var.create_safe( - "${getBackendURL(env.UPLOAD)}" -)._replace( - merge_var_data=VarData( # type: ignore + "${getBackendURL(env.UPLOAD)}", + _var_data=VarData( imports={ - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, - "/env.json": {imports.ImportVar(tag="env", is_default=True)}, + f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")], + "/env.json": [imports.ImportVar(tag="env", is_default=True)], } - ) + ), ) diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index a98bd47c7f..37051b2797 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -216,13 +216,17 @@ def _get_form_refs(self) -> Dict[str, Any]: if ref.startswith("refs_"): ref_var = Var.create_safe(ref[:-3]).as_ref() form_refs[ref[5:-3]] = Var.create_safe( - f"getRefValues({str(ref_var)})", _var_is_local=False - )._replace(merge_var_data=ref_var._var_data) + f"getRefValues({str(ref_var)})", + _var_is_local=False, + _var_data=ref_var._var_data, + ) else: ref_var = Var.create_safe(ref).as_ref() form_refs[ref[4:]] = Var.create_safe( - f"getRefValue({str(ref_var)})", _var_is_local=False - )._replace(merge_var_data=ref_var._var_data) + f"getRefValue({str(ref_var)})", + _var_is_local=False, + _var_data=ref_var._var_data, + ) return form_refs def _get_vars(self, include_children: bool = True) -> Iterator[Var]: @@ -619,14 +623,16 @@ def _render(self) -> Tag: on_key_down=Var.create_safe( f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})", _var_is_local=False, - )._replace(merge_var_data=self.enter_key_submit._var_data), + _var_data=self.enter_key_submit._var_data, + ) ) if self.auto_height is not None: tag.add_props( on_input=Var.create_safe( f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})", _var_is_local=False, - )._replace(merge_var_data=self.auto_height._var_data), + _var_data=self.auto_height._var_data, + ) ) return tag diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index 6c05dfd811..fd0a220212 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -114,12 +114,14 @@ def _render(self) -> Tag: _var_name=f"{self.data._var_name}.columns", _var_type=List[Any], _var_full_name_needs_state_prefix=True, - )._replace(merge_var_data=self.data._var_data) + _var_data=self.data._var_data, + ) self.data = BaseVar( _var_name=f"{self.data._var_name}.data", _var_type=List[List[Any]], _var_full_name_needs_state_prefix=True, - )._replace(merge_var_data=self.data._var_data) + _var_data=self.data._var_data, + ) if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/components/radix/themes/components/tabs.py b/reflex/components/radix/themes/components/tabs.py index af1b6b5218..130cfd166a 100644 --- a/reflex/components/radix/themes/components/tabs.py +++ b/reflex/components/radix/themes/components/tabs.py @@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent): _valid_parents: List[str] = ["TabsList"] @classmethod - def create(self, *children, **props) -> Component: + def create(cls, *children, **props) -> Component: """Create a TabsTrigger component. Args: diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index 8a6c59a54f..5781f1ee44 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -98,7 +98,7 @@ class ToastProps(PropsBase): class Toaster(Component): """A Toaster Component for displaying toast notifications.""" - library = "sonner@1.4.41" + library: str = "sonner@1.4.41" tag = "Toaster" @@ -145,12 +145,15 @@ class Toaster(Component): pause_when_page_is_hidden: Var[bool] def _get_hooks(self) -> Var[str]: - hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True) - hook._var_data = VarData( # type: ignore - imports={ - "/utils/state": [ImportVar(tag="refs")], - self.library: [ImportVar(tag="toast", install=False)], - } + hook = Var.create_safe( + f"{toast_ref} = toast", + _var_is_local=True, + _var_data=VarData( + imports={ + "/utils/state": [ImportVar(tag="refs")], + self.library: [ImportVar(tag="toast", install=False)], + } + ), ) return hook diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index b99e31e8c7..96e8b03ba7 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -103,9 +103,9 @@ class Imports(SimpleNamespace): """Common sets of import vars.""" EVENTS = { - "react": {ImportVar(tag="useContext")}, - f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")}, - f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)}, + "react": [ImportVar(tag="useContext")], + f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")], + f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)], } diff --git a/reflex/style.py b/reflex/style.py index d77c2bb7c7..0216eedd9c 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -16,10 +16,10 @@ DARK_COLOR_MODE: str = "dark" # Reference the global ColorModeContext -color_mode_var_data = VarData( # type: ignore +color_mode_var_data = VarData( imports={ - f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")}, - "react": {ImportVar(tag="useContext")}, + f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")], + "react": [ImportVar(tag="useContext")], }, hooks={ f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None, diff --git a/reflex/vars.py b/reflex/vars.py index 4a8e6b30f9..244d131c07 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -340,7 +340,11 @@ class Var: @classmethod def create( - cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False + cls, + value: Any, + _var_is_local: bool = True, + _var_is_string: bool = False, + _var_data: Optional[VarData] = None, ) -> Var | None: """Create a var from a value. @@ -348,6 +352,7 @@ def create( value: The value to create the var from. _var_is_local: Whether the var is local. _var_is_string: Whether the var is a string literal. + _var_data: Additional hooks and imports associated with the Var. Returns: The var. @@ -364,9 +369,8 @@ def create( return value # Try to pull the imports and hooks from contained values. - _var_data = None if not isinstance(value, str): - _var_data = VarData.merge(*_extract_var_data(value)) + _var_data = VarData.merge(*_extract_var_data(value), _var_data) # Try to serialize the value. type_ = type(value) @@ -387,7 +391,11 @@ def create( @classmethod def create_safe( - cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False + cls, + value: Any, + _var_is_local: bool = True, + _var_is_string: bool = False, + _var_data: Optional[VarData] = None, ) -> Var: """Create a var from a value, asserting that it is not None. @@ -395,6 +403,7 @@ def create_safe( value: The value to create the var from. _var_is_local: Whether the var is local. _var_is_string: Whether the var is a string literal. + _var_data: Additional hooks and imports associated with the Var. Returns: The var. @@ -403,6 +412,7 @@ def create_safe( value, _var_is_local=_var_is_local, _var_is_string=_var_is_string, + _var_data=_var_data, ) assert var is not None return var diff --git a/reflex/vars.pyi b/reflex/vars.pyi index fb2ed46573..8251563f86 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -34,10 +34,10 @@ def _decode_var(value: str) -> tuple[VarData, str]: ... def _extract_var_data(value: Iterable) -> list[VarData | None]: ... class VarData(Base): - state: str - imports: dict[str, set[ImportVar]] - hooks: Dict[str, None] - interpolations: List[Tuple[int, int]] + state: str = "" + imports: dict[str, List[ImportVar]] = {} + hooks: Dict[str, None] = {} + interpolations: List[Tuple[int, int]] = [] @classmethod def merge(cls, *others: VarData | None) -> VarData | None: ... @@ -50,11 +50,11 @@ class Var: _var_data: VarData | None = None @classmethod def create( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False + cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, ) -> Optional[Var]: ... @classmethod def create_safe( - cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False + cls, value: Any, _var_is_local: bool = False, _var_is_string: bool = False, _var_data: VarData | None = None, ) -> Var: ... @classmethod def __class_getitem__(cls, type_: Type) -> _GenericAlias: ... diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 96c1b69629..cfa8af7359 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1063,7 +1063,7 @@ def test_stateful_banner(): TEST_VAR = Var.create_safe("test")._replace( merge_var_data=VarData( hooks={"useTest": None}, - imports={"test": {ImportVar(tag="test")}}, + imports={"test": [ImportVar(tag="test")]}, state="Test", interpolations=[], ) @@ -1951,3 +1951,41 @@ def add_custom_code(self): "const custom_code5 = 46", "const custom_code6 = 47", } + + +def test_component_add_hooks_var(): + class HookComponent(Component): + def add_hooks(self): + return [ + "const hook3 = useRef(null)", + "const hook1 = 42", + Var.create( + "useEffect(() => () => {}, [])", + _var_data=VarData( + hooks={ + "const hook2 = 43": None, + "const hook3 = useRef(null)": None, + }, + imports={"react": [ImportVar(tag="useEffect")]}, + ), + ), + Var.create( + "const hook3 = useRef(null)", + _var_data=VarData( + imports={"react": [ImportVar(tag="useRef")]}, + ), + ), + ] + + assert list(HookComponent()._get_all_hooks()) == [ + "const hook3 = useRef(null)", + "const hook1 = 42", + "const hook2 = 43", + "useEffect(() => () => {}, [])", + ] + imports = HookComponent()._get_all_imports() + assert len(imports) == 1 + assert "react" in imports + assert len(imports["react"]) == 2 + assert ImportVar(tag="useRef") in imports["react"] + assert ImportVar(tag="useEffect") in imports["react"] From f16572436f5bbeee6ad133a577e33dcd2686e09f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 8 May 2024 13:17:25 -0700 Subject: [PATCH 2/2] Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict --- reflex/components/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index cd2f4644a0..592e04c3c4 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1276,7 +1276,7 @@ def _get_hooks_imports(self) -> imports.ImportDict: and user_hooks._var_data is not None and user_hooks._var_data.imports ): - other_imports.extend(user_hooks._var_data.imports) + other_imports.append(user_hooks._var_data.imports) other_imports.extend( hook_imports for hook_imports in self._get_added_hooks().values() )