Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to Client Credentials Flow (OAuth2) #5052

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions fastapi/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from .api_key import APIKeyQuery as APIKeyQuery
from .http import HTTPAuthorizationCredentials as HTTPAuthorizationCredentials
from .http import HTTPBasic as HTTPBasic
from .http import HTTPBasicClientCredentials as HTTPBasicClientCredentials
from .http import HTTPBasicCredentials as HTTPBasicCredentials
from .http import HTTPBearer as HTTPBearer
from .http import HTTPClientCredentials as HTTPClientCredentials
from .http import HTTPDigest as HTTPDigest
from .oauth2 import OAuth2 as OAuth2
from .oauth2 import OAuth2AuthorizationCodeBearer as OAuth2AuthorizationCodeBearer
from .oauth2 import OAuth2ClientCredentials as OAuth2ClientCredentials
from .oauth2 import (
OAuth2ClientCredentialsRequestForm as OAuth2ClientCredentialsRequestForm,
)
from .oauth2 import OAuth2PasswordBearer as OAuth2PasswordBearer
from .oauth2 import OAuth2PasswordRequestForm as OAuth2PasswordRequestForm
from .oauth2 import OAuth2PasswordRequestFormStrict as OAuth2PasswordRequestFormStrict
Expand Down
78 changes: 59 additions & 19 deletions fastapi/security/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import binascii
from base64 import b64decode
from typing import Optional
from typing import Optional, Tuple

from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
Expand All @@ -22,18 +22,25 @@ class HTTPAuthorizationCredentials(BaseModel):
credentials: str


class HTTPClientCredentials(BaseModel):
client_id: str
client_secret: str


