Skip to content

Commit

Permalink
feat: forbid generating an access token without a session (#1504)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Enforces the precondition of a valid session before one can create an
access token. This supports refactors around `generateAccessToken` and
`updateMFASessionAndClaims`. Also allows for stronger guarantees within
the function since one can always assume there is a valid session.

There were a few test changes:
- To mirror real world use, Access Tokens should now only exist where
there is a valid session. We wrap `generateAccessToken` into a helper
`generateAccessTokenAndSession` to replace previous occurrences where
session was set to nil.
- We split TestUpdatePassword into cases where reauthentication is
required and reauthentication is not required. We also attach a session
to two of the test cases as they were previously nil
  • Loading branch information
J0 committed Mar 28, 2024
1 parent 31a5854 commit 795e93d
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 43 deletions.
7 changes: 6 additions & 1 deletion internal/api/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string {
require.NoError(ts.T(), err, "Error making new user")

u.Role = "supabase_admin"
require.NoError(ts.T(), ts.API.db.Create(u))

session, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))

var token string
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
15 changes: 13 additions & 2 deletions internal/api/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() {

for _, c := range cases {
ts.Run(c.desc, func() {
token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, c.user, nil, models.PasswordGrant)
token := ts.generateAccessTokenAndSession(context.Background(), c.user, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
Expand Down Expand Up @@ -183,7 +183,7 @@ func (ts *IdentityTestSuite) TestUnlinkIdentity() {
identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider)
require.NoError(ts.T(), err)

token, _, _ := ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
token := ts.generateAccessTokenAndSession(context.Background(), u, models.PasswordGrant)
req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
Expand Down Expand Up @@ -213,3 +213,14 @@ func (ts *IdentityTestSuite) TestUnlinkIdentity() {
}

}

func (ts *IdentityTestSuite) generateAccessTokenAndSession(ctx context.Context, u *models.User, authenticationMethod models.AuthenticationMethod) string {
s, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)
return token

}
7 changes: 6 additions & 1 deletion internal/api/invite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string {

u, err := models.NewUser("123456789", email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"})
require.NoError(ts.T(), err, "Error making new user")
require.NoError(ts.T(), ts.API.db.Create(u))

u.Role = "supabase_admin"

var token string
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.Invite)

session, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.Invite)

require.NoError(ts.T(), err, "Error generating access token")

Expand Down
5 changes: 4 additions & 1 deletion internal/api/logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ func (ts *LogoutTestSuite) SetupTest() {

// generate access token to use for logout
var t string
t, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.PasswordGrant)
s, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))
t, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)
ts.token = t
}
Expand Down
7 changes: 5 additions & 2 deletions internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() {
u.PhoneConfirmedAt = &now
require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user")

var token string
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.OTP)
s, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(s))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, u, &s.ID, models.PasswordGrant)
require.NoError(ts.T(), err)

cases := []struct {
Expand Down
23 changes: 11 additions & 12 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,18 +293,17 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)

func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
config := a.config
aal, amr := models.AAL1, []models.AMREntry{}
sid := ""
if sessionId != nil {
sid = sessionId.String()
session, terr := models.FindSessionByID(tx, *sessionId, false)
if terr != nil {
return "", 0, terr
}
aal, amr, terr = session.CalculateAALAndAMR(user)
if terr != nil {
return "", 0, terr
}
if sessionId == nil {
return "", 0, internalServerError("Session is required to issue access token")
}
sid := sessionId.String()
session, terr := models.FindSessionByID(tx, *sessionId, false)
if terr != nil {
return "", 0, terr
}
aal, amr, terr := session.CalculateAALAndAMR(user)
if terr != nil {
return "", 0, terr
}

issuedAt := time.Now().UTC()
Expand Down
97 changes: 75 additions & 22 deletions internal/api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,20 @@ func (ts *UserTestSuite) generateToken(user *models.User, sessionId *uuid.UUID)
return token
}

func (ts *UserTestSuite) generateAccessTokenAndSession(user *models.User) string {
session, err := models.NewSession(user.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, &session.ID, models.PasswordGrant)
require.NoError(ts.T(), err, "Error generating access token")
return token
}

func (ts *UserTestSuite) TestUserGet() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err, "Error finding user")
token := ts.generateToken(u, nil)
token := ts.generateAccessTokenAndSession(u)

