diff --git a/oauth2/handler.go b/oauth2/handler.go index 3dfe8130f9..54c6573038 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -1340,8 +1340,14 @@ func (h *Handler) createVerifiableCredential(w http.ResponseWriter, r *http.Requ "id": fmt.Sprintf("did:jwk:%s", base64.RawURLEncoding.EncodeToString(proofJWKJSON)), }, }) - - rawToken, _, err := h.r.OpenIDJWTStrategy().Generate(ctx, session.Claims.ToMapClaims(), jwt.NewHeaders()) + signingKeyID, err := h.r.OpenIDJWTStrategy().GetPublicKeyID(ctx) + if err != nil { + h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) + return + } + headers := jwt.NewHeaders() + headers.Add("kid", signingKeyID) + rawToken, _, err := h.r.OpenIDJWTStrategy().Generate(ctx, session.Claims.ToMapClaims(), headers) if err != nil { h.r.Writer().WriteError(w, r, errorsx.WithStack(err)) return diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index a0ccc53c6a..e94eecba9d 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -8,6 +8,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -1164,6 +1165,18 @@ func assertCreateVerifiableCredential(t *testing.T, reg driver.Registry, nonce s func assertVerifiableCredentialContainsPublicKey(t *testing.T, reg driver.Registry, vc *hydraoauth2.VerifiableCredentialResponse, pubKeyJWK *jose.JSONWebKey) { ctx := context.Background() token, err := jwt.Parse(vc.Credential, func(token *jwt.Token) (interface{}, error) { + kid, found := token.Header["kid"] + if !found { + return nil, errors.New("missing kid header") + } + openIDKey, err := reg.OpenIDJWTStrategy().GetPublicKeyID(ctx) + if err != nil { + return nil, err + } + if kid != openIDKey { + return nil, errors.New("invalid kid header") + } + return x.Must(reg.OpenIDJWTStrategy().GetPublicKey(ctx)).Key, nil }) require.NoError(t, err)