Skip to content

Commit

Permalink
Merge 64a6add into b0ae545
Browse files Browse the repository at this point in the history
  • Loading branch information
yance-dev committed Apr 19, 2022
2 parents b0ae545 + 64a6add commit fdaf019
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -113,6 +113,8 @@ It used the casbin config from `examples` folder, and you can find this demo in

You can also view the unit tests to understand this middleware.

Besides, there is another example for `CasbinMiddleware` which is designed to work with JWT authentication. You can find it in `demo/jwt_test.py`.

## Development

### Run unit tests
Expand Down
106 changes: 106 additions & 0 deletions demo/jwt_test.py
@@ -0,0 +1,106 @@
from datetime import datetime, timedelta
from typing import Optional, Tuple, Union

import casbin
import jwt
import uvicorn
from fastapi import FastAPI
from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials)
from starlette.middleware.authentication import AuthenticationMiddleware

from fastapi_authz import CasbinMiddleware

JWT_SECRET_KEY = "secret"
app = FastAPI()


class JWTUser(BaseUser):
def __init__(self, username: str, token: str, payload: dict) -> None:
self.username = username
self.token = token
self.payload = payload

@property
def is_authenticated(self) -> bool:
return True

@property
def display_name(self) -> str:
return self.username


class JWTAuthenticationBackend(AuthenticationBackend):

def __init__(self,
secret_key: str,
algorithm: str = 'HS256',
prefix: str = 'Bearer',
username_field: str = 'username',
audience: Optional[str] = None,
options: Optional[dict] = None) -> None:
self.secret_key = secret_key
self.algorithm = algorithm
self.prefix = prefix
self.username_field = username_field
self.audience = audience
self.options = options or dict()

@classmethod
def get_token_from_header(cls, authorization: str, prefix: str) -> str:
"""Parses the Authorization header and returns only the token"""
try:
scheme, token = authorization.split()
except ValueError as e:
raise AuthenticationError('Could not separate Authorization scheme and token') from e

if scheme.lower() != prefix.lower():
raise AuthenticationError(f'Authorization scheme {scheme} is not supported')
return token

async def authenticate(self, request) -> Union[None, Tuple[AuthCredentials, BaseUser]]:
if "Authorization" not in request.headers:
return None

auth = request.headers["Authorization"]
token = self.get_token_from_header(authorization=auth, prefix=self.prefix)
try:
payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm, audience=self.audience,
options=self.options)
except jwt.InvalidTokenError as e:
raise AuthenticationError(str(e)) from e
return AuthCredentials(["authenticated"]), JWTUser(username=payload[self.username_field], token=token,
payload=payload)


enforcer = casbin.Enforcer('../examples/rbac_model.conf', '../examples/rbac_policy.csv')
app.add_middleware(CasbinMiddleware, enforcer=enforcer)

app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key=JWT_SECRET_KEY))


def create_access_token(subject: str, expires_delta: timedelta = None) -> str:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=60
)
to_encode = {"exp": expire, "username": subject}
return jwt.encode(to_encode, JWT_SECRET_KEY, algorithm="HS256")


@app.get('/')
async def index():
return "If you see this, you have been authenticated."


@app.get('/dataset1/protected')
async def auth_test():
return "You must be alice to see this."


if __name__ == '__main__':
print("alice:", create_access_token("alice", expires_delta=timedelta(minutes=60)))
print("mark:", create_access_token("mark", expires_delta=timedelta(minutes=60)))
uvicorn.run(app, debug=True)
3 changes: 2 additions & 1 deletion dev-requirements.in
Expand Up @@ -8,4 +8,5 @@ uvicorn
starlette-auth-toolkit
requests
twine
build
build
pyjwt
1 change: 1 addition & 0 deletions dev-requirements.txt
Expand Up @@ -25,6 +25,7 @@ pluggy==0.13.1
py==1.10.0
pydantic==1.7.3
pygments==2.8.0
pyjwt==2.3.0
pyparsing==2.4.7
pytest-cov==2.11.1
pytest==6.2.2
Expand Down
84 changes: 84 additions & 0 deletions tests/conftest.py
Expand Up @@ -7,6 +7,17 @@
from fastapi import FastAPI
from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials, SimpleUser
from starlette.middleware.authentication import AuthenticationMiddleware
import os
from typing import Optional, Tuple, Union

import casbin
import jwt
import pytest
from fastapi import FastAPI
from starlette.authentication import (
AuthenticationBackend, AuthenticationError, BaseUser, AuthCredentials)
from starlette.middleware.authentication import AuthenticationMiddleware
from datetime import datetime, timedelta

