Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use process pool to compile faster #2377

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 103 additions & 81 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import contextlib
import copy
import functools
import multiprocessing
import os
import platform
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -35,6 +37,7 @@
from reflex.base import Base
from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils
from reflex.compiler.compiler import ExecutorSafeFunctions
from reflex.components import connection_modal
from reflex.components.base.app_wrap import AppWrap
from reflex.components.base.fragment import Fragment
Expand Down Expand Up @@ -661,15 +664,24 @@ def compile_(self):
TimeElapsedColumn(),
)

# try to be somewhat accurate - but still not 100%
adhoc_steps_without_executor = 6
fixed_pages_within_executor = 7
progress.start()
task = progress.add_task(
"Compiling:",
total=len(self.pages)
+ fixed_pages_within_executor
+ adhoc_steps_without_executor,
)

# Get the env mode.
config = get_config()

# Store the compile results.
compile_results = []

# Compile the pages in parallel.
custom_components = set()
# TODO Anecdotally, processes=2 works 10% faster (cpu_count=12)
all_imports = {}
app_wrappers: Dict[tuple[int, str], Component] = {
# Default app wrap component renders {children}
Expand All @@ -679,127 +691,137 @@ def compile_(self):
# If a theme component was provided, wrap the app with it
app_wrappers[(20, "Theme")] = self.theme

with progress, concurrent.futures.ThreadPoolExecutor() as thread_pool:
fixed_pages = 7
task = progress.add_task("Compiling:", total=len(self.pages) + fixed_pages)
progress.advance(task)

def mark_complete(_=None):
progress.advance(task)
for _route, component in self.pages.items():
# Merge the component style with the app style.
component.add_style(self.style)

for _route, component in self.pages.items():
# Merge the component style with the app style.
component.add_style(self.style)
component.apply_theme(self.theme)

component.apply_theme(self.theme)
# Add component.get_imports() to all_imports.
all_imports.update(component.get_imports())

# Add component.get_imports() to all_imports.
all_imports.update(component.get_imports())
# Add the app wrappers from this component.
app_wrappers.update(component.get_app_wrap_components())

# Add the app wrappers from this component.
app_wrappers.update(component.get_app_wrap_components())
# Add the custom components from the page to the set.
custom_components |= component.get_custom_components()

# Add the custom components from the page to the set.
custom_components |= component.get_custom_components()
progress.advance(task)

# Perform auto-memoization of stateful components.
(
stateful_components_path,
stateful_components_code,
page_components,
) = compiler.compile_stateful_components(self.pages.values())
# Perform auto-memoization of stateful components.
(
stateful_components_path,
stateful_components_code,
page_components,
) = compiler.compile_stateful_components(self.pages.values())

# Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State.
if (
code_uses_state_contexts(stateful_components_code)
and self.state is None
):
raise RuntimeError(
"To access rx.State in frontend components, at least one "
"subclass of rx.State must be defined in the app."
)
compile_results.append((stateful_components_path, stateful_components_code))
progress.advance(task)

result_futures = []
# Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State.
if code_uses_state_contexts(stateful_components_code) and self.state is None:
raise RuntimeError(
"To access rx.State in frontend components, at least one "
"subclass of rx.State must be defined in the app."
)
compile_results.append((stateful_components_path, stateful_components_code))

app_root = self._app_root(app_wrappers=app_wrappers)

progress.advance(task)

# Prepopulate the global ExecutorSafeFunctions class with input data required by the compile functions.
# This is required for multiprocessing to work, in presence of non-picklable inputs.
for route, component in zip(self.pages, page_components):
ExecutorSafeFunctions.COMPILE_PAGE_ARGS_BY_ROUTE[route] = (
route,
component,
self.state,
)

def submit_work(fn, *args, **kwargs):
"""Submit work to the thread pool and add a callback to mark the task as complete.
ExecutorSafeFunctions.COMPILE_APP_APP_ROOT = app_root
ExecutorSafeFunctions.CUSTOM_COMPONENTS = custom_components
ExecutorSafeFunctions.HEAD_COMPONENTS = self.head_components
ExecutorSafeFunctions.STYLE = self.style
ExecutorSafeFunctions.STATE = self.state

# Use a forking process pool, if possible. Much faster, especially for large sites.
# Fallback to ThreadPoolExecutor as something that will always work.
executor = None
if platform.system() in ("Linux", "Darwin"):
executor = concurrent.futures.ProcessPoolExecutor(
mp_context=multiprocessing.get_context("fork")
)
else:
executor = concurrent.futures.ThreadPoolExecutor()

The Future will be added to the `result_futures` list.
with executor:
result_futures = []

Args:
fn: The function to submit.
*args: The args to submit.
**kwargs: The kwargs to submit.
"""
f = thread_pool.submit(fn, *args, **kwargs)
f.add_done_callback(mark_complete)
def _mark_complete(_=None):
progress.advance(task)

def _submit_work(fn, *args, **kwargs):
f = executor.submit(fn, *args, **kwargs)
f.add_done_callback(_mark_complete)
result_futures.append(f)

# Compile all page components.
for route, component in zip(self.pages, page_components):
submit_work(
compiler.compile_page,
route,
component,
self.state,
)
for route in self.pages:
_submit_work(ExecutorSafeFunctions.compile_page, route)

# Compile the app wrapper.
app_root = self._app_root(app_wrappers=app_wrappers)
submit_work(compiler.compile_app, app_root)
_submit_work(ExecutorSafeFunctions.compile_app)

# Compile the custom components.
submit_work(compiler.compile_components, custom_components)
_submit_work(ExecutorSafeFunctions.compile_custom_components)

# Compile the root stylesheet with base styles.
submit_work(compiler.compile_root_stylesheet, self.stylesheets)
_submit_work(compiler.compile_root_stylesheet, self.stylesheets)

# Compile the root document.
submit_work(compiler.compile_document_root, self.head_components)
_submit_work(ExecutorSafeFunctions.compile_document_root)

# Compile the theme.
submit_work(compiler.compile_theme, style=self.style)
_submit_work(ExecutorSafeFunctions.compile_theme)

# Compile the contexts.
submit_work(compiler.compile_contexts, self.state)
_submit_work(ExecutorSafeFunctions.compile_contexts)

# Compile the Tailwind config.
if config.tailwind is not None:
config.tailwind["content"] = config.tailwind.get(
"content", constants.Tailwind.CONTENT
)
submit_work(compiler.compile_tailwind, config.tailwind)
_submit_work(compiler.compile_tailwind, config.tailwind)
else:
submit_work(compiler.remove_tailwind_from_postcss)

# Get imports from AppWrap components.
all_imports.update(app_root.get_imports())

# Iterate through all the custom components and add their imports to the all_imports.
for component in custom_components:
all_imports.update(component.get_imports())
_submit_work(compiler.remove_tailwind_from_postcss)

# Wait for all compilation tasks to complete.
for future in concurrent.futures.as_completed(result_futures):
compile_results.append(future.result())

# Empty the .web pages directory.
compiler.purge_web_pages_dir()
# Get imports from AppWrap components.
all_imports.update(app_root.get_imports())

# Avoid flickering when installing frontend packages
progress.stop()
# Iterate through all the custom components and add their imports to the all_imports.
for component in custom_components:
all_imports.update(component.get_imports())

# Install frontend packages.
self.get_frontend_packages(all_imports)
progress.advance(task)

# Write the pages at the end to trigger the NextJS hot reload only once.
write_page_futures = []
for output_path, code in compile_results:
write_page_futures.append(
thread_pool.submit(compiler_utils.write_page, output_path, code)
)
for future in concurrent.futures.as_completed(write_page_futures):
future.result()
# Empty the .web pages directory.
compiler.purge_web_pages_dir()

progress.advance(task)
progress.stop()

# Install frontend packages.
self.get_frontend_packages(all_imports)

for output_path, code in compile_results:
compiler_utils.write_page(output_path, code)

@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
Expand Down
110 changes: 110 additions & 0 deletions reflex/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,113 @@ def remove_tailwind_from_postcss() -> tuple[str, str]:
def purge_web_pages_dir():
"""Empty out .web directory."""
utils.empty_dir(constants.Dirs.WEB_PAGES, keep_files=["_app.js"])


class ExecutorSafeFunctions:
"""Helper class to allow parallelisation of parts of the compilation process.

This class (and its class attributes) are available at global scope.

In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this
global class is logically replicated to any FORKED process.

How it works:
* Before the child process is forked, ensure that we stash any input data required by any future
function call in the child process.
* After the child process is forked, the child process will have a copy of the global class, which
includes the previously stashed input data.
* Any task submitted to the child process simply needs a way to communicate which input data the
requested function call requires.

Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable.
The mechanic described here removes the need to pickle the input data at all.

Limitations:
* This can never support returning unpicklable OUTPUT data.
* Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!).

"""

COMPILE_PAGE_ARGS_BY_ROUTE = {}
COMPILE_APP_APP_ROOT: Component | None = None
CUSTOM_COMPONENTS: set[CustomComponent] | None = None
HEAD_COMPONENTS: list[Component] | None = None
STYLE: ComponentStyle | None = None
STATE: type[BaseState] | None = None

@classmethod
def compile_page(cls, route: str):
"""Compile a page.

Args:
route: The route of the page to compile.

Returns:
The path and code of the compiled page.
"""
return compile_page(*cls.COMPILE_PAGE_ARGS_BY_ROUTE[route])

@classmethod
def compile_app(cls):
"""Compile the app.

Returns:
The path and code of the compiled app.

Raises:
ValueError: If the app root is not set.
"""
if cls.COMPILE_APP_APP_ROOT is None:
raise ValueError("COMPILE_APP_APP_ROOT should be set")
return compile_app(cls.COMPILE_APP_APP_ROOT)

@classmethod
def compile_custom_components(cls):
"""Compile the custom components.

Returns:
The path and code of the compiled custom components.

Raises:
ValueError: If the custom components are not set.
"""
if cls.CUSTOM_COMPONENTS is None:
raise ValueError("CUSTOM_COMPONENTS should be set")
return compile_components(cls.CUSTOM_COMPONENTS)

@classmethod
def compile_document_root(cls):
"""Compile the document root.

Returns:
The path and code of the compiled document root.

Raises:
ValueError: If the head components are not set.
"""
if cls.HEAD_COMPONENTS is None:
raise ValueError("HEAD_COMPONENTS should be set")
return compile_document_root(cls.HEAD_COMPONENTS)

@classmethod
def compile_theme(cls):
"""Compile the theme.

Returns:
The path and code of the compiled theme.

Raises:
ValueError: If the style is not set.
"""
if cls.STYLE is None:
raise ValueError("STYLE should be set")
return compile_theme(cls.STYLE)

@classmethod
def compile_contexts(cls):
"""Compile the contexts.

Returns:
The path and code of the compiled contexts.
"""
return compile_contexts(cls.STATE)
Loading