diff --git a/api/mfa_test.go b/api/mfa_test.go index b7143a283..c36c56358 100644 --- a/api/mfa_test.go +++ b/api/mfa_test.go @@ -13,6 +13,8 @@ import ( "github.com/netlify/gotrue/conf" "github.com/netlify/gotrue/models" "github.com/netlify/gotrue/utilities" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -20,8 +22,11 @@ import ( type MFATestSuite struct { suite.Suite - API *API - Config *conf.GlobalConfiguration + API *API + Config *conf.GlobalConfiguration + TestDomain string + TestEmail string + TestOTPKey *otp.Key } func TestMFA(t *testing.T) { @@ -50,9 +55,27 @@ func (ts *MFATestSuite) SetupTest() { s, err := models.NewSession(u, &f.ID) require.NoError(ts.T(), err, "Error creating test session") require.NoError(ts.T(), ts.API.db.Create(s), "Error saving test session") + + // Generate TOTP related settings + emailValue, err := u.Email.Value() + require.NoError(ts.T(), err) + testEmail := emailValue.(string) + testDomain := strings.Split(testEmail, "@")[1] + ts.TestDomain = testDomain + ts.TestEmail = testEmail + + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: ts.TestDomain, + AccountName: ts.TestEmail, + }) + require.NoError(ts.T(), err) + ts.TestOTPKey = key + } func (ts *MFATestSuite) TestEnrollFactor() { + testFriendlyName := "bob" + alternativeFriendlyName := "john" var cases = []struct { desc string FriendlyName string @@ -62,31 +85,31 @@ func (ts *MFATestSuite) TestEnrollFactor() { }{ { "TOTP: No issuer", - "john", + alternativeFriendlyName, models.TOTP, "", http.StatusOK, }, { "Invalid factor type", - "bob", + testFriendlyName, "", - "john.com", + ts.TestDomain, http.StatusUnprocessableEntity, }, { "TOTP: Factor has friendly name", - "bob", + testFriendlyName, models.TOTP, - "supabase.com", + ts.TestDomain, http.StatusOK, }, { "TOTP: Enrolling without friendly name", "", models.TOTP, - "supabase.com", + ts.TestDomain, http.StatusOK, }, } @@ -139,7 +162,6 @@ func (ts *MFATestSuite) TestChallengeFactor() { require.Equal(ts.T(), http.StatusOK, w.Code) } -// TODO: Check behavior that downgrades all other sessions func (ts *MFATestSuite) TestMFAVerifyFactor() { cases := []struct { desc string @@ -170,30 +192,24 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { ts.Run(v.desc, func() { u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - emailValue, err := u.Email.Value() - require.NoError(ts.T(), err) - testEmail := emailValue.(string) - testDomain := strings.Split(testEmail, "@")[1] - // set factor secret - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: testDomain, - AccountName: testEmail, - }) + + //r, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) require.NoError(ts.T(), err) - sharedSecret := key.Secret() + + sharedSecret := ts.TestOTPKey.Secret() factors, err := models.FindFactorsByUser(ts.API.db, u) f := factors[0] f.Secret = sharedSecret require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor") - s2, err := models.NewSession(u, &f.ID) + secondarySession, err := models.NewSession(u, &f.ID) require.NoError(ts.T(), err, "Error creating test session") - require.NoError(ts.T(), ts.API.db.Create(s2), "Error saving test session") + require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session") - user, err := models.FindUserByEmailAndAudience(ts.API.db, testEmail, ts.Config.JWT.Aud) + user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud) ts.Require().NoError(err) - // Make a challenge var buffer bytes.Buffer + token, err := generateAccessToken(user, nil, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret, nil, "") require.NoError(ts.T(), err) @@ -227,9 +243,8 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { require.Equal(ts.T(), v.expectedHTTPCode, w.Code) if v.expectedHTTPCode == http.StatusOK { - _, err = models.FindSessionById(ts.API.db, s2.ID) + _, err = models.FindSessionById(ts.API.db, secondarySession.ID) require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error()) - // Check that session is downgraded } if !v.validChallenge { _, err := models.FindChallengeByChallengeID(ts.API.db, c.ID) @@ -260,36 +275,27 @@ func (ts *MFATestSuite) TestUnenrollFactor() { require.NoError(ts.T(), err) s, err := models.FindSessionByUserID(ts.API.db, u.ID) require.NoError(ts.T(), err) - var s2 *models.Session + var secondarySession *models.Session if v.CreateAdditionalSession { factors, err := models.FindFactorsByUser(ts.API.db, u) require.NoError(ts.T(), err, "error finding factors") f := factors[0] - s2, err = models.NewSession(u, &f.ID) + secondarySession, err = models.NewSession(u, &f.ID) require.NoError(ts.T(), err, "Error creating test session") - require.NoError(ts.T(), ts.API.db.Create(s2), "Error saving test session") + require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session") } - emailValue, err := u.Email.Value() - require.NoError(ts.T(), err) - testEmail := emailValue.(string) - testDomain := strings.Split(testEmail, "@")[1] - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: testDomain, - AccountName: testEmail, - }) - require.NoError(ts.T(), err) - sharedSecret := key.Secret() factors, err := models.FindFactorsByUser(ts.API.db, u) require.NoError(ts.T(), err) f := factors[0] + + sharedSecret := ts.TestOTPKey.Secret() f.Secret = sharedSecret if v.IsFactorVerified { err = f.UpdateStatus(ts.API.db, models.FactorVerifiedState) require.NoError(ts.T(), err) } - require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor") var buffer bytes.Buffer @@ -312,7 +318,7 @@ func (ts *MFATestSuite) TestUnenrollFactor() { if v.IsFactorVerified && v.CreateAdditionalSession { _, err = models.FindFactorByFactorID(ts.API.db, f.ID) require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) - _, err = models.FindSessionById(ts.API.db, s2.ID) + _, err = models.FindSessionById(ts.API.db, secondarySession.ID) require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error()) } })