diff --git a/CHANGELOG.md b/CHANGELOG.md index e3b5c6e..8c796f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- Possibility to whitelist routes with the `add_public_routes` utility function, the routes should follow Flask route syntax +- NOTE: If you are using server-side callbacks on your public routes, you should use dash_auth's new `public_callback` rather than the default Dash callback + ## [2.1.0] - 2024-01-24 ### Changed - Uses flask `before_request` to protect all endpoints rather than protecting routes present at instantiation time diff --git a/README.md b/README.md index 545b841..ead912b 100644 --- a/README.md +++ b/README.md @@ -52,3 +52,109 @@ def authorization_function(username, password): app = Dash(__name__) BasicAuth(app, auth_func = authorization_function) ``` + +### Public routes + +You can whitelist routes from authentication with the `add_public_routes` utility function, +or by passing a `public_routes` argument to the Auth constructor. +The public routes should follow [Flask's route syntax](https://flask.palletsprojects.com/en/2.3.x/quickstart/#routing). + +```python +from dash import Dash +from dash_auth import BasicAuth, add_public_routes + +app = Dash(__name__) +USER_PWD = { + "username": "password", + "user2": "useSomethingMoreSecurePlease", +} +BasicAuth(app, USER_PWD, public_routes=["/"]) + +add_public_routes(app, public_routes=["/user//public"]) +``` + +NOTE: If you are using server-side callbacks on your public routes, you should also use dash_auth's new `public_callback` rather than the default Dash callback. +Below is an example of a public route and callbacks on a multi-page Dash app using Dash's pages API: + +*app.py* +```python +from dash import Dash, html, dcc, page_container +from dash_auth import BasicAuth + +app = Dash(__name__, use_pages=True, suppress_callback_exceptions=True) +USER_PWD = { + "username": "password", + "user2": "useSomethingMoreSecurePlease", +} +BasicAuth(app, USER_PWD, public_routes=["/", "/user//public"]) + +app.layout = html.Div( + [ + html.Div( + [ + dcc.Link("Home", href="/"), + dcc.Link("John Doe", href="/user/john_doe/public"), + ], + style={"display": "flex", "gap": "1rem", "background": "lightgray", "padding": "0.5rem 1rem"}, + ), + page_container, + ], + style={"display": "flex", "flexDirection": "column"}, +) + +if __name__ == "__main__": + app.run_server(debug=True) +``` + +--- +*pages/home.py* +```python +from dash import Input, Output, html, register_page +from dash_auth import public_callback + +register_page(__name__, "/") + +layout = [ + html.H1("Home Page"), + html.Button("Click me", id="home-button"), + html.Div(id="home-contents"), +] + +# Note the use of public callback here rather than the default Dash callback +@public_callback( + Output("home-contents", "children"), + Input("home-button", "n_clicks"), +) +def home(n_clicks): + if not n_clicks: + return "You haven't clicked the button." + return "You clicked the button {} times".format(n_clicks) +``` + +--- +*pages/public_user.py* +```python +from dash import html, dcc, register_page + +register_page(__name__, path_template="/user//public") + +def layout(user_id: str): + return [ + html.H1(f"User {user_id} (public)"), + dcc.Link("Authenticated user content", href=f"/user/{user_id}/private"), + ] +``` + +--- +*pages/private_user.py* +```python +from dash import html, register_page + +register_page(__name__, path_template="/user//private") + +def layout(user_id: str): + return [ + html.H1(f"User {user_id} (authenticated only)"), + html.Div("Members-only information"), + ] +``` \ No newline at end of file diff --git a/dash_auth/__init__.py b/dash_auth/__init__.py index 6ca5cec..9031fc3 100644 --- a/dash_auth/__init__.py +++ b/dash_auth/__init__.py @@ -1,5 +1,6 @@ +from .public_routes import add_public_routes, public_callback from .basic_auth import BasicAuth from .version import __version__ -__all__ = ["BasicAuth", "__version__"] +__all__ = ["add_public_routes", "public_callback", "BasicAuth", "__version__"] diff --git a/dash_auth/auth.py b/dash_auth/auth.py index f5fdcb3..e2598fe 100644 --- a/dash_auth/auth.py +++ b/dash_auth/auth.py @@ -1,14 +1,27 @@ from __future__ import absolute_import from abc import ABC, abstractmethod +from typing import Optional from dash import Dash +from flask import request + +from .public_routes import ( + add_public_routes, get_public_callbacks, get_public_routes +) class Auth(ABC): - def __init__(self, app: Dash, **obsolete): + def __init__( + self, + app: Dash, + public_routes: Optional[list] = None, + **obsolete + ): """Auth base class for authentication in Dash. :param app: Dash app + :param public_routes: list of public routes, routes should follow the + Flask route syntax """ # Deprecated arguments @@ -19,12 +32,15 @@ def __init__(self, app: Dash, **obsolete): self.app = app self._protect() + if public_routes is not None: + add_public_routes(app, public_routes) def _protect(self): """Add a before_request authentication check on all routes. - The authentication check will pass if the request - is authorised by `Auth.is_authorised` + The authentication check will pass if either + * The endpoint is marked as public via `add_public_routes` + * The request is authorised by `Auth.is_authorised` """ server = self.app.server @@ -32,8 +48,35 @@ def _protect(self): @server.before_request def before_request_auth(): - # Check whether the request is authorised - if self.is_authorized(): + public_routes = get_public_routes(self.app) + public_callbacks = get_public_callbacks(self.app) + # Handle Dash's callback route: + # * Check whether the callback is marked as public + # * Check whether the callback is performed on route change in + # which case the path should be checked against the public routes + if request.path == "/_dash-update-component": + body = request.get_json() + + # Check whether the callback is marked as public + if body["output"] in public_callbacks: + return None + + # Check whether the callback has an input using the pathname, + # such a callback will be a routing callback and the pathname + # should be checked against the public routes + pathname = next( + ( + inp["value"] for inp in body["inputs"] + if inp["property"] == "pathname" + ), + None, + ) + if pathname and public_routes.test(pathname): + return None + + # If the route is not a callback route, check whether the path + # matches a public route, or whether the request is authorised + if public_routes.test(request.path) or self.is_authorized(): return None # Otherwise, ask the user to log in diff --git a/dash_auth/basic_auth.py b/dash_auth/basic_auth.py index 66dc36f..b594935 100644 --- a/dash_auth/basic_auth.py +++ b/dash_auth/basic_auth.py @@ -1,5 +1,5 @@ import base64 -from typing import Union, Callable +from typing import Optional, Union, Callable import flask from dash import Dash @@ -11,7 +11,8 @@ def __init__( self, app: Dash, username_password_list: Union[list, dict] = None, - auth_func: Callable = None + auth_func: Callable = None, + public_routes: Optional[list] = None, ): """Add basic authentication to Dash. @@ -21,18 +22,24 @@ def __init__( :param auth_func: python function accepting two string arguments (username, password) and returning a boolean (True if the user has access otherwise False). + :param public_routes: list of public routes, routes should follow the + Flask route syntax """ - Auth.__init__(self, app) + Auth.__init__(self, app, public_routes=public_routes) self._auth_func = auth_func if self._auth_func is not None: if username_password_list is not None: - raise ValueError("BasicAuth can only use authorization " - "function (auth_func kwarg) or " - "username_password_list, it cannot use both.") + raise ValueError( + "BasicAuth can only use authorization function " + "(auth_func kwarg) or username_password_list, " + "it cannot use both." + ) else: if username_password_list is None: - raise ValueError("BasicAuth requires username/password map " - "or user-defined authorization function.") + raise ValueError( + "BasicAuth requires username/password map " + "or user-defined authorization function." + ) else: self._users = ( username_password_list diff --git a/dash_auth/public_routes.py b/dash_auth/public_routes.py new file mode 100644 index 0000000..5c9540c --- /dev/null +++ b/dash_auth/public_routes.py @@ -0,0 +1,106 @@ +import inspect +import os + +from dash import Dash, callback +from dash._callback import GLOBAL_CALLBACK_MAP +from dash import get_app +from werkzeug.routing import Map, MapAdapter, Rule + + +DASH_PUBLIC_ASSETS_EXTENSIONS = "js,css" +BASE_PUBLIC_ROUTES = [ + f"/assets/.{ext}" + for ext in os.getenv( + "DASH_PUBLIC_ASSETS_EXTENSIONS", + DASH_PUBLIC_ASSETS_EXTENSIONS, + ).split(",") +] + [ + "/_dash-component-suites/", + "/_dash-layout", + "/_dash-dependencies", + "/_favicon.ico", + "/_reload-hash", +] +PUBLIC_ROUTES = "PUBLIC_ROUTES" +PUBLIC_CALLBACKS = "PUBLIC_CALLBACKS" + + +def add_public_routes(app: Dash, routes: list): + """Add routes to the public routes list. + + The routes passed should follow the Flask route syntax. + e.g. "/login", "/user//public" + + Some routes are made public by default: + * All dash scripts (_dash-dependencies, _dash-component-suites/**) + * All dash mechanics routes (_dash-layout, _reload-hash) + * All assets with extension .css, .js, .svg, .jpg, .png, .gif, .webp + Note: you can modify the extension by setting the + `DASH_ASSETS_PUBLIC_EXTENSIONS` envvar (comma-separated list of + extensions, e.g. "js,css,svg"). + * The favicon + + If you use callbacks on your public routes, you should use dash_auth's + `public_callback` rather than the standard dash callback. + + :param app: Dash app + :param routes: list of public routes to be added + """ + + public_routes = get_public_routes(app) + + if not public_routes.map._rules: + routes = BASE_PUBLIC_ROUTES + routes + + for route in routes: + public_routes.map.add(Rule(route)) + + app.server.config[PUBLIC_ROUTES] = public_routes + + +def public_callback(*callback_args, **callback_kwargs): + """Public Dash callback. + + This works by adding the callback id (from the callback map) to a list + of whitelisted callbacks in the Flask server's config. + + :param **: all args and kwargs passed to a dash callback + """ + + def decorator(func): + + wrapped_func = callback(*callback_args, **callback_kwargs)(func) + callback_id = next( + ( + k for k, v in GLOBAL_CALLBACK_MAP.items() + if inspect.getsource(v["callback"]) == inspect.getsource(func) + ), + None, + ) + try: + app = get_app() + app.server.config[PUBLIC_CALLBACKS] = ( + get_public_callbacks(app) + [callback_id] + ) + except Exception: + print( + "Could not set up the public callback as the Dash object " + "has not yet been instantiated." + ) + + def wrap(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrap + + return decorator + + +def get_public_routes(app: Dash) -> MapAdapter: + """Retrieve the public routes.""" + return app.server.config.get(PUBLIC_ROUTES, Map([]).bind("")) + + +def get_public_callbacks(app: Dash) -> list: + """Retrieve the public callbacks ids.""" + return app.server.config.get(PUBLIC_CALLBACKS, []) diff --git a/dev-requirements.txt b/dev-requirements.txt index e93cf42..10e3822 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -2,4 +2,5 @@ dash[testing]>=2 requests[security] flake8 flask +werkzeug pytest diff --git a/setup.py b/setup.py index ab6096f..3a71423 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ long_description_content_type="text/markdown", install_requires=[ 'dash>=1.1.1', - "flask", + 'flask', + 'werkzeug', ], python_requires=">=3.8", include_package_data=True, diff --git a/tests/test_basic_auth_integration.py b/tests/test_basic_auth_integration.py index 3dc049c..bb2cac6 100644 --- a/tests/test_basic_auth_integration.py +++ b/tests/test_basic_auth_integration.py @@ -1,7 +1,7 @@ from dash import Dash, Input, Output, dcc, html import requests -from dash_auth import basic_auth +from dash_auth import basic_auth, add_public_routes TEST_USERS = { @@ -26,19 +26,26 @@ def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server): def update_output(new_value): return new_value - basic_auth.BasicAuth(app, TEST_USERS["valid"]) + basic_auth.BasicAuth(app, TEST_USERS["valid"], public_routes=["/home"]) + add_public_routes(app, ["/user//public"]) dash_thread_server(app) base_url = dash_thread_server.url def test_failed_views(url): assert requests.get(url).status_code == 401 - assert requests.get(url.strip("/") + "/_dash-layout").status_code == 401 + + def test_successful_views(url): + assert requests.get(url.strip("/") + "/_dash-layout").status_code == 200 + assert requests.get(url.strip("/") + "/home").status_code == 200 + assert requests.get(url.strip("/") + "/user/john123/public").status_code == 200 test_failed_views(base_url) + test_successful_views(base_url) for user, password in TEST_USERS["invalid"]: test_failed_views(base_url.replace("//", f"//{user}:{password}@")) + test_successful_views(base_url.replace("//", f"//{user}:{password}@")) # Test login for each user: for user, password in TEST_USERS["valid"]: