Skip to content

Commit

Permalink
feat: add hashed refresh tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Sep 18, 2022
1 parent b6bec2f commit d1605c0
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 59 deletions.
53 changes: 19 additions & 34 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,32 +271,19 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return oauthError("invalid_grant", "Invalid Refresh Token")
}

var newToken *models.RefreshToken
refreshTokenReuseWindow := token.UpdatedAt.Add(time.Second * time.Duration(config.Security.RefreshTokenReuseInterval))

if token.Revoked {
a.clearCookieTokens(config, w)
err = db.Transaction(func(tx *storage.Connection) error {
validToken, terr := models.GetValidChildToken(tx, token)
if terr != nil {
if errors.Is(terr, models.RefreshTokenNotFoundError{}) {
// revoked token has no descendants
return nil
}
return terr
}
// check if token is the last previous revoked token
if validToken.Parent == storage.NullString(token.Token) {
refreshTokenReuseWindow := token.UpdatedAt.Add(time.Second * time.Duration(config.Security.RefreshTokenReuseInterval))
if time.Now().Before(refreshTokenReuseWindow) {
newToken = validToken
}
}
return nil
})
if err != nil {
return internalServerError("Error validating reuse interval").WithInternalError(err)
}
if time.Now().Before(refreshTokenReuseWindow) {
// token can still be used, and the browser/client probably did
// not sync refresh tokens concurrently this often happens with
// multiple tabs without tab synchronization
} else {
// for some reason the browser/client did not sync refresh
// tokens in the window, the token can't be reused

a.clearCookieTokens(config, w)

if newToken == nil {
if config.Security.RefreshTokenRotationEnabled {
// Revoke all tokens in token family
err = db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -310,11 +297,11 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError(err.Error())
}
}

return oauthError("invalid_grant", "Invalid Refresh Token").WithInternalMessage("Possible abuse attempt: %v", r)
}
}

var tokenString string
var newTokenResponse *AccessTokenResponse

err = db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -323,23 +310,21 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return terr
}

if newToken == nil {
newToken, terr = models.GrantRefreshTokenSwap(r, tx, user, token)
if terr != nil {
return internalServerError(terr.Error())
}
newRefreshToken, terr := models.GrantRefreshTokenSwap(r, tx, user, token)
if terr != nil {
return internalServerError(terr.Error())
}

tokenString, terr = generateAccessToken(user, newToken.SessionId, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
accessToken, terr := generateAccessToken(user, newRefreshToken.SessionId, time.Second*time.Duration(config.JWT.Exp), config.JWT.Secret)
if terr != nil {
return internalServerError("error generating jwt token").WithInternalError(terr)
}

newTokenResponse = &AccessTokenResponse{
Token: tokenString,
Token: accessToken,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
RefreshToken: newToken.Token,
RefreshToken: newRefreshToken.Token, // sending back the original, not hashed token
User: user,
}
if terr = a.setCookieTokens(config, newTokenResponse, false, w); terr != nil {
Expand Down Expand Up @@ -586,7 +571,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u
Token: tokenString,
TokenType: "bearer",
ExpiresIn: config.JWT.Exp,
RefreshToken: refreshToken.Token,
RefreshToken: refreshToken.Token, // sending back the original, not hashed token
User: user,
}, nil
}
Expand Down
36 changes: 25 additions & 11 deletions api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving foo user")
first, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{})
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), first.Token)
second, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, first)
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), second.Token)
third, err := models.GrantRefreshTokenSwap(&http.Request{}, ts.API.db, u, second)
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), third.Token)
require.False(ts.T(), third.Revoked)

cases := []struct {
desc string
Expand All @@ -169,18 +173,28 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
expectedBody map[string]interface{}
}{
{
"Valid refresh within reuse interval",
true,
"Valid refresh token, not revoked",
false,
0,
third.Token,
http.StatusOK,
map[string]interface{}{
// some refresh token should be returned, we can't know it's value
},
},
{
"Valid refresh token, revoked but within reuse interval",
false,
30,
second.Token,
http.StatusOK,
map[string]interface{}{
"refresh_token": third.Token,
// some refresh token should be returned, we can't know it's value
},
},
{
"Invalid refresh, first token is not the previous revoked token",
true,
"Invalid refresh token, revoked but outside of reuse interval",
false,
0,
first.Token,
http.StatusBadRequest,
Expand All @@ -190,20 +204,20 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
},
},
{
"Invalid refresh, revoked third token",
"Invalid refresh token, revoked and outside of reuse interval and will invalidate all",
true,
0,
second.Token,
first.Token,
http.StatusBadRequest,
map[string]interface{}{
"error": "invalid_grant",
"error_description": "Invalid Refresh Token",
},
},
{
"Invalid refresh, third token revoked by previous case",
"Invalid refresh token, revoked and outside of reuse interval, invalidated in previous case",
true,
30,
0,
third.Token,
http.StatusBadRequest,
map[string]interface{}{
Expand All @@ -213,7 +227,7 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
},
}

