diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index fd3f7c4a..00000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "docs/_themes"] - path = docs/_themes - url = https://github.com/pallets/flask-sphinx-themes.git diff --git a/README.md b/README.md index 867c666d..92938567 100644 --- a/README.md +++ b/README.md @@ -37,12 +37,7 @@ $ tox ``` ### Generating Documentation -You can generate a local copy of the documentation. First, make sure you have -the flask sphinx theme cloned -``` -$ git submodule update --init -``` - +You can generate a local copy of the documentation. In the `docs` directory, run: Then, in the docs directory, run ``` $ make clean && make html diff --git a/docs/_themes b/docs/_themes deleted file mode 160000 index d5b65706..00000000 --- a/docs/_themes +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d5b65706937214e98ce08c17c87439b3f8369c8c diff --git a/docs/blacklist_and_token_revoking.rst b/docs/blacklist_and_token_revoking.rst index bc352d4b..6bcd307a 100644 --- a/docs/blacklist_and_token_revoking.rst +++ b/docs/blacklist_and_token_revoking.rst @@ -3,35 +3,39 @@ Blacklist and Token Revoking This extension supports optional token revoking out of the box. This will allow you to revoke a specific token so that it can no longer access your endpoints. -In order to revoke a token, we need some storage where we can save a list of all -the tokens we have created, as well as if they have been revoked or not. In order -to make the underlying storage as agnostic as possible, we use `simplekv -`_ to provide assess to a variety of backends. - -In production, it is important to use a backend that can have some sort of -persistent storage, so we don't 'forget' that we revoked a token if the flask -process is restarted. We also need something that can be safely used by the -multiple thread and processes running your application. At present we believe -redis is a good fit for this. It has the added benefit of removing expired tokens -from the store automatically, so it wont blow up into something huge. - -We also have to choose what tokens we want to check against the blacklist. We could -check all tokens (refresh and access), or only the refresh tokens. There are pros -and cons to either way, namely extra overhead on jwt_required endpoints vs someone -being able to use an access token freely until it expires. In this example, we are -looking at all tokens: + +You will have to choose what tokens you want to check against the blacklist. In +most cases, you will probably want to check both refresh and access tokens, which +is the default behavior. However, if the extra overhead of checking tokens is a +concern you could instead only check the refresh tokens, and set the access +tokens to have a short expires time so any damage a compromised token could +cause is minimal. + +Blacklisting works by is providing a callback function to this extension, using the +**@jwt.token_in_blacklist_loader** decorator. This method will be called whenever the +specified tokens (``'access'`` and/or ``'refresh'``) are used to access a protected endpoint. +If the callback function says that the token is revoked, we will not allow the +call to continue, otherwise we will allow the call to access the endpoint as normal. + + +Here is a basic example of this in action. + .. literalinclude:: ../examples/blacklist.py -If you want better performance (ie, not having to check the blacklist store -with every request), you could check only the refresh tokens. This makes it -so any call to a jwt_required endpoint does not need to check the blacklist -store, but on the flip side would allow a compromised access token to be used -until it expired. If using the approach, you should set the access tokens to -have a very short lifetime to help combat this. - -It's worth noting that if your selected backend support the `time to live mixin -`_ (such as redis), -keys will be automatically deleted from the store at some point after they have -expired. This prevents your store from blowing up with old keys without you having -to do any work to prune it back down. +In production, you will likely want to use either a database or in memory store +(such as redis) to store your tokens. In memory stores are great if you are wanting +to revoke a token when the users logs out, as they are blazing fast. A downside +to using redis is that in the case of a power outage or other such event, it's +possible that you might 'forget' that some tokens have been revoked, depending +on if the redis data was synced to disk. + +In contrast to that, databases are great if the data persistance is of the highest +importance (for example, if you have very long lived tokens that other developers +use to access your api), or if you want to add some addition features like showing +users all of their active tokens, and letting them revoke and unrevoke those tokens. + +For more in depth examples of these, check out: + +- https://github.com/vimalloc/flask-jwt-extended/examples/redis_blacklist.py +- https://github.com/vimalloc/flask-jwt-extended/examples/database_blacklist diff --git a/docs/options.rst b/docs/options.rst index d56b7631..3e23bbb0 100644 --- a/docs/options.rst +++ b/docs/options.rst @@ -111,10 +111,8 @@ Blacklist Options: ================================= ========================================= ``JWT_BLACKLIST_ENABLED`` Enable/disable token blacklisting and revoking. Defaults to ``False`` -``JWT_BLACKLIST_STORE`` Where to save created and revoked tokens. `See here - `_ for options. - Only used if blacklisting is enabled. -``JWT_BLACKLIST_TOKEN_CHECKS`` What token types to check against the blacklist. Options are - ``'refresh'`` or ``'all'``. Defaults to ``'refresh'``. +``JWT_BLACKLIST_TOKEN_CHECKS`` What token types to check against the blacklist. The options are + ``'refresh'`` or ``'access'``. You can pass in a list to check + more then one type. Defaults to ``['access', 'refresh']``. Only used if blacklisting is enabled. ================================= ========================================= diff --git a/examples/blacklist.py b/examples/blacklist.py index aac7dcd5..bc016096 100644 --- a/examples/blacklist.py +++ b/examples/blacklist.py @@ -1,34 +1,48 @@ -import datetime - -import simplekv.memory from flask import Flask, request, jsonify -from flask_jwt_extended import JWTManager, jwt_required, \ - get_jwt_identity, revoke_token, unrevoke_token, \ - get_stored_tokens, get_all_stored_tokens, create_access_token, \ - create_refresh_token, jwt_refresh_token_required, \ - get_raw_jwt, get_stored_token +from flask_jwt_extended import ( + JWTManager, jwt_required, get_jwt_identity, + create_access_token, create_refresh_token, + jwt_refresh_token_required, get_raw_jwt +) # Setup flask app = Flask(__name__) -app.secret_key = 'super-secret' +app.secret_key = 'ChangeMe!' -# Enable and configure the JWT blacklist / token revoke. We are using -# an in memory store for this example. In production, you should -# use something persistent (such as redis, memcached, sqlalchemy). -# See here for options: http://pythonhosted.org/simplekv/ +# Enable blacklisting and specify what kind of tokens to check +# against the blacklist app.config['JWT_BLACKLIST_ENABLED'] = True -app.config['JWT_BLACKLIST_STORE'] = simplekv.memory.DictStore() - -# Check all tokens (access and refresh) to see if they have been revoked. -# You can alternately check only the refresh tokens here, by setting this -# to 'refresh' instead of 'all' -app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all' -app.config['JWT_ACCESS_TOKEN_EXPIRES'] = datetime.timedelta(minutes=5) - +app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] jwt = JWTManager(app) +# A storage engine to save revoked tokens. In production if +# speed is the primary concern, redis is a good bet. If data +# persistence is more important for you, postgres is another +# great option. In this example, we will be using an in memory +# store, just to show you how this might work. For more +# complete examples, check out these: +# https://github.com/vimalloc/flask-jwt-extended/examples/redis_blacklist.py +# https://github.com/vimalloc/flask-jwt-extended/examples/database_blacklist +blacklist = set() + + +# For this example, we are just checking if the tokens jti +# (unique identifier) is in the blacklist set. This could +# be made more complex, for example storing all tokens +# into the blacklist with a revoked status when created, +# and returning the revoked status in this call. This +# would allow you to have a list of all created tokens, +# and to consider tokens that aren't in the blacklist +# (aka tokens you didn't create) as revoked. These are +# just two options, and this can be tailored to whatever +# your application needs. +@jwt.token_in_blacklist_loader +def check_if_token_in_blacklist(decrypted_token): + jti = decrypted_token['jti'] + return jti in blacklist + # Standard login endpoint @app.route('/login', methods=['POST']) @@ -45,7 +59,8 @@ def login(): return jsonify(ret), 200 -# Standard refresh endpoint +# Standard refresh endpoint. A blacklisted refresh token +# will not be able to access this endpoint @app.route('/refresh', methods=['POST']) @jwt_refresh_token_required def refresh(): @@ -56,87 +71,26 @@ def refresh(): return jsonify(ret), 200 -# Helper method to revoke the current token used to access -# a protected endpoint -def _revoke_current_token(): - current_token = get_raw_jwt() - jti = current_token['jti'] - revoke_token(jti) - - # Endpoint for revoking the current users access token -@app.route('/logout', methods=['POST']) +@app.route('/logout', methods=['DELETE']) @jwt_required def logout(): - try: - _revoke_current_token() - except KeyError: - return jsonify({ - 'msg': 'Access token not found in the blacklist store' - }), 500 + jti = get_raw_jwt()['jti'] + blacklist.add(jti) return jsonify({"msg": "Successfully logged out"}), 200 # Endpoint for revoking the current users refresh token -@app.route('/logout2', methods=['POST']) +@app.route('/logout2', methods=['DELETE']) @jwt_refresh_token_required def logout2(): - try: - _revoke_current_token() - except KeyError: - return jsonify({ - 'msg': 'Refresh token not found in the blacklist store' - }), 500 + jti = get_raw_jwt()['jti'] + blacklist.add(jti) return jsonify({"msg": "Successfully logged out"}), 200 -# Endpoint for listing tokens that have the same identity as you -# NOTE: This is currently very inefficient. -@app.route('/auth/tokens', methods=['GET']) -@jwt_required -def list_identity_tokens(): - username = get_jwt_identity() - return jsonify(get_stored_tokens(username)), 200 - - -# Endpoint for listing all tokens. In your app, you should either -# not expose this endpoint, or put some addition security on top -# of it so only trusted users (administrators, etc) can access it -@app.route('/auth/all-tokens') -def list_all_tokens(): - return jsonify(get_all_stored_tokens()), 200 - - -# Endpoint for allowing users to revoke their own tokens. -@app.route('/auth/tokens/revoke/', methods=['PUT']) -@jwt_required -def change_jwt_revoke_state(jti): - username = get_jwt_identity() - try: - token_data = get_stored_token(jti) - if token_data['token']['identity'] != username: - raise KeyError - revoke_token(jti) - return jsonify({"msg": "Token successfully revoked"}), 200 - except KeyError: - return jsonify({'msg': 'Token not found'}), 404 - - -# Endpoint for allowing users to un-revoke their own tokens. -@app.route('/auth/tokens/unrevoke/', methods=['PUT']) -@jwt_required -def change_jwt_unrevoke_state(jti): - username = get_jwt_identity() - try: - token_data = get_stored_token(jti) - if token_data['token']['identity'] != username: - raise KeyError - unrevoke_token(jti) - return jsonify({"msg": "Token successfully unrevoked"}), 200 - except KeyError: - return jsonify({'msg': 'Token not found'}), 404 - - +# This will now prevent users with blacklisted tokens from +# accessing this endpoint @app.route('/protected', methods=['GET']) @jwt_required def protected(): diff --git a/examples/database_blacklist/README.md b/examples/database_blacklist/README.md new file mode 100644 index 00000000..7e057dee --- /dev/null +++ b/examples/database_blacklist/README.md @@ -0,0 +1,23 @@ +# Blacklist with a database +Database are a common choice for storing blacklist tokens. It has many +benefits over an in memory store, like redis. The most obvious benefit of +using a database is data consistency. If you add something to the database, +you don't need to worry about it vanishing in an event like a power outage. +This is huge if you need to revoke long lived keys (for example, keys that +you give to another developer so they can access your API). Another advantage +of using a database is that you have easy access to all of the relational +data stored in there. You can easily and efficiently get a list of all tokens +that belong to a given user, and revoke or unrevoke those tokens with ease. +This is very handy if you want to provide a user with a way to see all the +active tokens they have with your service. + +Databases also have some cons compared to an in memory store, namely that +they are potentially slower, and they may grow huge over time and need to be +manually pruned back down. + +This project contains example code for you you might implement a blacklist +using a database, with some more complex features that might benefit your +application. For ease of use, we will use flask-sqlalchey with an in +memory data store, but in production I would highly recommend using postgres. +Please note that this code is only an example, and although I do my best to +insure its quality, it has not been thoroughly tested. diff --git a/examples/database_blacklist/__init__.py b/examples/database_blacklist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/database_blacklist/app.py b/examples/database_blacklist/app.py new file mode 100644 index 00000000..f562a178 --- /dev/null +++ b/examples/database_blacklist/app.py @@ -0,0 +1,116 @@ +from flask import Flask, request, jsonify + +from extensions import jwt, db +from exceptions import TokenNotFound +from flask_jwt_extended import ( + jwt_refresh_token_required, get_jwt_identity, create_access_token, + create_refresh_token, jwt_required +) +from blacklist_helpers import ( + is_token_revoked, add_token_to_database, get_user_tokens, + revoke_token, unrevoke_token, + prune_database +) + + +# We will use an in memory sqlite database for this example. In production, +# I would recommend postgres. +def create_app(): + app = Flask(__name__) + + app.secret_key = 'ChangeMe!' + app.config['JWT_BLACKLIST_ENABLED'] = True + app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] + app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite://" + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + + db.init_app(app) + jwt.init_app(app) + + # In a real application, these would likely be blueprints + register_endpoints(app) + + return app + + +def register_endpoints(app): + # Make sure the sqlalchemy database is created + @app.before_first_request + def setup_sqlalchemy(): + db.create_all() + + # Define our callback function to check if a token has been revoked or not + @jwt.token_in_blacklist_loader + def check_if_token_revoked(decoded_token): + return is_token_revoked(decoded_token) + + @app.route('/auth/login', methods=['POST']) + def login(): + username = request.json.get('username', None) + password = request.json.get('password', None) + if username != 'test' or password != 'test': + return jsonify({"msg": "Bad username or password"}), 401 + + # Create our JWTs + access_token = create_access_token(identity=username) + refresh_token = create_refresh_token(identity=username) + + # Store the tokens in our store with a status of not currently revoked. + add_token_to_database(access_token) + add_token_to_database(refresh_token) + + ret = { + 'access_token': access_token, + 'refresh_token': refresh_token + } + return jsonify(ret), 201 + + # A revoked refresh tokens will not be able to access this endpoint + @app.route('/auth/refresh', methods=['POST']) + @jwt_refresh_token_required + def refresh(): + # Do the same thing that we did in the login endpoint here + current_user = get_jwt_identity() + access_token = create_access_token(identity=current_user) + add_token_to_database(access_token) + return jsonify({'access_token': access_token}), 201 + + # Provide a way for a user to look at their tokens + @app.route('/auth/token', methods=['GET']) + @jwt_required + def get_tokens(): + user_identity = get_jwt_identity() + all_tokens = get_user_tokens(user_identity) + ret = [token.to_dict() for token in all_tokens] + return jsonify(ret), 200 + + # Provide a way for a user to revoke/unrevoke their tokens + @app.route('/auth/token/', methods=['PUT']) + @jwt_required + def modify_token(token_id): + # Get and verify the desired revoked status from the body + json_data = request.get_json(silent=True) + if not json_data: + return jsonify({"msg": "Missing 'revoke' in body"}), 400 + revoke = json_data.get('revoke', None) + if revoke is None: + return jsonify({"msg": "Missing 'revoke' in body"}), 400 + if not isinstance(revoke, bool): + return jsonify({"msg": "'revoke' must be a boolean"}), 400 + + # Revoke or unrevoke the token based on what was passed to this function + user_identity = get_jwt_identity() + try: + if revoke: + revoke_token(token_id, user_identity) + return jsonify({'msg': 'Token revoked'}), 200 + else: + unrevoke_token(token_id, user_identity) + return jsonify({'msg': 'Token unrevoked'}), 200 + except TokenNotFound: + return jsonify({'msg': 'The specified token was not found'}), 404 + + +if __name__ == '__main__': + app = create_app() + app.run(debug=True) diff --git a/examples/database_blacklist/blacklist_helpers.py b/examples/database_blacklist/blacklist_helpers.py new file mode 100644 index 00000000..90cb3e04 --- /dev/null +++ b/examples/database_blacklist/blacklist_helpers.py @@ -0,0 +1,102 @@ +from datetime import datetime + +from sqlalchemy.orm.exc import NoResultFound +from flask_jwt_extended import decode_token + +from exceptions import TokenNotFound +from database import TokenBlacklist +from extensions import db + + +def _epoch_utc_to_datetime(epoch_utc): + """ + Helper function for converting epoch timestamps (as stored in JWTs) into + python datetime objects (which are easier to use with sqlalchemy). + """ + return datetime.fromtimestamp(epoch_utc) + + +def add_token_to_database(encoded_token): + """ + Adds a new token to the database. It is not revoked when it is added. + """ + decoded_token = decode_token(encoded_token) + jti = decoded_token['jti'] + token_type = decoded_token['type'] + user_identity = decoded_token['identity'] + expires = _epoch_utc_to_datetime(decoded_token['exp']) + revoked = False + + db_token = TokenBlacklist( + jti=jti, + token_type=token_type, + user_identity=user_identity, + expires=expires, + revoked=revoked, + ) + db.session.add(db_token) + db.session.commit() + + +def is_token_revoked(decoded_token): + """ + Checks if the given token is revoked or not. Because we are adding all the + tokens that we create into this database, if the token is not present + in the database we are going to consider it revoked, as we don't know where + it was created. + """ + jti = decoded_token['jti'] + try: + token = TokenBlacklist.query.filter_by(jti=jti).one() + return token.revoked + except NoResultFound: + return True + + +def get_user_tokens(user_identity): + """ + Returns all of the tokens, revoked and unrevoked, that are stored for the + given user + """ + return TokenBlacklist.query.filter_by(user_identity=user_identity).all() + + +def revoke_token(token_id, user): + """ + Revokes the given token. Raises a TokenNotFound error if the token does + not exist in the database + """ + try: + token = TokenBlacklist.query.filter_by(id=token_id, user_identity=user).one() + token.revoked = True + db.session.commit() + except NoResultFound: + raise TokenNotFound("Could not find the token {}".format(token_id)) + + +def unrevoke_token(token_id, user): + """ + Unrevokes the given token. Raises a TokenNotFound error if the token does + not exist in the database + """ + try: + token = TokenBlacklist.query.filter_by(id=token_id, user_identity=user).one() + token.revoked = False + db.session.commit() + except NoResultFound: + raise TokenNotFound("Could not find the token {}".format(token_id)) + + +def prune_database(): + """ + Delete tokens that have expired from the database. + + How (and if) you call this is entirely up you. You could expose it to an + endpoint that only administrators could call, you could run it as a cron, + set it up with flask cli, etc. + """ + now = datetime.now() + expired = TokenBlacklist.query.filter(TokenBlacklist.expires < now).all() + for token in expired: + db.session.delete(token) + db.session.commit() diff --git a/examples/database_blacklist/database.py b/examples/database_blacklist/database.py new file mode 100644 index 00000000..672473f5 --- /dev/null +++ b/examples/database_blacklist/database.py @@ -0,0 +1,20 @@ +from extensions import db + + +class TokenBlacklist(db.Model): + id = db.Column(db.Integer, primary_key=True) + jti = db.Column(db.String(36), nullable=False) + token_type = db.Column(db.String(10), nullable=False) + user_identity = db.Column(db.String(50), nullable=False) + revoked = db.Column(db.Boolean, nullable=False) + expires = db.Column(db.DateTime, nullable=False) + + def to_dict(self): + return { + 'token_id': self.id, + 'jti': self.jti, + 'token_type': self.token_type, + 'user_identity': self.user_identity, + 'revoked': self.revoked, + 'expires': self.expires + } diff --git a/examples/database_blacklist/exceptions.py b/examples/database_blacklist/exceptions.py new file mode 100644 index 00000000..5eb74acb --- /dev/null +++ b/examples/database_blacklist/exceptions.py @@ -0,0 +1,7 @@ + + +class TokenNotFound(Exception): + """ + Indicates that a token could not be found in the database + """ + pass diff --git a/examples/database_blacklist/extensions.py b/examples/database_blacklist/extensions.py new file mode 100644 index 00000000..77a19aff --- /dev/null +++ b/examples/database_blacklist/extensions.py @@ -0,0 +1,6 @@ +from flask_jwt_extended import JWTManager +from flask_sqlalchemy import SQLAlchemy + +jwt = JWTManager() +db = SQLAlchemy() + diff --git a/examples/redis_blacklist.py b/examples/redis_blacklist.py new file mode 100644 index 00000000..a6316fbd --- /dev/null +++ b/examples/redis_blacklist.py @@ -0,0 +1,137 @@ +# Redis is a very quick in memory store. The benefits of using redis is that +# things will generally speedy, and it can be (mostly) persistent by dumping +# the data to disk (see: https://redis.io/topics/persistence). The drawbacks +# to using redis is you have a higher chance of encountering data loss (in +# this case, 'forgetting' that a token was revoked), due to events like +# power outages in between making a change to redis and that change being +# dumped for a disk. +# +# So when does it make sense to use redis for a blacklist? If you are blacklist +# every token on logout but doing nothing besides that (not keeping track of +# what tokens are blacklisted, not providing the option un-revoke blacklisted +# tokens, or view tokens that are currently active for a given user), then redis +# is a great choice. Worst case, a few tokens might slip between the cracks in +# the case of a power outage or other such event, but 99.999% of the time tokens +# will be properly blacklisted, and the security of your application should be +# peachy. +# +# Redis also has the benefit of supporting an expires time when storing data. +# Utilizing this, you will not need to manually prune back down the data +# store to keep it from blowing up on you over time. We will show how this +# could work in this example. +# +# If you intend to use some of the other features in your blacklist (tracking +# what tokens are currently active, option to revoke or unrevoke specific +# tokens, etc), data integrity is probably more important to your app then +# raw performance, in which case a sql base solution (such as postgres) is +# probably a better fit for your blacklist. Check out the "sql_blacklist.py" +# example for how that might work. +import redis +from datetime import timedelta +from flask import Flask, request, jsonify +from flask_jwt_extended import ( + JWTManager, create_access_token, create_refresh_token, get_jti, + jwt_refresh_token_required, get_jwt_identity, jwt_required, get_raw_jwt +) + +app = Flask(__name__) +app.secret_key = 'ChangeMe!' + +# Setup the flask-jwt-extended extension. See: +# http://flask-jwt-extended.readthedocs.io/en/latest/options.html +ACCESS_EXPIRES = timedelta(minutes=15) +REFRESH_EXPIRES = timedelta(days=30) +app.config['JWT_ACCESS_TOKEN_EXPIRES'] = ACCESS_EXPIRES +app.config['JWT_REFRESH_TOKEN_EXPIRES'] = REFRESH_EXPIRES +app.config['JWT_BLACKLIST_ENABLED'] = True +app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] +jwt = JWTManager(app) + +# Setup our redis connection for storing the blacklisted tokens +revoked_store = redis.StrictRedis(host='localhost', port=6379, db=0, + decode_responses=True) + + +# Create our function to check if a token has been blacklisted. In this simple +# case, we will just store the tokens jti (unique identifier) in the redis +# store whenever we create it with a revoked status of False. This function +# will grab the revoked status from the store and return it. If a token doesn't +# exist in our store, we don't know where it came from (as we are adding newly +# created # tokens to our store), so we are going to considered to be a +# revoked token for safety purposes. This is obviously optional. +@jwt.token_in_blacklist_loader +def check_if_token_in_blacklist(decrypted_token): + jti = decrypted_token['jti'] + entry = revoked_store.get(jti) + if entry is None: + return False + return entry == 'true' + + +@app.route('/auth/login', methods=['POST']) +def login(): + username = request.json.get('username', None) + password = request.json.get('password', None) + if username != 'test' or password != 'test': + return jsonify({"msg": "Bad username or password"}), 401 + + # Create our JWTs + access_token = create_access_token(identity=username) + refresh_token = create_refresh_token(identity=username) + + # Store the tokens in our store with a status of not currently revoked. We + # can use the `get_jti()` method to get the unique identifier string for + # each token. We can also set an expires time on these tokens in redis, + # so they will get automatically removed after they expire. We will set + # everything to be automatically removed shortly after the token expires + access_jti = get_jti(encoded_token=access_token) + refresh_jti = get_jti(encoded_token=refresh_token) + revoked_store.set(access_jti, 'false', ACCESS_EXPIRES * 1.2) + revoked_store.set(refresh_jti, 'false', REFRESH_EXPIRES * 1.2) + + ret = { + 'access_token': access_token, + 'refresh_token': refresh_token + } + return jsonify(ret), 201 + + +# A blacklisted refresh tokens will not be able to access this endpoint +@app.route('/auth/refresh', methods=['POST']) +@jwt_refresh_token_required +def refresh(): + # Do the same thing that we did in the login endpoint here + current_user = get_jwt_identity() + access_token = create_access_token(identity=current_user) + access_jti = get_jti(encoded_token=access_token) + revoked_store.set(access_jti, 'false', ACCESS_EXPIRES * 1.2) + ret = {'access_token': access_token} + return jsonify(ret), 201 + + +# Endpoint for revoking the current users access token +@app.route('/auth/access_revoke', methods=['DELETE']) +@jwt_required +def logout(): + jti = get_raw_jwt()['jti'] + revoked_store.set(jti, 'true', ACCESS_EXPIRES * 1.2) + return jsonify({"msg": "Access token revoked"}), 200 + + +# Endpoint for revoking the current users refresh token +@app.route('/auth/refresh_revoke', methods=['DELETE']) +@jwt_refresh_token_required +def logout2(): + jti = get_raw_jwt()['jti'] + revoked_store.set(jti, 'true', REFRESH_EXPIRES * 1.2) + return jsonify({"msg": "Refresh token revoked"}), 200 + + +# A blacklisted access token will not be able to access this any more +@app.route('/protected', methods=['GET']) +@jwt_required +def protected(): + return jsonify({'hello': 'world'}) + +if __name__ == '__main__': + app.run() diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 3504b748..3348cc20 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -1,15 +1,10 @@ from .jwt_manager import JWTManager from .view_decorators import ( - jwt_required, fresh_jwt_required, jwt_refresh_token_required, - jwt_optional + jwt_required, fresh_jwt_required, jwt_refresh_token_required, jwt_optional ) from .utils import ( create_refresh_token, create_access_token, get_jwt_identity, get_jwt_claims, set_access_cookies, set_refresh_cookies, unset_jwt_cookies, get_raw_jwt, get_current_user, current_user, - get_jti + get_jti, decode_token ) -from .blacklist import ( - revoke_token, unrevoke_token, get_stored_tokens, get_all_stored_tokens, - get_stored_token -) \ No newline at end of file diff --git a/flask_jwt_extended/blacklist.py b/flask_jwt_extended/blacklist.py deleted file mode 100644 index 30a0acd4..00000000 --- a/flask_jwt_extended/blacklist.py +++ /dev/null @@ -1,170 +0,0 @@ -# Collection of code deals with storing and revoking tokens -import datetime -import json -from functools import wraps - -from flask_jwt_extended.config import config -from flask_jwt_extended.exceptions import RevokedTokenError -from flask_jwt_extended.utils import get_jti - -# TODO make simplekv an optional dependency if blacklist is disabled - - -def _verify_blacklist_enabled(fn): - """ - Helper decorator that verifies the blacklist is enabled on any function - that requires it - """ - @wraps(fn) - def wrapper(*args, **kwargs): - if not config.blacklist_enabled: - err = 'JWT_BLACKLIST_ENABLED must be True to access this functionality' - raise RuntimeError(err) - return fn(*args, **kwargs) - return wrapper - - -def _ts_to_utc_datetime(ts): - return datetime.datetime.utcfromtimestamp(ts) - - -def _store_supports_ttl(store): - """ - Checks if this store supports a TTL on its keys, for automatic removal - after the token has expired. For more info on this, see: - http://pythonhosted.org/simplekv/#simplekv.TimeToLiveMixin - """ - return getattr(store, 'ttl_support', False) - - -def _get_token_ttl(token): - """ - Returns a datetime.timdelta() of how long this token has left to live before - it is expired - """ - expires = _ts_to_utc_datetime(token['exp']) - now = datetime.datetime.utcnow() - delta = expires - now - - # If the token is already expired, return that it has a ttl of 0 - if delta.total_seconds() < 0: - return datetime.timedelta(0) - return delta - - -def _get_token_from_store(jti): - store = config.blacklist_store - stored_str = store.get(jti).decode('utf-8') - stored_data = json.loads(stored_str) - return stored_data - - -def _update_token(jti, revoked): - # Raises a KeyError if the token is not found in the store - stored_data = _get_token_from_store(jti) - token = stored_data['token'] - store_token(token, revoked) - - -@_verify_blacklist_enabled -def revoke_token(jti): - """ - Revoke a token - - :param jti: The jti of the token to revoke - """ - _update_token(jti, revoked=True) - - -@_verify_blacklist_enabled -def unrevoke_token(jti): - """ - Revoke a token - - :param jti: The jti of the token to unrevoke - """ - _update_token(jti, revoked=False) - - -@_verify_blacklist_enabled -def get_stored_token(jti=None, encoded_token=None): - """ - Get the stored token for the passed in jti or encoded_token - - :param jti: The jti of the token - :param encoded_token: The encoded JWT string - :return: Python dictionary with the token information - """ - if jti is None and encoded_token is not None: - jti = get_jti(encoded_token) - elif jti is None and encoded_token is None: - raise ValueError('Either jti or encoded_token is required') - return _get_token_from_store(jti) - - -@_verify_blacklist_enabled -def get_stored_tokens(identity): - """ - Get a list of stored tokens for this identity. Each token will look like: - - TODO - """ - # TODO this is *super* inefficient. Come up with a better way - store = config.blacklist_store - data = [json.loads(store.get(jti).decode('utf-8')) for jti in store.iter_keys()] - return [d for d in data if d['token']['identity'] == identity] - - -@_verify_blacklist_enabled -def get_all_stored_tokens(): - """ - Get a list of stored tokens for every identity. Each token will look like: - - TODO - """ - store = config.blacklist_store - return [json.loads(store.get(jti).decode('utf-8')) for jti in store.iter_keys()] - - -@_verify_blacklist_enabled -def check_if_token_revoked(token): - """ - Checks if the given token has been revoked. - """ - store = config.blacklist_store - check_type = config.blacklist_checks - token_type = token['type'] - jti = token['jti'] - - # Only check access tokens if BLACKLIST_TOKEN_CHECKS is set to 'all` - if token_type == 'access' and check_type == 'all': - stored_data = json.loads(store.get(jti).decode('utf-8')) - if stored_data['revoked']: - raise RevokedTokenError('Token has been revoked') - - # Always check refresh tokens - if token_type == 'refresh': - stored_data = json.loads(store.get(jti).decode('utf-8')) - if stored_data['revoked']: - raise RevokedTokenError('Token has been revoked') - - -@_verify_blacklist_enabled -def store_token(token, revoked): - """ - Stores this token in our key-value store, with the given revoked status - """ - data_to_store = json.dumps({ - 'token': token, - 'revoked': revoked - }).encode('utf-8') - - store = config.blacklist_store - - if _store_supports_ttl(store): # pragma: no cover - # Add 15 minutes to ttl to account for possible time drift - ttl = _get_token_ttl(token) + datetime.timedelta(minutes=15) - ttl_secs = ttl.total_seconds() - store.put(token['jti'], data_to_store, ttl_secs=ttl_secs) - else: - store.put(token['jti'], data_to_store) diff --git a/flask_jwt_extended/config.py b/flask_jwt_extended/config.py index 2110dfec..bf522d2e 100644 --- a/flask_jwt_extended/config.py +++ b/flask_jwt_extended/config.py @@ -1,7 +1,6 @@ import datetime from warnings import warn -import simplekv from flask import current_app # Older versions of pyjwt do not have the requires_cryptography set. Also, @@ -33,11 +32,11 @@ def is_asymmetric(self): @property def encode_key(self): - return self.private_key if self.is_asymmetric else self.secret_key + return self._private_key if self.is_asymmetric else self._secret_key @property def decode_key(self): - return self.public_key if self.is_asymmetric else self.secret_key + return self._public_key if self.is_asymmetric else self._secret_key @property def token_location(self): @@ -170,27 +169,26 @@ def algorithm(self): def blacklist_enabled(self): return current_app.config['JWT_BLACKLIST_ENABLED'] - @property - def blacklist_store(self): - # simplekv object: https://pypi.python.org/pypi/simplekv/ - store = current_app.config['JWT_BLACKLIST_STORE'] - if not isinstance(store, simplekv.KeyValueStore): - raise RuntimeError("JWT_BLACKLIST_STORE must be a simplekv KeyValueStore") - return store - @property def blacklist_checks(self): check_type = current_app.config['JWT_BLACKLIST_TOKEN_CHECKS'] - if check_type not in ('all', 'refresh'): - raise RuntimeError('JWT_BLACKLIST_TOKEN_CHECKS must be "all" or "refresh"') + if not isinstance(check_type, list): + check_type = [check_type] + for item in check_type: + if item not in ('access', 'refresh'): + raise RuntimeError('JWT_BLACKLIST_TOKEN_CHECKS must be "access" or "refresh"') return check_type @property def blacklist_access_tokens(self): - return 'all' in self.blacklist_checks + return 'access' in self.blacklist_checks + + @property + def blacklist_refresh_tokens(self): + return 'refresh' in self.blacklist_checks @property - def secret_key(self): + def _secret_key(self): key = current_app.config['JWT_SECRET_KEY'] if not key: key = current_app.config.get('SECRET_KEY', None) @@ -201,7 +199,7 @@ def secret_key(self): return key @property - def public_key(self): + def _public_key(self): key = current_app.config['JWT_PUBLIC_KEY'] if not key: raise RuntimeError('JWT_PUBLIC_KEY must be set to use ' @@ -210,7 +208,7 @@ def public_key(self): return key @property - def private_key(self): + def _private_key(self): key = current_app.config['JWT_PRIVATE_KEY'] if not key: raise RuntimeError('JWT_PRIVATE_KEY must be set to use ' diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index 710d0f67..a230c7d2 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -2,7 +2,6 @@ from jwt import ExpiredSignatureError, InvalidTokenError -from flask_jwt_extended.blacklist import store_token from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError, @@ -15,7 +14,7 @@ default_revoked_token_callback, default_user_loader_error_callback ) from flask_jwt_extended.tokens import ( - encode_refresh_token, decode_jwt, encode_access_token + encode_refresh_token, encode_access_token ) from flask_jwt_extended.utils import get_jwt_identity @@ -40,6 +39,7 @@ def __init__(self, app=None): self._revoked_token_callback = default_revoked_token_callback self._user_loader_callback = None self._user_loader_error_callback = default_user_loader_error_callback + self._token_in_blacklist_callback = None # Register this extension with the flask app now (if it is provided) if app is not None: @@ -162,8 +162,7 @@ def _set_default_configuration_options(app): # Options for blacklisting/revoking tokens app.config.setdefault('JWT_BLACKLIST_ENABLED', False) - app.config.setdefault('JWT_BLACKLIST_STORE', None) - app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', 'refresh') + app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh']) def user_claims_loader(self, callback): """ @@ -283,19 +282,17 @@ def user_loader_error_loader(self, callback): self._user_loader_error_callback = callback return callback - def has_user_loader(self): + def token_in_blacklist_loader(self, callback): """ - Returns True if a user_loader_callback has been defined in this - application, False otherwise - """ - return self._user_loader_callback is not None + Sets the callback function for checking if a token has been revoked. - def user_loader(self, identity): - """ - Calls the _user_loader_callback function (if it is defined) and returns - the resulting user from this callback. + This callback function must take one paramater, which is the full + decoded token dictionary. This should return True if the token has been + blacklisted (or is otherwise considered revoked, or an invalid token), + False otherwise. """ - return self._user_loader_callback(identity) + self._token_in_blacklist_callback = callback + return callback def create_refresh_token(self, identity, expires_delta=None): """ @@ -324,12 +321,6 @@ def create_refresh_token(self, identity, expires_delta=None): expires_delta=expires_delta, csrf=config.csrf_protect ) - - # If blacklisting is enabled, store this token in our key-value store - if config.blacklist_enabled: - decoded_token = decode_jwt(refresh_token, config.decode_key, - config.algorithm, csrf=config.csrf_protect) - store_token(decoded_token, revoked=False) return refresh_token def create_access_token(self, identity, fresh=False, expires_delta=None): @@ -363,9 +354,5 @@ def create_access_token(self, identity, fresh=False, expires_delta=None): user_claims=self._user_claims_callback(identity), csrf=config.csrf_protect ) - if config.blacklist_enabled and config.blacklist_access_tokens: - decoded_token = decode_jwt(access_token, config.decode_key, - config.algorithm, csrf=config.csrf_protect) - store_token(decoded_token, revoked=False) return access_token diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 65d94ce9..8f93753d 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -50,18 +50,28 @@ def get_current_user(): def get_jti(encoded_token): """ Returns the JTI given the JWT encoded token + """ + return decode_token(encoded_token).get('jti') + - :param encoded_token: The encoded JWT string - :return: The JTI of the token +def decode_token(encoded_token): + """ + Returns the decoded token from an encoded one. This does all the checks + to insure that the decoded token is valid before returning it. """ - return decode_jwt(encoded_token, config.secret_key, config.algorithm, config.csrf_protect).get('jti') + return decode_jwt( + encoded_token=encoded_token, + secret=config.decode_key, + algorithm=config.algorithm, + csrf=config.csrf_protect + ) def _get_jwt_manager(): try: return current_app.jwt_manager except AttributeError: # pragma: no cover - raise RuntimeError("You must initialize a JWTManager with this flask" + raise RuntimeError("You must initialize a JWTManager with this flask " "application before using this method") @@ -75,14 +85,24 @@ def create_refresh_token(*args, **kwargs): return jwt_manager.create_refresh_token(*args, **kwargs) +def has_user_loader(): + jwt_manager = _get_jwt_manager() + return jwt_manager._user_loader_callback is not None + + def user_loader(*args, **kwargs): jwt_manager = _get_jwt_manager() - return jwt_manager.user_loader(*args, **kwargs) + return jwt_manager._user_loader_callback(*args, **kwargs) + + +def has_token_in_blacklist_callback(): + jwt_manager = _get_jwt_manager() + return jwt_manager._token_in_blacklist_callback is not None -def has_user_loader(*args, **kwargs): +def token_in_blacklist(*args, **kwargs): jwt_manager = _get_jwt_manager() - return jwt_manager.has_user_loader(*args, **kwargs) + return jwt_manager._token_in_blacklist_callback(*args, **kwargs) def get_csrf_token(encoded_token): diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 7b0a756a..386e458a 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -7,14 +7,16 @@ except ImportError: # pragma: no cover from flask import _request_ctx_stack as ctx_stack -from flask_jwt_extended.blacklist import check_if_token_revoked from flask_jwt_extended.config import config from flask_jwt_extended.exceptions import ( InvalidHeaderError, NoAuthorizationError, WrongTokenError, - FreshTokenRequired, CSRFError, UserLoadError + FreshTokenRequired, CSRFError, UserLoadError, RevokedTokenError ) from flask_jwt_extended.tokens import decode_jwt -from flask_jwt_extended.utils import has_user_loader, user_loader +from flask_jwt_extended.utils import ( + has_user_loader, user_loader, token_in_blacklist, + has_token_in_blacklist_callback +) def jwt_required(fn): @@ -104,6 +106,21 @@ def _load_user(identity): ctx_stack.top.jwt_user = user +def _token_blacklisted(decoded_token, request_type): + if not config.blacklist_enabled: + return False + if not has_token_in_blacklist_callback(): + raise RuntimeError("A token_in_blacklist_callback must be provided via " + "the '@token_in_blacklist_loader' if " + "JWT_BLACKLIST_ENABLED is True") + + if config.blacklist_access_tokens and request_type == 'access': + return token_in_blacklist(decoded_token) + if config.blacklist_refresh_tokens and request_type == 'refresh': + return token_in_blacklist(decoded_token) + return False + + def _decode_jwt_from_headers(): header_name = config.header_name header_type = config.header_type @@ -184,8 +201,7 @@ def _decode_jwt_from_request(request_type): raise WrongTokenError('Only {} tokens can access this endpoint'.format(request_type)) # If blacklisting is enabled, see if this token has been revoked - if config.blacklist_enabled: - check_if_token_revoked(decoded_token) + if _token_blacklisted(decoded_token, request_type): + raise RevokedTokenError('Token has been revoked') return decoded_token - diff --git a/requirements.txt b/requirements.txt index de842e2f..463d610d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,32 @@ -alabaster==0.7.9 -Babel==2.3.4 -click==6.6 -coverage==4.2 -cryptography==1.8.1 -docutils==0.12 -Flask==0.11.1 +alabaster==0.7.10 +asn1crypto==0.22.0 +Babel==2.4.0 +certifi==2017.4.17 +cffi==1.10.0 +chardet==3.0.4 +click==6.7 +coverage==4.4.1 +cryptography==1.9 +docutils==0.13.1 +Flask==0.12.2 +Flask-Sphinx-Themes==1.0.1 +idna==2.5 imagesize==0.7.1 itsdangerous==0.24 -Jinja2==2.8 -MarkupSafe==0.23 -pluggy==0.3.1 -py==1.4.31 -Pygments==2.1.3 -PyJWT==1.5.0 -pytz==2016.7 -simplekv==0.10.0 +Jinja2==2.9.6 +MarkupSafe==1.0 +pluggy==0.4.0 +py==1.4.34 +pycparser==2.18 +Pygments==2.2.0 +PyJWT==1.5.2 +pytz==2017.2 +requests==2.18.1 six==1.10.0 snowballstemmer==1.2.1 -Sphinx==1.4.8 -tox==2.3.1 -virtualenv==15.0.3 -Werkzeug==0.11.11 +Sphinx==1.6.3 +sphinxcontrib-websupport==1.0.1 +tox==2.7.0 +urllib3==1.21.1 +virtualenv==15.1.0 +Werkzeug==0.12.2 diff --git a/setup.py b/setup.py index 9147b5b9..c1ca1486 100644 --- a/setup.py +++ b/setup.py @@ -17,12 +17,12 @@ packages=['flask_jwt_extended'], zip_safe=False, platforms='any', - install_requires=['Flask', 'PyJWT', 'simplekv'], + install_requires=['Flask', 'PyJWT'], extras_require={ 'asymmetric_crypto': ["cryptography"] }, classifiers=[ - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Environment :: Web Environment', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', diff --git a/tests/test_blacklist.py b/tests/test_blacklist.py index cdd3db84..6fd4a137 100644 --- a/tests/test_blacklist.py +++ b/tests/test_blacklist.py @@ -1,19 +1,12 @@ -import time import unittest import json -from datetime import timedelta -import simplekv.memory from flask import Flask, jsonify, request -from flask_jwt_extended.blacklist import _get_token_ttl, get_stored_token -from flask_jwt_extended.tokens import encode_refresh_token, decode_jwt -from flask_jwt_extended.utils import get_jwt_identity, get_raw_jwt +from flask_jwt_extended.utils import get_jwt_identity, get_jti from flask_jwt_extended import ( - JWTManager, create_access_token, - get_all_stored_tokens, get_stored_tokens, revoke_token, unrevoke_token, - jwt_required, create_refresh_token, jwt_refresh_token_required, - fresh_jwt_required + JWTManager, create_access_token, jwt_required, create_refresh_token, + jwt_refresh_token_required, fresh_jwt_required ) @@ -23,9 +16,14 @@ def setUp(self): self.app = Flask(__name__) self.app.secret_key = 'super=secret' self.app.config['JWT_BLACKLIST_ENABLED'] = True - self.app.config['JWT_BLACKLIST_STORE'] = simplekv.memory.DictStore() self.jwt_manager = JWTManager(self.app) self.client = self.app.test_client() + self.blacklist = set() + + @self.jwt_manager.token_in_blacklist_loader + def token_in_blacklist(decoded_token): + jti = decoded_token['jti'] + return jti in self.blacklist @self.app.route('/auth/login', methods=['POST']) def login(): @@ -36,42 +34,6 @@ def login(): } return jsonify(ret), 200 - @self.app.route('/auth/token/jti/', methods=['GET']) - @self.app.route('/auth/token/encoded_token/', methods=['GET']) - @self.app.route('/auth/token/encoded_token/', methods=['GET']) - def get_single_token(jti=None, encoded_token=None): - try: - if jti is not None: - return jsonify(get_stored_token(jti=jti)), 200 - else: - return jsonify(get_stored_token(encoded_token=encoded_token)), 200 - except KeyError: - return jsonify({"msg": "token not found"}), 404 - - @self.app.route('/auth/tokens/', methods=['GET']) - def list_identity_tokens(identity): - return jsonify(get_stored_tokens(identity)), 200 - - @self.app.route('/auth/tokens', methods=['GET']) - def list_all_tokens(): - return jsonify(get_all_stored_tokens()), 200 - - @self.app.route('/auth/revoke/', methods=['POST']) - def revoke(jti): - try: - revoke_token(jti) - return jsonify({"msg": "Token revoked"}) - except KeyError: - return jsonify({"msg": "Token not found"}), 404 - - @self.app.route('/auth/unrevoke/', methods=['POST']) - def unrevoke(jti): - try: - unrevoke_token(jti) - return jsonify({"msg": "Token unrevoked"}) - except KeyError: - return jsonify({"msg": "Token not found"}), 404 - @self.app.route('/auth/refresh', methods=['POST']) @jwt_refresh_token_required def refresh(): @@ -79,13 +41,15 @@ def refresh(): ret = {'access_token': create_access_token(username, fresh=False)} return jsonify(ret), 200 - @self.app.route('/auth/logout', methods=['POST']) - @jwt_required - def logout(): - jti = get_raw_jwt()['jti'] - revoke_token(jti) - ret = {"msg": "Successfully logged out"} - return jsonify(ret), 200 + @self.app.route('/auth/revoke/', methods=['POST']) + def revoke(jti): + self.blacklist.add(jti) + return jsonify({"msg": "Token revoked"}) + + @self.app.route('/auth/unrevoke/', methods=['POST']) + def unrevoke(jti): + self.blacklist.remove(jti) + return jsonify({"msg": "Token unrevoked"}) @self.app.route('/protected', methods=['POST']) @jwt_required @@ -114,325 +78,90 @@ def _jwt_post(self, url, jwt=None): data = json.loads(response.get_data(as_text=True)) return status_code, data - def test_revoke_unrevoke_all_token(self): + def test_revoke_access_token(self): # Check access and refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all' + self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] - # No tokens initially - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(data, []) + # Generate our tokens + access_token, _ = self._login('user') + with self.app.app_context(): + access_jti = get_jti(access_token) - # Login, now should have two tokens (access and refresh) that are not revoked - self._login('test1') - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 2) - self.assertFalse(data[0]['revoked']) - self.assertFalse(data[1]['revoked']) + # Make sure we can access a protected endpoint + status_code, data = self._jwt_post('/protected', access_token) + self.assertEqual(status_code, 200) + self.assertEqual(data, {'hello': 'world'}) - # Revoke the access token - access_jti = [x['token']['jti'] for x in data if x['token']['type'] == 'access'][0] + # Revoke our access token status, data = self._jwt_post('/auth/revoke/{}'.format(access_jti)) self.assertEqual(status, 200) - self.assertIn('msg', data) + self.assertEqual(data, {'msg': 'Token revoked'}) - # Verify the access token has been revoked on new lookup - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 2) - if data[0]['token']['jti'] == access_jti: - self.assertTrue(data[0]['revoked']) - self.assertFalse(data[1]['revoked']) - else: - self.assertFalse(data[0]['revoked']) - self.assertTrue(data[1]['revoked']) - - # Unrevoke the access token - status, data = self._jwt_post('/auth/unrevoke/{}'.format(access_jti)) - self.assertEqual(status, 200) - self.assertIn('msg', data) - - # Make sure token is marked as unrevoked - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 2) - self.assertFalse(data[0]['revoked']) - self.assertFalse(data[1]['revoked']) + # Verify the access token can no longer access a protected endpoint + status_code, data = self._jwt_post('/protected', access_token) + self.assertEqual(status_code, 401) + self.assertEqual(data, {'msg': 'Token has been revoked'}) - def test_revoke_unrevoke_refresh_token(self): - # Check only refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'refresh' + def test_revoke_refresh_token(self): + # Check access and refresh tokens + self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'refresh'] - # No tokens initially - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(data, []) + # Generate our tokens + _, refresh_token = self._login('user') + with self.app.app_context(): + refresh_jti = get_jti(refresh_token) - # Login, now should have one token that is not revoked - self._login('test1') - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 1) - self.assertFalse(data[0]['revoked']) + # Make sure we can access a protected endpoint + status_code, data = self._jwt_post('/auth/refresh', refresh_token) + self.assertEqual(status_code, 200) + self.assertIn('access_token', data) - # Revoke the token - refresh_jti = data[0]['token']['jti'] + # Revoke our access token status, data = self._jwt_post('/auth/revoke/{}'.format(refresh_jti)) self.assertEqual(status, 200) - self.assertIn('msg', data) - - # Verify the token has been revoked on new lookup - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 1) - self.assertTrue(data[0]['revoked']) - - # Unrevoke the token - status, data = self._jwt_post('/auth/unrevoke/{}'.format(refresh_jti)) - self.assertEqual(status, 200) - self.assertIn('msg', data) - - # Make sure token is marked as unrevoked - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(response.status_code, 200) - self.assertEqual(len(data), 1) - self.assertFalse(data[0]['revoked']) - - def test_revoked_access_token_enabled(self): - # Check access and refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all' + self.assertEqual(data, {'msg': 'Token revoked'}) - # Login - access_token, refresh_token = self._login('test1') + # Verify the access token can no longer access a protected endpoint + status_code, data = self._jwt_post('/auth/refresh', refresh_token) + self.assertEqual(status_code, 401) + self.assertEqual(data, {'msg': 'Token has been revoked'}) - # Get the access jti - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - access_jti = [x['token']['jti'] for x in data if x['token']['type'] == 'access'][0] + def test_revoked_token_with_access_blacklist_only(self): + # Setup to only revoke refresh tokens + self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['refresh'] - # Verify we can initially access the endpoint - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) - status, data = self._jwt_post('/protected-fresh', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) + # Generate our tokens + access_token, refresh_token = self._login('user') + with self.app.app_context(): + access_jti = get_jti(access_token) + refresh_jti = get_jti(refresh_token) - # Verify we can no longer access endpoint after revoking + # Revoke both tokens (even though app is only configured to look + # at revoked refresh tokens) self._jwt_post('/auth/revoke/{}'.format(access_jti)) - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 401) - self.assertIn('msg', data) - status, data = self._jwt_post('/protected-fresh', access_token) - self.assertEqual(status, 401) - self.assertIn('msg', data) - - # Verify refresh token works, and new token can access endpoint - _, data = self._jwt_post('/auth/refresh', refresh_token) - new_access_token = data['access_token'] - status, data = self._jwt_post('/protected', new_access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) - - # Verify original token can access endpoint after unrevoking - self._jwt_post('/auth/unrevoke/{}'.format(access_jti)) - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) - status, data = self._jwt_post('/protected-fresh', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) - - def test_revoked_access_token_disabled(self): - # Check only refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'refresh' - - # Login - access_token, refresh_token = self._login('test1') - - # Nothing should be returned, as this token wasn't saved - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - access_jti = [x for x in data if x['token']['type'] == 'access'] - self.assertEqual(len(access_jti), 0) - - # Verify we can access the endpoint - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'hello': 'world'}) - - def test_revoked_refresh_token(self): - # Check only refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'refresh' - - # Login - access_token, refresh_token = self._login('test1') - - # Get the access jti - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - refresh_jti = [x['token']['jti'] for x in data - if x['token']['type'] == 'refresh'][0] - - # Verify we can initially access the refresh endpoint - status, data = self._jwt_post('/auth/refresh', refresh_token) - self.assertEqual(status, 200) - self.assertIn('access_token', data) - - # Verify we can no longer access the refresh endpoint after revoking self._jwt_post('/auth/revoke/{}'.format(refresh_jti)) - status, data = self._jwt_post('/auth/refresh', refresh_token) - self.assertEqual(status, 401) - self.assertIn('msg', data) - - # Verify we can access again after unrevoking - self._jwt_post('/auth/unrevoke/{}'.format(refresh_jti)) - status, data = self._jwt_post('/auth/refresh', refresh_token) - self.assertEqual(status, 200) - self.assertIn('access_token', data) - - def test_login_logout(self): - # Check access and refresh tokens - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all' - # Login - access_token, refresh_token = self._login('test12345') - - # Verify we can access the protected endpoint - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 200) + # Make sure we can still access a protected endpoint with the access token + status_code, data = self._jwt_post('/protected', access_token) + self.assertEqual(status_code, 200) self.assertEqual(data, {'hello': 'world'}) - # Logout - status, data = self._jwt_post('/auth/logout', access_token) - self.assertEqual(status, 200) - self.assertEqual(data, {'msg': 'Successfully logged out'}) - - # Verify that we cannot access the protected endpoint anymore - status, data = self._jwt_post('/protected', access_token) - self.assertEqual(status, 401) + # Make sure that the refresh token kicks us back out + status_code, data = self._jwt_post('/auth/refresh', refresh_token) + self.assertEqual(status_code, 401) self.assertEqual(data, {'msg': 'Token has been revoked'}) def test_bad_blacklist_settings(self): - app = Flask(__name__) - app.testing = True # Propagate exceptions - JWTManager(app) - client = app.test_client() - - @app.route('/list-tokens') - def list_tokens(): - return jsonify(get_all_stored_tokens()) + # Disable the token in blacklist check function + self.jwt_manager.token_in_blacklist_loader(None) - # Check calling blacklist function if blacklist is disabled - app.config['JWT_BLACKLIST_ENABLED'] = False - with self.assertRaises(RuntimeError): - client.get('/list-tokens') + access_token, _ = self._login('user') - # Check calling blacklist function if store is not set - app.config['JWT_BLACKLIST_ENABLED'] = True - app.config['JWT_BLACKLIST_STORE'] = None + # Check that accessing a jwt_required endpoint raises a runtime error with self.assertRaises(RuntimeError): - client.get('/list-tokens') + self._jwt_post('/protected', access_token) # Check calling blacklist function if invalid blacklist check type - app.config['JWT_BLACKLIST_ENABLED'] = True - app.config['JWT_BLACKLIST_STORE'] = {} + self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = ['access', 'banana'] with self.assertRaises(RuntimeError): - client.get('/list-tokens') - - def test_get_token_ttl(self): - # This is called when using a simplekv backend that supports ttl (such - # as redis or memcached). Because I do not want to require having those - # installed to run the unit tests, I'm going to fiat that the code for - # them works, and manually test the helper methods they call for correctness. - - # Test token ttl - with self.app.test_request_context(): - token_str = encode_refresh_token('foo', 'secret', 'HS256', - timedelta(minutes=5), csrf=False) - token = decode_jwt(token_str, 'secret', 'HS256', csrf=False) - time.sleep(2) - token_ttl = _get_token_ttl(token).total_seconds() - self.assertGreater(token_ttl, 296) - self.assertLessEqual(token_ttl, 298) - - # Test ttl is 0 if token is already expired - with self.app.test_request_context(): - token_str = encode_refresh_token('foo', 'secret', 'HS256', - timedelta(seconds=0), csrf=False) - token = decode_jwt(token_str, 'secret', 'HS256', csrf=False) - time.sleep(2) - token_ttl = _get_token_ttl(token).total_seconds() - self.assertEqual(token_ttl, 0) - - def test_revoke_invalid_token(self): - status, data = self._jwt_post('/auth/revoke/404_token_not_found') - self.assertEqual(status, 404) - self.assertIn('msg', data) - - def test_get_specific_identity(self): - self._login('test1') - self._login('test1') - self._login('test1') - self._login('test2') - - response = self.client.get('/auth/tokens/test1') - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 200) - self.assertEqual(len(data), 3) - - response = self.client.get('/auth/tokens/test2') - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 200) - self.assertEqual(len(data), 1) - - response = self.client.get('/auth/tokens/test3') - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 200) - self.assertEqual(len(data), 0) - - def test_get_stored_token(self): - access_token, refresh_token = self._login('test1') - response = self.client.get('/auth/tokens') - data = json.loads(response.get_data(as_text=True)) - refresh_jti = data[0]['token']['jti'] - - # Test getting the token by passing in JTI - response = self.client.get('/auth/token/jti/{}'.format(refresh_jti)) - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 200) - self.assertIn('token', data) - self.assertIn('revoked', data) - self.assertEqual(len(data), 2) - - # Test getting the token by passing in the encoded token - response = self.client.get('/auth/token/encoded_token/{}'.format(refresh_token)) - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 200) - self.assertIn('token', data) - self.assertIn('revoked', data) - self.assertEqual(len(data), 2) - - # Test passing neither throws an exception - with self.assertRaises(ValueError): - self.client.get('/auth/token/encoded_token/') - - response = self.client.get('/auth/token/jti/404notokenfound') - status_code = response.status_code - data = json.loads(response.get_data(as_text=True)) - self.assertEqual(status_code, 404) - self.assertIn('msg', data) + self._jwt_post('/protected', access_token) diff --git a/tests/test_config.py b/tests/test_config.py index 54f02248..fece2933 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,6 @@ import warnings from datetime import timedelta -import simplekv.memory from flask import Flask from flask_jwt_extended.config import config @@ -47,24 +46,15 @@ def test_default_configs(self): self.assertEqual(config.algorithm, 'HS256') self.assertEqual(config.is_asymmetric, False) self.assertEqual(config.blacklist_enabled, False) - self.assertEqual(config.blacklist_checks, 'refresh') - self.assertEqual(config.blacklist_access_tokens, False) + self.assertEqual(config.blacklist_checks, ['access', 'refresh']) + self.assertEqual(config.blacklist_access_tokens, True) + self.assertEqual(config.blacklist_refresh_tokens, True) - self.assertEqual(config.secret_key, self.app.secret_key) self.assertEqual(config.encode_key, self.app.secret_key) self.assertEqual(config.decode_key, self.app.secret_key) self.assertEqual(config.cookie_max_age, None) - with self.assertRaises(RuntimeError): - config.blacklist_store - with self.assertRaises(RuntimeError): - config.public_key - with self.assertRaises(RuntimeError): - config.private_key - def test_override_configs(self): - sample_store = simplekv.memory.DictStore() - self.app.config['JWT_TOKEN_LOCATION'] = ['cookies'] self.app.config['JWT_HEADER_NAME'] = 'TestHeader' self.app.config['JWT_HEADER_TYPE'] = 'TestType' @@ -92,8 +82,7 @@ def test_override_configs(self): self.app.config['JWT_ALGORITHM'] = 'HS512' self.app.config['JWT_BLACKLIST_ENABLED'] = True - self.app.config['JWT_BLACKLIST_STORE'] = sample_store - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all' + self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'refresh' self.app.secret_key = 'banana' @@ -127,11 +116,10 @@ def test_override_configs(self): self.assertEqual(config.algorithm, 'HS512') self.assertEqual(config.blacklist_enabled, True) - self.assertEqual(config.blacklist_store, sample_store) - self.assertEqual(config.blacklist_checks, 'all') - self.assertEqual(config.blacklist_access_tokens, True) + self.assertEqual(config.blacklist_checks, ['refresh']) + self.assertEqual(config.blacklist_access_tokens, False) + self.assertEqual(config.blacklist_refresh_tokens, True) - self.assertEqual(config.secret_key, 'banana') self.assertEqual(config.cookie_max_age, 2147483647) def test_invalid_config_options(self): @@ -157,25 +145,21 @@ def test_invalid_config_options(self): with self.assertRaises(RuntimeError): config.refresh_expires - self.app.config['JWT_BLACKLIST_STORE'] = {} - with self.assertRaises(RuntimeError): - config.blacklist_store - self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'banana' with self.assertRaises(RuntimeError): config.blacklist_checks self.app.secret_key = None with self.assertRaises(RuntimeError): - config.secret_key + config.decode_key self.app.secret_key = '' with self.assertRaises(RuntimeError): - config.secret_key + config.decode_key self.app.secret_key = None with self.assertRaises(RuntimeError): - config.encode_key + config.decode_key self.app.config['JWT_ALGORITHM'] = 'RS256' self.app.config['JWT_PUBLIC_KEY'] = None @@ -234,5 +218,3 @@ def test_asymmetric_encryption_key_handling(self): self.assertEqual(config.is_asymmetric, True) self.assertEqual(config.encode_key, 'MOCK_RSA_PRIVATE_KEY') self.assertEqual(config.decode_key, 'MOCK_RSA_PUBLIC_KEY') - self.assertEqual(config.private_key, 'MOCK_RSA_PRIVATE_KEY') - self.assertEqual(config.public_key, 'MOCK_RSA_PUBLIC_KEY') diff --git a/tests/test_jwt_encode_decode.py b/tests/test_jwt_encode_decode.py index 10e2eaf5..b65caad3 100644 --- a/tests/test_jwt_encode_decode.py +++ b/tests/test_jwt_encode_decode.py @@ -36,7 +36,7 @@ def test_encode_access_token(self): identity = 'user1' token = encode_access_token(identity, secret, algorithm, token_expire_delta, fresh=True, user_claims=user_claims, csrf=False) - data = jwt.decode(token, secret, algorithm=algorithm) + data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) @@ -60,7 +60,7 @@ def test_encode_access_token(self): identity = 12345 # identity can be anything json serializable token = encode_access_token(identity, secret, algorithm, token_expire_delta, fresh=False, user_claims=user_claims, csrf=True) - data = jwt.decode(token, secret, algorithm=algorithm) + data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) @@ -105,7 +105,7 @@ def test_encode_refresh_token(self): identity = 'user1' token = encode_refresh_token(identity, secret, algorithm, token_expire_delta, csrf=False) - data = jwt.decode(token, secret, algorithm=algorithm) + data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) @@ -125,7 +125,7 @@ def test_encode_refresh_token(self): identity = 12345 # identity can be anything json serializable token = encode_refresh_token(identity, secret, algorithm, token_expire_delta, csrf=True) - data = jwt.decode(token, secret, algorithm=algorithm) + data = jwt.decode(token, secret, algorithms=[algorithm]) self.assertIn('exp', data) self.assertIn('iat', data) self.assertIn('nbf', data) diff --git a/tests/test_jwt_manager.py b/tests/test_jwt_manager.py index 067ee56f..283bad09 100644 --- a/tests/test_jwt_manager.py +++ b/tests/test_jwt_manager.py @@ -3,6 +3,7 @@ from flask import Flask, jsonify from flask_jwt_extended import JWTManager +from flask_jwt_extended.utils import has_user_loader class TestJWTManager(unittest.TestCase): @@ -101,7 +102,8 @@ def test_default_user_loader_error_callback(self): def test_default_has_user_loader(self): m = JWTManager(self.app) - self.assertEqual(m.has_user_loader(), False) + with self.app.app_context(): + self.assertEqual(has_user_loader(), False) def test_custom_user_claims_callback(self): identity = 'foobar' @@ -196,7 +198,7 @@ def custom_user_loader(identity): identity = 'foobar' result = m._user_loader_callback(identity) self.assertEqual(result, identity) - self.assertEqual(m.has_user_loader(), True) + self.assertEqual(has_user_loader(), True) def test_custom_user_loader_error_callback(self): with self.app.test_request_context():