from fastapi_authz import CasbinMiddleware

Expand Down Expand Up @@ -42,3 +53,76 @@ def app_fixture():
app.add_middleware(AuthenticationMiddleware, backend=BasicAuth())

yield app




class JWTUser(BaseUser):
def __init__(self, username: str, token: str, payload: dict) -> None:
self.username = username
self.token = token
self.payload = payload

@property
def is_authenticated(self) -> bool:
return True

@property
def display_name(self) -> str:
return self.username


class JWTAuthenticationBackend(AuthenticationBackend):

def __init__(self,
secret_key: str,
algorithm: str = 'HS256',
prefix: str = 'Bearer',
username_field: str = 'username',
audience: Optional[str] = None,
options: Optional[dict] = None) -> None:
self.secret_key = secret_key
self.algorithm = algorithm
self.prefix = prefix
self.username_field = username_field
self.audience = audience
self.options = options or dict()

@classmethod
def get_token_from_header(cls, authorization: str, prefix: str) -> str:
"""Parses the Authorization header and returns only the token"""
try:
scheme, token = authorization.split()
except ValueError as e:
raise AuthenticationError('Could not separate Authorization scheme and token') from e

if scheme.lower() != prefix.lower():
raise AuthenticationError(f'Authorization scheme {scheme} is not supported')
return token

async def authenticate(self, request) -> Union[None, Tuple[AuthCredentials, BaseUser]]:
if "Authorization" not in request.headers:
return None

auth = request.headers["Authorization"]
token = self.get_token_from_header(authorization=auth, prefix=self.prefix)
try:
payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm, audience=self.audience,
options=self.options)
except jwt.InvalidTokenError as e:
raise AuthenticationError(str(e)) from e
return AuthCredentials(["authenticated"]), JWTUser(username=payload[self.username_field], token=token,
payload=payload)


@pytest.fixture
def jwt_app_fixture():
JWT_SECRET_KEY = "secret"
enforcer = casbin.Enforcer(get_examples("rbac_model.conf"), get_examples("rbac_policy.csv"))

app = FastAPI()

app.add_middleware(CasbinMiddleware, enforcer=enforcer)
app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key=JWT_SECRET_KEY))

yield app
55 changes: 55 additions & 0 deletions tests/test_jwt.py
@@ -0,0 +1,55 @@
import pytest
from starlette.testclient import TestClient
import jwt
from datetime import datetime, timedelta


@pytest.mark.parametrize(
"test_server_path, test_client_path, method, status_code, user, response_body", [
('/dataset1/resource2', '/dataset1/resource2', 'GET', 200, 'alice', 'ok'),
('/dataset1/resource2', '/dataset1/resource2', 'GET', 403, 'notalice', 'Forbidden'),
('/dataset1/resource2', '/dataset1/resource2', 'OPTIONS', 200, 'notalice', 'ok'),
('/dataset1/resource1', '/dataset1/resource1', 'POST', 200, 'alice', 'ok'),
]
)
def test_jwt_middleware_authed(jwt_app_fixture, test_server_path, test_client_path, method, status_code, user,
response_body):
@getattr(jwt_app_fixture, method.lower())(test_server_path)
async def index():
return 'ok'

JWT_SECRET_KEY = "secret"
test_client = TestClient(jwt_app_fixture)
expire = datetime.utcnow() + timedelta(
minutes=60
)
token = jwt.encode({"exp": expire, "username": user}, JWT_SECRET_KEY, algorithm="HS256")

test_response = getattr(test_client, method.lower())(test_client_path, headers={'Authorization': 'Bearer ' + token})

assert test_response.status_code == status_code
assert test_response.json() == response_body


@pytest.mark.parametrize(
"test_server_path, test_client_path, method, status_code, response_body", [
('/login', '/login', 'GET', 200, 'ok'),
('/', '/', 'GET', 200, 'ok')
]
)
def test_jwt_middleware_not_authed(jwt_app_fixture, test_server_path, test_client_path, method, status_code,
response_body):
@getattr(jwt_app_fixture, method.lower())(test_server_path)
async def index():
return 'ok'

test_client = TestClient(jwt_app_fixture)

test_response = getattr(test_client, method.lower())(test_client_path)

assert test_response.status_code == status_code
assert test_response.json() == response_body


if __name__ == '__main__':
pytest.main()

0 comments on commit fdaf019

Please sign in to comment.