Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
identity_claim: identity,
'fresh': fresh,
'type': 'access',
'user_claims': user_claims,
}

# Add `user_claims` only is not empty or None.
if user_claims:
token_data['user_claims'] = user_claims

if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm)
Expand Down Expand Up @@ -104,7 +108,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")
if 'user_claims' not in data:
raise JWTDecodeError("Missing claim: user_claims")
data['user_claims'] = {}
if csrf:
if 'csrf' not in data:
raise JWTDecodeError("Missing claim: csrf")
Expand Down
65 changes: 52 additions & 13 deletions tests/test_jwt_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ def test_encode_access_token(self):
self.assertLessEqual(exp_seconds, 60 * 5)
self.assertGreater(exp_seconds, 60 * 4)

def test_encode_access_token__no_user_claims(self):
'''
To make JWT shorter, do not add `user_claims` if empty.
'''
secret = 'super-totally-secret-key'
algorithm = 'HS256'
token_expire_delta = timedelta(minutes=5)
identity_claim = 'sub'

# `user_claims` is empty dict
with self.app.test_request_context():
identity = 'user1'
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=False, user_claims={}, csrf=False,
identity_claim=identity_claim)

data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertNotIn('user_claims', data)

# `user_claims` is None
with self.app.test_request_context():
identity = 'user1'
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=False, user_claims=None, csrf=False,
identity_claim=identity_claim)

data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertNotIn('user_claims', data)

def test_encode_invalid_access_token(self):
# Check with non-serializable json
with self.app.test_request_context():
Expand Down Expand Up @@ -212,6 +241,29 @@ def test_decode_jwt(self):
self.assertEqual(data[identity_claim], 'banana')
self.assertEqual(data['type'], 'refresh')

def test_decode_access_token__no_user_claims(self):
'''
Test decoding a valid access token without `user_claims`.
'''
identity_claim = 'sub'
with self.app.test_request_context():
now = datetime.utcnow()
token_data = {
'exp': now + timedelta(minutes=5),
'iat': now,
'nbf': now,
'jti': 'banana',
identity_claim: 'banana',
'fresh': True,
'type': 'access',
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
data = decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim=identity_claim)

self.assertIn('user_claims', data)
self.assertEqual(data['user_claims'], {})

def test_decode_invalid_jwt(self):
with self.app.test_request_context():
identity_claim = 'identity'
Expand Down Expand Up @@ -284,19 +336,6 @@ def test_decode_invalid_jwt(self):
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim=identity_claim)

# Missing user claims in access token
with self.assertRaises(JWTDecodeError):
token_data = {
'jti': 'banana',
identity_claim: 'banana',
'exp': datetime.utcnow() + timedelta(minutes=5),
'type': 'access',
'fresh': True
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim=identity_claim)

# Bad token type
with self.assertRaises(JWTDecodeError):
token_data = {
Expand Down