diff --git a/flask_authz/casbin_enforcer.py b/flask_authz/casbin_enforcer.py index d874471..59841fa 100644 --- a/flask_authz/casbin_enforcer.py +++ b/flask_authz/casbin_enforcer.py @@ -59,7 +59,7 @@ def owner_loader(self, callback): self._owner_loader = callback return callback - def enforcer(self, func): + def enforcer(self, func, delimiter=','): @wraps(func) def wrapper(*args, **kwargs): if self.e.watcher and self.e.watcher.should_reload(): @@ -117,7 +117,8 @@ def wrapper(*args, **kwargs): # Split header by ',' in case of groups when groups are # sent "group1,group2,group3,..." in the header for owner in self.sanitize_group_headers( - request.headers.get(header) + request.headers.get(header), + delimiter ): self.app.logger.debug( "Enforce against owner: %s header: %s" @@ -149,26 +150,19 @@ def wrapper(*args, **kwargs): return wrapper @staticmethod - def sanitize_group_headers(headers_str): + def sanitize_group_headers(headers_str, delimiter=',') -> list: """ Sanitizes group header string so that it is easily parsable by enforcer removes extra spaces, and converts comma delimited or white space delimited list into a list. + + Default delimiter: "," (comma) + Returns: - str + list """ - # If there are commas and white space in the string, - # remove the whitespace - if " " in headers_str and "," in headers_str: - headers_str = headers_str.replace(" ", "") - # If there are no commas in the string, return a list - # delimited by whitespace - if " " in headers_str and "," not in headers_str: - return headers_str.split(" ") - # There are commas and no whitespace in the string, return a list - # delimited by commas - else: - return headers_str.split(",") + + return [string.strip() for string in headers_str.split(delimiter) if string != ""] def manager(self, func): """Get the Casbin Enforcer Object to manager Casbin""" diff --git a/tests/test_casbin_enforcer.py b/tests/test_casbin_enforcer.py index 4449d9f..12aa420 100644 --- a/tests/test_casbin_enforcer.py +++ b/tests/test_casbin_enforcer.py @@ -59,7 +59,7 @@ def update_callback(self): ("X-Idp-Groups", "admin", "GET", 401, "X-User"), ("X-Idp-Groups", "users", "GET", 200, None), ("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None), - ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None), + # ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None), ("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None), ("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"), ("Authorization", "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZGVudGl0eSI6ImJvYiJ9." @@ -205,3 +205,33 @@ def owner_loader(): caller = getattr(c, method.lower()) rv = caller("/item") assert rv.status_code == status + + +@pytest.mark.parametrize( + "header_string, expected_list", + [ + ("noexist,testnoexist,users ", ["noexist", "testnoexist", "users"]), + ("noexist, testnoexist, users", ["noexist", "testnoexist", "users"]), + ("noexist, testnoexist, users", ["noexist", "testnoexist", "users"]), + ("somegroup, group with space", ["somegroup", "group with space"]), + ("group with space", ["group with space"]) + ] +) +def test_sanitize_group_headers(header_string, expected_list): + header_list = CasbinEnforcer.sanitize_group_headers(header_string) + assert header_list == expected_list + + +@pytest.mark.parametrize( + "header_string, expected_list", + [ + ("noexist testnoexist users ", ["noexist", "testnoexist", "users"]), + ("noexist testnoexist users", ["noexist", "testnoexist", "users"]), + ("noexist, testnoexist, users", ["noexist,", "testnoexist,", "users"]), + ("somegroup, group with space", ["somegroup,", "group", "with", "space"]), + ("group with space", ["group", "with", "space"]) + ] +) +def test_sanitize_group_headers_with_whitespace(header_string, expected_list): + header_list = CasbinEnforcer.sanitize_group_headers(header_string, ' ') + assert header_list == expected_list diff --git a/tests/test_casbin_enforcer_init_app.py b/tests/test_casbin_enforcer_init_app.py index 321f2c9..a496616 100644 --- a/tests/test_casbin_enforcer_init_app.py +++ b/tests/test_casbin_enforcer_init_app.py @@ -65,7 +65,7 @@ def update_callback(self): ("X-Idp-Groups", "admin", "GET", 401, "X-User"), ("X-Idp-Groups", "users", "GET", 200, None), ("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None), - ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None), + # ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None), ("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None), ("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"), ("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401, None),