diff --git a/api/admin_test.go b/api/admin_test.go index a2489ba3e..bb708a2b0 100644 --- a/api/admin_test.go +++ b/api/admin_test.go @@ -51,7 +51,7 @@ func (ts *AdminTestSuite) makeSuperAdmin(email string) string { u.Role = "supabase_admin" - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} diff --git a/api/api.go b/api/api.go index d6b778abe..dae206f74 100644 --- a/api/api.go +++ b/api/api.go @@ -152,13 +152,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.With(api.requireAuthentication).Post("/logout", api.Logout) - r.Route("/reauthenticate", func(r *router) { - r.Use(api.requireAuthentication) + r.With(api.requireAuthentication).Route("/reauthenticate", func(r *router) { r.Get("/", api.Reauthenticate) }) - r.Route("/user", func(r *router) { - r.Use(api.requireAuthentication) + r.With(api.requireAuthentication).Route("/user", func(r *router) { r.Get("/", api.UserGet) r.With(sharedLimiter).Put("/", api.UserUpdate) }) diff --git a/api/audit_test.go b/api/audit_test.go index 59513cdc0..ad868de3a 100644 --- a/api/audit_test.go +++ b/api/audit_test.go @@ -48,7 +48,7 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string { u.Role = "supabase_admin" - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} diff --git a/api/auth.go b/api/auth.go index bab0e1c00..cd16f8fb9 100644 --- a/api/auth.go +++ b/api/auth.go @@ -2,10 +2,12 @@ package api import ( "context" + "errors" "fmt" "net/http" "time" + "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt" "github.com/netlify/gotrue/models" "github.com/netlify/gotrue/storage" @@ -20,7 +22,16 @@ func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (con return nil, err } - return a.parseJWTClaims(token, r, w) + ctx, err := a.parseJWTClaims(token, r, w) + if err != nil { + return ctx, err + } + + ctx, err = a.maybeLoadUserOrSession(ctx) + if err != nil { + return ctx, err + } + return ctx, err } func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) { @@ -71,3 +82,41 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request, w http.ResponseWrit return withToken(ctx, token), nil } + +func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) { + claims := getClaims(ctx) + if claims == nil { + return ctx, errors.New("invalid token") + } + + if claims.Subject == "" { + return nil, errors.New("invalid claim: subject missing") + } + + var user *models.User + if claims.Subject != "" { + userId, err := uuid.FromString(claims.Subject) + if err != nil { + return ctx, err + } + user, err = models.FindUserByID(a.db, userId) + if err != nil { + return ctx, err + } + ctx = withUser(ctx, user) + } + + var session *models.Session + if claims.SessionId != "" { + sessionId, err := uuid.FromString(claims.SessionId) + if err != nil { + return ctx, err + } + session, err = models.FindSessionById(a.db, sessionId) + if err != nil { + return ctx, err + } + ctx = withSession(ctx, session) + } + return ctx, nil +} diff --git a/api/context.go b/api/context.go index d0150b937..1543b5ba1 100644 --- a/api/context.go +++ b/api/context.go @@ -20,6 +20,7 @@ const ( signatureKey = contextKey("signature") externalProviderTypeKey = contextKey("external_provider_type") userKey = contextKey("user") + sessionKey = contextKey("session") externalReferrerKey = contextKey("external_referrer") functionHooksKey = contextKey("function_hooks") adminUserKey = contextKey("admin_user") @@ -65,12 +66,12 @@ func getRequestID(ctx context.Context) string { return obj.(string) } -// withUser adds the user id to the context. +// withUser adds the user to the context. func withUser(ctx context.Context, u *models.User) context.Context { return context.WithValue(ctx, userKey, u) } -// getUser reads the user id from the context. +// getUser reads the user from the context. func getUser(ctx context.Context) *models.User { obj := ctx.Value(userKey) if obj == nil { @@ -79,6 +80,20 @@ func getUser(ctx context.Context) *models.User { return obj.(*models.User) } +// withSession adds the session to the context. +func withSession(ctx context.Context, s *models.Session) context.Context { + return context.WithValue(ctx, sessionKey, s) +} + +// getSession reads the session from the context. +func getSession(ctx context.Context) *models.Session { + obj := ctx.Value(sessionKey) + if obj == nil { + return nil + } + return obj.(*models.Session) +} + // withSignature adds the provided request ID to the context. func withSignature(ctx context.Context, id string) context.Context { return context.WithValue(ctx, signatureKey, id) diff --git a/api/helpers.go b/api/helpers.go index d2be35ffb..94f2aee57 100644 --- a/api/helpers.go +++ b/api/helpers.go @@ -13,7 +13,6 @@ import ( "github.com/gofrs/uuid" "github.com/netlify/gotrue/conf" "github.com/netlify/gotrue/models" - "github.com/netlify/gotrue/storage" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -49,23 +48,6 @@ func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { return err } -func getUserFromClaims(ctx context.Context, conn *storage.Connection) (*models.User, error) { - claims := getClaims(ctx) - if claims == nil { - return nil, errors.New("Invalid token") - } - - if claims.Subject == "" { - return nil, errors.New("Invalid claim: id") - } - - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return nil, errors.New("Invalid user ID") - } - return models.FindUserByID(conn, userID) -} - func (a *API) isAdmin(ctx context.Context, u *models.User, aud string) bool { config := a.config if aud == "" { diff --git a/api/invite_test.go b/api/invite_test.go index de3663f42..52a8e0924 100644 --- a/api/invite_test.go +++ b/api/invite_test.go @@ -58,7 +58,7 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string { u.Role = "supabase_admin" - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}} diff --git a/api/logout.go b/api/logout.go index 252245412..c95508d65 100644 --- a/api/logout.go +++ b/api/logout.go @@ -14,16 +14,17 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { a.clearCookieTokens(config, w) - u, err := getUserFromClaims(ctx, a.db) - if err != nil { - return unauthorizedError("Invalid user").WithInternalError(err) - } + s := getSession(ctx) + u := getUser(ctx) - err = a.db.Transaction(func(tx *storage.Connection) error { + err := a.db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, u, models.LogoutAction, "", nil); terr != nil { return terr } - return models.Logout(tx, u.ID) + if s != nil { + return models.Logout(tx, u.ID) + } + return models.LogoutAllRefreshTokens(tx, u.ID) }) if err != nil { return internalServerError("Error logging out user").WithInternalError(err) diff --git a/api/phone_test.go b/api/phone_test.go index 394fd1e20..5f5f9cc68 100644 --- a/api/phone_test.go +++ b/api/phone_test.go @@ -126,7 +126,7 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { u.PhoneConfirmedAt = &now require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err) cases := []struct { diff --git a/api/reauthenticate.go b/api/reauthenticate.go index 0d7291beb..57eee0b3b 100644 --- a/api/reauthenticate.go +++ b/api/reauthenticate.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" - "github.com/gofrs/uuid" "github.com/netlify/gotrue/api/sms_provider" "github.com/netlify/gotrue/conf" "github.com/netlify/gotrue/models" @@ -20,19 +19,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() config := a.config - claims := getClaims(ctx) - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return badRequestError("Could not read User ID claim") - } - user, err := models.FindUserByID(a.db, userID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - + user := getUser(ctx) email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { @@ -49,7 +36,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } } - err = a.db.Transaction(func(tx *storage.Connection) error { + err := a.db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserReauthenticateAction, "", nil); terr != nil { return terr } diff --git a/api/token.go b/api/token.go index cad41ae3f..88c23edd6 100644 --- a/api/token.go +++ b/api/token.go @@ -26,6 +26,7 @@ type GoTrueClaims struct { AppMetaData map[string]interface{} `json:"app_metadata"` UserMetaData map[string]interface{} `json:"user_metadata"` Role string `json:"role"` + SessionId string `json:"session_id"` } // AccessTokenResponse represents an OAuth2 success response @@ -328,7 +329,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h } } - tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) + tokenString, terr = generateAccessToken(user, newToken.SessionId.UUID.String(), time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) if terr != nil { return internalServerError("error generating jwt token").WithInternalError(terr) } @@ -526,7 +527,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R return sendJSON(w, http.StatusOK, token) } -func generateAccessToken(user *models.User, expiresIn time.Duration, secret string) (string, error) { +func generateAccessToken(user *models.User, sessionId string, expiresIn time.Duration, secret string) (string, error) { claims := &GoTrueClaims{ StandardClaims: jwt.StandardClaims{ Subject: user.ID.String(), @@ -538,6 +539,7 @@ func generateAccessToken(user *models.User, expiresIn time.Duration, secret stri AppMetaData: user.AppMetaData, UserMetaData: user.UserMetaData, Role: user.Role, + SessionId: sessionId, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -560,7 +562,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u return internalServerError("Database error granting user").WithInternalError(terr) } - tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) + tokenString, terr = generateAccessToken(user, refreshToken.SessionId.UUID.String(), time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret) if terr != nil { return internalServerError("error generating jwt token").WithInternalError(terr) } diff --git a/api/user.go b/api/user.go index 816601049..f63909cff 100644 --- a/api/user.go +++ b/api/user.go @@ -4,7 +4,6 @@ import ( "encoding/json" "net/http" - "github.com/gofrs/uuid" "github.com/netlify/gotrue/api/sms_provider" "github.com/netlify/gotrue/logger" "github.com/netlify/gotrue/models" @@ -29,24 +28,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { return badRequestError("Could not read claims") } - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return badRequestError("Could not read User ID claim") - } - aud := a.requestAud(ctx, r) if aud != claims.Audience { return badRequestError("Token audience doesn't match request audience") } - user, err := models.FindUserByID(a.db, userID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - + user := getUser(ctx) return sendJSON(w, http.StatusOK, user) } @@ -62,20 +49,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { return badRequestError("Could not read User Update params: %v", err) } - claims := getClaims(ctx) - userID, err := uuid.FromString(claims.Subject) - if err != nil { - return badRequestError("Could not read User ID claim") - } - - user, err := models.FindUserByID(a.db, userID) - if err != nil { - if models.IsNotFoundError(err) { - return notFoundError(err.Error()) - } - return internalServerError("Database error finding user").WithInternalError(err) - } - + user := getUser(ctx) log := logger.GetLogEntry(r) log.Debugf("Checking params for token %v", params) diff --git a/api/user_test.go b/api/user_test.go index 67f983c88..83163bd1c 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -47,7 +47,7 @@ func (ts *UserTestSuite) SetupTest() { func (ts *UserTestSuite) TestUserGet() { u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err, "Error finding user") - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") req := httptest.NewRequest(http.MethodGet, "http://localhost/user", nil) @@ -111,7 +111,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() { require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone") require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user") - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") var buffer bytes.Buffer @@ -170,7 +170,7 @@ func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() { for _, c := range cases { ts.Run(c.desc, func() { - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") var buffer bytes.Buffer @@ -244,7 +244,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) req.Header.Set("Content-Type", "application/json") - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) @@ -272,7 +272,7 @@ func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() { u.EmailConfirmedAt = &now require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err) // request for reauthentication nonce diff --git a/api/verify_test.go b/api/verify_test.go index 0379d253a..3b17a75cc 100644 --- a/api/verify_test.go +++ b/api/verify_test.go @@ -103,7 +103,7 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { req.Header.Set("Content-Type", "application/json") // Generate access token for request - token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) diff --git a/migrations/20210927181326_add_refresh_token_parent.up.sql b/migrations/20210927181326_add_refresh_token_parent.up.sql index 2a9ff0e3f..3a55c725c 100644 --- a/migrations/20210927181326_add_refresh_token_parent.up.sql +++ b/migrations/20210927181326_add_refresh_token_parent.up.sql @@ -1,8 +1,24 @@ -- adds parent column ALTER TABLE auth.refresh_tokens -ADD COLUMN IF NOT EXISTS parent varchar(255) NULL, -ADD CONSTRAINT refresh_tokens_token_unique UNIQUE ("token"), -ADD CONSTRAINT refresh_tokens_parent_fkey FOREIGN KEY (parent) REFERENCES auth.refresh_tokens("token"); +ADD COLUMN IF NOT EXISTS parent varchar(255) NULL; + +DO $$ +BEGIN + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = 'auth' and table_name='refresh_tokens' and constraint_name='refresh_tokens_token_unique') + THEN + ALTER TABLE "auth"."refresh_tokens" ADD CONSTRAINT refresh_tokens_token_unique UNIQUE ("token"); + END IF; + + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = 'auth' and table_name='refresh_tokens' and constraint_name='refresh_tokens_parent_fkey') + THEN + ALTER TABLE "auth"."refresh_tokens" ADD CONSTRAINT refresh_tokens_parent_fkey FOREIGN KEY (parent) REFERENCES auth.refresh_tokens("token"); + END IF; + + CREATE INDEX IF NOT EXISTS refresh_tokens_parent_idx ON refresh_tokens USING btree (parent); +END $$; -CREATE INDEX IF NOT EXISTS refresh_tokens_parent_idx ON refresh_tokens USING btree (parent); diff --git a/migrations/20220811173540_add_sessions_table.up.sql b/migrations/20220811173540_add_sessions_table.up.sql new file mode 100644 index 000000000..7c4badcba --- /dev/null +++ b/migrations/20220811173540_add_sessions_table.up.sql @@ -0,0 +1,23 @@ +-- Add session_id column to refresh_tokens table +create table if not exists auth.sessions ( + id uuid not null, + user_id uuid not null, + created_at timestamptz null, + updated_at timestamptz null, + constraint sessions_pkey primary key (id), + constraint sessions_user_id_fkey foreign key (user_id) references auth.users(id) on delete cascade +); +comment on table auth.sessions is 'Auth: Stores session data associated to a user.'; + +alter table auth.refresh_tokens +add column if not exists session_id uuid null; + +do $$ +begin + if not exists(select * + from information_schema.constraint_column_usage + where table_schema = 'auth' and table_name='sessions' and constraint_name='refresh_tokens_session_id_fkey') + then + alter table "auth"."refresh_tokens" add constraint refresh_tokens_session_id_fkey foreign key (session_id) references auth.sessions(id) on delete cascade; + end if; +END $$; diff --git a/models/errors.go b/models/errors.go index 6c33c70c2..d2b6fe640 100644 --- a/models/errors.go +++ b/models/errors.go @@ -5,6 +5,8 @@ func IsNotFoundError(err error) bool { switch err.(type) { case UserNotFoundError: return true + case SessionNotFoundError: + return true case ConfirmationTokenNotFoundError: return true case RefreshTokenNotFoundError: @@ -19,6 +21,12 @@ func IsNotFoundError(err error) bool { return false } +type SessionNotFoundError struct{} + +func (e SessionNotFoundError) Error() string { + return "Session not found" +} + // UserNotFoundError represents when a user is not found. type UserNotFoundError struct{} diff --git a/models/refresh_token.go b/models/refresh_token.go index e20de2ebf..5f1bc4630 100644 --- a/models/refresh_token.go +++ b/models/refresh_token.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "github.com/gobuffalo/nulls" "github.com/gobuffalo/pop/v5" "github.com/gofrs/uuid" "github.com/netlify/gotrue/crypto" @@ -20,7 +21,8 @@ type RefreshToken struct { UserID uuid.UUID `db:"user_id"` - Parent storage.NullString `db:"parent"` + Parent storage.NullString `db:"parent"` + SessionId nulls.UUID `db:"session_id"` Revoked bool `db:"revoked"` CreatedAt time.Time `db:"created_at"` @@ -60,14 +62,19 @@ func GrantRefreshTokenSwap(r *http.Request, tx *storage.Connection, user *User, // RevokeTokenFamily revokes all refresh tokens that descended from the provided token. func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error { + var err error tablename := (&pop.Model{Value: RefreshToken{}}).TableName() - err := tx.RawQuery(` - with recursive token_family as ( - select id, user_id, token, revoked, parent from `+tablename+` where parent = ? - union - select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent - ) - update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec() + if token.SessionId.Valid && len(token.SessionId.UUID) > 0 { + err = tx.RawQuery(`update `+tablename+` set revoked = true where session_id = ?;`, token.SessionId).Exec() + } else { + err = tx.RawQuery(` + with recursive token_family as ( + select id, user_id, token, revoked, parent from `+tablename+` where parent = ? + union + select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent + ) + update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec() + } if err != nil { return err } @@ -87,11 +94,6 @@ func GetValidChildToken(tx *storage.Connection, token *RefreshToken) (*RefreshTo return refreshToken, nil } -// Logout deletes all refresh tokens for a user. -func Logout(tx *storage.Connection, id uuid.UUID) error { - return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: RefreshToken{}}).TableName()+" WHERE instance_id = ? AND user_id = ?", uuid.Nil, id).Exec() -} - func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken) (*RefreshToken, error) { token := &RefreshToken{ UserID: user.ID, @@ -100,6 +102,13 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok } if oldToken != nil { token.Parent = storage.NullString(oldToken.Token) + token.SessionId = oldToken.SessionId + } else { + session, err := CreateSession(tx, user) + if err != nil { + return nil, errors.Wrap(err, "Error generated unique session id") + } + token.SessionId = nulls.NewUUID(session.ID) } if err := tx.Create(token); err != nil { @@ -111,3 +120,9 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok } return token, nil } + +// Deprecated. For backward compatibility, some access tokens may not have a sessionId. Use models.Logout instead. +// LogoutAllRefreshTokens deletes all sessions for a user. +func LogoutAllRefreshTokens(tx *storage.Connection, userId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: RefreshToken{}}).TableName()+" WHERE user_id = ?", userId).Exec() +} diff --git a/models/sessions.go b/models/sessions.go new file mode 100644 index 000000000..331d9fe5a --- /dev/null +++ b/models/sessions.go @@ -0,0 +1,66 @@ +package models + +import ( + "database/sql" + "time" + + "github.com/gobuffalo/pop/v5" + "github.com/gofrs/uuid" + "github.com/netlify/gotrue/storage" + "github.com/pkg/errors" +) + +type Session struct { + ID uuid.UUID `json:"-" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +func (Session) TableName() string { + tableName := "sessions" + return tableName +} + +func NewSession(user *User) (*Session, error) { + id, err := uuid.NewV4() + if err != nil { + return nil, errors.Wrap(err, "Error generating unique session id") + } + session := &Session{ + ID: id, + UserID: user.ID, + } + return session, nil +} + +func CreateSession(tx *storage.Connection, user *User) (*Session, error) { + session, err := NewSession(user) + if err != nil { + return nil, err + } + if err := tx.Create(session); err != nil { + return nil, errors.Wrap(err, "error creating session") + } + return session, nil +} + +func FindSessionById(tx *storage.Connection, id uuid.UUID) (*Session, error) { + session := &Session{} + if err := tx.Eager().Q().Where("id = ?", id).First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + return nil, errors.Wrap(err, "error finding user") + } + return session, nil +} + +// Logout deletes all sessions for a user. +func Logout(tx *storage.Connection, userId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ?", userId).Exec() +} + +func LogoutSession(tx *storage.Connection, sessionId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id = ?", sessionId).Exec() +}