diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 26990bcc..031fce19 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -48,8 +48,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh, identity_claim: identity, 'fresh': fresh, 'type': 'access', - 'user_claims': user_claims, } + + # Add `user_claims` only is not empty or None. + if user_claims: + token_data['user_claims'] = user_claims + if csrf: token_data['csrf'] = _create_csrf_token() return _encode_jwt(token_data, expires_delta, secret, algorithm) @@ -104,7 +108,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim): if 'fresh' not in data: raise JWTDecodeError("Missing claim: fresh") if 'user_claims' not in data: - raise JWTDecodeError("Missing claim: user_claims") + data['user_claims'] = {} if csrf: if 'csrf' not in data: raise JWTDecodeError("Missing claim: csrf") diff --git a/tests/test_jwt_encode_decode.py b/tests/test_jwt_encode_decode.py index 2d31884e..73d166c3 100644 --- a/tests/test_jwt_encode_decode.py +++ b/tests/test_jwt_encode_decode.py @@ -83,6 +83,35 @@ def test_encode_access_token(self): self.assertLessEqual(exp_seconds, 60 * 5) self.assertGreater(exp_seconds, 60 * 4) + def test_encode_access_token__no_user_claims(self): + ''' + To make JWT shorter, do not add `user_claims` if empty. + ''' + secret = 'super-totally-secret-key' + algorithm = 'HS256' + token_expire_delta = timedelta(minutes=5) + identity_claim = 'sub' + + # `user_claims` is empty dict + with self.app.test_request_context(): + identity = 'user1' + token = encode_access_token(identity, secret, algorithm, token_expire_delta, + fresh=False, user_claims={}, csrf=False, + identity_claim=identity_claim) + + data = jwt.decode(token, secret, algorithms=[algorithm]) + self.assertNotIn('user_claims', data) + + # `user_claims` is None + with self.app.test_request_context(): + identity = 'user1' + token = encode_access_token(identity, secret, algorithm, token_expire_delta, + fresh=False, user_claims=None, csrf=False, + identity_claim=identity_claim) + + data = jwt.decode(token, secret, algorithms=[algorithm]) + self.assertNotIn('user_claims', data) + def test_encode_invalid_access_token(self): # Check with non-serializable json with self.app.test_request_context(): @@ -212,6 +241,29 @@ def test_decode_jwt(self): self.assertEqual(data[identity_claim], 'banana') self.assertEqual(data['type'], 'refresh') + def test_decode_access_token__no_user_claims(self): + ''' + Test decoding a valid access token without `user_claims`. + ''' + identity_claim = 'sub' + with self.app.test_request_context(): + now = datetime.utcnow() + token_data = { + 'exp': now + timedelta(minutes=5), + 'iat': now, + 'nbf': now, + 'jti': 'banana', + identity_claim: 'banana', + 'fresh': True, + 'type': 'access', + } + encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') + data = decode_jwt(encoded_token, 'secret', 'HS256', + csrf=False, identity_claim=identity_claim) + + self.assertIn('user_claims', data) + self.assertEqual(data['user_claims'], {}) + def test_decode_invalid_jwt(self): with self.app.test_request_context(): identity_claim = 'identity' @@ -284,19 +336,6 @@ def test_decode_invalid_jwt(self): decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim=identity_claim) - # Missing user claims in access token - with self.assertRaises(JWTDecodeError): - token_data = { - 'jti': 'banana', - identity_claim: 'banana', - 'exp': datetime.utcnow() + timedelta(minutes=5), - 'type': 'access', - 'fresh': True - } - encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8') - decode_jwt(encoded_token, 'secret', 'HS256', - csrf=False, identity_claim=identity_claim) - # Bad token type with self.assertRaises(JWTDecodeError): token_data = {