Skip to content

Commit

Permalink
added support to pass additional headers in JWT encoding the token (#276
Browse files Browse the repository at this point in the history
)
  • Loading branch information
iamajay authored and vimalloc committed Oct 4, 2019
1 parent 7f39a44 commit e182953
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 34 deletions.
13 changes: 7 additions & 6 deletions flask_jwt_extended/__init__.py
@@ -1,14 +1,15 @@
from .jwt_manager import JWTManager
from .view_decorators import (
fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required,
verify_fresh_jwt_in_request, verify_jwt_in_request,
verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request
)
from .utils import (
create_access_token, create_refresh_token, current_user, decode_token,
get_csrf_token, get_current_user, get_jti, get_jwt_claims, get_jwt_identity,
get_raw_jwt, set_access_cookies, set_refresh_cookies, unset_access_cookies,
unset_jwt_cookies, unset_refresh_cookies
unset_jwt_cookies, unset_refresh_cookies, get_unverified_jwt_headers,
get_raw_jwt_header
)
from .view_decorators import (
fresh_jwt_required, jwt_optional, jwt_refresh_token_required, jwt_required,
verify_fresh_jwt_in_request, verify_jwt_in_request,
verify_jwt_in_request_optional, verify_jwt_refresh_token_in_request
)

__version__ = '3.23.0'
11 changes: 11 additions & 0 deletions flask_jwt_extended/default_callbacks.py
Expand Up @@ -22,6 +22,17 @@ def default_user_claims_callback(userdata):
return {}


def default_jwt_headers_callback(default_headers):
"""
By default header typically consists of two parts: the type of the token,
which is JWT, and the signing algorithm being used, such as HMAC SHA256
or RSA. But we don't set the default header here we set it as empty which
further by default set while encoding the token
:return: default we set None here
"""
return None


def default_user_identity_callback(userdata):
"""
By default, we use the passed in object directly as the jwt identity.
Expand Down
40 changes: 34 additions & 6 deletions flask_jwt_extended/jwt_manager.py
Expand Up @@ -5,6 +5,7 @@
ExpiredSignatureError, InvalidTokenError, InvalidAudienceError,
InvalidIssuerError, DecodeError
)

try:
from flask import _app_ctx_stack as ctx_stack
except ImportError: # pragma: no cover
Expand All @@ -22,8 +23,8 @@
default_unauthorized_callback, default_needs_fresh_token_callback,
default_revoked_token_callback, default_user_loader_error_callback,
default_claims_verification_callback, default_verify_claims_failed_callback,
default_decode_key_callback, default_encode_key_callback
)
default_decode_key_callback, default_encode_key_callback,
default_jwt_headers_callback)
from flask_jwt_extended.tokens import (
encode_refresh_token, encode_access_token
)
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(self, app=None):
self._verify_claims_failed_callback = default_verify_claims_failed_callback
self._decode_key_callback = default_decode_key_callback
self._encode_key_callback = default_encode_key_callback
self._jwt_additional_header_callback = default_jwt_headers_callback

# Register this extension with the flask app now (if it is provided)
if app is not None:
Expand Down Expand Up @@ -454,13 +456,33 @@ def encode_key_loader(self, callback):
self._encode_key_callback = callback
return callback

def _create_refresh_token(self, identity, expires_delta=None, user_claims=None):
def additional_headers_loader(self, callback):
"""
This decorator sets the callback function for adding custom headers to an
access token when :func:`~flask_jwt_extended.create_access_token` is
called. By default, two headers will be added the type of the token, which is JWT,
and the signing algorithm being used, such as HMAC SHA256 or RSA.
*HINT*: The callback function must be a function that takes **no** argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
claims you want included in the access tokens. This returned claims
must be *JSON serializable*.
"""
self._jwt_additional_header_callback = callback
return callback

def _create_refresh_token(self, identity, expires_delta=None, user_claims=None,
headers=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if user_claims is None and config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)

if headers is None:
headers = self._jwt_additional_header_callback(identity)

refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=self._encode_key_callback(identity),
Expand All @@ -470,17 +492,22 @@ def _create_refresh_token(self, identity, expires_delta=None, user_claims=None):
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
json_encoder=config.json_encoder,
headers=headers
)
return refresh_token

def _create_access_token(self, identity, fresh=False, expires_delta=None, user_claims=None):
def _create_access_token(self, identity, fresh=False, expires_delta=None,
user_claims=None, headers=None):
if expires_delta is None:
expires_delta = config.access_expires

if user_claims is None:
user_claims = self._user_claims_callback(identity)

if headers is None:
headers = self._jwt_additional_header_callback(identity)

access_token = encode_access_token(
identity=self._user_identity_callback(identity),
secret=self._encode_key_callback(identity),
Expand All @@ -491,6 +518,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None, user_c
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder
json_encoder=config.json_encoder,
headers=headers
)
return access_token
15 changes: 8 additions & 7 deletions flask_jwt_extended/tokens.py
@@ -1,6 +1,5 @@
import datetime
import uuid

from calendar import timegm

import jwt
Expand All @@ -14,7 +13,7 @@ def _create_csrf_token():


def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
json_encoder=None):
json_encoder=None, headers=None):
uid = _create_csrf_token()
now = datetime.datetime.utcnow()
token_data = {
Expand All @@ -28,13 +27,13 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm,
token_data['exp'] = now + expires_delta
token_data.update(additional_token_data)
encoded_token = jwt.encode(token_data, secret, algorithm,
json_encoder=json_encoder).decode('utf-8')
json_encoder=json_encoder, headers=headers).decode('utf-8')
return encoded_token


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf, identity_claim_key, user_claims_key,
json_encoder=None):
json_encoder=None, headers=None):
"""
Creates a new encoded (utf-8) access token.
Expand All @@ -54,6 +53,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
(boolean)
:param identity_claim_key: Which key should be used to store the identity
:param user_claims_key: Which key should be used to store the user claims
:param headers: valid dict for specifying additional headers in JWT header section
:return: Encoded access token
"""

Expand All @@ -74,12 +74,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
json_encoder=json_encoder)
json_encoder=json_encoder, headers=headers)


def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims,
csrf, identity_claim_key, user_claims_key,
json_encoder=None):
json_encoder=None, headers=None):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -95,6 +95,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
(boolean)
:param identity_claim_key: Which key should be used to store the identity
:param user_claims_key: Which key should be used to store the user claims
:param headers: valid dict for specifying additional headers in JWT header section
:return: Encoded refresh token
"""
token_data = {
Expand All @@ -109,7 +110,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
if csrf:
token_data['csrf'] = _create_csrf_token()
return _encode_jwt(token_data, expires_delta, secret, algorithm,
json_encoder=json_encoder)
json_encoder=json_encoder, headers=headers)


