Skip to content

Commit

Permalink
feat: clean up test setup in MFA tests (#1452)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Remove redundant test setup, particularly around `TestSecondarySession`.
- Changes `generateToken` into `generateAAL1Token`
- Abstracts adding the claim and reloading a session into a single
method. Was probably a mistake to add
[AMREntry](https://github.com/supabase/gotrue/blob/e9f38e76d8a7b93c5c2bb0de918a9b156155f018/internal/models/sessions.go#L38).
Should sync to use
[AMRClaim](https://github.com/supabase/gotrue/blob/e9f38e76d8a7b93c5c2bb0de918a9b156155f018/internal/models/amr.go)
at some point
- Add additional check that a provider fields must exist if there's an
SSO Claim on the last entry

---------

Co-authored-by: joel <joel@joels-MacBook-Pro.local>
  • Loading branch information
J0 and joel committed Mar 3, 2024
1 parent b260449 commit 7185af8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 75 deletions.
104 changes: 44 additions & 60 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import (

type MFATestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
TestUser *models.User
TestSession *models.Session
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
TestUser *models.User
TestSession *models.Session
TestSecondarySession *models.Session
}

func TestMFA(t *testing.T) {
Expand Down Expand Up @@ -71,6 +72,12 @@ func (ts *MFATestSuite) SetupTest() {
ts.TestUser = u
ts.TestSession = s

secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

ts.TestSecondarySession = secondarySession

// Generate TOTP related settings
testDomain := strings.Split(ts.TestEmail, "@")[1]
ts.TestDomain = testDomain
Expand All @@ -84,7 +91,7 @@ func (ts *MFATestSuite) SetupTest() {

}

func (ts *MFATestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string {
func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string {
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.TOTPSignIn)
require.NoError(ts.T(), err, "Error generating access token")
return token
Expand All @@ -94,8 +101,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
testFriendlyName := "bob"
alternativeFriendlyName := "john"

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn)
require.NoError(ts.T(), err)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

var cases = []struct {
desc string
Expand Down Expand Up @@ -135,15 +141,14 @@ func (ts *MFATestSuite) TestEnrollFactor() {
}
for _, c := range cases {
ts.Run(c.desc, func() {

w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
ts.Require().NoError(err)
latestFactor := factors[len(factors)-1]
require.False(ts.T(), latestFactor.IsVerified())
addedFactor := factors[len(factors)-1]
require.False(ts.T(), addedFactor.IsVerified())
if c.friendlyName != "" && c.expectedCode == http.StatusOK {
require.Equal(ts.T(), c.friendlyName, latestFactor.FriendlyName)
require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName)
}
if w.Code == http.StatusOK {
enrollResp := EnrollFactorResponse{}
Expand All @@ -159,13 +164,13 @@ func (ts *MFATestSuite) TestEnrollFactor() {

func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() {
friendlyName := "mary"
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn)
require.NoError(ts.T(), err)
_ = performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusOK)
response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusBadRequest)
issuer := "https://issuer.com"
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
_ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusOK)
response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusBadRequest)

var errorResponse HTTPError
err = json.NewDecoder(response.Body).Decode(&errorResponse)
err := json.NewDecoder(response.Body).Decode(&errorResponse)
require.NoError(ts.T(), err)

// Convert the response body to a string and check for the expected error message
Expand All @@ -176,7 +181,7 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() {

func (ts *MFATestSuite) TestChallengeFactor() {
f := ts.TestUser.Factors[0]
token := ts.generateToken(ts.TestUser, nil)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
w := performChallengeFlow(ts, f.ID, token)
require.Equal(ts.T(), http.StatusOK, w.Code)
}
Expand Down Expand Up @@ -210,7 +215,6 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
// Authenticate users and set secret

var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{})
require.NoError(ts.T(), err)
Expand All @@ -221,12 +225,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

// Create session to be invalidated
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

token := ts.generateToken(ts.TestUser, r.SessionId)
token := ts.generateAAL1Token(ts.TestUser, r.SessionId)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer)
Expand Down Expand Up @@ -259,7 +258,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {

if v.expectedHTTPCode == http.StatusOK {
// Ensure alternate session has been deleted
_, err = models.FindSessionByID(ts.API.db, secondarySession.ID, false)
_, err = models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error())
}
if !v.validChallenge {
Expand All @@ -272,7 +271,6 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}

func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

cases := []struct {
desc string
isAAL2 bool
Expand All @@ -291,29 +289,20 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}
for _, v := range cases {
ts.Run(v.desc, func() {
// Create User
var buffer bytes.Buffer
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}
var secondarySession *models.Session

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
f.Secret = sharedSecret
err = f.UpdateStatus(ts.API.db, models.FactorStateVerified)
require.NoError(ts.T(), err)
f.Secret = ts.TestOTPKey.Secret()
require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified))
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

var buffer bytes.Buffer

token := ts.generateToken(ts.TestUser, &ts.TestSession.ID)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s/", f.ID), &buffer)
Expand All @@ -324,7 +313,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
if v.expectedHTTPCode == http.StatusOK {
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, secondarySession.ID, false)
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
require.Nil(ts.T(), session.FactorID)

Expand All @@ -335,19 +324,11 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}

func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
var secondarySession *models.Session
f := ts.TestUser.Factors[0]
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
f.Secret = sharedSecret

var buffer bytes.Buffer
f := ts.TestUser.Factors[0]
f.Secret = ts.TestOTPKey.Secret()

token := ts.generateToken(ts.TestUser, &ts.TestSession.ID)
require.NoError(ts.T(), err)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"factor_id": f.ID,
}))
Expand All @@ -357,21 +338,22 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)

_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, secondarySession.ID, false)
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
require.Nil(ts.T(), session.FactorID)

}

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
ts.Config.Security.RefreshTokenRotationEnabled = true
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": accessTokenResp.RefreshToken,
Expand All @@ -395,11 +377,11 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
ts.Config.Security.RefreshTokenRotationEnabled = true
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": ts.TestEmail,
Expand All @@ -415,15 +397,18 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
ctx, err := ts.API.parseJWTClaims(data.Token, req)
require.NoError(ts.T(), err)

