Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
user := getUser(ctx)
config := a.config
adminUser := getAdminUser(ctx)
db := a.db.WithContext(ctx)

// ShouldSoftDelete defaults to false
params := &adminUserDeleteParams{}
Expand All @@ -525,7 +526,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
}
}

err := a.db.Transaction(func(tx *storage.Connection) error {
err := db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{
"user_id": user.ID,
"user_email": user.Email,
Expand Down Expand Up @@ -575,8 +576,9 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
config := a.config
user := getUser(ctx)
factor := getFactor(ctx)
db := a.db.WithContext(ctx)

err := a.db.Transaction(func(tx *storage.Connection) error {
err := db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{
"user_id": user.ID,
"factor_id": factor.ID,
Expand Down Expand Up @@ -608,12 +610,13 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params := &adminUserUpdateFactorParams{}
db := a.db.WithContext(ctx)

if err := retrieveRequestParams(r, params); err != nil {
return err
}

err := a.db.Transaction(func(tx *storage.Connection) error {
err := db.Transaction(func(tx *storage.Connection) error {
if params.FriendlyName != "" {
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
return terr
Expand Down
8 changes: 4 additions & 4 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ

flowStateID := ""
if isPKCEFlow(flowType) {
flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
flowState, err := generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var flowState *models.FlowState
// if there's a non-empty FlowStateID we perform PKCE Flow
if flowStateID := getFlowStateID(ctx); flowStateID != "" {
flowState, err = models.FindFlowStateByID(a.db, flowStateID)
flowState, err = models.FindFlowStateByID(db, flowStateID)
if models.IsNotFoundError(err) {
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err)
} else if err != nil {
Expand Down Expand Up @@ -506,7 +506,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p
return user, nil
}

func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storage.Connection) (context.Context, error) {
var state string
switch r.Method {
case http.MethodPost:
Expand Down Expand Up @@ -564,7 +564,7 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)")
}
u, err := models.FindUserByID(a.db, linkingTargetUserID)
u, err := models.FindUserByID(db, linkingTargetUserID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found")
Expand Down
4 changes: 3 additions & 1 deletion internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type OAuthProviderData struct {
// extracting the provider requested
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
db := a.db.WithContext(ctx)

oauthToken := r.URL.Query().Get("oauth_token")
if oauthToken != "" {
ctx = withRequestToken(ctx, oauthToken)
Expand All @@ -37,7 +39,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
}

var err error
ctx, err = a.loadExternalState(ctx, r)
ctx, err = a.loadExternalState(ctx, r, db)
if err != nil {
u, uerr := url.ParseRequestURI(a.config.SiteURL)
if uerr != nil {
Expand Down
14 changes: 7 additions & 7 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@ import (

func (a *API) triggerBeforeUserCreated(
r *http.Request,
conn *storage.Connection,
db *storage.Connection,
user *models.User,
) error {
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
return nil
}
if err := checkTX(conn); err != nil {
if err := checkTX(db); err != nil {
return err
}

req := v0hooks.NewBeforeUserCreatedInput(r, user)
res := new(v0hooks.BeforeUserCreatedOutput)
return a.hooksMgr.InvokeHook(conn, r, req, res)
return a.hooksMgr.InvokeHook(db, r, req, res)
}

func (a *API) triggerBeforeUserCreatedExternal(
r *http.Request,
conn *storage.Connection,
db *storage.Connection,
userData *provider.UserProvidedData,
providerType string,
) error {
if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) {
return nil
}
if err := checkTX(conn); err != nil {
if err := checkTX(db); err != nil {
return err
}

Expand All @@ -55,7 +55,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
err error
decision models.AccountLinkingResult
)
err = a.db.Transaction(func(tx *storage.Connection) error {
err = db.Transaction(func(tx *storage.Connection) error {
decision, err = models.DetermineAccountLinking(
tx, config, userData.Emails, aud,
providerType, userData.Metadata.Subject)
Expand Down Expand Up @@ -93,7 +93,7 @@ func (a *API) triggerBeforeUserCreatedExternal(
if err != nil {
return err
}
return a.triggerBeforeUserCreated(r, conn, user)
return a.triggerBeforeUserCreated(r, db, user)
}

func checkTX(conn *storage.Connection) error {
Expand Down
3 changes: 2 additions & 1 deletion internal/api/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config

claims := getClaims(ctx)
Expand Down Expand Up @@ -49,7 +50,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error {
return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist")
}

err = a.db.Transaction(func(tx *storage.Connection) error {
err = db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{
"identity_id": identityToBeDeleted.ID,
"provider": identityToBeDeleted.Provider,
Expand Down
2 changes: 1 addition & 1 deletion internal/api/magic_link.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error {
}

if isPKCEFlow(flowType) {
if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
},
}
output := v0hooks.SendSMSOutput{}
err := a.hooksMgr.InvokeHook(a.db, r, &input, &output)
err := a.hooksMgr.InvokeHook(db, r, &input, &output)
if err != nil {
return apierrors.NewInternalServerError("error invoking hook")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
output.Message = v0hooks.DefaultPasswordHookRejectionMessage
}
if output.ShouldLogoutUser {
if err := models.Logout(a.db, user.ID); err != nil {
if err := models.Logout(db, user.ID); err != nil {
return err
}
}
Expand Down
3 changes: 2 additions & 1 deletion internal/api/token_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return err
}

tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, a.db, r, tokens.RefreshTokenGrantParams{
db := a.db.WithContext(ctx)
tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{
RefreshToken: params.RefreshToken,
})
if err != nil {
Expand Down