Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=partial(AuthMFAUnenrollResponse, model_validate),
xform=partial(model_validate, AuthMFAUnenrollResponse),
)

async def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down
2 changes: 1 addition & 1 deletion supabase_auth/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=partial(AuthMFAUnenrollResponse, model_validate),
xform=partial(model_validate, AuthMFAUnenrollResponse),
)

def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down
22 changes: 18 additions & 4 deletions supabase_auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,13 +520,21 @@ class GenerateEmailChangeLinkParams(TypedDict):
]


class MFAEnrollParams(TypedDict):
factor_type: Literal["totp", "phone"]
class MFAEnrollTOTPParams(TypedDict):
factor_type: Literal["totp"]
issuer: NotRequired[str]
friendly_name: NotRequired[str]


class MFAEnrollPhoneParams(TypedDict):
factor_type: Literal["phone"]
friendly_name: NotRequired[str]
phone: str


MFAEnrollParams = Union[MFAEnrollTOTPParams, MFAEnrollPhoneParams]


class MFAUnenrollParams(TypedDict):
factor_id: str
"""
Expand Down Expand Up @@ -644,11 +652,17 @@ class AuthMFAEnrollResponse(BaseModel):
"""
Friendly name of the factor, useful for distinguishing between factors
"""
phone: str
phone: Optional[str] = None
"""
Phone number of the MFA factor in E.164 format. Used to send messages
"""

@model_validator_v1_v2_compat
def validate_phone_required_for_phone_type(cls, values: dict) -> dict:
if values.get("type") == "phone" and not values.get("phone"):
raise ValueError("phone is required when type is 'phone'")
return values

Comment on lines +660 to +665
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@silentworks is this the correct way of doing it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's correct.


class AuthMFAUnenrollResponse(BaseModel):
id: str
Expand All @@ -666,7 +680,7 @@ class AuthMFAChallengeResponse(BaseModel):
"""
Timestamp in UNIX seconds when this challenge will no longer be usable.
"""
factor_type: Literal["totp", "phone"]
factor_type: Optional[Literal["totp", "phone"]] = None
"""
Factor Type which generated the challenge
"""
Expand Down
101 changes: 100 additions & 1 deletion tests/_async/test_gotrue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
from supabase_auth.helpers import decode_jwt

from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
from .clients import (
GOTRUE_JWT_SECRET,
auth_client,
auth_client_with_asymmetric_session,
auth_client_with_session,
)
from .utils import mock_user_credentials


Expand Down Expand Up @@ -189,3 +194,97 @@ async def test_set_session_with_invalid_token():
# Try to set the session with invalid tokens
with pytest.raises(AuthInvalidJwtError):
await client.set_session("invalid.token.here", "invalid_refresh_token")


async def test_mfa_enroll():
client = auth_client_with_session()

credentials = mock_user_credentials()

# First sign up to get a valid session
await client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)

# Test MFA enrollment
enroll_response = await client.mfa.enroll(
{"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"}
)

assert enroll_response.id is not None
assert enroll_response.type == "totp"
assert enroll_response.friendly_name == "test-factor"
assert enroll_response.totp.qr_code is not None


async def test_mfa_challenge():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = await client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
enroll_response = await client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA challenge
challenge_response = await client.mfa.challenge({"factor_id": enroll_response.id})
assert challenge_response.id is not None
assert challenge_response.expires_at is not None


async def test_mfa_unenroll():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = await client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
enroll_response = await client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA unenroll
unenroll_response = await client.mfa.unenroll({"factor_id": enroll_response.id})
assert unenroll_response.id == enroll_response.id


async def test_mfa_list_factors():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = await client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
await client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA list factors
list_response = await client.mfa.list_factors()
assert len(list_response.all) == 1
101 changes: 100 additions & 1 deletion tests/_sync/test_gotrue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from supabase_auth.errors import AuthInvalidJwtError, AuthSessionMissingError
from supabase_auth.helpers import decode_jwt

from .clients import GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session
from .clients import (
GOTRUE_JWT_SECRET,
auth_client,
auth_client_with_asymmetric_session,
auth_client_with_session,
)
from .utils import mock_user_credentials


Expand Down Expand Up @@ -189,3 +194,97 @@ def test_set_session_with_invalid_token():
# Try to set the session with invalid tokens
with pytest.raises(AuthInvalidJwtError):
client.set_session("invalid.token.here", "invalid_refresh_token")


def test_mfa_enroll():
client = auth_client_with_session()

credentials = mock_user_credentials()

# First sign up to get a valid session
client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)

# Test MFA enrollment
enroll_response = client.mfa.enroll(
{"issuer": "test-issuer", "factor_type": "totp", "friendly_name": "test-factor"}
)

assert enroll_response.id is not None
assert enroll_response.type == "totp"
assert enroll_response.friendly_name == "test-factor"
assert enroll_response.totp.qr_code is not None


def test_mfa_challenge():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
enroll_response = client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA challenge
challenge_response = client.mfa.challenge({"factor_id": enroll_response.id})
assert challenge_response.id is not None
assert challenge_response.expires_at is not None


def test_mfa_unenroll():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
enroll_response = client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA unenroll
unenroll_response = client.mfa.unenroll({"factor_id": enroll_response.id})
assert unenroll_response.id == enroll_response.id


def test_mfa_list_factors():
client = auth_client()
credentials = mock_user_credentials()

# First sign up to get a valid session
signup_response = client.sign_up(
{
"email": credentials.get("email"),
"password": credentials.get("password"),
}
)
assert signup_response.session is not None

# Enroll a factor first
client.mfa.enroll(
{"factor_type": "totp", "issuer": "test-issuer", "friendly_name": "test-factor"}
)

# Test MFA list factors
list_response = client.mfa.list_factors()
assert len(list_response.all) == 1