Skip to content

Commit

Permalink
[REF-2787] add_hooks supports Var-wrapped hooks (#3248)
Browse files Browse the repository at this point in the history
* [REF-2787] add_hooks supports Var-wrapped hooks

* Fix VarData definition in .pyi file to allow removal of type ignore comments
* Var.create and Var.create_safe accept _var_data parameter
* Replace instances where a set of imports was being passed to VarData
* Update code throughout reduce use of `._replace` to add VarData

* Fixup: user hooks _var_data.imports will never be iterable, just a single ImportDict
  • Loading branch information
masenf committed May 15, 2024
1 parent d96baac commit c5f32db
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 63 deletions.
57 changes: 46 additions & 11 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
"""
return {}

def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var]:
"""Add hooks inside the component function.
Hooks are pieces of literal Javascript code that is inserted inside the
Expand Down Expand Up @@ -1265,11 +1265,20 @@ def _get_hooks_imports(self) -> imports.ImportDict:
},
)

other_imports = []
user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var):
_imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore
if (
user_hooks is not None
and isinstance(user_hooks, Var)
and user_hooks._var_data is not None
and user_hooks._var_data.imports
):
other_imports.append(user_hooks._var_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
)

return _imports
return imports.merge_imports(_imports, *other_imports)

def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
Expand Down Expand Up @@ -1416,6 +1425,36 @@ def _get_hooks_internal(self) -> dict[str, None]:
**self._get_special_hooks(),
}

def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
"""Get the hooks added via `add_hooks` method.
Returns:
The deduplicated hooks and imports added by the component and parent components.
"""
code = {}

def extract_var_hooks(hook: Var):
_imports = {}
if hook._var_data is not None:
for sub_hook in hook._var_data.hooks:
code[sub_hook] = {}
if hook._var_data.imports:
_imports = hook._var_data.imports
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
else:
code[str(hook)] = _imports

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
return code

