From c457ab7bee35457a2ecc1536beed49ad6f033bb7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 08:24:23 -0400 Subject: [PATCH 01/74] initial push --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6a2224907..4120ea8f93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ All notable changes to `dash` will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/). +## [bringyourownserver] +- [#3430] Adds support to bring your own server, eg (Quart, FastAPI, etc). + ## [UNRELEASED] ## Added From 4ebc657a49e6440d9977bc08c1afcc8916945230 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:00:32 -0400 Subject: [PATCH 02/74] work to modularize the dash eco-system and decouple from Flask --- dash/_callback.py | 19 +- dash/dash.py | 257 ++++++----------------- dash/server_factories/__init__.py | 10 + dash/server_factories/base_factory.py | 47 +++++ dash/server_factories/fastapi_factory.py | 226 ++++++++++++++++++++ dash/server_factories/flask_factory.py | 188 +++++++++++++++++ 6 files changed, 550 insertions(+), 197 deletions(-) create mode 100644 dash/server_factories/__init__.py create mode 100644 dash/server_factories/base_factory.py create mode 100644 dash/server_factories/fastapi_factory.py create mode 100644 dash/server_factories/flask_factory.py diff --git a/dash/_callback.py b/dash/_callback.py index aacb8dbdde..bca8027fdd 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -6,7 +6,7 @@ import asyncio -import flask +from dash.server_factories import get_request_adapter from .dependencies import ( handle_callback_args, @@ -376,7 +376,7 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = flask.request.args.getlist("oldJob") + old_job = get_request_adapter().get_args().getlist("oldJob") if old_job: for job in old_job: @@ -436,7 +436,7 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = flask.request.args.get("cacheKey") + cache_key = get_request_adapter().get_args().get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -453,8 +453,8 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + cache_key = get_request_adapter().get_args().get("cacheKey") + job_id = get_request_adapter().get_args().get("job") _progress_background_callback(response, callback_manager, background) @@ -474,8 +474,8 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + cache_key = get_request_adapter().get_args().get("cacheKey") + job_id = get_request_adapter().get_args().get("job") # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -688,11 +688,10 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} - jsonResponse = None try: if background is not None: - if not flask.request.args.get("cacheKey"): + if not get_request_adapter().get_args().get("cacheKey"): return _setup_background_callback( kwargs, background, @@ -763,7 +762,7 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not flask.request.args.get("cacheKey"): + if not get_request_adapter().get_args().get("cacheKey"): return _setup_background_callback( kwargs, background, diff --git a/dash/dash.py b/dash/dash.py index 8430259c27..56bf65c9e6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -17,6 +17,7 @@ import hashlib import base64 import traceback +import inspect from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List @@ -67,6 +68,8 @@ from . import _validate from . import _watch from . import _get_app +from .server_factories.flask_factory import FlaskServerFactory +from .server_factories.base_factory import BaseServerFactory from ._get_app import with_app_context, with_app_context_async, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -421,7 +424,7 @@ class Dash(ObsoleteChecker): _plotlyjs_url: str STARTUP_ROUTES: list = [] - server: flask.Flask + server: Any # Layout is a complex type which can be many things _layout: Any @@ -430,7 +433,7 @@ class Dash(ObsoleteChecker): def __init__( # pylint: disable=too-many-statements self, name: Optional[str] = None, - server: Union[bool, flask.Flask] = True, + server: Union[bool, Callable[[], Any]] = True, assets_folder: str = "assets", pages_folder: str = "pages", use_pages: Optional[bool] = None, @@ -466,6 +469,7 @@ def __init__( # pylint: disable=too-many-statements description: Optional[str] = None, on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, + server_factory: Optional[BaseServerFactory] = None, **obsolete, ): @@ -488,16 +492,23 @@ def __init__( # pylint: disable=too-many-statements caller_name: str = name if name is not None else get_caller_name() + self.server_factory = server_factory or FlaskServerFactory() + # We have 3 cases: server is either True (we create the server), False # (defer server creation) or a Flask app instance (we use their server) - if isinstance(server, flask.Flask): + if callable(server) and not (hasattr(server, 'route') and hasattr(server, 'run')): + # Server factory function + self.server = server() + if name is None: + caller_name = getattr(self.server, "name", caller_name) + elif hasattr(server, 'route') and hasattr(server, 'run'): self.server = server if name is None: caller_name = getattr(server, "name", caller_name) elif isinstance(server, bool): - self.server = flask.Flask(caller_name) if server else None # type: ignore + self.server = self.server_factory.create_app(caller_name) if server else None else: - raise ValueError("server must be a Flask app or a boolean") + raise ValueError("server must be a Flask app, a boolean, or a server factory function") base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -671,11 +682,8 @@ def _setup_hooks(self): if self._hooks.get_hooks("error"): self._on_error = self._hooks.HookErrorHandler(self._on_error) - def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: - """Initialize the parts of Dash that require a flask app.""" - + def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config - config.update(kwargs) config.set_read_only( [ @@ -685,89 +693,58 @@ def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: ], "Read-only: can only be set in the Dash constructor or during init_app()", ) - if app is not None: self.server = app - bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" - - self.server.register_blueprint( - flask.Blueprint( - assets_blueprint_name, - config.name, - static_folder=self.config.assets_folder, - static_url_path=config.routes_pathname_prefix - + self.config.assets_url_path.lstrip("/"), - ) + self.server_factory.register_assets_blueprint( + self.server, + assets_blueprint_name, + config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), + self.config.assets_folder, ) - if config.compress: try: - # pylint: disable=import-outside-toplevel - from flask_compress import Compress # type: ignore[reportMissingImports] - - # gzip + from flask_compress import Compress Compress(self.server) - _flask_compress_version = parse_version( _get_distribution_version("flask_compress") ) - if not hasattr( self.server.config, "COMPRESS_ALGORITHM" ) and _flask_compress_version >= parse_version("1.6.0"): - # flask-compress==1.6.0 changed default to ['br', 'gzip'] - # and non-overridable default compression with Brotli is - # causing performance issues self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] except ImportError as error: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - - @self.server.errorhandler(PreventUpdate) - def _handle_error(_): - """Handle a halted callback and return an empty 204 response.""" - return "", 204 - - self.server.before_request(self._setup_server) - - # add a handler for components suites errors to return 404 - self.server.errorhandler(InvalidResourceError)(self._invalid_resources_handler) - + self.server_factory.register_error_handlers(self.server) + self.server_factory.before_request(self.server, self._setup_server) self._setup_routes() - _get_app.APP = self self.enable_pages() - self._setup_plotlyjs() def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name - - self.server.add_url_rule( - full_name, view_func=view_func, endpoint=full_name, methods=list(methods) + self.server_factory.add_url_rule( + self.server, + full_name, + view_func=view_func, + endpoint=full_name, + methods=list(methods), ) - - # record the url in Dash.routes so that it can be accessed later - # e.g. for adding authentication with flask_login self.routes.append(full_name) def _setup_routes(self): - self._add_url( - "_dash-component-suites//", - self.serve_component_suites, - ) + self.server_factory.setup_component_suites(self.server, self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - if self._use_async: - self._add_url("_dash-update-component", self.async_dispatch, ["POST"]) - else: - self._add_url("_dash-update-component", self.dispatch, ["POST"]) + self._add_url("_dash-update-component", self.server_factory.dispatch(self.server, self, self._use_async), ["POST"]) self._add_url("_reload-hash", self.serve_reload_hash) - self._add_url("_favicon.ico", self._serve_default_favicon) - self._add_url("", self.index) + self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) + self.server_factory.setup_index(self.server, self) + self.server_factory.setup_catchall(self.server, self) if jupyter_dash.active: self._add_url( @@ -781,8 +758,6 @@ def _setup_routes(self): hook.data["methods"], ) - # catch-all for front-end routes, used by dcc.Location - self._add_url("", self.index) def setup_apis(self): """ @@ -902,7 +877,7 @@ def serve_layout(self): layout = hook(layout) # TODO - Set browser cache limit - pass hash into frontend - return flask.Response( + return self.server_factory.make_response( to_json(layout), mimetype="application/json", ) @@ -966,7 +941,7 @@ def serve_reload_hash(self): _reload.hard = False _reload.changed_assets = [] - return flask.jsonify( + return self.server_factory.jsonify( { "reloadHash": _hash, "hard": hard, @@ -1159,54 +1134,12 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - # Serve the JS bundles for each package - def serve_component_suites(self, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - - _validate.validate_js_path(self.registered_paths, package_name, path_in_pkg) - - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - - package = sys.modules[package_name] - self.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - - response = flask.Response( - pkgutil.get_data(package_name, path_in_pkg), mimetype=mimetype - ) - - if has_fingerprint: - # Fingerprinted resources are good forever (1 year) - # No need for ETag as the fingerprint changes with each build - response.cache_control.max_age = 31536000 # 1 year - else: - # Non-fingerprinted resources are given an ETag that - # will be used / check on future requests - response.add_etag() - tag = response.get_etag()[0] - - request_etag = flask.request.headers.get("If-None-Match") - - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - - return response - - @with_app_context - def index(self, *args, **kwargs): # pylint: disable=unused-argument + def render_index(self, *args, **kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() metas = self._generate_meta() renderer = self._generate_renderer() - - # use self.title instead of app.config.title for backwards compatibility title = self.title if self.use_pages and self.config.include_pages_meta: @@ -1314,7 +1247,7 @@ def interpolate_index(self, **kwargs): @with_app_context def dependencies(self): - return flask.Response( + return self.server_factory.make_response( to_json(self._callback_list), content_type="application/json", ) @@ -1417,8 +1350,11 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]: **_kwargs, ) + def _inputs_to_vals(self, inputs): + return inputs_to_vals(inputs) + # pylint: disable=R0915 - def _initialize_context(self, body): + def _initialize_context(self, body, adapter): """Initialize the global context for the request.""" g = AttributeDict({}) g.inputs_list = body.get("inputs", []) @@ -1430,12 +1366,12 @@ def _initialize_context(self, body): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = flask.Response(mimetype="application/json") - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin + g.dash_response = self.server_factory.make_response(mimetype="application/json", data=None) + g.cookies = dict(adapter.get_cookies()) + g.headers = dict(adapter.get_headers()) + g.path = adapter.get_full_path() + g.remote = adapter.get_remote_addr() + g.origin = adapter.get_origin() g.updated_props = {} return g @@ -1499,11 +1435,6 @@ def _prepare_grouping(self, data_list, indices): def _execute_callback(self, func, args, outputs_list, g): """Execute the callback with the prepared arguments.""" - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin g.custom_data = AttributeDict({}) for hook in self._hooks.get_hooks("custom_data"): @@ -1522,47 +1453,6 @@ def _execute_callback(self, func, args, outputs_list, g): ) return partial_func - @with_app_context_async - async def async_dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - if asyncio.iscoroutine(func): - response_data = await ctx.run(partial_func) - else: - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - response_data = await response_data - - g.dash_response.set_data(response_data) - return g.dash_response - - @with_app_context - def dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - - g.dash_response.set_data(response_data) - return g.dash_response - def _setup_server(self): if self._got_first_request["setup_server"]: return @@ -1695,12 +1585,6 @@ def _walk_assets_directory(self): def _invalid_resources_handler(err): return err.args[0], 404 - @staticmethod - def _serve_default_favicon(): - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - def csp_hashes(self, hash_algorithm="sha256") -> Sequence[str]: """Calculates CSP hashes (sha + base64) of all inline scripts, such that one of the biggest benefits of CSP (disallowing general inline scripts) @@ -2112,14 +1996,19 @@ def enable_dev_tools( elif dev_tools.prune_errors: secret = gen_salt(20) - @self.server.errorhandler(Exception) - def _wrap_errors(error): - # find the callback invocation, if the error is from a callback - # and skip the traceback up to that point - # if the error didn't come from inside a callback, we won't - # skip anything. - tb = _get_traceback(secret, error) - return tb, 500 + if hasattr(self.server, "errorhandler"): + # Flask + @self.server.errorhandler(Exception) + def _wrap_errors(error): + tb = _get_traceback(secret, error) + return tb, 500 + elif hasattr(self.server, "exception_handler"): + # FastAPI + @self.server.exception_handler(Exception) + async def _wrap_errors(request, error): + tb = _get_traceback(secret, error) + from fastapi.responses import PlainTextResponse + return PlainTextResponse(tb, status_code=500) if debug and dev_tools.ui: @@ -2149,9 +2038,8 @@ def _after_request(response): return response - self.server.before_request(_before_request) - - self.server.after_request(_after_request) + self.server_factory.before_request(self.server, _before_request) + self.server_factory.after_request(self.server, _after_request) if ( debug @@ -2435,7 +2323,7 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server.run(host=host, port=port, debug=debug, **flask_run_options) + self.server_factory.run(self.server, host=host, port=port, debug=debug, **flask_run_options) def enable_pages(self) -> None: if not self.use_pages: @@ -2495,7 +2383,7 @@ async def update(pathname_, search_, **states): ) if callable(title): title = await execute_async_function( - title, **(path_variables or {}) + title, **{**(path_variables or {})} ) return layout, {"title": title} @@ -2559,7 +2447,7 @@ def update(pathname_, search_, **states): **{**(path_variables or {}), **query_parameters, **states} ) if callable(title): - title = title(**(path_variables or {})) + title = title(**{**(path_variables or {})}) return layout, {"title": title} @@ -2599,10 +2487,5 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - def __call__(self, environ, start_response): - """ - This method makes instances of Dash WSGI-compliant callables. - It delegates the actual WSGI handling to the internal Flask app's - __call__ method. - """ - return self.server(environ, start_response) + def __call__(self, *args, **kwargs): + return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py new file mode 100644 index 0000000000..7d9874ec7a --- /dev/null +++ b/dash/server_factories/__init__.py @@ -0,0 +1,10 @@ +# python +import contextvars + +_request_adapter_var = contextvars.ContextVar("request_adapter") + +def set_request_adapter(adapter): + _request_adapter_var.set(adapter) + +def get_request_adapter(): + return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py new file mode 100644 index 0000000000..f429357e03 --- /dev/null +++ b/dash/server_factories/base_factory.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod + +class BaseServerFactory(ABC): + def __call__(self, server, *args, **kwargs): + # Default: WSGI + return server(*args, **kwargs) + + @abstractmethod + def create_app(self, name="__main__", config=None): + pass + + @abstractmethod + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + pass + + @abstractmethod + def register_error_handlers(self, app): + pass + + @abstractmethod + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + pass + + @abstractmethod + def before_request(self, app, func): + pass + + @abstractmethod + def after_request(self, app, func): + pass + + @abstractmethod + def run(self, app, host, port, debug, **kwargs): + pass + + @abstractmethod + def make_response(self, data, mimetype=None, content_type=None): + pass + + @abstractmethod + def jsonify(self, obj): + pass + + @abstractmethod + def get_request_adapter(self): + pass + diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py new file mode 100644 index 0000000000..7592a51ce3 --- /dev/null +++ b/dash/server_factories/fastapi_factory.py @@ -0,0 +1,226 @@ +import traceback + +from fastapi import FastAPI, Request, Response, APIRouter +from fastapi.responses import JSONResponse +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter, get_request_adapter +from .base_factory import BaseServerFactory +import inspect +import pkgutil +from contextvars import copy_context + +class FastAPIServerFactory(BaseServerFactory): + def __call__(self, server, *args, **kwargs): + # ASGI: (scope, receive, send) + if ( + len(args) == 3 + and isinstance(args[0], dict) + and "type" in args[0] + ): + return server(*args, **kwargs) + raise TypeError("FastAPI app must be called with (scope, receive, send)") + + + def create_app(self, name="__main__", config=None): + app = FastAPI() + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + from fastapi.staticfiles import StaticFiles + try: + app.mount(assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self, app): + @app.exception_handler(PreventUpdate) + async def _handle_error(request: Request, exc: PreventUpdate): + return Response(status_code=204) + + @app.exception_handler(InvalidResourceError) + async def _invalid_resources_handler(request: Request, exc: InvalidResourceError): + return Response(content=exc.args[0], status_code=404) + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, app, dash_app): + async def index(): + return Response(content=dash_app.render_index(), media_type="text/html") + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + async def catchall(path: str): + return Response(content=dash_app.render_index(), media_type="text/html") + + # self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + app.add_api_route(rule, view_func, methods=methods or ["GET"], name=endpoint, include_in_schema=False) + + def before_request(self, app, func): + # FastAPI does not have before_request, but we can use middleware + app.middleware("http")(self._make_before_middleware(func)) + + def after_request(self, app, func): + # FastAPI does not have after_request, but we can use middleware + app.middleware("http")(self._make_after_middleware(func)) + + def run(self, app, host, port, debug, **kwargs): + import uvicorn + reload = debug + if reload: + # Assume app is created in 'main.py' as 'app' + # Adjust 'main:app' as needed for your project structure + uvicorn.run("app:app", host=host, port=port, reload=reload, **kwargs) + else: + uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers) + + def jsonify(self, obj): + return JSONResponse(content=obj) + + def get_request_adapter(self): + return FastAPIRequestAdapter + + def _make_before_middleware(self, func): + pass + async def middleware(request, call_next): + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + response = await call_next(request) + return response + + return middleware + + def _make_after_middleware(self, func): + pass + async def middleware(request, call_next): + response = await call_next(request) + await func() + return response + return middleware + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request): + import sys + import mimetypes + import pkgutil + from dash.fingerprint import check_fingerprint + from dash import _validate + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + from starlette.responses import Response as StarletteResponse + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + else: + import hashlib + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + from fastapi import Request + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites(dash_app, package_name, fingerprinted_path, request) + + self.add_url_rule( + app, + "/_dash-component-suites/{package_name}/{fingerprinted_path:path}", + serve, + ) + + def dispatch(self, app, dash_app, use_async): + async def _dispatch(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + body = await request.json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + # Instead of set_data, return a new Response + return Response(content=response_data, media_type="application/json") + + return _dispatch + + def _serve_default_favicon(self): + return Response( + content=pkgutil.get_data("dash", "favicon.ico"), + media_type="image/x-icon" + ) + +class FastAPIRequestAdapter: + def __init__(self): + self._request = None + + def set_request(self, request: Request): + self._request = request + + def get_args(self): + return self._request.query_params + + async def get_json(self): + return await self._request.json() + + def is_json(self): + return self._request.headers.get("content-type", "").startswith("application/json") + + def get_cookies(self, request=None): + return self._request.cookies + + def get_headers(self): + return self._request.headers + + def get_full_path(self): + return str(self._request.url) + + def get_remote_addr(self): + return self._request.client.host if self._request.client else None + + def get_origin(self): + return self._request.headers.get("origin") diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py new file mode 100644 index 0000000000..82b6b266a8 --- /dev/null +++ b/dash/server_factories/flask_factory.py @@ -0,0 +1,188 @@ +import flask +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter, get_request_adapter +from .base_factory import BaseServerFactory +from contextvars import copy_context +import asyncio +import pkgutil + +class FlaskServerFactory(BaseServerFactory): + def __call__(self, server, *args, **kwargs): + # Always WSGI + return server(*args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = flask.Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + bp = flask.Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + app.add_url_rule(rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]) + + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + app.after_request(func) + + def run(self, app, host, port, debug, **kwargs): + app.run(host=host, port=port, debug=debug, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + return flask.Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj): + return flask.jsonify(obj) + + def get_request_adapter(self): + return FlaskRequestAdapter + + def setup_catchall(self, app, dash_app): + def catchall(path, *args, **kwargs): + return dash_app.index(*args, **kwargs) + self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) + + def setup_index(self, app, dash_app): + def index(*args, **kwargs): + return dash_app.render_index(dash_app, *args, **kwargs) + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request=None): + import sys + import mimetypes + import pkgutil + from dash.fingerprint import check_fingerprint + from dash import _validate + import flask + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = flask.Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = flask.request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = flask.Response(None, status=304) + return response + + def setup_component_suites(self, app, dash_app): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites(dash_app, package_name, fingerprinted_path, flask.request) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + ) + + def dispatch(self, app, dash_app, use_async=False): + def _dispatch(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + g.dash_response.set_data(response_data) + return g.dash_response + + async def _dispatch_async(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + g.dash_response.set_data(response_data) + return g.dash_response + + if use_async: + _dispatch = _dispatch_async + return _dispatch + + def _serve_default_favicon(): + import flask + return flask.Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + +class FlaskRequestAdapter: + @staticmethod + def get_args(): + return flask.request.args + + @staticmethod + def get_json(): + return flask.request.get_json() + + @staticmethod + def is_json(): + return flask.request.is_json + + @staticmethod + def get_cookies(): + return flask.request.cookies + + @staticmethod + def get_headers(): + return flask.request.headers + + @staticmethod + def get_full_path(): + return flask.request.full_path + + @staticmethod + def get_remote_addr(): + return flask.request.remote_addr + + @staticmethod + def get_origin(): + return getattr(flask.request, 'origin', None) From 9dff79140b93e65b2076e61ff821bc324a936f00 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:28:31 -0400 Subject: [PATCH 03/74] fix favicon --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 82b6b266a8..bdcc6aef87 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -148,7 +148,7 @@ async def _dispatch_async(): _dispatch = _dispatch_async return _dispatch - def _serve_default_favicon(): + def _serve_default_favicon(self): import flask return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" From c319b18a18044bbbf1c2081731feeafbeff5fd2a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:29:37 -0400 Subject: [PATCH 04/74] removing changelog entry --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4120ea8f93..a6a2224907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,6 @@ All notable changes to `dash` will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/). -## [bringyourownserver] -- [#3430] Adds support to bring your own server, eg (Quart, FastAPI, etc). - ## [UNRELEASED] ## Added From 7de2a41017a37ee54a7a2d6219cb4211e71b3a11 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:51:45 -0400 Subject: [PATCH 05/74] fixing issue with debug true for FastAPI --- dash/server_factories/fastapi_factory.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 7592a51ce3..f893e61bc6 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -8,6 +8,7 @@ import inspect import pkgutil from contextvars import copy_context +import importlib.util class FastAPIServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): @@ -81,12 +82,15 @@ def after_request(self, app, func): app.middleware("http")(self._make_after_middleware(func)) def run(self, app, host, port, debug, **kwargs): + frame = inspect.stack()[2] import uvicorn + reload = debug if reload: - # Assume app is created in 'main.py' as 'app' - # Adjust 'main:app' as needed for your project structure - uvicorn.run("app:app", host=host, port=port, reload=reload, **kwargs) + # Dynamically determine the module name from the file path + file_path = frame.filename + module_name = importlib.util.spec_from_file_location("app", file_path).name + uvicorn.run(f"{module_name}:app.server", host=host, port=port, reload=reload, **kwargs) else: uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) From 2cd769e51f2050aedcfd030feb3a2c4bed09938e Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:04:10 -0400 Subject: [PATCH 06/74] fixing `catchall` for path routes --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index bdcc6aef87..4748fa317e 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -58,7 +58,7 @@ def get_request_adapter(self): def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): - return dash_app.index(*args, **kwargs) + return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) def setup_index(self, app, dash_app): From 686f32f64e45904ab13059dbf1b352df28a02601 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:37:30 -0400 Subject: [PATCH 07/74] fixing pages for use with `fastapi` --- dash/_pages.py | 10 +++++----- dash/dash.py | 7 +++++-- dash/server_factories/fastapi_factory.py | 23 +++++++++++++++++++---- dash/server_factories/flask_factory.py | 10 ++++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 45538546e8..b1cd0cbe69 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -389,15 +389,15 @@ def _path_to_page(path_id): return {}, None -def _page_meta_tags(app): - start_page, path_variables = _path_to_page(flask.request.path.strip("/")) +def _page_meta_tags(app, request): + request_url = request.get_path() + start_page, path_variables = _path_to_page(request_url.strip("/")) - # use the supplied image_url or create url based on image in the assets folder image = start_page.get("image", "") if image: image = app.get_asset_url(image) assets_image_url = ( - "".join([flask.request.url_root, image.lstrip("/")]) if image else None + "".join([request.url_root, image.lstrip("/")]) if image else None ) supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +413,7 @@ def _page_meta_tags(app): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": flask.request.url}, + {"property": "twitter:url", "content": request_url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/dash.py b/dash/dash.py index 56bf65c9e6..18fe56acca 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -44,6 +44,7 @@ ProxyError, DuplicateCallback, ) +from .server_factories import get_request_adapter from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -1141,9 +1142,10 @@ def render_index(self, *args, **kwargs): metas = self._generate_meta() renderer = self._generate_renderer() title = self.title + request = get_request_adapter() if self.use_pages and self.config.include_pages_meta: - metas = _page_meta_tags(self) + metas + metas = _page_meta_tags(self, request) + metas if self._favicon: favicon_mod_time = os.path.getmtime( @@ -2331,7 +2333,7 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - @self.server.before_request + def router(): if self._got_first_request["pages"]: return @@ -2487,5 +2489,6 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) + self.server_factory.before_request(self.server, router) def __call__(self, *args, **kwargs): return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index f893e61bc6..aa4ff5d523 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -55,15 +55,27 @@ async def wrapped(*args, **kwargs): return wrapped def setup_index(self, app, dash_app): - async def index(): + async def index(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - async def catchall(path: str): - return Response(content=dash_app.render_index(), media_type="text/html") + @dash_app.server.on_event("startup") + def _setup_catchall(): + from fastapi import Request, Response - # self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + async def catchall(path: str, request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.render_index(), media_type="text/html") + + self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + + pass # catchall needs to be last to not override other routes def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -228,3 +240,6 @@ def get_remote_addr(self): def get_origin(self): return self._request.headers.get("origin") + + def get_path(self): + return self._request.url.path # <-- Add this method diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 4748fa317e..1c748e01ed 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -58,11 +58,17 @@ def get_request_adapter(self): def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(flask.request) return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) def setup_index(self, app, dash_app): def index(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(flask.request) return dash_app.render_index(dash_app, *args, **kwargs) self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) @@ -186,3 +192,7 @@ def get_remote_addr(): @staticmethod def get_origin(): return getattr(flask.request, 'origin', None) + + @staticmethod + def get_path(): + return flask.request.path From 660e257604bc5e95681e6b7d495830c5ed5686ac Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:04:52 -0400 Subject: [PATCH 08/74] fixing issue with flask pages --- dash/server_factories/flask_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 1c748e01ed..9bc7929685 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -60,7 +60,6 @@ def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - adapter.set_request(flask.request) return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) @@ -68,7 +67,6 @@ def setup_index(self, app, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - adapter.set_request(flask.request) return dash_app.render_index(dash_app, *args, **kwargs) self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) From 0fa5c99de789f9161be40bfe72c05e4140906281 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:20:21 -0400 Subject: [PATCH 09/74] fixing for lint --- dash/_pages.py | 4 +- dash/dash.py | 34 ++++++++---- dash/server_factories/__init__.py | 2 + dash/server_factories/base_factory.py | 6 ++- dash/server_factories/fastapi_factory.py | 68 +++++++++++++++++------- dash/server_factories/flask_factory.py | 26 ++++++--- 6 files changed, 102 insertions(+), 38 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index b1cd0cbe69..2a3a116324 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,9 +396,7 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([request.url_root, image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.url_root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url diff --git a/dash/dash.py b/dash/dash.py index 18fe56acca..f6f6e76e01 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -497,19 +497,25 @@ def __init__( # pylint: disable=too-many-statements # We have 3 cases: server is either True (we create the server), False # (defer server creation) or a Flask app instance (we use their server) - if callable(server) and not (hasattr(server, 'route') and hasattr(server, 'run')): + if callable(server) and not ( + hasattr(server, "route") and hasattr(server, "run") + ): # Server factory function self.server = server() if name is None: caller_name = getattr(self.server, "name", caller_name) - elif hasattr(server, 'route') and hasattr(server, 'run'): + elif hasattr(server, "route") and hasattr(server, "run"): self.server = server if name is None: caller_name = getattr(server, "name", caller_name) elif isinstance(server, bool): - self.server = self.server_factory.create_app(caller_name) if server else None + self.server = ( + self.server_factory.create_app(caller_name) if server else None + ) else: - raise ValueError("server must be a Flask app, a boolean, or a server factory function") + raise ValueError( + "server must be a Flask app, a boolean, or a server factory function" + ) base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -707,6 +713,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: if config.compress: try: from flask_compress import Compress + Compress(self.server) _flask_compress_version = parse_version( _get_distribution_version("flask_compress") @@ -741,7 +748,11 @@ def _setup_routes(self): self.server_factory.setup_component_suites(self.server, self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - self._add_url("_dash-update-component", self.server_factory.dispatch(self.server, self, self._use_async), ["POST"]) + self._add_url( + "_dash-update-component", + self.server_factory.dispatch(self.server, self, self._use_async), + ["POST"], + ) self._add_url("_reload-hash", self.serve_reload_hash) self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) self.server_factory.setup_index(self.server, self) @@ -759,7 +770,6 @@ def _setup_routes(self): hook.data["methods"], ) - def setup_apis(self): """ Register API endpoints for all callbacks defined using `dash.callback`. @@ -1368,7 +1378,9 @@ def _initialize_context(self, body, adapter): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = self.server_factory.make_response(mimetype="application/json", data=None) + g.dash_response = self.server_factory.make_response( + mimetype="application/json", data=None + ) g.cookies = dict(adapter.get_cookies()) g.headers = dict(adapter.get_headers()) g.path = adapter.get_full_path() @@ -2004,12 +2016,14 @@ def enable_dev_tools( def _wrap_errors(error): tb = _get_traceback(secret, error) return tb, 500 + elif hasattr(self.server, "exception_handler"): # FastAPI @self.server.exception_handler(Exception) async def _wrap_errors(request, error): tb = _get_traceback(secret, error) from fastapi.responses import PlainTextResponse + return PlainTextResponse(tb, status_code=500) if debug and dev_tools.ui: @@ -2325,7 +2339,9 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server_factory.run(self.server, host=host, port=port, debug=debug, **flask_run_options) + self.server_factory.run( + self.server, host=host, port=port, debug=debug, **flask_run_options + ) def enable_pages(self) -> None: if not self.use_pages: @@ -2333,7 +2349,6 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - def router(): if self._got_first_request["pages"]: return @@ -2490,5 +2505,6 @@ def update(pathname_, search_, **states): ) self.server_factory.before_request(self.server, router) + def __call__(self, *args, **kwargs): return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py index 7d9874ec7a..1bfd497935 100644 --- a/dash/server_factories/__init__.py +++ b/dash/server_factories/__init__.py @@ -3,8 +3,10 @@ _request_adapter_var = contextvars.ContextVar("request_adapter") + def set_request_adapter(adapter): _request_adapter_var.set(adapter) + def get_request_adapter(): return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index f429357e03..b44f6888cb 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class BaseServerFactory(ABC): def __call__(self, server, *args, **kwargs): # Default: WSGI @@ -10,7 +11,9 @@ def create_app(self, name="__main__", config=None): pass @abstractmethod - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): pass @abstractmethod @@ -44,4 +47,3 @@ def jsonify(self, obj): @abstractmethod def get_request_adapter(self): pass - diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index aa4ff5d523..8d9efb2416 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -10,18 +10,14 @@ from contextvars import copy_context import importlib.util + class FastAPIServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): # ASGI: (scope, receive, send) - if ( - len(args) == 3 - and isinstance(args[0], dict) - and "type" in args[0] - ): + if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: return server(*args, **kwargs) raise TypeError("FastAPI app must be called with (scope, receive, send)") - def create_app(self, name="__main__", config=None): app = FastAPI() if config: @@ -29,10 +25,17 @@ def create_app(self, name="__main__", config=None): setattr(app.state, key, value) return app - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): from fastapi.staticfiles import StaticFiles + try: - app.mount(assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name) + app.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) except RuntimeError: # directory doesnt exist pass @@ -43,7 +46,9 @@ async def _handle_error(request: Request, exc: PreventUpdate): return Response(status_code=204) @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(request: Request, exc: InvalidResourceError): + async def _invalid_resources_handler( + request: Request, exc: InvalidResourceError + ): return Response(content=exc.args[0], status_code=404) def _html_response_wrapper(self, view_func): @@ -60,6 +65,7 @@ async def index(request: Request): set_request_adapter(adapter) adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): @@ -73,9 +79,11 @@ async def catchall(path: str, request: Request): adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + self.add_url_rule( + app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] + ) - pass # catchall needs to be last to not override other routes + pass # catchall needs to be last to not override other routes def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -83,7 +91,13 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if isinstance(view_func, str): # Wrap string or sync function to async FastAPI handler view_func = self._html_response_wrapper(view_func) - app.add_api_route(rule, view_func, methods=methods or ["GET"], name=endpoint, include_in_schema=False) + app.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=False, + ) def before_request(self, app, func): # FastAPI does not have before_request, but we can use middleware @@ -102,7 +116,13 @@ def run(self, app, host, port, debug, **kwargs): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name - uvicorn.run(f"{module_name}:app.server", host=host, port=port, reload=reload, **kwargs) + uvicorn.run( + f"{module_name}:app.server", + host=host, + port=port, + reload=reload, + **kwargs, + ) else: uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) @@ -122,6 +142,7 @@ def get_request_adapter(self): def _make_before_middleware(self, func): pass + async def middleware(request, call_next): if func is not None: if inspect.iscoroutinefunction(func): @@ -135,13 +156,17 @@ async def middleware(request, call_next): def _make_after_middleware(self, func): pass + async def middleware(request, call_next): response = await call_next(request) await func() return response + return middleware - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request): + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request + ): import sys import mimetypes import pkgutil @@ -162,12 +187,14 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req ) data = pkgutil.get_data(package_name, path_in_pkg) from starlette.responses import Response as StarletteResponse + headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" return StarletteResponse(content=data, media_type=mimetype, headers=headers) else: import hashlib + etag = hashlib.md5(data).hexdigest() if data else "" headers["ETag"] = etag if request.headers.get("if-none-match") == etag: @@ -176,8 +203,11 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, app, dash_app): from fastapi import Request + async def serve(request: Request, package_name: str, fingerprinted_path: str): - return self.serve_component_suites(dash_app, package_name, fingerprinted_path, request) + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) self.add_url_rule( app, @@ -206,10 +236,10 @@ async def _dispatch(request: Request): def _serve_default_favicon(self): return Response( - content=pkgutil.get_data("dash", "favicon.ico"), - media_type="image/x-icon" + content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) + class FastAPIRequestAdapter: def __init__(self): self._request = None @@ -224,7 +254,9 @@ async def get_json(self): return await self._request.json() def is_json(self): - return self._request.headers.get("content-type", "").startswith("application/json") + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) def get_cookies(self, request=None): return self._request.cookies diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 9bc7929685..5eaaf44a36 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -6,6 +6,7 @@ import asyncio import pkgutil + class FlaskServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): # Always WSGI @@ -17,7 +18,9 @@ def create_app(self, name="__main__", config=None): app.config.update(config) return app - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): bp = flask.Blueprint( blueprint_name, __name__, @@ -36,7 +39,9 @@ def _invalid_resources_handler(err): return err.args[0], 404 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule(rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]) + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) def before_request(self, app, func): app.before_request(func) @@ -61,7 +66,10 @@ def catchall(path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) + + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) def setup_index(self, app, dash_app): def index(*args, **kwargs): @@ -71,7 +79,9 @@ def index(*args, **kwargs): self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request=None): + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request=None + ): import sys import mimetypes import pkgutil @@ -105,7 +115,9 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, app, dash_app): def serve(package_name, fingerprinted_path): - return self.serve_component_suites(dash_app, package_name, fingerprinted_path, flask.request) + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, flask.request + ) self.add_url_rule( app, @@ -154,10 +166,12 @@ async def _dispatch_async(): def _serve_default_favicon(self): import flask + return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + class FlaskRequestAdapter: @staticmethod def get_args(): @@ -189,7 +203,7 @@ def get_remote_addr(): @staticmethod def get_origin(): - return getattr(flask.request, 'origin', None) + return getattr(flask.request, "origin", None) @staticmethod def get_path(): From 1088331323380bd58c27f3776d21d4e4132bfd3d Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:33:16 -0400 Subject: [PATCH 10/74] fixing issue with failing test due to `endpoint` name --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 5eaaf44a36..69173516b2 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -77,7 +77,7 @@ def index(*args, **kwargs): set_request_adapter(adapter) return dash_app.render_index(dash_app, *args, **kwargs) - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) def serve_component_suites( self, dash_app, package_name, fingerprinted_path, request=None From 4920e33cf68651bebc699f72f6da1e8eadbf925e Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:33:43 -0400 Subject: [PATCH 11/74] fixing `run` command to trigger `devtools` properly --- dash/server_factories/fastapi_factory.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 8d9efb2416..1fa07a6ac6 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -12,6 +12,10 @@ class FastAPIServerFactory(BaseServerFactory): + def __init__(self): + self.config = {} + super().__init__() + def __call__(self, server, *args, **kwargs): # ASGI: (scope, receive, send) if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: @@ -71,6 +75,7 @@ async def index(request: Request): def setup_catchall(self, app, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): + dash_app.enable_dev_tools(**self.config) # do this to make sure dev tools are enabled from fastapi import Request, Response async def catchall(path: str, request: Request): @@ -111,6 +116,9 @@ def run(self, app, host, port, debug, **kwargs): frame = inspect.stack()[2] import uvicorn + self.config = dict({'debug': debug} if debug else {}, **kwargs) + + reload = debug if reload: # Dynamically determine the module name from the file path From 9ffba5a58652cc7b28d253e46662ad8cbe0fb8bd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:51:02 -0400 Subject: [PATCH 12/74] fixing issue with lint and debug ui --- dash/dash.py | 69 +++--------- dash/server_factories/fastapi_factory.py | 136 ++++++++++++++--------- dash/server_factories/flask_factory.py | 68 ++++++++---- 3 files changed, 144 insertions(+), 129 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index f6f6e76e01..2151f31f77 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -4,7 +4,6 @@ import collections import importlib import warnings -from contextvars import copy_context from importlib.machinery import ModuleSpec from importlib.util import find_spec from importlib import metadata @@ -12,12 +11,10 @@ import threading import re import logging -import time import mimetypes import hashlib import base64 import traceback -import inspect from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List @@ -30,7 +27,7 @@ from dash import html from dash import dash_table -from .fingerprint import build_fingerprint, check_fingerprint +from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( Input, @@ -39,8 +36,6 @@ ) from .development.base_component import ComponentRegistry from .exceptions import ( - PreventUpdate, - InvalidResourceError, ProxyError, DuplicateCallback, ) @@ -72,7 +67,7 @@ from .server_factories.flask_factory import FlaskServerFactory from .server_factories.base_factory import BaseServerFactory -from ._get_app import with_app_context, with_app_context_async, with_app_context_factory +from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group from ._obsolete import ObsoleteChecker @@ -712,8 +707,9 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: ) if config.compress: try: - from flask_compress import Compress + import flask_compress # pylint: disable=import-outside-toplevel + Compress = flask_compress.Compress Compress(self.server) _flask_compress_version = parse_version( _get_distribution_version("flask_compress") @@ -754,7 +750,10 @@ def _setup_routes(self): ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) - self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) + self._add_url( + "_favicon.ico", + self.server_factory._serve_default_favicon, # pylint: disable=protected-access + ) self.server_factory.setup_index(self.server, self) self.server_factory.setup_catchall(self.server, self) @@ -1145,7 +1144,7 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - def render_index(self, *args, **kwargs): + def render_index(self, *_args, **_kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() @@ -1845,6 +1844,7 @@ def enable_dev_tools( dev_tools_silence_routes_logging: Optional[bool] = None, dev_tools_disable_version_check: Optional[bool] = None, dev_tools_prune_errors: Optional[bool] = None, + first_run: bool = True, ) -> bool: """Activate the dev tools, called by `run`. If your application is served by wsgi and you want to activate the dev tools, you can call @@ -2009,53 +2009,12 @@ def enable_dev_tools( ) elif dev_tools.prune_errors: secret = gen_salt(20) - - if hasattr(self.server, "errorhandler"): - # Flask - @self.server.errorhandler(Exception) - def _wrap_errors(error): - tb = _get_traceback(secret, error) - return tb, 500 - - elif hasattr(self.server, "exception_handler"): - # FastAPI - @self.server.exception_handler(Exception) - async def _wrap_errors(request, error): - tb = _get_traceback(secret, error) - from fastapi.responses import PlainTextResponse - - return PlainTextResponse(tb, status_code=500) + self.server_factory.register_prune_error_handler( + self.server, secret, _get_traceback + ) if debug and dev_tools.ui: - - def _before_request(): - flask.g.timing_information = { # pylint: disable=assigning-non-slot - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - - if info.get("dur") is not None: - value += f";dur={info['dur']}" - - response.headers.add("Server-Timing", value) - - return response - - self.server_factory.before_request(self.server, _before_request) - self.server_factory.after_request(self.server, _after_request) + self.server_factory.register_timing_hooks(self.server, first_run) if ( debug diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 1fa07a6ac6..918ca2175f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -1,14 +1,22 @@ -import traceback - -from fastapi import FastAPI, Request, Response, APIRouter -from fastapi.responses import JSONResponse -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter, get_request_adapter -from .base_factory import BaseServerFactory +import sys +import mimetypes +import hashlib import inspect import pkgutil from contextvars import copy_context import importlib.util +import time +import uvicorn +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response as StarletteResponse +from starlette.datastructures import MutableHeaders +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from .base_factory import BaseServerFactory class FastAPIServerFactory(BaseServerFactory): @@ -32,8 +40,6 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - from fastapi.staticfiles import StaticFiles - try: app.mount( assets_url_path, @@ -46,17 +52,21 @@ def register_assets_blueprint( def register_error_handlers(self, app): @app.exception_handler(PreventUpdate) - async def _handle_error(request: Request, exc: PreventUpdate): + async def _handle_error(_request, _exc): return Response(status_code=204) @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler( - request: Request, exc: InvalidResourceError - ): + async def _invalid_resources_handler(_request, exc): return Response(content=exc.args[0], status_code=404) + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.exception_handler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return PlainTextResponse(tb, status_code=500) + def _html_response_wrapper(self, view_func): - async def wrapped(*args, **kwargs): + async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly html = view_func() if callable(view_func) else view_func return Response(content=html, media_type="text/html") @@ -75,10 +85,11 @@ async def index(request: Request): def setup_catchall(self, app, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): - dash_app.enable_dev_tools(**self.config) # do this to make sure dev tools are enabled - from fastapi import Request, Response + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled - async def catchall(path: str, request: Request): + async def catchall(_path: str, request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) @@ -88,8 +99,6 @@ async def catchall(path: str, request: Request): app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] ) - pass # catchall needs to be last to not override other routes - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": rule = "/" @@ -114,11 +123,7 @@ def after_request(self, app, func): def run(self, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - import uvicorn - - self.config = dict({'debug': debug} if debug else {}, **kwargs) - - + self.config = dict({"debug": debug} if debug else {}, **kwargs) reload = debug if reload: # Dynamically determine the module name from the file path @@ -149,8 +154,6 @@ def get_request_adapter(self): return FastAPIRequestAdapter def _make_before_middleware(self, func): - pass - async def middleware(request, call_next): if func is not None: if inspect.iscoroutinefunction(func): @@ -163,11 +166,13 @@ async def middleware(request, call_next): return middleware def _make_after_middleware(self, func): - pass - async def middleware(request, call_next): response = await call_next(request) - await func() + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() return response return middleware @@ -175,12 +180,6 @@ async def middleware(request, call_next): def serve_component_suites( self, dash_app, package_name, fingerprinted_path, request ): - import sys - import mimetypes - import pkgutil - from dash.fingerprint import check_fingerprint - from dash import _validate - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -194,24 +193,17 @@ def serve_component_suites( package.__path__, ) data = pkgutil.get_data(package_name, path_in_pkg) - from starlette.responses import Response as StarletteResponse - headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" return StarletteResponse(content=data, media_type=mimetype, headers=headers) - else: - import hashlib - - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if request.headers.get("if-none-match") == etag: - return StarletteResponse(status_code=304) - return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): - from fastapi import Request - async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -223,17 +215,26 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): serve, ) - def dispatch(self, app, dash_app, use_async): + def dispatch(self, _app, dash_app, _use_async): async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) + # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + g = dash_app._initialize_context( + body, adapter + ) # pylint: disable=protected-access + func = dash_app._prepare_callback( + g, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + g.inputs_list + g.states_list + ) # pylint: disable=protected-access ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback( + func, args, g.outputs_list, g + ) # pylint: disable=protected-access response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): response_data = await response_data @@ -247,6 +248,33 @@ def _serve_default_favicon(self): content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) + def register_timing_hooks(self, app, first_run): + if not first_run: + return + + @app.middleware("http") + async def timing_middleware(request, call_next): + # Before request + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + response = await call_next(request) + # After request + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + class FastAPIRequestAdapter: def __init__(self): @@ -266,7 +294,7 @@ def is_json(self): "application/json" ) - def get_cookies(self, request=None): + def get_cookies(self, _request=None): return self._request.cookies def get_headers(self): @@ -282,4 +310,4 @@ def get_origin(self): return self._request.headers.get("origin") def get_path(self): - return self._request.url.path # <-- Add this method + return self._request.url.path diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 69173516b2..dafa4b24b4 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -1,10 +1,15 @@ -import flask -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter, get_request_adapter -from .base_factory import BaseServerFactory from contextvars import copy_context import asyncio import pkgutil +import sys +import mimetypes +import time +import flask +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from .base_factory import BaseServerFactory class FlaskServerFactory(BaseServerFactory): @@ -38,6 +43,12 @@ def _handle_error(_): def _invalid_resources_handler(err): return err.args[0], 404 + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = get_traceback_func(secret, error) + return tb, 500 + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] @@ -62,7 +73,7 @@ def get_request_adapter(self): return FlaskRequestAdapter def setup_catchall(self, app, dash_app): - def catchall(path, *args, **kwargs): + def catchall(_path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) @@ -75,20 +86,11 @@ def setup_index(self, app, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(dash_app, *args, **kwargs) + return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request=None - ): - import sys - import mimetypes - import pkgutil - from dash.fingerprint import check_fingerprint - from dash import _validate - import flask - + def serve_component_suites(self, dash_app, package_name, fingerprinted_path): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -116,7 +118,7 @@ def serve_component_suites( def setup_component_suites(self, app, dash_app): def serve(package_name, fingerprinted_path): return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, flask.request + dash_app, package_name, fingerprinted_path ) self.add_url_rule( @@ -125,11 +127,12 @@ def serve(package_name, fingerprinted_path): serve, ) - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, _app, dash_app, use_async=False): def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) body = flask.request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) @@ -149,6 +152,7 @@ async def _dispatch_async(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) body = flask.request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) @@ -161,16 +165,40 @@ async def _dispatch_async(): return g.dash_response if use_async: - _dispatch = _dispatch_async + return _dispatch_async return _dispatch def _serve_default_favicon(self): - import flask return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + def register_timing_hooks(self, app, _first_run): + def _before_request(): + flask.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response): + timing_information = flask.g.get("timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(app, _before_request) + self.after_request(app, _after_request) + class FlaskRequestAdapter: @staticmethod From 908aacd729695fd2ef8d79a6343d0ef21b6cea84 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:39:56 -0400 Subject: [PATCH 13/74] fixing issue with `_app` when using dispatch, need to keep in context --- dash/server_factories/fastapi_factory.py | 4 +++- dash/server_factories/flask_factory.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 918ca2175f..ff21e61d72 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -215,7 +215,9 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): serve, ) - def dispatch(self, _app, dash_app, _use_async): + def dispatch( + self, app, dash_app, use_async=False + ): # pylint: disable=unused-argument async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index dafa4b24b4..b16135cfff 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -127,7 +127,9 @@ def serve(package_name, fingerprinted_path): serve, ) - def dispatch(self, _app, dash_app, use_async=False): + def dispatch( + self, app, dash_app, use_async=False + ): # pylint: disable=unused-argument def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) From 9491c7fbfbc637029092413ffee155f56bcf4988 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:58:55 -0400 Subject: [PATCH 14/74] fixing issue with catchall --- dash/server_factories/fastapi_factory.py | 2 +- dash/server_factories/flask_factory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index ff21e61d72..0853972d1f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -89,7 +89,7 @@ def _setup_catchall(): **self.config, first_run=False ) # do this to make sure dev tools are enabled - async def catchall(_path: str, request: Request): + async def catchall(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index b16135cfff..bb1204af19 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -73,7 +73,7 @@ def get_request_adapter(self): return FlaskRequestAdapter def setup_catchall(self, app, dash_app): - def catchall(_path, *args, **kwargs): + def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) From 39ad7bd9c837699393e05ebdcc12c0c95119bc8f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:18:17 -0400 Subject: [PATCH 15/74] fixing issue with args and cancelling callbacks --- dash/_callback_context.py | 8 ++++++++ dash/dash.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index f64865c464..72b92e09e2 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -288,6 +288,14 @@ def path(self): """ return _get_from_context("path", "") + @property + @has_context + def args(self): + """ + Query parameters of the callback request as a dictionary-like object. + """ + return _get_from_context("args", "") + @property @has_context def remote(self): diff --git a/dash/dash.py b/dash/dash.py index 2151f31f77..d20672453c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -70,6 +70,7 @@ from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group from ._obsolete import ObsoleteChecker +from ._callback_context import callback_context from . import _pages from ._pages import ( @@ -1382,6 +1383,7 @@ def _initialize_context(self, body, adapter): ) g.cookies = dict(adapter.get_cookies()) g.headers = dict(adapter.get_headers()) + g.args = adapter.get_args() g.path = adapter.get_full_path() g.remote = adapter.get_remote_addr() g.origin = adapter.get_origin() @@ -1529,7 +1531,7 @@ def _setup_server(self): manager=manager, ) def cancel_call(*_): - job_ids = flask.request.args.getlist("cancelJob") + job_ids = callback_context.args.getlist("cancelJob") executor = _callback.context_value.get().background_callback_manager if job_ids: for job_id in job_ids: From 7bf69a7583e1b216de132d6a622ded05d85f1ce8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:33:33 -0400 Subject: [PATCH 16/74] fixing issues with pages metadata and flaky tests --- dash/_pages.py | 4 +-- dash/server_factories/fastapi_factory.py | 28 +++++++++++++++---- dash/server_factories/flask_factory.py | 8 ++++++ .../multi_page/test_pages_relative_path.py | 3 +- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 2a3a116324..3fab86eb99 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,7 +396,7 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = "".join([request.url_root, image.lstrip("/")]) if image else None + assets_image_url = "".join([request.get_root(), image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -411,7 +411,7 @@ def _page_meta_tags(app, request): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": request_url}, + {"property": "twitter:url", "content": request.get_url()}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 0853972d1f..19d70b022e 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -6,12 +6,22 @@ from contextvars import copy_context import importlib.util import time -import uvicorn -from fastapi import FastAPI, Request, Response -from fastapi.responses import JSONResponse, PlainTextResponse -from fastapi.staticfiles import StaticFiles -from starlette.responses import Response as StarletteResponse -from starlette.datastructures import MutableHeaders + +try: + import uvicorn + from fastapi import FastAPI, Request, Response + from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders +except ImportError: + uvicorn = None + FastAPI = Request = Response = None + JSONResponse = PlainTextResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -285,6 +295,9 @@ def __init__(self): def set_request(self, request: Request): self._request = request + def get_root(self): + return str(self._request.base_url) + def get_args(self): return self._request.query_params @@ -305,6 +318,9 @@ def get_headers(self): def get_full_path(self): return str(self._request.url) + def get_url(self): + return str(self._request.url) + def get_remote_addr(self): return self._request.client.host if self._request.client else None diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index bb1204af19..8153ec4f92 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -207,6 +207,10 @@ class FlaskRequestAdapter: def get_args(): return flask.request.args + @staticmethod + def get_root(): + return flask.request.url_root + @staticmethod def get_json(): return flask.request.get_json() @@ -223,6 +227,10 @@ def get_cookies(): def get_headers(): return flask.request.headers + @staticmethod + def get_url(): + return flask.request.url + @staticmethod def get_full_path(): return flask.request.full_path diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 6c505ac3f5..6fcbb6c6e0 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -2,6 +2,7 @@ import dash from dash import Dash, dcc, html +from dash.testing.wait import until def get_app(app): @@ -83,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"],timeout=3) assert dash_duo.get_logs() == [], "browser console should contain no error" From 10681dccfef11be7425b8fe11c2b8a21c148f20f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:53:27 -0400 Subject: [PATCH 17/74] fixing issues with relativate paths --- dash/server_factories/fastapi_factory.py | 5 ++--- dash/server_factories/flask_factory.py | 3 +-- tests/integration/multi_page/test_pages_relative_path.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 19d70b022e..914f591e17 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -219,9 +219,8 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) - self.add_url_rule( - app, - "/_dash-component-suites/{package_name}/{fingerprinted_path:path}", + dash_app._add_url( + "/_dash-component-suites//", serve, ) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 8153ec4f92..684596ac23 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -121,8 +121,7 @@ def serve(package_name, fingerprinted_path): dash_app, package_name, fingerprinted_path ) - self.add_url_rule( - app, + dash_app._add_url( "/_dash-component-suites//", serve, ) diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 6fcbb6c6e0..696ecc39a4 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -71,7 +71,7 @@ def test_pare002_relative_path_with_url_base_pathname( for page in dash.page_registry.values(): dash_br.find_element("#" + page["id"]).click() dash_br.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_br.driver.title == page["title"], "check that page title updates" + until(lambda: dash_br.driver.title == page["title"], timeout=3) assert dash_br.get_logs() == [], "browser console should contain no error" From 4944d6d2b3060d43489f71c82bb7bdaf69242471 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Thu, 11 Sep 2025 21:24:55 +0200 Subject: [PATCH 18/74] =?UTF-8?q?=E2=88=99=20-=20initial=20quart=20factory?= =?UTF-8?q?=20=E2=88=99=20-=20added=20types=20to=20BaseFactory=20to=20remo?= =?UTF-8?q?ve=20linting=20errors=20on=20create=20app=20in=20Flask=20and=20?= =?UTF-8?q?Quart=20Factory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dash/server_factories/base_factory.py | 25 +-- dash/server_factories/quart_factory.py | 238 +++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 12 deletions(-) create mode 100644 dash/server_factories/quart_factory.py diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index b44f6888cb..12088947c2 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,49 +1,50 @@ from abc import ABC, abstractmethod +from typing import Any class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs): + def __call__(self, server, *args, **kwargs) -> Any: # Default: WSGI return server(*args, **kwargs) @abstractmethod - def create_app(self, name="__main__", config=None): + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface pass @abstractmethod def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface pass @abstractmethod - def register_error_handlers(self, app): + def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface pass @abstractmethod - def before_request(self, app, func): + def before_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def after_request(self, app, func): + def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host, port, debug, **kwargs): + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None): + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface pass @abstractmethod - def jsonify(self, obj): + def jsonify(self, obj) -> Any: # pragma: no cover - interface pass @abstractmethod - def get_request_adapter(self): + def get_request_adapter(self) -> Any: # pragma: no cover - interface pass diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py new file mode 100644 index 0000000000..977c9aea4c --- /dev/null +++ b/dash/server_factories/quart_factory.py @@ -0,0 +1,238 @@ +from .base_factory import BaseServerFactory +from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from contextvars import copy_context +import inspect +import os +import pkgutil +import mimetypes +import hashlib +import sys + + +class QuartAPIServerFactory(BaseServerFactory): + """Quart implementation of the Dash server factory. + + All Quart/async specific imports are at the top-level (per user request) so + Quart must be installed when this module is imported. + """ + + def __init__(self) -> None: + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory + return super().__call__(server, *args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = Quart(name) + if config: + for key, value in config.items(): + # Mirror Flask usage of config dict + app.config[key] = value + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + if os.path.isdir(assets_folder): + route = f"{assets_url_path}/" + + @app.route(route) + async def serve_asset(filename): # pragma: no cover - simple passthrough + return await send_from_directory(assets_folder, filename) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return QuartResponse(html, content_type="text/html") + + return wrapped + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + # Wrap plain strings or sync callables in async handler returning HTML + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + view_func = self._html_response_wrapper(view_func) + app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + + # ---- Index & Catchall ------------------------------------------------ + def setup_index(self, app, dash_app): + async def index(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + @app.before_serving + async def _enable_dev_tools(): # pragma: no cover - environmental + dash_app.enable_dev_tools(**self.config) + + async def catchall(path): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + # Must be added after other routes + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) + + # ---- Middleware-esque hooks ----------------------------------------- + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + # Quart after_request expects (response) -> response + @app.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + # ---- Running --------------------------------------------------------- + def run(self, app, host, port, debug, **kwargs): + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) + + # ---- Responses / JSON ------------------------------------------------ + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["Content-Type"] = mimetype + if content_type: + headers["Content-Type"] = content_type + return QuartResponse(data, headers=headers) + + def jsonify(self, obj): + return jsonify(obj) + + def get_request_adapter(self): + return QuartRequestAdapter + + # ---- Component Suites ------------------------------------------------ + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, req + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return QuartResponse(data, content_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if req.headers.get("If-None-Match") == etag: + return QuartResponse(None, status=304) + return QuartResponse(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + methods=["GET"], + ) + + # ---- Dispatch (Callbacks) ------------------------------------------- + def dispatch(self, app, dash_app, use_async=True): # Quart always async + async def _dispatch(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + body = await request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return QuartResponse(response_data, content_type="application/json") + + return _dispatch + + # ---- Favicon --------------------------------------------------------- + def _serve_default_favicon(self): + return QuartResponse( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + +class QuartRequestAdapter: + """Adapter that normalizes Quart's request API to what Dash expects.""" + + @staticmethod + def get_args(): + return request.args + + @staticmethod + async def get_json(): + return await request.get_json() + + @staticmethod + def is_json(): + return request.is_json + + @staticmethod + def get_cookies(): + return request.cookies + + @staticmethod + def get_headers(): + return request.headers + + @staticmethod + def get_full_path(): + return request.full_path + + @staticmethod + def get_remote_addr(): + return request.remote_addr + + @staticmethod + def get_origin(): + return request.headers.get("Origin") + + @staticmethod + def get_path(): + return request.path + From 3b0f47e37d465a01ce6acd35a779458509d53aa1 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 15:28:49 +0200 Subject: [PATCH 19/74] Quart factory ready --- dash/server_factories/quart_factory.py | 154 +++++++++++++++++-------- 1 file changed, 107 insertions(+), 47 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 977c9aea4c..75376fcd7a 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,5 +1,5 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify, send_from_directory from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint @@ -11,6 +11,7 @@ import mimetypes import hashlib import sys +import time class QuartAPIServerFactory(BaseServerFactory): @@ -39,12 +40,50 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - if os.path.isdir(assets_folder): - route = f"{assets_url_path}/" + # Mirror Flask implementation using a blueprint serving static files + from quart import Blueprint + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def register_timing_hooks(self, app, _first_run): # parity with Flask factory + from quart import g + + @app.before_request + async def _before_request(): # pragma: no cover - timing infra + g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} - @app.route(route) - async def serve_asset(filename): # pragma: no cover - simple passthrough - return await send_from_directory(assets_folder, filename) + @app.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = getattr(g, "timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response def register_error_handlers(self, app): @app.errorhandler(PreventUpdate) @@ -61,49 +100,72 @@ async def wrapped(*args, **kwargs): if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val html = str(html_val) - return QuartResponse(html, content_type="text/html") + return Response(html, content_type="text/html") return wrapped def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - if rule == "": - rule = "/" - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - # Wrap plain strings or sync callables in async handler returning HTML - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - view_func = self._html_response_wrapper(view_func) - app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - # ---- Index & Catchall ------------------------------------------------ + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + # if rule == "": + # rule = "/" + # if isinstance(view_func, str): + # # Literal HTML content + # view_func = self._html_response_wrapper(view_func) + # elif not inspect.iscoroutinefunction(view_func): + # # Sync function: wrap to make async but preserve Response objects + # original = view_func + + # async def _async_adapter(*args, **kwargs): + # result = original(*args, **kwargs) + # # Pass through existing Response (Quart/Flask style) + # if isinstance(result, Response) or ( + # hasattr(result, "status_code") + # and hasattr(result, "headers") + # and hasattr(result, "get_data") + # ): + # return result + # # If it's bytes or str treat as HTML + # if isinstance(result, (str, bytes)): + # return Response(result, content_type="text/html") + # # Fallback: JSON encode arbitrary python objects + # try: + # import json + + # return Response( + # json.dumps(result), content_type="application/json" + # ) + # except Exception: # pragma: no cover + # return Response(str(result), content_type="text/plain") + + # view_func = _async_adapter + # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - @app.before_serving - async def _enable_dev_tools(): # pragma: no cover - environmental - dash_app.enable_dev_tools(**self.config) - async def catchall(path): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") - # Must be added after other routes self.add_url_rule( app, "/", catchall, endpoint="catchall", methods=["GET"] ) - # ---- Middleware-esque hooks ----------------------------------------- def before_request(self, app, func): app.before_request(func) def after_request(self, app, func): - # Quart after_request expects (response) -> response @app.after_request async def _after(response): if func is not None: @@ -112,19 +174,24 @@ async def _after(response): await result return response - # ---- Running --------------------------------------------------------- def run(self, app, host, port, debug, **kwargs): - self.config = dict({'debug': debug} if debug else {}, **kwargs) - app.run(host=host, port=port, debug=debug, **kwargs) + # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) + # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. + unsupported = {"threaded", "processes"} + filtered_kwargs = {} + for k, v in kwargs.items(): + if k in unsupported: + continue + filtered_kwargs[k] = v + + # Keep a slim config for potential future use (dev tools already enabled in Dash.run) + self.config = {'debug': debug} + self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) + + app.run(host=host, port=port, debug=debug, **filtered_kwargs) - # ---- Responses / JSON ------------------------------------------------ def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["Content-Type"] = mimetype - if content_type: - headers["Content-Type"] = content_type - return QuartResponse(data, headers=headers) + return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) @@ -132,7 +199,6 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - # ---- Component Suites ------------------------------------------------ def serve_component_suites( self, dash_app, package_name, fingerprinted_path, req ): @@ -152,12 +218,9 @@ def serve_component_suites( headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return QuartResponse(data, content_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if req.headers.get("If-None-Match") == etag: - return QuartResponse(None, status=304) - return QuartResponse(data, content_type=mimetype, headers=headers) + return Response(data, content_type=mimetype, headers=headers) + + return Response(data, content_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): async def serve(package_name, fingerprinted_path): @@ -172,7 +235,6 @@ async def serve(package_name, fingerprinted_path): methods=["GET"], ) - # ---- Dispatch (Callbacks) ------------------------------------------- def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() @@ -186,13 +248,12 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return QuartResponse(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") return _dispatch - # ---- Favicon --------------------------------------------------------- def _serve_default_favicon(self): - return QuartResponse( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -235,4 +296,3 @@ def get_origin(): @staticmethod def get_path(): return request.path - From 1112f7743c7a91791f0720ee505959a80fd0c0cd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:05:26 -0400 Subject: [PATCH 20/74] fixing for lint --- dash/_pages.py | 4 +++- dash/server_factories/fastapi_factory.py | 6 +++--- dash/server_factories/flask_factory.py | 6 +++--- tests/integration/multi_page/test_pages_relative_path.py | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 3fab86eb99..6c00e656c7 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,7 +396,9 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = "".join([request.get_root(), image.lstrip("/")]) if image else None + assets_image_url = ( + "".join([request.get_root(), image.lstrip("/")]) if image else None + ) supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 914f591e17..eb4a9392f5 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -219,14 +219,14 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) + # pylint: disable=protected-access dash_app._add_url( "/_dash-component-suites//", serve, ) - def dispatch( - self, app, dash_app, use_async=False - ): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 684596ac23..c2221469fc 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -121,14 +121,14 @@ def serve(package_name, fingerprinted_path): dash_app, package_name, fingerprinted_path ) + # pylint: disable=protected-access dash_app._add_url( "/_dash-component-suites//", serve, ) - def dispatch( - self, app, dash_app, use_async=False - ): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 696ecc39a4..24e7209a70 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -84,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - until(lambda: dash_duo.driver.title == page["title"],timeout=3) + until(lambda: dash_duo.driver.title == page["title"], timeout=3) assert dash_duo.get_logs() == [], "browser console should contain no error" From 8c52bbb9033588df0c764d6e3fd61e6c2defebd5 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:28:06 -0400 Subject: [PATCH 21/74] fixing issue with apps overwriting other paths --- dash/dash.py | 6 +++--- dash/server_factories/fastapi_factory.py | 16 ++++++++-------- dash/server_factories/flask_factory.py | 17 +++++++++-------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index d20672453c..4a06e9216e 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -742,7 +742,7 @@ def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> Non self.routes.append(full_name) def _setup_routes(self): - self.server_factory.setup_component_suites(self.server, self) + self.server_factory.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) self._add_url( @@ -755,8 +755,8 @@ def _setup_routes(self): "_favicon.ico", self.server_factory._serve_default_favicon, # pylint: disable=protected-access ) - self.server_factory.setup_index(self.server, self) - self.server_factory.setup_catchall(self.server, self) + self.server_factory.setup_index(self) + self.server_factory.setup_catchall(self) if jupyter_dash.active: self._add_url( diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index eb4a9392f5..de1caf451c 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -83,16 +83,17 @@ async def wrapped(*_args, **_kwargs): return wrapped - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): async def index(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, app, dash_app): + def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( @@ -105,9 +106,8 @@ async def catchall(request: Request): adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule( - app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] - ) + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -213,7 +213,7 @@ def serve_component_suites( return StarletteResponse(status_code=304) return StarletteResponse(content=data, media_type=mimetype, headers=headers) - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -221,7 +221,7 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=protected-access dash_app._add_url( - "/_dash-component-suites//", + "_dash-component-suites//", serve, ) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index c2221469fc..1ea561b076 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -50,6 +50,7 @@ def _wrap_errors(error): return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + print(rule, endpoint, methods) app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) @@ -72,23 +73,23 @@ def jsonify(self, obj): def get_request_adapter(self): return FlaskRequestAdapter - def setup_catchall(self, app, dash_app): + def setup_catchall(self, dash_app): def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule( - app, "/", catchall, endpoint="catchall", methods=["GET"] - ) + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) def serve_component_suites(self, dash_app, package_name, fingerprinted_path): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) @@ -115,7 +116,7 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path): response = flask.Response(None, status=304) return response - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -123,7 +124,7 @@ def serve(package_name, fingerprinted_path): # pylint: disable=protected-access dash_app._add_url( - "/_dash-component-suites//", + "_dash-component-suites//", serve, ) From aabeeb7801f47a4c2f1e54bb3278c20ec308f417 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:35:10 -0400 Subject: [PATCH 22/74] removing print --- dash/server_factories/flask_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 1ea561b076..6eebd735ac 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -50,7 +50,6 @@ def _wrap_errors(error): return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - print(rule, endpoint, methods) app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) From 5659cd73e98bda4ebb907480a56cf24207fb5f3a Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 16:37:39 +0200 Subject: [PATCH 23/74] cleanup --- dash/server_factories/quart_factory.py | 58 ++------------------------ 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 75376fcd7a..685b8d70e4 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,15 +1,13 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context import inspect -import os import pkgutil import mimetypes -import hashlib import sys import time @@ -26,21 +24,18 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory return super().__call__(server, *args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) if config: for key, value in config.items(): - # Mirror Flask usage of config dict app.config[key] = value return app def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - # Mirror Flask implementation using a blueprint serving static files from quart import Blueprint bp = Blueprint( @@ -109,41 +104,6 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - # if rule == "": - # rule = "/" - # if isinstance(view_func, str): - # # Literal HTML content - # view_func = self._html_response_wrapper(view_func) - # elif not inspect.iscoroutinefunction(view_func): - # # Sync function: wrap to make async but preserve Response objects - # original = view_func - - # async def _async_adapter(*args, **kwargs): - # result = original(*args, **kwargs) - # # Pass through existing Response (Quart/Flask style) - # if isinstance(result, Response) or ( - # hasattr(result, "status_code") - # and hasattr(result, "headers") - # and hasattr(result, "get_data") - # ): - # return result - # # If it's bytes or str treat as HTML - # if isinstance(result, (str, bytes)): - # return Response(result, content_type="text/html") - # # Fallback: JSON encode arbitrary python objects - # try: - # import json - - # return Response( - # json.dumps(result), content_type="application/json" - # ) - # except Exception: # pragma: no cover - # return Response(str(result), content_type="text/plain") - - # view_func = _async_adapter - # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() @@ -175,20 +135,8 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) - # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. - unsupported = {"threaded", "processes"} - filtered_kwargs = {} - for k, v in kwargs.items(): - if k in unsupported: - continue - filtered_kwargs[k] = v - - # Keep a slim config for potential future use (dev tools already enabled in Dash.run) - self.config = {'debug': debug} - self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) - - app.run(host=host, port=port, debug=debug, **filtered_kwargs) + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): return Response(data, mimetype=mimetype, content_type=content_type) From b05e37654f0f36eaa709514effe5900c6d94cf5a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:07:10 -0400 Subject: [PATCH 24/74] reverting `render_index` -> `index` and making catch for outside of a `request` context --- dash/dash.py | 10 +++++++--- dash/server_factories/fastapi_factory.py | 4 ++-- dash/server_factories/flask_factory.py | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 4a06e9216e..973f1ec579 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1145,16 +1145,20 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - def render_index(self, *_args, **_kwargs): + def index(self, *_args, **_kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() metas = self._generate_meta() renderer = self._generate_renderer() title = self.title - request = get_request_adapter() + try: + request = get_request_adapter() + except LookupError: + # no request context + request = None - if self.use_pages and self.config.include_pages_meta: + if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas if self._favicon: diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index de1caf451c..cf08f85d7f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -88,7 +88,7 @@ async def index(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) - return Response(content=dash_app.render_index(), media_type="text/html") + return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) @@ -104,7 +104,7 @@ async def catchall(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) - return Response(content=dash_app.render_index(), media_type="text/html") + return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 6eebd735ac..9ba8a5017c 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -76,7 +76,7 @@ def setup_catchall(self, dash_app): def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(*args, **kwargs) + return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) @@ -85,7 +85,7 @@ def setup_index(self, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(*args, **kwargs) + return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) From ed0dc3b4fbdf78ea68ca1d21e510aca1bf4a3320 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Thu, 11 Sep 2025 21:24:55 +0200 Subject: [PATCH 25/74] =?UTF-8?q?=E2=88=99=20-=20initial=20quart=20factory?= =?UTF-8?q?=20=E2=88=99=20-=20added=20types=20to=20BaseFactory=20to=20remo?= =?UTF-8?q?ve=20linting=20errors=20on=20create=20app=20in=20Flask=20and=20?= =?UTF-8?q?Quart=20Factory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dash/server_factories/base_factory.py | 25 +-- dash/server_factories/quart_factory.py | 238 +++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 12 deletions(-) create mode 100644 dash/server_factories/quart_factory.py diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index b44f6888cb..12088947c2 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,49 +1,50 @@ from abc import ABC, abstractmethod +from typing import Any class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs): + def __call__(self, server, *args, **kwargs) -> Any: # Default: WSGI return server(*args, **kwargs) @abstractmethod - def create_app(self, name="__main__", config=None): + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface pass @abstractmethod def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface pass @abstractmethod - def register_error_handlers(self, app): + def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface pass @abstractmethod - def before_request(self, app, func): + def before_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def after_request(self, app, func): + def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host, port, debug, **kwargs): + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None): + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface pass @abstractmethod - def jsonify(self, obj): + def jsonify(self, obj) -> Any: # pragma: no cover - interface pass @abstractmethod - def get_request_adapter(self): + def get_request_adapter(self) -> Any: # pragma: no cover - interface pass diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py new file mode 100644 index 0000000000..977c9aea4c --- /dev/null +++ b/dash/server_factories/quart_factory.py @@ -0,0 +1,238 @@ +from .base_factory import BaseServerFactory +from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from contextvars import copy_context +import inspect +import os +import pkgutil +import mimetypes +import hashlib +import sys + + +class QuartAPIServerFactory(BaseServerFactory): + """Quart implementation of the Dash server factory. + + All Quart/async specific imports are at the top-level (per user request) so + Quart must be installed when this module is imported. + """ + + def __init__(self) -> None: + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory + return super().__call__(server, *args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = Quart(name) + if config: + for key, value in config.items(): + # Mirror Flask usage of config dict + app.config[key] = value + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + if os.path.isdir(assets_folder): + route = f"{assets_url_path}/" + + @app.route(route) + async def serve_asset(filename): # pragma: no cover - simple passthrough + return await send_from_directory(assets_folder, filename) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return QuartResponse(html, content_type="text/html") + + return wrapped + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + # Wrap plain strings or sync callables in async handler returning HTML + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + view_func = self._html_response_wrapper(view_func) + app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + + # ---- Index & Catchall ------------------------------------------------ + def setup_index(self, app, dash_app): + async def index(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + @app.before_serving + async def _enable_dev_tools(): # pragma: no cover - environmental + dash_app.enable_dev_tools(**self.config) + + async def catchall(path): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + # Must be added after other routes + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) + + # ---- Middleware-esque hooks ----------------------------------------- + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + # Quart after_request expects (response) -> response + @app.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + # ---- Running --------------------------------------------------------- + def run(self, app, host, port, debug, **kwargs): + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) + + # ---- Responses / JSON ------------------------------------------------ + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["Content-Type"] = mimetype + if content_type: + headers["Content-Type"] = content_type + return QuartResponse(data, headers=headers) + + def jsonify(self, obj): + return jsonify(obj) + + def get_request_adapter(self): + return QuartRequestAdapter + + # ---- Component Suites ------------------------------------------------ + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, req + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return QuartResponse(data, content_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if req.headers.get("If-None-Match") == etag: + return QuartResponse(None, status=304) + return QuartResponse(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + methods=["GET"], + ) + + # ---- Dispatch (Callbacks) ------------------------------------------- + def dispatch(self, app, dash_app, use_async=True): # Quart always async + async def _dispatch(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + body = await request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return QuartResponse(response_data, content_type="application/json") + + return _dispatch + + # ---- Favicon --------------------------------------------------------- + def _serve_default_favicon(self): + return QuartResponse( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + +class QuartRequestAdapter: + """Adapter that normalizes Quart's request API to what Dash expects.""" + + @staticmethod + def get_args(): + return request.args + + @staticmethod + async def get_json(): + return await request.get_json() + + @staticmethod + def is_json(): + return request.is_json + + @staticmethod + def get_cookies(): + return request.cookies + + @staticmethod + def get_headers(): + return request.headers + + @staticmethod + def get_full_path(): + return request.full_path + + @staticmethod + def get_remote_addr(): + return request.remote_addr + + @staticmethod + def get_origin(): + return request.headers.get("Origin") + + @staticmethod + def get_path(): + return request.path + From 141527c8b8c25c7b1a47f7dc9eded53135e2ce94 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 15:28:49 +0200 Subject: [PATCH 26/74] Quart factory ready --- dash/server_factories/quart_factory.py | 154 +++++++++++++++++-------- 1 file changed, 107 insertions(+), 47 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 977c9aea4c..75376fcd7a 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,5 +1,5 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify, send_from_directory from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint @@ -11,6 +11,7 @@ import mimetypes import hashlib import sys +import time class QuartAPIServerFactory(BaseServerFactory): @@ -39,12 +40,50 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - if os.path.isdir(assets_folder): - route = f"{assets_url_path}/" + # Mirror Flask implementation using a blueprint serving static files + from quart import Blueprint + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def register_timing_hooks(self, app, _first_run): # parity with Flask factory + from quart import g + + @app.before_request + async def _before_request(): # pragma: no cover - timing infra + g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} - @app.route(route) - async def serve_asset(filename): # pragma: no cover - simple passthrough - return await send_from_directory(assets_folder, filename) + @app.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = getattr(g, "timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response def register_error_handlers(self, app): @app.errorhandler(PreventUpdate) @@ -61,49 +100,72 @@ async def wrapped(*args, **kwargs): if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val html = str(html_val) - return QuartResponse(html, content_type="text/html") + return Response(html, content_type="text/html") return wrapped def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - if rule == "": - rule = "/" - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - # Wrap plain strings or sync callables in async handler returning HTML - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - view_func = self._html_response_wrapper(view_func) - app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - # ---- Index & Catchall ------------------------------------------------ + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + # if rule == "": + # rule = "/" + # if isinstance(view_func, str): + # # Literal HTML content + # view_func = self._html_response_wrapper(view_func) + # elif not inspect.iscoroutinefunction(view_func): + # # Sync function: wrap to make async but preserve Response objects + # original = view_func + + # async def _async_adapter(*args, **kwargs): + # result = original(*args, **kwargs) + # # Pass through existing Response (Quart/Flask style) + # if isinstance(result, Response) or ( + # hasattr(result, "status_code") + # and hasattr(result, "headers") + # and hasattr(result, "get_data") + # ): + # return result + # # If it's bytes or str treat as HTML + # if isinstance(result, (str, bytes)): + # return Response(result, content_type="text/html") + # # Fallback: JSON encode arbitrary python objects + # try: + # import json + + # return Response( + # json.dumps(result), content_type="application/json" + # ) + # except Exception: # pragma: no cover + # return Response(str(result), content_type="text/plain") + + # view_func = _async_adapter + # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - @app.before_serving - async def _enable_dev_tools(): # pragma: no cover - environmental - dash_app.enable_dev_tools(**self.config) - async def catchall(path): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") - # Must be added after other routes self.add_url_rule( app, "/", catchall, endpoint="catchall", methods=["GET"] ) - # ---- Middleware-esque hooks ----------------------------------------- def before_request(self, app, func): app.before_request(func) def after_request(self, app, func): - # Quart after_request expects (response) -> response @app.after_request async def _after(response): if func is not None: @@ -112,19 +174,24 @@ async def _after(response): await result return response - # ---- Running --------------------------------------------------------- def run(self, app, host, port, debug, **kwargs): - self.config = dict({'debug': debug} if debug else {}, **kwargs) - app.run(host=host, port=port, debug=debug, **kwargs) + # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) + # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. + unsupported = {"threaded", "processes"} + filtered_kwargs = {} + for k, v in kwargs.items(): + if k in unsupported: + continue + filtered_kwargs[k] = v + + # Keep a slim config for potential future use (dev tools already enabled in Dash.run) + self.config = {'debug': debug} + self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) + + app.run(host=host, port=port, debug=debug, **filtered_kwargs) - # ---- Responses / JSON ------------------------------------------------ def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["Content-Type"] = mimetype - if content_type: - headers["Content-Type"] = content_type - return QuartResponse(data, headers=headers) + return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) @@ -132,7 +199,6 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - # ---- Component Suites ------------------------------------------------ def serve_component_suites( self, dash_app, package_name, fingerprinted_path, req ): @@ -152,12 +218,9 @@ def serve_component_suites( headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return QuartResponse(data, content_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if req.headers.get("If-None-Match") == etag: - return QuartResponse(None, status=304) - return QuartResponse(data, content_type=mimetype, headers=headers) + return Response(data, content_type=mimetype, headers=headers) + + return Response(data, content_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): async def serve(package_name, fingerprinted_path): @@ -172,7 +235,6 @@ async def serve(package_name, fingerprinted_path): methods=["GET"], ) - # ---- Dispatch (Callbacks) ------------------------------------------- def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() @@ -186,13 +248,12 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return QuartResponse(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") return _dispatch - # ---- Favicon --------------------------------------------------------- def _serve_default_favicon(self): - return QuartResponse( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -235,4 +296,3 @@ def get_origin(): @staticmethod def get_path(): return request.path - From 3e38d4151414bca27449d03acf323b67f958e282 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:15:49 -0400 Subject: [PATCH 27/74] fixing `prune_errors` test --- tests/integration/devtools/test_devtools_error_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index 40d5731202..b481ef2fad 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "dash.py" in error0 + assert "factory.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "dash.py" in error1 + assert "factory.py" in error1 assert "self.wsgi_app" in error1 From 381fb0c135b0f3b7a7106afe307d7e8a5866c65a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:40:34 -0400 Subject: [PATCH 28/74] adjustments for flask api_endpoint declared in callback defs --- dash/dash.py | 26 ++------------------------ dash/server_factories/flask_factory.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 973f1ec579..bed7ab43a4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -793,30 +793,8 @@ def setup_apis(self): ) self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) - def make_parse_body(func): - def _parse_body(): - if flask.request.is_json: - data = flask.request.get_json() - return flask.jsonify(func(**data)) - return flask.jsonify({}) - - return _parse_body - - def make_parse_body_async(func): - async def _parse_body_async(): - if flask.request.is_json: - data = flask.request.get_json() - result = await func(**data) - return flask.jsonify(result) - return flask.jsonify({}) - - return _parse_body_async - - for path, func in self.callback_api_paths.items(): - if asyncio.iscoroutinefunction(func): - self._add_url(path, make_parse_body_async(func), ["POST"]) - else: - self._add_url(path, make_parse_body(func), ["POST"]) + # Delegate to the server factory for route registration + self.server_factory.register_callback_api_routes(self.server, self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 9ba8a5017c..a488a070e1 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -5,6 +5,7 @@ import mimetypes import time import flask +import inspect from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -200,6 +201,31 @@ def _after_request(response): self.before_request(app, _before_request) self.after_request(app, _after_request) + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = await handler(**data) if data else await handler() + return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = handler(**data) if data else handler() + return flask.jsonify(result) + + # Flask 2.x+ supports async views natively + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + class FlaskRequestAdapter: @staticmethod From a27927a2dfde473d5eb2215fbb035a4b6597c44b Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 21:49:49 +0200 Subject: [PATCH 29/74] updated QuartRequestAdapter & QuartFactory to latest changes --- dash/server_factories/quart_factory.py | 147 ++++++++----------------- 1 file changed, 47 insertions(+), 100 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 75376fcd7a..99c9c2e5a0 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,15 +1,13 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response, jsonify, send_from_directory +from quart import Quart, Request, Response, jsonify, request from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context import inspect -import os import pkgutil import mimetypes -import hashlib import sys import time @@ -26,21 +24,18 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory return super().__call__(server, *args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) if config: for key, value in config.items(): - # Mirror Flask usage of config dict app.config[key] = value return app def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - # Mirror Flask implementation using a blueprint serving static files from quart import Blueprint bp = Blueprint( @@ -109,58 +104,23 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - # if rule == "": - # rule = "/" - # if isinstance(view_func, str): - # # Literal HTML content - # view_func = self._html_response_wrapper(view_func) - # elif not inspect.iscoroutinefunction(view_func): - # # Sync function: wrap to make async but preserve Response objects - # original = view_func - - # async def _async_adapter(*args, **kwargs): - # result = original(*args, **kwargs) - # # Pass through existing Response (Quart/Flask style) - # if isinstance(result, Response) or ( - # hasattr(result, "status_code") - # and hasattr(result, "headers") - # and hasattr(result, "get_data") - # ): - # return result - # # If it's bytes or str treat as HTML - # if isinstance(result, (str, bytes)): - # return Response(result, content_type="text/html") - # # Fallback: JSON encode arbitrary python objects - # try: - # import json - - # return Response( - # json.dumps(result), content_type="application/json" - # ) - # except Exception: # pragma: no cover - # return Response(str(result), content_type="text/plain") - - # view_func = _async_adapter - # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return Response(dash_app.render_index(), content_type="text/html") + adapter.set_request(request) + return Response(dash_app.index(), content_type="text/html") - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, app, dash_app): - async def catchall(path): + def setup_catchall(self, dash_app): + async def catchall(path): # noqa: ARG001 - path is unused but kept for route signature adapter = QuartRequestAdapter() set_request_adapter(adapter) - return Response(dash_app.render_index(), content_type="text/html") + adapter.set_request(request) + return Response(dash_app.index(), content_type="text/html") - self.add_url_rule( - app, "/", catchall, endpoint="catchall", methods=["GET"] - ) + dash_app._add_url("", catchall, methods=["GET"]) def before_request(self, app, func): app.before_request(func) @@ -175,20 +135,8 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) - # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. - unsupported = {"threaded", "processes"} - filtered_kwargs = {} - for k, v in kwargs.items(): - if k in unsupported: - continue - filtered_kwargs[k] = v - - # Keep a slim config for potential future use (dev tools already enabled in Dash.run) - self.config = {'debug': debug} - self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) - - app.run(host=host, port=port, debug=debug, **filtered_kwargs) + self.config = {'debug': debug, **kwargs} if debug else kwargs + app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): return Response(data, mimetype=mimetype, content_type=content_type) @@ -199,9 +147,7 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, req - ): + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -222,23 +168,22 @@ def serve_component_suites( return Response(data, content_type=mimetype, headers=headers) - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request ) - self.add_url_rule( - app, - "/_dash-component-suites//", + dash_app._add_url( + "_dash-component-suites//", serve, - methods=["GET"], ) def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() set_request_adapter(adapter) + adapter.set_request(request) body = await request.get_json() g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) @@ -259,40 +204,42 @@ def _serve_default_favicon(self): class QuartRequestAdapter: - """Adapter that normalizes Quart's request API to what Dash expects.""" + def __init__(self) -> None: + self._request = None + + def set_request(self, request: Request) -> None: + self._request = request + + # Accessors (instance-based) + def get_root(self): + return self._request.root_url + + def get_args(self): + return self._request.args - @staticmethod - def get_args(): - return request.args + async def get_json(self): + return await self._request.get_json() - @staticmethod - async def get_json(): - return await request.get_json() + def is_json(self): + return self._request.is_json - @staticmethod - def is_json(): - return request.is_json + def get_cookies(self): + return self._request.cookies - @staticmethod - def get_cookies(): - return request.cookies + def get_headers(self): + return self._request.headers - @staticmethod - def get_headers(): - return request.headers + def get_full_path(self): + return self._request.full_path - @staticmethod - def get_full_path(): - return request.full_path + def get_url(self): + return str(self._request.url) - @staticmethod - def get_remote_addr(): - return request.remote_addr + def get_remote_addr(self): + return self._request.remote_addr - @staticmethod - def get_origin(): - return request.headers.get("Origin") + def get_origin(self): + return self._request.headers.get("origin") - @staticmethod - def get_path(): - return request.path + def get_path(self): + return self._request.path From 1824e110327740252dd944988763e0972371f7e9 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 22:06:43 +0200 Subject: [PATCH 30/74] Removed redundant Response return --- dash/server_factories/quart_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 99c9c2e5a0..53fa1cc469 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -164,7 +164,6 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return Response(data, content_type=mimetype, headers=headers) return Response(data, content_type=mimetype, headers=headers) From b14f6d276f039239725c7554decabd250c0f8975 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:21:18 -0400 Subject: [PATCH 31/74] fix for fastapi `api_endpoint` registering --- dash/server_factories/fastapi_factory.py | 48 ++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index cf08f85d7f..1090c85050 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -14,13 +14,21 @@ from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders + from pydantic import create_model + from typing import Any, Optional except ImportError: uvicorn = None - FastAPI = Request = Response = None - JSONResponse = PlainTextResponse = None + FastAPI = None + Request = None + Response = None + JSONResponse = None + PlainTextResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None + create_model = None + Any = None + Optional = None from dash.fingerprint import check_fingerprint from dash import _validate @@ -109,7 +117,7 @@ async def catchall(request: Request): # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): if rule == "": rule = "/" if isinstance(view_func, str): @@ -120,7 +128,7 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): view_func, methods=methods or ["GET"], name=endpoint, - include_in_schema=False, + include_in_schema=include_in_schema, ) def before_request(self, app, func): @@ -286,6 +294,38 @@ async def timing_middleware(request, call_next): headers.append("Server-Timing", value) return response + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + fields = {name: (Optional[Any], None) for name in param_names} + Model = create_model(f"Payload_{endpoint}", **fields) + + async def view_func(request: Request, body: Model): + kwargs = body.dict(exclude_unset=True) + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + + app.add_api_route( + route, + view_func, + methods=methods, + name=endpoint, + include_in_schema=True, + ) + class FastAPIRequestAdapter: def __init__(self): From 5ef796bf7614bdf5aeb4a5bb3ad61c78f2b4bfb4 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:08:00 -0400 Subject: [PATCH 32/74] shifting from `server_factory` to `backend` --- dash/_callback.py | 2 +- .../quart_factory.py => backend/quart.py} | 30 +- dash/dash.py | 120 +++--- dash/server_factories/__init__.py | 12 - dash/server_factories/base_factory.py | 50 --- dash/server_factories/fastapi_factory.py | 370 ------------------ dash/server_factories/flask_factory.py | 273 ------------- 7 files changed, 103 insertions(+), 754 deletions(-) rename dash/{server_factories/quart_factory.py => backend/quart.py} (86%) delete mode 100644 dash/server_factories/__init__.py delete mode 100644 dash/server_factories/base_factory.py delete mode 100644 dash/server_factories/fastapi_factory.py delete mode 100644 dash/server_factories/flask_factory.py diff --git a/dash/_callback.py b/dash/_callback.py index bca8027fdd..6cc55b9162 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -6,7 +6,7 @@ import asyncio -from dash.server_factories import get_request_adapter +from dash.backend import get_request_adapter from .dependencies import ( handle_callback_args, diff --git a/dash/server_factories/quart_factory.py b/dash/backend/quart.py similarity index 86% rename from dash/server_factories/quart_factory.py rename to dash/backend/quart.py index 53fa1cc469..a2437811a4 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/backend/quart.py @@ -1,7 +1,7 @@ -from .base_factory import BaseServerFactory +from .base_server import BaseDashServer from quart import Quart, Request, Response, jsonify, request from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter +from dash.backend import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context @@ -12,7 +12,7 @@ import time -class QuartAPIServerFactory(BaseServerFactory): +class QuartDashServer(BaseDashServer): """Quart implementation of the Dash server factory. All Quart/async specific imports are at the top-level (per user request) so @@ -196,6 +196,30 @@ async def _dispatch(): return _dispatch + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Quart app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + else: + async def view_func(*args, handler=handler, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + def _serve_default_favicon(self): return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" diff --git a/dash/dash.py b/dash/dash.py index bed7ab43a4..0e7cbb25fa 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -39,7 +39,7 @@ ProxyError, DuplicateCallback, ) -from .server_factories import get_request_adapter +from .backend import get_request_adapter, get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -64,8 +64,7 @@ from . import _validate from . import _watch from . import _get_app -from .server_factories.flask_factory import FlaskServerFactory -from .server_factories.base_factory import BaseServerFactory +from .backend.flask import FlaskDashServer from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -156,6 +155,27 @@ except: # noqa: E722 page_container = None +def _is_flask_instance(obj): + try: + from flask import Flask + return isinstance(obj, Flask) + except ImportError: + return False + +def _is_fastapi_instance(obj): + try: + from fastapi import FastAPI + return isinstance(obj, FastAPI) + except ImportError: + return False + +def _is_quart_instance(obj): + try: + from quart import Quart + return isinstance(obj, Quart) + except ImportError: + return False + def _get_traceback(secret, error: Exception): try: @@ -249,6 +269,12 @@ class Dash(ObsoleteChecker): ``flask.Flask``: use this pre-existing Flask server. :type server: boolean or flask.Flask + :param backend: The backend to use for the Dash app. Can be a string + (name of the backend) or a backend class. Default is None, which + selects the Flask backend. Currently, "flask" and "fastapi" backends + are supported. + :type backend: string or type + :param assets_folder: a path, relative to the current working directory, for extra files to be used in the browser. Default ``'assets'``. All .js and .css files will be loaded immediately unless excluded by @@ -431,6 +457,7 @@ def __init__( # pylint: disable=too-many-statements self, name: Optional[str] = None, server: Union[bool, Callable[[], Any]] = True, + backend: Union[str, type, None] = None, assets_folder: str = "assets", pages_folder: str = "pages", use_pages: Optional[bool] = None, @@ -466,7 +493,6 @@ def __init__( # pylint: disable=too-many-statements description: Optional[str] = None, on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, - server_factory: Optional[BaseServerFactory] = None, **obsolete, ): @@ -489,29 +515,33 @@ def __init__( # pylint: disable=too-many-statements caller_name: str = name if name is not None else get_caller_name() - self.server_factory = server_factory or FlaskServerFactory() - - # We have 3 cases: server is either True (we create the server), False - # (defer server creation) or a Flask app instance (we use their server) - if callable(server) and not ( - hasattr(server, "route") and hasattr(server, "run") - ): - # Server factory function - self.server = server() - if name is None: - caller_name = getattr(self.server, "name", caller_name) - elif hasattr(server, "route") and hasattr(server, "run"): + # Determine backend + if backend is None: + backend_cls = FlaskDashServer + elif isinstance(backend, str): + backend_cls = get_backend(backend) + elif isinstance(backend, type): + backend_cls = backend + else: + raise ValueError("Invalid backend argument") + + # Determine server and backend instance + if server is not None and server is not True and server is not False: + # User provided a server instance (e.g., Flask, Quart, FastAPI) + if _is_flask_instance(server): + backend_cls = get_backend("flask") + elif _is_quart_instance(server): + backend_cls = get_backend("quart") + elif _is_fastapi_instance(server): + backend_cls = get_backend("fastapi") + else: + raise ValueError("Unsupported server type") + self.backend = backend_cls() self.server = server - if name is None: - caller_name = getattr(server, "name", caller_name) - elif isinstance(server, bool): - self.server = ( - self.server_factory.create_app(caller_name) if server else None - ) else: - raise ValueError( - "server must be a Flask app, a boolean, or a server factory function" - ) + # No server instance provided, create backend and let backend create server + self.backend = backend_cls() + self.server = server base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -700,7 +730,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: self.server = app bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" - self.server_factory.register_assets_blueprint( + self.backend.register_assets_blueprint( self.server, assets_blueprint_name, config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), @@ -723,8 +753,8 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - self.server_factory.register_error_handlers(self.server) - self.server_factory.before_request(self.server, self._setup_server) + self.backend.register_error_handlers(self.server) + self.backend.before_request(self.server, self._setup_server) self._setup_routes() _get_app.APP = self self.enable_pages() @@ -732,7 +762,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name - self.server_factory.add_url_rule( + self.backend.add_url_rule( self.server, full_name, view_func=view_func, @@ -742,21 +772,21 @@ def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> Non self.routes.append(full_name) def _setup_routes(self): - self.server_factory.setup_component_suites(self) + self.backend.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.server_factory.dispatch(self.server, self, self._use_async), + self.backend.dispatch(self.server, self, self._use_async), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) self._add_url( "_favicon.ico", - self.server_factory._serve_default_favicon, # pylint: disable=protected-access + self.backend._serve_default_favicon, # pylint: disable=protected-access ) - self.server_factory.setup_index(self) - self.server_factory.setup_catchall(self) + self.backend.setup_index(self) + self.backend.setup_catchall(self) if jupyter_dash.active: self._add_url( @@ -794,7 +824,7 @@ def setup_apis(self): self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) # Delegate to the server factory for route registration - self.server_factory.register_callback_api_routes(self.server, self.callback_api_paths) + self.backend.register_callback_api_routes(self.server, self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -866,7 +896,7 @@ def serve_layout(self): layout = hook(layout) # TODO - Set browser cache limit - pass hash into frontend - return self.server_factory.make_response( + return self.backend.make_response( to_json(layout), mimetype="application/json", ) @@ -930,7 +960,7 @@ def serve_reload_hash(self): _reload.hard = False _reload.changed_assets = [] - return self.server_factory.jsonify( + return self.backend.jsonify( { "reloadHash": _hash, "hard": hard, @@ -1241,7 +1271,7 @@ def interpolate_index(self, **kwargs): @with_app_context def dependencies(self): - return self.server_factory.make_response( + return self.backend.make_response( to_json(self._callback_list), content_type="application/json", ) @@ -1360,7 +1390,7 @@ def _initialize_context(self, body, adapter): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = self.server_factory.make_response( + g.dash_response = self.backend.make_response( mimetype="application/json", data=None ) g.cookies = dict(adapter.get_cookies()) @@ -1736,7 +1766,7 @@ def display_content(path): For nested URLs, slashes are still included: `app.strip_relative_path('/page-1/sub-page-1/')` will return - `page-1/sub-page-1` + `page-1/sub-page-1 ``` """ return _get_paths.app_strip_relative_path( @@ -1993,12 +2023,12 @@ def enable_dev_tools( ) elif dev_tools.prune_errors: secret = gen_salt(20) - self.server_factory.register_prune_error_handler( + self.backend.register_prune_error_handler( self.server, secret, _get_traceback ) if debug and dev_tools.ui: - self.server_factory.register_timing_hooks(self.server, first_run) + self.backend.register_timing_hooks(self.server, first_run) if ( debug @@ -2282,7 +2312,7 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server_factory.run( + self.backend.run( self.server, host=host, port=port, debug=debug, **flask_run_options ) @@ -2447,7 +2477,7 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - self.server_factory.before_request(self.server, router) + self.backend.before_request(self.server, router) def __call__(self, *args, **kwargs): - return self.server_factory.__call__(self.server, *args, **kwargs) + return self.backend.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py deleted file mode 100644 index 1bfd497935..0000000000 --- a/dash/server_factories/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# python -import contextvars - -_request_adapter_var = contextvars.ContextVar("request_adapter") - - -def set_request_adapter(adapter): - _request_adapter_var.set(adapter) - - -def get_request_adapter(): - return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py deleted file mode 100644 index 12088947c2..0000000000 --- a/dash/server_factories/base_factory.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs) -> Any: - # Default: WSGI - return server(*args, **kwargs) - - @abstractmethod - def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, app, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self, app) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def get_request_adapter(self) -> Any: # pragma: no cover - interface - pass diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py deleted file mode 100644 index 1090c85050..0000000000 --- a/dash/server_factories/fastapi_factory.py +++ /dev/null @@ -1,370 +0,0 @@ -import sys -import mimetypes -import hashlib -import inspect -import pkgutil -from contextvars import copy_context -import importlib.util -import time - -try: - import uvicorn - from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse, PlainTextResponse - from fastapi.staticfiles import StaticFiles - from starlette.responses import Response as StarletteResponse - from starlette.datastructures import MutableHeaders - from pydantic import create_model - from typing import Any, Optional -except ImportError: - uvicorn = None - FastAPI = None - Request = None - Response = None - JSONResponse = None - PlainTextResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - create_model = None - Any = None - Optional = None - -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter -from .base_factory import BaseServerFactory - - -class FastAPIServerFactory(BaseServerFactory): - def __init__(self): - self.config = {} - super().__init__() - - def __call__(self, server, *args, **kwargs): - # ASGI: (scope, receive, send) - if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return server(*args, **kwargs) - raise TypeError("FastAPI app must be called with (scope, receive, send)") - - def create_app(self, name="__main__", config=None): - app = FastAPI() - if config: - for key, value in config.items(): - setattr(app.state, key, value) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - try: - app.mount( - assets_url_path, - StaticFiles(directory=assets_folder), - name=blueprint_name, - ) - except RuntimeError: - # directory doesnt exist - pass - - def register_error_handlers(self, app): - @app.exception_handler(PreventUpdate) - async def _handle_error(_request, _exc): - return Response(status_code=204) - - @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(_request, exc): - return Response(content=exc.args[0], status_code=404) - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.exception_handler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return PlainTextResponse(tb, status_code=500) - - def _html_response_wrapper(self, view_func): - async def wrapped(*_args, **_kwargs): - # If view_func is a function, call it; if it's a string, use it directly - html = view_func() if callable(view_func) else view_func - return Response(content=html, media_type="text/html") - - return wrapped - - def setup_index(self, dash_app): - async def index(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def setup_catchall(self, dash_app): - @dash_app.server.on_event("startup") - def _setup_catchall(): - dash_app.enable_dev_tools( - **self.config, first_run=False - ) # do this to make sure dev tools are enabled - - async def catchall(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): - if rule == "": - rule = "/" - if isinstance(view_func, str): - # Wrap string or sync function to async FastAPI handler - view_func = self._html_response_wrapper(view_func) - app.add_api_route( - rule, - view_func, - methods=methods or ["GET"], - name=endpoint, - include_in_schema=include_in_schema, - ) - - def before_request(self, app, func): - # FastAPI does not have before_request, but we can use middleware - app.middleware("http")(self._make_before_middleware(func)) - - def after_request(self, app, func): - # FastAPI does not have after_request, but we can use middleware - app.middleware("http")(self._make_after_middleware(func)) - - def run(self, app, host, port, debug, **kwargs): - frame = inspect.stack()[2] - self.config = dict({"debug": debug} if debug else {}, **kwargs) - reload = debug - if reload: - # Dynamically determine the module name from the file path - file_path = frame.filename - module_name = importlib.util.spec_from_file_location("app", file_path).name - uvicorn.run( - f"{module_name}:app.server", - host=host, - port=port, - reload=reload, - **kwargs, - ) - else: - uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["content-type"] = mimetype - if content_type: - headers["content-type"] = content_type - return Response(content=data, headers=headers) - - def jsonify(self, obj): - return JSONResponse(content=obj) - - def get_request_adapter(self): - return FastAPIRequestAdapter - - def _make_before_middleware(self, func): - async def middleware(request, call_next): - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - response = await call_next(request) - return response - - return middleware - - def _make_after_middleware(self, func): - async def middleware(request, call_next): - response = await call_next(request) - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - return response - - return middleware - - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request - ): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - headers = {} - if has_fingerprint: - headers["Cache-Control"] = "public, max-age=31536000" - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if request.headers.get("if-none-match") == etag: - return StarletteResponse(status_code=304) - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - - def setup_component_suites(self, dash_app): - async def serve(request: Request, package_name: str, fingerprinted_path: str): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, request - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - async def _dispatch(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - # pylint: disable=protected-access - body = await request.json() - g = dash_app._initialize_context( - body, adapter - ) # pylint: disable=protected-access - func = dash_app._prepare_callback( - g, body - ) # pylint: disable=protected-access - args = dash_app._inputs_to_vals( - g.inputs_list + g.states_list - ) # pylint: disable=protected-access - ctx = copy_context() - partial_func = dash_app._execute_callback( - func, args, g.outputs_list, g - ) # pylint: disable=protected-access - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): - response_data = await response_data - # Instead of set_data, return a new Response - return Response(content=response_data, media_type="application/json") - - return _dispatch - - def _serve_default_favicon(self): - return Response( - content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" - ) - - def register_timing_hooks(self, app, first_run): - if not first_run: - return - - @app.middleware("http") - async def timing_middleware(request, call_next): - # Before request - request.state.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - response = await call_next(request) - # After request - timing_information = getattr(request.state, "timing_information", None) - if timing_information is not None: - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - headers = MutableHeaders(response.headers) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - headers.append("Server-Timing", value) - return response - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the FastAPI app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - sig = inspect.signature(handler) - param_names = list(sig.parameters.keys()) - fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model(f"Payload_{endpoint}", **fields) - - async def view_func(request: Request, body: Model): - kwargs = body.dict(exclude_unset=True) - if inspect.iscoroutinefunction(handler): - result = await handler(**kwargs) - else: - result = handler(**kwargs) - return JSONResponse(content=result) - - - app.add_api_route( - route, - view_func, - methods=methods, - name=endpoint, - include_in_schema=True, - ) - - -class FastAPIRequestAdapter: - def __init__(self): - self._request = None - - def set_request(self, request: Request): - self._request = request - - def get_root(self): - return str(self._request.base_url) - - def get_args(self): - return self._request.query_params - - async def get_json(self): - return await self._request.json() - - def is_json(self): - return self._request.headers.get("content-type", "").startswith( - "application/json" - ) - - def get_cookies(self, _request=None): - return self._request.cookies - - def get_headers(self): - return self._request.headers - - def get_full_path(self): - return str(self._request.url) - - def get_url(self): - return str(self._request.url) - - def get_remote_addr(self): - return self._request.client.host if self._request.client else None - - def get_origin(self): - return self._request.headers.get("origin") - - def get_path(self): - return self._request.url.path diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py deleted file mode 100644 index a488a070e1..0000000000 --- a/dash/server_factories/flask_factory.py +++ /dev/null @@ -1,273 +0,0 @@ -from contextvars import copy_context -import asyncio -import pkgutil -import sys -import mimetypes -import time -import flask -import inspect -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter -from .base_factory import BaseServerFactory - - -class FlaskServerFactory(BaseServerFactory): - def __call__(self, server, *args, **kwargs): - # Always WSGI - return server(*args, **kwargs) - - def create_app(self, name="__main__", config=None): - app = flask.Flask(name) - if config: - app.config.update(config) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - bp = flask.Blueprint( - blueprint_name, - __name__, - static_folder=assets_folder, - static_url_path=assets_url_path, - ) - app.register_blueprint(bp) - - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) - def _handle_error(_): - return "", 204 - - @app.errorhandler(InvalidResourceError) - def _invalid_resources_handler(err): - return err.args[0], 404 - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - def _wrap_errors(error): - tb = get_traceback_func(secret, error) - return tb, 500 - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( - rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] - ) - - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - app.after_request(func) - - def run(self, app, host, port, debug, **kwargs): - app.run(host=host, port=port, debug=debug, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - return flask.Response(data, mimetype=mimetype, content_type=content_type) - - def jsonify(self, obj): - return flask.jsonify(obj) - - def get_request_adapter(self): - return FlaskRequestAdapter - - def setup_catchall(self, dash_app): - def catchall(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", catchall, methods=["GET"]) - - def setup_index(self, dash_app): - def index(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def serve_component_suites(self, dash_app, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - response = flask.Response(data, mimetype=mimetype) - if has_fingerprint: - response.cache_control.max_age = 31536000 # 1 year - else: - response.add_etag() - tag = response.get_etag()[0] - request_etag = flask.request.headers.get("If-None-Match") - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - return response - - def setup_component_suites(self, dash_app): - def serve(package_name, fingerprinted_path): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - def _dispatch(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - g.dash_response.set_data(response_data) - return g.dash_response - - async def _dispatch_async(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - response_data = await response_data - g.dash_response.set_data(response_data) - return g.dash_response - - if use_async: - return _dispatch_async - return _dispatch - - def _serve_default_favicon(self): - - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - - def register_timing_hooks(self, app, _first_run): - def _before_request(): - flask.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - response.headers.add("Server-Timing", value) - return response - - self.before_request(app, _before_request) - self.after_request(app, _after_request) - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the Flask app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - The view function parses the JSON body and passes it to the handler. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - - if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = await handler(**data) if data else await handler() - return flask.jsonify(result) - else: - def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = handler(**data) if data else handler() - return flask.jsonify(result) - - # Flask 2.x+ supports async views natively - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) - - -class FlaskRequestAdapter: - @staticmethod - def get_args(): - return flask.request.args - - @staticmethod - def get_root(): - return flask.request.url_root - - @staticmethod - def get_json(): - return flask.request.get_json() - - @staticmethod - def is_json(): - return flask.request.is_json - - @staticmethod - def get_cookies(): - return flask.request.cookies - - @staticmethod - def get_headers(): - return flask.request.headers - - @staticmethod - def get_url(): - return flask.request.url - - @staticmethod - def get_full_path(): - return flask.request.full_path - - @staticmethod - def get_remote_addr(): - return flask.request.remote_addr - - @staticmethod - def get_origin(): - return getattr(flask.request, "origin", None) - - @staticmethod - def get_path(): - return flask.request.path From a4ca566d6810cce00ced20ea5d1b975c39cdc36a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:13:23 -0400 Subject: [PATCH 33/74] adding missing files --- dash/backend/__init__.py | 13 ++ dash/backend/base_server.py | 50 +++++ dash/backend/fastapi.py | 370 ++++++++++++++++++++++++++++++++++++ dash/backend/flask.py | 273 ++++++++++++++++++++++++++ dash/backend/registry.py | 22 +++ 5 files changed, 728 insertions(+) create mode 100644 dash/backend/__init__.py create mode 100644 dash/backend/base_server.py create mode 100644 dash/backend/fastapi.py create mode 100644 dash/backend/flask.py create mode 100644 dash/backend/registry.py diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py new file mode 100644 index 0000000000..497f2dca2d --- /dev/null +++ b/dash/backend/__init__.py @@ -0,0 +1,13 @@ +# python +import contextvars +from .registry import * + +_request_adapter_var = contextvars.ContextVar("request_adapter") + + +def set_request_adapter(adapter): + _request_adapter_var.set(adapter) + + +def get_request_adapter(): + return _request_adapter_var.get() diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py new file mode 100644 index 0000000000..8c902f4248 --- /dev/null +++ b/dash/backend/base_server.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseDashServer(ABC): + def __call__(self, server, *args, **kwargs) -> Any: + # Default: WSGI + return server(*args, **kwargs) + + @abstractmethod + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self, app) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, app, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, app, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def get_request_adapter(self) -> Any: # pragma: no cover - interface + pass diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py new file mode 100644 index 0000000000..b2feeb446f --- /dev/null +++ b/dash/backend/fastapi.py @@ -0,0 +1,370 @@ +import sys +import mimetypes +import hashlib +import inspect +import pkgutil +from contextvars import copy_context +import importlib.util +import time + +try: + import uvicorn + from fastapi import FastAPI, Request, Response + from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders + from pydantic import create_model + from typing import Any, Optional +except ImportError: + uvicorn = None + FastAPI = None + Request = None + Response = None + JSONResponse = None + PlainTextResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + create_model = None + Any = None + Optional = None + +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from .base_server import BaseDashServer + + +class FastAPIDashServer(BaseDashServer): + def __init__(self): + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI: (scope, receive, send) + if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: + return server(*args, **kwargs) + raise TypeError("FastAPI app must be called with (scope, receive, send)") + + def create_app(self, name="__main__", config=None): + app = FastAPI() + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + try: + app.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self, app): + @app.exception_handler(PreventUpdate) + async def _handle_error(_request, _exc): + return Response(status_code=204) + + @app.exception_handler(InvalidResourceError) + async def _invalid_resources_handler(_request, exc): + return Response(content=exc.args[0], status_code=404) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.exception_handler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return PlainTextResponse(tb, status_code=500) + + def _html_response_wrapper(self, view_func): + async def wrapped(*_args, **_kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, dash_app): + async def index(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app): + @dash_app.server.on_event("startup") + def _setup_catchall(): + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled + + async def catchall(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + app.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=include_in_schema, + ) + + def before_request(self, app, func): + # FastAPI does not have before_request, but we can use middleware + app.middleware("http")(self._make_before_middleware(func)) + + def after_request(self, app, func): + # FastAPI does not have after_request, but we can use middleware + app.middleware("http")(self._make_after_middleware(func)) + + def run(self, app, host, port, debug, **kwargs): + frame = inspect.stack()[2] + self.config = dict({"debug": debug} if debug else {}, **kwargs) + reload = debug + if reload: + # Dynamically determine the module name from the file path + file_path = frame.filename + module_name = importlib.util.spec_from_file_location("app", file_path).name + uvicorn.run( + f"{module_name}:app.server", + host=host, + port=port, + reload=reload, + **kwargs, + ) + else: + uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers) + + def jsonify(self, obj): + return JSONResponse(content=obj) + + def get_request_adapter(self): + return FastAPIRequestAdapter + + def _make_before_middleware(self, func): + async def middleware(request, call_next): + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + response = await call_next(request) + return response + + return middleware + + def _make_after_middleware(self, func): + async def middleware(request, call_next): + response = await call_next(request) + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + return response + + return middleware + + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app): + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): + async def _dispatch(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + # pylint: disable=protected-access + body = await request.json() + g = dash_app._initialize_context( + body, adapter + ) # pylint: disable=protected-access + func = dash_app._prepare_callback( + g, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + g.inputs_list + g.states_list + ) # pylint: disable=protected-access + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, g.outputs_list, g + ) # pylint: disable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + # Instead of set_data, return a new Response + return Response(content=response_data, media_type="application/json") + + return _dispatch + + def _serve_default_favicon(self): + return Response( + content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" + ) + + def register_timing_hooks(self, app, first_run): + if not first_run: + return + + @app.middleware("http") + async def timing_middleware(request, call_next): + # Before request + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + response = await call_next(request) + # After request + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + fields = {name: (Optional[Any], None) for name in param_names} + Model = create_model(f"Payload_{endpoint}", **fields) + + async def view_func(request: Request, body: Model): + kwargs = body.dict(exclude_unset=True) + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + + app.add_api_route( + route, + view_func, + methods=methods, + name=endpoint, + include_in_schema=True, + ) + + +class FastAPIRequestAdapter: + def __init__(self): + self._request = None + + def set_request(self, request: Request): + self._request = request + + def get_root(self): + return str(self._request.base_url) + + def get_args(self): + return self._request.query_params + + async def get_json(self): + return await self._request.json() + + def is_json(self): + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) + + def get_cookies(self, _request=None): + return self._request.cookies + + def get_headers(self): + return self._request.headers + + def get_full_path(self): + return str(self._request.url) + + def get_url(self): + return str(self._request.url) + + def get_remote_addr(self): + return self._request.client.host if self._request.client else None + + def get_origin(self): + return self._request.headers.get("origin") + + def get_path(self): + return self._request.url.path diff --git a/dash/backend/flask.py b/dash/backend/flask.py new file mode 100644 index 0000000000..2d7d01af32 --- /dev/null +++ b/dash/backend/flask.py @@ -0,0 +1,273 @@ +from contextvars import copy_context +import asyncio +import pkgutil +import sys +import mimetypes +import time +import flask +import inspect +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from .base_server import BaseDashServer + + +class FlaskDashServer(BaseDashServer): + def __call__(self, server, *args, **kwargs): + # Always WSGI + return server(*args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = flask.Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + bp = flask.Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + app.after_request(func) + + def run(self, app, host, port, debug, **kwargs): + app.run(host=host, port=port, debug=debug, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + return flask.Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj): + return flask.jsonify(obj) + + def get_request_adapter(self): + return FlaskRequestAdapter + + def setup_catchall(self, dash_app): + def catchall(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def setup_index(self, dash_app): + def index(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = flask.Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = flask.request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = flask.Response(None, status=304) + return response + + def setup_component_suites(self, dash_app): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): + def _dispatch(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + # pylint: disable=protected-access + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + g.dash_response.set_data(response_data) + return g.dash_response + + async def _dispatch_async(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + # pylint: disable=protected-access + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + g.dash_response.set_data(response_data) + return g.dash_response + + if use_async: + return _dispatch_async + return _dispatch + + def _serve_default_favicon(self): + + return flask.Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + def register_timing_hooks(self, app, _first_run): + def _before_request(): + flask.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response): + timing_information = flask.g.get("timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(app, _before_request) + self.after_request(app, _after_request) + + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = await handler(**data) if data else await handler() + return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = handler(**data) if data else handler() + return flask.jsonify(result) + + # Flask 2.x+ supports async views natively + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + + +class FlaskRequestAdapter: + @staticmethod + def get_args(): + return flask.request.args + + @staticmethod + def get_root(): + return flask.request.url_root + + @staticmethod + def get_json(): + return flask.request.get_json() + + @staticmethod + def is_json(): + return flask.request.is_json + + @staticmethod + def get_cookies(): + return flask.request.cookies + + @staticmethod + def get_headers(): + return flask.request.headers + + @staticmethod + def get_url(): + return flask.request.url + + @staticmethod + def get_full_path(): + return flask.request.full_path + + @staticmethod + def get_remote_addr(): + return flask.request.remote_addr + + @staticmethod + def get_origin(): + return getattr(flask.request, "origin", None) + + @staticmethod + def get_path(): + return flask.request.path diff --git a/dash/backend/registry.py b/dash/backend/registry.py new file mode 100644 index 0000000000..1b80da879f --- /dev/null +++ b/dash/backend/registry.py @@ -0,0 +1,22 @@ +import importlib + +_backend_imports = { + 'flask': ('dash.backend.flask', 'FlaskDashServer'), + 'fastapi': ('dash.backend.fastapi', 'FastAPIDashServer'), + 'quart': ('dash.backend.quart', 'QuartDashServer'), +} + +def register_backend(name, module_path, class_name): + """Register a new backend by name.""" + _backend_imports[name.lower()] = (module_path, class_name) + +def get_backend(name): + try: + module_name, class_name = _backend_imports[name.lower()] + module = importlib.import_module(module_name) + return getattr(module, class_name) + except KeyError: + raise ValueError(f"Unknown backend: {name}") + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import backend '{name}': {e}") + From 708773f3d4f21cc1ef61fdc244010959c9c8567b Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:15:24 -0400 Subject: [PATCH 34/74] fixing issue with server not declared --- dash/dash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 0e7cbb25fa..22f79873bb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -540,8 +540,10 @@ def __init__( # pylint: disable=too-many-statements self.server = server else: # No server instance provided, create backend and let backend create server + if server is True and backend_cls is None: + backend_cls = FlaskDashServer self.backend = backend_cls() - self.server = server + self.server = self.backend.create_app(caller_name) # type: ignore base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix From b7bcebaf442e10455987dd72b02c98a1e680578f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:16:50 -0400 Subject: [PATCH 35/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 22f79873bb..2000114067 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -526,7 +526,7 @@ def __init__( # pylint: disable=too-many-statements raise ValueError("Invalid backend argument") # Determine server and backend instance - if server is not None and server is not True and server is not False: + if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) if _is_flask_instance(server): backend_cls = get_backend("flask") From 9873079800f773ef09581159312b7f0b48209f67 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:01 -0400 Subject: [PATCH 36/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 2000114067..af9ab03139 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1768,7 +1768,7 @@ def display_content(path): For nested URLs, slashes are still included: `app.strip_relative_path('/page-1/sub-page-1/')` will return - `page-1/sub-page-1 + `page-1/sub-page-1` ``` """ return _get_paths.app_strip_relative_path( From 9f4d291689c05cada98f4ad9fe380aa718ae56ac Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:16 -0400 Subject: [PATCH 37/74] Update dash/backend/quart.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/quart.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/dash/backend/quart.py b/dash/backend/quart.py index a2437811a4..5bb568fe72 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -207,17 +207,21 @@ def register_callback_api_routes(self, app, callback_api_paths): route = path if path.startswith("/") else f"/{path}" methods = ["POST"] - if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = await request.get_json() - result = await handler(**data) if data else await handler() - return jsonify(result) - else: - async def view_func(*args, handler=handler, **kwargs): - data = await request.get_json() - result = handler(**data) if data else handler() - return jsonify(result) - + def _make_view_func(handler): + if inspect.iscoroutinefunction(handler): + async def async_view_func(*args, **kwargs): + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + return async_view_func + else: + async def sync_view_func(*args, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + return sync_view_func + + view_func = _make_view_func(handler) app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) def _serve_default_favicon(self): From da86e8666b731375bf251a0f76fd5eb8d360b6a9 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:52 -0400 Subject: [PATCH 38/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index af9ab03139..242f4cbd91 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -540,8 +540,6 @@ def __init__( # pylint: disable=too-many-statements self.server = server else: # No server instance provided, create backend and let backend create server - if server is True and backend_cls is None: - backend_cls = FlaskDashServer self.backend = backend_cls() self.server = self.backend.create_app(caller_name) # type: ignore From 4c60740a5f98e2e2612cbb8c0dbf22d324345255 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:22:44 -0400 Subject: [PATCH 39/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 242f4cbd91..0e7be84128 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -529,13 +529,33 @@ def __init__( # pylint: disable=too-many-statements if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) if _is_flask_instance(server): - backend_cls = get_backend("flask") + inferred_backend = "flask" elif _is_quart_instance(server): - backend_cls = get_backend("quart") + inferred_backend = "quart" elif _is_fastapi_instance(server): - backend_cls = get_backend("fastapi") + inferred_backend = "fastapi" else: raise ValueError("Unsupported server type") + # Validate that backend matches server type if both are provided + if backend is not None: + if isinstance(backend, str): + requested_backend = backend + elif isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + requested_backend = get_backend(inferred_backend).__name__.lower() + backend_name = backend.__name__.lower() + if backend_name != requested_backend: + raise ValueError( + f"Conflict between provided backend '{backend_name}' and server type '{inferred_backend}'." + ) + else: + raise ValueError("Invalid backend argument") + if isinstance(backend, str) and backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) + backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server else: From 84cb5e52de9e5046009ce4913f160cbaf435cac0 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:25:37 -0400 Subject: [PATCH 40/74] update for caller_name --- dash/dash.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dash/dash.py b/dash/dash.py index 0e7be84128..4aa18bd9c0 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -558,6 +558,9 @@ def __init__( # pylint: disable=too-many-statements backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server + # Update caller_name from server's name attribute if available + if hasattr(server, "name"): + caller_name = server.name else: # No server instance provided, create backend and let backend create server self.backend = backend_cls() From 29cf8232684cd881034ad5986c8de51bc580a9e4 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:27:30 -0400 Subject: [PATCH 41/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 4aa18bd9c0..05249fe583 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -543,11 +543,10 @@ def __init__( # pylint: disable=too-many-statements elif isinstance(backend, type): # get_backend returns the backend class for a string # So we compare the class names - requested_backend = get_backend(inferred_backend).__name__.lower() - backend_name = backend.__name__.lower() - if backend_name != requested_backend: + expected_backend_cls = get_backend(inferred_backend) + if backend is not expected_backend_cls: raise ValueError( - f"Conflict between provided backend '{backend_name}' and server type '{inferred_backend}'." + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) else: raise ValueError("Invalid backend argument") From 5d0f4dced2c362eb89c1211ca31a296aa69eda26 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:30:45 -0400 Subject: [PATCH 42/74] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 05249fe583..ce44c936ef 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -271,7 +271,7 @@ class Dash(ObsoleteChecker): :param backend: The backend to use for the Dash app. Can be a string (name of the backend) or a backend class. Default is None, which - selects the Flask backend. Currently, "flask" and "fastapi" backends + selects the Flask backend. Currently, "flask", "fastapi", and "quart" backends are supported. :type backend: string or type From 86f452873a4c8c9068296154578e7a05121d5f4d Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:33:10 -0400 Subject: [PATCH 43/74] adjustments for matching types --- dash/dash.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 05249fe583..db77f3f4fb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -538,9 +538,7 @@ def __init__( # pylint: disable=too-many-statements raise ValueError("Unsupported server type") # Validate that backend matches server type if both are provided if backend is not None: - if isinstance(backend, str): - requested_backend = backend - elif isinstance(backend, type): + if isinstance(backend, type): # get_backend returns the backend class for a string # So we compare the class names expected_backend_cls = get_backend(inferred_backend) @@ -548,9 +546,9 @@ def __init__( # pylint: disable=too-many-statements raise ValueError( f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) - else: + elif not isinstance(backend, str): raise ValueError("Invalid backend argument") - if isinstance(backend, str) and backend.lower() != inferred_backend: + elif backend.lower() != inferred_backend: raise ValueError( f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) From 2a88385f46fa57758f2ba1023526db892cbcadf7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:33:38 -0400 Subject: [PATCH 44/74] Update dash/backend/registry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/registry.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index 1b80da879f..4aac7142ef 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -17,6 +17,8 @@ def get_backend(name): return getattr(module, class_name) except KeyError: raise ValueError(f"Unknown backend: {name}") - except (ImportError, AttributeError) as e: - raise ImportError(f"Could not import backend '{name}': {e}") + except ImportError as e: + raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") + except AttributeError as e: + raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") From bc51c0d0269cf0fba3d25e4f07a26cfad0f3bf06 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:35:54 -0400 Subject: [PATCH 45/74] Update dash/backend/registry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index 4aac7142ef..fb9c99cc2d 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -15,10 +15,10 @@ def get_backend(name): module_name, class_name = _backend_imports[name.lower()] module = importlib.import_module(module_name) return getattr(module, class_name) - except KeyError: - raise ValueError(f"Unknown backend: {name}") + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: - raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") + raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") from e except AttributeError as e: - raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") + raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") from e From 1b4d0d3f767ab0ca6ffd2829d5b75c03642fbcef Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:38:03 -0400 Subject: [PATCH 46/74] fixing another type check --- dash/dash.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index aa9e4c51ca..a95f969faa 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -542,7 +542,10 @@ def __init__( # pylint: disable=too-many-statements # get_backend returns the backend class for a string # So we compare the class names expected_backend_cls = get_backend(inferred_backend) - if backend is not expected_backend_cls: + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): raise ValueError( f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) From f867f98fd791ec790d755f75eb0a6e5b8a986117 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:03:26 -0400 Subject: [PATCH 47/74] fixing for lint --- dash/backend/__init__.py | 4 +- dash/backend/base_server.py | 16 +++++-- dash/backend/fastapi.py | 10 ++-- dash/backend/flask.py | 9 +++- dash/backend/quart.py | 91 +++++++++++++++++++++++-------------- dash/backend/registry.py | 17 ++++--- dash/dash.py | 11 ++++- 7 files changed, 108 insertions(+), 50 deletions(-) diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py index 497f2dca2d..eb1d47bc3f 100644 --- a/dash/backend/__init__.py +++ b/dash/backend/__init__.py @@ -1,6 +1,8 @@ # python import contextvars -from .registry import * +from .registry import get_backend # pylint: disable=unused-import + +__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"] _request_adapter_var = contextvars.ContextVar("request_adapter") diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py index 8c902f4248..4855f86ad6 100644 --- a/dash/backend/base_server.py +++ b/dash/backend/base_server.py @@ -8,7 +8,9 @@ def __call__(self, server, *args, **kwargs) -> Any: return server(*args, **kwargs) @abstractmethod - def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface + def create_app( + self, name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface pass @abstractmethod @@ -22,7 +24,9 @@ def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface + def add_url_rule( + self, app, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface pass @abstractmethod @@ -34,11 +38,15 @@ def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface + def run( + self, app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface pass @abstractmethod diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index b2feeb446f..d283e90346 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -117,7 +117,9 @@ async def catchall(request: Request): # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): + def add_url_rule( + self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False + ): if rule == "": rule = "/" if isinstance(view_func, str): @@ -307,8 +309,11 @@ def register_callback_api_routes(self, app, callback_api_paths): sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model(f"Payload_{endpoint}", **fields) + Model = create_model( + f"Payload_{endpoint}", **fields + ) # pylint: disable=cell-var-from-loop + # pylint: disable=cell-var-from-loop async def view_func(request: Request, body: Model): kwargs = body.dict(exclude_unset=True) if inspect.iscoroutinefunction(handler): @@ -317,7 +322,6 @@ async def view_func(request: Request, body: Model): result = handler(**kwargs) return JSONResponse(content=result) - app.add_api_route( route, view_func, diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 2d7d01af32..b48225a3c5 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -4,8 +4,8 @@ import sys import mimetypes import time -import flask import inspect +import flask from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -213,18 +213,23 @@ def register_callback_api_routes(self, app, callback_api_paths): methods = ["POST"] if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): data = flask.request.get_json() result = await handler(**data) if data else await handler() return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): data = flask.request.get_json() result = handler(**data) if data else handler() return flask.jsonify(result) # Flask 2.x+ supports async views natively - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + app.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) class FlaskRequestAdapter: diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 5bb568fe72..c3d42dadee 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -1,15 +1,25 @@ -from .base_server import BaseDashServer -from quart import Quart, Request, Response, jsonify, request -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from dash.fingerprint import check_fingerprint -from dash import _validate -from contextvars import copy_context import inspect import pkgutil import mimetypes import sys import time +from contextvars import copy_context + +try: + import quart + from quart import Quart, Response, jsonify, request, Blueprint +except ImportError: + quart = None + Quart = None + Response = None + jsonify = None + request = None + Blueprint = None +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from .base_server import BaseDashServer class QuartDashServer(BaseDashServer): @@ -24,7 +34,7 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - return super().__call__(server, *args, **kwargs) + return server(*args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) @@ -36,8 +46,6 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - from quart import Blueprint - bp = Blueprint( blueprint_name, __name__, @@ -53,15 +61,15 @@ async def _wrap_errors(_error_request, error): return tb, 500 def register_timing_hooks(self, app, _first_run): # parity with Flask factory - from quart import g - @app.before_request async def _before_request(): # pragma: no cover - timing infra - g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} + quart.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } @app.after_request async def _after_request(response): # pragma: no cover - timing infra - timing_information = getattr(g, "timing_information", None) + timing_information = getattr(quart.g, "timing_information", None) if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -90,7 +98,7 @@ async def _invalid_resource(err): return err.args[0], 404 def _html_response_wrapper(self, view_func): - async def wrapped(*args, **kwargs): + async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val @@ -105,21 +113,25 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): ) def setup_index(self, dash_app): - async def index(): + async def index(*args, **kwargs): adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) - return Response(dash_app.index(), content_type="text/html") + adapter.set_request() + return Response(dash_app.index(*args, **kwargs), content_type="text/html") + # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app): - async def catchall(path): # noqa: ARG001 - path is unused but kept for route signature + async def catchall( + path, *args, **kwargs + ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) - return Response(dash_app.index(), content_type="text/html") + adapter.set_request() + return Response(dash_app.index(*args, **kwargs), content_type="text/html") + # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) def before_request(self, app, func): @@ -135,7 +147,7 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - self.config = {'debug': debug, **kwargs} if debug else kwargs + self.config = {"debug": debug, **kwargs} if debug else kwargs app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): @@ -147,7 +159,9 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req): # noqa: ARG002 unused req preserved for interface parity + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path + ): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -170,24 +184,30 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, dash_app): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, request + dash_app, package_name, fingerprinted_path ) + # pylint: disable=protected-access dash_app._add_url( "_dash-component-suites//", serve, ) + # pylint: disable=unused-argument def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) + adapter.set_request() body = await request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) + # pylint: disable=protected-access func = dash_app._prepare_callback(g, body) + # pylint: disable=protected-access args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) ctx = copy_context() + # pylint: disable=protected-access partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async @@ -209,20 +229,25 @@ def register_callback_api_routes(self, app, callback_api_paths): def _make_view_func(handler): if inspect.iscoroutinefunction(handler): + async def async_view_func(*args, **kwargs): data = await request.get_json() result = await handler(**data) if data else await handler() return jsonify(result) + return async_view_func - else: - async def sync_view_func(*args, **kwargs): - data = await request.get_json() - result = handler(**data) if data else handler() - return jsonify(result) - return sync_view_func + + async def sync_view_func(*args, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + return sync_view_func view_func = _make_view_func(handler) - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + app.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) def _serve_default_favicon(self): return Response( @@ -234,7 +259,7 @@ class QuartRequestAdapter: def __init__(self) -> None: self._request = None - def set_request(self, request: Request) -> None: + def set_request(self) -> None: self._request = request # Accessors (instance-based) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index fb9c99cc2d..4aae9fafc5 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -1,15 +1,17 @@ import importlib _backend_imports = { - 'flask': ('dash.backend.flask', 'FlaskDashServer'), - 'fastapi': ('dash.backend.fastapi', 'FastAPIDashServer'), - 'quart': ('dash.backend.quart', 'QuartDashServer'), + "flask": ("dash.backend.flask", "FlaskDashServer"), + "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"), + "quart": ("dash.backend.quart", "QuartDashServer"), } + def register_backend(name, module_path, class_name): """Register a new backend by name.""" _backend_imports[name.lower()] = (module_path, class_name) + def get_backend(name): try: module_name, class_name = _backend_imports[name.lower()] @@ -18,7 +20,10 @@ def get_backend(name): except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: - raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") from e + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e except AttributeError as e: - raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") from e - + raise AttributeError( + f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}" + ) from e diff --git a/dash/dash.py b/dash/dash.py index a95f969faa..18c933f08c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -155,23 +155,32 @@ except: # noqa: E722 page_container = None + def _is_flask_instance(obj): try: + # pylint: disable=import-outside-toplevel from flask import Flask + return isinstance(obj, Flask) except ImportError: return False + def _is_fastapi_instance(obj): try: + # pylint: disable=import-outside-toplevel from fastapi import FastAPI + return isinstance(obj, FastAPI) except ImportError: return False + def _is_quart_instance(obj): try: + # pylint: disable=import-outside-toplevel from quart import Quart + return isinstance(obj, Quart) except ImportError: return False @@ -453,7 +462,7 @@ class Dash(ObsoleteChecker): _layout: Any _extra_components: Any - def __init__( # pylint: disable=too-many-statements + def __init__( # pylint: disable=too-many-statements, too-many-branches self, name: Optional[str] = None, server: Union[bool, Callable[[], Any]] = True, From 0ed81ce67c38bb49056b9ea90564217349ba2825 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:06:32 -0400 Subject: [PATCH 48/74] fixing failing test --- tests/integration/devtools/test_devtools_error_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index b481ef2fad..005bf8c335 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "factory.py" in error0 + assert "backend" in error0 and "flask.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "factory.py" in error1 + assert "backend" in error1 and "flask.py" in error1 assert "self.wsgi_app" in error1 From 6bd342a4feed571dcb87a3eeba1e96c06be35226 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:20:12 -0400 Subject: [PATCH 49/74] fixing issue with fastapi and component suites --- dash/backend/fastapi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index d283e90346..56f2761a3d 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -231,7 +231,7 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=protected-access dash_app._add_url( - "_dash-component-suites//", + "_dash-component-suites/{package_name}/{fingerprinted_path:path}", serve, ) From b1c99537c08a1d71ee544d5defaf9f021b2895b8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:29:10 -0400 Subject: [PATCH 50/74] adjustments to fix issues with caller_name and init the app a couple of times --- dash/dash.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 18c933f08c..19f789e7c0 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -567,9 +567,6 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server - # Update caller_name from server's name attribute if available - if hasattr(server, "name"): - caller_name = server.name else: # No server instance provided, create backend and let backend create server self.backend = backend_cls() @@ -703,9 +700,6 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} - if self.server is not None: - self.init_app() - self.logger.setLevel(logging.INFO) if self.__class__.__name__ == "JupyterDash": From bd40b56c9dd797d365d840945b69c749010d4ec8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:52:35 -0400 Subject: [PATCH 51/74] adjustments for failing tests --- dash/_pages.py | 4 ++-- dash/_utils.py | 5 +++++ dash/dash.py | 10 ++++++++++ tests/integration/multi_page/test_pages_layout.py | 3 ++- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 6c00e656c7..acb26e8791 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -390,8 +390,8 @@ def _path_to_page(path_id): def _page_meta_tags(app, request): - request_url = request.get_path() - start_page, path_variables = _path_to_page(request_url.strip("/")) + request_path = request.get_path() + start_page, path_variables = _path_to_page(request_path.strip("/")) image = start_page.get("image", "") if image: diff --git a/dash/_utils.py b/dash/_utils.py index f118e61538..ef6c63c281 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -104,6 +104,11 @@ def set_read_only(self, names, msg="Attribute is read-only"): else: object.__setattr__(self, "_read_only", new_read_only) + def unset_read_only(self, keys): + if hasattr(self, "_read_only"): + for key in keys: + self._read_only.pop(key, None) + def finalize(self, msg="Object is final: No new keys may be added."): """Prevent any new keys being set.""" object.__setattr__(self, "_final", msg) diff --git a/dash/dash.py b/dash/dash.py index 19f789e7c0..c84b4476df 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -565,6 +565,8 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) backend_cls = get_backend(inferred_backend) + if name is None: + caller_name = getattr(server, "name", caller_name) self.backend = backend_cls() self.server = server else: @@ -700,6 +702,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} + if self.server is not None: + self.init_app() + self.logger.setLevel(logging.INFO) if self.__class__.__name__ == "JupyterDash": @@ -743,6 +748,11 @@ def _setup_hooks(self): def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config + config.unset_read_only([ + "url_base_pathname", + "routes_pathname_prefix", + "requests_pathname_prefix", + ]) config.update(kwargs) config.set_read_only( [ diff --git a/tests/integration/multi_page/test_pages_layout.py b/tests/integration/multi_page/test_pages_layout.py index 48751021b9..a209ae4517 100644 --- a/tests/integration/multi_page/test_pages_layout.py +++ b/tests/integration/multi_page/test_pages_layout.py @@ -3,6 +3,7 @@ from dash import Dash, Input, State, dcc, html, Output from dash.dash import _ID_LOCATION from dash.exceptions import NoLayoutException +from dash.testing.wait import until def get_app(path1="/", path2="/layout2"): @@ -57,7 +58,7 @@ def test_pala001_layout(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"], timeout=3) # test redirects dash_duo.wait_for_page(url=f"{dash_duo.server_url}/v2") From 4e50430bd5abf0772afcaeb82aef2b08b4881642 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 08:18:59 -0400 Subject: [PATCH 52/74] format dash --- dash/dash.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index c84b4476df..747901bd9a 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -748,11 +748,13 @@ def _setup_hooks(self): def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config - config.unset_read_only([ - "url_base_pathname", - "routes_pathname_prefix", - "requests_pathname_prefix", - ]) + config.unset_read_only( + [ + "url_base_pathname", + "routes_pathname_prefix", + "requests_pathname_prefix", + ] + ) config.update(kwargs) config.set_read_only( [ From 0d32e651e3d88cf9b2874422bfb5a6925d6d3518 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 09:01:33 -0400 Subject: [PATCH 53/74] removing `FlaskDashServer` from import and using `get_backend('flask')` instead --- dash/dash.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 747901bd9a..d9ac42bddf 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -64,7 +64,6 @@ from . import _validate from . import _watch from . import _get_app -from .backend.flask import FlaskDashServer from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -526,7 +525,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls = FlaskDashServer + backend_cls = get_backend('flask') elif isinstance(backend, str): backend_cls = get_backend(backend) elif isinstance(backend, type): From 1b3f61ea5d924015f4f1959b6f7cff56ccc134d6 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 09:07:53 -0400 Subject: [PATCH 54/74] reverting change to callable(title) process --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index d9ac42bddf..1963884072 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2474,7 +2474,7 @@ def update(pathname_, search_, **states): **{**(path_variables or {}), **query_parameters, **states} ) if callable(title): - title = title(**{**(path_variables or {})}) + title = title(**(path_variables or {})) return layout, {"title": title} From c6805b5b6ac70b05ef0a73ce697fcf77f8a2d753 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:11:11 -0400 Subject: [PATCH 55/74] fixing for lint --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 1963884072..18ad1c2367 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -525,7 +525,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls = get_backend('flask') + backend_cls = get_backend("flask") elif isinstance(backend, str): backend_cls = get_backend(backend) elif isinstance(backend, type): From 8c7808962c3c7914ab82df2ea74d5691704a35c7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:18:30 -0400 Subject: [PATCH 56/74] adding custom error handling per backend, tests and adjustments to the flow. Made endpoints for downloading the reqs --- .github/workflows/testing.yml | 103 +++++++++ dash/backend/fastapi.py | 184 ++++++++++++--- dash/backend/flask.py | 54 ++++- dash/backend/quart.py | 121 +++++++++- .../error/FrontEnd/FrontEndError.react.js | 44 ++-- dash/dash.py | 68 +----- dash/testing/application_runners.py | 20 +- package.json | 2 +- requirements/fastapi.txt | 2 + requirements/quart.txt | 1 + tests/backend_tests/__init__.py | 0 .../backend_tests/test_preconfig_backends.py | 211 ++++++++++++++++++ 12 files changed, 688 insertions(+), 122 deletions(-) create mode 100644 requirements/fastapi.txt create mode 100644 requirements/quart.txt create mode 100644 tests/backend_tests/__init__.py create mode 100644 tests/backend_tests/test_preconfig_backends.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 1fc0df1845..068fe777d1 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -271,6 +271,109 @@ jobs: cd bgtests pytest --headless --nopercyfinalize tests/async_tests -v -s + backend-tests: + name: Run Backend Callback Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.backend_tests_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + services: + redis: + image: redis:6 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + REDIS_URL: redis://localhost:6379 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<78.0.0" + python -m pip install "selenium==4.32.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache, fastapi, quart]"' \; + + - name: Install Google Chrome + run: | + sudo apt-get update + sudo apt-get install -y google-chrome-stable + + - name: Install ChromeDriver + run: | + echo "Determining Chrome version..." + CHROME_BROWSER_VERSION=$(google-chrome --version) + echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION" + CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.') + echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION" + if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then + echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints." + exit 1 + fi + CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip" + else + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip" + fi + echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING" + echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL" + wget -q -O chromedriver.zip "$CHROMEDRIVER_URL" + unzip -o chromedriver.zip -d /tmp/ + sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver + sudo chmod +x /usr/local/bin/chromedriver + echo "/usr/local/bin" >> $GITHUB_PATH + shell: bash + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run Backend Callback Tests + run: | + mkdir bgtests + cp -r tests bgtests/tests + cd bgtests + touch __init__.py + pytest --headless --nopercyfinalize tests/backend_tests -v -s + table-unit: name: Table Unit/Lint Tests (Python ${{ matrix.python-version }}) needs: [build, changes_filter] diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index 56f2761a3d..0afcfabd07 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -6,6 +6,7 @@ from contextvars import copy_context import importlib.util import time +import traceback try: import uvicorn @@ -32,14 +33,28 @@ from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.exceptions import PreventUpdate, InvalidResourceError, InvalidCallbackReturnValue, BackgroundCallbackError from dash.backend import set_request_adapter from .base_server import BaseDashServer +import json +import os + +CONFIG_PATH = "dash_config.json" + +def save_config(config): + with open(CONFIG_PATH, "w") as f: + json.dump(config, f) + +def load_config(): + if os.path.exists(CONFIG_PATH): + with open(CONFIG_PATH, "r") as f: + return json.load(f) + return {} class FastAPIDashServer(BaseDashServer): def __init__(self): - self.config = {} + self.error_handling_mode = "prune" super().__init__() def __call__(self, server, *args, **kwargs): @@ -69,19 +84,120 @@ def register_assets_blueprint( pass def register_error_handlers(self, app): - @app.exception_handler(PreventUpdate) - async def _handle_error(_request, _exc): - return Response(status_code=204) + self.error_handling_mode = "prune" + # FastAPI uses exception handlers, but we will handle errors in middleware + pass + + def _get_traceback(self, secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if 'callback invoked' in str(err) and '_callback.py' in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + import re + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split('\n') + current_file = None + card_lines = [] + + for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += f""" +
+
{filename}
+
"""+ '\n'.join(card) + """
+
+ """ + + html = f""" + + + + {error_type}: {error_msg} // FastAPI Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html - @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(_request, exc): - return Response(content=exc.args[0], status_code=404) + def register_prune_error_handler(self, _app, _secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.exception_handler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return PlainTextResponse(tb, status_code=500) def _html_response_wrapper(self, view_func): async def wrapped(*_args, **_kwargs): @@ -104,9 +220,10 @@ async def index(request: Request): def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): + config = load_config() dash_app.enable_dev_tools( - **self.config, first_run=False - ) # do this to make sure dev tools are enabled + **config, first_run=False + ) async def catchall(request: Request): adapter = FastAPIRequestAdapter() @@ -141,11 +258,15 @@ def after_request(self, app, func): # FastAPI does not have after_request, but we can use middleware app.middleware("http")(self._make_after_middleware(func)) - def run(self, app, host, port, debug, **kwargs): + def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - self.config = dict({"debug": debug} if debug else {}, **kwargs) - reload = debug - if reload: + config = dict({"debug": debug} if debug else {}, **{ + f'dev_tools_{k}': v for k, v in dash_app._dev_tools.items()}) + save_config(config) + if debug: + if kwargs.get('reload') is None: + kwargs['reload'] = True + if kwargs.get('reload'): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name @@ -153,11 +274,10 @@ def run(self, app, host, port, debug, **kwargs): f"{module_name}:app.server", host=host, port=port, - reload=reload, **kwargs, ) else: - uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + uvicorn.run(app, host=host, port=port, **kwargs) def make_response(self, data, mimetype=None, content_type=None): headers = {} @@ -175,13 +295,21 @@ def get_request_adapter(self): def _make_before_middleware(self, func): async def middleware(request, call_next): - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - response = await call_next(request) - return response + try: + response = await call_next(request) + return response + except PreventUpdate: + # No content, nothing to update + return Response(status_code=204) + except Exception as e: + if self.error_handling_mode in ["raise", "prune"]: + # Prune the traceback to remove internal Dash calls + tb = self._get_traceback(None, e) + return Response(content=tb, media_type='text/html', status_code=500) + return JSONResponse( + status_code=500, + content={"error": "InternalServerError", "message": str(e.args[0])}, + ) return middleware diff --git a/dash/backend/flask.py b/dash/backend/flask.py index b48225a3c5..75526e6feb 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -11,6 +11,7 @@ from dash.exceptions import PreventUpdate, InvalidResourceError from dash.backend import set_request_adapter from .base_server import BaseDashServer +import traceback class FlaskDashServer(BaseDashServer): @@ -44,11 +45,52 @@ def _handle_error(_): def _invalid_resources_handler(err): return err.args[0], 404 - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - def _wrap_errors(error): - tb = get_traceback_func(secret, error) - return tb, 500 + def _get_traceback(self, secret, error: Exception): + try: + from werkzeug.debug import tbtools + except ImportError: + tbtools = None + + def _get_skip(error): + from dash._callback import _invoke_callback, _async_invoke_callback + + tb = error.__traceback__ + skip = 1 + while tb.tb_next is not None: + skip += 1 + tb = tb.tb_next + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return skip + return skip + + def _do_skip(error): + from dash._callback import _invoke_callback, _async_invoke_callback + + tb = error.__traceback__ + while tb.tb_next is not None: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return tb.tb_next + tb = tb.tb_next + return error.__traceback__ + + if hasattr(tbtools, "get_current_traceback"): + return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() + if hasattr(tbtools, "DebugTraceback"): + return tbtools.DebugTraceback(error, skip=_get_skip(error)).render_debugger_html(True, secret, True) + return "".join(traceback.format_exception(type(error), error, _do_skip(error))) + + def register_prune_error_handler(self, app, secret, prune_errors): + if prune_errors: + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): app.add_url_rule( @@ -61,7 +103,7 @@ def before_request(self, app, func): def after_request(self, app, func): app.after_request(func) - def run(self, app, host, port, debug, **kwargs): + def run(self, _dash_app, app, host, port, debug, **kwargs): app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): diff --git a/dash/backend/quart.py b/dash/backend/quart.py index c3d42dadee..40f30108b2 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -4,6 +4,7 @@ import sys import time from contextvars import copy_context +import traceback try: import quart @@ -31,6 +32,7 @@ class QuartDashServer(BaseDashServer): def __init__(self) -> None: self.config = {} + self.error_handling_mode = "prune" super().__init__() def __call__(self, server, *args, **kwargs): @@ -54,11 +56,120 @@ def register_assets_blueprint( ) app.register_blueprint(bp) - def register_prune_error_handler(self, app, secret, get_traceback_func): + def _get_traceback(self, secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if 'callback invoked' in str(err) and '_callback.py' in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + import re + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split('\n') + current_file = None + card_lines = [] + + for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += f""" +
+
{filename}
+
""" + '\n'.join(card) + """
+
+ """ + + html = f""" + + + + {error_type}: {error_msg} // Quart Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html + + def register_prune_error_handler(self, app, secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + @app.errorhandler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return tb, 500 + async def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return Response(tb, status=500, content_type='text/html') def register_timing_hooks(self, app, _first_run): # parity with Flask factory @app.before_request @@ -146,7 +257,7 @@ async def _after(response): await result return response - def run(self, app, host, port, debug, **kwargs): + def run(self, _dash_app, app, host, port, debug, **kwargs): self.config = {"debug": debug, **kwargs} if debug else kwargs app.run(host=host, port=port, debug=debug, **kwargs) diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index 176cb2c6f8..ab5430e7da 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -121,13 +121,17 @@ function BackendError({error, base}) { const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { + // Helper to detect full HTML document + const isFullHtmlDoc = typeof error.html === 'string' && + error.html.trim().toLowerCase().startsWith(' - {/* - * 40 is a rough heuristic - if longer than 40 then the - * message might overflow into ellipses in the title above & - * will need to be displayed in full in this error body - */} + {/* Frontend error message */} {typeof error.message !== 'string' || error.message.length < MAX_MESSAGE_LENGTH ? null : (
@@ -137,6 +141,7 @@ function UnconnectedErrorContent({error, base}) {
)} + {/* Frontend stack trace */} {typeof error.stack !== 'string' ? null : (
@@ -149,7 +154,6 @@ function UnconnectedErrorContent({error, base}) { browser's console.) - {error.stack.split('\n').map((line, i) => (

{line}

))} @@ -157,24 +161,30 @@ function UnconnectedErrorContent({error, base}) {
)} - {/* Backend Error */} - {typeof error.html !== 'string' ? null : error.html - .substring(0, '
- {/* Embed werkzeug debugger in an iframe to prevent - CSS leaking - werkzeug HTML includes a bunch - of CSS on base html elements like `` - */}
- ) : ( + ) : isHtmlFragment ? ( + // Backend error: HTML fragment +
+
+
+ ) : typeof error.html === 'string' ? ( + // Backend error: plain text
-
{error.html}
+
+
{error.html}
+
- )} + ) : null}
); } diff --git a/dash/dash.py b/dash/dash.py index 18ad1c2367..fa1aa45ea5 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -185,63 +185,6 @@ def _is_quart_instance(obj): return False -def _get_traceback(secret, error: Exception): - try: - # pylint: disable=import-outside-toplevel - from werkzeug.debug import tbtools - except ImportError: - tbtools = None - - def _get_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - skip = 1 - while tb.tb_next is not None: - skip += 1 - tb = tb.tb_next - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return skip - - return skip - - def _do_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - while tb.tb_next is not None: - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return tb.tb_next - tb = tb.tb_next - return error.__traceback__ - - # werkzeug<2.1.0 - if hasattr(tbtools, "get_current_traceback"): - return tbtools.get_current_traceback( # type: ignore - skip=_get_skip(error) - ).render_full() - - if hasattr(tbtools, "DebugTraceback"): - # pylint: disable=no-member - return tbtools.DebugTraceback( # type: ignore - error, skip=_get_skip(error) - ).render_debugger_html(True, secret, True) - - return "".join(traceback.format_exception(type(error), error, _do_skip(error))) - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -2058,11 +2001,10 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - elif dev_tools.prune_errors: - secret = gen_salt(20) - self.backend.register_prune_error_handler( - self.server, secret, _get_traceback - ) + secret = gen_salt(20) + self.backend.register_prune_error_handler( + self.server, secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: self.backend.register_timing_hooks(self.server, first_run) @@ -2350,7 +2292,7 @@ def verify_url_part(served_part, url_part, part_name): ) else: self.backend.run( - self.server, host=host, port=port, debug=debug, **flask_run_options + self, self.server, host=host, port=port, debug=debug, **flask_run_options ) def enable_pages(self) -> None: diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index dc88afe844..df036aabfa 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -171,7 +171,15 @@ def run(): self.port = options["port"] try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run( + **options + ) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") except Exception as error: @@ -229,7 +237,15 @@ def target(): options = kwargs.copy() try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run( + **options + ) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") raise diff --git a/package.json b/package.json index e78e279c1b..b7416dbb34 100644 --- a/package.json +++ b/package.json @@ -44,7 +44,7 @@ "setup-tests.R": "run-s private::test.R.deploy-*", "citest.integration": "run-s setup-tests.py private::test.integration-*", "citest.unit": "run-s private::test.unit-**", - "test": "pytest && cd dash/dash-renderer && npm run test", + "test": "pytest --ignore=tests/backend_tests && cd dash/dash-renderer && npm run test", "first-build": "cd dash/dash-renderer && npm i && cd ../../ && cd components/dash-html-components && npm i && npm run extract && cd ../../ && npm run build" }, "devDependencies": { diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt new file mode 100644 index 0000000000..97dc7cd8c1 --- /dev/null +++ b/requirements/fastapi.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/requirements/quart.txt b/requirements/quart.txt new file mode 100644 index 0000000000..60af440c9c --- /dev/null +++ b/requirements/quart.txt @@ -0,0 +1 @@ +quart diff --git a/tests/backend_tests/__init__.py b/tests/backend_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py new file mode 100644 index 0000000000..4868406814 --- /dev/null +++ b/tests/backend_tests/test_preconfig_backends.py @@ -0,0 +1,211 @@ +import pytest +from dash import Dash, Input, Output, html, dcc + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ( + "fastapi", + "dash_duo", + "Hello FastAPI!" + ), + ( + "quart", + "dash_duo_mp", + "Hello Quart!" + ), + ] +) +def test_backend_basic_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + if backend == "fastapi": + from fastapi import FastAPI + server = FastAPI() + else: + import quart + server = quart.Quart(__name__) + app = Dash(__name__, server=server) + app.layout = html.Div([ + dcc.Input(id="input", value=input_value, type="text"), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") + dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") + assert dash_duo.get_logs() == [] + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "reload": False, "dev_tools_ui": True}, + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + ), + ] +) +def test_backend_error_handling(request, backend, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "dev_tools_ui": True, "dev_tools_prune_errors": False, + "reload": False}, + "fastapi.py" + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + "dev_tools_prune_errors": False, + }, + "quart.py" + ), + ] +) +def test_backend_error_handling_no_prune(request, backend, fixture, start_server_kwargs, error_msg): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "backend" in error0 and error_msg in error0 + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, + "reload": False}, + "fastapi.py" + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + "quart.py" + ), + ] +) +def test_backend_error_handling_prune(request, backend, fixture, start_server_kwargs, error_msg): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "dash/backend" not in error0 and error_msg not in error0 + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Background FastAPI!"), + ("quart", "dash_duo_mp", "Background Quart!"), + ] +) +def test_backend_background_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + background_callback_manager = DiskcacheManager(cache) + + + app = Dash(__name__, backend=backend, background_callback_manager=background_callback_manager) + app.layout = html.Div([ + dcc.Input(id="input", value=input_value, type="text"), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("input", "value"), background=True) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {backend.title()} BG Test") + assert dash_duo.get_logs() == [] From 5211f6fb43f335b5c99e37859f5f2f5ec2dbe729 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:20:56 -0400 Subject: [PATCH 57/74] adjusments for formatting --- dash/backend/fastapi.py | 46 ++++--- dash/backend/flask.py | 5 +- dash/backend/quart.py | 19 ++- .../error/FrontEnd/FrontEndError.react.js | 7 +- dash/dash.py | 7 +- dash/testing/application_runners.py | 8 +- .../backend_tests/test_preconfig_backends.py | 112 +++++++++--------- 7 files changed, 118 insertions(+), 86 deletions(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index 0afcfabd07..a76a5a47ec 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -33,7 +33,12 @@ from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError, InvalidCallbackReturnValue, BackgroundCallbackError +from dash.exceptions import ( + PreventUpdate, + InvalidResourceError, + InvalidCallbackReturnValue, + BackgroundCallbackError, +) from dash.backend import set_request_adapter from .base_server import BaseDashServer @@ -42,16 +47,19 @@ CONFIG_PATH = "dash_config.json" + def save_config(config): with open(CONFIG_PATH, "w") as f: json.dump(config, f) + def load_config(): if os.path.exists(CONFIG_PATH): with open(CONFIG_PATH, "r") as f: return json.load(f) return {} + class FastAPIDashServer(BaseDashServer): def __init__(self): self.error_handling_mode = "prune" @@ -96,7 +104,7 @@ def _get_traceback(self, secret, error: Exception): for err in errors: if self.error_handling_mode == "prune": if not callback_handled: - if 'callback invoked' in str(err) and '_callback.py' in str(err): + if "callback invoked" in str(err) and "_callback.py" in str(err): callback_handled = True continue pass_errs.append(err) @@ -106,9 +114,10 @@ def _get_traceback(self, secret, error: Exception): # Parse traceback lines to group by file import re + file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split('\n') + lines = formatted_tb.split("\n") current_file = None card_lines = [] @@ -117,7 +126,9 @@ def _get_traceback(self, secret, error: Exception): if match: if current_file and card_lines: file_cards.append((current_file, card_lines)) - current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) card_lines = [line] elif current_file: card_lines.append(line) @@ -126,12 +137,16 @@ def _get_traceback(self, secret, error: Exception): cards_html = "" for filename, card in file_cards: - cards_html += f""" + cards_html += ( + f"""
{filename}
-
"""+ '\n'.join(card) + """
+
"""
+                + "\n".join(card)
+                + """
""" + ) html = f""" @@ -198,7 +213,6 @@ def register_prune_error_handler(self, _app, _secret, prune_errors): else: self.error_handling_mode = "raise" - def _html_response_wrapper(self, view_func): async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly @@ -221,9 +235,7 @@ def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): config = load_config() - dash_app.enable_dev_tools( - **config, first_run=False - ) + dash_app.enable_dev_tools(**config, first_run=False) async def catchall(request: Request): adapter = FastAPIRequestAdapter() @@ -260,13 +272,15 @@ def after_request(self, app, func): def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - config = dict({"debug": debug} if debug else {}, **{ - f'dev_tools_{k}': v for k, v in dash_app._dev_tools.items()}) + config = dict( + {"debug": debug} if debug else {}, + **{f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items()}, + ) save_config(config) if debug: - if kwargs.get('reload') is None: - kwargs['reload'] = True - if kwargs.get('reload'): + if kwargs.get("reload") is None: + kwargs["reload"] = True + if kwargs.get("reload"): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name @@ -305,7 +319,7 @@ async def middleware(request, call_next): if self.error_handling_mode in ["raise", "prune"]: # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) - return Response(content=tb, media_type='text/html', status_code=500) + return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, content={"error": "InternalServerError", "message": str(e.args[0])}, diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 75526e6feb..542da93129 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -82,11 +82,14 @@ def _do_skip(error): if hasattr(tbtools, "get_current_traceback"): return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() if hasattr(tbtools, "DebugTraceback"): - return tbtools.DebugTraceback(error, skip=_get_skip(error)).render_debugger_html(True, secret, True) + return tbtools.DebugTraceback( + error, skip=_get_skip(error) + ).render_debugger_html(True, secret, True) return "".join(traceback.format_exception(type(error), error, _do_skip(error))) def register_prune_error_handler(self, app, secret, prune_errors): if prune_errors: + @app.errorhandler(Exception) def _wrap_errors(error): tb = self._get_traceback(secret, error) diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 40f30108b2..71a2053a61 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -64,7 +64,7 @@ def _get_traceback(self, secret, error: Exception): for err in errors: if self.error_handling_mode == "prune": if not callback_handled: - if 'callback invoked' in str(err) and '_callback.py' in str(err): + if "callback invoked" in str(err) and "_callback.py" in str(err): callback_handled = True continue pass_errs.append(err) @@ -74,9 +74,10 @@ def _get_traceback(self, secret, error: Exception): # Parse traceback lines to group by file import re + file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split('\n') + lines = formatted_tb.split("\n") current_file = None card_lines = [] @@ -85,7 +86,9 @@ def _get_traceback(self, secret, error: Exception): if match: if current_file and card_lines: file_cards.append((current_file, card_lines)) - current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) card_lines = [line] elif current_file: card_lines.append(line) @@ -94,12 +97,16 @@ def _get_traceback(self, secret, error: Exception): cards_html = "" for filename, card in file_cards: - cards_html += f""" + cards_html += ( + f"""
{filename}
-
""" + '\n'.join(card) + """
+
"""
+                + "\n".join(card)
+                + """
""" + ) html = f""" @@ -169,7 +176,7 @@ def register_prune_error_handler(self, app, secret, prune_errors): @app.errorhandler(Exception) async def _wrap_errors(error): tb = self._get_traceback(secret, error) - return Response(tb, status=500, content_type='text/html') + return Response(tb, status=500, content_type="text/html") def register_timing_hooks(self, app, _first_run): # parity with Flask factory @app.before_request diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index ab5430e7da..db4c6ddd2b 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -122,12 +122,13 @@ const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { // Helper to detect full HTML document - const isFullHtmlDoc = typeof error.html === 'string' && + const isFullHtmlDoc = + typeof error.html === 'string' && error.html.trim().toLowerCase().startsWith(' diff --git a/dash/dash.py b/dash/dash.py index fa1aa45ea5..994453f4a2 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2292,7 +2292,12 @@ def verify_url_part(served_part, url_part, part_name): ) else: self.backend.run( - self, self.server, host=host, port=port, debug=debug, **flask_run_options + self, + self.server, + host=host, + port=port, + debug=debug, + **flask_run_options, ) def enable_pages(self) -> None: diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index df036aabfa..2956f1a4c0 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -174,9 +174,7 @@ def run(): module = app.server.__class__.__module__ # FastAPI support if not module.startswith("flask"): - app.run( - **options - ) + app.run(**options) # Dash/Flask fallback else: app.run(threaded=True, **options) @@ -240,9 +238,7 @@ def target(): module = app.server.__class__.__module__ # FastAPI support if not module.startswith("flask"): - app.run( - **options - ) + app.run(**options) # Dash/Flask fallback else: app.run(threaded=True, **options) diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 4868406814..5fbd28dfd9 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -1,34 +1,28 @@ import pytest from dash import Dash, Input, Output, html, dcc + @pytest.mark.parametrize( "backend,fixture,input_value", [ - ( - "fastapi", - "dash_duo", - "Hello FastAPI!" - ), - ( - "quart", - "dash_duo_mp", - "Hello Quart!" - ), - ] + ("fastapi", "dash_duo", "Hello FastAPI!"), + ("quart", "dash_duo_mp", "Hello Quart!"), + ], ) def test_backend_basic_callback(request, backend, fixture, input_value): dash_duo = request.getfixturevalue(fixture) if backend == "fastapi": from fastapi import FastAPI + server = FastAPI() else: import quart + server = quart.Quart(__name__) app = Dash(__name__, server=server) - app.layout = html.Div([ - dcc.Input(id="input", value=input_value, type="text"), - html.Div(id="output") - ]) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("input", "value")) def update_output(value): @@ -41,6 +35,7 @@ def update_output(value): dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") assert dash_duo.get_logs() == [] + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs", [ @@ -58,15 +53,14 @@ def update_output(value): "dev_tools_hot_reload": False, }, ), - ] + ], ) def test_backend_error_handling(request, backend, fixture, start_server_kwargs): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -79,21 +73,27 @@ def error_callback(n): dash_duo.find_element("#btn").click() dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + def get_error_html(dash_duo, index): # error is in an iframe so is annoying to read out - get it from the store return dash_duo.driver.execute_script( "return store.getState().error.backEnd[{}].error.html;".format(index) ) + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs, error_msg", [ ( "fastapi", "dash_duo", - {"debug": True, "dev_tools_ui": True, "dev_tools_prune_errors": False, - "reload": False}, - "fastapi.py" + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + "fastapi.py", ), ( "quart", @@ -104,17 +104,18 @@ def get_error_html(dash_duo, index): "dev_tools_hot_reload": False, "dev_tools_prune_errors": False, }, - "quart.py" + "quart.py", ), - ] + ], ) -def test_backend_error_handling_no_prune(request, backend, fixture, start_server_kwargs, error_msg): +def test_backend_error_handling_no_prune( + request, backend, fixture, start_server_kwargs, error_msg +): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -132,16 +133,11 @@ def error_callback(n): assert "ZeroDivisionError" in error0 assert "backend" in error0 and error_msg in error0 + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs, error_msg", [ - ( - "fastapi", - "dash_duo", - {"debug": True, - "reload": False}, - "fastapi.py" - ), + ("fastapi", "dash_duo", {"debug": True, "reload": False}, "fastapi.py"), ( "quart", "dash_duo_mp", @@ -150,17 +146,18 @@ def error_callback(n): "use_reloader": False, "dev_tools_hot_reload": False, }, - "quart.py" + "quart.py", ), - ] + ], ) -def test_backend_error_handling_prune(request, backend, fixture, start_server_kwargs, error_msg): +def test_backend_error_handling_prune( + request, backend, fixture, start_server_kwargs, error_msg +): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -178,28 +175,35 @@ def error_callback(n): assert "ZeroDivisionError" in error0 assert "dash/backend" not in error0 and error_msg not in error0 + @pytest.mark.parametrize( "backend,fixture,input_value", [ ("fastapi", "dash_duo", "Background FastAPI!"), ("quart", "dash_duo_mp", "Background Quart!"), - ] + ], ) def test_backend_background_callback(request, backend, fixture, input_value): dash_duo = request.getfixturevalue(fixture) import diskcache + cache = diskcache.Cache("./cache") from dash.background_callback import DiskcacheManager - background_callback_manager = DiskcacheManager(cache) + background_callback_manager = DiskcacheManager(cache) - app = Dash(__name__, backend=backend, background_callback_manager=background_callback_manager) - app.layout = html.Div([ - dcc.Input(id="input", value=input_value, type="text"), - html.Div(id="output") - ]) + app = Dash( + __name__, + backend=backend, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) - @app.callback(Output("output", "children"), Input("input", "value"), background=True) + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) def update_output_bg(value): return f"Background typed: {value}" @@ -207,5 +211,7 @@ def update_output_bg(value): dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") dash_duo.find_element("#input").clear() dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") - dash_duo.wait_for_text_to_equal("#output", f"Background typed: {backend.title()} BG Test") + dash_duo.wait_for_text_to_equal( + "#output", f"Background typed: {backend.title()} BG Test" + ) assert dash_duo.get_logs() == [] From 6a34208f92d20cf3a7283407c1ac68528d5c9d8a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:36:28 -0400 Subject: [PATCH 58/74] adjustment to retest backend --- .github/workflows/testing.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 068fe777d1..48bfe0c305 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -17,6 +17,7 @@ jobs: # This output will be 'true' if files in the 'table_related_paths' list changed, 'false' otherwise. table_paths_changed: ${{ steps.filter.outputs.table_related_paths }} background_cb_changed: ${{ steps.filter.outputs.background_paths }} + backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -37,6 +38,9 @@ jobs: - 'tests/background_callback/**' - 'tests/async_tests/**' - 'requirements/**' + backend_paths: + - 'dash/backend/**' + - 'tests/backend/**' build: name: Build Dash Package @@ -276,7 +280,7 @@ jobs: needs: [build, changes_filter] if: | (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || - needs.changes_filter.outputs.backend_tests_changed == 'true' + needs.changes_filter.outputs.backend_cb_changed == 'true' timeout-minutes: 30 runs-on: ubuntu-latest strategy: From 1a2b53124b11b16b014ed941d822377659d01a5a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:44:12 -0400 Subject: [PATCH 59/74] adding missing reqs association --- .github/workflows/testing.yml | 2 +- setup.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 48bfe0c305..be5caf4929 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -331,7 +331,7 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install "setuptools<78.0.0" python -m pip install "selenium==4.32.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache, fastapi, quart]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; - name: Install Google Chrome run: | diff --git a/setup.py b/setup.py index 7ed781c20d..950bcbe14d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,9 @@ def read_req_file(req_type): "testing": read_req_file("testing"), "celery": read_req_file("celery"), "diskcache": read_req_file("diskcache"), - "compress": read_req_file("compress") + "compress": read_req_file("compress"), + "fastapi": read_req_file("fastapi"), + "quart": read_req_file("quart"), }, entry_points={ "console_scripts": [ From 465e45e469324a25498f32fc5979ba190205f328 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:59:36 -0400 Subject: [PATCH 60/74] fixing minor linting issues --- dash/backend/fastapi.py | 27 +++++++++++---------------- dash/backend/flask.py | 11 +++++------ dash/backend/quart.py | 7 +++---- dash/dash.py | 1 - 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index a76a5a47ec..8c402cb187 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -7,11 +7,12 @@ import importlib.util import time import traceback +import re try: import uvicorn from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders @@ -23,7 +24,6 @@ Request = None Response = None JSONResponse = None - PlainTextResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None @@ -31,20 +31,17 @@ Any = None Optional = None + +import json +import os from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import ( PreventUpdate, - InvalidResourceError, - InvalidCallbackReturnValue, - BackgroundCallbackError, ) from dash.backend import set_request_adapter from .base_server import BaseDashServer -import json -import os - CONFIG_PATH = "dash_config.json" @@ -93,10 +90,8 @@ def register_assets_blueprint( def register_error_handlers(self, app): self.error_handling_mode = "prune" - # FastAPI uses exception handlers, but we will handle errors in middleware - pass - def _get_traceback(self, secret, error: Exception): + def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ errors = traceback.format_exception(type(error), error, tb) pass_errs = [] @@ -113,15 +108,13 @@ def _get_traceback(self, secret, error: Exception): error_msg = str(error) # Parse traceback lines to group by file - import re - file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') lines = formatted_tb.split("\n") current_file = None card_lines = [] - for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + for line in lines[:-1]: # Skip the last line (error message) match = pattern.match(line) if match: if current_file and card_lines: @@ -274,7 +267,9 @@ def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( {"debug": debug} if debug else {}, - **{f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items()}, + **{ + f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() + }, # pylint: disable=protected-access ) save_config(config) if debug: @@ -307,7 +302,7 @@ def jsonify(self, obj): def get_request_adapter(self): return FastAPIRequestAdapter - def _make_before_middleware(self, func): + def _make_before_middleware(self, _func): async def middleware(request, call_next): try: response = await call_next(request) diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 542da93129..cf544ef5bc 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -5,13 +5,14 @@ import mimetypes import time import inspect +import traceback import flask from dash.fingerprint import check_fingerprint from dash import _validate +from dash._callback import _invoke_callback, _async_invoke_callback from dash.exceptions import PreventUpdate, InvalidResourceError from dash.backend import set_request_adapter from .base_server import BaseDashServer -import traceback class FlaskDashServer(BaseDashServer): @@ -47,13 +48,13 @@ def _invalid_resources_handler(err): def _get_traceback(self, secret, error: Exception): try: - from werkzeug.debug import tbtools + from werkzeug.debug import ( + tbtools, + ) # pylint: disable=import-outside-toplevel except ImportError: tbtools = None def _get_skip(error): - from dash._callback import _invoke_callback, _async_invoke_callback - tb = error.__traceback__ skip = 1 while tb.tb_next is not None: @@ -67,8 +68,6 @@ def _get_skip(error): return skip def _do_skip(error): - from dash._callback import _invoke_callback, _async_invoke_callback - tb = error.__traceback__ while tb.tb_next is not None: if tb.tb_frame.f_code in [ diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 71a2053a61..830d7dd3b9 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -5,6 +5,7 @@ import time from contextvars import copy_context import traceback +import re try: import quart @@ -56,7 +57,7 @@ def register_assets_blueprint( ) app.register_blueprint(bp) - def _get_traceback(self, secret, error: Exception): + def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ errors = traceback.format_exception(type(error), error, tb) pass_errs = [] @@ -73,15 +74,13 @@ def _get_traceback(self, secret, error: Exception): error_msg = str(error) # Parse traceback lines to group by file - import re - file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') lines = formatted_tb.split("\n") current_file = None card_lines = [] - for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + for line in lines[:-1]: # Skip the last line (error message) match = pattern.match(line) if match: if current_file and card_lines: diff --git a/dash/dash.py b/dash/dash.py index 994453f4a2..6bba3aadfd 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -14,7 +14,6 @@ import mimetypes import hashlib import base64 -import traceback from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List From c43a5835d78dff075e80e8df420f53aa9c37e18c Mon Sep 17 00:00:00 2001 From: chgiesse <83552131+chgiesse@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:16:39 +0200 Subject: [PATCH 61/74] Add global Request Adapter (#6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ∙ - remove contextvar from flask and quart only FastApi now relies on that ∙ - backend __init__ now holds the global request adapter and backend which get set on app initialisation ∙ request adapter and server can now be call from everywhere after the app initialised ∙ - added normal top level imports because the modules get matching loaded - but bad Import Error message when quart or equivilent are not installed ∙ - added _ as prefix to backends to avoid importing errors with their underlying ∙ - Can now move to remove unnecessary passing of the server object ∙ * Moved get_server_type to backends * ∙ moved async validation to validation ∙ replaced request.get_path with request.path ∙ * Moved custom backend check to _validation.py * Removed server injection of server methods - they use self.server now * removed use_async from dispatch server methods and use dash_app._use_async removed remaining set request process from flask * adding custom error handling per backend, tests and adjustments to the flow. Made endpoints for downloading the reqs * adjusments for formatting * adjustment to retest backend * Added Dash app as type to servers * adding missing reqs association * Addedd basic typing to servers * fixing minor linting issues * Fixed weird AI shit * Cleanup before heavy pull * Merged latest changes * f rebase * f rebase * Added Dash app as type to servers * Addedd basic typing to servers --------- Co-authored-by: Christian Giessel Co-authored-by: BSd3v <82055130+BSd3v@users.noreply.github.com> --- dash/_callback.py | 41 +-- dash/_pages.py | 32 ++- dash/_validate.py | 39 +++ dash/backend/__init__.py | 15 - dash/backend/base_server.py | 58 ---- dash/backend/registry.py | 29 -- dash/backends/__init__.py | 88 ++++++ .../fastapi.py => backends/_fastapi.py} | 264 +++++++++++------- dash/{backend/flask.py => backends/_flask.py} | 264 ++++++++++-------- dash/{backend/quart.py => backends/_quart.py} | 240 +++++++++------- dash/backends/base_server.py | 119 ++++++++ dash/dash.py | 210 +++++--------- dash_config.json | 1 + quart_app.py | 23 ++ 14 files changed, 831 insertions(+), 592 deletions(-) delete mode 100644 dash/backend/__init__.py delete mode 100644 dash/backend/base_server.py delete mode 100644 dash/backend/registry.py create mode 100644 dash/backends/__init__.py rename dash/{backend/fastapi.py => backends/_fastapi.py} (72%) rename dash/{backend/flask.py => backends/_flask.py} (55%) rename dash/{backend/quart.py => backends/_quart.py} (68%) create mode 100644 dash/backends/base_server.py create mode 100644 dash_config.json create mode 100644 quart_app.py diff --git a/dash/_callback.py b/dash/_callback.py index 6cc55b9162..4a714caeac 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,12 +1,8 @@ +from typing import Callable, Optional, Any, List, Tuple, Union +from functools import wraps import collections import hashlib -from functools import wraps - -from typing import Callable, Optional, Any, List, Tuple, Union - - import asyncio -from dash.backend import get_request_adapter from .dependencies import ( handle_callback_args, @@ -39,10 +35,11 @@ clean_property_name, ) -from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value from ._no_update import NoUpdate +from . import _validate +from . import backends async def _async_invoke_callback( @@ -176,7 +173,6 @@ def callback( Note that the endpoint will not appear in the list of registered callbacks in the Dash devtools. """ - background_spec = None config_prevent_initial_callbacks = _kwargs.pop( @@ -376,7 +372,8 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = get_request_adapter().get_args().getlist("oldJob") + adapter = backends.request_adapter() + old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: for job in old_job: @@ -390,6 +387,8 @@ def _setup_background_callback( ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) + if not callback_manager: + return to_json({"error": "No background callback manager configured"}) progress_outputs = background.get("progress") @@ -397,14 +396,11 @@ def _setup_background_callback( cache_key = callback_manager.build_cache_key( func, - # Inputs provided as dict is kwargs. func_args if func_args else func_kwargs, background.get("cache_args_to_ignore", []), None if cache_ignore_triggered else callback_ctx.get("triggered_inputs", []), ) - job_fn = callback_manager.func_registry.get(background_key) - ctx_value = AttributeDict(**context_value.get()) ctx_value.ignore_register_page = True ctx_value.pop("background_callback_manager") @@ -436,7 +432,8 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = get_request_adapter().get_args().get("cacheKey") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -453,8 +450,9 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None _progress_background_callback(response, callback_manager, background) @@ -474,8 +472,9 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -688,10 +687,11 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} - jsonResponse = None + jsonResponse: Optional[str] = None try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, @@ -762,7 +762,8 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, diff --git a/dash/_pages.py b/dash/_pages.py index acb26e8791..19a797bcf2 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -318,18 +318,22 @@ def register_page( ) page.update( supplied_title=title, - title=title - if title is not None - else CONFIG.title - if CONFIG.title != "Dash" - else page["name"], + title=( + title + if title is not None + else CONFIG.title + if CONFIG.title != "Dash" + else page["name"] + ), ) page.update( - description=description - if description - else CONFIG.description - if CONFIG.description - else "", + description=( + description + if description + else CONFIG.description + if CONFIG.description + else "" + ), order=order, supplied_order=order, supplied_layout=layout, @@ -390,15 +394,13 @@ def _path_to_page(path_id): def _page_meta_tags(app, request): - request_path = request.get_path() + request_path = request.path start_page, path_variables = _path_to_page(request_path.strip("/")) image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([request.get_root(), image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +415,7 @@ def _page_meta_tags(app, request): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": request.get_url()}, + {"property": "twitter:url", "content": request.url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/_validate.py b/dash/_validate.py index dea19d64c2..76661cef6b 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -8,6 +8,7 @@ from ._grouping import grouping_len, map_grouping from ._no_update import NoUpdate from .development.base_component import Component +from . import backends from . import exceptions from ._utils import ( patch_collections_abc, @@ -585,3 +586,41 @@ def _valid(out): return _valid(output) + + +def check_async(use_async): + if use_async is None: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + + use_async = True + except ImportError: + pass + elif use_async: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + except ImportError as exc: + raise Exception( + "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" + ) from exc + + +def check_backend(backend, inferred_backend): + if backend is not None: + if isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + expected_backend_cls, _ = backends.get_backend(inferred_backend) + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): + raise ValueError( + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." + ) + elif not isinstance(backend, str): + raise ValueError("Invalid backend argument") + elif backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py deleted file mode 100644 index eb1d47bc3f..0000000000 --- a/dash/backend/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# python -import contextvars -from .registry import get_backend # pylint: disable=unused-import - -__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"] - -_request_adapter_var = contextvars.ContextVar("request_adapter") - - -def set_request_adapter(adapter): - _request_adapter_var.set(adapter) - - -def get_request_adapter(): - return _request_adapter_var.get() diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py deleted file mode 100644 index 4855f86ad6..0000000000 --- a/dash/backend/base_server.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class BaseDashServer(ABC): - def __call__(self, server, *args, **kwargs) -> Any: - # Default: WSGI - return server(*args, **kwargs) - - @abstractmethod - def create_app( - self, name: str = "__main__", config=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, app, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self, app) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule( - self, app, rule: str, view_func, endpoint=None, methods=None - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run( - self, app, host: str, port: int, debug: bool, **kwargs - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response( - self, data, mimetype=None, content_type=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def get_request_adapter(self) -> Any: # pragma: no cover - interface - pass diff --git a/dash/backend/registry.py b/dash/backend/registry.py deleted file mode 100644 index 4aae9fafc5..0000000000 --- a/dash/backend/registry.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -_backend_imports = { - "flask": ("dash.backend.flask", "FlaskDashServer"), - "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"), - "quart": ("dash.backend.quart", "QuartDashServer"), -} - - -def register_backend(name, module_path, class_name): - """Register a new backend by name.""" - _backend_imports[name.lower()] = (module_path, class_name) - - -def get_backend(name): - try: - module_name, class_name = _backend_imports[name.lower()] - module = importlib.import_module(module_name) - return getattr(module, class_name) - except KeyError as e: - raise ValueError(f"Unknown backend: {name}") from e - except ImportError as e: - raise ImportError( - f"Could not import module '{module_name}' for backend '{name}': {e}" - ) from e - except AttributeError as e: - raise AttributeError( - f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}" - ) from e diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py new file mode 100644 index 0000000000..940c8f18bd --- /dev/null +++ b/dash/backends/__init__.py @@ -0,0 +1,88 @@ +from .base_server import BaseDashServer, RequestAdapter + +from typing import Literal, Any +import importlib + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +_backend_imports = { + "flask": ("dash.backends._flask", "FlaskDashServer", "FlaskRequestAdapter"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer", "FastAPIRequestAdapter"), + "quart": ("dash.backends._quart", "QuartDashServer", "QuartRequestAdapter"), +} + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +def get_backend( + name: Literal["flask", "fastapi", "quart"] | str +) -> tuple[BaseDashServer, RequestAdapter]: + module_name, server_class, request_class = _backend_imports[name.lower()] + try: + module = importlib.import_module(module_name) + server = getattr(module, server_class) + request_adapter = getattr(module, request_class) + return server, request_adapter + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e + except ImportError as e: + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e + except AttributeError as e: + raise AttributeError( + f"Module '{module_name}' does not have class '{server_class}' for backend '{name}': {e}" + ) from e + + +def _is_flask_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from flask import Flask + + return isinstance(obj, Flask) + except ImportError: + return False + + +def _is_fastapi_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from fastapi import FastAPI + + return isinstance(obj, FastAPI) + except ImportError: + return False + + +def _is_quart_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from quart import Quart + + return isinstance(obj, Quart) + except ImportError: + return False + + +def get_server_type(server): + if _is_flask_instance(server): + return "flask" + if _is_quart_instance(server): + return "quart" + if _is_fastapi_instance(server): + return "fastapi" + raise ValueError("Invalid backend argument") + + +__all__ = [ + "get_backend", + "request_adapter", + "backend", + "get_server_type", +] diff --git a/dash/backend/fastapi.py b/dash/backends/_fastapi.py similarity index 72% rename from dash/backend/fastapi.py rename to dash/backends/_fastapi.py index 8c402cb187..f3f9f2df33 100644 --- a/dash/backend/fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,46 +1,71 @@ +from __future__ import annotations + +from contextvars import copy_context, ContextVar +from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes import hashlib import inspect import pkgutil -from contextvars import copy_context -import importlib.util import time import traceback -import re - -try: - import uvicorn - from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse - from fastapi.staticfiles import StaticFiles - from starlette.responses import Response as StarletteResponse - from starlette.datastructures import MutableHeaders - from pydantic import create_model - from typing import Any, Optional -except ImportError: - uvicorn = None - FastAPI = None - Request = None - Response = None - JSONResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - create_model = None - Any = None - Optional = None - - +from importlib.util import spec_from_file_location import json import os +import re + from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import ( - PreventUpdate, -) -from dash.backend import set_request_adapter -from .base_server import BaseDashServer +from dash.exceptions import PreventUpdate +from .base_server import BaseDashServer, RequestAdapter + +from fastapi import FastAPI, Request, Response, Body +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response as StarletteResponse +from starlette.datastructures import MutableHeaders +from starlette.types import ASGIApp, Scope, Receive, Send +import uvicorn + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash.dash import Dash + + +_current_request_var = ContextVar("dash_current_request", default=None) + + +def set_current_request(req): + return _current_request_var.set(req) + + +def reset_current_request(token): + _current_request_var.reset(token) + + +def get_current_request() -> Request: + req = _current_request_var.get() + if req is None: + raise RuntimeError("No active request in context") + return req + + +class CurrentRequestMiddleware: + def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] + # non-http/ws scopes pass through (lifespan etc.) + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + token = set_current_request(request) + try: + await self.app(scope, receive, send) + finally: + reset_current_request(token) + CONFIG_PATH = "dash_config.json" @@ -58,28 +83,35 @@ def load_config(): class FastAPIDashServer(BaseDashServer): - def __init__(self): + + def __init__(self, server: FastAPI): + self.config = {} + self.server_type = "fastapi" + self.server: FastAPI = server self.error_handling_mode = "prune" super().__init__() - def __call__(self, server, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any): # ASGI: (scope, receive, send) if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return server(*args, **kwargs) + return self.server(*args, **kwargs) raise TypeError("FastAPI app must be called with (scope, receive, send)") - def create_app(self, name="__main__", config=None): + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() + app.add_middleware(CurrentRequestMiddleware) + if config: for key, value in config.items(): setattr(app.state, key, value) return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str ): try: - app.mount( + self.server.mount( assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name, @@ -88,7 +120,7 @@ def register_assets_blueprint( # directory doesnt exist pass - def register_error_handlers(self, app): + def register_error_handlers(self): self.error_handling_mode = "prune" def _get_traceback(self, _secret, error: Exception): @@ -200,13 +232,13 @@ def _get_traceback(self, _secret, error: Exception): """ return html - def register_prune_error_handler(self, _app, _secret, prune_errors): + def register_prune_error_handler(self, _secret, prune_errors): if prune_errors: self.error_handling_mode = "prune" else: self.error_handling_mode = "raise" - def _html_response_wrapper(self, view_func): + def _html_response_wrapper(self, view_func: Callable[..., Any] | str): async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly html = view_func() if callable(view_func) else view_func @@ -214,40 +246,40 @@ async def wrapped(*_args, **_kwargs): return wrapped - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): async def index(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, dash_app): - @dash_app.server.on_event("startup") + def setup_catchall(self, dash_app: Dash): + @self.server.on_event("startup") def _setup_catchall(): - config = load_config() - dash_app.enable_dev_tools(**config, first_run=False) + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled async def catchall(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) def add_url_rule( - self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False + self, + rule: str, + view_func: Callable[..., Any] | str, + endpoint: str | None = None, + methods: list[str] | None = None, + include_in_schema: bool = False, ): if rule == "": rule = "/" if isinstance(view_func, str): # Wrap string or sync function to async FastAPI handler view_func = self._html_response_wrapper(view_func) - app.add_api_route( + self.server.add_api_route( rule, view_func, methods=methods or ["GET"], @@ -255,15 +287,15 @@ def add_url_rule( include_in_schema=include_in_schema, ) - def before_request(self, app, func): + def before_request(self, func: Callable[[], Any] | None): # FastAPI does not have before_request, but we can use middleware - app.middleware("http")(self._make_before_middleware(func)) + self.server.middleware("http")(self._make_before_middleware(func)) - def after_request(self, app, func): + def after_request(self, func: Callable[[], Any] | None): # FastAPI does not have after_request, but we can use middleware - app.middleware("http")(self._make_after_middleware(func)) + self.server.middleware("http")(self._make_after_middleware(func)) - def run(self, dash_app, app, host, port, debug, **kwargs): + def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( {"debug": debug} if debug else {}, @@ -278,7 +310,8 @@ def run(self, dash_app, app, host, port, debug, **kwargs): if kwargs.get("reload"): # Dynamically determine the module name from the file path file_path = frame.filename - module_name = importlib.util.spec_from_file_location("app", file_path).name + spec = spec_from_file_location("app", file_path) + module_name = spec.name if spec and getattr(spec, "name", None) else "app" uvicorn.run( f"{module_name}:app.server", host=host, @@ -286,9 +319,14 @@ def run(self, dash_app, app, host, port, debug, **kwargs): **kwargs, ) else: - uvicorn.run(app, host=host, port=port, **kwargs) + uvicorn.run(self.server, host=host, port=port, **kwargs) - def make_response(self, data, mimetype=None, content_type=None): + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): headers = {} if mimetype: headers["content-type"] = mimetype @@ -296,13 +334,10 @@ def make_response(self, data, mimetype=None, content_type=None): headers["content-type"] = content_type return Response(content=data, headers=headers) - def jsonify(self, obj): + def jsonify(self, obj: Any): return JSONResponse(content=obj) - def get_request_adapter(self): - return FastAPIRequestAdapter - - def _make_before_middleware(self, _func): + def _make_before_middleware(self, func: Callable[[], Any] | None): async def middleware(request, call_next): try: response = await call_next(request) @@ -322,7 +357,7 @@ async def middleware(request, call_next): return middleware - def _make_after_middleware(self, func): + def _make_after_middleware(self, func: Callable[[], Any] | None): async def middleware(request, call_next): response = await call_next(request) if func is not None: @@ -335,8 +370,13 @@ async def middleware(request, call_next): return middleware def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request + self, + dash_app: Dash, + package_name: str, + fingerprinted_path: str, + request: Request, ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -360,7 +400,7 @@ def serve_component_suites( return StarletteResponse(status_code=304) return StarletteResponse(content=data, media_type=mimetype, headers=headers) - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -373,16 +413,12 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, dash_app: Dash): + async def _dispatch(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context( - body, adapter - ) # pylint: disable=protected-access + g = dash_app._initialize_context(body) # pylint: disable=protected-access func = dash_app._prepare_callback( g, body ) # pylint: disable=protected-access @@ -406,12 +442,12 @@ def _serve_default_favicon(self): content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) - def register_timing_hooks(self, app, first_run): + def register_timing_hooks(self, first_run: bool): if not first_run: return - @app.middleware("http") - async def timing_middleware(request, call_next): + @self.server.middleware("http") + async def timing_middleware(request: Request, call_next): # Before request request.state.timing_information = { "__dash_server": {"dur": time.time(), "desc": None} @@ -433,11 +469,11 @@ async def timing_middleware(request, call_next): headers.append("Server-Timing", value) return response - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): """ Register callback API endpoints on the FastAPI app. Each key in callback_api_paths is a route, each value is a handler (sync or async). - Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + Accepts a JSON body (dict) and filters keys based on the handler's signature. """ for path, handler in callback_api_paths.items(): endpoint = f"dash_callback_api_{path}" @@ -445,21 +481,19 @@ def register_callback_api_routes(self, app, callback_api_paths): methods = ["POST"] sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) - fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model( - f"Payload_{endpoint}", **fields - ) # pylint: disable=cell-var-from-loop - - # pylint: disable=cell-var-from-loop - async def view_func(request: Request, body: Model): - kwargs = body.dict(exclude_unset=True) + + async def view_func(request: Request, body: dict = Body(...)): + # Only pass expected params; ignore extras + kwargs = { + k: v for k, v in body.items() if k in param_names and v is not None + } if inspect.iscoroutinefunction(handler): result = await handler(**kwargs) else: result = handler(**kwargs) return JSONResponse(content=result) - app.add_api_route( + self.server.add_api_route( route, view_func, methods=methods, @@ -468,44 +502,58 @@ async def view_func(request: Request, body: Model): ) -class FastAPIRequestAdapter: +class FastAPIRequestAdapter(RequestAdapter): + def __init__(self): - self._request = None + self._request: Request = get_current_request() + super().__init__() - def set_request(self, request: Request): - self._request = request + def __call__(self): + self._request = get_current_request() + return self - def get_root(self): + @property + def root(self): return str(self._request.base_url) - def get_args(self): + @property + def args(self): return self._request.query_params - async def get_json(self): - return await self._request.json() - + @property def is_json(self): return self._request.headers.get("content-type", "").startswith( "application/json" ) - def get_cookies(self, _request=None): + @property + def cookies(self): return self._request.cookies - def get_headers(self): + @property + def headers(self): return self._request.headers - def get_full_path(self): + @property + def full_path(self): return str(self._request.url) - def get_url(self): + @property + def url(self): return str(self._request.url) - def get_remote_addr(self): - return self._request.client.host if self._request.client else None + @property + def remote_addr(self): + client = getattr(self._request, "client", None) + return getattr(client, "host", None) - def get_origin(self): + @property + def origin(self): return self._request.headers.get("origin") - def get_path(self): + @property + def path(self): return self._request.url.path + + async def get_json(self): # async method retained + return await self._request.json() diff --git a/dash/backend/flask.py b/dash/backends/_flask.py similarity index 55% rename from dash/backend/flask.py rename to dash/backends/_flask.py index cf544ef5bc..5a1385d574 100644 --- a/dash/backend/flask.py +++ b/dash/backends/_flask.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from contextvars import copy_context +from typing import TYPE_CHECKING, Any, Callable, Dict import asyncio import pkgutil import sys @@ -6,43 +9,60 @@ import time import inspect import traceback -import flask +from flask import ( + Flask, + Blueprint, + Response, + request, + jsonify, + g as flask_g, +) + from dash.fingerprint import check_fingerprint from dash import _validate -from dash._callback import _invoke_callback, _async_invoke_callback from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from .base_server import BaseDashServer +from dash._callback import _invoke_callback, _async_invoke_callback +from .base_server import BaseDashServer, RequestAdapter + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash import Dash class FlaskDashServer(BaseDashServer): - def __call__(self, server, *args, **kwargs): + + def __init__(self, server: Flask) -> None: + self.server: Flask = server + self.server_type = "flask" + super().__init__() + + def __call__(self, *args: Any, **kwargs: Any): # Always WSGI - return server(*args, **kwargs) + return self.server(*args, **kwargs) - def create_app(self, name="__main__", config=None): - app = flask.Flask(name) + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = Flask(name) if config: app.config.update(config) return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str ): - bp = flask.Blueprint( + bp = Blueprint( blueprint_name, __name__, static_folder=assets_folder, static_url_path=assets_url_path, ) - app.register_blueprint(bp) + self.server.register_blueprint(bp) - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) + def register_error_handlers(self): + @self.server.errorhandler(PreventUpdate) def _handle_error(_): return "", 204 - @app.errorhandler(InvalidResourceError) + @self.server.errorhandler(InvalidResourceError) def _invalid_resources_handler(err): return err.args[0], 404 @@ -86,56 +106,64 @@ def _do_skip(error): ).render_debugger_html(True, secret, True) return "".join(traceback.format_exception(type(error), error, _do_skip(error))) - def register_prune_error_handler(self, app, secret, prune_errors): + def register_prune_error_handler(self, secret, prune_errors): if prune_errors: - @app.errorhandler(Exception) + @self.server.errorhandler(Exception) def _wrap_errors(error): tb = self._get_traceback(secret, error) return tb, 500 - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - app.after_request(func) + def before_request(self, func: Callable[[], Any]): + # Flask expects a callable; user responsibility not to pass None + self.server.before_request(func) - def run(self, _dash_app, app, host, port, debug, **kwargs): - app.run(host=host, port=port, debug=debug, **kwargs) + def after_request(self, func: Callable[[Any], Any]): + # Flask after_request expects a function(response) -> response + self.server.after_request(func) - def make_response(self, data, mimetype=None, content_type=None): - return flask.Response(data, mimetype=mimetype, content_type=content_type) + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any): + self.server.run(host=host, port=port, debug=debug, **kwargs) - def jsonify(self, obj): - return flask.jsonify(obj) + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + return Response(data, mimetype=mimetype, content_type=content_type) - def get_request_adapter(self): - return FlaskRequestAdapter + def jsonify(self, obj: Any): + return jsonify(obj) - def setup_catchall(self, dash_app): + def setup_catchall(self, dash_app: Dash): def catchall(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): def index(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def serve_component_suites(self, dash_app, package_name, fingerprinted_path): + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -149,18 +177,18 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path): package.__path__, ) data = pkgutil.get_data(package_name, path_in_pkg) - response = flask.Response(data, mimetype=mimetype) + response = Response(data, mimetype=mimetype) if has_fingerprint: response.cache_control.max_age = 31536000 # 1 year else: response.add_etag() tag = response.get_etag()[0] - request_etag = flask.request.headers.get("If-None-Match") + request_etag = request.headers.get("If-None-Match") if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) + response = Response(None, status=304) return response - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -173,17 +201,15 @@ def serve(package_name, fingerprinted_path): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, dash_app: Dash): def _dispatch(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() + body = request.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): raise Exception( @@ -191,43 +217,41 @@ def _dispatch(): "Please install the dependencies via `pip install dash[async]` and ensure " "that `use_async=False` is not being passed to the app." ) - g.dash_response.set_data(response_data) - return g.dash_response + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response async def _dispatch_async(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() + body = request.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): response_data = await response_data - g.dash_response.set_data(response_data) - return g.dash_response + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response - if use_async: + if dash_app._use_async: return _dispatch_async return _dispatch def _serve_default_favicon(self): - - return flask.Response( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) - def register_timing_hooks(self, app, _first_run): - def _before_request(): - flask.g.timing_information = { + def register_timing_hooks(self, _first_run: bool): + # Define timing hooks inside method scope and register them + def _before_request() -> None: + flask_g.timing_information = { # type: ignore[attr-defined] "__dash_server": {"dur": time.time(), "desc": None} } - def _after_request(response): - timing_information = flask.g.get("timing_information", None) + def _after_request(response: Response): # type: ignore[name-defined] + timing_information = flask_g.get("timing_information", None) # type: ignore[attr-defined] if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -242,10 +266,10 @@ def _after_request(response): response.headers.add("Server-Timing", value) return response - self.before_request(app, _before_request) - self.after_request(app, _after_request) + self.before_request(_before_request) + self.after_request(_after_request) - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): """ Register callback API endpoints on the Flask app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -258,65 +282,79 @@ def register_callback_api_routes(self, app, callback_api_paths): if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() + async def _async_view_func(*args, handler=handler, **kwargs): + data = request.get_json() result = await handler(**data) if data else await handler() - return flask.jsonify(result) + return jsonify(result) + view_func = _async_view_func else: - def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() + def _sync_view_func(*args, handler=handler, **kwargs): + data = request.get_json() result = handler(**data) if data else handler() - return flask.jsonify(result) + return jsonify(result) + + view_func = _sync_view_func + + view_func = _sync_view_func # Flask 2.x+ supports async views natively - app.add_url_rule( + self.server.add_url_rule( route, endpoint=endpoint, view_func=view_func, methods=methods ) -class FlaskRequestAdapter: - @staticmethod - def get_args(): - return flask.request.args +class FlaskRequestAdapter(RequestAdapter): + """Flask implementation using property-based accessors.""" - @staticmethod - def get_root(): - return flask.request.url_root + def __init__(self) -> None: + # Store the request LocalProxy so we can reference it consistently + self._request = request + super().__init__() - @staticmethod - def get_json(): - return flask.request.get_json() + def __call__(self, *args: Any, **kwds: Any): + return self - @staticmethod - def is_json(): - return flask.request.is_json + @property + def args(self): + return self._request.args - @staticmethod - def get_cookies(): - return flask.request.cookies + @property + def root(self): + return self._request.url_root - @staticmethod - def get_headers(): - return flask.request.headers + def get_json(self): # kept as method + return self._request.get_json() - @staticmethod - def get_url(): - return flask.request.url + @property + def is_json(self): + return self._request.is_json - @staticmethod - def get_full_path(): - return flask.request.full_path + @property + def cookies(self): + return self._request.cookies - @staticmethod - def get_remote_addr(): - return flask.request.remote_addr + @property + def headers(self): + return self._request.headers - @staticmethod - def get_origin(): - return getattr(flask.request, "origin", None) + @property + def url(self): + return self._request.url - @staticmethod - def get_path(): - return flask.request.path + @property + def full_path(self): + return self._request.full_path + + @property + def remote_addr(self): + return self._request.remote_addr + + @property + def origin(self): + return getattr(self._request, "origin", None) + + @property + def path(self): + return self._request.path diff --git a/dash/backend/quart.py b/dash/backends/_quart.py similarity index 68% rename from dash/backend/quart.py rename to dash/backends/_quart.py index 830d7dd3b9..a462d07af6 100644 --- a/dash/backend/quart.py +++ b/dash/backends/_quart.py @@ -1,61 +1,68 @@ +from __future__ import annotations +from contextvars import copy_context +import typing as _t +import traceback +import mimetypes import inspect import pkgutil -import mimetypes -import sys import time -from contextvars import copy_context -import traceback +import sys import re -try: - import quart - from quart import Quart, Response, jsonify, request, Blueprint -except ImportError: - quart = None - Quart = None - Response = None - jsonify = None - request = None - Blueprint = None +# Attempt top-level Quart imports; allow absence if user not using quart backend +from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g, +) + +if _t.TYPE_CHECKING: + from dash import Dash + from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from .base_server import BaseDashServer class QuartDashServer(BaseDashServer): - """Quart implementation of the Dash server factory. - - All Quart/async specific imports are at the top-level (per user request) so - Quart must be installed when this module is imported. - """ - def __init__(self) -> None: + def __init__(self, server: Quart) -> None: + self.server_type = "quart" + self.server: Quart = server self.config = {} self.error_handling_mode = "prune" super().__init__() - def __call__(self, server, *args, **kwargs): - return server(*args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] + return self.server(*args, **kwargs) - def create_app(self, name="__main__", config=None): - app = Quart(name) + @staticmethod + def create_app(name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None): + if Quart is None: + raise RuntimeError( + "Quart is not installed. Install with 'pip install quart' to use the quart backend." + ) + app = Quart(name) # type: ignore if config: for key, value in config.items(): app.config[key] = value return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str # type: ignore[name-defined] ): + bp = Blueprint( blueprint_name, __name__, static_folder=assets_folder, static_url_path=assets_url_path, ) - app.register_blueprint(bp) + self.server.register_blueprint(bp) def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ @@ -166,27 +173,30 @@ def _get_traceback(self, _secret, error: Exception): """ return html - def register_prune_error_handler(self, app, secret, prune_errors): + def register_prune_error_handler(self, secret, prune_errors): if prune_errors: self.error_handling_mode = "prune" else: self.error_handling_mode = "raise" - @app.errorhandler(Exception) + @self.server.errorhandler(Exception) async def _wrap_errors(error): tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") - def register_timing_hooks(self, app, _first_run): # parity with Flask factory - @app.before_request + def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + @self.server.before_request async def _before_request(): # pragma: no cover - timing infra - quart.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } + if g is not None: + g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } - @app.after_request + @self.server.after_request async def _after_request(response): # pragma: no cover - timing infra - timing_information = getattr(quart.g, "timing_information", None) + timing_information = ( + getattr(g, "timing_information", None) if g is not None else None + ) if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -205,16 +215,17 @@ async def _after_request(response): # pragma: no cover - timing infra response.headers["Server-Timing"] = value return response - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) + def register_error_handlers(self): # type: ignore[name-defined] + @self.server.errorhandler(PreventUpdate) async def _prevent_update(_): return "", 204 - @app.errorhandler(InvalidResourceError) + @self.server.errorhandler(InvalidResourceError) async def _invalid_resource(err): return err.args[0], 404 - def _html_response_wrapper(self, view_func): + def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): + async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html @@ -224,38 +235,40 @@ async def wrapped(*_args, **_kwargs): return wrapped - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( + def add_url_rule( + self, + rule: str, + view_func: _t.Callable[..., _t.Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): # type: ignore[name-defined] + async def index(*args, **kwargs): - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, dash_app): + def setup_catchall(self, dash_app: Dash): + async def catchall( - path, *args, **kwargs + path: str, *args, **kwargs ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) - def before_request(self, app, func): - app.before_request(func) + def before_request(self, func: _t.Callable[[], _t.Any]): + self.server.before_request(func) - def after_request(self, app, func): - @app.after_request + def after_request(self, func: _t.Callable[[], _t.Any]): + @self.server.after_request async def _after(response): if func is not None: result = func() @@ -263,21 +276,25 @@ async def _after(response): await result return response - def run(self, _dash_app, app, host, port, debug, **kwargs): + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs - app.run(host=host, port=port, debug=debug, **kwargs) + self.server.run(host=host, port=port, debug=debug, **kwargs) - def make_response(self, data, mimetype=None, content_type=None): + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) - def get_request_adapter(self): - return QuartRequestAdapter - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path + self, dash_app: Dash, package_name: str, fingerprinted_path: str ): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) @@ -296,9 +313,11 @@ def serve_component_suites( if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response(data, content_type=mimetype, headers=headers) - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -311,14 +330,13 @@ async def serve(package_name, fingerprinted_path): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=True): # Quart always async + def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + async def _dispatch(): adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - body = await request.get_json() + body = await adapter.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) + g = dash_app._initialize_context(body) # pylint: disable=protected-access func = dash_app._prepare_callback(g, body) # pylint: disable=protected-access @@ -329,11 +347,11 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return Response(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") # type: ignore[arg-type] return _dispatch - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]]): """ Register callback API endpoints on the Quart app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -348,25 +366,33 @@ def _make_view_func(handler): if inspect.iscoroutinefunction(handler): async def async_view_func(*args, **kwargs): + if request is None: + raise RuntimeError( + "Quart not installed; request unavailable" + ) data = await request.get_json() result = await handler(**data) if data else await handler() - return jsonify(result) + return jsonify(result) # type: ignore[arg-type] return async_view_func async def sync_view_func(*args, **kwargs): + if request is None: + raise RuntimeError("Quart not installed; request unavailable") data = await request.get_json() result = handler(**data) if data else handler() - return jsonify(result) + return jsonify(result) # type: ignore[arg-type] return sync_view_func view_func = _make_view_func(handler) - app.add_url_rule( + self.server.add_url_rule( route, endpoint=endpoint, view_func=view_func, methods=methods ) def _serve_default_favicon(self): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -374,41 +400,53 @@ def _serve_default_favicon(self): class QuartRequestAdapter: def __init__(self) -> None: - self._request = None + self._request = request # type: ignore[assignment] + if self._request is None: + raise RuntimeError("Quart not installed; cannot access request context") - def set_request(self) -> None: - self._request = request + @property + def request(self) -> _t.Any: + return self._request - # Accessors (instance-based) - def get_root(self): - return self._request.root_url + @property + def root(self): + return self.request.root_url - def get_args(self): - return self._request.args - - async def get_json(self): - return await self._request.get_json() + @property + def args(self): + return self.request.args + @property def is_json(self): - return self._request.is_json + return self.request.is_json + + @property + def cookies(self): + return self.request.cookies - def get_cookies(self): - return self._request.cookies + @property + def headers(self): + return self.request.headers - def get_headers(self): - return self._request.headers + @property + def full_path(self): + return self.request.full_path - def get_full_path(self): - return self._request.full_path + @property + def url(self): + return str(self.request.url) - def get_url(self): - return str(self._request.url) + @property + def remote_addr(self): + return self.request.remote_addr - def get_remote_addr(self): - return self._request.remote_addr + @property + def origin(self): + return self.request.headers.get("origin") - def get_origin(self): - return self._request.headers.get("origin") + @property + def path(self): + return self.request.path - def get_path(self): - return self._request.path + async def get_json(self): + return await self.request.get_json() diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py new file mode 100644 index 0000000000..1c47548ad0 --- /dev/null +++ b/dash/backends/base_server.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseDashServer(ABC): + server_type: str + server: Any + config: dict[str, Any] + + def __call__(self, *args, **kwargs) -> Any: + # Default: WSGI + return self.server(*args, **kwargs) + + @staticmethod + @abstractmethod + def create_app( + name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule( + self, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def run( + self, dash_app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + +class RequestAdapter(ABC): + def __call__(self) -> Any: + return self + + # Properties to be implemented in concrete adapters + @property # pragma: no cover - interface + @abstractmethod + def root(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def args(self): + raise NotImplementedError() + + @abstractmethod # kept as method (may be sync or async) + def get_json(self): # pragma: no cover - interface + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def is_json(self) -> bool: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def cookies(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def headers(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def full_path(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def url(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def remote_addr(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def origin(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def path(self) -> str: + raise NotImplementedError() diff --git a/dash/dash.py b/dash/dash.py index 6bba3aadfd..1ed05657dc 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -25,7 +25,6 @@ from dash import dcc from dash import html from dash import dash_table - from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( @@ -38,7 +37,7 @@ ProxyError, DuplicateCallback, ) -from .backend import get_request_adapter, get_backend +from .backends import get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -63,6 +62,7 @@ from . import _validate from . import _watch from . import _get_app +from . import backends from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -154,36 +154,6 @@ page_container = None -def _is_flask_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from flask import Flask - - return isinstance(obj, Flask) - except ImportError: - return False - - -def _is_fastapi_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from fastapi import FastAPI - - return isinstance(obj, FastAPI) - except ImportError: - return False - - -def _is_quart_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from quart import Quart - - return isinstance(obj, Quart) - except ImportError: - return False - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -446,74 +416,41 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches **obsolete, ): - if use_async is None: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - - use_async = True - except ImportError: - pass - elif use_async: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - except ImportError as exc: - raise Exception( - "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" - ) from exc - + _validate.check_async(use_async) _validate.check_obsolete(obsolete) caller_name: str = name if name is not None else get_caller_name() # Determine backend if backend is None: - backend_cls = get_backend("flask") + backend_cls, request_cls = get_backend("flask") elif isinstance(backend, str): - backend_cls = get_backend(backend) + backend_cls, request_cls = get_backend(backend) elif isinstance(backend, type): backend_cls = backend + _, request_cls = get_backend(backend.server_type) else: raise ValueError("Invalid backend argument") # Determine server and backend instance if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) - if _is_flask_instance(server): - inferred_backend = "flask" - elif _is_quart_instance(server): - inferred_backend = "quart" - elif _is_fastapi_instance(server): - inferred_backend = "fastapi" - else: - raise ValueError("Unsupported server type") - # Validate that backend matches server type if both are provided - if backend is not None: - if isinstance(backend, type): - # get_backend returns the backend class for a string - # So we compare the class names - expected_backend_cls = get_backend(inferred_backend) - if ( - backend.__module__ != expected_backend_cls.__module__ - or backend.__name__ != expected_backend_cls.__name__ - ): - raise ValueError( - f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." - ) - elif not isinstance(backend, str): - raise ValueError("Invalid backend argument") - elif backend.lower() != inferred_backend: - raise ValueError( - f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." - ) - backend_cls = get_backend(inferred_backend) + inferred_backend = backends.get_server_type(server) + _validate.check_backend(backend, inferred_backend) + backend_cls, request_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) - self.backend = backend_cls() + + self.backend = backend_cls(server) self.server = server + backends.backend = self.backend # type: ignore + backends.request_adapter = request_cls else: # No server instance provided, create backend and let backend create server - self.backend = backend_cls() - self.server = self.backend.create_app(caller_name) # type: ignore + self.server = backend_cls.create_app(caller_name) # type: ignore + self.backend = backend_cls(self.server) + backends.backend = self.backend + backends.request_adapter = request_cls base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -710,7 +647,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" self.backend.register_assets_blueprint( - self.server, assets_blueprint_name, config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), self.config.assets_folder, @@ -732,8 +668,9 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - self.backend.register_error_handlers(self.server) - self.backend.before_request(self.server, self._setup_server) + + self.backend.register_error_handlers() + self.backend.before_request(self._setup_server) self._setup_routes() _get_app.APP = self self.enable_pages() @@ -742,7 +679,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name self.backend.add_url_rule( - self.server, full_name, view_func=view_func, endpoint=full_name, @@ -756,7 +692,7 @@ def _setup_routes(self): self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.backend.dispatch(self.server, self, self._use_async), + self.backend.dispatch(self), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) @@ -803,7 +739,7 @@ def setup_apis(self): self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) # Delegate to the server factory for route registration - self.backend.register_callback_api_routes(self.server, self.callback_api_paths) + self.backend.register_callback_api_routes(self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -1043,9 +979,11 @@ def _generate_css_dist_html(self): return "\n".join( [ - format_tag("link", link, opened=True) - if isinstance(link, dict) - else f'' + ( + format_tag("link", link, opened=True) + if isinstance(link, dict) + else f'' + ) for link in (external_links + links) ] ) @@ -1099,9 +1037,11 @@ def _generate_scripts_html(self) -> str: return "\n".join( [ - format_tag("script", src) - if isinstance(src, dict) - else f'' + ( + format_tag("script", src) + if isinstance(src, dict) + else f'' + ) for src in srcs ] + [f"" for src in self._inline_scripts] @@ -1139,11 +1079,8 @@ def index(self, *_args, **_kwargs): metas = self._generate_meta() renderer = self._generate_renderer() title = self.title - try: - request = get_request_adapter() - except LookupError: - # no request context - request = None + # Refactored: direct access to global request adapter + request = backends.request_adapter() if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas @@ -1357,8 +1294,9 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body, adapter): + def _initialize_context(self, body): """Initialize the global context for the request.""" + adapter = backends.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -1372,12 +1310,12 @@ def _initialize_context(self, body, adapter): g.dash_response = self.backend.make_response( mimetype="application/json", data=None ) - g.cookies = dict(adapter.get_cookies()) - g.headers = dict(adapter.get_headers()) - g.args = adapter.get_args() - g.path = adapter.get_full_path() - g.remote = adapter.get_remote_addr() - g.origin = adapter.get_origin() + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin g.updated_props = {} return g @@ -1964,15 +1902,21 @@ def enable_dev_tools( packages[index] = dash_spec component_packages_dist = [ - dash_test_path # type: ignore[reportPossiblyUnboundVariable] - if isinstance(package, ModuleSpec) - else os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] - if hasattr(package, "path") - else os.path.dirname( - package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access - ) - if hasattr(package, "_path") - else package.filename # type: ignore[reportAttributeAccessIssue] + ( + dash_test_path # type: ignore[reportPossiblyUnboundVariable] + if isinstance(package, ModuleSpec) + else ( + os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] + if hasattr(package, "path") + else ( + os.path.dirname( + package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access + ) + if hasattr(package, "_path") + else package.filename + ) + ) + ) # type: ignore[reportAttributeAccessIssue] for package in packages ] @@ -2000,13 +1944,14 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - secret = gen_salt(20) - self.backend.register_prune_error_handler( - self.server, secret, dev_tools.prune_errors - ) + elif dev_tools.prune_errors: + secret = gen_salt(20) + self.backend.register_prune_error_handler( + secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: - self.backend.register_timing_hooks(self.server, first_run) + self.backend.register_timing_hooks(first_run) if ( debug @@ -2290,13 +2235,8 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.backend.run( - self, - self.server, - host=host, - port=port, - debug=debug, - **flask_run_options, + backends.backend.run( + dash_app=self, host=host, port=port, debug=debug, **flask_run_options ) def enable_pages(self) -> None: @@ -2368,9 +2308,11 @@ async def update(pathname_, search_, **states): if not self.config.suppress_callback_exceptions: self.validation_layout = html.Div( [ - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] + ( + asyncio.run(execute_async_function(page["layout"])) + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + [ @@ -2439,9 +2381,11 @@ def update(pathname_, search_, **states): ] self.validation_layout = html.Div( [ - page["layout"]() - if callable(page["layout"]) - else page["layout"] + ( + page["layout"]() + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + layout @@ -2460,7 +2404,7 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - self.backend.before_request(self.server, router) + self.backend.before_request(router) def __call__(self, *args, **kwargs): - return self.backend.__call__(self.server, *args, **kwargs) + return self.backend.__call__(*args, **kwargs) diff --git a/dash_config.json b/dash_config.json new file mode 100644 index 0000000000..3afa0d11f1 --- /dev/null +++ b/dash_config.json @@ -0,0 +1 @@ +{"debug": true, "dev_tools_ui": true, "dev_tools_props_check": true, "dev_tools_serve_dev_bundles": true, "dev_tools_hot_reload": true, "dev_tools_silence_routes_logging": true, "dev_tools_prune_errors": true, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": false} \ No newline at end of file diff --git a/quart_app.py b/quart_app.py new file mode 100644 index 0000000000..54d40add56 --- /dev/null +++ b/quart_app.py @@ -0,0 +1,23 @@ +from dash import Dash, html, Input, Output +from dash import dcc +from dash import backends + +app = Dash(__name__, backend="quart") + +app.layout = html.Div( + [ + html.H2("Quart Server Factory Example"), + html.Div("Type below to see async callback update."), + dcc.Input(id="text", value="hello", autoComplete="off"), + html.Div(id="echo"), + ] +) + + +@app.callback(Output("echo", "children"), Input("text", "value")) +def update_echo(val): + return f"You typed: {val}" if val else "Type something" + + +if __name__ == "__main__": + app.run(debug=True) From c4795ed3b544964c259fe21cb81746911fb7e6aa Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:58:06 -0400 Subject: [PATCH 62/74] fixes for failing tests --- .gitignore | 1 + dash/backends/_fastapi.py | 5 +++-- dash/dash.py | 9 ++++----- dash_config.json | 2 +- tests/backend_tests/test_preconfig_backends.py | 12 ++++++------ 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 89029448fe..06e855e2dc 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,4 @@ packages/ !components/dash-core-components/tests/integration/upload/upload-assets/upft001.csv !components/dash-table/tests/assets/*.csv !components/dash-table/tests/selenium/assets/*.csv +dash_config.json diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index f3f9f2df33..57cf18b6ec 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,6 +52,7 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app + print('loaded CurrentRequestMiddleware') async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) @@ -100,7 +101,6 @@ def __call__(self, *args: Any, **kwargs: Any): @staticmethod def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() - app.add_middleware(CurrentRequestMiddleware) if config: for key, value in config.items(): @@ -257,7 +257,7 @@ def setup_catchall(self, dash_app: Dash): @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( - **self.config, first_run=False + **load_config(), first_run=False ) # do this to make sure dev tools are enabled async def catchall(request: Request): @@ -289,6 +289,7 @@ def add_url_rule( def before_request(self, func: Callable[[], Any] | None): # FastAPI does not have before_request, but we can use middleware + self.server.add_middleware(CurrentRequestMiddleware) self.server.middleware("http")(self._make_before_middleware(func)) def after_request(self, func: Callable[[], Any] | None): diff --git a/dash/dash.py b/dash/dash.py index 1ed05657dc..3ab830e8a3 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1944,11 +1944,10 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - elif dev_tools.prune_errors: - secret = gen_salt(20) - self.backend.register_prune_error_handler( - secret, dev_tools.prune_errors - ) + secret = gen_salt(20) + self.backend.register_prune_error_handler( + secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: self.backend.register_timing_hooks(first_run) diff --git a/dash_config.json b/dash_config.json index 3afa0d11f1..e4af4373cb 100644 --- a/dash_config.json +++ b/dash_config.json @@ -1 +1 @@ -{"debug": true, "dev_tools_ui": true, "dev_tools_props_check": true, "dev_tools_serve_dev_bundles": true, "dev_tools_hot_reload": true, "dev_tools_silence_routes_logging": true, "dev_tools_prune_errors": true, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": false} \ No newline at end of file +{"dev_tools_ui": false, "dev_tools_props_check": false, "dev_tools_serve_dev_bundles": false, "dev_tools_hot_reload": false, "dev_tools_silence_routes_logging": false, "dev_tools_prune_errors": false, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": true} \ No newline at end of file diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 5fbd28dfd9..4c4ccc7083 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -30,7 +30,7 @@ def update_output(value): dash_duo.start_server(app) dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") - dash_duo.find_element("#input").clear() + dash_duo.clear_input(dash_duo.find_element("#input")) dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") assert dash_duo.get_logs() == [] @@ -93,7 +93,7 @@ def get_error_html(dash_duo, index): "dev_tools_prune_errors": False, "reload": False, }, - "fastapi.py", + "_fastapi.py", ), ( "quart", @@ -104,7 +104,7 @@ def get_error_html(dash_duo, index): "dev_tools_hot_reload": False, "dev_tools_prune_errors": False, }, - "quart.py", + "_quart.py", ), ], ) @@ -131,7 +131,7 @@ def error_callback(n): error0 = get_error_html(dash_duo, 0) assert "in error_callback" in error0 assert "ZeroDivisionError" in error0 - assert "backend" in error0 and error_msg in error0 + assert "backends/" in error0 and error_msg in error0 @pytest.mark.parametrize( @@ -173,7 +173,7 @@ def error_callback(n): error0 = get_error_html(dash_duo, 0) assert "in error_callback" in error0 assert "ZeroDivisionError" in error0 - assert "dash/backend" not in error0 and error_msg not in error0 + assert "dash/backends/" not in error0 and error_msg not in error0 @pytest.mark.parametrize( @@ -209,7 +209,7 @@ def update_output_bg(value): dash_duo.start_server(app) dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") - dash_duo.find_element("#input").clear() + dash_duo.clear_input(dash_duo.find_element("#input")) dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") dash_duo.wait_for_text_to_equal( "#output", f"Background typed: {backend.title()} BG Test" From 567d0f8d592e4281047794100c80a3728aa6b128 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:05:33 -0400 Subject: [PATCH 63/74] fixing formatting --- dash/backends/_fastapi.py | 9 ++++----- dash/backends/_flask.py | 13 +++++++++---- dash/backends/_quart.py | 13 ++++++------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 57cf18b6ec..540238f727 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,7 +52,7 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app - print('loaded CurrentRequestMiddleware') + print("loaded CurrentRequestMiddleware") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) @@ -84,7 +84,6 @@ def load_config(): class FastAPIDashServer(BaseDashServer): - def __init__(self, server: FastAPI): self.config = {} self.server_type = "fastapi" @@ -415,7 +414,6 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): - async def _dispatch(request: Request): # pylint: disable=protected-access body = await request.json() @@ -470,7 +468,9 @@ async def timing_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response - def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): """ Register callback API endpoints on the FastAPI app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -504,7 +504,6 @@ async def view_func(request: Request, body: dict = Body(...)): class FastAPIRequestAdapter(RequestAdapter): - def __init__(self): self._request: Request = get_current_request() super().__init__() diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 5a1385d574..138234a4bc 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -29,7 +29,6 @@ class FlaskDashServer(BaseDashServer): - def __init__(self, server: Flask) -> None: self.server: Flask = server self.server_type = "flask" @@ -209,7 +208,9 @@ def _dispatch(): func = dash_app._prepare_callback(cb_ctx, body) args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): raise Exception( @@ -227,7 +228,9 @@ async def _dispatch_async(): func = dash_app._prepare_callback(cb_ctx, body) args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): response_data = await response_data @@ -269,7 +272,9 @@ def _after_request(response: Response): # type: ignore[name-defined] self.before_request(_before_request) self.after_request(_after_request) - def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): """ Register callback API endpoints on the Flask app. Each key in callback_api_paths is a route, each value is a handler (sync or async). diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index a462d07af6..ff544c2c91 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -29,7 +29,6 @@ class QuartDashServer(BaseDashServer): - def __init__(self, server: Quart) -> None: self.server_type = "quart" self.server: Quart = server @@ -41,7 +40,9 @@ def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @staticmethod - def create_app(name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None): + def create_app( + name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None + ): if Quart is None: raise RuntimeError( "Quart is not installed. Install with 'pip install quart' to use the quart backend." @@ -225,7 +226,6 @@ async def _invalid_resource(err): return err.args[0], 404 def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): - async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html @@ -247,7 +247,6 @@ def add_url_rule( ) def setup_index(self, dash_app: Dash): # type: ignore[name-defined] - async def index(*args, **kwargs): return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] @@ -255,7 +254,6 @@ async def index(*args, **kwargs): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - async def catchall( path: str, *args, **kwargs ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument @@ -331,7 +329,6 @@ async def serve(package_name, fingerprinted_path): # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async - async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() @@ -351,7 +348,9 @@ async def _dispatch(): return _dispatch - def register_callback_api_routes(self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]]): + def register_callback_api_routes( + self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]] + ): """ Register callback API endpoints on the Quart app. Each key in callback_api_paths is a route, each value is a handler (sync or async). From a855c6db89e167ca01d87e79fc394e4bfd58c280 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:26:05 -0400 Subject: [PATCH 64/74] fixing issues --- dash/backends/__init__.py | 6 +----- dash/backends/_fastapi.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index 940c8f18bd..e4d4141bb8 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -15,12 +15,8 @@ } -request_adapter: RequestAdapter -backend: BaseDashServer - - def get_backend( - name: Literal["flask", "fastapi", "quart"] | str + name: str ) -> tuple[BaseDashServer, RequestAdapter]: module_name, server_class, request_class = _backend_imports[name.lower()] try: diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 540238f727..be2308d5f5 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,7 +52,6 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app - print("loaded CurrentRequestMiddleware") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) From 79afb0bab2d7d058cc8770b6de65eb7dcab656dd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:54:41 -0400 Subject: [PATCH 65/74] fixing async validation --- .github/workflows/testing.yml | 4 ++-- dash/_validate.py | 1 + dash/backends/__init__.py | 5 +---- dash/backends/_quart.py | 1 + dash/dash.py | 6 ++---- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index be5caf4929..c47e188222 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -39,8 +39,8 @@ jobs: - 'tests/async_tests/**' - 'requirements/**' backend_paths: - - 'dash/backend/**' - - 'tests/backend/**' + - 'dash/backends/**' + - 'tests/backend_tests/**' build: name: Build Dash Package diff --git a/dash/_validate.py b/dash/_validate.py index 76661cef6b..d595cba0fc 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -603,6 +603,7 @@ def check_async(use_async): raise Exception( "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" ) from exc + return use_async or False def check_backend(backend, inferred_backend): diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index e4d4141bb8..b845abb1ad 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -1,6 +1,5 @@ from .base_server import BaseDashServer, RequestAdapter -from typing import Literal, Any import importlib @@ -15,9 +14,7 @@ } -def get_backend( - name: str -) -> tuple[BaseDashServer, RequestAdapter]: +def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: module_name, server_class, request_class = _backend_imports[name.lower()] try: module = importlib.import_module(module_name) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index ff544c2c91..c5759026d4 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -26,6 +26,7 @@ from dash.fingerprint import check_fingerprint from dash import _validate from .base_server import BaseDashServer +from typing import Any class QuartDashServer(BaseDashServer): diff --git a/dash/dash.py b/dash/dash.py index 3ab830e8a3..52dd219627 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -416,7 +416,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches **obsolete, ): - _validate.check_async(use_async) + use_async = _validate.check_async(use_async) _validate.check_obsolete(obsolete) caller_name: str = name if name is not None else get_caller_name() @@ -1945,9 +1945,7 @@ def enable_dev_tools( self, dev_tools.prune_errors ) secret = gen_salt(20) - self.backend.register_prune_error_handler( - secret, dev_tools.prune_errors - ) + self.backend.register_prune_error_handler(secret, dev_tools.prune_errors) if debug and dev_tools.ui: self.backend.register_timing_hooks(first_run) From 77e22a3ca21ddc4f5d1a0f50964204f7dc92fb46 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:29:37 -0400 Subject: [PATCH 66/74] adjustments for request_adapter --- dash/backends/__init__.py | 2 +- dash/backends/_fastapi.py | 1 + dash/backends/_flask.py | 1 + dash/backends/_quart.py | 1 + dash/backends/base_server.py | 1 + dash/dash.py | 11 +++++------ 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index b845abb1ad..c8ac9321d0 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -19,7 +19,7 @@ def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: try: module = importlib.import_module(module_name) server = getattr(module, server_class) - request_adapter = getattr(module, request_class) + request_adapter = server.request_adapter # type: ignore return server, request_adapter except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index be2308d5f5..3f50f96f57 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -88,6 +88,7 @@ def __init__(self, server: FastAPI): self.server_type = "fastapi" self.server: FastAPI = server self.error_handling_mode = "prune" + self.request_adapter = FastAPIRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 138234a4bc..b4bab46ff2 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -32,6 +32,7 @@ class FlaskDashServer(BaseDashServer): def __init__(self, server: Flask) -> None: self.server: Flask = server self.server_type = "flask" + self.request_adapter = FlaskRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c5759026d4..8e509a08e0 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -35,6 +35,7 @@ def __init__(self, server: Quart) -> None: self.server: Quart = server self.config = {} self.error_handling_mode = "prune" + self.request_adapter = QuartRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 1c47548ad0..cf2b62c2e7 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -6,6 +6,7 @@ class BaseDashServer(ABC): server_type: str server: Any config: dict[str, Any] + request_adapter: Any def __call__(self, *args, **kwargs) -> Any: # Default: WSGI diff --git a/dash/dash.py b/dash/dash.py index 52dd219627..2d71766b4e 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -423,12 +423,11 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls, request_cls = get_backend("flask") + backend_cls = get_backend("flask") elif isinstance(backend, str): - backend_cls, request_cls = get_backend(backend) + backend_cls = get_backend(backend) elif isinstance(backend, type): backend_cls = backend - _, request_cls = get_backend(backend.server_type) else: raise ValueError("Invalid backend argument") @@ -437,20 +436,20 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # User provided a server instance (e.g., Flask, Quart, FastAPI) inferred_backend = backends.get_server_type(server) _validate.check_backend(backend, inferred_backend) - backend_cls, request_cls = get_backend(inferred_backend) + backend_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) self.backend = backend_cls(server) self.server = server backends.backend = self.backend # type: ignore - backends.request_adapter = request_cls + backends.request_adapter = self.backend.request_adapter # type: ignore else: # No server instance provided, create backend and let backend create server self.server = backend_cls.create_app(caller_name) # type: ignore self.backend = backend_cls(self.server) backends.backend = self.backend - backends.request_adapter = request_cls + backends.request_adapter = self.backend.request_adapter # type: ignore base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix From f7331d3f7e97dc52ebc1289d71e95111d49b3bb8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:31:35 -0400 Subject: [PATCH 67/74] adding test for custom dash server --- tests/backend_tests/test_custom_backend.py | 243 +++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 tests/backend_tests/test_custom_backend.py diff --git a/tests/backend_tests/test_custom_backend.py b/tests/backend_tests/test_custom_backend.py new file mode 100644 index 0000000000..aa590599b5 --- /dev/null +++ b/tests/backend_tests/test_custom_backend.py @@ -0,0 +1,243 @@ +import pytest +from dash import Dash, Input, Output, html, dcc +from fastapi import FastAPI +import traceback +import re +from dash.backends._fastapi import FastAPIDashServer + + +class CustomDashServer(FastAPIDashServer): + def _get_traceback(self, _secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+                + "\n".join(card)
+                + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // Custom Debugger + + + +
+

