Skip to content

Commit

Permalink
feat: add mfa
Browse files Browse the repository at this point in the history
  • Loading branch information
leynier committed Oct 21, 2022
1 parent c3ff22a commit 4c2b443
Show file tree
Hide file tree
Showing 11 changed files with 843 additions and 73 deletions.
28 changes: 28 additions & 0 deletions gotrue/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from ..http_clients import AsyncClient
from ..types import (
AdminUserAttributes,
AuthMFAAdminDeleteFactorParams,
AuthMFAAdminDeleteFactorResponse,
AuthMFAAdminListFactorsParams,
AuthMFAAdminListFactorsResponse,
GenerateLinkParams,
GenerateLinkResponse,
Options,
User,
UserResponse,
)
from .gotrue_admin_mfa_api import AsyncGoTrueAdminMFAAPI
from .gotrue_base_api import AsyncGoTrueBaseAPI


Expand All @@ -29,6 +34,9 @@ def __init__(
headers=headers,
http_client=http_client,
)
self.mfa = AsyncGoTrueAdminMFAAPI()
self.mfa.list_factors = self._list_factors
self.mfa.delete_factor = self._delete_factor

async def sign_out(self, jwt: str) -> None:
"""
Expand Down Expand Up @@ -142,3 +150,23 @@ async def delete_user(self, id: str) -> UserResponse:
f"admin/users/{id}",
xform=parse_user_response,
)

async def _list_factors(
self,
params: AuthMFAAdminListFactorsParams,
) -> AuthMFAAdminListFactorsResponse:
return await self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.parse_obj,
)

async def _delete_factor(
self,
params: AuthMFAAdminDeleteFactorParams,
) -> AuthMFAAdminDeleteFactorResponse:
return await self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.parse_obj,
)
32 changes: 32 additions & 0 deletions gotrue/_async/gotrue_admin_mfa_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from ..types import (
AuthMFAAdminDeleteFactorParams,
AuthMFAAdminDeleteFactorResponse,
AuthMFAAdminListFactorsParams,
AuthMFAAdminListFactorsResponse,
)


class AsyncGoTrueAdminMFAAPI:
"""
Contains the full multi-factor authentication administration API.
"""

async def list_factors(
self,
params: AuthMFAAdminListFactorsParams,
) -> AuthMFAAdminListFactorsResponse:
"""
Lists all factors attached to a user.
"""
raise NotImplementedError()

async def delete_factor(
self,
params: AuthMFAAdminDeleteFactorParams,
) -> AuthMFAAdminDeleteFactorResponse:
"""
Deletes a factor on a user. This will log the user out of all active
sessions (if the deleted factor was verified). There's no need to delete
unverified factors.
"""
raise NotImplementedError()
23 changes: 20 additions & 3 deletions gotrue/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, Callable, Dict, Literal, TypeVar, Union, overload

from httpx import Response
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -43,7 +44,7 @@ async def _request(
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
no_resolve_json: Union[bool, None] = None,
no_resolve_json: Literal[False] = False,
xform: Callable[[Any], T],
) -> T:
...
Expand All @@ -59,7 +60,23 @@ async def _request(
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
no_resolve_json: Union[bool, None] = None,
no_resolve_json: Literal[True],
xform: Callable[[Response], T],
) -> T:
...

@overload
async def _request(
self,
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
no_resolve_json: bool = False,
) -> None:
...

Expand All @@ -73,7 +90,7 @@ async def _request(
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
no_resolve_json: Union[bool, None] = None,
no_resolve_json: bool = False,
xform: Union[Callable[[Any], T], None] = None,
) -> Union[T, None]:
url = f"{self._url}/{path}"
Expand Down
130 changes: 120 additions & 10 deletions gotrue/_async/gotrue_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from base64 import b64decode
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
Expand All @@ -21,12 +20,24 @@
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import parse_auth_response, parse_user_response
from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response
from ..http_clients import AsyncClient
from ..timer import Timer
from ..types import (
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
AuthMFAGetAuthenticatorAssuranceLevelResponse,
AuthMFAListFactorsResponse,
AuthMFAUnenrollResponse,
AuthMFAVerifyResponse,
AuthResponse,
DecodedJWTDict,
MFAChallengeParams,
MFAEnrollParams,
MFAUnenrollParams,
MFAVerifyParams,
OAuthResponse,
Options,
Provider,
Expand All @@ -42,6 +53,7 @@
)
from .gotrue_admin_api import AsyncGoTrueAdminAPI
from .gotrue_base_api import AsyncGoTrueBaseAPI
from .gotrue_mfa_api import AsyncGoTrueMFAAPI
from .storage import AsyncMemoryStorage, AsyncSupportedStorage


Expand Down Expand Up @@ -77,6 +89,15 @@ def __init__(
headers=self._headers,
http_client=self._http_client,
)
self.mfa = AsyncGoTrueMFAAPI()
self.mfa.challenge = self._challenge
self.mfa.enroll = self._enroll
self.mfa.get_authenticator_assurance_level = (
self._get_authenticator_assurance_level
)
self.mfa.list_factors = self._list_factors
self.mfa.unenroll = self._unenroll
self.mfa.verify = self._verify

# Initializations

Expand Down Expand Up @@ -389,10 +410,10 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
has_expired = True
session: Union[Session, None] = None
if access_token and access_token.split(".")[1]:
json_raw = b64decode(access_token.split(".")[1] + "===").decode("utf-8")
payload = loads(json_raw)
if payload.get("exp"):
expires_at = int(payload.get("exp"))
payload = self._decode_jwt(access_token)
exp = payload.get("exp")
if exp:
expires_at = int(exp)
has_expired = expires_at <= time_now
if has_expired:
if not refresh_token:
Expand Down Expand Up @@ -474,6 +495,94 @@ async def reset_password_email(
redirect_to=options.get("redirect_to"),
)

# MFA methods

async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
session = await self.get_session()
if not session:
raise AuthSessionMissingError()
response = await self._request(
"POST",
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.parse_obj,
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
return response

async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse:
session = await self.get_session()
if not session:
raise AuthSessionMissingError()
return await self._request(
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.parse_obj,
)

async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
session = await self.get_session()
if not session:
raise AuthSessionMissingError()
response = await self._request(
"POST",
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.parse_obj,
)
session = Session.parse_obj(response.dict())
await self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response

async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
session = await self.get_session()
if not session:
raise AuthSessionMissingError()
return await self._request(
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.parse_obj,
)

async def _list_factors(self) -> AuthMFAListFactorsResponse:
response = await self.get_user()
all = response.user.factors or []
totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"]
return AuthMFAListFactorsResponse(all=all, totp=totp)

async def _get_authenticator_assurance_level(
self,
) -> AuthMFAGetAuthenticatorAssuranceLevelResponse:
session = await self.get_session()
if not session:
return AuthMFAGetAuthenticatorAssuranceLevelResponse(
current_level=None,
next_level=None,
current_authentication_methods=[],
)
payload = self._decode_jwt(session.access_token)
current_level: Union[AuthenticatorAssuranceLevels, None] = None
if payload.get("aal"):
current_level = payload.get("aal")
next_level = current_level
verified_factors = [
f for f in session.user.factors or [] if f.status == "verified"
]
if verified_factors:
next_level = "aal2"
current_authentication_methods = payload.get("amr") or []
return AuthMFAGetAuthenticatorAssuranceLevelResponse(
current_level=current_level,
next_level=next_level,
current_authentication_methods=current_authentication_methods,
)

# Private methods

async def _remove_session(self) -> None:
Expand Down Expand Up @@ -685,7 +794,8 @@ def _get_url_for_provider(
query = urlencode(params)
return f"{self._url}/authorize?{query}"


async def test():
client = AsyncGoTrueClient()
await client.initialize()
def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
"""
Decodes a JWT (without performing any validation).
"""
return decode_jwt_payload(jwt)
82 changes: 82 additions & 0 deletions gotrue/_async/gotrue_mfa_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from ..types import (
AuthMFAChallengeResponse,
AuthMFAEnrollResponse,
AuthMFAGetAuthenticatorAssuranceLevelResponse,
AuthMFAListFactorsResponse,
AuthMFAUnenrollResponse,
AuthMFAVerifyResponse,
MFAChallengeParams,
MFAEnrollParams,
MFAUnenrollParams,
MFAVerifyParams,
)


class AsyncGoTrueMFAAPI:
"""
Contains the full multi-factor authentication API.
"""

async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"""
Starts the enrollment process for a new Multi-Factor Authentication
factor. This method creates a new factor in the 'unverified' state.
Present the QR code or secret to the user and ask them to add it to their
authenticator app. Ask the user to provide you with an authenticator code
from their app and verify it by calling challenge and then verify.
The first successful verification of an unverified factor activates the
factor. All other sessions are logged out and the current one gets an
`aal2` authenticator level.
"""
raise NotImplementedError()

async def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse:
"""
Prepares a challenge used to verify that a user has access to a MFA
factor. Provide the challenge ID and verification code by calling `verify`.
"""
raise NotImplementedError()

async def verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
"""
Verifies a verification code against a challenge. The verification code is
provided by the user by entering a code seen in their authenticator app.
"""
raise NotImplementedError()

async def unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"""
Unenroll removes a MFA factor. Unverified factors can safely be ignored
and it's not necessary to unenroll them. Unenrolling a verified MFA factor
cannot be done from a session with an `aal1` authenticator level.
"""
raise NotImplementedError()

async def list_factors(self) -> AuthMFAListFactorsResponse:
"""
Returns the list of MFA factors enabled for this user. For most use cases
you should consider using `get_authenticator_assurance_level`.
This uses a cached version of the factors and avoids incurring a network call.
If you need to update this list, call `get_user` first.
"""
raise NotImplementedError()

async def get_authenticator_assurance_level(
self,
) -> AuthMFAGetAuthenticatorAssuranceLevelResponse:
"""
Returns the Authenticator Assurance Level (AAL) for the active session.
- `aal1` (or `null`) means that the user's identity has been verified only
with a conventional login (email+password, OTP, magic link, social login,
etc.).
- `aal2` means that the user's identity has been verified both with a
conventional login and at least one MFA factor.
Although this method returns a promise, it's fairly quick (microseconds)
and rarely uses the network. You can use this to check whether the current
user needs to be shown a screen to verify their MFA factors.
"""
raise NotImplementedError()

0 comments on commit 4c2b443

Please sign in to comment.