diff --git a/internal/api/admin.go b/internal/api/admin.go index 0ff63a841..0d53406de 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -514,6 +514,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { user := getUser(ctx) config := a.config adminUser := getAdminUser(ctx) + db := a.db.WithContext(ctx) // ShouldSoftDelete defaults to false params := &adminUserDeleteParams{} @@ -525,7 +526,7 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { } } - err := a.db.Transaction(func(tx *storage.Connection) error { + err := db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{ "user_id": user.ID, "user_email": user.Email, @@ -575,8 +576,9 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro config := a.config user := getUser(ctx) factor := getFactor(ctx) + db := a.db.WithContext(ctx) - err := a.db.Transaction(func(tx *storage.Connection) error { + err := db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{ "user_id": user.ID, "factor_id": factor.ID, @@ -608,12 +610,13 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro user := getUser(ctx) adminUser := getAdminUser(ctx) params := &adminUserUpdateFactorParams{} + db := a.db.WithContext(ctx) if err := retrieveRequestParams(r, params); err != nil { return err } - err := a.db.Transaction(func(tx *storage.Connection) error { + err := db.Transaction(func(tx *storage.Connection) error { if params.FriendlyName != "" { if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil { return terr diff --git a/internal/api/external.go b/internal/api/external.go index 42c59f38d..ae61bc86b 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -83,7 +83,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ flowStateID := "" if isPKCEFlow(flowType) { - flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) + flowState, err := generateFlowState(db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) if err != nil { return "", err } @@ -200,7 +200,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re var flowState *models.FlowState // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { - flowState, err = models.FindFlowStateByID(a.db, flowStateID) + flowState, err = models.FindFlowStateByID(db, flowStateID) if models.IsNotFoundError(err) { return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) } else if err != nil { @@ -506,7 +506,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p return user, nil } -func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) { +func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storage.Connection) (context.Context, error) { var state string switch r.Method { case http.MethodPost: @@ -564,7 +564,7 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C if err != nil { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } - u, err := models.FindUserByID(a.db, linkingTargetUserID) + u, err := models.FindUserByID(db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeUserNotFound, "Linking target user not found") diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index f52ea53ac..a02623a38 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -27,6 +27,8 @@ type OAuthProviderData struct { // extracting the provider requested func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() + db := a.db.WithContext(ctx) + oauthToken := r.URL.Query().Get("oauth_token") if oauthToken != "" { ctx = withRequestToken(ctx, oauthToken) @@ -37,7 +39,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } var err error - ctx, err = a.loadExternalState(ctx, r) + ctx, err = a.loadExternalState(ctx, r, db) if err != nil { u, uerr := url.ParseRequestURI(a.config.SiteURL) if uerr != nil { diff --git a/internal/api/hooks.go b/internal/api/hooks.go index f850ab1ad..95c2367c8 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -14,31 +14,31 @@ import ( func (a *API) triggerBeforeUserCreated( r *http.Request, - conn *storage.Connection, + db *storage.Connection, user *models.User, ) error { if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) { return nil } - if err := checkTX(conn); err != nil { + if err := checkTX(db); err != nil { return err } req := v0hooks.NewBeforeUserCreatedInput(r, user) res := new(v0hooks.BeforeUserCreatedOutput) - return a.hooksMgr.InvokeHook(conn, r, req, res) + return a.hooksMgr.InvokeHook(db, r, req, res) } func (a *API) triggerBeforeUserCreatedExternal( r *http.Request, - conn *storage.Connection, + db *storage.Connection, userData *provider.UserProvidedData, providerType string, ) error { if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) { return nil } - if err := checkTX(conn); err != nil { + if err := checkTX(db); err != nil { return err } @@ -55,7 +55,7 @@ func (a *API) triggerBeforeUserCreatedExternal( err error decision models.AccountLinkingResult ) - err = a.db.Transaction(func(tx *storage.Connection) error { + err = db.Transaction(func(tx *storage.Connection) error { decision, err = models.DetermineAccountLinking( tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject) @@ -93,7 +93,7 @@ func (a *API) triggerBeforeUserCreatedExternal( if err != nil { return err } - return a.triggerBeforeUserCreated(r, conn, user) + return a.triggerBeforeUserCreated(r, db, user) } func checkTX(conn *storage.Connection) error { diff --git a/internal/api/identity.go b/internal/api/identity.go index 4c26bccbd..5bfa2a767 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -15,6 +15,7 @@ import ( func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.db.WithContext(ctx) config := a.config claims := getClaims(ctx) @@ -49,7 +50,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeIdentityNotFound, "Identity doesn't exist") } - err = a.db.Transaction(func(tx *storage.Connection) error { + err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(config.AuditLog, r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{ "identity_id": identityToBeDeleted.ID, "provider": identityToBeDeleted.Provider, diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index ee2c1a532..2393059de 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -130,7 +130,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } if isPKCEFlow(flowType) { - if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil { + if _, err = generateFlowState(db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil { return err } } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index b49271113..a1e1de0bf 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -418,7 +418,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error }, } output := v0hooks.SendSMSOutput{} - err := a.hooksMgr.InvokeHook(a.db, r, &input, &output) + err := a.hooksMgr.InvokeHook(db, r, &input, &output) if err != nil { return apierrors.NewInternalServerError("error invoking hook") } diff --git a/internal/api/token.go b/internal/api/token.go index 4925290b2..a2b9bdcd1 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -165,7 +165,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri output.Message = v0hooks.DefaultPasswordHookRejectionMessage } if output.ShouldLogoutUser { - if err := models.Logout(a.db, user.ID); err != nil { + if err := models.Logout(db, user.ID); err != nil { return err } } diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 16395a573..2ba3dbfd2 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -19,7 +19,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return err } - tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, a.db, r, tokens.RefreshTokenGrantParams{ + db := a.db.WithContext(ctx) + tokenResponse, err := a.tokenService.RefreshTokenGrant(ctx, db, r, tokens.RefreshTokenGrantParams{ RefreshToken: params.RefreshToken, }) if err != nil {