Skip to content

Commit

Permalink
Add FetchJWTSVIDs function for workloadapi and jwtSource (#187)
Browse files Browse the repository at this point in the history
Signed-off-by: Yuhan Li <liyuhan.loveyana@bytedance.com>
  • Loading branch information
loveyana committed Apr 29, 2022
1 parent c882182 commit 23ed83e
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 15 deletions.
56 changes: 50 additions & 6 deletions v2/workloadapi/client.go
Expand Up @@ -163,10 +163,29 @@ func (c *Client) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*jwts
return nil, err
}

if len(resp.Svids) == 0 {
return nil, errors.New("there were no SVIDs in the response")
svids, err := parseJWTSVIDs(resp, audience, true)
if err != nil {
return nil, err
}

return svids[0], nil
}

// FetchJWTSVIDs fetches all JWT-SVIDs.
func (c *Client) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

audience := append([]string{params.Audience}, params.ExtraAudiences...)
resp, err := c.wlClient.FetchJWTSVID(ctx, &workload.JWTSVIDRequest{
SpiffeId: params.Subject.String(),
Audience: audience,
})
if err != nil {
return nil, err
}
return jwtsvid.ParseInsecure(resp.Svids[0].Svid, audience)

return parseJWTSVIDs(resp, audience, false)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
Expand Down Expand Up @@ -357,6 +376,9 @@ func parseX509Context(resp *workload.X509SVIDResponse) (*X509Context, error) {
// Otherwise all SVIDs are parsed and returned.
func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
return nil, errors.New("no SVIDs in response")
}
if firstOnly {
n = 1
}
Expand All @@ -371,9 +393,6 @@ func parseX509SVIDs(resp *workload.X509SVIDResponse, firstOnly bool) ([]*x509svi
svids = append(svids, s)
}

if len(svids) == 0 {
return nil, errors.New("no SVIDs in response")
}
return svids, nil
}

Expand Down Expand Up @@ -413,6 +432,31 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error)
return x509bundle.FromX509Authorities(td, certs), nil
}

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
func parseJWTSVIDs(resp *workload.JWTSVIDResponse, audience []string, firstOnly bool) ([]*jwtsvid.SVID, error) {
n := len(resp.Svids)
if n == 0 {
return nil, errors.New("there were no SVIDs in the response")
}
if firstOnly {
n = 1
}

svids := make([]*jwtsvid.SVID, 0, n)
for i := 0; i < n; i++ {
svid := resp.Svids[i]
s, err := jwtsvid.ParseInsecure(svid.Svid, audience)
if err != nil {
return nil, err
}
svids = append(svids, s)
}

return svids, nil
}

