Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-2787] add_hooks supports Var-wrapped hooks #3248

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]:
"""
return {}

def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var]:
"""Add hooks inside the component function.

Hooks are pieces of literal Javascript code that is inserted inside the
Expand Down Expand Up @@ -1265,11 +1265,20 @@ def _get_hooks_imports(self) -> imports.ImportDict:
},
)

other_imports = []
user_hooks = self._get_hooks()
if user_hooks is not None and isinstance(user_hooks, Var):
_imports = imports.merge_imports(_imports, user_hooks._var_data.imports) # type: ignore
if (
user_hooks is not None
and isinstance(user_hooks, Var)
and user_hooks._var_data is not None
and user_hooks._var_data.imports
):
other_imports.append(user_hooks._var_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
)

return _imports
return imports.merge_imports(_imports, *other_imports)

def _get_imports(self) -> imports.ImportDict:
"""Get all the libraries and fields that are used by the component.
Expand Down Expand Up @@ -1416,6 +1425,36 @@ def _get_hooks_internal(self) -> dict[str, None]:
**self._get_special_hooks(),
}

def _get_added_hooks(self) -> dict[str, imports.ImportDict]:
"""Get the hooks added via `add_hooks` method.

Returns:
The deduplicated hooks and imports added by the component and parent components.
"""
code = {}

def extract_var_hooks(hook: Var):
_imports = {}
if hook._var_data is not None:
for sub_hook in hook._var_data.hooks:
code[sub_hook] = {}
if hook._var_data.imports:
_imports = hook._var_data.imports
if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
else:
code[str(hook)] = _imports

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
return code

def _get_hooks(self) -> str | None:
"""Get the React hooks for this component.

Expand Down Expand Up @@ -1454,11 +1493,7 @@ def _get_all_hooks(self) -> dict[str, None]:
if hooks is not None:
code[hooks] = None

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for hook in clz.add_hooks(self):
code[hook] = None
code.update(self._get_added_hooks())

# Add the hook code for the children.
for child in self.children:
Expand Down Expand Up @@ -2092,8 +2127,8 @@ def _get_memoized_event_triggers(
var_deps.extend(cls._get_hook_deps(hook))
memo_var_data = VarData.merge(
*[var._var_data for var in event_args],
VarData( # type: ignore
imports={"react": {ImportVar(tag="useCallback")}},
VarData(
imports={"react": [ImportVar(tag="useCallback")]},
),
)

Expand Down
12 changes: 8 additions & 4 deletions reflex/components/core/banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@
value="(connectErrors.length > 0) ? connectErrors[connectErrors.length - 1].message : ''",
_var_is_local=False,
_var_is_string=False,
)._replace(merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
)

connection_errors_count: Var = Var.create_safe(
value="connectErrors.length",
_var_is_string=False,
_var_is_local=False,
)._replace(merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
)

has_connection_errors: Var = Var.create_safe(
value="connectErrors.length > 0",
_var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
).to(bool)

has_too_many_connection_errors: Var = Var.create_safe(
value="connectErrors.length >= 2",
_var_is_string=False,
)._replace(_var_type=bool, merge_var_data=connect_error_var_data)
_var_data=connect_error_var_data,
).to(bool)


class WebsocketTargetURL(Bare):
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from reflex.vars import BaseVar, Var, VarData

_IS_TRUE_IMPORT = {
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="isTrue")},
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")],
}


Expand Down
6 changes: 2 additions & 4 deletions reflex/components/core/debounce.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,11 @@ def create(cls, *children: Component, **props: Any) -> Component:
"{%s}" % (child.alias or child.tag),
_var_is_local=False,
_var_is_string=False,
)._replace(
_var_type=Type[Component],
merge_var_data=VarData( # type: ignore
_var_data=VarData(
imports=child._get_imports(),
hooks=child._get_hooks_internal(),
),
),
).to(Type[Component]),
)

component = super().create(**props)
Expand Down
19 changes: 9 additions & 10 deletions reflex/components/core/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@

DEFAULT_UPLOAD_ID: str = "default"

upload_files_context_var_data: VarData = VarData( # type: ignore
upload_files_context_var_data: VarData = VarData(
imports={
"react": {imports.ImportVar(tag="useContext")},
f"/{Dirs.CONTEXTS_PATH}": {
"react": [imports.ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [
imports.ImportVar(tag="UploadFilesContext"),
},
],
},
hooks={
"const [filesById, setFilesById] = useContext(UploadFilesContext);": None,
Expand Down Expand Up @@ -118,14 +118,13 @@ def get_upload_dir() -> Path:


uploaded_files_url_prefix: Var = Var.create_safe(
"${getBackendURL(env.UPLOAD)}"
)._replace(
merge_var_data=VarData( # type: ignore
"${getBackendURL(env.UPLOAD)}",
_var_data=VarData(
imports={
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
"/env.json": {imports.ImportVar(tag="env", is_default=True)},
f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="getBackendURL")],
"/env.json": [imports.ImportVar(tag="env", is_default=True)],
}
)
),
)


Expand Down
18 changes: 12 additions & 6 deletions reflex/components/el/elements/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,17 @@ def _get_form_refs(self) -> Dict[str, Any]:
if ref.startswith("refs_"):
ref_var = Var.create_safe(ref[:-3]).as_ref()
form_refs[ref[5:-3]] = Var.create_safe(
f"getRefValues({str(ref_var)})", _var_is_local=False
)._replace(merge_var_data=ref_var._var_data)
f"getRefValues({str(ref_var)})",
_var_is_local=False,
_var_data=ref_var._var_data,
)
else:
ref_var = Var.create_safe(ref).as_ref()
form_refs[ref[4:]] = Var.create_safe(
f"getRefValue({str(ref_var)})", _var_is_local=False
)._replace(merge_var_data=ref_var._var_data)
f"getRefValue({str(ref_var)})",
_var_is_local=False,
_var_data=ref_var._var_data,
)
return form_refs

