Skip to content

Commit

Permalink
Merge 35f72d9 into 5269842
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulter committed Dec 7, 2020
2 parents 5269842 + 35f72d9 commit d966516
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 92 deletions.
28 changes: 0 additions & 28 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,6 @@ def set_access_cookies(response, encoded_access_token, max_age=None):
JWT_SESSION_COOKIE option will be ignored. Values should be
the number of seconds (as an integer).
"""
if not config.jwt_in_cookies:
raise RuntimeWarning(
"set_access_cookies() called without "
"'JWT_TOKEN_LOCATION' configured to use cookies"
)

# Set the access JWT in the cookie
response.set_cookie(
config.access_cookie_name,
value=encoded_access_token,
Expand All @@ -263,7 +256,6 @@ def set_access_cookies(response, encoded_access_token, max_age=None):
samesite=config.cookie_samesite,
)

# If enabled, set the csrf double submit access cookie
if config.csrf_protect and config.csrf_in_cookies:
response.set_cookie(
config.access_csrf_cookie_name,
Expand Down Expand Up @@ -292,13 +284,6 @@ def set_refresh_cookies(response, encoded_refresh_token, max_age=None):
JWT_SESSION_COOKIE option will be ignored. Values should be
the number of seconds (as an integer).
"""
if not config.jwt_in_cookies:
raise RuntimeWarning(
"set_refresh_cookies() called without "
"'JWT_TOKEN_LOCATION' configured to use cookies"
)

# Set the refresh JWT in the cookie
response.set_cookie(
config.refresh_cookie_name,
value=encoded_refresh_token,
Expand All @@ -310,7 +295,6 @@ def set_refresh_cookies(response, encoded_refresh_token, max_age=None):
samesite=config.cookie_samesite,
)

# If enabled, set the csrf double submit refresh cookie
if config.csrf_protect and config.csrf_in_cookies:
response.set_cookie(
config.refresh_csrf_cookie_name,
Expand Down Expand Up @@ -344,12 +328,6 @@ def unset_access_cookies(response):
:param response: the flask response object to delete the jwt cookies in.
"""
if not config.jwt_in_cookies:
raise RuntimeWarning(
"unset_refresh_cookies() called without "
"'JWT_TOKEN_LOCATION' configured to use cookies"
)

response.set_cookie(
config.access_cookie_name,
value="",
Expand Down Expand Up @@ -383,12 +361,6 @@ def unset_refresh_cookies(response):
:param response: the flask response object to delete the jwt cookies in.
"""
if not config.jwt_in_cookies:
raise RuntimeWarning(
"unset_refresh_cookies() called without "
"'JWT_TOKEN_LOCATION' configured to use cookies"
)

response.set_cookie(
config.refresh_cookie_name,
value="",
Expand Down
30 changes: 16 additions & 14 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _verify_token_is_fresh(jwt_header, jwt_data):
raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data)


def verify_jwt_in_request(optional=False, fresh=False, refresh=False):
def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None):
"""
Ensure that the requester has a valid access token. This does not check the
freshness of the access token. Raises an appropiate exception there is
Expand All @@ -44,9 +44,9 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False):

try:
if refresh:
jwt_data, jwt_header = _decode_jwt_from_request("refresh")
jwt_data, jwt_header = _decode_jwt_from_request("refresh", locations)
else:
jwt_data, jwt_header = _decode_jwt_from_request("access")
jwt_data, jwt_header = _decode_jwt_from_request("access", locations)
except (NoAuthorizationError, InvalidHeaderError):
if not optional:
raise
Expand All @@ -68,7 +68,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False):
return jwt_header, jwt_data


def jwt_required(optional=False, fresh=False, refresh=False):
def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
"""
A decorator to protect a Flask endpoint.
Expand All @@ -82,7 +82,7 @@ def jwt_required(optional=False, fresh=False, refresh=False):
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
verify_jwt_in_request(optional, fresh, refresh)
verify_jwt_in_request(optional, fresh, refresh, locations)
return fn(*args, **kwargs)

return decorator
Expand Down Expand Up @@ -196,12 +196,17 @@ def _decode_jwt_from_json(token_type):
return encoded_token, None


def _decode_jwt_from_request(token_type):
def _decode_jwt_from_request(token_type, locations):
# All the places we can get a JWT from in this request
get_encoded_token_functions = []

