Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add session_id to refresh_tokens table #600

Merged
merged 9 commits into from
Aug 24, 2022
2 changes: 1 addition & 1 deletion api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (ts *AdminTestSuite) makeSuperAdmin(email string) string {

u.Role = "supabase_admin"

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
6 changes: 2 additions & 4 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati

r.With(api.requireAuthentication).Post("/logout", api.Logout)

r.Route("/reauthenticate", func(r *router) {
r.Use(api.requireAuthentication)
r.With(api.requireAuthentication).Route("/reauthenticate", func(r *router) {
r.Get("/", api.Reauthenticate)
})

r.Route("/user", func(r *router) {
r.Use(api.requireAuthentication)
r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(sharedLimiter).Put("/", api.UserUpdate)
})
Expand Down
2 changes: 1 addition & 1 deletion api/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (ts *AuditTestSuite) makeSuperAdmin(email string) string {

u.Role = "supabase_admin"

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
51 changes: 50 additions & 1 deletion api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package api

import (
"context"
"errors"
"fmt"
"net/http"
"time"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
Expand All @@ -20,7 +22,16 @@ func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (con
return nil, err
}

return a.parseJWTClaims(token, r, w)
ctx, err := a.parseJWTClaims(token, r, w)
if err != nil {
return ctx, err
}

ctx, err = a.maybeLoadUserOrSession(ctx)
if err != nil {
return ctx, err
}
return ctx, err
}

func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
Expand Down Expand Up @@ -71,3 +82,41 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request, w http.ResponseWrit

return withToken(ctx, token), nil
}

func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) {
claims := getClaims(ctx)
if claims == nil {
return ctx, errors.New("invalid token")
}

if claims.Subject == "" {
return nil, errors.New("invalid claim: subject missing")
}

var user *models.User
if claims.Subject != "" {
userId, err := uuid.FromString(claims.Subject)
if err != nil {
return ctx, err
}
user, err = models.FindUserByID(a.db, userId)
if err != nil {
return ctx, err
}
ctx = withUser(ctx, user)
}

var session *models.Session
if claims.SessionId != "" {
sessionId, err := uuid.FromString(claims.SessionId)
if err != nil {
return ctx, err
}
session, err = models.FindSessionById(a.db, sessionId)
if err != nil {
return ctx, err
}
ctx = withSession(ctx, session)
}
return ctx, nil
}
19 changes: 17 additions & 2 deletions api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
signatureKey = contextKey("signature")
externalProviderTypeKey = contextKey("external_provider_type")
userKey = contextKey("user")
sessionKey = contextKey("session")
externalReferrerKey = contextKey("external_referrer")
functionHooksKey = contextKey("function_hooks")
adminUserKey = contextKey("admin_user")
Expand Down Expand Up @@ -65,12 +66,12 @@ func getRequestID(ctx context.Context) string {
return obj.(string)
}

// withUser adds the user id to the context.
// withUser adds the user to the context.
func withUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, userKey, u)
}