def _get_hooks(self) -> str | None:
"""Get the React hooks for this component.
Expand Down Expand Up @@ -1454,11 +1493,7 @@ def _get_all_hooks(self) -> dict[str, None]:
if hooks is not None:
code[hooks] = None

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
code[hook] = None
code.update(self._get_added_hooks())

# Add the hook code for the children.
for child in self.children:
Expand Down Expand Up @@ -2092,8 +2127,8 @@ def _get_memoized_event_triggers(
var_deps.extend(cls._get_hook_deps(hook))
memo_var_data = VarData.merge(
*[var._var_data for var in event_args],
VarData( # type: ignore
imports={"react": {ImportVar(tag="useCallback")}},
VarData(
imports={"react": [ImportVar(tag="useCallback")]},
),
)

Expand Down
12 changes: 8 additions & 4 deletions reflex/components/core/banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@
value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
_var_is_local=False,
_var_is_string=False,
)._replace(merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
)

connection_errors_count: Var = Var.create_safe(
value="connectErrors.length",
_var_is_string=False,
_var_is_local=False,
)._replace(merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
)

has_connection_errors: Var = Var.create_safe(
value="connectErrors.length > 0",
_var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
).to(bool)

has_too_many_connection_errors: Var = Var.create_safe(
value="connectErrors.length >= 2",
_var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
).to(bool)


class WebsocketTargetURL(Bare):
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from reflex.vars import BaseVar, Var, VarData

_IS_TRUE_IMPORT = {
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
}


Expand Down
6 changes: 2 additions & 4 deletions reflex/components/core/debounce.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,11 @@ def create(cls, *children: Component, **props: Any) -> Component:
"{%s}" % (child.alias or child.tag),
_var_is_local=False,
_var_is_string=False,
)._replace(
_var_type=Type[Component],
merge_var_data=VarData( # type: ignore
_var_data=VarData(
imports=child._get_imports(),
hooks=child._get_hooks_internal(),
),
),
).to(Type[Component]),
)

component = super().create(**props)
Expand Down
19 changes: 9 additions & 10 deletions reflex/components/core/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

DEFAULT_UPLOAD_ID: str = "default"

upload_files_context_var_data: VarData = VarData( # type: ignore
upload_files_context_var_data: VarData = VarData(
imports={
"react": {imports.ImportVar(tag="useContext")},
f"/{Dirs.CONTEXTS_PATH}": {
"react": [imports.ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [
imports.ImportVar(tag="UploadFilesContext"),
},
],
},
hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
Expand Down Expand Up @@ -118,14 +118,13 @@ def get_upload_dir() -> Path:


uploaded_files_url_prefix: Var = Var.create_safe(
"${getBackendURL(env.UPLOAD)}"
)._replace(
merge_var_data=VarData( # type: ignore
"${getBackendURL(env.UPLOAD)}",
_var_data=VarData(
imports={
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
"/env.json": {imports.ImportVar(tag="env", is_default=True)},
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
}
)
),
)


Expand Down
18 changes: 12 additions & 6 deletions reflex/components/el/elements/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,17 @@ def _get_form_refs(self) -> Dict[str, Any]:
if ref.startswith("refs_"):
ref_var = Var.create_safe(ref[:-3]).as_ref()
form_refs[ref[5:-3]] = Var.create_safe(
f"getRefValues({str(ref_var)})", _var_is_local=False
)._replace(merge_var_data=ref_var._var_data)
f"getRefValues({str(ref_var)})",
_var_is_local=False,
_var_data=ref_var._var_data,
)
else:
ref_var = Var.create_safe(ref).as_ref()
form_refs[ref[4:]] = Var.create_safe(
f"getRefValue({str(ref_var)})", _var_is_local=False
)._replace(merge_var_data=ref_var._var_data)
f"getRefValue({str(ref_var)})",
_var_is_local=False,
_var_data=ref_var._var_data,
)
return form_refs

def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
Expand Down Expand Up @@ -619,14 +623,16 @@ def _render(self) -> Tag:
on_key_down=Var.create_safe(
f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
_var_is_local=False,
)._replace(merge_var_data=self.enter_key_submit._var_data),
_var_data=self.enter_key_submit._var_data,
)
)
if self.auto_height is not None:
tag.add_props(
on_input=Var.create_safe(
f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
_var_is_local=False,
)._replace(merge_var_data=self.auto_height._var_data),
_var_data=self.auto_height._var_data,
)
)
return tag

Expand Down
6 changes: 4 additions & 2 deletions reflex/components/gridjs/datatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def _render(self) -> Tag:
_var_name=f"{self.data._var_name}.columns",
_var_type=List[Any],
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
_var_data=self.data._var_data,
)
self.data = BaseVar(
_var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]],
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
_var_data=self.data._var_data,
)
if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns
data = serialize(self.data)
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/radix/themes/components/tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
_valid_parents: List[str] = ["TabsList"]

@classmethod
def create(self, *children, **props) -> Component:
def create(cls, *children, **props) -> Component:
"""Create a TabsTrigger component.
Args:
Expand Down
17 changes: 10 additions & 7 deletions reflex/components/sonner/toast.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def dict(self, *args, **kwargs) -> dict:
class Toaster(Component):
"""A Toaster Component for displaying toast notifications."""

library = "sonner@1.4.41"
library: str = "sonner@1.4.41"

tag = "Toaster"

Expand Down Expand Up @@ -209,12 +209,15 @@ class Toaster(Component):
pause_when_page_is_hidden: Var[bool]

def _get_hooks(self) -> Var[str]:
hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True)
hook._var_data = VarData( # type: ignore
imports={
"/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)],
}
hook = Var.create_safe(
f"{toast_ref} = toast",
_var_is_local=True,
_var_data=VarData(
imports={
"/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)],
}
),
)
return hook

Expand Down
6 changes: 3 additions & 3 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
"""Common sets of import vars."""

EVENTS = {
"react": {ImportVar(tag="useContext")},
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
"react": [ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
}


Expand Down
6 changes: 3 additions & 3 deletions reflex/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
DARK_COLOR_MODE: str = "dark"

# Reference the global ColorModeContext
color_mode_var_data = VarData( # type: ignore
color_mode_var_data = VarData(
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
"react": {ImportVar(tag="useContext")},
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
"react": [ImportVar(tag="useContext")],
},
hooks={
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
Expand Down
18 changes: 14 additions & 4 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,19 @@ class Var:

@classmethod
def create(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var | None:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Expand All @@ -365,9 +370,8 @@ def create(
return value

# Try to pull the imports and hooks from contained values.
_var_data = None
if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value))
_var_data = VarData.merge(*_extract_var_data(value), _var_data)

# Try to serialize the value.
type_ = type(value)
Expand All @@ -388,14 +392,19 @@ def create(

@classmethod
def create_safe(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var:
"""Create a var from a value, asserting that it is not None.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Expand All @@ -404,6 +413,7 @@ def create_safe(
value,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
assert var is not None
return var
Expand Down
Loading

0 comments on commit c5f32db

Please sign in to comment.