diff --git a/api/external.go b/api/external.go index e005fb4e1..5fa3b0a9a 100644 --- a/api/external.go +++ b/api/external.go @@ -110,7 +110,9 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re providerType := getExternalProviderType(ctx) var userData *provider.UserProvidedData var providerAccessToken string - var providerRefreshToken string = "" + var providerRefreshToken string + var grantParams models.GrantParams + if providerType == "twitter" { // future OAuth1.0 providers will use this method oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) @@ -275,7 +277,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re } } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return oauthError("server_error", terr.Error()) } diff --git a/api/signup.go b/api/signup.go index a2ea8f5a7..175aecda9 100644 --- a/api/signup.go +++ b/api/signup.go @@ -64,6 +64,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } var user *models.User + var grantParams models.GrantParams params.Aud = a.requestAud(ctx, r) switch params.Provider { @@ -218,7 +219,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return terr } diff --git a/api/token.go b/api/token.go index cd5356700..a4685befa 100644 --- a/api/token.go +++ b/api/token.go @@ -171,6 +171,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return unprocessableEntityError("Only an email address or phone number should be provided on login.") } var user *models.User + var grantParams models.GrantParams var provider string if params.Email != "" { provider = "email" @@ -218,7 +219,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return terr } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return terr } @@ -419,6 +420,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } var user *models.User + var grantParams models.GrantParams var token *AccessTokenResponse err = a.db.Transaction(func(tx *storage.Connection) error { var terr error @@ -507,7 +509,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return oauthError("server_error", terr.Error()) } @@ -545,7 +547,7 @@ func generateAccessToken(user *models.User, sessionId string, expiresIn time.Dur return token.SignedString([]byte(secret)) } -func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User) (*AccessTokenResponse, error) { +func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, user *models.User, grantParams models.GrantParams) (*AccessTokenResponse, error) { config := a.config now := time.Now() @@ -556,7 +558,7 @@ func (a *API) issueRefreshToken(ctx context.Context, conn *storage.Connection, u err := conn.Transaction(func(tx *storage.Connection) error { var terr error - refreshToken, terr = models.GrantAuthenticatedUser(tx, user, models.GrantParams{}) + refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams) if terr != nil { return internalServerError("Database error granting user").WithInternalError(terr) } diff --git a/api/verify.go b/api/verify.go index 07890cf40..c46827def 100644 --- a/api/verify.go +++ b/api/verify.go @@ -77,9 +77,10 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error { params.RedirectTo = a.getRedirectURLOrReferrer(r, r.FormValue("redirect_to")) var ( - user *models.User - err error - token *AccessTokenResponse + user *models.User + grantParams models.GrantParams + err error + token *AccessTokenResponse ) err = a.db.Transaction(func(tx *storage.Connection) error { @@ -121,7 +122,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request) error { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return terr } @@ -182,8 +183,9 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error { } var ( - user *models.User - token *AccessTokenResponse + user *models.User + grantParams models.GrantParams + token *AccessTokenResponse ) err = a.db.Transaction(func(tx *storage.Connection) error { @@ -217,7 +219,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request) error { return terr } - token, terr = a.issueRefreshToken(ctx, tx, user) + token, terr = a.issueRefreshToken(ctx, tx, user, grantParams) if terr != nil { return terr }