// getUser reads the user id from the context.
// getUser reads the user from the context.
func getUser(ctx context.Context) *models.User {
obj := ctx.Value(userKey)
if obj == nil {
Expand All @@ -79,6 +80,20 @@ func getUser(ctx context.Context) *models.User {
return obj.(*models.User)
}

// withSession adds the session to the context.
func withSession(ctx context.Context, s *models.Session) context.Context {
return context.WithValue(ctx, sessionKey, s)
}

// getSession reads the session from the context.
func getSession(ctx context.Context) *models.Session {
obj := ctx.Value(sessionKey)
if obj == nil {
return nil
}
return obj.(*models.Session)
}

// withSignature adds the provided request ID to the context.
func withSignature(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, signatureKey, id)
Expand Down
18 changes: 0 additions & 18 deletions api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/gofrs/uuid"
"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -49,23 +48,6 @@ func sendJSON(w http.ResponseWriter, status int, obj interface{}) error {
return err
}

func getUserFromClaims(ctx context.Context, conn *storage.Connection) (*models.User, error) {
claims := getClaims(ctx)
if claims == nil {
return nil, errors.New("Invalid token")
}

if claims.Subject == "" {
return nil, errors.New("Invalid claim: id")
}

userID, err := uuid.FromString(claims.Subject)
if err != nil {
return nil, errors.New("Invalid user ID")
}
return models.FindUserByID(conn, userID)
}

func (a *API) isAdmin(ctx context.Context, u *models.User, aud string) bool {
config := a.config
if aud == "" {
Expand Down
2 changes: 1 addition & 1 deletion api/invite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (ts *InviteTestSuite) makeSuperAdmin(email string) string {

u.Role = "supabase_admin"

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")

p := jwt.Parser{ValidMethods: []string{jwt.SigningMethodHS256.Name}}
Expand Down
13 changes: 7 additions & 6 deletions api/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error {

a.clearCookieTokens(config, w)

u, err := getUserFromClaims(ctx, a.db)
if err != nil {
return unauthorizedError("Invalid user").WithInternalError(err)
}
s := getSession(ctx)
u := getUser(ctx)

err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, u, models.LogoutAction, "", nil); terr != nil {
return terr
}
return models.Logout(tx, u.ID)
if s != nil {
return models.Logout(tx, u.ID)
}
return models.LogoutAllRefreshTokens(tx, u.ID)
})
if err != nil {
return internalServerError("Error logging out user").WithInternalError(err)
Expand Down
2 changes: 1 addition & 1 deletion api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() {
u.PhoneConfirmedAt = &now
require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user")

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
token, err := generateAccessToken(u, "", time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err)

cases := []struct {
Expand Down
17 changes: 2 additions & 15 deletions api/reauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net/http"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/models"
Expand All @@ -20,19 +19,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
config := a.config

claims := getClaims(ctx)
userID, err := uuid.FromString(claims.Subject)
if err != nil {
return badRequestError("Could not read User ID claim")
}
user, err := models.FindUserByID(a.db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return notFoundError(err.Error())
}
return internalServerError("Database error finding user").WithInternalError(err)
}

user := getUser(ctx)
email, phone := user.GetEmail(), user.GetPhone()

if email == "" && phone == "" {
Expand All @@ -49,7 +36,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
}
}

err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserReauthenticateAction, "", nil); terr != nil {
return terr
}
Expand Down
8 changes: 5 additions & 3 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type GoTrueClaims struct {
AppMetaData map[string]interface{} `json:"app_metadata"`
UserMetaData map[string]interface{} `json:"user_metadata"`
Role string `json:"role"`
SessionId string `json:"session_id"`
}

// AccessTokenResponse represents an OAuth2 success response
Expand Down Expand Up @@ -328,7 +329,7 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
}
}

tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
tokenString, terr = generateAccessToken(user, newToken.SessionId.UUID.String(), time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
}
Expand Down Expand Up @@ -526,7 +527,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
return sendJSON(w, http.StatusOK, token)
}

func generateAccessToken(user *models.User, expiresIn time.Duration, secret string) (string, error) {
func generateAccessToken(user *models.User, sessionId string, expiresIn time.Duration, secret string) (string, error) {
claims := &GoTrueClaims{
StandardClaims: jwt.StandardClaims{
Subject: user.ID.String(),
Expand All @@ -538,6 +539,7 @@ func generateAccessToken(user *models.User, expiresIn time.Duration, secret stri
AppMetaData: user.AppMetaData,
UserMetaData: user.UserMetaData,
Role: user.Role,
SessionId: sessionId,
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
Expand All @@ -560,7 +562,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
return internalServerError("Database error granting user").WithInternalError(terr)
}

tokenString, terr = generateAccessToken(user, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
tokenString, terr = generateAccessToken(user, refreshToken.SessionId.UUID.String(), time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
}
Expand Down
30 changes: 2 additions & 28 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"net/http"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/logger"
"github.com/netlify/gotrue/models"
Expand All @@ -29,24 +28,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error {
return badRequestError("Could not read claims")
}

userID, err := uuid.FromString(claims.Subject)
if err != nil {
return badRequestError("Could not read User ID claim")
}

aud := a.requestAud(ctx, r)
if aud != claims.Audience {
return badRequestError("Token audience doesn't match request audience")
}

user, err := models.FindUserByID(a.db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return notFoundError(err.Error())
}
return internalServerError("Database error finding user").WithInternalError(err)
}

user := getUser(ctx)
return sendJSON(w, http.StatusOK, user)
}

Expand All @@ -62,20 +49,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
return badRequestError("Could not read User Update params: %v", err)
}

claims := getClaims(ctx)
userID, err := uuid.FromString(claims.Subject)
if err != nil {
return badRequestError("Could not read User ID claim")
}

user, err := models.FindUserByID(a.db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return notFoundError(err.Error())
}
return internalServerError("Database error finding user").WithInternalError(err)
}

user := getUser(ctx)
log := logger.GetLogEntry(r)
log.Debugf("Checking params for token %v", params)

Expand Down