class HTTPBase(SecurityBase):
def __init__(
self,
*,
scheme: str,
scheme_name: Optional[str] = None,
realm: Optional[str] = None,
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme=scheme, description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
self.realm = realm

async def __call__(
self, request: Request
Expand All @@ -49,24 +56,9 @@ async def __call__(
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)


class HTTPBasic(HTTPBase):
def __init__(
self,
*,
scheme_name: Optional[str] = None,
realm: Optional[str] = None,
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme="basic", description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
self.auto_error = auto_error

async def __call__( # type: ignore
def get_http_basic_authorization_credentials(
self, request: Request
) -> Optional[HTTPBasicCredentials]:
) -> Optional[Tuple[str, str]]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
Expand All @@ -91,9 +83,33 @@ async def __call__( # type: ignore
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc
username, separator, password = data.partition(":")
first_credential, separator, second_credential = data.partition(":")
if not separator:
raise invalid_user_credentials_exc
return (first_credential, second_credential)


class HTTPBasic(HTTPBase):
def __init__(
self,
*,
scheme_name: Optional[str] = None,
realm: Optional[str] = None,
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme="basic", description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
self.auto_error = auto_error

async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
credentials = self.get_http_basic_authorization_credentials(request)
if not credentials:
return None
username, password = credentials
return HTTPBasicCredentials(username=username, password=password)


Expand Down Expand Up @@ -133,6 +149,30 @@ async def __call__(
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)


class HTTPBasicClientCredentials(HTTPBase):
def __init__(
self,
*,
scheme_name: Optional[str] = None,
realm: Optional[str] = None,
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme="basic", description=description)
self.scheme_name = scheme_name or self.__class__.__name__
self.realm = realm
self.auto_error = auto_error

async def __call__( # type:ignore
self, request: Request
) -> Optional[HTTPClientCredentials]:
credentials = self.get_http_basic_authorization_credentials(request)
if not credentials:
return None
client_id, client_secret = credentials
return HTTPClientCredentials(client_id=client_id, client_secret=client_secret)


class HTTPDigest(HTTPBase):
def __init__(
self,
Expand Down
84 changes: 84 additions & 0 deletions fastapi/security/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from fastapi.exceptions import HTTPException
Expand All @@ -10,6 +11,10 @@
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN


class SchemeName(str, Enum):
OAUTH2_CLIENT_CREDENTIALS = "oAuth2ClientCredentials"


class OAuth2PasswordRequestForm:
"""
This is a dependency class, use it like:
Expand Down Expand Up @@ -112,6 +117,52 @@ def __init__(
)


class OAuth2ClientCredentialsRequestForm:
"""
This is a dependency class, use it like:
token_scheme = HTTPBasicClientCredentials(
auto_error=False, scheme_name="oAuth2ClientCredentials"
)

@router.post("/token")
def create_access_token(
form: OAuth2ClientCredentialsRequestForm = Depends(),
basic_credentials: Optional[HTTPClientCredentials] = Depends(token_scheme),
):
if form.client_id and form.client_secret:
client_id = form.client_id
client_secret = form.client_secret
elif basic_credentials:
client_id = basic_credentials.client_id
client_secret = basic_credentials.client_secret
else:
HTTPException(status_code=400, detail="Client credentials not provided")
pass

This will allow the client to send its credentials either via headers or body with the request for a token.

grant_type: the OAuth2 spec says it is required and MUST be the fixed string "client_credentials".
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
"items:read items:write users:read profile openid"
client_id: optional string. OAuth2 recommends sending the client_id and client_secret
using HTTP Basic auth, as: client_id:client_secret
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret
using HTTP Basic auth, as: client_id:client_secret
"""

def __init__(
self,
grant_type: str = Form(None, regex="client_credentials"),
scope: str = Form(""),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
):
self.grant_type = grant_type
self.scopes = scope.split()
self.client_id = client_id
self.client_secret = client_secret


class OAuth2(SecurityBase):
def __init__(
self,
Expand Down Expand Up @@ -214,6 +265,39 @@ async def __call__(self, request: Request) -> Optional[str]:
return param


class OAuth2ClientCredentials(OAuth2):
def __init__(
self,
tokenUrl: str,
scopes: Optional[Dict[str, str]] = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
clientCredentials={"tokenUrl": tokenUrl, "scopes": scopes}
)
super().__init__(
flows=flows,
scheme_name=SchemeName.OAUTH2_CLIENT_CREDENTIALS,
auto_error=auto_error,
)

async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
return param


class SecurityScopes:
def __init__(self, scopes: Optional[List[str]] = None):
self.scopes = scopes or []
Expand Down
95 changes: 95 additions & 0 deletions tests/test_security_http_client_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from base64 import b64encode

from fastapi import FastAPI, Security
from fastapi.security import HTTPBasicClientCredentials, HTTPClientCredentials
from fastapi.testclient import TestClient

app = FastAPI()

security = HTTPBasicClientCredentials(auto_error=True, scheme_name="basic")


class ClientCredentialsAuthMock:
def __call__(self, r):
auth_mock = b64encode(b"max:powersecret").decode("ascii")
r.headers["Authorization"] = f"Basic {auth_mock}"
return r


@app.get("/users/me")
def read_current_user(credentials: HTTPClientCredentials = Security(security)):
if credentials:
return {
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
}


client = TestClient(app)

openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"summary": "Read Current User",
"operationId": "read_current_user_users_me_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
}
},
"security": [{"basic": []}],
}
}
},
"components": {"securitySchemes": {"basic": {"type": "http", "scheme": "basic"}}},
}


def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_security_http_basic():
auth = ClientCredentialsAuthMock()
response = client.get("/users/me", auth=auth)
assert response.status_code == 200, response.text
assert response.json() == {"client_id": "max", "client_secret": "powersecret"}


def test_security_http_basic_no_credentials():
response = client.get("/users/me")
assert response.json() == {"detail": "Not authenticated"}
assert response.status_code == 401, response.text
assert response.headers["WWW-Authenticate"] == "Basic"


def test_security_http_basic_invalid_credentials():
response = client.get(
"/users/me", headers={"Authorization": "Basic notabase64token"}
)
assert response.status_code == 401, response.text
assert response.headers["WWW-Authenticate"] == "Basic"
assert response.json() == {"detail": "Invalid authentication credentials"}


def test_security_http_basic_non_basic_credentials():
payload = b64encode(b"johnsecret").decode("ascii")
auth_header = f"Basic {payload}"
response = client.get("/users/me", headers={"Authorization": auth_header})
assert response.status_code == 401, response.text
assert response.headers["WWW-Authenticate"] == "Basic"
assert response.json() == {"detail": "Invalid authentication credentials"}


def test_no_return_none():
security.auto_error = False
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() is None
security.auto_error = True