Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def identity_claim_key(self):
def user_claims_key(self):
return current_app.config['JWT_USER_CLAIMS']

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this config option is necessary

@property
def additional_claim_keys(self):
return current_app.config['JWT_ADDITIONAL_CLAIMS']

@property
def user_claims_in_refresh_token(self):
return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN']
Expand Down
22 changes: 22 additions & 0 deletions flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def default_user_claims_callback(userdata):
return {}


def default_additional_claims_callback(userdata):
"""
By default, we add no additional claims to the access token

:param userdata: data passed in as the ```identity``` argument to the
```create_access_token``` and ```create_refresh_token```
functions
"""
return {}

def default_user_identity_callback(userdata):
"""
By default, we use the passed in object directly as the jwt identity.
Expand Down Expand Up @@ -91,10 +101,22 @@ def default_claims_verification_callback(user_claims):
"""
return True

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer to get rid of this and catch the underlying pyjwt errors

def default_additional_claims_verification_callback(additional_claims):
"""
By default, we do not do any verification of the additional claims.
"""
return True

def default_claims_verification_failed_callback():
"""
By default, if the user claims verification failed, we return a generic
error message with a 400 status code
"""
return jsonify({'msg': 'User claims verification failed'}), 400

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we should get rid of this as well. Whatever callback handles the invalid token can handle if any additional claims do not validate as well.

def default_additional_claims_verification_failed_callback():
"""
By default, if the additional claims verification failed, we return a generic
error message with a 400 status code
"""
return jsonify({'msg': 'Additional claims verification failed'}), 400
12 changes: 12 additions & 0 deletions flask_jwt_extended/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class JWTDecodeError(JWTExtendedException):
"""
pass

class JWTEncodeError(JWTExtendedException):
"""
An error encoding a JWT
"""
pass

class InvalidHeaderError(JWTExtendedException):
"""
Expand Down Expand Up @@ -70,3 +75,10 @@ class UserClaimsVerificationError(JWTExtendedException):
indicating that the expected user claims are invalid
"""
pass

class AdditionalClaimsVerificationError(JWTExtendedException):
"""
Error raised when the additional_claims_verification_callback function returns False,
indicating that the expected user claims are invalid
"""
pass
72 changes: 67 additions & 5 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
JWTDecodeError, JWTEncodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError,
UserClaimsVerificationError
UserClaimsVerificationError, AdditionalClaimsVerificationError
)
from flask_jwt_extended.default_callbacks import (
default_expired_token_callback, default_user_claims_callback,
default_user_identity_callback, default_invalid_token_callback,
default_unauthorized_callback, default_needs_fresh_token_callback,
default_revoked_token_callback, default_user_loader_error_callback,
default_claims_verification_callback,
default_claims_verification_failed_callback
default_claims_verification_failed_callback,
default_additional_claims_callback,
default_additional_claims_verification_callback,
default_additional_claims_verification_failed_callback
)
from flask_jwt_extended.tokens import (
encode_refresh_token, encode_access_token
Expand Down Expand Up @@ -43,6 +46,7 @@ def __init__(self, app=None):
# Register the default error handler callback methods. These can be
# overridden with the appropriate loader decorators
self._user_claims_callback = default_user_claims_callback
self._additonal_claims_callback = default_additional_claims_callback
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misspelling on additional here

self._user_identity_callback = default_user_identity_callback
self._expired_token_callback = default_expired_token_callback
self._invalid_token_callback = default_invalid_token_callback
Expand All @@ -54,6 +58,8 @@ def __init__(self, app=None):
self._token_in_blacklist_callback = None
self._claims_verification_callback = default_claims_verification_callback
self._claims_verification_failed_callback = default_claims_verification_failed_callback
self._additonal_claims_verification_callback = default_additional_claims_verification_callback
self._additonal_claims_verification_failed_callback = default_additional_claims_verification_failed_callback

# Register this extension with the flask app now (if it is provided)
if app is not None:
Expand Down Expand Up @@ -101,7 +107,7 @@ def handle_invalid_token_error(e):
@app.errorhandler(JWTDecodeError)
def handle_jwt_decode_error(e):
return self._invalid_token_callback(str(e))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace

@app.errorhandler(WrongTokenError)
def handle_wrong_token_error(e):
return self._invalid_token_callback(str(e))
Expand All @@ -126,6 +132,10 @@ def handler_user_load_error(e):
def handle_failed_user_claims_verification(e):
return self._claims_verification_failed_callback()

@app.errorhandler(AdditionalClaimsVerificationError)
def handle_failed_additional_claims_verification(e):
return self._additonal_claims_verification_failed_callback()

@staticmethod
def _set_default_configuration_options(app):
"""
Expand Down Expand Up @@ -186,6 +196,7 @@ def _set_default_configuration_options(app):

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
app.config.setdefault('JWT_ADDITIONAL_CLAIMS', [])

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

Expand All @@ -204,6 +215,21 @@ def user_claims_loader(self, callback):
self._user_claims_callback = callback
return callback

def additional_claims_loader(self, callback):
"""
This decorator sets the callback function for adding additional claims
the access token when :func:`~flask_jwt_extended.create_access_token` is
called. By defailt, no additional claims will be added.

The callback function must be a function that takes only one argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
claims you want included in the access tokens. This returned claims
must be JSON serializable.
"""
self._additonal_claims_callback = callback
return callback

