diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index ef10216477..02912e6da1 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -99,8 +99,8 @@ func NewPersister(c *pop.Connection, r Dependencies, config configuration.Provid } func (p *Persister) Connection(ctx context.Context) *pop.Connection { - if c := ctx.Value(transactionContextKey); c != nil { - return c.(*pop.Connection).WithContext(ctx) + if c, ok := ctx.Value(transactionContextKey).(*pop.Connection); ok { + return c.WithContext(ctx) } return p.conn.WithContext(ctx) } diff --git a/persistence/sql/persister_consent.go b/persistence/sql/persister_consent.go index afc2fd132a..d5cbd35d0c 100644 --- a/persistence/sql/persister_consent.go +++ b/persistence/sql/persister_consent.go @@ -301,27 +301,23 @@ func (p *Persister) FindGrantedAndRememberedConsentRequests(ctx context.Context, } func (p *Persister) FindSubjectsGrantedConsentRequests(ctx context.Context, subject string, limit, offset int) ([]consent.HandledConsentRequest, error) { - rs := make([]consent.HandledConsentRequest, 0) - - return rs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - tn := consent.HandledConsentRequest{}.TableName() + var rs []consent.HandledConsentRequest + c := p.Connection(ctx) + tn := consent.HandledConsentRequest{}.TableName() - if err := c. - Where(fmt.Sprintf("r.subject = ? AND r.skip=FALSE AND %s.error='{}'", tn), subject). - Join("hydra_oauth2_consent_request AS r", fmt.Sprintf("%s.challenge = r.challenge", tn)). - Order(fmt.Sprintf("%s.requested_at DESC", tn)). - Paginate(offset/limit+1, limit). - All(&rs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return errors.WithStack(consent.ErrNoPreviousConsentFound) - } - return sqlcon.HandleError(err) + if err := c. + Where(fmt.Sprintf("r.subject = ? AND r.skip=FALSE AND %s.error='{}'", tn), subject). + Join("hydra_oauth2_consent_request AS r", fmt.Sprintf("%s.challenge = r.challenge", tn)). + Order(fmt.Sprintf("%s.requested_at DESC", tn)). + Paginate(offset/limit+1, limit). + All(&rs); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.WithStack(consent.ErrNoPreviousConsentFound) } + return nil, sqlcon.HandleError(err) + } - var err error - rs, err = p.resolveHandledConsentRequests(ctx, rs) - return err - }) + return p.resolveHandledConsentRequests(ctx, rs) } func (p *Persister) CountSubjectsGrantedConsentRequests(ctx context.Context, subject string) (int, error) { @@ -336,31 +332,31 @@ func (p *Persister) CountSubjectsGrantedConsentRequests(ctx context.Context, sub func (p *Persister) resolveHandledConsentRequests(ctx context.Context, requests []consent.HandledConsentRequest) ([]consent.HandledConsentRequest, error) { var result []consent.HandledConsentRequest - return result, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { - for _, v := range requests { - _, err := p.GetConsentRequest(ctx, v.ID) - if errors.Is(err, sqlcon.ErrNoRows) || errors.Is(err, x.ErrNotFound) { - return errors.WithStack(consent.ErrNoPreviousConsentFound) - } else if err != nil { - return err - } - - // this will probably never error because we first check if the consent request actually exists - if err := v.AfterFind(p.Connection(ctx)); err != nil { - return err - } - if v.RememberFor > 0 && v.RequestedAt.Add(time.Duration(v.RememberFor)*time.Second).Before(time.Now().UTC()) { - continue - } - result = append(result, v) + for _, v := range requests { + _, err := p.GetConsentRequest(ctx, v.ID) + if errors.Is(err, sqlcon.ErrNoRows) || errors.Is(err, x.ErrNotFound) { + return nil, errors.WithStack(consent.ErrNoPreviousConsentFound) + } else if err != nil { + return nil, err } - if len(result) == 0 { - return errors.WithStack(consent.ErrNoPreviousConsentFound) + // This will probably never error because we first check if the consent request actually exists + if err := v.AfterFind(p.Connection(ctx)); err != nil { + return nil, err } - return nil - }) + if v.RememberFor > 0 && v.RequestedAt.Add(time.Duration(v.RememberFor)*time.Second).Before(time.Now().UTC()) { + continue + } + + result = append(result, v) + } + + if len(result) == 0 { + return nil, errors.WithStack(consent.ErrNoPreviousConsentFound) + } + + return result, nil } func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {