diff --git a/CHANGELOG.md b/CHANGELOG.md index 4171da9b58..0dbf505395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). ### Added - [#1702](https://github.com/plotly/dash/pull/1702) Added a new `@app.long_callback` decorator to support callback functions that take a long time to run. See the PR and documentation for more information. - [#1514](https://github.com/plotly/dash/pull/1514) Perform json encoding using the active plotly JSON engine. This will default to the faster orjson encoder if the `orjson` package is installed. +- [#1736](https://github.com/plotly/dash/pull/1736) Add support for `request_refresh_jwt` hook and retry requests that used expired JWT tokens. ### Changed - [#1679](https://github.com/plotly/dash/pull/1679) Restructure `dash`, `dash-core-components`, `dash-html-components`, and `dash-table` into a singular monorepo and move component packages into `dash`. This change makes the component modules available for import within the `dash` namespace, and simplifies the import pattern for a Dash app. From a development standpoint, all future changes to component modules will be made within the `components` directory, and relevant packages updated with the `dash-update-components` CLI command. diff --git a/dash/dash-renderer/src/AppContainer.react.js b/dash/dash-renderer/src/AppContainer.react.js index 92f4325bbf..0ef7af6412 100644 --- a/dash/dash-renderer/src/AppContainer.react.js +++ b/dash/dash-renderer/src/AppContainer.react.js @@ -6,16 +6,29 @@ import Loading from './components/core/Loading.react'; import Toolbar from './components/core/Toolbar.react'; import Reloader from './components/core/Reloader.react'; import {setHooks, setConfig} from './actions/index'; -import {type} from 'ramda'; +import {type, memoizeWith, identity} from 'ramda'; class UnconnectedAppContainer extends React.Component { constructor(props) { super(props); if ( props.hooks.request_pre !== null || - props.hooks.request_post !== null + props.hooks.request_post !== null || + props.hooks.request_refresh_jwt !== null ) { - props.dispatch(setHooks(props.hooks)); + let hooks = props.hooks; + + if (hooks.request_refresh_jwt) { + hooks = { + ...hooks, + request_refresh_jwt: memoizeWith( + identity, + hooks.request_refresh_jwt + ) + }; + } + + props.dispatch(setHooks(hooks)); } } diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 4d811eb130..4069715c04 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -18,14 +18,16 @@ const AppProvider = ({hooks}: any) => { AppProvider.propTypes = { hooks: PropTypes.shape({ request_pre: PropTypes.func, - request_post: PropTypes.func + request_post: PropTypes.func, + request_refresh_jwt: PropTypes.func }) }; AppProvider.defaultProps = { hooks: { request_pre: null, - request_post: null + request_post: null, + request_refresh_jwt: null } }; diff --git a/dash/dash-renderer/src/actions/api.js b/dash/dash-renderer/src/actions/api.js index a2f422aa3a..3f617d5c11 100644 --- a/dash/dash-renderer/src/actions/api.js +++ b/dash/dash-renderer/src/actions/api.js @@ -1,6 +1,8 @@ import {mergeDeepRight, once} from 'ramda'; -import {handleAsyncError, getCSRFHeader} from '../actions'; +import {getCSRFHeader, handleAsyncError, addHttpHeaders} from '../actions'; import {urlBase} from './utils'; +import {MAX_AUTH_RETRIES} from './constants'; +import {JWT_EXPIRED_MESSAGE, STATUS} from '../constants/constants'; /* eslint-disable-next-line no-console */ const logWarningOnce = once(console.warn); @@ -29,8 +31,10 @@ function POST(path, fetchConfig, body = {}) { const request = {GET, POST}; export default function apiThunk(endpoint, method, store, id, body) { - return (dispatch, getState) => { - const {config} = getState(); + return async (dispatch, getState) => { + let {config, hooks} = getState(); + let newHeaders = null; + const url = `${urlBase(config)}${endpoint}`; function setConnectionStatus(connected) { @@ -46,48 +50,81 @@ export default function apiThunk(endpoint, method, store, id, body) { type: store, payload: {id, status: 'loading'} }); - return request[method](url, config.fetch, body) - .then( - res => { - setConnectionStatus(true); - const contentType = res.headers.get('content-type'); - if ( - contentType && - contentType.indexOf('application/json') !== -1 - ) { - return res.json().then(json => { - dispatch({ - type: store, - payload: { - status: res.status, - content: json, - id - } - }); - return json; - }); + + try { + let res; + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { + try { + res = await request[method](url, config.fetch, body); + } catch (e) { + // fetch rejection - this means the request didn't return, + // we don't get here from 400/500 errors, only network + // errors or unresponsive servers. + console.log('fetch error', res); + setConnectionStatus(false); + return; + } + + if (res.status === STATUS.UNAUTHORIZED) { + if (hooks.request_refresh_jwt) { + const body = await res.text(); + if (body.includes(JWT_EXPIRED_MESSAGE)) { + const newJwt = await hooks.request_refresh_jwt( + config.fetch.headers.Authorization.substr( + 'Bearer '.length + ) + ); + if (newJwt) { + newHeaders = { + Authorization: `Bearer ${newJwt}` + }; + + config = mergeDeepRight(config, { + fetch: { + headers: newHeaders + } + }); + + continue; + } + } } - logWarningOnce( - 'Response is missing header: content-type: application/json' - ); - return dispatch({ + } + break; + } + + const contentType = res.headers.get('content-type'); + + if (newHeaders) { + dispatch(addHttpHeaders(newHeaders)); + } + setConnectionStatus(true); + if (contentType && contentType.indexOf('application/json') !== -1) { + return res.json().then(json => { + dispatch({ type: store, payload: { - id, - status: res.status + status: res.status, + content: json, + id } }); - }, - () => { - // fetch rejection - this means the request didn't return, - // we don't get here from 400/500 errors, only network - // errors or unresponsive servers. - setConnectionStatus(false); + return json; + }); + } + logWarningOnce( + 'Response is missing header: content-type: application/json' + ); + return dispatch({ + type: store, + payload: { + id, + status: res.status } - ) - .catch(err => { - const message = 'Error from API call: ' + endpoint; - handleAsyncError(err, message, dispatch); }); + } catch (err) { + const message = 'Error from API call: ' + endpoint; + handleAsyncError(err, message, dispatch); + } }; } diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 1a59cfa240..2769d4c5eb 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -10,7 +10,8 @@ import { zip } from 'ramda'; -import {STATUS} from '../constants/constants'; +import {STATUS, JWT_EXPIRED_MESSAGE} from '../constants/constants'; +import {MAX_AUTH_RETRIES} from './constants'; import { CallbackActionType, CallbackAggregateActionType @@ -29,6 +30,7 @@ import {isMultiValued, stringifyId, isMultiOutputProp} from './dependencies'; import {urlBase} from './utils'; import {getCSRFHeader} from '.'; import {createAction, Action} from 'redux-actions'; +import {addHttpHeaders} from '../actions'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -306,7 +308,7 @@ function handleServerside( config: any, payload: any ): Promise { - if (hooks.request_pre !== null) { + if (hooks.request_pre) { hooks.request_pre(payload); } @@ -364,7 +366,7 @@ function handleServerside( if (status === STATUS.OK) { return res.json().then((data: any) => { const {multi, response} = data; - if (hooks.request_post !== null) { + if (hooks.request_post) { hooks.request_post(payload, response); } @@ -488,7 +490,7 @@ export function executeCallback( }; } - const __promise = new Promise(resolve => { + const __execute = async (): Promise => { try { const payload: ICallbackPayload = { output, @@ -502,7 +504,7 @@ export function executeCallback( if (clientside_function) { try { - resolve({ + return { data: handleClientside( dispatch, clientside_function, @@ -510,24 +512,81 @@ export function executeCallback( payload ), payload - }); + }; } catch (error) { - resolve({error, payload}); + return {error, payload}; } - return null; } - handleServerside(dispatch, hooks, config, payload) - .then(data => resolve({data, payload})) - .catch(error => resolve({error, payload})); + let newConfig = config; + let newHeaders: Record | null = null; + let lastError: any; + + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { + try { + const data = await handleServerside( + dispatch, + hooks, + newConfig, + payload + ); + + if (newHeaders) { + dispatch(addHttpHeaders(newHeaders)); + } + + return {data, payload}; + } catch (res) { + lastError = res; + if ( + retry <= MAX_AUTH_RETRIES && + res.status === STATUS.UNAUTHORIZED + ) { + const body = await res.text(); + + if (body.includes(JWT_EXPIRED_MESSAGE)) { + if (hooks.request_refresh_jwt !== null) { + let oldJwt = null; + if (config.fetch.headers.Authorization) { + oldJwt = + config.fetch.headers.Authorization.substr( + 'Bearer '.length + ); + } + + const newJwt = + await hooks.request_refresh_jwt(oldJwt); + if (newJwt) { + newHeaders = { + Authorization: `Bearer ${newJwt}` + }; + + newConfig = mergeDeepRight(config, { + fetch: { + headers: newHeaders + } + }); + + continue; + } + } + } + } + + break; + } + } + + // we reach here when we run out of retries. + return {error: lastError, payload: null}; } catch (error) { - resolve({error, payload: null}); + return {error, payload: null}; } - }); + }; const newCb = { ...cb, - executionPromise: __promise + executionPromise: __execute() }; return newCb; diff --git a/dash/dash-renderer/src/actions/constants.js b/dash/dash-renderer/src/actions/constants.js index 3d61debf6c..352f25be88 100644 --- a/dash/dash-renderer/src/actions/constants.js +++ b/dash/dash-renderer/src/actions/constants.js @@ -6,6 +6,7 @@ const actionList = { SET_LAYOUT: 1, SET_APP_LIFECYCLE: 1, SET_CONFIG: 1, + ADD_HTTP_HEADERS: 1, ON_ERROR: 1, SET_HOOKS: 1 }; @@ -16,3 +17,5 @@ export const getAction = action => { } throw new Error(`${action} is not defined.`); }; + +export const MAX_AUTH_RETRIES = 1; diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index 10bf413fb0..2b1dd51324 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -11,6 +11,7 @@ import {getPath} from './paths'; export const onError = createAction(getAction('ON_ERROR')); export const setAppLifecycle = createAction(getAction('SET_APP_LIFECYCLE')); export const setConfig = createAction(getAction('SET_CONFIG')); +export const addHttpHeaders = createAction(getAction('ADD_HTTP_HEADERS')); export const setGraphs = createAction(getAction('SET_GRAPHS')); export const setHooks = createAction(getAction('SET_HOOKS')); export const setLayout = createAction(getAction('SET_LAYOUT')); diff --git a/dash/dash-renderer/src/constants/constants.js b/dash/dash-renderer/src/constants/constants.js index c37b3b1f75..80d50fd164 100644 --- a/dash/dash-renderer/src/constants/constants.js +++ b/dash/dash-renderer/src/constants/constants.js @@ -1,9 +1,11 @@ export const REDIRECT_URI_PATHNAME = '/_oauth2/callback'; export const OAUTH_COOKIE_NAME = 'plotly_oauth_token'; +export const JWT_EXPIRED_MESSAGE = 'JWT Expired'; export const STATUS = { OK: 200, PREVENT_UPDATE: 204, + UNAUTHORIZED: 401, CLIENTSIDE_ERROR: 'CLIENTSIDE_ERROR', NO_RESPONSE: 'NO_RESPONSE' }; diff --git a/dash/dash-renderer/src/reducers/config.js b/dash/dash-renderer/src/reducers/config.js index 97c5fb2a1d..208944533c 100644 --- a/dash/dash-renderer/src/reducers/config.js +++ b/dash/dash-renderer/src/reducers/config.js @@ -1,8 +1,15 @@ import {getAction} from '../actions/constants'; +import {mergeDeepRight} from 'ramda'; export default function config(state = null, action) { if (action.type === getAction('SET_CONFIG')) { return action.payload; + } else if (action.type === getAction('ADD_HTTP_HEADERS')) { + return mergeDeepRight(state, { + fetch: { + headers: action.payload + } + }); } return state; } diff --git a/dash/dash-renderer/src/reducers/hooks.js b/dash/dash-renderer/src/reducers/hooks.js index 21f5f5a9ae..2fdb665d68 100644 --- a/dash/dash-renderer/src/reducers/hooks.js +++ b/dash/dash-renderer/src/reducers/hooks.js @@ -1,5 +1,10 @@ const customHooks = ( - state = {request_pre: null, request_post: null, bear: false}, + state = { + request_pre: null, + request_post: null, + request_refresh_jwt: null, + bear: false + }, action ) => { switch (action.type) { diff --git a/tests/integration/renderer/test_request_hooks.py b/tests/integration/renderer/test_request_hooks.py index 72ed12e486..0afaf56ea8 100644 --- a/tests/integration/renderer/test_request_hooks.py +++ b/tests/integration/renderer/test_request_hooks.py @@ -1,6 +1,9 @@ import json +import functools +import flask from dash import Dash, Output, Input, html, dcc +from werkzeug.exceptions import HTTPException def test_rdrh001_request_hooks(dash_duo): @@ -185,3 +188,104 @@ def update_output(value): assert dash_duo.find_element("#output-post").text == "request_post!!!" dash_duo.percy_snapshot(name="request-hooks interpolated") + + +def test_rdrh003_refresh_jwt(dash_duo): + + app = Dash(__name__) + + app.index_string = """ + + + {%metas%} + {%title%} + {%favicon%} + {%css%} + + +
Testing custom DashRenderer
+ {%app_entry%} +
+ {%config%} + {%scripts%} + +
+
With request hooks
+ + """ + + app.layout = html.Div( + [ + dcc.Input(id="input", value="initial value"), + html.Div(html.Div([html.Div(id="output-1"), html.Div(id="output-token")])), + ] + ) + + @app.callback(Output("output-1", "children"), [Input("input", "value")]) + def update_output(value): + return value + + required_jwt_len = 0 + + # test with an auth layer that requires a JWT with a certain length + def protect_route(func): + @functools.wraps(func) + def wrap(*args, **kwargs): + try: + if flask.request.method == "OPTIONS": + return func(*args, **kwargs) + token = ( + flask.request.authorization + or flask.request.headers.environ.get("HTTP_AUTHORIZATION") + ) + if required_jwt_len and ( + not token or len(token) != required_jwt_len + len("Bearer ") + ): + flask.abort(401, description="JWT Expired " + str(token)) + except HTTPException as e: + return e + return func(*args, **kwargs) + + return wrap + + # wrap all API calls with auth. + for name, method in ( + (x, app.server.view_functions[x]) + for x in app.routes + if x in app.server.view_functions + ): + app.server.view_functions[name] = protect_route(method) + + dash_duo.start_server(app) + + _in = dash_duo.find_element("#input") + dash_duo.clear_input(_in) + + required_jwt_len = 1 + + _in.send_keys("fired request") + + dash_duo.wait_for_text_to_equal("#output-1", "fired request") + dash_duo.wait_for_text_to_equal("#output-token", ".") + + required_jwt_len = 2 + + dash_duo.clear_input(_in) + _in.send_keys("fired request again") + + dash_duo.wait_for_text_to_equal("#output-1", "fired request again") + dash_duo.wait_for_text_to_equal("#output-token", "..") + + dash_duo.percy_snapshot(name="request-hooks jwt-refresh")