From 5cf4d6a309a0034e2dbb0bc0383da6ea443cd82f Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 29 Oct 2014 17:20:19 +0800 Subject: [PATCH] Use a function to get csrf token. https://github.com/lepture/flask-wtf/pull/159#discussion_r19521736 --- flask_wtf/csrf.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/flask_wtf/csrf.py b/flask_wtf/csrf.py index 9411c6c8..6ae5e017 100644 --- a/flask_wtf/csrf.py +++ b/flask_wtf/csrf.py @@ -146,6 +146,22 @@ def init_app(self, app): app.config.setdefault('WTF_CSRF_ENABLED', True) app.config.setdefault('WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH']) + def _get_csrf_token(): + # find the ``csrf_token`` field in the subitted form + # if the form had a prefix, the name will be + # ``{prefix}-csrf_token`` + for key in request.form: + if key.endswith('csrf_token'): + csrf_token = request.form[key] + if csrf_token: + return csrf_token + + for header_name in app.config['WTF_CSRF_HEADERS']: + csrf_token = request.headers.get(header_name) + if csrf_token: + return csrf_token + return None + # expose csrf_token as a helper in all templates @app.context_processor def csrf_token(): @@ -157,7 +173,7 @@ def _csrf_protect(): if not app.config['WTF_CSRF_ENABLED']: return - if request.method in ('GET', 'HEAD', 'OPTIONS', 'TRACE'): + if request.method not in app.config['WTF_CSRF_METHODS']: return if self._exempt_views or self._exempt_blueprints: @@ -174,18 +190,7 @@ def _csrf_protect(): if request.blueprint in self._exempt_blueprints: return - csrf_token = None - if request.method in app.config['WTF_CSRF_METHODS']: - # find the ``csrf_token`` field in the subitted form - # if the form had a prefix, the name will be ``{prefix}-csrf_token`` - for key in request.form: - if key.endswith('csrf_token'): - csrf_token = request.form[key] - for header_name in app.config['WTF_CSRF_HEADERS']: - if csrf_token is not None: - break - csrf_token = request.headers.get(header_name) - if not validate_csrf(csrf_token): + if not validate_csrf(_get_csrf_token()): reason = 'CSRF token missing or incorrect.' return self._error_response(reason)