Skip to content

Commit

Permalink
fix: add pkce
Browse files Browse the repository at this point in the history
  • Loading branch information
joel@joellee.org committed Sep 25, 2023
1 parent 152ed06 commit 02f7c05
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gotrue/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..types import (
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
AuthMFAGetAuthenticatorAssuranceLevelResponse,
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
persist_session: bool = True,
storage: Union[AsyncSupportedStorage, None] = None,
http_client: Union[AsyncClient, None] = None,
flow_type: AuthFlowType = "implicit",
) -> None:
AsyncGoTrueBaseAPI.__init__(
self,
Expand All @@ -92,6 +94,7 @@ def __init__(
self._refresh_token_timer: Union[Timer, None] = None
self._network_retries = 0
self._state_change_emitters: Dict[str, Subscription] = {}
self._flow_type = flow_type

self.admin = AsyncGoTrueAdminAPI(
url=self._url,
Expand Down Expand Up @@ -252,6 +255,7 @@ async def sign_in_with_oauth(
Log in an existing user via a third-party provider.
"""
await self._remove_session()

provider = credentials.get("provider")
options = credentials.get("options", {})
redirect_to = options.get("redirect_to")
Expand Down Expand Up @@ -839,6 +843,16 @@ def _get_url_for_provider(
provider: Provider,
params: Dict[str, str],
) -> str:
if self.flow_type == "pkce":
code_verifier = generate_pkce_verifier()
code_challenge = generate_pkce_challenge(code_verifier)
self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
code_challenge_method = (
"plain" if code_verifier == code_challenge else "s256"
)
params["code_challenge"] = code_challenge
params["code_challenge_method"] = code_challenge_method

params["provider"] = provider
query = urlencode(params)
return f"{self._url}/authorize?{query}"
Expand All @@ -848,3 +862,21 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
Decodes a JWT (without performing any validation).
"""
return decode_jwt_payload(jwt)

def exchange_code_for_session(auth_code: str):
code_verifier = get_item(f"{self._storage_key}-code-verifier")
response = self._request(
"POST",
"token?grant_type=pkce",
body={
"auth_code": email,
"code_verifier": code_verifier,
},
redirect_to=redirect_to,
xform=parse_auth_response,
)
self._storage.remove_item(f"{self._storage_key}-code-verifier")
if response.session:
self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
return response
32 changes: 32 additions & 0 deletions gotrue/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..types import (
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
AuthMFAGetAuthenticatorAssuranceLevelResponse,
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
persist_session: bool = True,
storage: Union[SyncSupportedStorage, None] = None,
http_client: Union[SyncClient, None] = None,
flow_type: AuthFlowType = "implicit",
) -> None:
SyncGoTrueBaseAPI.__init__(
self,
Expand All @@ -92,6 +94,7 @@ def __init__(
self._refresh_token_timer: Union[Timer, None] = None
self._network_retries = 0
self._state_change_emitters: Dict[str, Subscription] = {}
self._flow_type = flow_type

self.admin = SyncGoTrueAdminAPI(
url=self._url,
Expand Down Expand Up @@ -252,6 +255,7 @@ def sign_in_with_oauth(
Log in an existing user via a third-party provider.
"""
self._remove_session()

provider = credentials.get("provider")
options = credentials.get("options", {})
redirect_to = options.get("redirect_to")
Expand Down Expand Up @@ -837,6 +841,16 @@ def _get_url_for_provider(
provider: Provider,
params: Dict[str, str],
) -> str:
if self.flow_type == "pkce":
code_verifier = generate_pkce_verifier()
code_challenge = generate_pkce_challenge(code_verifier)
self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
code_challenge_method = (
"plain" if code_verifier == code_challenge else "s256"
)
params["code_challenge"] = code_challenge
params["code_challenge_method"] = code_challenge_method

params["provider"] = provider
query = urlencode(params)
return f"{self._url}/authorize?{query}"
Expand All @@ -846,3 +860,21 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
Decodes a JWT (without performing any validation).
"""
return decode_jwt_payload(jwt)

def exchange_code_for_session(auth_code: str):
code_verifier = get_item(f"{self._storage_key}-code-verifier")
response = self._request(
"POST",
"token?grant_type=pkce",
body={
"auth_code": email,
"code_verifier": code_verifier,
},
redirect_to=redirect_to,
xform=parse_auth_response,
)
self._storage.remove_item(f"{self._storage_key}-code-verifier")
if response.session:
self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
return response
29 changes: 29 additions & 0 deletions gotrue/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import base64
import hashlib
import secrets
from base64 import b64decode
from json import loads
from typing import Any, Dict, Type, TypeVar, Union, cast
Expand Down Expand Up @@ -124,3 +127,29 @@ def decode_jwt_payload(token: str) -> Any:
# binascii.Error: Incorrect padding
base64UrlWithPadding = base64Url + "=" * (-len(base64Url) % 4)
return loads(b64decode(base64UrlWithPadding).decode("utf-8"))


def generate_pkce_verifier(length=64):
"""Generate a random PKCE verifier of the specified length."""
if length < 43 or length > 128:
raise ValueError("PKCE verifier length must be between 43 and 128 characters")

# Define characters that can be used in the PKCE verifier
charset = string.ascii_letters + string.digits + "-._~"

# Generate a random PKCE verifier using the secrets module
verifier = "".join(secrets.choice(charset) for _ in range(length))

return verifier


def generate_pkce_challenge(code_verifier):
"""Generate a code challenge from a PKCE verifier."""
# Hash the verifier using SHA-256
verifier_bytes = verifier.encode("utf-8")
sha256_hash = hashlib.sha256(verifier_bytes).digest()

# Encode the hash as URL-safe base64
base64_encoded = base64.urlsafe_b64encode(sha256_hash).rstrip(b"=").decode("utf-8")

return base64_encoded
2 changes: 2 additions & 0 deletions gotrue/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@

AuthChangeEventMFA = Literal["MFA_CHALLENGE_VERIFIED"]

AuthFlowType = Literal["pkce", "implicit"]

AuthChangeEvent = Literal[
"PASSWORD_RECOVERY",
"SIGNED_IN",
Expand Down

0 comments on commit 02f7c05

Please sign in to comment.