From 1144f3387af330e7ecba278a5340e16d7bd3a7dd Mon Sep 17 00:00:00 2001 From: waled meselhy Date: Sat, 19 Jan 2019 16:49:18 +0200 Subject: [PATCH] add extra_info_for_claims argument to the function that add claims --- flask_jwt_extended/default_callbacks.py | 2 +- flask_jwt_extended/jwt_manager.py | 8 ++++---- flask_jwt_extended/utils.py | 8 ++++---- tests/test_claims_verification.py | 2 +- tests/test_decode_tokens.py | 2 +- tests/test_user_claims_loader.py | 14 +++++++------- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/flask_jwt_extended/default_callbacks.py b/flask_jwt_extended/default_callbacks.py index 7c85e8cc..4ed87c9a 100644 --- a/flask_jwt_extended/default_callbacks.py +++ b/flask_jwt_extended/default_callbacks.py @@ -11,7 +11,7 @@ from flask_jwt_extended.config import config -def default_user_claims_callback(userdata): +def default_user_claims_callback(userdata, extra_info_for_claims={}): """ By default, we add no additional claims to the access tokens. diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 8cd743cd..c47fb913 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -421,12 +421,12 @@ def encode_key_loader(self, callback): self._encode_key_callback = callback return callback - def _create_refresh_token(self, identity, expires_delta=None): + def _create_refresh_token(self, identity, extra_info_for_claims, expires_delta=None): if expires_delta is None: expires_delta = config.refresh_expires if config.user_claims_in_refresh_token: - user_claims = self._user_claims_callback(identity) + user_claims = self._user_claims_callback(identity, extra_info_for_claims) else: user_claims = None @@ -443,7 +443,7 @@ def _create_refresh_token(self, identity, expires_delta=None): ) return refresh_token - def _create_access_token(self, identity, fresh=False, expires_delta=None): + def _create_access_token(self, identity, extra_info_for_claims={}, fresh=False, expires_delta=None): if expires_delta is None: expires_delta = config.access_expires @@ -453,7 +453,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None): algorithm=config.algorithm, expires_delta=expires_delta, fresh=fresh, - user_claims=self._user_claims_callback(identity), + user_claims=self._user_claims_callback(identity, extra_info_for_claims), csrf=config.csrf_protect, identity_claim_key=config.identity_claim_key, user_claims_key=config.user_claims_key, diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 8e27c348..0d49f557 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -110,7 +110,7 @@ def _get_jwt_manager(): "application before using this method") -def create_access_token(identity, fresh=False, expires_delta=None): +def create_access_token(identity, extra_info_for_claims={}, fresh=False, expires_delta=None): """ Create a new access token. @@ -133,10 +133,10 @@ def create_access_token(identity, fresh=False, expires_delta=None): :return: An encoded access token """ jwt_manager = _get_jwt_manager() - return jwt_manager._create_access_token(identity, fresh, expires_delta) + return jwt_manager._create_access_token(identity, extra_info_for_claims, fresh, expires_delta) -def create_refresh_token(identity, expires_delta=None): +def create_refresh_token(identity, extra_info_for_claims={}, expires_delta=None): """ Creates a new refresh token. @@ -154,7 +154,7 @@ def create_refresh_token(identity, expires_delta=None): :return: An encoded refresh token """ jwt_manager = _get_jwt_manager() - return jwt_manager._create_refresh_token(identity, expires_delta) + return jwt_manager._create_refresh_token(identity, extra_info_for_claims, expires_delta) def has_user_loader(): diff --git a/tests/test_claims_verification.py b/tests/test_claims_verification.py index 1893a7a8..d43105d8 100644 --- a/tests/test_claims_verification.py +++ b/tests/test_claims_verification.py @@ -15,7 +15,7 @@ def app(): jwt = JWTManager(app) @jwt.user_claims_loader - def add_user_claims(identity): + def add_user_claims(identity, extra_info_for_claims={}): return {'foo': 'bar'} @app.route('/protected1', methods=['GET']) diff --git a/tests/test_decode_tokens.py b/tests/test_decode_tokens.py index 2ca15dcb..dfba79ca 100644 --- a/tests/test_decode_tokens.py +++ b/tests/test_decode_tokens.py @@ -60,7 +60,7 @@ def test_no_user_claims(app, user_loader_return): jwtM = get_jwt_manager(app) @jwtM.user_claims_loader - def empty_user_loader_return(identity): + def empty_user_loader_return(identity, extra_info_for_claims={}): return user_loader_return # Identity should not be in the actual token, but should be in the data diff --git a/tests/test_user_claims_loader.py b/tests/test_user_claims_loader.py index 61990fef..1bbf7fad 100644 --- a/tests/test_user_claims_loader.py +++ b/tests/test_user_claims_loader.py @@ -31,7 +31,7 @@ def test_user_claim_in_access_token(app): jwt = get_jwt_manager(app) @jwt.user_claims_loader - def add_claims(identity): + def add_claims(identity, extra_info_for_claims={}): return {'foo': 'bar'} with app.test_request_context(): @@ -47,7 +47,7 @@ def test_non_serializable_user_claims(app): jwt = get_jwt_manager(app) @jwt.user_claims_loader - def add_claims(identity): + def add_claims(identity, extra_info_for_claims={}): return app with pytest.raises(TypeError): @@ -63,11 +63,11 @@ def __init__(self, username): jwt = get_jwt_manager(app) @jwt.user_claims_loader - def add_claims(test_obj): + def add_claims(test_obj, extra_info_for_claims={}): return {'username': test_obj.username} @jwt.user_identity_loader - def add_claims(test_obj): + def add_claims(test_obj, extra_info_for_claims={}): return test_obj.username with app.test_request_context(): @@ -89,7 +89,7 @@ def test_user_claims_with_different_name(app): app.config['JWT_USER_CLAIMS'] = 'banana' @jwt.user_claims_loader - def add_claims(identity): + def add_claims(identity, extra_info_for_claims={}): return {'foo': 'bar'} with app.test_request_context(): @@ -110,7 +110,7 @@ def test_user_claim_not_in_refresh_token(app): jwt = get_jwt_manager(app) @jwt.user_claims_loader - def add_claims(identity): + def add_claims(identity, extra_info_for_claims={}): return {'foo': 'bar'} with app.test_request_context(): @@ -127,7 +127,7 @@ def test_user_claim_in_refresh_token(app): jwt = get_jwt_manager(app) @jwt.user_claims_loader - def add_claims(identity): + def add_claims(identity, extra_info_for_claims={}): return {'foo': 'bar'} with app.test_request_context():