def _get_vars(self, include_children: bool = True) -> Iterator[Var]:
Expand Down Expand Up @@ -619,14 +623,16 @@ def _render(self) -> Tag:
on_key_down=Var.create_safe(
f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})",
_var_is_local=False,
)._replace(merge_var_data=self.enter_key_submit._var_data),
_var_data=self.enter_key_submit._var_data,
)
)
if self.auto_height is not None:
tag.add_props(
on_input=Var.create_safe(
f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})",
_var_is_local=False,
)._replace(merge_var_data=self.auto_height._var_data),
_var_data=self.auto_height._var_data,
)
)
return tag

Expand Down
6 changes: 4 additions & 2 deletions reflex/components/gridjs/datatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def _render(self) -> Tag:
_var_name=f"{self.data._var_name}.columns",
_var_type=List[Any],
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
_var_data=self.data._var_data,
)
self.data = BaseVar(
_var_name=f"{self.data._var_name}.data",
_var_type=List[List[Any]],
_var_full_name_needs_state_prefix=True,
)._replace(merge_var_data=self.data._var_data)
_var_data=self.data._var_data,
)
if types.is_dataframe(type(self.data)):
# If given a pandas df break up the data and columns
data = serialize(self.data)
Expand Down
2 changes: 1 addition & 1 deletion reflex/components/radix/themes/components/tabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TabsTrigger(RadixThemesComponent):
_valid_parents: List[str] = ["TabsList"]

@classmethod
def create(self, *children, **props) -> Component:
def create(cls, *children, **props) -> Component:
"""Create a TabsTrigger component.

Args:
Expand Down
17 changes: 10 additions & 7 deletions reflex/components/sonner/toast.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def dict(self, *args, **kwargs) -> dict:
class Toaster(Component):
"""A Toaster Component for displaying toast notifications."""

library = "sonner@1.4.41"
library: str = "sonner@1.4.41"

tag = "Toaster"

Expand Down Expand Up @@ -209,12 +209,15 @@ class Toaster(Component):
pause_when_page_is_hidden: Var[bool]

def _get_hooks(self) -> Var[str]:
hook = Var.create_safe(f"{toast_ref} = toast", _var_is_local=True)
hook._var_data = VarData( # type: ignore
imports={
"/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)],
}
hook = Var.create_safe(
f"{toast_ref} = toast",
_var_is_local=True,
_var_data=VarData(
imports={
"/utils/state": [ImportVar(tag="refs")],
self.library: [ImportVar(tag="toast", install=False)],
}
),
)
return hook

Expand Down
6 changes: 3 additions & 3 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ class Imports(SimpleNamespace):
"""Common sets of import vars."""

EVENTS = {
"react": {ImportVar(tag="useContext")},
f"/{Dirs.CONTEXTS_PATH}": {ImportVar(tag="EventLoopContext")},
f"/{Dirs.STATE_PATH}": {ImportVar(tag=CompileVars.TO_EVENT)},
"react": [ImportVar(tag="useContext")],
f"/{Dirs.CONTEXTS_PATH}": [ImportVar(tag="EventLoopContext")],
f"/{Dirs.STATE_PATH}": [ImportVar(tag=CompileVars.TO_EVENT)],
}


Expand Down
6 changes: 3 additions & 3 deletions reflex/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
DARK_COLOR_MODE: str = "dark"

# Reference the global ColorModeContext
color_mode_var_data = VarData( # type: ignore
color_mode_var_data = VarData(
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": {ImportVar(tag="ColorModeContext")},
"react": {ImportVar(tag="useContext")},
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="ColorModeContext")],
"react": [ImportVar(tag="useContext")],
},
hooks={
f"const [ {constants.ColorMode.NAME}, {constants.ColorMode.TOGGLE} ] = useContext(ColorModeContext)": None,
Expand Down
18 changes: 14 additions & 4 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,19 @@ class Var:

@classmethod
def create(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var | None:
"""Create a var from a value.

Args:
value: The value to create the var from.
_var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.

Returns:
The var.
Expand All @@ -365,9 +370,8 @@ def create(
return value

# Try to pull the imports and hooks from contained values.
_var_data = None
if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value))
_var_data = VarData.merge(*_extract_var_data(value), _var_data)

# Try to serialize the value.
type_ = type(value)
Expand All @@ -388,14 +392,19 @@ def create(

@classmethod
def create_safe(
cls, value: Any, _var_is_local: bool = True, _var_is_string: bool = False
cls,
value: Any,
_var_is_local: bool = True,
_var_is_string: bool = False,
_var_data: Optional[VarData] = None,
) -> Var:
"""Create a var from a value, asserting that it is not None.

Args:
value: The value to create the var from.
_var_is_local: Whether the var is local.
_var_is_string: Whether the var is a string literal.
_var_data: Additional hooks and imports associated with the Var.

Returns:
The var.
Expand All @@ -404,6 +413,7 @@ def create_safe(
value,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
assert var is not None
return var
Expand Down
Loading
Loading