From df68d65a2b28427130cfb9d4fabadd5b1b23cf9c Mon Sep 17 00:00:00 2001 From: Pau Ruiz i Safont Date: Wed, 12 Jul 2017 15:12:18 +0100 Subject: [PATCH 1/2] Allow changing subject claim when decoding Related to issue #65 --- docs/options.rst | 3 ++ flask_jwt_extended/config.py | 4 +++ flask_jwt_extended/jwt_manager.py | 2 ++ flask_jwt_extended/tokens.py | 7 +++-- flask_jwt_extended/utils.py | 13 +++++++-- flask_jwt_extended/view_decorators.py | 11 +++++-- tests/test_config.py | 6 ++++ tests/test_jwt_encode_decode.py | 41 ++++++++++++++++++--------- 8 files changed, 65 insertions(+), 22 deletions(-) diff --git a/docs/options.rst b/docs/options.rst index 3e23bbb0..98fb94ab 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -30,6 +30,9 @@ General Options: such as ``RS*`` or ``ES*``. PEM format expected. ``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms, such as ``RS*`` or ``ES*``. PEM format expected. +``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity. + For interoperativity, the JWT RFC recommends using ``'sub'``. + Defaults to ``'identity'``. ================================= ========================================= diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index bf522d2e..990d24e7 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -223,6 +223,10 @@ def cookie_max_age(self): # seconds a long ways in the future return None if self.session_cookie else 2147483647 # 2^31 + @property + def identity_claim(self): + return current_app.config['JWT_IDENTITY_CLAIM'] + config = _Config() diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index a230c7d2..da72c171 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -164,6 +164,8 @@ def _set_default_configuration_options(app): app.config.setdefault('JWT_BLACKLIST_ENABLED', False) app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh']) + app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity') + def user_claims_loader(self, callback): """ This sets the callback method for adding custom user claims to a JWT. diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index e6d6bf11..b7ee6885 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -76,7 +76,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf): return _encode_jwt(token_data, expires_delta, secret, algorithm) -def decode_jwt(encoded_token, secret, algorithm, csrf): +def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim): """ Decodes an encoded JWT @@ -85,6 +85,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf): :param algorithm: Algorithm used to encode the JWT :param csrf: If this token is expected to have a CSRF double submit value present (boolean) + :param identity_claim: expected claim that is used to identify the subject :return: Dictionary containing contents of the JWT """ # This call verifies the ext, iat, and nbf claims @@ -93,8 +94,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf): # Make sure that any custom claims we expect in the token are present if 'jti' not in data: raise JWTDecodeError("Missing claim: jti") - if 'identity' not in data: - raise JWTDecodeError("Missing claim: identity") + if identity_claim not in data: + raise JWTDecodeError("Missing claim: {}".format(identity_claim)) if 'type' not in data or data['type'] not in ('refresh', 'access'): raise JWTDecodeError("Missing or invalid claim: type") if data['type'] == 'access': diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 8f93753d..85f49f9c 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -27,7 +27,7 @@ def get_jwt_identity(): Returns the identity of the JWT in this context. If no JWT is present, None is returned. """ - return get_raw_jwt().get('identity', None) + return get_raw_jwt().get(config.identity_claim, None) def get_jwt_claims(): @@ -63,7 +63,8 @@ def decode_token(encoded_token): encoded_token=encoded_token, secret=config.decode_key, algorithm=config.algorithm, - csrf=config.csrf_protect + csrf=config.csrf_protect, + identity_claim=config.identity_claim ) @@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs): def get_csrf_token(encoded_token): - token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True) + token = decode_jwt( + encoded_token, + config.decode_key, + config.algorithm, + csrf=True, + identity_claim=config.identity_claim + ) return token['csrf'] diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 386e458a..65253598 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -144,7 +144,13 @@ def _decode_jwt_from_headers(): raise InvalidHeaderError(msg) token = parts[1] - return decode_jwt(token, config.decode_key, config.algorithm, csrf=False) + return decode_jwt( + encoded_token=token, + secret=config.decode_key, + algorithm=config.algorithm, + csrf=False, + identity_claim=config.identity_claim + ) def _decode_jwt_from_cookies(request_type): @@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type): encoded_token=encoded_token, secret=config.decode_key, algorithm=config.algorithm, - csrf=config.csrf_protect + csrf=config.csrf_protect, + identity_claim=config.identity_claim ) # Verify csrf double submit tokens match if required diff --git a/tests/test_config.py b/tests/test_config.py index fece2933..f67aad03 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -54,6 +54,8 @@ def test_default_configs(self): self.assertEqual(config.decode_key, self.app.secret_key) self.assertEqual(config.cookie_max_age, None) + self.assertEqual(config.identity_claim, 'identity') + def test_override_configs(self): self.app.config['JWT_TOKEN_LOCATION'] = ['cookies'] self.app.config['JWT_HEADER_NAME'] = 'TestHeader' @@ -86,6 +88,8 @@ def test_override_configs(self): self.app.secret_key = 'banana' + self.app.config['JWT_IDENTITY_CLAIM'] = 'foo' + with self.app.test_request_context(): self.assertEqual(config.token_location, ['cookies']) self.assertEqual(config.jwt_in_cookies, True) @@ -122,6 +126,8 @@ def test_override_configs(self): self.assertEqual(config.cookie_max_age, 2147483647) + self.assertEqual(config.identity_claim, 'foo') + def test_invalid_config_options(self): with self.app.test_request_context(): self.app.config['JWT_TOKEN_LOCATION'] = 'banana' diff --git a/tests/test_jwt_encode_decode.py b/tests/test_jwt_encode_decode.py index b65caad3..9d2ef651 100644 --- a/tests/test_jwt_encode_decode.py +++ b/tests/test_jwt_encode_decode.py @@ -157,7 +157,7 @@ def test_decode_jwt(self): 'user_claims': {'foo': 'bar'}, } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) @@ -188,7 +188,7 @@ def test_decode_jwt(self): 'type': 'refresh', } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) @@ -210,7 +210,7 @@ def test_decode_invalid_jwt(self): 'exp': datetime.utcnow() - timedelta(minutes=5), } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Missing jti with self.assertRaises(JWTDecodeError): @@ -220,7 +220,7 @@ def test_decode_invalid_jwt(self): 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Missing identity with self.assertRaises(JWTDecodeError): @@ -230,7 +230,17 @@ def test_decode_invalid_jwt(self): 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + + # Non-matching identity claim + with self.assertRaises(JWTDecodeError): + token_data = { + 'exp': datetime.utcnow() + timedelta(minutes=5), + 'identity': 'banana', + 'type': 'refresh' + } + encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub') # Missing type with self.assertRaises(JWTDecodeError): @@ -240,7 +250,7 @@ def test_decode_invalid_jwt(self): 'exp': datetime.utcnow() + timedelta(minutes=5), } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Missing fresh in access token with self.assertRaises(JWTDecodeError): @@ -252,7 +262,7 @@ def test_decode_invalid_jwt(self): 'user_claims': {} } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Missing user claims in access token with self.assertRaises(JWTDecodeError): @@ -264,7 +274,7 @@ def test_decode_invalid_jwt(self): 'fresh': True } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Bad token type with self.assertRaises(JWTDecodeError): @@ -277,7 +287,7 @@ def test_decode_invalid_jwt(self): 'user_claims': 'banana' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') # Missing csrf in csrf enabled token with self.assertRaises(JWTDecodeError): @@ -290,7 +300,7 @@ def test_decode_invalid_jwt(self): 'user_claims': 'banana' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=True) + decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity') def test_create_jwt_with_object(self): # Complex object to test building a JWT from. Normally if you are using @@ -322,12 +332,15 @@ def user_identity_lookup(user): user = TestUser(username='foo', roles=['bar', 'baz']) access_token = create_access_token(identity=user) refresh_token = create_refresh_token(identity=user) + identity = 'identity' # Decode the tokens and make sure the values are set properly access_token_data = decode_jwt(access_token, app.secret_key, - app.config['JWT_ALGORITHM'], csrf=False) + app.config['JWT_ALGORITHM'], csrf=False, + identity_claim=identity) refresh_token_data = decode_jwt(refresh_token, app.secret_key, - app.config['JWT_ALGORITHM'], csrf=False) - self.assertEqual(access_token_data['identity'], 'foo') + app.config['JWT_ALGORITHM'], csrf=False, + identity_claim=identity) + self.assertEqual(access_token_data[identity], 'foo') self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz']) - self.assertEqual(refresh_token_data['identity'], 'foo') + self.assertEqual(refresh_token_data[identity], 'foo') From f8d83f22fcddef1c4e5a5a79e9cbf3dbc001b80b Mon Sep 17 00:00:00 2001 From: Pau Ruiz i Safont Date: Wed, 12 Jul 2017 16:17:41 +0100 Subject: [PATCH 2/2] Use JWT_IDENTITY_CLAIM for encoding too --- docs/options.rst | 2 +- flask_jwt_extended/jwt_manager.py | 6 ++- flask_jwt_extended/tokens.py | 10 ++-- tests/test_jwt_encode_decode.py | 82 +++++++++++++++++++------------ tests/test_protected_endpoints.py | 11 +++-- 5 files changed, 68 insertions(+), 43 deletions(-) diff --git a/docs/options.rst b/docs/options.rst index 98fb94ab..6d6b4ed5 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -30,7 +30,7 @@ General Options: such as ``RS*`` or ``ES*``. PEM format expected. ``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms, such as ``RS*`` or ``ES*``. PEM format expected. -``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity. +``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity. For interoperativity, the JWT RFC recommends using ``'sub'``. Defaults to ``'identity'``. ================================= ========================================= diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index da72c171..fff481c3 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -321,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None): secret=config.encode_key, algorithm=config.algorithm, expires_delta=expires_delta, - csrf=config.csrf_protect + csrf=config.csrf_protect, + identity_claim=config.identity_claim ) return refresh_token @@ -354,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None): expires_delta=expires_delta, fresh=fresh, user_claims=self._user_claims_callback(identity), - csrf=config.csrf_protect + csrf=config.csrf_protect, + identity_claim=config.identity_claim ) return access_token diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index b7ee6885..26990bcc 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm): def encode_access_token(identity, secret, algorithm, expires_delta, fresh, - user_claims, csrf): + user_claims, csrf, identity_claim): """ Creates a new encoded (utf-8) access token. @@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, be json serializable :param csrf: Whether to include a csrf double submit claim in this token (boolean) + :param identity_claim: Which claim should be used to store the identity in :return: Encoded access token """ # Create the jwt token_data = { - 'identity': identity, + identity_claim: identity, 'fresh': fresh, 'type': 'access', 'user_claims': user_claims, @@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, return _encode_jwt(token_data, expires_delta, secret, algorithm) -def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf): +def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim): """ Creates a new encoded (utf-8) refresh token. @@ -65,10 +66,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf): (datetime.timedelta) :param csrf: Whether to include a csrf double submit claim in this token (boolean) + :param identity_claim: Which claim should be used to store the identity in :return: Encoded refresh token """ token_data = { - 'identity': identity, + identity_claim: identity, 'type': 'refresh', } if csrf: diff --git a/tests/test_jwt_encode_decode.py b/tests/test_jwt_encode_decode.py index 9d2ef651..d363eda5 100644 --- a/tests/test_jwt_encode_decode.py +++ b/tests/test_jwt_encode_decode.py @@ -35,7 +35,8 @@ def test_encode_access_token(self): with self.app.test_request_context(): identity = 'user1' token = encode_access_token(identity, secret, algorithm, token_expire_delta, - fresh=True, user_claims=user_claims, csrf=False) + fresh=True, user_claims=user_claims, csrf=False, + identity_claim='identity') data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) @@ -59,7 +60,8 @@ def test_encode_access_token(self): # Check with a non-fresh token identity = 12345 # identity can be anything json serializable token = encode_access_token(identity, secret, algorithm, token_expire_delta, - fresh=False, user_claims=user_claims, csrf=True) + fresh=False, user_claims=user_claims, csrf=True, + identity_claim='identity') data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) @@ -87,33 +89,35 @@ def test_encode_invalid_access_token(self): with self.assertRaises(Exception): encode_access_token('user1', 'secret', 'HS256', timedelta(hours=1), True, user_claims, - csrf=True) + csrf=True, identity_claim='identity') user_claims = {'foo': timedelta(hours=4)} with self.assertRaises(Exception): encode_access_token('user1', 'secret', 'HS256', timedelta(hours=1), True, user_claims, - csrf=True) + csrf=True, identity_claim='identity') def test_encode_refresh_token(self): secret = 'super-totally-secret-key' algorithm = 'HS256' token_expire_delta = timedelta(minutes=5) + identity_claim = 'sub' # Check with a fresh token with self.app.test_request_context(): identity = 'user1' token = encode_refresh_token(identity, secret, algorithm, - token_expire_delta, csrf=False) + token_expire_delta, csrf=False, + identity_claim=identity_claim) data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) self.assertIn('type', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertNotIn('csrf', data) - self.assertEqual(data['identity'], identity) + self.assertEqual(data[identity_claim], identity) self.assertEqual(data['type'], 'refresh') self.assertEqual(data['iat'], data['nbf']) now_ts = calendar.timegm(datetime.utcnow().utctimetuple()) @@ -124,7 +128,8 @@ def test_encode_refresh_token(self): # Check with a csrf token identity = 12345 # identity can be anything json serializable token = encode_refresh_token(identity, secret, algorithm, - token_expire_delta, csrf=True) + token_expire_delta, csrf=True, + identity_claim=identity_claim) data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) @@ -132,8 +137,8 @@ def test_encode_refresh_token(self): self.assertIn('jti', data) self.assertIn('type', data) self.assertIn('csrf', data) - self.assertIn('identity', data) - self.assertEqual(data['identity'], identity) + self.assertIn(identity_claim, data) + self.assertEqual(data[identity_claim], identity) self.assertEqual(data['type'], 'refresh') self.assertEqual(data['iat'], data['nbf']) now_ts = calendar.timegm(datetime.utcnow().utctimetuple()) @@ -142,6 +147,7 @@ def test_encode_refresh_token(self): self.assertGreater(exp_seconds, 60 * 4) def test_decode_jwt(self): + identity_claim = 'sub' # Test decoding a valid access token with self.app.test_request_context(): now = datetime.utcnow() @@ -151,18 +157,19 @@ def test_decode_jwt(self): 'iat': now, 'nbf': now, 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'fresh': True, 'type': 'access', 'user_claims': {'foo': 'bar'}, } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + data = decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim=identity_claim) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertIn('fresh', data) self.assertIn('type', data) self.assertIn('user_claims', data) @@ -170,7 +177,7 @@ def test_decode_jwt(self): self.assertEqual(data['iat'], now_ts) self.assertEqual(data['nbf'], now_ts) self.assertEqual(data['jti'], 'banana') - self.assertEqual(data['identity'], 'banana') + self.assertEqual(data[identity_claim], 'banana') self.assertEqual(data['fresh'], True) self.assertEqual(data['type'], 'access') self.assertEqual(data['user_claims'], {'foo': 'bar'}) @@ -184,22 +191,23 @@ def test_decode_jwt(self): 'iat': now, 'nbf': now, 'jti': 'banana', - 'identity': 'banana', + identity_claim: 'banana', 'type': 'refresh', } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + data = decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim=identity_claim) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) self.assertIn('jti', data) - self.assertIn('identity', data) + self.assertIn(identity_claim, data) self.assertIn('type', data) self.assertEqual(data['exp'], now_ts + (5 * 60)) self.assertEqual(data['iat'], now_ts) self.assertEqual(data['nbf'], now_ts) self.assertEqual(data['jti'], 'banana') - self.assertEqual(data['identity'], 'banana') + self.assertEqual(data[identity_claim], 'banana') self.assertEqual(data['type'], 'refresh') def test_decode_invalid_jwt(self): @@ -210,7 +218,8 @@ def test_decode_invalid_jwt(self): 'exp': datetime.utcnow() - timedelta(minutes=5), } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Missing jti with self.assertRaises(JWTDecodeError): @@ -220,7 +229,8 @@ def test_decode_invalid_jwt(self): 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Missing identity with self.assertRaises(JWTDecodeError): @@ -230,7 +240,8 @@ def test_decode_invalid_jwt(self): 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Non-matching identity claim with self.assertRaises(JWTDecodeError): @@ -240,7 +251,8 @@ def test_decode_invalid_jwt(self): 'type': 'refresh' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='sub') # Missing type with self.assertRaises(JWTDecodeError): @@ -250,7 +262,8 @@ def test_decode_invalid_jwt(self): 'exp': datetime.utcnow() + timedelta(minutes=5), } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Missing fresh in access token with self.assertRaises(JWTDecodeError): @@ -262,7 +275,8 @@ def test_decode_invalid_jwt(self): 'user_claims': {} } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Missing user claims in access token with self.assertRaises(JWTDecodeError): @@ -274,7 +288,8 @@ def test_decode_invalid_jwt(self): 'fresh': True } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Bad token type with self.assertRaises(JWTDecodeError): @@ -287,7 +302,8 @@ def test_decode_invalid_jwt(self): 'user_claims': 'banana' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim='identity') # Missing csrf in csrf enabled token with self.assertRaises(JWTDecodeError): @@ -300,7 +316,8 @@ def test_decode_invalid_jwt(self): 'user_claims': 'banana' } encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity') + decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, + identity_claim='identity') def test_create_jwt_with_object(self): # Complex object to test building a JWT from. Normally if you are using @@ -329,18 +346,19 @@ def user_identity_lookup(user): # Create the token using the complex object with app.test_request_context(): + identity_claim = 'sub' + app.config['JWT_IDENTITY_CLAIM'] = identity_claim user = TestUser(username='foo', roles=['bar', 'baz']) access_token = create_access_token(identity=user) refresh_token = create_refresh_token(identity=user) - identity = 'identity' # Decode the tokens and make sure the values are set properly access_token_data = decode_jwt(access_token, app.secret_key, app.config['JWT_ALGORITHM'], csrf=False, - identity_claim=identity) + identity_claim=identity_claim) refresh_token_data = decode_jwt(refresh_token, app.secret_key, app.config['JWT_ALGORITHM'], csrf=False, - identity_claim=identity) - self.assertEqual(access_token_data[identity], 'foo') + identity_claim=identity_claim) + self.assertEqual(access_token_data[identity_claim], 'foo') self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz']) - self.assertEqual(refresh_token_data[identity], 'foo') + self.assertEqual(refresh_token_data[identity_claim], 'foo') diff --git a/tests/test_protected_endpoints.py b/tests/test_protected_endpoints.py index 6cbb80ca..165ce505 100644 --- a/tests/test_protected_endpoints.py +++ b/tests/test_protected_endpoints.py @@ -331,7 +331,8 @@ def test_bad_tokens(self): # Test token that was signed with a different key with self.app.test_request_context(): token = encode_access_token('foo', 'newsecret', 'HS256', - timedelta(minutes=5), True, {}, csrf=False) + timedelta(minutes=5), True, {}, csrf=False, + identity_claim='identity') auth_header = "Bearer {}".format(token) response = self.client.get('/protected', headers={'Authorization': auth_header}) data = json.loads(response.get_data(as_text=True)) @@ -397,7 +398,7 @@ def test_optional_jwt_bad_tokens(self): with self.app.test_request_context(): token = encode_access_token('foo', 'newsecret', 'HS256', timedelta(minutes=5), True, {}, - csrf=False) + csrf=False, identity_claim='identity') auth_header = "Bearer {}".format(token) response = self.client.get('/partially-protected', headers={'Authorization': auth_header}) @@ -584,7 +585,8 @@ def test_jwt_with_different_algorithm(self): expires_delta=timedelta(minutes=5), fresh=True, user_claims={}, - csrf=False + csrf=False, + identity_claim='identity' ) status, data = self._jwt_get('/protected', access_token) self.assertEqual(status, 422) @@ -600,7 +602,8 @@ def test_optional_jwt_with_different_algorithm(self): expires_delta=timedelta(minutes=5), fresh=True, user_claims={}, - csrf=False + csrf=False, + identity_claim='identity' ) status, data = self._jwt_get('/partially-protected', access_token) self.assertEqual(status, 422)