From 73e8d8742de3575b3165a707b5d2f486b2598d9d Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Thu, 7 Mar 2024 18:25:25 +0800 Subject: [PATCH] fix: unlink identity bugs (#1475) --- internal/api/identity.go | 20 ++++- internal/api/identity_test.go | 138 +++++++++++++++++++++++++++++++++- internal/models/user.go | 6 ++ 3 files changed, 158 insertions(+), 6 deletions(-) diff --git a/internal/api/identity.go b/internal/api/identity.go index 4db43d8c2..f47708555 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -60,11 +60,23 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { if terr := tx.Destroy(identityToBeDeleted); terr != nil { return internalServerError("Database error deleting identity").WithInternalError(terr) } - if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { - if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + + switch identityToBeDeleted.Provider { + case "phone": + user.PhoneConfirmedAt = nil + if terr := user.SetPhone(tx, ""); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + default: + if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { + if models.IsUniqueConstraintViolatedError(terr) { + return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + } + return internalServerError("Database error updating user email").WithInternalError(terr) } - return internalServerError("Database error updating user email").WithInternalError(terr) } if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { return internalServerError("Database error updating user providers").WithInternalError(terr) diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index a71c612d0..7f70af416 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -2,10 +2,13 @@ package api import ( "context" + "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" + "github.com/gofrs/uuid" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/api/provider" @@ -34,9 +37,10 @@ func (ts *IdentityTestSuite) SetupTest() { models.TruncateAll(ts.API.db) // Create user - u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error creating test user model") require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) // Create identity i, err := models.NewIdentity(u, "email", map[string]interface{}{ @@ -45,10 +49,31 @@ func (ts *IdentityTestSuite) SetupTest() { }) require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(i)) + + // Create user with 2 identities + u, err = models.NewUser("123456789", "two@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) + require.NoError(ts.T(), u.ConfirmPhone(ts.API.db)) + + i, err = models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) + + i2, err := models.NewIdentity(u, "phone", map[string]interface{}{ + "sub": u.ID.String(), + "phone": u.GetPhone(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i2)) } func (ts *IdentityTestSuite) TestLinkIdentityToUser() { - u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) ctx := withTargetUser(context.Background(), u) @@ -79,3 +104,112 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) require.Nil(ts.T(), u) } + +func (ts *IdentityTestSuite) TestUnlinkIdentityError() { + ts.Config.Security.ManualLinkingEnabled = true + userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + cases := []struct { + desc string + user *models.User + identityId uuid.UUID + expectedError *HTTPError + }{ + { + desc: "User must have at least 1 identity after unlinking", + user: userWithOneIdentity, + identityId: userWithOneIdentity.Identities[0].ID, + expectedError: badRequestError("User must have at least 1 identity after unlinking"), + }, + { + desc: "Identity doesn't exist", + user: userWithTwoIdentities, + identityId: uuid.Must(uuid.NewV4()), + expectedError: badRequestError("Identity doesn't exist"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, c.user, nil, models.PasswordGrant) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedError.Code, w.Code) + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expectedError.Message, data.Message) + }) + } +} + +func (ts *IdentityTestSuite) TestUnlinkIdentity() { + ts.Config.Security.ManualLinkingEnabled = true + + // we want to test 2 cases here: unlinking a phone identity and email identity from a user + cases := []struct { + desc string + // the provider to be unlinked + provider string + // the remaining provider that should be linked to the user + providerRemaining string + }{ + { + desc: "Unlink phone identity successfully", + provider: "phone", + providerRemaining: "email", + }, + { + desc: "Unlink email identity successfully", + provider: "email", + providerRemaining: "phone", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // teardown and reset the state of the db to prevent running into errors + ts.SetupTest() + u, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider) + require.NoError(ts.T(), err) + + token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // sanity checks + u, err = models.FindUserByID(ts.API.db, u.ID) + require.NoError(ts.T(), err) + require.Len(ts.T(), u.Identities, 1) + require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining) + + // conditional checks depending on the provider that was unlinked + switch c.provider { + case "phone": + require.Equal(ts.T(), "", u.GetPhone()) + require.Nil(ts.T(), u.PhoneConfirmedAt) + case "email": + require.Equal(ts.T(), "", u.GetEmail()) + require.Nil(ts.T(), u.EmailConfirmedAt) + } + + // user still has a phone / email identity linked so it should not be unconfirmed + require.NotNil(ts.T(), u.ConfirmedAt) + }) + } + +} diff --git a/internal/models/user.go b/internal/models/user.go index 6111fa7af..105c57a9c 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -262,6 +262,12 @@ func (u *User) UpdateUserEmailFromIdentities(tx *storage.Connection) error { if terr := u.SetEmail(tx, primaryIdentity.GetEmail()); terr != nil { return terr } + if primaryIdentity.GetEmail() == "" { + u.EmailConfirmedAt = nil + if terr := tx.UpdateOnly(u, "email_confirmed_at"); terr != nil { + return terr + } + } return nil }