Skip to content

Commit

Permalink
Merge pull request #544 from azdagron/node-api-jwt-svid-support
Browse files Browse the repository at this point in the history
Updates node API to support JWT SVID
  • Loading branch information
azdagron committed Jul 23, 2018
2 parents d7f02b3 + 5005aaf commit 79c6a7c
Show file tree
Hide file tree
Showing 20 changed files with 674 additions and 258 deletions.
2 changes: 1 addition & 1 deletion pkg/agent/attestor/node/node.go
Expand Up @@ -317,7 +317,7 @@ func (a *attestor) parseAttestationResponse(id string, r *node.AttestResponse) (
return nil, nil, fmt.Errorf("incorrect svid: %s", id)
}

svid, err := x509.ParseCertificate(svidMsg.SvidCert)
svid, err := x509.ParseCertificate(svidMsg.Cert)
if err != nil {
return nil, nil, fmt.Errorf("invalid svid: %v", err)
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/agent/attestor/node/node_test.go
Expand Up @@ -44,7 +44,7 @@ type NodeAttestorTestSuite struct {
keyManager *mock_keymanager.MockKeyManager
nodeClient *mock_node.MockNodeClient
config *Config
expectation *node.SvidUpdate
expectation *node.X509SVIDUpdate
}

func (s *NodeAttestorTestSuite) SetupTest() {
Expand Down Expand Up @@ -247,11 +247,11 @@ func (s *NodeAttestorTestSuite) setAttestResponse(challenges []challengeResponse
}, nil)
}
stream.EXPECT().Recv().Return(&node.AttestResponse{
SvidUpdate: &node.SvidUpdate{
Svids: map[string]*node.Svid{
"spiffe://example.com/spire/agent/join_token/foobar": &node.Svid{
SvidCert: svid.Raw,
Ttl: 300,
SvidUpdate: &node.X509SVIDUpdate{
Svids: map[string]*node.X509SVID{
"spiffe://example.com/spire/agent/join_token/foobar": &node.X509SVID{
Cert: svid.Raw,
ExpiresAt: svid.NotAfter.Unix(),
}},
}}, nil)
stream.EXPECT().CloseSend()
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/attestor/workload/workload_test.go
Expand Up @@ -33,7 +33,7 @@ type WorkloadAttestorTestSuite struct {
ctrl *gomock.Controller

attestor *attestor
expectation *node.SvidUpdate
expectation *node.X509SVIDUpdate
attestor1 *mock_workloadattestor.MockWorkloadAttestor
attestor2 *mock_workloadattestor.MockWorkloadAttestor
}
Expand Down
51 changes: 42 additions & 9 deletions pkg/agent/client/client.go
Expand Up @@ -28,8 +28,14 @@ var (
ErrUnableToGetStream = errors.New("unable to get a stream")
)

type JWTSVID struct {
Token string
ExpiresAt time.Time
}

type Client interface {
FetchUpdates(req *node.FetchX509SVIDRequest) (*Update, error)
FetchUpdates(ctx context.Context, req *node.FetchX509SVIDRequest) (*Update, error)
FetchJWTSVID(ctx context.Context, jsr *node.JSR) (*JWTSVID, error)

// Release releases any resources that were held by this Client, if any.
Release()
Expand Down Expand Up @@ -74,8 +80,8 @@ func (c *client) credsFunc() (credentials.TransportCredentials, error) {
return credentials.NewTLS(tlsConfig), nil
}

func (c *client) dial() (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // TODO: Make this timeout configurable?
func (c *client) dial(ctx context.Context) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) // TODO: Make this timeout configurable?
defer cancel()

config := grpcutil.GRPCDialerConfig{
Expand All @@ -90,13 +96,13 @@ func (c *client) dial() (*grpc.ClientConn, error) {
return conn, nil
}

func (c *client) FetchUpdates(req *node.FetchX509SVIDRequest) (*Update, error) {
nodeClient, err := c.newNodeClient()
func (c *client) FetchUpdates(ctx context.Context, req *node.FetchX509SVIDRequest) (*Update, error) {
nodeClient, err := c.newNodeClient(ctx)
if err != nil {
return nil, err
}

stream, err := nodeClient.FetchX509SVID(context.Background())
stream, err := nodeClient.FetchX509SVID(ctx)
// We weren't able to get a stream...close the client and return the error.
if err != nil {
c.Release()
Expand All @@ -114,7 +120,7 @@ func (c *client) FetchUpdates(req *node.FetchX509SVIDRequest) (*Update, error) {
}

regEntries := map[string]*common.RegistrationEntry{}
svids := map[string]*node.Svid{}
svids := map[string]*node.X509SVID{}
var lastBundle []byte
// Read all the server responses from the stream.
for {
Expand Down Expand Up @@ -142,6 +148,33 @@ func (c *client) FetchUpdates(req *node.FetchX509SVIDRequest) (*Update, error) {
}, nil
}

func (c *client) FetchJWTSVID(ctx context.Context, jsr *node.JSR) (*JWTSVID, error) {
nodeClient, err := c.newNodeClient(ctx)
if err != nil {
return nil, err
}

response, err := nodeClient.FetchJWTSVID(ctx, &node.FetchJWTSVIDRequest{
Jsr: jsr,
})
// We weren't able to make the request...close the client and return the error.
if err != nil {
c.Release()
c.c.Log.Errorf("%v: %v", ErrUnableToGetStream, err)
return nil, ErrUnableToGetStream
}

svid := response.GetSvid()
if svid == nil {
return nil, errors.New("JWTSVID response missing SVID")
}

return &JWTSVID{
Token: svid.Token,
ExpiresAt: time.Unix(svid.ExpiresAt, 0),
}, nil
}

func (c *client) Release() {
c.m.Lock()
defer c.m.Unlock()
Expand All @@ -152,7 +185,7 @@ func (c *client) Release() {
}
}

func (c *client) newNodeClient() (node.NodeClient, error) {
func (c *client) newNodeClient(ctx context.Context) (node.NodeClient, error) {
if c.newNodeClientCallback != nil {
return c.newNodeClientCallback()
}
Expand All @@ -161,7 +194,7 @@ func (c *client) newNodeClient() (node.NodeClient, error) {
defer c.m.Unlock()

if c.conn == nil {
conn, err := c.dial()
conn, err := c.dial(ctx)
if err != nil {
return nil, err
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/agent/client/client_test.go
@@ -1,6 +1,7 @@
package client

import (
"context"
"io"
"testing"

Expand Down Expand Up @@ -34,14 +35,14 @@ func TestFetchUpdates(t *testing.T) {
Csrs: [][]byte{{1, 2, 3, 4}},
}
res := &node.FetchX509SVIDResponse{
SvidUpdate: &node.SvidUpdate{
SvidUpdate: &node.X509SVIDUpdate{
Bundle: []byte{10, 20, 30, 40},
RegistrationEntries: []*common.RegistrationEntry{{
EntryId: "1",
}},
Svids: map[string]*node.Svid{
Svids: map[string]*node.X509SVID{
"someSpiffeId": {
SvidCert: []byte{11, 22, 33},
Cert: []byte{11, 22, 33},
},
},
},
Expand All @@ -53,7 +54,7 @@ func TestFetchUpdates(t *testing.T) {
nodeFsc.EXPECT().Recv().Return(res, nil)
nodeFsc.EXPECT().Recv().Return(nil, io.EOF)

update, err := client.FetchUpdates(req)
update, err := client.FetchUpdates(context.Background(), req)
require.Nil(t, err)

assert.Equal(t, res.SvidUpdate.Bundle, update.Bundle)
Expand Down
2 changes: 1 addition & 1 deletion pkg/agent/client/update.go
Expand Up @@ -10,7 +10,7 @@ import (

type Update struct {
Entries map[string]*common.RegistrationEntry
SVIDs map[string]*node.Svid
SVIDs map[string]*node.X509SVID
Bundle []byte
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/agent/client/update_test.go
Expand Up @@ -17,15 +17,15 @@ func TestString(t *testing.T) {
u := &Update{
Bundle: []byte{1, 2, 3},
Entries: map[string]*common.RegistrationEntry{entries[0].EntryId: entries[0]},
SVIDs: map[string]*node.Svid{
SVIDs: map[string]*node.X509SVID{
"spiffe://example.org": {
SvidCert: []byte{4, 5},
Ttl: 5,
Cert: []byte{4, 5},
ExpiresAt: 5,
},
},
}

expected := "{ Entries: [{ spiffeID: spiffe://example.org/spire/agent, parentID: spiffe://example.org/spire/agent/join_token/abcd, selectors: [type:\"spiffe_id\" value:\"spiffe://example.org/spire/agent/join_token/abcd\" ]}], SVIDs: [spiffe://example.org: svid_cert:\"\\004\\005\" ttl:5 ], Bundle: bytes}"
expected := "{ Entries: [{ spiffeID: spiffe://example.org/spire/agent, parentID: spiffe://example.org/spire/agent/join_token/abcd, selectors: [type:\"spiffe_id\" value:\"spiffe://example.org/spire/agent/join_token/abcd\" ]}], SVIDs: [spiffe://example.org: cert:\"\\004\\005\" expires_at:5 ], Bundle: bytes}"
if u.String() != expected {
t.Errorf("expected: %s, got: %s", expected, u.String())
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/agent/manager/manager.go
Expand Up @@ -71,7 +71,7 @@ func (m *manager) Initialize(ctx context.Context) error {
m.storeSVID(m.svid.State().SVID)
m.storeBundle(m.cache.Bundle())

return m.synchronize()
return m.synchronize(ctx)
}

func (m *manager) Run(ctx context.Context) error {
Expand Down Expand Up @@ -120,7 +120,7 @@ func (m *manager) runSynchronizer(ctx context.Context) error {
for {
select {
case <-t.C:
err := m.synchronize()
err := m.synchronize(ctx)
if err != nil {
// Just log the error to keep waiting for next sinchronization...
m.c.Log.Errorf("synchronize failed: %v", err)
Expand Down

0 comments on commit 79c6a7c

Please sign in to comment.