func parseJWTSVIDBundles(resp *workload.JWTBundlesResponse) (*jwtbundle.Set, error) {
bundles := []*jwtbundle.Bundle{}

Expand Down
49 changes: 40 additions & 9 deletions v2/workloadapi/client_test.go
Expand Up @@ -232,8 +232,8 @@ func TestFetchJWTSVID(t *testing.T) {
subjectID := spiffeid.RequireFromPath(td, "/subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()})
respJWT := makeJWTSVIDResponse(subjectID.String(), token)
token := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse([]string{token}, subjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Expand All @@ -245,7 +245,36 @@ func TestFetchJWTSVID(t *testing.T) {
jwtSvid, err := c.FetchJWTSVID(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid, subjectID, token.Marshal(), audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid, subjectID, token, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTSVIDs(t *testing.T) {
ca := test.NewCA(t, td)
wl := fakeworkloadapi.New(t)
defer wl.Stop()
c, _ := New(context.Background(), WithAddr(wl.Addr()))
defer c.Close()

subjectID := spiffeid.RequireFromPath(td, "/subject")
extraSubjectID := spiffeid.RequireFromPath(td, "/extra_subject")
audienceID := spiffeid.RequireFromPath(td, "/audience")
extraAudienceID := spiffeid.RequireFromPath(td, "/extra_audience")
subjectIDToken := ca.CreateJWTSVID(subjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
extraSubjectIDToken := ca.CreateJWTSVID(extraSubjectID, []string{audienceID.String(), extraAudienceID.String()}).Marshal()
respJWT := makeJWTSVIDResponse([]string{subjectIDToken, extraSubjectIDToken}, subjectID, extraSubjectID)
wl.SetJWTSVIDResponse(respJWT)

params := jwtsvid.Params{
Subject: subjectID,
Audience: audienceID.String(),
ExtraAudiences: []string{extraAudienceID.String()},
}

jwtSvid, err := c.FetchJWTSVIDs(context.Background(), params)

require.NoError(t, err)
assertJWTSVID(t, jwtSvid[0], subjectID, subjectIDToken, audienceID.String(), extraAudienceID.String())
assertJWTSVID(t, jwtSvid[1], extraSubjectID, extraSubjectIDToken, audienceID.String(), extraAudienceID.String())
}

func TestFetchJWTBundles(t *testing.T) {
Expand Down Expand Up @@ -357,12 +386,14 @@ func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID {
return svids
}

func makeJWTSVIDResponse(spiffeID string, token *jwtsvid.SVID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{
{
SpiffeId: spiffeID,
Svid: token.Marshal(),
},
func makeJWTSVIDResponse(token []string, ids ...spiffeid.ID) *workload.JWTSVIDResponse {
svids := []*workload.JWTSVID{}
for i, id := range ids {
svid := &workload.JWTSVID{
SpiffeId: id.String(),
Svid: token[i],
}
svids = append(svids, svid)
}
return &workload.JWTSVIDResponse{
Svids: svids,
Expand Down
10 changes: 10 additions & 0 deletions v2/workloadapi/convenience.go
Expand Up @@ -71,6 +71,16 @@ func FetchJWTSVID(ctx context.Context, params jwtsvid.Params, options ...ClientO
return c.FetchJWTSVID(ctx, params)
}

// FetchJWTSVID fetches all JWT-SVIDs.
func FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params, options ...ClientOption) ([]*jwtsvid.SVID, error) {
c, err := New(ctx, options...)
if err != nil {
return nil, err
}
defer c.Close()
return c.FetchJWTSVIDs(ctx, params)
}

// FetchJWTBundles fetches the JWT bundles for JWT-SVID validation, keyed
// by a SPIFFE ID of the trust domain to which they belong.
func FetchJWTBundles(ctx context.Context, options ...ClientOption) (*jwtbundle.Set, error) {
Expand Down
9 changes: 9 additions & 0 deletions v2/workloadapi/jwtsource.go
Expand Up @@ -63,6 +63,15 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j
return s.watcher.client.FetchJWTSVID(ctx, params)
}

// FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters.
// It implements the jwtsvid.Source interface.
func (s *JWTSource) FetchJWTSVIDs(ctx context.Context, params jwtsvid.Params) ([]*jwtsvid.SVID, error) {
if err := s.checkClosed(); err != nil {
return nil, err
}
return s.watcher.client.FetchJWTSVIDs(ctx, params)
}

// GetJWTBundleForTrustDomain returns the JWT bundle for the given trust
// domain. It implements the jwtbundle.Source interface.
func (s *JWTSource) GetJWTBundleForTrustDomain(trustDomain spiffeid.TrustDomain) (*jwtbundle.Bundle, error) {
Expand Down
1 change: 1 addition & 0 deletions v2/workloadapi/watcher.go
Expand Up @@ -13,6 +13,7 @@ type sourceClient interface {
WatchX509Context(context.Context, X509ContextWatcher) error
WatchJWTBundles(context.Context, JWTBundleWatcher) error
FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error)
FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error)
Close() error
}

Expand Down

0 comments on commit 23ed83e

Please sign in to comment.