def user_identity_loader(self, callback):
"""
This decorator sets the callback function for getting the JSON
Expand Down Expand Up @@ -373,21 +399,57 @@ def claims_verification_failed_loader(self, callback):
self._claims_verification_failed_callback = callback
return callback

def additonal_claims_verification_loader(self, callback):
"""
This decorator sets the callback function that will be called when
a protected endpoint is accessed, and will check if the custom claims
in the JWT are valid. By default, this callback is not used. The
error returned if the claims are invalid can be controlled via the
:meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader`
decorator.

This callback must be a function that takes one argument, which is the
custom claims (python dict) present in the JWT, and returns `True` if the
claims are valid, or `False` otherwise.
"""

self._additonal_claims_verification_callback = callback
return callback

def additonal_claims_verification_failed_loader(self, callback):
"""
This decorator sets the callback function that will be called if
the :meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader`
callback returns False, indicating that the user claims are not valid.
The default implementation will return a 400 status code with the JSON:

{"msg": "Additional claims verification failed"}

This callback must be a function that takes no arguments, and returns
a Flask response.
"""
self._additonal_claims_verification_failed_callback = callback
return callback


def _create_refresh_token(self, identity, expires_delta=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
additional_claims = self._additonal_claims_callback(identity)
else:
user_claims = None
additional_claims = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The additional claims (aud for example) should probably be in the refresh token regardless of the user_claims_in_refresh_token option.


refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
user_claims=user_claims,
additional_claims=additional_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
Expand All @@ -406,10 +468,10 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None):
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
additional_claims=self._additonal_claims_callback(identity),
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
)
return access_token

50 changes: 43 additions & 7 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import jwt
from werkzeug.security import safe_str_cmp

from flask_jwt_extended.exceptions import JWTDecodeError, CSRFError

from flask_jwt_extended.exceptions import JWTDecodeError, JWTEncodeError, CSRFError
from flask_jwt_extended.config import config

def _create_csrf_token():
return str(uuid.uuid4())

def _check_claims(default_claims, additional_claims):
for claim in default_claims:
if claim in additional_claims:
raise JWTEncodeError("Claim %s in conflict with default claims" % str(claim))

def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
json_encoder=None):
Expand All @@ -26,15 +30,19 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
# and the 'exp' claim is not set.
if expires_delta:
token_data['exp'] = now + expires_delta

# Make sure additional_token_data is in conflict with default claims
_check_claims(['iat', 'nbf', 'jti', 'exp'], additional_token_data)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be less prone to breaking with future updates if this actually compared the keys in the newly created token with the additional token data, so that these hard coded lists could go away.

token_data.update(additional_token_data)

encoded_token = jwt.encode(token_data, secret, algorithm,
json_encoder=json_encoder).decode('utf-8')
return encoded_token


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf, identity_claim_key, user_claims_key,
json_encoder=None):
user_claims, additional_claims, csrf, identity_claim_key,
user_claims_key, json_encoder=None):
"""
Creates a new encoded (utf-8) access token.

Expand All @@ -50,6 +58,8 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
token will remain fresh.
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param additional_claims: Custom claims to include in this token. Object
must be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
Expand All @@ -67,18 +77,25 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
'type': 'access',
}

# Make sure additional_token_data is in conflict with default claims
_check_claims(['fresh', 'type', identity_claim_key], additional_claims)

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims

# Make sure additional claims is a dict before merge
if additional_claims and isinstance(additional_claims, dict):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an error should be raised here instead of silently doing nothing if additional claims is not a dict.

token_data.update(additional_claims)

if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
json_encoder=json_encoder)


def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
csrf, identity_claim_key, user_claims_key,
additional_claims, csrf, identity_claim_key, user_claims_key,
json_encoder=None):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -91,6 +108,8 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
:type expires_delta: datetime.timedelta or False
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param additional_claims: Custom claims to include in this token. Object
must be json serializable
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only must it be json serializable, it must be a dictionary. Should update the wording here to match.

:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
Expand All @@ -102,6 +121,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
'type': 'refresh',
}

# Make sure additional_token_data is in conflict with default claims
if additional_claims and isinstance(additional_claims, dict):
_check_claims(['type', identity_claim_key], additional_claims)
token_data.update(additional_claims)

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims
Expand All @@ -113,7 +137,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims


def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
user_claims_key, csrf_value=None):
user_claims_key, additional_claim_keys, csrf_value=None):
"""
Decodes an encoded JWT

Expand All @@ -125,8 +149,13 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
:param csrf_value: Expected double submit csrf value
:return: Dictionary containing contents of the JWT
"""
# The validation decorator for additional claims must evaluate these!
options = {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this, simply checking if these claims exist in the JWT is not enough to verify them. Instead I think we need a way to pass the expected verify options to the underlying pyjwt decode function.

'verify_aud': False,
'verify_iss': False,
}
# This call verifies the ext, iat, and nbf claims
data = jwt.decode(encoded_token, secret, algorithms=[algorithm])
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], options=options)

# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
Expand All @@ -138,8 +167,15 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
if data['type'] == 'access':
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")

if user_claims_key not in data:
data[user_claims_key] = {}

if data['type'] != 'refresh' or config.user_claims_in_refresh_token:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to above, I think we shouldn't have this code path, and instead rely on the existing code path for verifying custom data in user_claims, and pyjwt.decode errors for standard JWT claims.

for claim in additional_claim_keys:
if claim not in data:
raise JWTDecodeError("Missing claim %s" % str(claim))

if csrf_value:
if 'csrf' not in data:
raise JWTDecodeError("Missing claim: csrf")
Expand Down
Loading