Skip to content

Commit

Permalink
Merge branch 'master' into td-macos-travis
Browse files Browse the repository at this point in the history
  • Loading branch information
evan2645 committed Jun 28, 2019
2 parents 2ed41f7 + 000a2aa commit 9718d9e
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 108 deletions.
53 changes: 53 additions & 0 deletions pkg/common/bundleutil/refreshhint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package bundleutil

import (
"math"
"time"
)

const (
refreshHintLeewayFactor = 10

// MinimumRefreshHint is the smallest refresh hint the client allows.
// Anything smaller than the minimum will be reset to the minimum.
MinimumRefreshHint = time.Minute
)

// CalculateRefreshHint is used to calculate the refresh hint for a given
// bundle. If the bundle already contains a refresh hint, then that is used,
// Otherwise, it looks at the lifetimes of the bundle contents and returns a
// fraction of the smallest. It is fairly aggressive but ensures clients don't
// miss a rotation period and lose their ability to fetch.
// TODO: reevaluate our strategy here when we rework the TTL story inside SPIRE.
func CalculateRefreshHint(bundle *Bundle) time.Duration {
if r := bundle.RefreshHint(); r > 0 {
return safeRefreshHint(r)
}

const maxDuration time.Duration = math.MaxInt64

smallestLifetime := maxDuration
for _, rootCA := range bundle.RootCAs() {
lifetime := rootCA.NotAfter.Sub(rootCA.NotBefore)
if lifetime < smallestLifetime {
smallestLifetime = lifetime
}
}

// TODO: look at JWT key lifetimes... requires us to track issued_at dates
// which we currently do not do.

// Set the refresh hint to a fraction of the smallest lifetime, if found.
var refreshHint time.Duration
if smallestLifetime != maxDuration {
refreshHint = smallestLifetime / refreshHintLeewayFactor
}
return safeRefreshHint(refreshHint)
}

func safeRefreshHint(refreshHint time.Duration) time.Duration {
if refreshHint < MinimumRefreshHint {
return MinimumRefreshHint
}
return refreshHint
}
65 changes: 65 additions & 0 deletions pkg/common/bundleutil/refreshhint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package bundleutil

import (
"crypto/x509"
"testing"
"time"

"github.com/spiffe/spire/proto/spire/common"
"github.com/stretchr/testify/require"
)

func TestCalculateRefreshHint(t *testing.T) {
emptyBundle := New("domain.test")

emptyBundleWithRefreshHint, err := BundleFromProto(&common.Bundle{
TrustDomainId: "domain.test",
RefreshHint: 3600,
})
require.NoError(t, err)

now := time.Now()
bundleWithCerts := New("domain.test")
bundleWithCerts.AppendRootCA(&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour * 2),
})
bundleWithCerts.AppendRootCA(&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
})
bundleWithCerts.AppendRootCA(&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour * 3),
})

testCases := []struct {
name string
bundle *Bundle
refreshHint time.Duration
}{
{
name: "empty bundle with no refresh hint",
bundle: emptyBundle,
refreshHint: MinimumRefreshHint,
},
{
name: "empty bundle with refresh hint",
bundle: emptyBundleWithRefreshHint,
refreshHint: time.Hour,
},
{
// the bundle has a few certs. the lowest lifetime is 1 hour.
// so we expect to get back a fraction of that time.
name: "bundle with certs",
bundle: bundleWithCerts,
refreshHint: time.Hour / refreshHintLeewayFactor,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
require.Equal(t, testCase.refreshHint, CalculateRefreshHint(testCase.bundle), "refresh hint is wrong")
})
}
}
7 changes: 0 additions & 7 deletions pkg/common/bundleutil/types.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
package bundleutil

import (
"time"

"gopkg.in/square/go-jose.v2"
)

const (
// DefaultRefreshHint is the default refresh hint returned from the bundle
// endpoint. Hard coding for now until we have a grasp on the right
// strategy.
DefaultRefreshHint = time.Minute * 10

x509SVIDUse = "x509-svid"
jwtSVIDUse = "jwt-svid"
)
Expand Down
64 changes: 47 additions & 17 deletions pkg/server/bundle/client/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ import (
"github.com/spiffe/spire/proto/spire/server/datastore"
)