{error_type}: {error_msg}

+ {cards_html} +
+ + + """ + return html + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Hello CustomBackend!"), + ], +) +def test_custom_backend_basic_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend Test") + dash_duo.wait_for_text_to_equal("#output", "You typed: CustomBackend Test") + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ("dash_duo", {"debug": True, "reload": False, "dev_tools_ui": True}), + ], +) +def test_custom_backend_error_handling(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ( + "dash_duo", + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + ), + ], +) +def test_custom_backend_error_handling_no_prune(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" in error0 + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs, error_msg", + [ + ("dash_duo", {"debug": True, "reload": False}, "custombackend.py"), + ], +) +def test_custom_backend_error_handling_prune( + request, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" not in error0 + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Background CustomBackend!"), + ], +) +def test_custom_backend_background_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + + background_callback_manager = DiskcacheManager(cache) + + app = Dash( + __name__, + backend=CustomDashServer, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend BG Test") + dash_duo.wait_for_text_to_equal( + "#output", "Background typed: CustomBackend BG Test" + ) + assert dash_duo.get_logs() == [] From 8b58cf4e10b4a2f6c04bf988e5a752fc123757e8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:50:15 -0400 Subject: [PATCH 68/74] fixing issue with `request_adapter` --- dash/backends/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index c8ac9321d0..c264af4824 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -19,8 +19,7 @@ def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: try: module = importlib.import_module(module_name) server = getattr(module, server_class) - request_adapter = server.request_adapter # type: ignore - return server, request_adapter + return server except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: From b7d4af2bc744a9f58a31a94b005acb969bf7d6b9 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 13:30:00 -0400 Subject: [PATCH 69/74] adjusting error handling for fastapi --- dash/backends/_fastapi.py | 33 +++++++++++++++++++-------------- dash/backends/_quart.py | 4 +++- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 3f50f96f57..8fc08b7a89 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -70,24 +70,29 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # CONFIG_PATH = "dash_config.json" -def save_config(config): - with open(CONFIG_PATH, "w") as f: - json.dump(config, f) +# Internal config helpers (local to this file) +_CONFIG_PATH = "dash_config.json" +def _save_config(config): + with open(_CONFIG_PATH, "w") as f: + json.dump(config, f) -def load_config(): - if os.path.exists(CONFIG_PATH): - with open(CONFIG_PATH, "r") as f: - return json.load(f) +def _load_config(): + try: + if os.path.exists(_CONFIG_PATH): + with open(_CONFIG_PATH, "r") as f: + return json.load(f) + except Exception: + pass # ignore errors return {} class FastAPIDashServer(BaseDashServer): def __init__(self, server: FastAPI): - self.config = {} + _save_config({"debug": False}) # ensure config file exists self.server_type = "fastapi" self.server: FastAPI = server - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter super().__init__() @@ -120,7 +125,7 @@ def register_assets_blueprint( pass def register_error_handlers(self): - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ @@ -256,7 +261,7 @@ def setup_catchall(self, dash_app: Dash): @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( - **load_config(), first_run=False + **_load_config(), first_run=False ) # do this to make sure dev tools are enabled async def catchall(request: Request): @@ -298,12 +303,12 @@ def after_request(self, func: Callable[[], Any] | None): def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( - {"debug": debug} if debug else {}, + {"debug": debug} if debug else {"debug": False}, **{ f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() }, # pylint: disable=protected-access ) - save_config(config) + _save_config(config) if debug: if kwargs.get("reload") is None: kwargs["reload"] = True @@ -352,7 +357,7 @@ async def middleware(request, call_next): return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, - content={"error": "InternalServerError", "message": str(e.args[0])}, + content={"error": "InternalServerError", "message": "An internal server error occurred."}, ) return middleware diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 8e509a08e0..eae8df9117 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -34,7 +34,7 @@ def __init__(self, server: Quart) -> None: self.server_type = "quart" self.server: Quart = server self.config = {} - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter super().__init__() @@ -184,6 +184,8 @@ def register_prune_error_handler(self, secret, prune_errors): @self.server.errorhandler(Exception) async def _wrap_errors(error): + if self.error_handling_mode == "ignore": + return Response("Internal server error.", status=500, content_type="text/plain") tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") From 4cf4686f4f62aa93210b907bed7c360e7e883f1c Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:02:45 -0400 Subject: [PATCH 70/74] adjustments for handling issues with `debug` for `fastapi` --- dash/backends/_fastapi.py | 22 ++++++++++++++-------- dash_config.json | 1 - 2 files changed, 14 insertions(+), 9 deletions(-) delete mode 100644 dash_config.json diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 8fc08b7a89..36249fbf8c 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -66,30 +66,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) - -CONFIG_PATH = "dash_config.json" - - # Internal config helpers (local to this file) -_CONFIG_PATH = "dash_config.json" +_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") def _save_config(config): with open(_CONFIG_PATH, "w") as f: json.dump(config, f) def _load_config(): + resp = {"debug": False} try: if os.path.exists(_CONFIG_PATH): with open(_CONFIG_PATH, "r") as f: - return json.load(f) + resp = json.load(f) except Exception: pass # ignore errors - return {} + return resp + +def _remove_config(): + try: + os.remove(_CONFIG_PATH) + except FileNotFoundError: + pass class FastAPIDashServer(BaseDashServer): def __init__(self, server: FastAPI): - _save_config({"debug": False}) # ensure config file exists self.server_type = "fastapi" self.server: FastAPI = server self.error_handling_mode = "ignore" @@ -258,6 +260,10 @@ async def index(request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): + @self.server.on_event("shutdown") + def cleanup_config(): + _remove_config() + @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( diff --git a/dash_config.json b/dash_config.json deleted file mode 100644 index e4af4373cb..0000000000 --- a/dash_config.json +++ /dev/null @@ -1 +0,0 @@ -{"dev_tools_ui": false, "dev_tools_props_check": false, "dev_tools_serve_dev_bundles": false, "dev_tools_hot_reload": false, "dev_tools_silence_routes_logging": false, "dev_tools_prune_errors": false, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": true} \ No newline at end of file From dfe0ac7f106dd1ea300c36adb3faf922a665b406 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:17:50 -0400 Subject: [PATCH 71/74] fixing for lint --- dash/backends/__init__.py | 15 +- dash/backends/_fastapi.py | 226 +++++++-------------- dash/backends/_flask.py | 10 +- dash/backends/_quart.py | 158 +++----------- dash/backends/_utils.py | 108 ++++++++++ dash/testing/application_runners.py | 8 +- tests/backend_tests/test_custom_backend.py | 7 +- 7 files changed, 229 insertions(+), 303 deletions(-) create mode 100644 dash/backends/_utils.py diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index c264af4824..e8b007a50b 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -1,21 +1,19 @@ -from .base_server import BaseDashServer, RequestAdapter - import importlib +from .base_server import BaseDashServer -request_adapter: RequestAdapter backend: BaseDashServer _backend_imports = { - "flask": ("dash.backends._flask", "FlaskDashServer", "FlaskRequestAdapter"), - "fastapi": ("dash.backends._fastapi", "FastAPIDashServer", "FastAPIRequestAdapter"), - "quart": ("dash.backends._quart", "QuartDashServer", "QuartRequestAdapter"), + "flask": ("dash.backends._flask", "FlaskDashServer"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer"), + "quart": ("dash.backends._quart", "QuartDashServer"), } -def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: - module_name, server_class, request_class = _backend_imports[name.lower()] +def get_backend(name: str) -> BaseDashServer: + module_name, server_class = _backend_imports[name.lower()] try: module = importlib.import_module(module_name) server = getattr(module, server_class) @@ -74,7 +72,6 @@ def get_server_type(server): __all__ = [ "get_backend", - "request_adapter", "backend", "get_server_type", ] diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 36249fbf8c..dc7805501b 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from contextvars import copy_context, ContextVar from typing import TYPE_CHECKING, Any, Callable, Dict import sys @@ -8,27 +9,41 @@ import inspect import pkgutil import time -import traceback from importlib.util import spec_from_file_location import json import os -import re + +try: + from fastapi import FastAPI, Request, Response, Body + from fastapi.responses import JSONResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders + from starlette.types import ASGIApp, Scope, Receive, Send + import uvicorn +except ImportError: + FastAPI = None + Request = None + Response = None + Body = None + JSONResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + ASGIApp = None + Scope = None + Receive = None + Send = None + uvicorn = None from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate from .base_server import BaseDashServer, RequestAdapter - -from fastapi import FastAPI, Request, Response, Body -from fastapi.responses import JSONResponse -from fastapi.staticfiles import StaticFiles -from starlette.responses import Response as StarletteResponse -from starlette.datastructures import MutableHeaders -from starlette.types import ASGIApp, Scope, Receive, Send -import uvicorn +from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only - from dash.dash import Dash + from dash import Dash _current_request_var = ContextVar("dash_current_request", default=None) @@ -49,7 +64,7 @@ def get_current_request() -> Request: return req -class CurrentRequestMiddleware: +class CurrentRequestMiddleware: # pylint: disable=too-few-public-methods def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app @@ -66,23 +81,27 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) + # Internal config helpers (local to this file) _CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") + def _save_config(config): - with open(_CONFIG_PATH, "w") as f: + with open(_CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config, f) + def _load_config(): resp = {"debug": False} try: if os.path.exists(_CONFIG_PATH): - with open(_CONFIG_PATH, "r") as f: + with open(_CONFIG_PATH, "r", encoding="utf-8") as f: resp = json.load(f) - except Exception: + except (json.JSONDecodeError, OSError): pass # ignore errors return resp + def _remove_config(): try: os.remove(_CONFIG_PATH) @@ -130,113 +149,9 @@ def register_error_handlers(self): self.error_handling_mode = "ignore" def _get_traceback(self, _secret, error: Exception): - tb = error.__traceback__ - errors = traceback.format_exception(type(error), error, tb) - pass_errs = [] - callback_handled = False - for err in errors: - if self.error_handling_mode == "prune": - if not callback_handled: - if "callback invoked" in str(err) and "_callback.py" in str(err): - callback_handled = True - continue - pass_errs.append(err) - formatted_tb = "".join(pass_errs) - error_type = type(error).__name__ - error_msg = str(error) - - # Parse traceback lines to group by file - file_cards = [] - pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split("\n") - current_file = None - card_lines = [] - - for line in lines[:-1]: # Skip the last line (error message) - match = pattern.match(line) - if match: - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - current_file = ( - f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" - ) - card_lines = [line] - elif current_file: - card_lines.append(line) - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - - cards_html = "" - for filename, card in file_cards: - cards_html += ( - f""" -
-
{filename}
-
"""
-                + "\n".join(card)
-                + """
-
- """ - ) - - html = f""" - - - - {error_type}: {error_msg} // FastAPI Debugger - - - -
-

