Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def identity_claim_key(self):
def user_claims_key(self):
return current_app.config['JWT_USER_CLAIMS']

@property
def additional_claim_keys(self):
return current_app.config['JWT_ADDITIONAL_CLAIMS']

@property
def user_claims_in_refresh_token(self):
return current_app.config['JWT_CLAIMS_IN_REFRESH_TOKEN']
Expand Down
22 changes: 22 additions & 0 deletions flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def default_user_claims_callback(userdata):
return {}


def default_additional_claims_callback(userdata):
"""
By default, we add no additional claims to the access token

:param userdata: data passed in as the ```identity``` argument to the
```create_access_token``` and ```create_refresh_token```
functions
"""
return {}

def default_user_identity_callback(userdata):
"""
By default, we use the passed in object directly as the jwt identity.
Expand Down Expand Up @@ -91,10 +101,22 @@ def default_claims_verification_callback(user_claims):
"""
return True

def default_additional_claims_verification_callback(additional_claims):
"""
By default, we do not do any verification of the additional claims.
"""
return True

def default_claims_verification_failed_callback():
"""
By default, if the user claims verification failed, we return a generic
error message with a 400 status code
"""
return jsonify({'msg': 'User claims verification failed'}), 400

def default_additional_claims_verification_failed_callback():
"""
By default, if the additional claims verification failed, we return a generic
error message with a 400 status code
"""
return jsonify({'msg': 'Additional claims verification failed'}), 400
12 changes: 12 additions & 0 deletions flask_jwt_extended/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class JWTDecodeError(JWTExtendedException):
"""
pass

class JWTEncodeError(JWTExtendedException):
"""
An error encoding a JWT
"""
pass

class InvalidHeaderError(JWTExtendedException):
"""
Expand Down Expand Up @@ -70,3 +75,10 @@ class UserClaimsVerificationError(JWTExtendedException):
indicating that the expected user claims are invalid
"""
pass

class AdditionalClaimsVerificationError(JWTExtendedException):
"""
Error raised when the additional_claims_verification_callback function returns False,
indicating that the expected user claims are invalid
"""
pass
72 changes: 67 additions & 5 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
JWTDecodeError, JWTEncodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError,
UserClaimsVerificationError
UserClaimsVerificationError, AdditionalClaimsVerificationError
)
from flask_jwt_extended.default_callbacks import (
default_expired_token_callback, default_user_claims_callback,
default_user_identity_callback, default_invalid_token_callback,
default_unauthorized_callback, default_needs_fresh_token_callback,
default_revoked_token_callback, default_user_loader_error_callback,
default_claims_verification_callback,
default_claims_verification_failed_callback
default_claims_verification_failed_callback,
default_additional_claims_callback,
default_additional_claims_verification_callback,
default_additional_claims_verification_failed_callback
)
from flask_jwt_extended.tokens import (
encode_refresh_token, encode_access_token
Expand Down Expand Up @@ -43,6 +46,7 @@ def __init__(self, app=None):
# Register the default error handler callback methods. These can be
# overridden with the appropriate loader decorators
self._user_claims_callback = default_user_claims_callback
self._additonal_claims_callback = default_additional_claims_callback
self._user_identity_callback = default_user_identity_callback
self._expired_token_callback = default_expired_token_callback
self._invalid_token_callback = default_invalid_token_callback
Expand All @@ -54,6 +58,8 @@ def __init__(self, app=None):
self._token_in_blacklist_callback = None
self._claims_verification_callback = default_claims_verification_callback
self._claims_verification_failed_callback = default_claims_verification_failed_callback
self._additonal_claims_verification_callback = default_additional_claims_verification_callback
self._additonal_claims_verification_failed_callback = default_additional_claims_verification_failed_callback

# Register this extension with the flask app now (if it is provided)
if app is not None:
Expand Down Expand Up @@ -101,7 +107,7 @@ def handle_invalid_token_error(e):
@app.errorhandler(JWTDecodeError)
def handle_jwt_decode_error(e):
return self._invalid_token_callback(str(e))

@app.errorhandler(WrongTokenError)
def handle_wrong_token_error(e):
return self._invalid_token_callback(str(e))
Expand All @@ -126,6 +132,10 @@ def handler_user_load_error(e):
def handle_failed_user_claims_verification(e):
return self._claims_verification_failed_callback()

@app.errorhandler(AdditionalClaimsVerificationError)
def handle_failed_additional_claims_verification(e):
return self._additonal_claims_verification_failed_callback()

@staticmethod
def _set_default_configuration_options(app):
"""
Expand Down Expand Up @@ -186,6 +196,7 @@ def _set_default_configuration_options(app):

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
app.config.setdefault('JWT_ADDITIONAL_CLAIMS', [])

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

Expand All @@ -204,6 +215,21 @@ def user_claims_loader(self, callback):
self._user_claims_callback = callback
return callback

def additional_claims_loader(self, callback):
"""
This decorator sets the callback function for adding additional claims
the access token when :func:`~flask_jwt_extended.create_access_token` is
called. By defailt, no additional claims will be added.

The callback function must be a function that takes only one argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
claims you want included in the access tokens. This returned claims
must be JSON serializable.
"""
self._additonal_claims_callback = callback
return callback