def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
Expand Down
42 changes: 36 additions & 6 deletions flask_jwt_extended/utils.py
@@ -1,7 +1,8 @@
from warnings import warn

from flask import current_app
from werkzeug.local import LocalProxy
from jwt import ExpiredSignatureError
from warnings import warn
from werkzeug.local import LocalProxy

try:
from flask import _app_ctx_stack as ctx_stack
Expand Down Expand Up @@ -29,6 +30,15 @@ def get_raw_jwt():
return getattr(ctx_stack.top, 'jwt', {})


def get_raw_jwt_header():
"""
In a protected endpoint, this will return the python dictionary which has
the JWT headers values. If no
JWT is currently present, an empty dict is returned instead.
"""
return getattr(ctx_stack.top, 'jwt_header', {})


def get_jwt_identity():
"""
In a protected endpoint, this will return the identity of the JWT that is
Expand Down Expand Up @@ -132,7 +142,8 @@ def _get_jwt_manager():
"application before using this method")


def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None):
def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None,
headers=None):
"""
Create a new access token.
Expand All @@ -153,13 +164,17 @@ def create_access_token(identity, fresh=False, expires_delta=None, user_claims=N
'JWT_ACCESS_TOKEN_EXPIRES` config value
(see :ref:`Configuration Options`)
:param user_claims: Optional JSON serializable to override user claims.
:param headers: Optional, valid dict for specifying additional headers in JWT
header section
:return: An encoded access token
"""
jwt_manager = _get_jwt_manager()
return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims)
return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims,
headers=headers)


def create_refresh_token(identity, expires_delta=None, user_claims=None):
def create_refresh_token(identity, expires_delta=None, user_claims=None,
headers=None):
"""
Creates a new refresh token.
Expand All @@ -175,10 +190,13 @@ def create_refresh_token(identity, expires_delta=None, user_claims=None):
'JWT_REFRESH_TOKEN_EXPIRES` config value
(see :ref:`Configuration Options`)
:param user_claims: Optional JSON serializable to override user claims.
:param headers: Optional, valid dict for specifying additional headers in JWT
header section
:return: An encoded refresh token
"""
jwt_manager = _get_jwt_manager()
return jwt_manager._create_refresh_token(identity, expires_delta, user_claims)
return jwt_manager._create_refresh_token(identity, expires_delta, user_claims,
headers=headers)


def has_user_loader():
Expand Down Expand Up @@ -396,3 +414,15 @@ def unset_refresh_cookies(response):
domain=config.cookie_domain,
path=config.refresh_csrf_cookie_path,
samesite=config.cookie_samesite)


def get_unverified_jwt_headers(encoded_token):
"""
Returns the Headers of an encoded JWT without verifying the actual signature of JWT.
Note: The signature is not verified so the header parameters
should not be fully trusted until signature verification is complete
:param encoded_token: The encoded JWT to get the Header from.
:return: JWT header parameters as python dict()
"""
return jwt.get_unverified_header(encoded_token)
18 changes: 12 additions & 6 deletions flask_jwt_extended/view_decorators.py
Expand Up @@ -18,7 +18,7 @@
)
from flask_jwt_extended.utils import (
decode_token, has_user_loader, user_loader, verify_token_claims,
verify_token_not_blacklisted, verify_token_type
verify_token_not_blacklisted, verify_token_type, get_unverified_jwt_headers
)


Expand All @@ -29,8 +29,9 @@ def verify_jwt_in_request():
no token or if the token is invalid.
"""
if request.method not in config.exempt_methods:
jwt_data = _decode_jwt_from_request(request_type='access')
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
ctx_stack.top.jwt_header = jwt_header
verify_token_claims(jwt_data)
_load_user(jwt_data[config.identity_claim_key])

Expand All @@ -48,8 +49,9 @@ def verify_jwt_in_request_optional():
"""
try:
if request.method not in config.exempt_methods:
jwt_data = _decode_jwt_from_request(request_type='access')
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
ctx_stack.top.jwt_header = jwt_header
verify_token_claims(jwt_data)
_load_user(jwt_data[config.identity_claim_key])
except (NoAuthorizationError, InvalidHeaderError):
Expand All @@ -63,8 +65,9 @@ def verify_fresh_jwt_in_request():
token is not marked as fresh.
"""
if request.method not in config.exempt_methods:
jwt_data = _decode_jwt_from_request(request_type='access')
jwt_data, jwt_header = _decode_jwt_from_request(request_type='access')
ctx_stack.top.jwt = jwt_data
ctx_stack.top.jwt_header = jwt_header
fresh = jwt_data['fresh']
if isinstance(fresh, bool):
if not fresh:
Expand All @@ -83,8 +86,9 @@ def verify_jwt_refresh_token_in_request():
exception if there is no token or the token is invalid.
"""
if request.method not in config.exempt_methods:
jwt_data = _decode_jwt_from_request(request_type='refresh')
jwt_data, jwt_header = _decode_jwt_from_request(request_type='refresh')
ctx_stack.top.jwt = jwt_data
ctx_stack.top.jwt_header = jwt_header
_load_user(jwt_data[config.identity_claim_key])


Expand Down Expand Up @@ -283,10 +287,12 @@ def _decode_jwt_from_request(request_type):
# in one place to be valid (not every location).
errors = []
decoded_token = None
jwt_header = None
for get_encoded_token_function in get_encoded_token_functions:
try:
encoded_token, csrf_token = get_encoded_token_function()
decoded_token = decode_token(encoded_token, csrf_token)
jwt_header = get_unverified_jwt_headers(encoded_token)
break
except NoAuthorizationError as e:
errors.append(str(e))
Expand All @@ -309,4 +315,4 @@ def _decode_jwt_from_request(request_type):

verify_token_type(decoded_token, expected_type=request_type)
verify_token_not_blacklisted(decoded_token, request_type)
return decoded_token
return decoded_token, jwt_header
11 changes: 10 additions & 1 deletion tests/test_decode_tokens.py
Expand Up @@ -13,7 +13,7 @@

from flask_jwt_extended import (
JWTManager, create_access_token, decode_token, create_refresh_token,
get_jti
get_jti, get_unverified_jwt_headers
)
from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import JWTDecodeError
Expand Down Expand Up @@ -286,3 +286,12 @@ def test_malformed_token(app):
with pytest.raises(DecodeError):
with app.test_request_context():
decode_token(invalid_token)


def test_jwt_headers(app):
jwt_header = {"foo": "bar"}
with app.test_request_context():
access_token = create_access_token('username', headers=jwt_header)
refresh_token = create_refresh_token('username', headers=jwt_header)
assert get_unverified_jwt_headers(access_token)["foo"] == "bar"
assert get_unverified_jwt_headers(refresh_token)["foo"] == "bar"

0 comments on commit e182953

Please sign in to comment.