Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: pass GrantParams object to issueRefreshToken #662

Merged
merged 1 commit into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
providerType := getExternalProviderType(ctx)
var userData *provider.UserProvidedData
var providerAccessToken string
var providerRefreshToken string = ""
var providerRefreshToken string
var grantParams models.GrantParams

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
Expand Down Expand Up @@ -275,7 +277,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return oauthError("server_error", terr.Error())
}
Expand Down
3 changes: 2 additions & 1 deletion api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var grantParams models.GrantParams
params.Aud = a.requestAud(ctx, r)

switch params.Provider {
Expand Down Expand Up @@ -218,7 +219,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down
10 changes: 6 additions & 4 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return unprocessableEntityError("Only an email address or phone number should be provided on login.")
}
var user *models.User
var grantParams models.GrantParams
var provider string
if params.Email != "" {
provider = "email"
Expand Down Expand Up @@ -218,7 +219,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -419,6 +420,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
}

var user *models.User
var grantParams models.GrantParams
var token *AccessTokenResponse
err = a.db.Transaction(func(tx *storage.Connection) error {
var terr error
Expand Down Expand Up @@ -507,7 +509,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
}
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return oauthError("server_error", terr.Error())
}
Expand Down Expand Up @@ -545,7 +547,7 @@ func generateAccessToken(user *models.User, sessionId string, expiresIn time.Dur
return token.SignedString([]byte(secret))
}

func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User) (*AccessTokenResponse, error) {
func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, grantParams models.GrantParams) (*AccessTokenResponse, error) {
config := a.config

now := time.Now()
Expand All @@ -556,7 +558,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u

err := conn.Transaction(func(tx *storage.Connection) error {
var terr error
refreshToken, terr = models.GrantAuthenticatedUser(tx, user, models.GrantParams{})
refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams)
if terr != nil {
return internalServerError("Database error granting user").WithInternalError(terr)
}
Expand Down
16 changes: 9 additions & 7 deletions api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error {
params.RedirectTo = a.getRedirectURLOrReferrer(r, r.FormValue("redirect_to"))

var (
user *models.User
err error
token *AccessTokenResponse
user *models.User
grantParams models.GrantParams
err error
token *AccessTokenResponse
)

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand Down Expand Up @@ -121,7 +122,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down Expand Up @@ -182,8 +183,9 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error {
}

var (
user *models.User
token *AccessTokenResponse
user *models.User
grantParams models.GrantParams
token *AccessTokenResponse
)

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand Down Expand Up @@ -217,7 +219,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error {
return terr
}

token, terr = a.issueRefreshToken(ctx, tx, user)
token, terr = a.issueRefreshToken(ctx, tx, user, grantParams)
if terr != nil {
return terr
}
Expand Down