diff --git a/.circleci/config.yml b/.circleci/config.yml index 2312c8b..b30a55a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,7 +1,7 @@ version: 2.1 orbs: - browser-tools: circleci/browser-tools@1.4.6 + browser-tools: circleci/browser-tools@1.4.8 jobs: python-38: &test-template diff --git a/CHANGELOG.md b/CHANGELOG.md index d6a69bd..b1cb54b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [Unreleased] +### Added +- OIDCAuth allows to authenticate via OIDC +- BasicAuth saves the current user in the session +- Ability to define user groups in BasicAuth +- Group-based permission and protection functions + ## [2.2.1] - 2024-03-01 ### Fixed - Fix when looking for callback inputs that are not in the right format when checking for whitelisted routes diff --git a/README.md b/README.md index ead912b..33e9016 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,155 @@ def layout(user_id: str): html.H1(f"User {user_id} (authenticated only)"), html.Div("Members-only information"), ] -``` \ No newline at end of file +``` + +### OIDC Authentication + +To add authentication with OpenID Connect, you will first need to set up an OpenID Connect provider (IDP). +This typically requires creating +* An application in your IDP +* Defining the redirect URI for your application, for testing locally you can use http://localhost:8050/oidc/callback +* A client ID and secret for the application + +Once you have set up your IDP, you can add it to your Dash app as follows: + +```python +from dash import Dash +from dash_auth import OIDCAuth + +app = Dash(__name__) + +auth = OIDCAuth(app, secret_key="aStaticSecretKey!") +auth.register_provider( + "idp", + token_endpoint_auth_method="client_secret_post", + # Replace the below values with your own + # NOTE: Do not hardcode your client secret! + client_id="", + client_secret="", + server_metadata_url="", +) +``` + +Once this is done, connecting to your app will automatically redirect to the IDP login page. + +#### Multiple OIDC Providers + +For multiple OIDC providers, you can use `register_provider` to add new ones after the OIDCAuth has been instantiated. + +```python +from dash import Dash, html +from dash_auth import OIDCAuth +from flask import request, redirect, url_for + +app = Dash(__name__) + +app.layout = html.Div([ + html.Div("Hello world!"), + html.A("Logout", href="/oidc/logout"), +]) + +auth = OIDCAuth( + app, + secret_key="aStaticSecretKey!", + # Set the route at which the user will select the IDP they wish to login with + idp_selection_route="/login", +) +auth.register_provider( + "IDP 1", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="", +) +auth.register_provider( + "IDP 2", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="", +) + +@app.server.route("/login", methods=["GET", "POST"]) +def login_handler(): + if request.method == "POST": + idp = request.form.get("idp") + else: + idp = request.args.get("idp") + + if idp is not None: + return redirect(url_for("oidc_login", idp=idp)) + + return """
+
+
How do you wish to sign in:
+ + +
+
""" + + +if __name__ == "__main__": + app.run_server(debug=True) +``` + +### User-group-based permissions + +`dash_auth` provides a convenient way to secure parts of your app based on user groups. + +The following utilities are defined: +* `list_groups`: Returns the groups of the current user, or None if the user is not authenticated. +* `check_groups`: Checks the current user groups against the provided list of groups. + Available group checks are `one_of`, `all_of` and `none_of`. + The function returns None if the user is not authenticated. +* `protected`: A function decorator that modifies the output if the user is unauthenticated + or missing group permission. +* `protected_callback`: A callback that only runs if the user is authenticated + and with the right group permissions. + +NOTE: user info is stored in the session so make sure you define a secret_key on the Flask server +to use this feature. + +If you wish to use this feature with BasicAuth, you will need to define the groups for individual +basicauth users: + +```python +from dash_auth import BasicAuth + +app = Dash(__name__) +USER_PWD = { + "username": "password", + "user2": "useSomethingMoreSecurePlease", +} +BasicAuth( + app, + USER_PWD, + user_groups={"user1": ["group1", "group2"], "user2": ["group2"]}, + secret_key="Test!", +) + +# You can also use a function to get user groups +def check_user(username, password): + if username == "user1" and password == "password": + return True + if username == "user2" and password == "useSomethingMoreSecurePlease": + return True + return False + +def get_user_groups(user): + if user == "user1": + return ["group1", "group2"] + elif user == "user2": + return ["group2"] + return [] + +BasicAuth( + app, + auth_func=check_user, + user_groups=get_user_groups, + secret_key="Test!", +) +``` diff --git a/dash_auth/__init__.py b/dash_auth/__init__.py index 9031fc3..b3cc0ab 100644 --- a/dash_auth/__init__.py +++ b/dash_auth/__init__.py @@ -1,6 +1,25 @@ from .public_routes import add_public_routes, public_callback from .basic_auth import BasicAuth +from .group_protection import ( + list_groups, check_groups, protected, protected_callback +) +# oidc auth requires authlib, install with `pip install dash-auth[oidc]` +try: + from .oidc_auth import OIDCAuth, get_oauth +except ModuleNotFoundError: + pass from .version import __version__ -__all__ = ["add_public_routes", "public_callback", "BasicAuth", "__version__"] +__all__ = [ + "add_public_routes", + "check_groups", + "list_groups", + "get_oauth", + "protected", + "protected_callback", + "public_callback", + "BasicAuth", + "OIDCAuth", + "__version__", +] diff --git a/dash_auth/auth.py b/dash_auth/auth.py index 6409e57..862bb5e 100644 --- a/dash_auth/auth.py +++ b/dash_auth/auth.py @@ -83,22 +83,10 @@ def before_request_auth(): # Otherwise, ask the user to log in return self.login_request() - def is_authorized_hook(self, func): - self._auth_hooks.append(func) - return func - @abstractmethod def is_authorized(self): pass - @abstractmethod - def auth_wrapper(self, f): - pass - - @abstractmethod - def index_auth_wrapper(self, f): - pass - @abstractmethod def login_request(self): pass diff --git a/dash_auth/basic_auth.py b/dash_auth/basic_auth.py index b594935..c281f70 100644 --- a/dash_auth/basic_auth.py +++ b/dash_auth/basic_auth.py @@ -1,10 +1,13 @@ import base64 -from typing import Optional, Union, Callable +import logging +from typing import Dict, List, Optional, Union, Callable import flask from dash import Dash from .auth import Auth +UserGroups = Dict[str, List[str]] + class BasicAuth(Auth): def __init__( @@ -13,6 +16,10 @@ def __init__( username_password_list: Union[list, dict] = None, auth_func: Callable = None, public_routes: Optional[list] = None, + user_groups: Optional[ + Union[UserGroups, Callable[[str], UserGroups]] + ] = None, + secret_key: str = None ): """Add basic authentication to Dash. @@ -24,9 +31,28 @@ def __init__( boolean (True if the user has access otherwise False). :param public_routes: list of public routes, routes should follow the Flask route syntax + :param user_groups: a dict or a function returning a dict + Optional group for each user, allowing to protect routes and + callbacks depending on user groups + :param secret_key: Flask secret key + A string to protect the Flask session, by default None. + It is required if you need to store the current user + in the session. + Generate a secret key in your Python session + with the following commands: + >>> import os + >>> import base64 + >>> base64.b64encode(os.urandom(30)).decode('utf-8') + Note that you should not do this dynamically: + you should create a key and then assign the value of + that key in your code. """ - Auth.__init__(self, app, public_routes=public_routes) + super().__init__(app, public_routes=public_routes) self._auth_func = auth_func + self._user_groups = user_groups + if secret_key is not None: + app.server.secret_key = secret_key + if self._auth_func is not None: if username_password_list is not None: raise ValueError( @@ -54,14 +80,31 @@ def is_authorized(self): username_password = base64.b64decode(header.split('Basic ')[1]) username_password_utf8 = username_password.decode('utf-8') username, password = username_password_utf8.split(':', 1) + authorized = False if self._auth_func is not None: try: - return self._auth_func(username, password) - except Exception as e: - print(e) + authorized = self._auth_func(username, password) + except Exception: + logging.exception("Error in authorization function.") return False else: - return self._users.get(username) == password + authorized = self._users.get(username) == password + if authorized: + try: + flask.session["user"] = {"email": username, "groups": []} + if callable(self._user_groups): + flask.session["user"]["groups"] = self._user_groups( + username + ) + elif self._user_groups: + flask.session["user"]["groups"] = self._user_groups.get( + username, [] + ) + except RuntimeError: + logging.warning( + "Session is not available. Have you set a secret key?" + ) + return authorized def login_request(self): return flask.Response( @@ -69,20 +112,3 @@ def login_request(self): headers={'WWW-Authenticate': 'Basic realm="User Visible Realm"'}, status=401 ) - - def auth_wrapper(self, f): - def wrap(*args, **kwargs): - if not self.is_authorized(): - return flask.Response(status=403) - - response = f(*args, **kwargs) - return response - return wrap - - def index_auth_wrapper(self, original_index): - def wrap(*args, **kwargs): - if self.is_authorized(): - return original_index(*args, **kwargs) - else: - return self.login_request() - return wrap diff --git a/dash_auth/group_protection.py b/dash_auth/group_protection.py new file mode 100644 index 0000000..072f154 --- /dev/null +++ b/dash_auth/group_protection.py @@ -0,0 +1,214 @@ +import logging +import re +from typing import Any, Callable, List, Literal, Optional, Union + +import dash +from dash.exceptions import PreventUpdate +from flask import session, has_request_context + + +OutputVal = Union[Callable[[], Any], Any] +CheckType = Literal["one_of", "all_of", "none_of"] + + +def list_groups( + *, + groups_key: str = "groups", + groups_str_split: str = None, +) -> Optional[List[str]]: + """List all the groups the user belongs to. + + :param groups_key: Groups key in the user data saved in the Flask session + e.g. session["user"] == {"email": "a.b@mail.com", "groups": ["admin"]} + :param groups_str_split: Used to split groups if provided as a string + :return: None or list[str]: + * None if the user is not authenticated + * list[str] otherwise + """ + if not has_request_context() or "user" not in session: + return None + + user_groups = session.get("user", {}).get(groups_key, []) + # Handle cases where groups are ,- or ;-separated string, + # may depend on OIDC provider + if isinstance(user_groups, str) and groups_str_split is not None: + user_groups = re.split(groups_str_split, user_groups) + return user_groups + + +def check_groups( + groups: Optional[List[str]] = None, + *, + groups_key: str = "groups", + groups_str_split: str = None, + check_type: CheckType = "one_of", +) -> Optional[bool]: + """Check whether the current user is authenticated + and has the specified groups. + + :param groups: List of groups to check for with check_type + :param groups_key: Groups key in the user data saved in the Flask session + e.g. session["user"] == {"email": "a.b@mail.com", "groups": ["admin"]} + :param groups_str_split: Used to split groups if provided as a string + :param check_type: Type of check to perform. + Either "one_of", "all_of" or "none_of" + :return: None or boolean: + * None if the user is not authenticated + * True if the user is authenticated and has the right permissions + * False if the user is authenticated but does not have + the right permissions + """ + user_groups = list_groups( + groups_key=groups_key, + groups_str_split=groups_str_split, + ) + + if user_groups is None: + # User is not authenticated + return None + + if groups is None: + return True + + if check_type == "one_of": + return bool(set(user_groups).intersection(groups)) + if check_type == "all_of": + return all(group in user_groups for group in groups) + if check_type == "none_of": + return not any(group in user_groups for group in groups) + + raise ValueError(f"Invalid check_type: {check_type}") + + +def protected( + unauthenticated_output: OutputVal, + *, + missing_permissions_output: Optional[OutputVal] = None, + groups: Optional[List[str]] = None, + groups_key: str = "groups", + groups_str_split: str = None, + check_type: CheckType = "one_of", +) -> Callable: + """Decorate a function or output to alter it depending on the state + of authentication and permissions. + + :param unauthenticated_output: Output when the user is not authenticated. + Note: needs to be a function with no argument or static outputs. + :param missing_permissions_output: Output when the user is authenticated + but does not have the right permissions. + It defaults to unauthenticated_output when not set. + Note: needs to be a function with no argument or static outputs. + :param groups: List of authorized user groups. If no groups are passed, + the decorator will only check whether the user is authenticated. + :param groups_key: Groups key in the user data saved in the Flask session + e.g. session["user"] == {"email": "a.b@mail.com", "groups": ["admin"]} + :param groups_str_split: Used to split groups if provided as a string + :param check_type: Type of check to perform. + Either "one_of", "all_of" or "none_of" + """ + + if missing_permissions_output is None: + missing_permissions_output = unauthenticated_output + + def decorator(output: OutputVal): + def wrap(*args, **kwargs): + def process_output(output, *args, **kwargs): + if isinstance(output, Callable): + return output(*args, **kwargs) + return output + + authorized = check_groups( + groups=groups, + groups_key=groups_key, + groups_str_split=groups_str_split, + check_type=check_type, + ) + if authorized is None: + return process_output(unauthenticated_output) + if authorized: + return process_output(output, *args, **kwargs) + return process_output(missing_permissions_output) + + if isinstance(output, Callable): + return wrap + return wrap() + + return decorator + + +def protected_callback( + *callback_args, + unauthenticated_output: Optional[OutputVal] = None, + missing_permissions_output: Optional[OutputVal] = None, + groups: List[str] = None, + groups_key: str = "groups", + groups_str_split: str = None, + check_type: CheckType = "one_of", + **callback_kwargs, +) -> Callable: + """Protected Dash callback. + + :param **: all args and kwargs passed to a Dash callback + :param unauthenticated_output: Output when the user is not authenticated. + **Note**: Needs to be a function with no argument or static outputs. + You can access the Dash callback context within the function call if + you need to use some of the inputs/states of the callback. + If left as None, it will simply raise PreventUpdate, stopping the + callback from processing. + :param missing_permissions_output: Output when the user is authenticated + but does not have the right permissions. + It defaults to unauthenticated_output when not set. + **Note**: Needs to be a function with no argument or static outputs. + You can access the Dash callback context within the function call if + you need to use some of the inputs/states of the callback. + If left as None, it will simply raise PreventUpdate, stopping the + callback from processing. + :param groups: List of authorized user groups + :param groups_key: Groups key in the user data saved in the Flask session + e.g. session["user"] == {"email": "a.b@mail.com", "groups": ["admin"]} + :param groups_str_split: Used to split groups if provided as a string + :param check_type: Type of check to perform. + Either "one_of", "all_of" or "none_of" + """ + + def decorator(func): + def prevent_unauthenticated(): + logging.info( + "A user tried to run %s without being authenticated.", + func.__name__, + ) + raise PreventUpdate + + def prevent_unauthorised(): + logging.info( + "%s tried to run %s but did not have the right permissions.", + session["user"]["email"], + func.__name__, + ) + raise PreventUpdate + + wrapped_func = dash.callback(*callback_args, **callback_kwargs)( + protected( + unauthenticated_output=( + unauthenticated_output + if unauthenticated_output is not None + else prevent_unauthenticated + ), + missing_permissions_output=( + missing_permissions_output + if missing_permissions_output is not None + else prevent_unauthorised + ), + groups=groups, + groups_key=groups_key, + groups_str_split=groups_str_split, + check_type=check_type, + )(func) + ) + + def wrap(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrap + + return decorator diff --git a/dash_auth/oidc_auth.py b/dash_auth/oidc_auth.py new file mode 100644 index 0000000..f462ec9 --- /dev/null +++ b/dash_auth/oidc_auth.py @@ -0,0 +1,316 @@ +import logging +import os +import re +from typing import Optional, Union, TYPE_CHECKING + +import dash +from authlib.integrations.base_client import OAuthError +from authlib.integrations.flask_client import OAuth +from dash_auth.auth import Auth +from flask import Response, redirect, request, session, url_for +from werkzeug.routing import Map, Rule + +if TYPE_CHECKING: + from authlib.integrations.flask_client.apps import ( + FlaskOAuth1App, FlaskOAuth2App + ) + + +class OIDCAuth(Auth): + """Implements auth via OpenID.""" + + def __init__( + self, + app: dash.Dash, + secret_key: str = Optional[None], + force_https_callback: Optional[Union[bool, str]] = None, + login_route: str = "/oidc//login", + logout_route: str = "/oidc/logout", + callback_route: str = "/oidc//callback", + idp_selection_route: str = None, + log_signins: bool = False, + public_routes: Optional[list] = None, + logout_page: Union[str, Response] = None, + secure_session: bool = False, + ): + """Secure a Dash app through OpenID Connect. + + Parameters + ---------- + app : Dash + The Dash app to secure + secret_key : str, optional + A string to protect the Flask session, by default None. + Generate a secret key in your Python session + with the following commands: + >>> import os + >>> import base64 + >>> base64.b64encode(os.urandom(30)).decode('utf-8') + Note that you should not do this dynamically: + you should create a key and then assign the value of + that key in your code. + force_https_callback : Union[bool, str], optional + Whether to force redirection to https, by default None + This is useful when the HTTPS termination is upstream of the server + If a string is passed, this will check for the existence of + an envvar with that name and force https callback if it exists. + login_route : str, optional + The route for the login function, it requires a + placeholder, by default "/oidc//login". + logout_route : str, optional + The route for the logout function, by default "/oidc/logout". + callback_route : str, optional + The route for the OIDC redirect URI, it requires a + placeholder, by default "/oidc//callback". + idp_selection_route : str, optional + The route for the IDP selection function, by default None + log_signins : bool, optional + Whether to log signins, by default False + public_routes : list, optional + List of public routes, routes should follow the + Flask route syntax + logout_page : str or Response, optional + Page seen by the user after logging out, + by default None which will default to a simple logged out message + secure_session: bool, optional + Whether to ensure the session is secure, setting the flasck config + SESSION_COOKIE_SECURE and SESSION_COOKIE_HTTPONLY to True, + by default False + + Raises + ------ + Exception + Raise an exception if the app.server.secret_key is not defined + """ + super().__init__(app, public_routes=public_routes) + + if isinstance(force_https_callback, str): + self.force_https_callback = force_https_callback in os.environ + elif force_https_callback is not None: + self.force_https_callback = force_https_callback + else: + self.force_https_callback = False + + self.login_route = login_route + self.logout_route = logout_route + self.callback_route = callback_route + self.log_signins = log_signins + self.idp_selection_route = idp_selection_route + self.logout_page = logout_page + + if secret_key is not None: + app.server.secret_key = secret_key + + if app.server.secret_key is None: + raise RuntimeError( + """ + app.server.secret_key is missing. + Generate a secret key in your Python session + with the following commands: + >>> import os + >>> import base64 + >>> base64.b64encode(os.urandom(30)).decode('utf-8') + and assign it to the property app.server.secret_key + (where app is your dash app instance), or pass is as + the secret_key argument to OIDCAuth.__init__. + Note that you should not do this dynamically: + you should create a key and then assign the value of + that key in your code/via a secret. + """ + ) + + if secure_session: + app.server.config["SESSION_COOKIE_SECURE"] = True + app.server.config["SESSION_COOKIE_HTTPONLY"] = True + + self.oauth = OAuth(app.server) + + # Check that the login and callback rules have an placeholder + if not re.findall(r"/(?=/|$)", login_route): + raise Exception( + "The login route must contain a placeholder." + ) + if not re.findall(r"/(?=/|$)", callback_route): + raise Exception( + "The callback route must contain a placeholder." + ) + + app.server.add_url_rule( + login_route, + endpoint="oidc_login", + view_func=self.login_request, + methods=["GET"], + ) + app.server.add_url_rule( + logout_route, + endpoint="oidc_logout", + view_func=self.logout, + methods=["GET"], + ) + app.server.add_url_rule( + callback_route, + endpoint="oidc_callback", + view_func=self.callback, + methods=["GET"], + ) + + def register_provider(self, idp_name: str, **kwargs): + """Register an OpenID Connect provider. + + :param idp_name: The name of the provider + :param kwargs: Keyword arguments passed to OAuth.register. + See https://docs.authlib.org/en/latest/client/flask.html for + additional details. + Typical keyword arguments for OIDC include: + * client_id + * client_secret + * server_metadata_url + * token_endpoint_auth_method + * client_kwargs (defaults to {"scope": "openid email"}) + """ + if not re.match(r"^[\w\-\. ]+$", idp_name): + raise ValueError( + "`idp_name` should only contain letters, numbers, hyphens, " + "underscores, periods and spaces" + ) + client_kwargs = kwargs.pop("client_kwargs", {}) + client_kwargs.setdefault("scope", "openid email") + self.oauth.register( + idp_name, client_kwargs=client_kwargs, **kwargs + ) + + def get_oauth_client(self, idp: str): + """Get the OAuth client.""" + if idp not in self.oauth._registry: + raise ValueError(f"'{idp}' is not a valid registered idp") + + client: Union[FlaskOAuth1App, FlaskOAuth2App] = ( + self.oauth.create_client(idp) + ) + return client + + def get_oauth_kwargs(self, idp: str): + """Get the OAuth kwargs.""" + if idp not in self.oauth._registry: + raise ValueError(f"'{idp}' is not a valid registered idp") + + kwargs: dict = ( + self.oauth._registry[idp][1] + ) + return kwargs + + def _create_redirect_uri(self, idp: str): + """Create the redirect uri based on callback endpoint and idp.""" + kwargs = {"_external": True} + if self.force_https_callback: + kwargs["_scheme"] = "https" + redirect_uri = url_for("oidc_callback", idp=idp, **kwargs) + if request.headers.get("X-Forwarded-Host"): + host = request.headers.get("X-Forwarded-Host") + redirect_uri = redirect_uri.replace(request.host, host, 1) + return redirect_uri + + def login_request(self, idp: str = None): + """Start the login process.""" + + # `idp` can be none here as login_request is called + # without arguments in the before_request hook + if idp not in self.oauth._registry: + # If only one provider is registered, we don't need to + # ask the user to pick one, just use the one + if len(self.oauth._registry) == 1: + idp = next(iter(self.oauth._clients)) + # If there are several providers and a `idp_selection_route` + # was provided, redirect to it. + elif self.idp_selection_route: + return redirect(self.idp_selection_route) + else: + return ( + "Several OAuth providers are registered. " + "Please choose one.", + 400, + ) + + redirect_uri = self._create_redirect_uri(idp) + oauth_client = self.get_oauth_client(idp) + oauth_kwargs = self.get_oauth_kwargs(idp) + return oauth_client.authorize_redirect( + redirect_uri, + **oauth_kwargs.get("authorize_redirect_kwargs", {}), + ) + + def logout(self): # pylint: disable=C0116 + """Logout the user.""" + session.clear() + base_url = self.app.config.get("url_base_pathname") or "/" + page = self.logout_page or f""" +
+
Logged out successfully
+ +
+ """ + return page + + def callback(self, idp: str): # pylint: disable=C0116 + """Do the OIDC dance.""" + if idp not in self.oauth._registry: + return f"'{idp}' is not a valid registered idp", 400 + + oauth_client = self.get_oauth_client(idp) + oauth_kwargs = self.get_oauth_kwargs(idp) + try: + token = oauth_client.authorize_access_token( + **oauth_kwargs.get("authorize_token_kwargs", {}), + ) + except OAuthError as err: + return str(err), 401 + user = token.get("userinfo") + if user: + session["user"] = user + session["idp"] = idp + if "offline_access" in oauth_client.client_kwargs["scope"]: + session["refresh_token"] = token.get("refresh_token") + if self.log_signins: + logging.info("User %s is logging in.", user.get("email")) + + return redirect(self.app.config.get("url_base_pathname") or "/") + + def is_authorized(self): # pylint: disable=C0116 + """Check whether ther user is authenticated.""" + + map_adapter = Map( + [ + Rule(x) + for x in [ + self.login_route, + self.logout_route, + self.callback_route, + self.idp_selection_route, + ] + if x + ] + ).bind("") + return map_adapter.test(request.path) or "user" in session + + +def get_oauth(app: dash.Dash = None) -> OAuth: + """Retrieve the OAuth object. + + :param app: dash.Dash + Dash app or None, if None the current app is used + calling `dash.get_app()` + """ + if app is None: + app = dash.get_app() + + oauth = getattr(app.server, "extensions", {}).get( + "authlib.integrations.flask_client" + ) + if oauth is not None: + return oauth + + raise RuntimeError( + "OAuth object is not yet defined. `OIDCAuth(app, **kwargs)` needs " + "to be run before `get_oauth` is called." + ) diff --git a/dev-requirements.txt b/dev-requirements.txt index 10e3822..bc7984c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,3 +4,4 @@ flake8 flask werkzeug pytest +authlib diff --git a/setup.py b/setup.py index 3a71423..8ba759b 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,9 @@ 'flask', 'werkzeug', ], + extras_require={ + "oidc": ["authlib"], + }, python_requires=">=3.8", include_package_data=True, url='https://plotly.com/dash', diff --git a/tests/test_basic_auth_integration.py b/tests/test_basic_auth_integration.py index bb2cac6..1a6a534 100644 --- a/tests/test_basic_auth_integration.py +++ b/tests/test_basic_auth_integration.py @@ -1,7 +1,7 @@ from dash import Dash, Input, Output, dcc, html import requests -from dash_auth import basic_auth, add_public_routes +from dash_auth import BasicAuth, add_public_routes, protected TEST_USERS = { @@ -26,7 +26,7 @@ def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server): def update_output(new_value): return new_value - basic_auth.BasicAuth(app, TEST_USERS["valid"], public_routes=["/home"]) + BasicAuth(app, TEST_USERS["valid"], public_routes=["/home"]) add_public_routes(app, ["/user//public"]) dash_thread_server(app) @@ -58,3 +58,47 @@ def test_successful_views(url): # visiting the page again will use the saved credentials dash_br.driver.get(base_url) dash_br.wait_for_text_to_equal("#output", "initial value") + + +def test_ba002_basic_auth_groups(dash_br, dash_thread_server): + app = Dash(__name__) + app.layout = html.Div([ + dcc.Input(id="input", value="initial value"), + html.Div(id="output") + ]) + + @app.callback( + Output("output", "children"), + Input("input", "value"), + groups=["admin"], + ) + @protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + ) + def update_output(new_value): + return new_value + + BasicAuth( + app, + TEST_USERS["valid"], + public_routes=["/home"], + user_groups={"hello": ["admin"]}, + secret_key="Test!", + ) + + dash_thread_server(app) + base_url = dash_thread_server.url + + for user, password in TEST_USERS["valid"]: + # login using the URL instead of the alert popup + # selenium has no way of accessing the alert popup + dash_br.driver.get(base_url.replace("//", f"//{user}:{password}@")) + + # the username:password@host url doesn"t work right now for dash + # routes, but it saves the credentials as part of the browser. + # visiting the page again will use the saved credentials + dash_br.driver.get(base_url) + expected = "initial value" if user == "hello" else "forbidden" + dash_br.wait_for_text_to_equal("#output", expected) diff --git a/tests/test_group_protection.py b/tests/test_group_protection.py new file mode 100644 index 0000000..f8be30d --- /dev/null +++ b/tests/test_group_protection.py @@ -0,0 +1,52 @@ +from dash_auth import list_groups, check_groups, protected +from flask import Flask, session + + +def test_gp001_list_groups(): + app = Flask(__name__) + app.secret_key = "Test!" + with app.test_request_context("/", method="GET"): + session["user"] = {"email": "a.b@mail.com", "groups": ["default"], "tenant": "ABC"} + assert list_groups() == ["default"] + assert list_groups(groups_key="tenant", groups_str_split=",") == ["ABC"] + + +def test_gp002_check_groups(): + app = Flask(__name__) + app.secret_key = "Test!" + with app.test_request_context("/", method="GET"): + session["user"] = {"email": "a.b@mail.com", "groups": ["default"], "tenant": "ABC"} + assert check_groups(["default"]) is True + assert check_groups(["other"]) is False + assert check_groups(["default", "other"]) is True + assert check_groups(["other", "default"], check_type="all_of") is False + assert check_groups(["default"], check_type="all_of") is True + assert check_groups(["other", "default"], check_type="none_of") is False + assert check_groups(["other"], check_type="none_of") is True + + +def test_gp003_protected(): + app = Flask(__name__) + app.secret_key = "Test!" + + def func(): + return "success" + + with app.test_request_context("/", method="GET"): + session["user"] = {"email": "a.b@mail.com", "groups": ["default"], "tenant": "ABC"} + f0 = protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["default"], + )(func) + assert f0() == "success" + + f1 = protected( + unauthenticated_output="unauthenticated", + missing_permissions_output="forbidden", + groups=["admin"], + )(func) + assert f1() == "forbidden" + + del session["user"] + assert f1() == "unauthenticated" diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py new file mode 100644 index 0000000..5442a67 --- /dev/null +++ b/tests/test_oidc_auth.py @@ -0,0 +1,188 @@ +import os +from unittest.mock import patch + +import requests +from dash import Dash, Input, Output, dcc, html +from flask import redirect + +from dash_auth import ( + protected_callback, + OIDCAuth, +) + + +def valid_authorize_redirect(_, redirect_uri, *args, **kwargs): + return redirect("/" + redirect_uri.split("/", maxsplit=3)[-1]) + + +def invalid_authorize_redirect(_, redirect_uri, *args, **kwargs): + base_url = "/" + redirect_uri.split("/", maxsplit=3)[-1] + return redirect(f"{base_url}?error=Unauthorized&error_description=something went wrong") + + +def valid_authorize_access_token(*args, **kwargs): + return { + "userinfo": {"email": "a.b@mail.com", "groups": ["viewer", "editor"]}, + "refresh_token": "ABCDEF", + } + + +@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) +@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server): + app = Dash(__name__) + app.layout = html.Div([ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div("static", id="output3"), + html.Div("static", id="output4"), + html.Div("not static", id="output5"), + ]) + + @app.callback(Output("output1", "children"), Input("input", "value")) + def update_output1(new_value): + return new_value + + @protected_callback( + Output("output2", "children"), + Input("input", "value"), + groups=["editor"], + check_type="one_of", + ) + def update_output2(new_value): + return new_value + + @protected_callback( + Output("output3", "children"), + Input("input", "value"), + groups=["admin"], + check_type="one_of", + ) + def update_output3(new_value): + return new_value + + @protected_callback( + Output("output4", "children"), + Input("input", "value"), + groups=["viewer"], + check_type="none_of", + ) + def update_output4(new_value): + return new_value + + @protected_callback( + Output("output5", "children"), + Input("input", "value"), + groups=["viewer", "editor"], + check_type="all_of", + ) + def update_output5(new_value): + return new_value + + oidc = OIDCAuth(app, secret_key="Test") + oidc.register_provider( + "oidc", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration", + ) + dash_thread_server(app) + base_url = dash_thread_server.url + + assert requests.get(base_url).status_code == 200 + + dash_br.driver.get(base_url) + dash_br.wait_for_text_to_equal("#output1", "initial value") + dash_br.wait_for_text_to_equal("#output2", "initial value") + dash_br.wait_for_text_to_equal("#output3", "static") + dash_br.wait_for_text_to_equal("#output4", "static") + dash_br.wait_for_text_to_equal("#output5", "initial value") + + +@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect) +def test_oa002_oidc_auth_login_fail(dash_thread_server): + app = Dash(__name__) + app.layout = html.Div([ + dcc.Input(id="input", value="initial value"), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(new_value): + return new_value + + oidc = OIDCAuth(app, public_routes=["/public"], secret_key="Test") + oidc.register_provider( + "oidc", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration", + ) + dash_thread_server(app) + base_url = dash_thread_server.url + + def test_unauthorized(url): + r = requests.get(url) + assert r.status_code == 401 + assert r.text == "Unauthorized: something went wrong" + + def test_authorized(url): + assert requests.get(url).status_code == 200 + + test_unauthorized(base_url) + test_authorized(os.path.join(base_url, "public")) + + +@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect) +@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token) +def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server): + app = Dash(__name__) + app.layout = html.Div([ + dcc.Input(id="input", value="initial value"), + html.Div(id="output1"), + ]) + + @app.callback(Output("output1", "children"), Input("input", "value")) + def update_output1(new_value): + return new_value + + oidc = OIDCAuth(app, secret_key="Test") + # Add a first provider + oidc.register_provider( + "idp1", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration", + ) + # Add a second provider + oidc.register_provider( + "idp2", + token_endpoint_auth_method="client_secret_post", + client_id="", + client_secret="", + server_metadata_url="https://idp2.com/oidc/2/.well-known/openid-configuration", + ) + + dash_thread_server(app) + base_url = dash_thread_server.url + + assert requests.get(base_url).status_code == 400 + + # Login with IDP1 + assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200 + + # Logout + assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200 + + assert requests.get(base_url).status_code == 400 + + # Login with IDP2 + assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200 + + dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login")) + dash_br.driver.get(base_url) + dash_br.wait_for_text_to_equal("#output1", "initial value")