# add the functions in the order specified in JWT_TOKEN_LOCATION
for location in config.token_location:
# Get locations in the order specified by the decorator or JWT_TOKEN_LOCATION
# configuration.
if not locations:
locations = config.token_location

# Add the functions in the order specified by locations.
for location in locations:
if location == "cookies":
get_encoded_token_functions.append(
lambda: _decode_jwt_from_cookies(token_type)
Expand Down Expand Up @@ -232,13 +237,10 @@ def _decode_jwt_from_request(token_type):
# Do some work to make a helpful and human readable error message if no
# token was found in any of the expected locations.
if not decoded_token:
token_locations = config.token_location
multiple_jwt_locations = len(token_locations) != 1

if multiple_jwt_locations:
if len(locations) > 1:
err_msg = "Missing JWT in {start_locs} or {end_locs} ({details})".format(
start_locs=", ".join(token_locations[:-1]),
end_locs=token_locations[-1],
start_locs=", ".join(locations[:-1]),
end_locs=locations[-1],
details="; ".join(errors),
)
raise NoAuthorizationError(err_msg)
Expand Down
16 changes: 0 additions & 16 deletions tests/test_cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,22 +312,6 @@ def test_custom_csrf_methods(app, options):
assert response.get_json() == {"foo": "bar"}


def test_setting_cookies_wihout_cookies_enabled(app):
app.config["JWT_TOKEN_LOCATION"] = ["headers"]
test_client = app.test_client()

response = test_client.get("/access_token")
assert response.status_code == 500
response = test_client.get("/refresh_token")
assert response.status_code == 500
response = test_client.get("/delete_tokens")
assert response.status_code == 500
response = test_client.get("/delete_access_tokens")
assert response.status_code == 500
response = test_client.get("/delete_refresh_tokens")
assert response.status_code == 500


def test_default_cookie_options(app):
test_client = app.test_client()

Expand Down
93 changes: 59 additions & 34 deletions tests/test_multiple_token_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,46 +30,71 @@ def access_protected():
return app


def test_header_access(app):
test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

access_headers = {"Authorization": "Bearer {}".format(access_token)}
response = test_client.get("/protected", headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}


def test_cookie_access(app):
test_client = app.test_client()
test_client.get("/cookie_login")
response = test_client.get("/protected")
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}

@pytest.fixture(scope="function")
def app_with_locations():
app = Flask(__name__)
app.config["JWT_SECRET_KEY"] = "foobarbaz"
app.config["JWT_TOKEN_LOCATION"] = ["headers"]
locations = ["headers", "cookies", "query_string", "json"]
JWTManager(app)

def test_query_string_access(app):
test_client = app.test_client()
with app.test_request_context():
@app.route("/cookie_login", methods=["GET"])
def cookie_login():
resp = jsonify(login=True)
access_token = create_access_token("username")
set_access_cookies(resp, access_token)
return resp

url = "/protected?jwt={}".format(access_token)
response = test_client.get(url)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}

@app.route("/protected", methods=["GET", "POST"])
@jwt_required(locations=locations)
def access_protected():
return jsonify(foo="bar")

def test_json_access(app):
test_client = app.test_client()
return app

with app.test_request_context():
access_token = create_access_token("username")

data = {"access_token": access_token}
response = test_client.post("/protected", json=data)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}
def test_header_access(app, app_with_locations):
for app in (app, app_with_locations):
test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

access_headers = {"Authorization": "Bearer {}".format(access_token)}
response = test_client.get("/protected", headers=access_headers)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}


def test_cookie_access(app, app_with_locations):
for app in (app, app_with_locations):
test_client = app.test_client()
test_client.get("/cookie_login")
response = test_client.get("/protected")
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}


def test_query_string_access(app, app_with_locations):
for app in (app, app_with_locations):
test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")

url = "/protected?jwt={}".format(access_token)
response = test_client.get(url)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}


def test_json_access(app, app_with_locations):
for app in (app, app_with_locations):
test_client = app.test_client()
with app.test_request_context():
access_token = create_access_token("username")
data = {"access_token": access_token}
response = test_client.post("/protected", json=data)
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}


@pytest.mark.parametrize(
Expand Down

0 comments on commit d966516

Please sign in to comment.