Skip to content

Commit

Permalink
refactor: pass GrantParams object to issueRefreshToken (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
hf authored Sep 7, 2022
1 parent e62c44c commit 8b1ec24
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
6 changes: 4 additions & 2 deletions api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
providerType := getExternalProviderType(ctx)
var userData *provider.UserProvidedData
var providerAccessToken string
var providerRefreshToken string = ""
var providerRefreshToken string
var grantParams models.GrantParams

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
Expand Down Expand Up @@ -275,7 +277,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return oauthError("server_error", terr.Error())
}
Expand Down
3 changes: 2 additions & 1 deletion api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var grantParams models.GrantParams
params.Aud = a.requestAud(ctx, r)

switch params.Provider {
Expand Down Expand Up @@ -218,7 +219,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down
10 changes: 6 additions & 4 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return unprocessableEntityError("Only an email address or phone number should be provided on login.")
}
var user *models.User
var grantParams models.GrantParams
var provider string
if params.Email != "" {
provider = "email"
Expand Down Expand Up @@ -219,7 +220,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -420,6 +421,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
}

var user *models.User
var grantParams models.GrantParams
var token *AccessTokenResponse
err = a.db.Transaction(func(tx *storage.Connection) error {
var terr error
Expand Down Expand Up @@ -508,7 +510,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
}
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return oauthError("server_error", terr.Error())
}
Expand Down Expand Up @@ -550,7 +552,7 @@ func generateAccessToken(user *models.User, sessionId *uuid.UUID, expiresIn time
return token.SignedString([]byte(secret))
}

func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User) (*AccessTokenResponse, error) {
func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, grantParams models.GrantParams) (*AccessTokenResponse, error) {
config := a.config

now := time.Now()
Expand All @@ -561,7 +563,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u

err := conn.Transaction(func(tx *storage.Connection) error {
var terr error
refreshToken, terr = models.GrantAuthenticatedUser(tx, user, models.GrantParams{})
refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams)
if terr != nil {
return internalServerError("Database error granting user").WithInternalError(terr)
}
Expand Down
16 changes: 9 additions & 7 deletions api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error {
params.RedirectTo = a.getRedirectURLOrReferrer(r, r.FormValue("redirect_to"))

var (
user *models.User
err error
token *AccessTokenResponse
user *models.User
grantParams models.GrantParams
err error
token *AccessTokenResponse
)

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand Down Expand Up @@ -121,7 +122,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -182,8 +183,9 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error {
}

var (
user *models.User
token *AccessTokenResponse
user *models.User
grantParams models.GrantParams
token *AccessTokenResponse
)

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand Down Expand Up @@ -217,7 +219,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down

0 comments on commit 8b1ec24

Please sign in to comment.