{error_type}

-
-

{error_type}: {error_msg}

-
-

Traceback (most recent call last)

- {cards_html} -
{error_type}: {error_msg}
-
-

This is the Copy/Paste friendly version of the traceback.

- -
-
- The debugger caught an exception in your ASGI application. You can now - look at the traceback which led to the error. -
- -
- - - """ - return html + return format_traceback_html( + error, self.error_handling_mode, "FastAPI Debugger", "FastAPI" + ) def register_prune_error_handler(self, _secret, prune_errors): if prune_errors: @@ -253,7 +168,7 @@ async def wrapped(*_args, **_kwargs): return wrapped def setup_index(self, dash_app: Dash): - async def index(request: Request): + async def index(_request: Request): return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access @@ -270,7 +185,7 @@ def _setup_catchall(): **_load_config(), first_run=False ) # do this to make sure dev tools are enabled - async def catchall(request: Request): + async def catchall(_request: Request): return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access @@ -308,11 +223,10 @@ def after_request(self, func: Callable[[], Any] | None): def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] + dev_tools = dash_app._dev_tools # pylint: disable=protected-access config = dict( {"debug": debug} if debug else {"debug": False}, - **{ - f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() - }, # pylint: disable=protected-access + **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, ) _save_config(config) if debug: @@ -348,7 +262,7 @@ def make_response( def jsonify(self, obj: Any): return JSONResponse(content=obj) - def _make_before_middleware(self, func: Callable[[], Any] | None): + def _make_before_middleware(self, _func: Callable[[], Any] | None): async def middleware(request, call_next): try: response = await call_next(request) @@ -356,14 +270,18 @@ async def middleware(request, call_next): except PreventUpdate: # No content, nothing to update return Response(status_code=204) - except Exception as e: + except (Exception) as e: # pylint: disable=broad-except + # Handle exceptions based on error_handling_mode if self.error_handling_mode in ["raise", "prune"]: # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, - content={"error": "InternalServerError", "message": "An internal server error occurred."}, + content={ + "error": "InternalServerError", + "message": "An internal server error occurred.", + }, ) return middleware @@ -417,27 +335,25 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites/{package_name}/{fingerprinted_path:path}", - serve, - ) + name = "_dash-component-suites/{package_name}/{fingerprinted_path:path}" + dash_app._add_url(name, serve) # pylint: disable=protected-access - # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context(body) # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context( + body + ) # pylint: disable=protected-access func = dash_app._prepare_callback( - g, body + cb_ctx, body ) # pylint: disable=protected-access args = dash_app._inputs_to_vals( - g.inputs_list + g.states_list + cb_ctx.inputs_list + cb_ctx.states_list ) # pylint: disable=protected-access ctx = copy_context() partial_func = dash_app._execute_callback( - func, args, g.outputs_list, g + func, args, cb_ctx.outputs_list, cb_ctx ) # pylint: disable=protected-access response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): @@ -494,20 +410,24 @@ def register_callback_api_routes( sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) - async def view_func(request: Request, body: dict = Body(...)): - # Only pass expected params; ignore extras - kwargs = { - k: v for k, v in body.items() if k in param_names and v is not None - } - if inspect.iscoroutinefunction(handler): - result = await handler(**kwargs) - else: - result = handler(**kwargs) - return JSONResponse(content=result) + def make_view_func(handler, param_names): + async def view_func(_request: Request, body: dict = Body(...)): + kwargs = { + k: v + for k, v in body.items() + if k in param_names and v is not None + } + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + return view_func self.server.add_api_route( route, - view_func, + make_view_func(handler, param_names), methods=methods, name=endpoint, include_in_schema=True, @@ -566,5 +486,5 @@ def origin(self): def path(self): return self._request.url.path - async def get_json(self): # async method retained - return await self._request.json() + def get_json(self): + return asyncio.run(self._request.json()) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index b4bab46ff2..d9abe9c7ed 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -17,6 +17,7 @@ jsonify, g as flask_g, ) +from werkzeug.debug import tbtools from dash.fingerprint import check_fingerprint from dash import _validate @@ -67,13 +68,6 @@ def _invalid_resources_handler(err): return err.args[0], 404 def _get_traceback(self, secret, error: Exception): - try: - from werkzeug.debug import ( - tbtools, - ) # pylint: disable=import-outside-toplevel - except ImportError: - tbtools = None - def _get_skip(error): tb = error.__traceback__ skip = 1 @@ -238,7 +232,7 @@ async def _dispatch_async(): cb_ctx.dash_response.set_data(response_data) return cb_ctx.dash_response - if dash_app._use_async: + if dash_app._use_async: # pylint: disable=protected-access return _dispatch_async return _dispatch diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index eae8df9117..f417bc0d2e 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -1,32 +1,36 @@ from __future__ import annotations from contextvars import copy_context import typing as _t -import traceback import mimetypes import inspect import pkgutil import time import sys -import re +from typing import Any # Attempt top-level Quart imports; allow absence if user not using quart backend -from quart import ( - Quart, - Response, - jsonify, - request, - Blueprint, - g, -) - -if _t.TYPE_CHECKING: - from dash import Dash +try: + from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g, + ) +except ImportError: + Quart = None + Response = None + jsonify = None + request = None + Blueprint = None + g = None from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint -from dash import _validate +from dash import _validate, Dash from .base_server import BaseDashServer -from typing import Any +from ._utils import format_traceback_html class QuartDashServer(BaseDashServer): @@ -68,113 +72,9 @@ def register_assets_blueprint( self.server.register_blueprint(bp) def _get_traceback(self, _secret, error: Exception): - tb = error.__traceback__ - errors = traceback.format_exception(type(error), error, tb) - pass_errs = [] - callback_handled = False - for err in errors: - if self.error_handling_mode == "prune": - if not callback_handled: - if "callback invoked" in str(err) and "_callback.py" in str(err): - callback_handled = True - continue - pass_errs.append(err) - formatted_tb = "".join(pass_errs) - error_type = type(error).__name__ - error_msg = str(error) - - # Parse traceback lines to group by file - file_cards = [] - pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split("\n") - current_file = None - card_lines = [] - - for line in lines[:-1]: # Skip the last line (error message) - match = pattern.match(line) - if match: - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - current_file = ( - f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" - ) - card_lines = [line] - elif current_file: - card_lines.append(line) - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - - cards_html = "" - for filename, card in file_cards: - cards_html += ( - f""" -
-
{filename}
-
"""
-                + "\n".join(card)
-                + """
-
- """ - ) - - html = f""" - - - - {error_type}: {error_msg} // Quart Debugger - - - -
-

