From 56d728deb0ea4be23e295bb33005adf379d72a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20Hartvig=20Gr=C3=B8nbech?= Date: Tue, 3 Jul 2018 01:45:26 +0200 Subject: [PATCH 1/3] Added additional claims functionality to lib --- flask_jwt_extended/config.py | 4 + flask_jwt_extended/default_callbacks.py | 22 +++ flask_jwt_extended/exceptions.py | 12 ++ flask_jwt_extended/jwt_manager.py | 75 ++++++++- flask_jwt_extended/tokens.py | 49 ++++-- flask_jwt_extended/utils.py | 15 +- tests/test_additional_claims_loader.py | 165 ++++++++++++++++++++ tests/test_additional_claims_verificaion.py | 114 ++++++++++++++ tests/test_config.py | 1 + 9 files changed, 442 insertions(+), 15 deletions(-) create mode 100644 tests/test_additional_claims_loader.py create mode 100644 tests/test_additional_claims_verificaion.py diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 9d71f492..93c968e7 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -247,6 +247,10 @@ def identity_claim_key(self): def user_claims_key(self): return current_app.config['JWT_USER_CLAIMS'] + @property + def additional_claim_keys(self): + return current_app.config['JWT_ADDITIONAL_CLAIMS'] + @property def user_claims_in_refresh_token(self): return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index c6d31826..085aa3f5 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -20,6 +20,16 @@ def default_user_claims_callback(userdata): return {} +def default_additional_claims_callback(userdata): + """ + By default, we add no additional claims to the access token + + :param userdata: data passed in as the ```identity``` argument to the + ```create_access_token``` and ```create_refresh_token``` + functions + """ + return {} + def default_user_identity_callback(userdata): """ By default, we use the passed in object directly as the jwt identity. @@ -91,6 +101,11 @@ def default_claims_verification_callback(user_claims): """ return True +def default_additional_claims_verification_callback(additional_claims): + """ + By default, we do not do any verification of the additional claims. + """ + return True def default_claims_verification_failed_callback(): """ @@ -98,3 +113,10 @@ def default_claims_verification_failed_callback(): error message with a 400 status code """ return jsonify({'msg': 'User claims verification failed'}), 400 + +def default_additional_claims_verification_failed_callback(): + """ + By default, if the additional claims verification failed, we return a generic + error message with a 400 status code + """ + return jsonify({'msg': 'Additional claims verification failed'}), 400 diff --git a/flask_jwt_extended/exceptions.py b/flask_jwt_extended/exceptions.py index 921d98e6..6396047f 100644 --- a/flask_jwt_extended/exceptions.py +++ b/flask_jwt_extended/exceptions.py @@ -11,6 +11,11 @@ class JWTDecodeError(JWTExtendedException): """ pass +class JWTEncodeError(JWTExtendedException): + """ + An error encoding a JWT + """ + pass class InvalidHeaderError(JWTExtendedException): """ @@ -70,3 +75,10 @@ class UserClaimsVerificationError(JWTExtendedException): indicating that the expected user claims are invalid """ pass + +class AdditionalClaimsVerificationError(JWTExtendedException): + """ + Error raised when the additional_claims_verification_callback function returns False, + indicating that the expected user claims are invalid + """ + pass diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 586400e8..eed1a11c 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -4,9 +4,9 @@ from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( - JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError, + JWTDecodeError, JWTEncodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError, RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError, - UserClaimsVerificationError + UserClaimsVerificationError, AdditionalClaimsVerificationError ) from flask_jwt_extended.default_callbacks import ( default_expired_token_callback, default_user_claims_callback, @@ -14,7 +14,10 @@ default_unauthorized_callback, default_needs_fresh_token_callback, default_revoked_token_callback, default_user_loader_error_callback, default_claims_verification_callback, - default_claims_verification_failed_callback + default_claims_verification_failed_callback, + default_additional_claims_callback, + default_additional_claims_verification_callback, + default_additional_claims_verification_failed_callback ) from flask_jwt_extended.tokens import ( encode_refresh_token, encode_access_token @@ -43,6 +46,7 @@ def __init__(self, app=None): # Register the default error handler callback methods. These can be # overridden with the appropriate loader decorators self._user_claims_callback = default_user_claims_callback + self._additonal_claims_callback = default_additional_claims_callback self._user_identity_callback = default_user_identity_callback self._expired_token_callback = default_expired_token_callback self._invalid_token_callback = default_invalid_token_callback @@ -54,6 +58,8 @@ def __init__(self, app=None): self._token_in_blacklist_callback = None self._claims_verification_callback = default_claims_verification_callback self._claims_verification_failed_callback = default_claims_verification_failed_callback + self._additonal_claims_verification_callback = default_additional_claims_verification_callback + self._additonal_claims_verification_failed_callback = default_additional_claims_verification_failed_callback # Register this extension with the flask app now (if it is provided) if app is not None: @@ -101,7 +107,7 @@ def handle_invalid_token_error(e): @app.errorhandler(JWTDecodeError) def handle_jwt_decode_error(e): return self._invalid_token_callback(str(e)) - + @app.errorhandler(WrongTokenError) def handle_wrong_token_error(e): return self._invalid_token_callback(str(e)) @@ -126,6 +132,10 @@ def handler_user_load_error(e): def handle_failed_user_claims_verification(e): return self._claims_verification_failed_callback() + @app.errorhandler(AdditionalClaimsVerificationError) + def handle_failed_additional_claims_verification(e): + return self._additonal_claims_verification_failed_callback() + @staticmethod def _set_default_configuration_options(app): """ @@ -186,6 +196,7 @@ def _set_default_configuration_options(app): app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity') app.config.setdefault('JWT_USER_CLAIMS', 'user_claims') + app.config.setdefault('JWT_ADDITIONAL_CLAIMS', []) app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False) @@ -204,6 +215,21 @@ def user_claims_loader(self, callback): self._user_claims_callback = callback return callback + def additional_claims_loader(self, callback): + """ + This decorator sets the callback function for adding additional claims + the access token when :func:`~flask_jwt_extended.create_access_token` is + called. By defailt, no additional claims will be added. + + The callback function must be a function that takes only one argument, + which is the object passed into + :func:`~flask_jwt_extended.create_access_token`, and returns the custom + claims you want included in the access tokens. This returned claims + must be JSON serializable. + """ + self._additonal_claims_callback = callback + return callback + def user_identity_loader(self, callback): """ This decorator sets the callback function for getting the JSON @@ -373,6 +399,39 @@ def claims_verification_failed_loader(self, callback): self._claims_verification_failed_callback = callback return callback + def additonal_claims_verification_loader(self, callback): + """ + This decorator sets the callback function that will be called when + a protected endpoint is accessed, and will check if the custom claims + in the JWT are valid. By default, this callback is not used. The + error returned if the claims are invalid can be controlled via the + :meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader` + decorator. + + This callback must be a function that takes one argument, which is the + custom claims (python dict) present in the JWT, and returns `True` if the + claims are valid, or `False` otherwise. + """ + + self._additonal_claims_verification_callback = callback + return callback + + def additonal_claims_verification_failed_loader(self, callback): + """ + This decorator sets the callback function that will be called if + the :meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader` + callback returns False, indicating that the user claims are not valid. + The default implementation will return a 400 status code with the JSON: + + {"msg": "User claims verification failed"} + + This callback must be a function that takes no arguments, and returns + a Flask response. + """ + self._additonal_claims_verification_failed_callback = callback + return callback + + def _create_refresh_token(self, identity, expires_delta=None): if expires_delta is None: expires_delta = config.refresh_expires @@ -382,12 +441,18 @@ def _create_refresh_token(self, identity, expires_delta=None): else: user_claims = None + if config.user_claims_in_refresh_token: + additional_claims = self._additonal_claims_callback(identity) + else: + additional_claims = None + refresh_token = encode_refresh_token( identity=self._user_identity_callback(identity), secret=config.encode_key, algorithm=config.algorithm, expires_delta=expires_delta, user_claims=user_claims, + additional_claims=additional_claims, csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, @@ -406,10 +471,10 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None): expires_delta=expires_delta, fresh=fresh, user_claims=self._user_claims_callback(identity), + additional_claims=self._additonal_claims_callback(identity), csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, json_encoder=config.json_encoder ) return access_token - diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 7f500d9b..4c12e481 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -6,15 +6,19 @@ import jwt from werkzeug.security import safe_str_cmp -from flask_jwt_extended.exceptions import JWTDecodeError, CSRFError - +from flask_jwt_extended.exceptions import JWTDecodeError, JWTEncodeError, CSRFError +from flask_jwt_extended.config import config def _create_csrf_token(): return str(uuid.uuid4()) +def _check_claims(default_claims, additional_claims): + for claim in default_claims: + if claim in additional_claims: + raise JWTEncodeError("Claim %s in conflict with default claims" % str(claim)) def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, - json_encoder=None): + claim_key, json_encoder=None): uid = str(uuid.uuid4()) now = datetime.datetime.utcnow() token_data = { @@ -26,15 +30,19 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, # and the 'exp' claim is not set. if expires_delta: token_data['exp'] = now + expires_delta + + # Make sure additional_token_data is in conflict with default claims + _check_claims(['iat', 'nbf', 'jti', 'exp'], additional_token_data) token_data.update(additional_token_data) + encoded_token = jwt.encode(token_data, secret, algorithm, json_encoder=json_encoder).decode('utf-8') return encoded_token def encode_access_token(identity, secret, algorithm, expires_delta, fresh, - user_claims, csrf, identity_claim_key, user_claims_key, - json_encoder=None): + user_claims, additional_claims, csrf, identity_claim_key, + user_claims_key, json_encoder=None): """ Creates a new encoded (utf-8) access token. @@ -50,6 +58,8 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, token will remain fresh. :param user_claims: Custom claims to include in this token. This data must be json serializable + :param additional_claims: Custom claims to include in this token. Object + must be json serializable :param csrf: Whether to include a csrf double submit claim in this token (boolean) :param identity_claim_key: Which key should be used to store the identity @@ -67,18 +77,25 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, 'type': 'access', } + # Make sure additional_token_data is in conflict with default claims + _check_claims(['fresh', 'type', identity_claim_key], additional_claims) + # Don't add extra data to the token if user_claims is empty. if user_claims: token_data[user_claims_key] = user_claims + # Make sure additional claims is a dict before merge + if additional_claims and isinstance(additional_claims, dict): + token_data.update(additional_claims) + if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - json_encoder=json_encoder) + claim_key=identity_claim_key, json_encoder=json_encoder) def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims, - csrf, identity_claim_key, user_claims_key, + additional_claims, csrf, identity_claim_key, user_claims_key, json_encoder=None): """ Creates a new encoded (utf-8) refresh token. @@ -91,6 +108,8 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims :type expires_delta: datetime.timedelta or False :param user_claims: Custom claims to include in this token. This data must be json serializable + :param additional_claims: Custom claims to include in this token. Object + must be json serializable :param csrf: Whether to include a csrf double submit claim in this token (boolean) :param identity_claim_key: Which key should be used to store the identity @@ -102,6 +121,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims 'type': 'refresh', } + # Make sure additional_token_data is in conflict with default claims + if additional_claims and isinstance(additional_claims, dict): + _check_claims(['type', identity_claim_key], additional_claims) + token_data.update(additional_claims) + # Don't add extra data to the token if user_claims is empty. if user_claims: token_data[user_claims_key] = user_claims @@ -109,11 +133,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - json_encoder=json_encoder) + claim_key=identity_claim_key, json_encoder=json_encoder) def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, - user_claims_key, csrf_value=None): + user_claims_key, additional_claim_keys, csrf_value=None): """ Decodes an encoded JWT @@ -138,8 +162,15 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, if data['type'] == 'access': if 'fresh' not in data: raise JWTDecodeError("Missing claim: fresh") + if user_claims_key not in data: data[user_claims_key] = {} + + if data['type'] != 'refresh' or config.user_claims_in_refresh_token: + for claim in additional_claim_keys: + if claim not in data: + raise JWTDecodeError("Missing claim %s" % str(claim)) + if csrf_value: if 'csrf' not in data: raise JWTDecodeError("Missing claim: csrf") diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 1e13f35c..d2a24aed 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -8,7 +8,7 @@ from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( - RevokedTokenError, UserClaimsVerificationError, WrongTokenError + RevokedTokenError, UserClaimsVerificationError, WrongTokenError, AdditionalClaimsVerificationError ) from flask_jwt_extended.tokens import decode_jwt @@ -42,6 +42,15 @@ def get_jwt_claims(): """ return get_raw_jwt().get(config.user_claims_key, {}) +def get_jwt_additional_claims(): + """ + In a protected endpoint, this will return the dictionary of custom claims + in the JWT that is accessing the endpoint. If no custom user claims are + present, an empty dict is returned instead. + """ + claims = get_raw_jwt() + return dict((k, claims.get(k)) for k in config.additional_claim_keys if k in claims) + def get_current_user(): """ @@ -77,6 +86,7 @@ def decode_token(encoded_token, csrf_value=None): algorithm=config.algorithm, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, + additional_claim_keys=config.additional_claim_keys, csrf_value=csrf_value ) @@ -182,6 +192,9 @@ def verify_token_claims(jwt_data): if not jwt_manager._claims_verification_callback(user_claims): raise UserClaimsVerificationError('User claims verification failed') + additional_claims = dict((k, jwt_data[k]) for k in config.additional_claim_keys if k in jwt_data) + if not jwt_manager._additonal_claims_verification_callback(additional_claims): + raise AdditionalClaimsVerificationError('Additional claims verification failed') def get_csrf_token(encoded_token): """ diff --git a/tests/test_additional_claims_loader.py b/tests/test_additional_claims_loader.py new file mode 100644 index 00000000..2e2b2d6e --- /dev/null +++ b/tests/test_additional_claims_loader.py @@ -0,0 +1,165 @@ +import pytest +from flask import Flask, jsonify + +from flask_jwt_extended import ( + JWTManager, create_access_token, jwt_required, + decode_token, jwt_refresh_token_required, create_refresh_token +) + +from flask_jwt_extended.utils import get_jwt_additional_claims +from flask_jwt_extended.exceptions import JWTEncodeError +from tests.utils import get_jwt_manager, make_headers + + +@pytest.fixture(scope='function') +def app(): + app = Flask(__name__) + app.config['JWT_SECRET_KEY'] = 'foobarbaz' + app.config['JWT_ADDITIONAL_CLAIMS'] = ['foo'] + JWTManager(app) + + @app.route('/protected', methods=['GET']) + @jwt_required + def get_claims(): + return jsonify(get_jwt_additional_claims()) + + @app.route('/protected2', methods=['GET']) + @jwt_refresh_token_required + def get_refresh_claims(): + return jsonify(get_jwt_additional_claims()) + + return app + + +def test_additional_claim_in_access_token(app): + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + access_token = create_access_token('username') + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 + + +def test_non_serializable_additional_claims(app): + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return app + + with pytest.raises(TypeError): + with app.test_request_context(): + create_access_token('username') + + +def test_default_additional_claims_error(app): + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return {'exp': 1} + + with pytest.raises(JWTEncodeError): + with app.test_request_context(): + access_token = create_access_token('username') + +def test_missing_additional_claims_error(app): + jwt = get_jwt_manager(app) + + with app.test_request_context(): + access_token = create_access_token('username') + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json() == {'msg': 'Missing claim foo'} + assert response.status_code == 422 + +def test_token_from_complex_object(app): + class TestObject: + def __init__(self, username): + self.username = username + + jwt = get_jwt_manager(app) + app.config['JWT_ADDITIONAL_CLAIMS'] = ['username'] + + @jwt.additional_claims_loader + def add_claims(test_obj): + return {'username': test_obj.username} + + @jwt.user_identity_loader + def add_claims(test_obj): + return test_obj.username + + with app.test_request_context(): + access_token = create_access_token(TestObject('username')) + + # Make sure the changes appear in the token + decoded_token = decode_token(access_token) + assert decoded_token['identity'] == 'username' + assert decoded_token['username'] == 'username' + + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json() == {'username': 'username'} + assert response.status_code == 200 + + +def test_additional_claims_with_different_name(app): + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + access_token = create_access_token('username') + + # Make sure the name is actually different in the token + decoded_token = decode_token(access_token) + assert decoded_token['foo'] == 'bar' + + # Make sure the correct data is returned to us from the full call + test_client = app.test_client() + response = test_client.get('/protected', headers=make_headers(access_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 + + +def test_additional_claim_not_in_refresh_token(app): + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + refresh_token = create_refresh_token('username') + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json() == {} + assert response.status_code == 200 + + +def test_additional_claim_in_refresh_token(app): + app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True + jwt = get_jwt_manager(app) + + @jwt.additional_claims_loader + def add_claims(identity): + return {'foo': 'bar'} + + with app.test_request_context(): + refresh_token = create_refresh_token('username') + + test_client = app.test_client() + response = test_client.get('/protected2', headers=make_headers(refresh_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 diff --git a/tests/test_additional_claims_verificaion.py b/tests/test_additional_claims_verificaion.py new file mode 100644 index 00000000..7014a34e --- /dev/null +++ b/tests/test_additional_claims_verificaion.py @@ -0,0 +1,114 @@ +import pytest +from flask import Flask, jsonify + +from flask_jwt_extended import ( + JWTManager, jwt_required, create_access_token, get_jwt_identity, + fresh_jwt_required, jwt_optional +) +from tests.utils import get_jwt_manager, make_headers + + +@pytest.fixture(scope='function') +def app(): + app = Flask(__name__) + app.config['JWT_SECRET_KEY'] = 'foobarbaz' + app.config['JWT_ADDITIONAL_CLAIMS'] = ['foo'] + jwt = JWTManager(app) + + @jwt.additional_claims_loader + def add_user_claims(identity): + return {'foo': 'bar'} + + @app.route('/protected1', methods=['GET']) + @jwt_required + def protected1(): + return jsonify(foo='bar') + + @app.route('/protected2', methods=['GET']) + @fresh_jwt_required + def protected2(): + return jsonify(foo='bar') + + @app.route('/protected3', methods=['GET']) + @jwt_optional + def protected3(): + return jsonify(foo='bar') + + return app + + +@pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) +def test_successful_claims_validation(app, url): + jwt = get_jwt_manager(app) + + @jwt.additonal_claims_verification_loader + def user_load_callback(user_claims): + return user_claims == {'foo': 'bar'} + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token('username', fresh=True) + + response = test_client.get(url, headers=make_headers(access_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 + + +@pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) +def test_unsuccessful_claims_validation(app, url): + jwt = get_jwt_manager(app) + + @jwt.additonal_claims_verification_loader + def user_load_callback(user_claims): + return False + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token('username', fresh=True) + + response = test_client.get(url, headers=make_headers(access_token)) + assert response.get_json() == {'msg': 'Additional claims verification failed'} + assert response.status_code == 400 + + +@pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) +def test_claims_validation_custom_error(app, url): + jwt = get_jwt_manager(app) + + @jwt.additonal_claims_verification_loader + def user_load_callback(user_claims): + return False + + @jwt.additonal_claims_verification_failed_loader + def custom_error(): + # Make sure that we can get the jwt identity in here if we need it. + user = get_jwt_identity() + return jsonify(msg='claims failed for {}'.format(user)), 404 + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token('username', fresh=True) + + response = test_client.get(url, headers=make_headers(access_token)) + assert response.get_json() == {'msg': 'claims failed for username'} + assert response.status_code == 404 + + +@pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) +def test_get_jwt_identity_in_verification_method(app, url): + jwt = get_jwt_manager(app) + + + @jwt.additonal_claims_verification_loader + def user_load_callback(user_claims): + # Make sure that we can get the jwt identity in here if we need it. + user = get_jwt_identity() + return user == 'username' + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token('username', fresh=True) + + response = test_client.get(url, headers=make_headers(access_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 diff --git a/tests/test_config.py b/tests/test_config.py index e1ef5ccf..043d725b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -101,6 +101,7 @@ def test_override_configs(app): app.config['JWT_IDENTITY_CLAIM'] = 'foo' app.config['JWT_USER_CLAIMS'] = 'bar' + app.config['JWT_ADDITIONAL_CLAIMS'] = [] app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True From 0aaefefdf692e69302e4d48a730b8465b4062a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20Hartvig=20Gr=C3=B8nbech?= Date: Tue, 3 Jul 2018 11:43:03 +0200 Subject: [PATCH 2/3] review of additional claims feature --- flask_jwt_extended/jwt_manager.py | 7 ++----- flask_jwt_extended/tokens.py | 2 +- tests/test_additional_claims_verificaion.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index eed1a11c..d8f38b82 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -423,7 +423,7 @@ def additonal_claims_verification_failed_loader(self, callback): callback returns False, indicating that the user claims are not valid. The default implementation will return a 400 status code with the JSON: - {"msg": "User claims verification failed"} + {"msg": "Additional claims verification failed"} This callback must be a function that takes no arguments, and returns a Flask response. @@ -438,12 +438,9 @@ def _create_refresh_token(self, identity, expires_delta=None): if config.user_claims_in_refresh_token: user_claims = self._user_claims_callback(identity) - else: - user_claims = None - - if config.user_claims_in_refresh_token: additional_claims = self._additonal_claims_callback(identity) else: + user_claims = None additional_claims = None refresh_token = encode_refresh_token( diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 4c12e481..686cbdbd 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -162,7 +162,7 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key, if data['type'] == 'access': if 'fresh' not in data: raise JWTDecodeError("Missing claim: fresh") - + if user_claims_key not in data: data[user_claims_key] = {} diff --git a/tests/test_additional_claims_verificaion.py b/tests/test_additional_claims_verificaion.py index 7014a34e..bd16a856 100644 --- a/tests/test_additional_claims_verificaion.py +++ b/tests/test_additional_claims_verificaion.py @@ -35,6 +35,18 @@ def protected3(): return jsonify(foo='bar') return app +@pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) +def test_successful_no_claims(app, url): + jwt = get_jwt_manager(app) + + test_client = app.test_client() + with app.test_request_context(): + access_token = create_access_token('username', fresh=True) + + response = test_client.get(url, headers=make_headers(access_token)) + assert response.get_json() == {'foo': 'bar'} + assert response.status_code == 200 + @pytest.mark.parametrize("url", ['/protected1', '/protected2', '/protected3']) From b708a17e0f3ee7f9c339450a3d744faebab1a69d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20Hartvig=20Gr=C3=B8nbech?= Date: Tue, 3 Jul 2018 11:45:26 +0200 Subject: [PATCH 3/3] review v2 of additional claims feature --- flask_jwt_extended/tokens.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 686cbdbd..6b311fb3 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -18,7 +18,7 @@ def _check_claims(default_claims, additional_claims): raise JWTEncodeError("Claim %s in conflict with default claims" % str(claim)) def _encode_jwt(additional_token_data, expires_delta, secret, algorithm, - claim_key, json_encoder=None): + json_encoder=None): uid = str(uuid.uuid4()) now = datetime.datetime.utcnow() token_data = { @@ -91,7 +91,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - claim_key=identity_claim_key, json_encoder=json_encoder) + json_encoder=json_encoder) def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims, @@ -133,7 +133,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm, - claim_key=identity_claim_key, json_encoder=json_encoder) + json_encoder=json_encoder) def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,