Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ General Options:
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity.
For interoperativity, the JWT RFC recommends using ``'sub'``.
Defaults to ``'identity'``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def cookie_max_age(self):
# seconds a long ways in the future
return None if self.session_cookie else 2147483647 # 2^31

@property
def identity_claim(self):
return current_app.config['JWT_IDENTITY_CLAIM']

config = _Config()


8 changes: 6 additions & 2 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh'])

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')

def user_claims_loader(self, callback):
"""
This sets the callback method for adding custom user claims to a JWT.
Expand Down Expand Up @@ -319,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None):
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return refresh_token

Expand Down Expand Up @@ -352,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return access_token

17 changes: 10 additions & 7 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf):
user_claims, csrf, identity_claim):
"""
Creates a new encoded (utf-8) access token.

Expand All @@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded access token
"""
# Create the jwt
token_data = {
'identity': identity,
identity_claim: identity,
'fresh': fresh,
'type': 'access',
'user_claims': user_claims,
Expand All @@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
"""
Creates a new encoded (utf-8) refresh token.

Expand All @@ -65,18 +66,19 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
(datetime.timedelta)
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded refresh token
"""
token_data = {
'identity': identity,
identity_claim: identity,
'type': 'refresh',
}
if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def decode_jwt(encoded_token, secret, algorithm, csrf):
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
"""
Decodes an encoded JWT

Expand All @@ -85,6 +87,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
:param algorithm: Algorithm used to encode the JWT
:param csrf: If this token is expected to have a CSRF double submit
value present (boolean)
:param identity_claim: expected claim that is used to identify the subject
:return: Dictionary containing contents of the JWT
"""
# This call verifies the ext, iat, and nbf claims
Expand All @@ -93,8 +96,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
raise JWTDecodeError("Missing claim: jti")
if 'identity' not in data:
raise JWTDecodeError("Missing claim: identity")
if identity_claim not in data:
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
if 'type' not in data or data['type'] not in ('refresh', 'access'):
raise JWTDecodeError("Missing or invalid claim: type")
if data['type'] == 'access':
Expand Down
13 changes: 10 additions & 3 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_jwt_identity():
Returns the identity of the JWT in this context. If no JWT is present,
None is returned.
"""
return get_raw_jwt().get('identity', None)
return get_raw_jwt().get(config.identity_claim, None)


def get_jwt_claims():
Expand Down Expand Up @@ -63,7 +63,8 @@ def decode_token(encoded_token):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)


Expand Down Expand Up @@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs):


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
token = decode_jwt(
encoded_token,
config.decode_key,
config.algorithm,
csrf=True,
identity_claim=config.identity_claim
)
return token['csrf']


Expand Down
11 changes: 9 additions & 2 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,13 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
token = parts[1]

return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)
return decode_jwt(
encoded_token=token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=False,
identity_claim=config.identity_claim
)


def _decode_jwt_from_cookies(request_type):
Expand All @@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)

# Verify csrf double submit tokens match if required
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_default_configs(self):
self.assertEqual(config.decode_key, self.app.secret_key)
self.assertEqual(config.cookie_max_age, None)

self.assertEqual(config.identity_claim, 'identity')

def test_override_configs(self):
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
self.app.config['JWT_HEADER_NAME'] = 'TestHeader'
Expand Down Expand Up @@ -86,6 +88,8 @@ def test_override_configs(self):

self.app.secret_key = 'banana'

self.app.config['JWT_IDENTITY_CLAIM'] = 'foo'

with self.app.test_request_context():
self.assertEqual(config.token_location, ['cookies'])
self.assertEqual(config.jwt_in_cookies, True)
Expand Down Expand Up @@ -122,6 +126,8 @@ def test_override_configs(self):

self.assertEqual(config.cookie_max_age, 2147483647)

self.assertEqual(config.identity_claim, 'foo')

def test_invalid_config_options(self):
with self.app.test_request_context():
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'
Expand Down
Loading