ctx, err = ts.API.maybeLoadUserOrSession(ctx)
require.NoError(ts.T(), err)

require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL())
session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID)
require.NoError(ts.T(), err)
require.True(ts.T(), session.IsAAL2())
}

func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) {
ts.API.config.Mailer.Autoconfirm = true
var buffer bytes.Buffer

require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
Expand All @@ -434,7 +419,6 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
// Setup request
req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer)
req.Header.Set("Content-Type", "application/json")
ts.API.config.Mailer.Autoconfirm = true
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
Expand Down
37 changes: 22 additions & 15 deletions internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,52 @@ func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() {
require.Equal(ts.T(), session.ID, found.ID)
}

func (ts *SessionsTestSuite) AddClaimAndReloadSession(session *Session, claim AuthenticationMethod) *Session {
err := AddClaimToSession(ts.db, session.ID, claim)
require.NoError(ts.T(), err)
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
return session
}

func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
totalDistinctClaims := 2
totalDistinctClaims := 3
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
session, err := NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.db.Create(session))

err = AddClaimToSession(ts.db, session.ID, PasswordGrant)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, PasswordGrant)

firstClaimAddedTime := time.Now()
err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)

aal, amr, err := session.CalculateAALAndAMR(u)
_, _, err = session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)
require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)

session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, SSOSAML)

aal, amr, err = session.CalculateAALAndAMR(u)
aal, amr, err := session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)

require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

found := false
for _, claim := range session.AMRClaims {
if claim.GetAuthenticationMethod() == TOTPSignIn.String() {
require.True(ts.T(), firstClaimAddedTime.Before(claim.UpdatedAt))
found = true
}
}

for _, claim := range amr {
if claim.Method == SSOSAML.String() {
require.NotNil(ts.T(), claim.Provider)
}
}
require.True(ts.T(), found)
}

0 comments on commit 7185af8

Please sign in to comment.