From e85808b9e7432f4e5b782032e32ad8cc3939e7d1 Mon Sep 17 00:00:00 2001 From: Nuno Santos Date: Fri, 13 Feb 2015 15:38:55 +0100 Subject: [PATCH] Abstract protect_csrf() into a separate method. Right now, CsrfProtect._csrf_protect() does 1) check if this view should be checked for a CSRF token, and 2) validate the actual token. This commit abstracts 2) into a separate method so we can manually call this method (for example on a before_request callback). This makes it possible to do further checks before validating the CSRF (e.g. skip the check for REST calls using OAuth). This commit also adds a configuration parameter WTF_CSRF_CHECK_DEFAULT, which will determine whether to check all views by default or not. It defaults to True. --- flask_wtf/csrf.py | 67 +++++++++++++++++++++++++++------------------- tests/test_csrf.py | 27 +++++++++++++++++++ 2 files changed, 66 insertions(+), 28 deletions(-) diff --git a/flask_wtf/csrf.py b/flask_wtf/csrf.py index 6ae5e017..45877615 100644 --- a/flask_wtf/csrf.py +++ b/flask_wtf/csrf.py @@ -138,30 +138,16 @@ def __init__(self, app=None): self.init_app(app) def init_app(self, app): + self._app = app app.jinja_env.globals['csrf_token'] = generate_csrf app.config.setdefault( 'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token'] ) app.config.setdefault('WTF_CSRF_SSL_STRICT', True) app.config.setdefault('WTF_CSRF_ENABLED', True) + app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True) app.config.setdefault('WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH']) - def _get_csrf_token(): - # find the ``csrf_token`` field in the subitted form - # if the form had a prefix, the name will be - # ``{prefix}-csrf_token`` - for key in request.form: - if key.endswith('csrf_token'): - csrf_token = request.form[key] - if csrf_token: - return csrf_token - - for header_name in app.config['WTF_CSRF_HEADERS']: - csrf_token = request.headers.get(header_name) - if csrf_token: - return csrf_token - return None - # expose csrf_token as a helper in all templates @app.context_processor def csrf_token(): @@ -173,6 +159,9 @@ def _csrf_protect(): if not app.config['WTF_CSRF_ENABLED']: return + if not app.config['WTF_CSRF_CHECK_DEFAULT']: + return + if request.method not in app.config['WTF_CSRF_METHODS']: return @@ -190,21 +179,43 @@ def _csrf_protect(): if request.blueprint in self._exempt_blueprints: return - if not validate_csrf(_get_csrf_token()): - reason = 'CSRF token missing or incorrect.' - return self._error_response(reason) + self.protect() + + def _get_csrf_token(self): + # find the ``csrf_token`` field in the subitted form + # if the form had a prefix, the name will be + # ``{prefix}-csrf_token`` + for key in request.form: + if key.endswith('csrf_token'): + csrf_token = request.form[key] + if csrf_token: + return csrf_token + + for header_name in self._app.config['WTF_CSRF_HEADERS']: + csrf_token = request.headers.get(header_name) + if csrf_token: + return csrf_token + return None + + def protect(self): + if request.method not in self._app.config['WTF_CSRF_METHODS']: + return - if request.is_secure and app.config['WTF_CSRF_SSL_STRICT']: - if not request.referrer: - reason = 'Referrer checking failed - no Referrer.' - return self._error_response(reason) + if not validate_csrf(self._get_csrf_token()): + reason = 'CSRF token missing or incorrect.' + return self._error_response(reason) - good_referrer = 'https://%s/' % request.host - if not same_origin(request.referrer, good_referrer): - reason = 'Referrer checking failed - origin not match.' - return self._error_response(reason) + if request.is_secure and self._app.config['WTF_CSRF_SSL_STRICT']: + if not request.referrer: + reason = 'Referrer checking failed - no Referrer.' + return self._error_response(reason) + + good_referrer = 'https://%s/' % request.host + if not same_origin(request.referrer, good_referrer): + reason = 'Referrer checking failed - origin does not match.' + return self._error_response(reason) - request.csrf_valid = True # mark this request is csrf valid + request.csrf_valid = True # mark this request is csrf valid def exempt(self, view): """A decorator that can exclude a view from csrf protection. diff --git a/tests/test_csrf.py b/tests/test_csrf.py index b7c1ee79..5732644c 100644 --- a/tests/test_csrf.py +++ b/tests/test_csrf.py @@ -38,6 +38,12 @@ def csrf_exempt(): "index.html", form=form, name=name ) + @csrf.exempt + @app.route('/csrf-protect-method', methods=['GET', 'POST']) + def csrf_protect_method(): + csrf.protect() + return 'protected' + bp = Blueprint('csrf', __name__) @bp.route('/foo', methods=['GET', 'POST']) @@ -170,6 +176,27 @@ def test_valid_secure_csrf(self): ) assert response.status_code == 200 + def test_valid_csrf_method(self): + response = self.client.get("/") + csrf_token = get_csrf_token(response.data) + + response = self.client.post("/csrf-protect-method", data={ + "csrf_token": csrf_token + }) + assert response.status_code == 200 + + def test_invalid_csrf_method(self): + response = self.client.post("/csrf-protect-method", data={"name": "danny"}) + assert response.status_code == 400 + + @self.csrf.error_handler + def invalid(reason): + return reason + + response = self.client.post("/", data={"name": "danny"}) + assert response.status_code == 200 + assert b'token missing' in response.data + def test_empty_csrf_headers(self): response = self.client.get("/", base_url='https://localhost/') csrf_token = get_csrf_token(response.data)