diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 02befb90..11e79ce9 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -878,7 +878,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=partial(AuthMFAUnenrollResponse, model_validate), + xform=partial(model_validate, AuthMFAUnenrollResponse), ) async def _list_factors(self) -> AuthMFAListFactorsResponse: diff --git a/supabase_auth/_sync/gotrue_client.py b/supabase_auth/_sync/gotrue_client.py index 169ebd0e..d45d6c92 100644 --- a/supabase_auth/_sync/gotrue_client.py +++ b/supabase_auth/_sync/gotrue_client.py @@ -874,7 +874,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=partial(AuthMFAUnenrollResponse, model_validate), + xform=partial(model_validate, AuthMFAUnenrollResponse), ) def _list_factors(self) -> AuthMFAListFactorsResponse: diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 991a27b4..709e17c3 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -520,13 +520,21 @@ class GenerateEmailChangeLinkParams(TypedDict): ] -class MFAEnrollParams(TypedDict): - factor_type: Literal["totp", "phone"] +class MFAEnrollTOTPParams(TypedDict): + factor_type: Literal["totp"] issuer: NotRequired[str] friendly_name: NotRequired[str] + + +class MFAEnrollPhoneParams(TypedDict): + factor_type: Literal["phone"] + friendly_name: NotRequired[str] phone: str +MFAEnrollParams = Union[MFAEnrollTOTPParams, MFAEnrollPhoneParams] + + class MFAUnenrollParams(TypedDict): factor_id: str """ @@ -644,11 +652,17 @@ class AuthMFAEnrollResponse(BaseModel): """ Friendly name of the factor, useful for distinguishing between factors """ - phone: str + phone: Optional[str] = None """ Phone number of the MFA factor in E.164 format. Used to send messages """ + @model_validator_v1_v2_compat + def validate_phone_required_for_phone_type(cls, values: dict) -> dict: + if values.get("type") == "phone" and not values.get("phone"): + raise ValueError("phone is required when type is 'phone'") + return values + class AuthMFAUnenrollResponse(BaseModel): id: str @@ -666,7 +680,7 @@ class AuthMFAChallengeResponse(BaseModel): """ Timestamp in UNIX seconds when this challenge will no longer be usable. """ - factor_type: Literal["totp", "phone"] + factor_type: Optional[Literal["totp", "phone"]] = None """ Factor Type which generated the challenge """ diff --git a/tests/_async/test_gotrue.py b/tests/_async/test_gotrue.py index 2a14a562..a7990921 100644 --- a/tests/_async/test_gotrue.py +++ b/tests/_async/test_gotrue.py @@ -7,7 +7,12 @@ from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError from supabase_auth.helpers import decode_jwt -from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session +from .clients import ( + GOTRUE_JWT_SECRET, + auth_client, + auth_client_with_asymmetric_session, + auth_client_with_session, +) from .utils import mock_user_credentials @@ -189,3 +194,97 @@ async def test_set_session_with_invalid_token(): # Try to set the session with invalid tokens with pytest.raises(AuthInvalidJwtError): await client.set_session("invalid.token.here", "invalid_refresh_token") + + +async def test_mfa_enroll(): + client = auth_client_with_session() + + credentials = mock_user_credentials() + + # First sign up to get a valid session + await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + + # Test MFA enrollment + enroll_response = await client.mfa.enroll( + {"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"} + ) + + assert enroll_response.id is not None + assert enroll_response.type == "totp" + assert enroll_response.friendly_name == "test-factor" + assert enroll_response.totp.qr_code is not None + + +async def test_mfa_challenge(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + enroll_response = await client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA challenge + challenge_response = await client.mfa.challenge({"factor_id": enroll_response.id}) + assert challenge_response.id is not None + assert challenge_response.expires_at is not None + + +async def test_mfa_unenroll(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + enroll_response = await client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA unenroll + unenroll_response = await client.mfa.unenroll({"factor_id": enroll_response.id}) + assert unenroll_response.id == enroll_response.id + + +async def test_mfa_list_factors(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = await client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + await client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA list factors + list_response = await client.mfa.list_factors() + assert len(list_response.all) == 1 diff --git a/tests/_sync/test_gotrue.py b/tests/_sync/test_gotrue.py index b68fc9e2..33fff720 100644 --- a/tests/_sync/test_gotrue.py +++ b/tests/_sync/test_gotrue.py @@ -7,7 +7,12 @@ from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError from supabase_auth.helpers import decode_jwt -from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session +from .clients import ( + GOTRUE_JWT_SECRET, + auth_client, + auth_client_with_asymmetric_session, + auth_client_with_session, +) from .utils import mock_user_credentials @@ -189,3 +194,97 @@ def test_set_session_with_invalid_token(): # Try to set the session with invalid tokens with pytest.raises(AuthInvalidJwtError): client.set_session("invalid.token.here", "invalid_refresh_token") + + +def test_mfa_enroll(): + client = auth_client_with_session() + + credentials = mock_user_credentials() + + # First sign up to get a valid session + client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + + # Test MFA enrollment + enroll_response = client.mfa.enroll( + {"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"} + ) + + assert enroll_response.id is not None + assert enroll_response.type == "totp" + assert enroll_response.friendly_name == "test-factor" + assert enroll_response.totp.qr_code is not None + + +def test_mfa_challenge(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + enroll_response = client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA challenge + challenge_response = client.mfa.challenge({"factor_id": enroll_response.id}) + assert challenge_response.id is not None + assert challenge_response.expires_at is not None + + +def test_mfa_unenroll(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + enroll_response = client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA unenroll + unenroll_response = client.mfa.unenroll({"factor_id": enroll_response.id}) + assert unenroll_response.id == enroll_response.id + + +def test_mfa_list_factors(): + client = auth_client() + credentials = mock_user_credentials() + + # First sign up to get a valid session + signup_response = client.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert signup_response.session is not None + + # Enroll a factor first + client.mfa.enroll( + {"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"} + ) + + # Test MFA list factors + list_response = client.mfa.list_factors() + assert len(list_response.all) == 1