def user_identity_loader(self, callback):
"""
This decorator sets the callback function for getting the JSON
Expand Down Expand Up @@ -373,21 +399,57 @@ def claims_verification_failed_loader(self, callback):
self._claims_verification_failed_callback = callback
return callback

def additonal_claims_verification_loader(self, callback):
"""
This decorator sets the callback function that will be called when
a protected endpoint is accessed, and will check if the custom claims
in the JWT are valid. By default, this callback is not used. The
error returned if the claims are invalid can be controlled via the
:meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader`
decorator.

This callback must be a function that takes one argument, which is the
custom claims (python dict) present in the JWT, and returns `True` if the
claims are valid, or `False` otherwise.
"""

self._additonal_claims_verification_callback = callback
return callback

def additonal_claims_verification_failed_loader(self, callback):
"""
This decorator sets the callback function that will be called if
the :meth:`~flask_jwt_extended.JWTManager.additonal_claims_verification_loader`
callback returns False, indicating that the user claims are not valid.
The default implementation will return a 400 status code with the JSON:

{"msg": "Additional claims verification failed"}

This callback must be a function that takes no arguments, and returns
a Flask response.
"""
self._additonal_claims_verification_failed_callback = callback
return callback


def _create_refresh_token(self, identity, expires_delta=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
additional_claims = self._additonal_claims_callback(identity)
else:
user_claims = None
additional_claims = None

refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
user_claims=user_claims,
additional_claims=additional_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
Expand All @@ -406,10 +468,10 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None):
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
additional_claims=self._additonal_claims_callback(identity),
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
)
return access_token

43 changes: 37 additions & 6 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import jwt
from werkzeug.security import safe_str_cmp

from flask_jwt_extended.exceptions import JWTDecodeError, CSRFError

from flask_jwt_extended.exceptions import JWTDecodeError, JWTEncodeError, CSRFError
from flask_jwt_extended.config import config

def _create_csrf_token():
return str(uuid.uuid4())

def _check_claims(default_claims, additional_claims):
for claim in default_claims:
if claim in additional_claims:
raise JWTEncodeError("Claim %s in conflict with default claims" % str(claim))

def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
json_encoder=None):
Expand All @@ -26,15 +30,19 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
# and the 'exp' claim is not set.
if expires_delta:
token_data['exp'] = now + expires_delta

# Make sure additional_token_data is in conflict with default claims
_check_claims(['iat', 'nbf', 'jti', 'exp'], additional_token_data)
token_data.update(additional_token_data)

encoded_token = jwt.encode(token_data, secret, algorithm,
json_encoder=json_encoder).decode('utf-8')
return encoded_token


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf, identity_claim_key, user_claims_key,
json_encoder=None):
user_claims, additional_claims, csrf, identity_claim_key,
user_claims_key, json_encoder=None):
"""
Creates a new encoded (utf-8) access token.

Expand All @@ -50,6 +58,8 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
token will remain fresh.
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param additional_claims: Custom claims to include in this token. Object
must be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
Expand All @@ -67,18 +77,25 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
'type': 'access',
}

# Make sure additional_token_data is in conflict with default claims
_check_claims(['fresh', 'type', identity_claim_key], additional_claims)

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims

# Make sure additional claims is a dict before merge
if additional_claims and isinstance(additional_claims, dict):
token_data.update(additional_claims)

if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
json_encoder=json_encoder)


def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
csrf, identity_claim_key, user_claims_key,
additional_claims, csrf, identity_claim_key, user_claims_key,
json_encoder=None):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -91,6 +108,8 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
:type expires_delta: datetime.timedelta or False
:param user_claims: Custom claims to include in this token. This data must
be json serializable
:param additional_claims: Custom claims to include in this token. Object
must be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim_key: Which key should be used to store the identity
Expand All @@ -102,6 +121,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
'type': 'refresh',
}

# Make sure additional_token_data is in conflict with default claims
if additional_claims and isinstance(additional_claims, dict):
_check_claims(['type', identity_claim_key], additional_claims)
token_data.update(additional_claims)

# Don't add extra data to the token if user_claims is empty.
if user_claims:
token_data[user_claims_key] = user_claims
Expand All @@ -113,7 +137,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims


def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
user_claims_key, csrf_value=None):
user_claims_key, additional_claim_keys, csrf_value=None):
"""
Decodes an encoded JWT

Expand All @@ -138,8 +162,15 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
if data['type'] == 'access':
if 'fresh' not in data:
raise JWTDecodeError("Missing claim: fresh")

if user_claims_key not in data:
data[user_claims_key] = {}

if data['type'] != 'refresh' or config.user_claims_in_refresh_token:
for claim in additional_claim_keys:
if claim not in data:
raise JWTDecodeError("Missing claim %s" % str(claim))

if csrf_value:
if 'csrf' not in data:
raise JWTDecodeError("Missing claim: csrf")
Expand Down
Loading