require.NoError(ts.T(), err, "Error generating access token")

Expand Down Expand Up @@ -120,7 +130,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")

token := ts.generateToken(u, nil)
token := ts.generateAccessTokenAndSession(u)

require.NoError(ts.T(), err, "Error generating access token")

Expand Down Expand Up @@ -183,7 +193,7 @@ func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() {

for _, c := range cases {
ts.Run(c.desc, func() {
token := ts.generateToken(u, nil)
token := ts.generateAccessTokenAndSession(u)
require.NoError(ts.T(), err, "Error generating access token")

var buffer bytes.Buffer
Expand Down Expand Up @@ -244,28 +254,28 @@ func (ts *UserTestSuite) TestUserUpdatePassword() {
expected expected
}{
{
desc: "Invalid password length",
newPassword: "",
desc: "Need reauthentication because outside of recently logged in window",
newPassword: "newpassword123",
nonce: "",
requireReauthentication: false,
sessionId: nil,
expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false},
requireReauthentication: true,
sessionId: &notRecentlyLoggedIn.ID,
expected: expected{code: http.StatusBadRequest, isAuthenticated: false},
},
{
desc: "No nonce provided",
newPassword: "newpassword123",
nonce: "",
sessionId: &notRecentlyLoggedIn.ID,
requireReauthentication: true,
sessionId: nil,
expected: expected{code: http.StatusBadRequest, isAuthenticated: false},
},
{
desc: "Need reauthentication because outside of recently logged in window",
newPassword: "newpassword123",
nonce: "",
desc: "Invalid nonce",
newPassword: "newpassword1234",
nonce: "123456",
sessionId: &notRecentlyLoggedIn.ID,
requireReauthentication: true,
sessionId: r2.SessionId,
expected: expected{code: http.StatusBadRequest, isAuthenticated: false},
expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false},
},
{
desc: "No need reauthentication because recently logged in",
Expand All @@ -275,20 +285,63 @@ func (ts *UserTestSuite) TestUserUpdatePassword() {
sessionId: r.SessionId,
expected: expected{code: http.StatusOK, isAuthenticated: true},
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.UpdatePasswordRequireReauthentication = c.requireReauthentication
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"password": c.newPassword, "nonce": c.nonce}))

req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
req.Header.Set("Content-Type", "application/json")
token := ts.generateToken(u, c.sessionId)

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

// Setup response recorder
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expected.code, w.Code)

// Request body
u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

require.Equal(ts.T(), c.expected.isAuthenticated, u.Authenticate(context.Background(), c.newPassword))
})
}
}

func (ts *UserTestSuite) TestUserUpdatePasswordNoReauthenticationRequired() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

type expected struct {
code int
isAuthenticated bool
}

var cases = []struct {
desc string
newPassword string
nonce string
requireReauthentication bool
expected expected
}{
{
desc: "Invalid nonce",
newPassword: "newpassword1234",
nonce: "123456",
requireReauthentication: true,
sessionId: nil,
desc: "Invalid password length",
newPassword: "",
nonce: "",
requireReauthentication: false,
expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false},
},

{
desc: "Valid password length",
newPassword: "newpassword",
nonce: "",
requireReauthentication: false,
sessionId: nil,
expected: expected{code: http.StatusOK, isAuthenticated: true},
},
}
Expand All @@ -301,7 +354,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() {

req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
req.Header.Set("Content-Type", "application/json")
token := ts.generateToken(u, c.sessionId)
token := ts.generateAccessTokenAndSession(u)

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

Expand Down Expand Up @@ -330,7 +383,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() {
u.EmailConfirmedAt = &now
require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user")

token := ts.generateToken(u, nil)
token := ts.generateAccessTokenAndSession(u)

// request for reauthentication nonce
req := httptest.NewRequest(http.MethodGet, "http://localhost/reauthenticate", nil)
Expand Down
8 changes: 6 additions & 2 deletions internal/api/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,13 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() {
req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
req.Header.Set("Content-Type", "application/json")

// Generate access token for request
// Generate access token for request and a mock session
var token string
token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, nil, models.MagicLink)
session, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Create(session))

token, _, err = ts.API.generateAccessToken(context.Background(), ts.API.db, u, &session.ID, models.MagicLink)
require.NoError(ts.T(), err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

Expand Down

0 comments on commit 795e93d

Please sign in to comment.