const (
// attemptsPerRefreshHint is the number of attemps within the returned
// refresh hint period that the manager will attempt to refresh the
// bundle. It is important to try more than once within a refresh hint
// period so we can be resilient to temporary downtime or failures.
attemptsPerRefreshHint = 4
)

type TrustDomainConfig struct {
EndpointAddress string
EndpointSpiffeID string
Expand All @@ -27,8 +35,9 @@ type ManagerConfig struct {
}

type Manager struct {
log logrus.FieldLogger
clock clock.Clock
updaters []BundleUpdater
updaters map[string]BundleUpdater
}

func NewManager(config ManagerConfig) *Manager {
Expand All @@ -39,40 +48,64 @@ func NewManager(config ManagerConfig) *Manager {
config.newBundleUpdater = NewBundleUpdater
}

var updaters []BundleUpdater
updaters := make(map[string]BundleUpdater)
for trustDomain, trustDomainConfig := range config.TrustDomains {
updaters = append(updaters, config.newBundleUpdater(BundleUpdaterConfig{
updaters[trustDomain] = config.newBundleUpdater(BundleUpdaterConfig{
TrustDomainConfig: trustDomainConfig,
TrustDomain: trustDomain,
Log: config.Log.WithField("trust_domain", trustDomain),
DataStore: config.DataStore,
}))
})
}

return &Manager{
log: config.Log,
clock: config.Clock,
updaters: updaters,
}
}

func (m *Manager) Run(ctx context.Context) error {
var tasks []func(context.Context) error
for _, updater := range m.updaters {
for trustDomain, updater := range m.updaters {
tasks = append(tasks, func(ctx context.Context) error {
return m.runUpdater(ctx, updater)
return m.runUpdater(ctx, trustDomain, updater)
})
}

return util.RunTasks(ctx, tasks...)
}

func (m *Manager) runUpdater(ctx context.Context, updater BundleUpdater) error {
var refreshHint time.Duration
func (m *Manager) runUpdater(ctx context.Context, trustDomain string, updater BundleUpdater) error {
log := m.log.WithField("trust_domain", trustDomain)
for {
if r, err := updater.UpdateBundle(ctx); err == nil {
refreshHint = r
var nextRefresh time.Duration
log.Debug("Polling for bundle update")
localBundle, endpointBundle, err := updater.UpdateBundle(ctx)
if err != nil {
log.WithError(err).Error("Error updating bundle")
}

switch {
case endpointBundle != nil:
log.Info("Bundle refreshed")
nextRefresh = calculateNextUpdate(endpointBundle)
case localBundle != nil:
nextRefresh = calculateNextUpdate(localBundle)
default:
// We have no bundle to use to calculate the refresh hint. Since
// the endpoint cannot be reached without the local bundle (until
// we implement web auth), we can retry more aggressively. This
// refresh period determines how fast we'll respond to the local
// bundle being bootstrapped.
// TODO: reevaluate once we support web auth
nextRefresh = bundleutil.MinimumRefreshHint
}
timer := m.newRefreshTimer(refreshHint)

log.WithFields(logrus.Fields{
"at": m.clock.Now().Add(nextRefresh).UTC().Format(time.RFC3339),
}).Debug("Scheduling next bundle refresh")

timer := m.clock.Timer(nextRefresh)
select {
case <-timer.C:
case <-ctx.Done():
Expand All @@ -82,9 +115,6 @@ func (m *Manager) runUpdater(ctx context.Context, updater BundleUpdater) error {
}
}

func (m *Manager) newRefreshTimer(refreshHint time.Duration) *clock.Timer {
if refreshHint == 0 {
refreshHint = bundleutil.DefaultRefreshHint
}
return m.clock.Timer(refreshHint)
func calculateNextUpdate(b *bundleutil.Bundle) time.Duration {
return bundleutil.CalculateRefreshHint(b) / attemptsPerRefreshHint
}
56 changes: 31 additions & 25 deletions pkg/server/bundle/client/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,52 @@ import (
)

func TestManager(t *testing.T) {
// create a pair of bundles with distinct refresh hints so we can assert
// that the manager selected the correct refresh hint.
localBundle := bundleutil.BundleFromRootCA("spiffe://domain.test", createCACertificate(t, "local"))
localBundle.SetRefreshHint(time.Hour)
endpointBundle := bundleutil.BundleFromRootCA("spiffe://domain.test", createCACertificate(t, "endpoint"))
endpointBundle.SetRefreshHint(time.Hour * 2)

testCases := []struct {
name string
refreshHint time.Duration
refreshErr error
refreshAfter time.Duration
name string
localBundle *bundleutil.Bundle
endpointBundle *bundleutil.Bundle
nextRefresh time.Duration
}{
{
name: "update refresh hint used",
refreshHint: time.Minute,
refreshAfter: time.Minute,
name: "update failed to obtain local bundle",
nextRefresh: bundleutil.MinimumRefreshHint,
},
{
name: "default refresh hint used ",
refreshHint: 0,
refreshAfter: bundleutil.DefaultRefreshHint,
name: "update failed to obtain endpoint bundle",
localBundle: localBundle,
nextRefresh: calculateNextUpdate(localBundle),
},
{
name: "refresh hint unchanged on error ",
refreshHint: time.Minute,
refreshErr: errors.New("OHNO!"),
refreshAfter: bundleutil.DefaultRefreshHint,
name: "update obtained endpoint bundle",
localBundle: localBundle,
endpointBundle: endpointBundle,
nextRefresh: calculateNextUpdate(endpointBundle),
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
clock := clock.NewMock(t)

updater := newFakeBundleUpdater(testCase.refreshHint, testCase.refreshErr)
updater := newFakeBundleUpdater(testCase.localBundle, testCase.endpointBundle)

done := startManager(t, clock, updater)
defer done()

// wait for the initial refresh
waitForRefresh(t, clock, testCase.refreshAfter)
waitForRefresh(t, clock, testCase.nextRefresh)
require.Equal(t, 1, updater.UpdateCount())

// advance time and make sure another refresh happens
clock.Add(testCase.refreshAfter + time.Millisecond)
waitForRefresh(t, clock, testCase.refreshAfter)
clock.Add(testCase.nextRefresh + time.Millisecond)
waitForRefresh(t, clock, testCase.nextRefresh)
require.Equal(t, 2, updater.UpdateCount())
})
}
Expand Down Expand Up @@ -117,17 +123,17 @@ func waitForRefresh(t *testing.T, clock *clock.Mock, expectedDuration time.Durat
}

type fakeBundleUpdater struct {
refreshHint time.Duration
err error
localBundle *bundleutil.Bundle
endpointBundle *bundleutil.Bundle

mu sync.Mutex
updateCount int
}

func newFakeBundleUpdater(refreshHint time.Duration, err error) *fakeBundleUpdater {
func newFakeBundleUpdater(localBundle, endpointBundle *bundleutil.Bundle) *fakeBundleUpdater {
return &fakeBundleUpdater{
refreshHint: refreshHint,
err: err,
localBundle: localBundle,
endpointBundle: endpointBundle,
}
}

Expand All @@ -137,9 +143,9 @@ func (u *fakeBundleUpdater) UpdateCount() int {
return u.updateCount
}

func (u *fakeBundleUpdater) UpdateBundle(context.Context) (time.Duration, error) {
func (u *fakeBundleUpdater) UpdateBundle(context.Context) (*bundleutil.Bundle, *bundleutil.Bundle, error) {
u.mu.Lock()
defer u.mu.Unlock()
u.updateCount++
return u.refreshHint, u.err
return u.localBundle, u.endpointBundle, errors.New("UNUSED")
}
Loading

0 comments on commit 9718d9e

Please sign in to comment.