Skip to content

Commit

Permalink
Merge pull request #35 from ffyuanda/master
Browse files Browse the repository at this point in the history
fix: support default delimiter for sanitize_group_headers()
  • Loading branch information
leeqvip committed Jul 4, 2021
2 parents c6f28cd + 0683ca3 commit 0a20a31
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
26 changes: 10 additions & 16 deletions flask_authz/casbin_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def owner_loader(self, callback):
self._owner_loader = callback
return callback

def enforcer(self, func):
def enforcer(self, func, delimiter=','):
@wraps(func)
def wrapper(*args, **kwargs):
if self.e.watcher and self.e.watcher.should_reload():
Expand Down Expand Up @@ -117,7 +117,8 @@ def wrapper(*args, **kwargs):
# Split header by ',' in case of groups when groups are
# sent "group1,group2,group3,..." in the header
for owner in self.sanitize_group_headers(
request.headers.get(header)
request.headers.get(header),
delimiter
):
self.app.logger.debug(
"Enforce against owner: %s header: %s"
Expand Down Expand Up @@ -149,26 +150,19 @@ def wrapper(*args, **kwargs):
return wrapper

@staticmethod
def sanitize_group_headers(headers_str):
def sanitize_group_headers(headers_str, delimiter=',') -> list:
"""
Sanitizes group header string so that it is easily parsable by enforcer
removes extra spaces, and converts comma delimited or white space
delimited list into a list.
Default delimiter: "," (comma)
Returns:
str
list
"""
# If there are commas and white space in the string,
# remove the whitespace
if " " in headers_str and "," in headers_str:
headers_str = headers_str.replace(" ", "")
# If there are no commas in the string, return a list
# delimited by whitespace
if " " in headers_str and "," not in headers_str:
return headers_str.split(" ")
# There are commas and no whitespace in the string, return a list
# delimited by commas
else:
return headers_str.split(",")

return [string.strip() for string in headers_str.split(delimiter) if string != ""]

def manager(self, func):
"""Get the Casbin Enforcer Object to manager Casbin"""
Expand Down
32 changes: 31 additions & 1 deletion tests/test_casbin_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def update_callback(self):
("X-Idp-Groups", "admin", "GET", 401, "X-User"),
("X-Idp-Groups", "users", "GET", 200, None),
("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None),
("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
# ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None),
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"),
("Authorization", "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZGVudGl0eSI6ImJvYiJ9."
Expand Down Expand Up @@ -205,3 +205,33 @@ def owner_loader():
caller = getattr(c, method.lower())
rv = caller("/item")
assert rv.status_code == status


@pytest.mark.parametrize(
"header_string, expected_list",
[
("noexist,testnoexist,users ", ["noexist", "testnoexist", "users"]),
("noexist, testnoexist, users", ["noexist", "testnoexist", "users"]),
("noexist, testnoexist, users", ["noexist", "testnoexist", "users"]),
("somegroup, group with space", ["somegroup", "group with space"]),
("group with space", ["group with space"])
]
)
def test_sanitize_group_headers(header_string, expected_list):
header_list = CasbinEnforcer.sanitize_group_headers(header_string)
assert header_list == expected_list


@pytest.mark.parametrize(
"header_string, expected_list",
[
("noexist testnoexist users ", ["noexist", "testnoexist", "users"]),
("noexist testnoexist users", ["noexist", "testnoexist", "users"]),
("noexist, testnoexist, users", ["noexist,", "testnoexist,", "users"]),
("somegroup, group with space", ["somegroup,", "group", "with", "space"]),
("group with space", ["group", "with", "space"])
]
)
def test_sanitize_group_headers_with_whitespace(header_string, expected_list):
header_list = CasbinEnforcer.sanitize_group_headers(header_string, ' ')
assert header_list == expected_list
2 changes: 1 addition & 1 deletion tests/test_casbin_enforcer_init_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def update_callback(self):
("X-Idp-Groups", "admin", "GET", 401, "X-User"),
("X-Idp-Groups", "users", "GET", 200, None),
("X-Idp-Groups", "noexist,testnoexist,users", "GET", 200, None),
("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
# ("X-Idp-Groups", "noexist testnoexist users", "GET", 200, None),
("X-Idp-Groups", "noexist, testnoexist, users", "GET", 200, None),
("Authorization", "Basic Ym9iOnBhc3N3b3Jk", "GET", 200, "Authorization"),
("Authorization", "Unsupported Ym9iOnBhc3N3b3Jk", "GET", 401, None),
Expand Down

0 comments on commit 0a20a31

Please sign in to comment.