Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flask_jwt_extended/default_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flask_jwt_extended.config import config


def default_user_claims_callback(userdata):
def default_user_claims_callback(userdata, extra_info_for_claims={}):
"""
By default, we add no additional claims to the access tokens.

Expand Down
8 changes: 4 additions & 4 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,12 @@ def encode_key_loader(self, callback):
self._encode_key_callback = callback
return callback

def _create_refresh_token(self, identity, expires_delta=None):
def _create_refresh_token(self, identity, extra_info_for_claims, expires_delta=None):
if expires_delta is None:
expires_delta = config.refresh_expires

if config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
user_claims = self._user_claims_callback(identity, extra_info_for_claims)
else:
user_claims = None

Expand All @@ -443,7 +443,7 @@ def _create_refresh_token(self, identity, expires_delta=None):
)
return refresh_token

def _create_access_token(self, identity, fresh=False, expires_delta=None):
def _create_access_token(self, identity, extra_info_for_claims={}, fresh=False, expires_delta=None):
if expires_delta is None:
expires_delta = config.access_expires

Expand All @@ -453,7 +453,7 @@ def _create_access_token(self, identity, fresh=False, expires_delta=None):
algorithm=config.algorithm,
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
user_claims=self._user_claims_callback(identity, extra_info_for_claims),
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
Expand Down
8 changes: 4 additions & 4 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _get_jwt_manager():
"application before using this method")


def create_access_token(identity, fresh=False, expires_delta=None):
def create_access_token(identity, extra_info_for_claims={}, fresh=False, expires_delta=None):
"""
Create a new access token.

Expand All @@ -133,10 +133,10 @@ def create_access_token(identity, fresh=False, expires_delta=None):
:return: An encoded access token
"""
jwt_manager = _get_jwt_manager()
return jwt_manager._create_access_token(identity, fresh, expires_delta)
return jwt_manager._create_access_token(identity, extra_info_for_claims, fresh, expires_delta)


def create_refresh_token(identity, expires_delta=None):
def create_refresh_token(identity, extra_info_for_claims={}, expires_delta=None):
"""
Creates a new refresh token.

Expand All @@ -154,7 +154,7 @@ def create_refresh_token(identity, expires_delta=None):
:return: An encoded refresh token
"""
jwt_manager = _get_jwt_manager()
return jwt_manager._create_refresh_token(identity, expires_delta)
return jwt_manager._create_refresh_token(identity, extra_info_for_claims, expires_delta)


def has_user_loader():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_claims_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def app():
jwt = JWTManager(app)

@jwt.user_claims_loader
def add_user_claims(identity):
def add_user_claims(identity, extra_info_for_claims={}):
return {'foo': 'bar'}

@app.route('/protected1', methods=['GET'])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_no_user_claims(app, user_loader_return):
jwtM = get_jwt_manager(app)

@jwtM.user_claims_loader
def empty_user_loader_return(identity):
def empty_user_loader_return(identity, extra_info_for_claims={}):
return user_loader_return

# Identity should not be in the actual token, but should be in the data
Expand Down
14 changes: 7 additions & 7 deletions tests/test_user_claims_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_user_claim_in_access_token(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
def add_claims(identity, extra_info_for_claims={}):
return {'foo': 'bar'}

with app.test_request_context():
Expand All @@ -47,7 +47,7 @@ def test_non_serializable_user_claims(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
def add_claims(identity, extra_info_for_claims={}):
return app

with pytest.raises(TypeError):
Expand All @@ -63,11 +63,11 @@ def __init__(self, username):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(test_obj):
def add_claims(test_obj, extra_info_for_claims={}):
return {'username': test_obj.username}

@jwt.user_identity_loader
def add_claims(test_obj):
def add_claims(test_obj, extra_info_for_claims={}):
return test_obj.username

with app.test_request_context():
Expand All @@ -89,7 +89,7 @@ def test_user_claims_with_different_name(app):
app.config['JWT_USER_CLAIMS'] = 'banana'

@jwt.user_claims_loader
def add_claims(identity):
def add_claims(identity, extra_info_for_claims={}):
return {'foo': 'bar'}

with app.test_request_context():
Expand All @@ -110,7 +110,7 @@ def test_user_claim_not_in_refresh_token(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
def add_claims(identity, extra_info_for_claims={}):
return {'foo': 'bar'}

with app.test_request_context():
Expand All @@ -127,7 +127,7 @@ def test_user_claim_in_refresh_token(app):
jwt = get_jwt_manager(app)

@jwt.user_claims_loader
def add_claims(identity):
def add_claims(identity, extra_info_for_claims={}):
return {'foo': 'bar'}

with app.test_request_context():
Expand Down