for _, c := range cases {
for i, c := range cases {
ts.Run(c.desc, func() {
ts.Config.Security.RefreshTokenRotationEnabled = c.refreshTokenRotationEnabled
ts.Config.Security.RefreshTokenReuseInterval = c.reuseInterval
Expand All @@ -230,7 +244,7 @@ func (ts *TokenTestSuite) TestTokenRefreshTokenRotation() {
data := make(map[string]interface{})
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
for k, v := range c.expectedBody {
require.Equal(ts.T(), v, data[k])
require.Equal(ts.T(), v, data[k], "mismatch on example %d with key %s", i, k)
}
})
}
Expand Down
41 changes: 37 additions & 4 deletions crypto/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package crypto

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"math"
Expand All @@ -12,13 +14,44 @@ import (
"github.com/pkg/errors"
)

// SecureToken is an object that represents a unique randomly generated string
// that can be sent to a client and/or stored in a database for lookup only.
type SecureToken struct {
Original string `json:"-"`
Hashed string `json:"-"`
}

// SecureToken creates a new random token
func SecureToken() string {
b := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
func GenerateSecureToken() SecureToken {
bytes := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
panic(err.Error()) // rand should never fail
}
return base64.RawURLEncoding.EncodeToString(b)

original := base64.RawURLEncoding.EncodeToString(bytes)

return SecureToken{
Original: original,
Hashed: HashSHA224Base64(original),
}
}

// HashSHA224 hashes the provided string with SHA256/224 and returns it as
// Base64 URL encoded. SHA256/224 is a good hashing function as it's shorter
// than SHA256 but also is not succeptible to a length extension attack.
func HashSHA224Base64(str string) string {
bytes := sha256.Sum224([]byte(str))

return base64.RawURLEncoding.EncodeToString(bytes[:])
}

// HashSHA224 hashes the provided string with SHA256/224 and returns it as
// hex encoded. SHA256/224 is a good hashing function as it's shorter
// than SHA256 but also is not succeptible to a length extension attack.
func HashSHA224Hex(str string) string {
bytes := sha256.Sum224([]byte(str))

return hex.EncodeToString(bytes[:])
}

// GenerateOtp generates a random n digit otp
Expand Down
19 changes: 13 additions & 6 deletions models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import (
type RefreshToken struct {
ID int64 `db:"id"`

Token string `db:"token"`
Token string `db:"-"`
HashedToken string `db:"token"`

UserID uuid.UUID `db:"user_id"`

Expand Down Expand Up @@ -77,7 +78,7 @@ func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error {
union
select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent
)
update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec()
update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.HashedToken).Exec()
}
if err != nil {
return err
Expand All @@ -88,7 +89,7 @@ func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error {
// GetValidChildToken returns the child token of the token provided if the child is not revoked.
func GetValidChildToken(tx *storage.Connection, token *RefreshToken) (*RefreshToken, error) {
refreshToken := &RefreshToken{}
err := tx.Q().Where("parent = ? and revoked = false", token.Token).First(refreshToken)
err := tx.Q().Where("parent = ? and revoked = false", token.HashedToken).First(refreshToken)
if err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, RefreshTokenNotFoundError{}
Expand All @@ -99,13 +100,18 @@ func GetValidChildToken(tx *storage.Connection, token *RefreshToken) (*RefreshTo
}

func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) {
secureToken := crypto.GenerateSecureToken()

token := &RefreshToken{
UserID: user.ID,
Token: crypto.SecureToken(),
Parent: "",
Token: secureToken.Original, // returning the original token
// storing the hashed token, with H: prefix that identifies
// hashed values (for backward compatibility)
HashedToken: "H:" + secureToken.Hashed,
Parent: "",
}
if oldToken != nil {
token.Parent = storage.NullString(oldToken.Token)
token.Parent = storage.NullString(oldToken.HashedToken)
token.SessionId = oldToken.SessionId
}

Expand All @@ -121,6 +127,7 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
if err := tx.Create(token); err != nil {
return nil, errors.Wrap(err, "error creating refresh token")
}
token.Token = secureToken.Original

if err := user.UpdateLastSignInAt(tx); err != nil {
return nil, errors.Wrap(err, "error update user`s last_sign_in field")
Expand Down
19 changes: 17 additions & 2 deletions models/refresh_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/crypto"
"github.com/netlify/gotrue/storage"
"github.com/netlify/gotrue/storage/test"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -41,20 +42,34 @@ func (ts *RefreshTokenTestSuite) TestGrantAuthenticatedUser() {
require.NoError(ts.T(), err)

require.NotEmpty(ts.T(), r.Token)
require.NotEmpty(ts.T(), r.HashedToken)
require.NotEqual(ts.T(), r.Token, r.HashedToken)
require.Equal(ts.T(), r.HashedToken, "H:"+crypto.HashSHA224Base64(r.Token))
require.Equal(ts.T(), u.ID, r.UserID)
}

func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() {
u := ts.createUser()
r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{})
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), r.Token)
require.NotEmpty(ts.T(), r.HashedToken)

s, err := GrantRefreshTokenSwap(&http.Request{}, ts.db, u, r)
require.NoError(ts.T(), err)
require.NotEmpty(ts.T(), s.Token)
require.NotEmpty(ts.T(), s.HashedToken)
require.NotEqual(ts.T(), s.Token, s.HashedToken)
require.Equal(ts.T(), s.Parent, storage.NullString(r.HashedToken))

_, nr, err := FindUserWithRefreshToken(ts.db, r.Token)
_, nr, err := FindUserWithRefreshToken(ts.db, r.Token) // using the original not hashed token
require.NoError(ts.T(), err)

require.Equal(ts.T(), nr.Token, r.Token)
require.NotEmpty(ts.T(), nr.HashedToken)
require.NotEqual(ts.T(), nr.Token, nr.HashedToken)
require.Equal(ts.T(), nr.HashedToken, "H:"+crypto.HashSHA224Base64(r.Token))

require.Equal(ts.T(), r.ID, nr.ID)
require.True(ts.T(), nr.Revoked, "expected old token to be revoked")

Expand All @@ -68,7 +83,7 @@ func (ts *RefreshTokenTestSuite) TestLogout() {
require.NoError(ts.T(), err)

require.NoError(ts.T(), Logout(ts.db, u.ID))
u, r, err = FindUserWithRefreshToken(ts.db, r.Token)
u, r, err = FindUserWithRefreshToken(ts.db, r.Token) // using the original not hashed token
require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r)

require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError")
Expand Down
15 changes: 14 additions & 1 deletion models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/gobuffalo/pop/v5"
"github.com/gofrs/uuid"
"github.com/netlify/gotrue/crypto"
"github.com/netlify/gotrue/storage"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
Expand Down Expand Up @@ -371,14 +372,26 @@ func FindUserByTokenAndTokenType(tx *storage.Connection, token string, tokenType

// FindUserWithRefreshToken finds a user from the provided refresh token.
func FindUserWithRefreshToken(tx *storage.Connection, token string) (*User, *RefreshToken, error) {
// it may just be that an attacker has stolen a hashed token value from the database
// (which is identified by the H: prefix). this attacker would be able to look up the
// exact token unless we remove the `H:` prefix.
// technically this input value should have been rejected
// before arriving in this function, but it is now too late to
// reject and is best if the data is sanitized.
token = strings.TrimPrefix(token, "H:")

hashedToken := "H:" + crypto.HashSHA224Base64(token)

refreshToken := &RefreshToken{}
if err := tx.Where("token = ?", token).First(refreshToken); err != nil {
if err := tx.Where("token = ? OR token = ?", hashedToken, token).First(refreshToken); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, nil, RefreshTokenNotFoundError{}
}
return nil, nil, errors.Wrap(err, "error finding refresh token")
}

refreshToken.Token = token // database only holds the hashed token, so adding the original here

user, err := findUser(tx, "id = ?", refreshToken.UserID)
if err != nil {
return nil, nil, err
Expand Down

0 comments on commit d1605c0

Please sign in to comment.