diff --git a/oauth2/trust/handler.go b/oauth2/trust/handler.go index 453ab376975..d46e6ca371c 100644 --- a/oauth2/trust/handler.go +++ b/oauth2/trust/handler.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/ory/fosite" "github.com/ory/x/pagination/tokenpagination" "github.com/ory/hydra/v2/x" @@ -110,7 +111,12 @@ func (h *Handler) trustOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http.Reque var grantRequest createGrantRequest if err := json.NewDecoder(r.Body).Decode(&grantRequest); err != nil { - h.registry.Writer().WriteError(w, r, errorsx.WithStack(err)) + h.registry.Writer().WriteError(w, r, + errorsx.WithStack(&fosite.RFC6749Error{ + ErrorField: "error", + DescriptionField: err.Error(), + CodeField: http.StatusBadRequest, + })) return } diff --git a/oauth2/trust/handler_test.go b/oauth2/trust/handler_test.go index f5e2b48d6f0..4d6bc80ba7d 100644 --- a/oauth2/trust/handler_test.go +++ b/oauth2/trust/handler_test.go @@ -9,11 +9,13 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/tidwall/gjson" "gopkg.in/square/go-jose.v2" "github.com/ory/x/pointerx" @@ -154,6 +156,24 @@ func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithSubjectAndAnySubject() { s.Require().Error(err, "expected error, because a grant with a subject and allow_any_subject cannot be created") } +func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithUnknownJWK() { + createRequestParams := hydra.TrustOAuth2JwtGrantIssuer{ + AllowAnySubject: pointerx.Ptr(true), + ExpiresAt: time.Now().Add(1 * time.Hour), + Issuer: "ory", + Jwk: hydra.JsonWebKey{ + Alg: "unknown", + }, + Scope: []string{"openid", "offline", "profile"}, + } + + _, res, err := s.hydraClient.OAuth2Api.TrustOAuth2JwtGrantIssuer(context.Background()).TrustOAuth2JwtGrantIssuer(createRequestParams).Execute() + s.Assert().Equal(http.StatusBadRequest, res.StatusCode) + body, _ := io.ReadAll(res.Body) + s.Contains(gjson.GetBytes(body, "error_description").String(), "unknown json web key type") + s.Require().Error(err, "expected error, because the key type was unknown") +} + func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithMissingFields() { createRequestParams := s.newCreateJwtBearerGrantParams( "", diff --git a/persistence/sql/persister_grant_jwk.go b/persistence/sql/persister_grant_jwk.go index 28f06669726..a9bc2ed4444 100644 --- a/persistence/sql/persister_grant_jwk.go +++ b/persistence/sql/persister_grant_jwk.go @@ -14,6 +14,7 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/ory/hydra/v2/oauth2/trust" + "github.com/ory/x/otelx" "github.com/ory/x/stringsx" "github.com/ory/x/sqlcon" @@ -21,9 +22,9 @@ import ( var _ trust.GrantManager = &Persister{} -func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) error { +func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant") - defer span.End() + defer otelx.End(span, &err) return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { // add key, if it doesn't exist @@ -42,9 +43,9 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo }) } -func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (trust.Grant, error) { +func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (_ trust.Grant, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteGrant") - defer span.End() + defer otelx.End(span, &err) var data trust.SQLData if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&data); err != nil { @@ -54,9 +55,9 @@ func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (trust.Gran return p.jwtGrantFromSQlData(data), nil } -func (p *Persister) DeleteGrant(ctx context.Context, id string) error { +func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant") - defer span.End() + defer otelx.End(span, &err) return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { grant, err := p.GetConcreteGrant(ctx, id) @@ -72,9 +73,9 @@ func (p *Persister) DeleteGrant(ctx context.Context, id string) error { }) } -func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIssuer string) ([]trust.Grant, error) { +func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIssuer string) (_ []trust.Grant, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetGrants") - defer span.End() + defer otelx.End(span, &err) grantsData := make([]trust.SQLData, 0) @@ -97,18 +98,18 @@ func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIs return grants, nil } -func (p *Persister) CountGrants(ctx context.Context) (int, error) { +func (p *Persister) CountGrants(ctx context.Context) (n int, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountGrants") - defer span.End() + defer otelx.End(span, &err) - n, err := p.QueryWithNetwork(ctx). + n, err = p.QueryWithNetwork(ctx). Count(&trust.SQLData{}) return n, sqlcon.HandleError(err) } -func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (*jose.JSONWebKey, error) { +func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (_ *jose.JSONWebKey, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKey") - defer span.End() + defer otelx.End(span, &err) var data trust.SQLData query := p.QueryWithNetwork(ctx). @@ -128,9 +129,9 @@ func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject str return &keySet.Keys[0], nil } -func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) { +func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (_ *jose.JSONWebKeySet, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeys") - defer span.End() + defer otelx.End(span, &err) grantsData := make([]trust.SQLData, 0) query := p.QueryWithNetwork(ctx). @@ -163,9 +164,9 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st return filteredKeySet, nil } -func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) ([]string, error) { +func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeyScopes") - defer span.End() + defer otelx.End(span, &err) var data trust.SQLData query := p.QueryWithNetwork(ctx). @@ -181,11 +182,11 @@ func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subje return p.jwtGrantFromSQlData(data).Scope, nil } -func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (bool, error) { +func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsJWTUsed") - defer span.End() + defer otelx.End(span, &err) - err := p.ClientAssertionJWTValid(ctx, jti) + err = p.ClientAssertionJWTValid(ctx, jti) if err != nil { return true, nil } @@ -193,9 +194,9 @@ func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (bool, error) { return false, nil } -func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error { +func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MarkJWTUsedForTime") - defer span.End() + defer otelx.End(span, &err) return p.SetClientAssertionJWT(ctx, jti, exp) } @@ -230,9 +231,9 @@ func (p *Persister) jwtGrantFromSQlData(data trust.SQLData) trust.Grant { } } -func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, limit int, batchSize int) error { +func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants") - defer span.End() + defer otelx.End(span, &err) deleteUntil := time.Now().UTC() if deleteUntil.After(notAfter) {