{error_type}

-
-

{error_type}: {error_msg}

-
-

Traceback (most recent call last)

- {cards_html} -
{error_type}: {error_msg}
-
-

This is the Copy/Paste friendly version of the traceback.

- -
-
- The debugger caught an exception in your ASGI application. You can now - look at the traceback which led to the error. -
- -
- - - """ - return html + return format_traceback_html( + error, self.error_handling_mode, "Quart Debugger", "Quart" + ) def register_prune_error_handler(self, secret, prune_errors): if prune_errors: @@ -185,7 +85,9 @@ def register_prune_error_handler(self, secret, prune_errors): @self.server.errorhandler(Exception) async def _wrap_errors(error): if self.error_handling_mode == "ignore": - return Response("Internal server error.", status=500, content_type="text/plain") + return Response( + "Internal server error.", status=500, content_type="text/plain" + ) tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") @@ -337,14 +239,16 @@ async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body) + cb_ctx = dash_app._initialize_context(body) # pylint: disable=protected-access - func = dash_app._prepare_callback(g, body) + func = dash_app._prepare_callback(cb_ctx, body) # pylint: disable=protected-access - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() # pylint: disable=protected-access - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data diff --git a/dash/backends/_utils.py b/dash/backends/_utils.py new file mode 100644 index 0000000000..0a5f4b0e76 --- /dev/null +++ b/dash/backends/_utils.py @@ -0,0 +1,108 @@ +import traceback +import re + + +def format_traceback_html(error, error_handling_mode, title, backend): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+            + "\n".join(card)
+            + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // {title} + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+
+ Brought to you by DON'T PANIC, your + friendly {backend} powered traceback interpreter. +
+
+ + + """ + return html diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index 2956f1a4c0..6e6cc8b810 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -173,9 +173,9 @@ def run(): try: module = app.server.__class__.__module__ # FastAPI support - if not module.startswith("flask"): + if module.startswith("fastapi"): app.run(**options) - # Dash/Flask fallback + # Dash/Flask/Quart fallback else: app.run(threaded=True, **options) except SystemExit: @@ -237,9 +237,9 @@ def target(): try: module = app.server.__class__.__module__ # FastAPI support - if not module.startswith("flask"): + if module.startswith("fastapi"): app.run(**options) - # Dash/Flask fallback + # Dash/Flask/Quart fallback else: app.run(threaded=True, **options) except SystemExit: diff --git a/tests/backend_tests/test_custom_backend.py b/tests/backend_tests/test_custom_backend.py index aa590599b5..befff7734b 100644 --- a/tests/backend_tests/test_custom_backend.py +++ b/tests/backend_tests/test_custom_backend.py @@ -1,9 +1,12 @@ import pytest from dash import Dash, Input, Output, html, dcc -from fastapi import FastAPI import traceback import re -from dash.backends._fastapi import FastAPIDashServer + +try: + from dash.backends._fastapi import FastAPIDashServer +except ImportError: + FastAPIDashServer = None class CustomDashServer(FastAPIDashServer): From cd02cc5cbefefb7db331db625deee14897bf88b2 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:05:28 -0400 Subject: [PATCH 72/74] adjustment for delayed config --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 2d71766b4e..6b022108b6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -579,7 +579,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} - if self.server is not None: + if server: self.init_app() self.logger.setLevel(logging.INFO) From 16b3c9e08743918b4e7af26cf43186955af1494f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:27:20 -0400 Subject: [PATCH 73/74] fix typing error --- dash/backends/base_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index cf2b62c2e7..2b11bc763b 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict class BaseDashServer(ABC): server_type: str server: Any - config: dict[str, Any] + config: Dict[str, Any] request_adapter: Any def __call__(self, *args, **kwargs) -> Any: From 493d1503630f3d87f5be95cd4a8dc641a84adad0 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Mon, 22 Sep 2025 08:30:16 -0400 Subject: [PATCH 74/74] fixes for pages --- dash/_pages.py | 13 +- dash/backends/_fastapi.py | 41 +++++-- dash/backends/_flask.py | 14 +++ dash/backends/_quart.py | 14 +++ dash/dash.py | 251 ++++++++++++++++++-------------------- 5 files changed, 181 insertions(+), 152 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 19a797bcf2..0a5f9d8c06 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -150,22 +150,13 @@ def _parse_path_variables(pathname, path_template): return dict(zip(var_names, variables)) -def _create_redirect_function(redirect_to): - def redirect(): - return flask.redirect(redirect_to, code=301) - - return redirect - - def _set_redirect(redirect_from, path): app = get_app() if redirect_from and len(redirect_from): for redirect in redirect_from: fullname = app.get_relative_path(redirect) - app.server.add_url_rule( - fullname, - fullname, - _create_redirect_function(app.get_relative_path(path)), + app.backend.add_redirect_rule( + app, fullname, app.get_relative_path(path) ) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index dc7805501b..61b2d65a8f 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -15,7 +15,7 @@ try: from fastapi import FastAPI, Request, Response, Body - from fastapi.responses import JSONResponse + from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders @@ -27,6 +27,7 @@ Response = None Body = None JSONResponse = None + RedirectResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None @@ -115,6 +116,7 @@ def __init__(self, server: FastAPI): self.server: FastAPI = server self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter + self._before_request_funcs = [] super().__init__() def __call__(self, *args: Any, **kwargs: Any): @@ -213,9 +215,13 @@ def add_url_rule( ) def before_request(self, func: Callable[[], Any] | None): - # FastAPI does not have before_request, but we can use middleware - self.server.add_middleware(CurrentRequestMiddleware) - self.server.middleware("http")(self._make_before_middleware(func)) + if func is not None: + self._before_request_funcs.append(func) + # Only add the middleware once + if not hasattr(self, "_before_middleware_added"): + self.server.add_middleware(CurrentRequestMiddleware) + self.server.middleware("http")(self._make_before_middleware()) + self._before_middleware_added = True def after_request(self, func: Callable[[], Any] | None): # FastAPI does not have after_request, but we can use middleware @@ -262,18 +268,20 @@ def make_response( def jsonify(self, obj: Any): return JSONResponse(content=obj) - def _make_before_middleware(self, _func: Callable[[], Any] | None): + def _make_before_middleware(self): async def middleware(request, call_next): + for func in self._before_request_funcs: + if inspect.iscoroutinefunction(func): + await func() + else: + func() try: response = await call_next(request) return response except PreventUpdate: - # No content, nothing to update return Response(status_code=204) - except (Exception) as e: # pylint: disable=broad-except - # Handle exceptions based on error_handling_mode + except Exception as e: if self.error_handling_mode in ["raise", "prune"]: - # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( @@ -338,6 +346,21 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): name = "_dash-component-suites/{package_name}/{fingerprinted_path:path}" dash_app._add_url(name, serve) # pylint: disable=protected-access + def _create_redirect_function(self, redirect_to): + def _redirect(): + return RedirectResponse(url=redirect_to, status_code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_api_route( + fullname, + self._create_redirect_function(app.get_relative_path(path)), + methods=["GET"], + name=fullname, + include_in_schema=False, + ) + def dispatch(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index d9abe9c7ed..2f7e08acf5 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -16,6 +16,7 @@ request, jsonify, g as flask_g, + redirect, ) from werkzeug.debug import tbtools @@ -194,6 +195,19 @@ def serve(package_name, fingerprinted_path): serve, ) + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): def _dispatch(): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index f417bc0d2e..c08a165234 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -17,6 +17,7 @@ request, Blueprint, g, + redirect, ) except ImportError: Quart = None @@ -233,6 +234,19 @@ async def serve(package_name, fingerprinted_path): serve, ) + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async async def _dispatch(): diff --git a/dash/dash.py b/dash/dash.py index 6b022108b6..53818cb5fb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2241,7 +2241,8 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - def router(): + # Async version + async def router_async(): if self._got_first_request["pages"]: return self._got_first_request["pages"] = True @@ -2250,157 +2251,143 @@ def router(): "pathname_": Input(_ID_LOCATION, "pathname"), "search_": Input(_ID_LOCATION, "search"), } - inputs.update(self.routing_callback_inputs) # type: ignore[reportCallIssue] + inputs.update(self.routing_callback_inputs) - if self._use_async: + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + ) + async def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break + else: + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - ) - async def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) + if callable(layout): + layout = await execute_async_function( + layout, + **{**(path_variables or {}), **query_parameters, **states}, ) + if callable(title): + title = await execute_async_function( + title, **{**(path_variables or {})} + ) + return layout, {"title": title} - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title - else: - layout = page.get("layout", "") - title = page["title"] - - if callable(layout): - layout = await execute_async_function( - layout, - **{**(path_variables or {}), **query_parameters, **states}, - ) - if callable(title): - title = await execute_async_function( - title, **{**(path_variables or {})} - ) + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) - return layout, {"title": title} + if not self.config.suppress_callback_exceptions: + async def get_layouts(): + return [ + await execute_async_function(page["layout"]) + if callable(page["layout"]) else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + layouts = await get_layouts() + layouts += [ + self.layout() if callable(self.layout) else self.layout + ] + self.validation_layout = html.Div(layouts) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) + self.clientside_callback( + """ + function(data) { + document.title = data.title + } + """, + Output(_ID_DUMMY, "children"), + Input(_ID_STORE, "data"), + ) - # Set validation_layout - if not self.config.suppress_callback_exceptions: - self.validation_layout = html.Div( - [ - ( - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") - else: + # Sync version + def router_sync(): + if self._got_first_request["pages"]: + return + self._got_first_request["pages"] = True - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - ) - def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) - ) + inputs = { + "pathname_": Input(_ID_LOCATION, "pathname"), + "search_": Input(_ID_LOCATION, "search"), + } + inputs.update(self.routing_callback_inputs) - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + ) + def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break else: - layout = page.get("layout", "") - title = page["title"] + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] - if callable(layout): - layout = layout( - **{**(path_variables or {}), **query_parameters, **states} - ) - if callable(title): - title = title(**(path_variables or {})) - - return layout, {"title": title} - - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) - - # Set validation_layout - if not self.config.suppress_callback_exceptions: - layout = self.layout - if not isinstance(layout, list): - layout = [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - self.validation_layout = html.Div( - [ - ( - page["layout"]() - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + layout - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") + if callable(layout): + layout = layout( + **{**(path_variables or {}), **query_parameters, **states} + ) + if callable(title): + title = title(**(path_variables or {})) + return layout, {"title": title} + + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) + + if not self.config.suppress_callback_exceptions: + layout = self.layout + if not isinstance(layout, list): + layout = [ + self.layout() if callable(self.layout) else self.layout + ] + self.validation_layout = html.Div( + [ + page["layout"]() if callable(page["layout"]) else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + layout + ) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - # Update the page title on page navigation self.clientside_callback( """ - function(data) {{ + function(data) { document.title = data.title - }} + } """, Output(_ID_DUMMY, "children"), Input(_ID_STORE, "data"), ) - self.backend.before_request(router) + if self._use_async: + self.backend.before_request(router_async) + else: + self.backend.before_request(router_sync) def __call__(self, *args, **kwargs): return self.backend.__call__(*args, **kwargs)