diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 00d79768447..45dab9b4931 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -14,6 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/zero/analytics" sdk "github.com/pomerium/pomerium/internal/zero/api" "github.com/pomerium/pomerium/internal/zero/bootstrap" + "github.com/pomerium/pomerium/internal/zero/leaser" "github.com/pomerium/pomerium/internal/zero/reconciler" "github.com/pomerium/pomerium/internal/zero/reporter" "github.com/pomerium/pomerium/pkg/cmd/pomerium" @@ -42,12 +43,10 @@ func Run(ctx context.Context, opts ...Option) error { } eg.Go(func() error { return run(ctx, "connect", c.runConnect, nil) }) + eg.Go(func() error { return run(ctx, "connect-log", c.RunConnectLog, nil) }) eg.Go(func() error { return run(ctx, "zero-bootstrap", c.runBootstrap, nil) }) eg.Go(func() error { return run(ctx, "pomerium-core", c.runPomeriumCore, src.WaitReady) }) - eg.Go(func() error { return run(ctx, "zero-reconciler", c.runReconciler, src.WaitReady) }) - eg.Go(func() error { return run(ctx, "connect-log", c.RunConnectLog, nil) }) - eg.Go(func() error { return run(ctx, "zero-analytics", c.runAnalytics, src.WaitReady) }) - eg.Go(func() error { return run(ctx, "zero-reporter", c.runReporter, src.WaitReady) }) + eg.Go(func() error { return c.runZeroControlLoop(ctx, src.WaitReady) }) return eg.Wait() } @@ -113,6 +112,19 @@ func (c *controller) runConnect(ctx context.Context) error { return c.api.Connect(ctx) } +func (c *controller) runZeroControlLoop(ctx context.Context, waitFn func(context.Context) error) error { + err := waitFn(ctx) + if err != nil { + return fmt.Errorf("error waiting for initial configuration: %w", err) + } + + return leaser.Run(ctx, c.databrokerClient, + c.runReconciler, + c.runAnalytics, + c.runReporter, + ) +} + func (c *controller) runReconciler(ctx context.Context) error { ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { return c.Str("service", "zero-reconciler") diff --git a/internal/zero/leaser/leaser.go b/internal/zero/leaser/leaser.go new file mode 100644 index 00000000000..9018fd2393d --- /dev/null +++ b/internal/zero/leaser/leaser.go @@ -0,0 +1,49 @@ +// Package leaser groups all Zero services that should run within a lease. +package leaser + +import ( + "context" + "time" + + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type service struct { + client databroker.DataBrokerServiceClient + funcs []func(ctx context.Context) error +} + +// GetDataBrokerServiceClient implements the databroker.LeaseHandler interface. +func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return c.client +} + +// RunLeased implements the databroker.LeaseHandler interface. +func (c *service) RunLeased(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + for _, fn := range c.funcs { + fn := fn + eg.Go(func() error { + return fn(ctx) + }) + } + return eg.Wait() +} + +// Run runs services within a lease +func Run( + ctx context.Context, + client databroker.DataBrokerServiceClient, + funcs ...func(ctx context.Context) error, +) error { + srv := &service{ + client: client, + funcs: funcs, + } + leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv) + return RunWithRestart(ctx, func(ctx context.Context) error { + return leaser.Run(ctx) + }, srv.databrokerChangeMonitor) +} diff --git a/internal/zero/leaser/monitor.go b/internal/zero/leaser/monitor.go new file mode 100644 index 00000000000..14554040748 --- /dev/null +++ b/internal/zero/leaser/monitor.go @@ -0,0 +1,37 @@ +package leaser + +import ( + "context" + "fmt" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +const typeStr = "pomerium.io/zero/leaser" + +// databrokerChangeMonitor runs infinite sync loop to see if there is any change in databroker +// it doesn't really syncs anything, just checks if the underlying databroker has changed +func (c *service) databrokerChangeMonitor(ctx context.Context) error { + _, recordVersion, serverVersion, err := databroker.InitialSync(ctx, c.GetDataBrokerServiceClient(), &databroker.SyncLatestRequest{ + Type: typeStr, + }) + if err != nil { + return fmt.Errorf("error during initial sync: %w", err) + } + + stream, err := c.GetDataBrokerServiceClient().Sync(ctx, &databroker.SyncRequest{ + Type: typeStr, + ServerVersion: serverVersion, + RecordVersion: recordVersion, + }) + if err != nil { + return fmt.Errorf("error calling sync: %w", err) + } + + for { + _, err := stream.Recv() + if err != nil { + return fmt.Errorf("error receiving record: %w", err) + } + } +} diff --git a/internal/zero/reconciler/restart.go b/internal/zero/leaser/restart.go similarity index 75% rename from internal/zero/reconciler/restart.go rename to internal/zero/leaser/restart.go index 0154214e457..6a852c97d70 100644 --- a/internal/zero/reconciler/restart.go +++ b/internal/zero/leaser/restart.go @@ -1,10 +1,15 @@ -package reconciler +package leaser import ( "context" "errors" "fmt" "sync" + "time" + + "github.com/cenkalti/backoff/v4" + + "github.com/pomerium/pomerium/internal/log" ) // RunWithRestart executes execFn. @@ -44,8 +49,15 @@ func restartContexts( contexts chan<- context.Context, restartFn func(context.Context) error, ) { + bo := backoff.NewExponentialBackOff() + bo.MaxElapsedTime = 0 // never stop + + ticker := time.NewTicker(bo.InitialInterval) + defer ticker.Stop() + defer close(contexts) for base.Err() == nil { + start := time.Now() ctx, cancel := context.WithCancelCause(base) select { case contexts <- ctx: @@ -55,6 +67,20 @@ func restartContexts( cancel(fmt.Errorf("parent context canceled: %w", base.Err())) return } + + if time.Since(start) > bo.MaxInterval { + bo.Reset() + } + next := bo.NextBackOff() + ticker.Reset(next) + + log.Ctx(ctx).Info().Msgf("restarting zero control loop in %s", next.String()) + + select { + case <-base.Done(): + return + case <-ticker.C: + } } } diff --git a/internal/zero/reconciler/restart_test.go b/internal/zero/leaser/restart_test.go similarity index 88% rename from internal/zero/reconciler/restart_test.go rename to internal/zero/leaser/restart_test.go index 2fbe0fdff84..4e17a52cd7d 100644 --- a/internal/zero/reconciler/restart_test.go +++ b/internal/zero/leaser/restart_test.go @@ -1,4 +1,4 @@ -package reconciler_test +package leaser_test import ( "context" @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/pomerium/pomerium/internal/zero/reconciler" + "github.com/pomerium/pomerium/internal/zero/leaser" ) func TestRestart(t *testing.T) { @@ -20,7 +20,7 @@ func TestRestart(t *testing.T) { errExpected := errors.New("execFn error") count := 0 - err := reconciler.RunWithRestart(context.Background(), + err := leaser.RunWithRestart(context.Background(), func(context.Context) error { count++ if count == 1 { @@ -40,7 +40,7 @@ func TestRestart(t *testing.T) { t.Parallel() count := 0 - err := reconciler.RunWithRestart(context.Background(), + err := leaser.RunWithRestart(context.Background(), func(context.Context) error { count++ if count == 1 { @@ -63,7 +63,7 @@ func TestRestart(t *testing.T) { t.Cleanup(cancel) ready := make(chan struct{}) - err := reconciler.RunWithRestart(ctx, + err := leaser.RunWithRestart(ctx, func(context.Context) error { <-ready cancel() @@ -87,7 +87,7 @@ func TestRestart(t *testing.T) { errExpected := errors.New("execFn error") count := 0 ready := make(chan struct{}) - err := reconciler.RunWithRestart(ctx, + err := leaser.RunWithRestart(ctx, func(ctx context.Context) error { count++ if count == 1 { // wait for us to be restarted diff --git a/internal/zero/reconciler/service.go b/internal/zero/reconciler/service.go index 0e987b002bf..86f9cf80da5 100644 --- a/internal/zero/reconciler/service.go +++ b/internal/zero/reconciler/service.go @@ -7,7 +7,6 @@ package reconciler import ( "context" - "fmt" "time" "golang.org/x/sync/errgroup" @@ -15,7 +14,6 @@ import ( "github.com/pomerium/pomerium/internal/atomicutil" connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux" - "github.com/pomerium/pomerium/pkg/grpc/databroker" ) type service struct { @@ -42,11 +40,6 @@ func Run(ctx context.Context, opts ...Option) error { } c.periodicUpdateInterval.Store(config.checkForUpdateIntervalWhenDisconnected) - return c.runMainLoop(ctx) -} - -// RunLeased implements the databroker.LeaseHandler interface -func (c *service) RunLeased(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { return c.watchUpdates(ctx) }) eg.Go(func() error { return c.SyncLoop(ctx) }) @@ -54,48 +47,6 @@ func (c *service) RunLeased(ctx context.Context) error { return eg.Wait() } -// GetDataBrokerServiceClient implements the databroker.LeaseHandler interface. -func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { - return c.config.databrokerClient -} - -func (c *service) runMainLoop(ctx context.Context) error { - leaser := databroker.NewLeaser("zero-reconciler", time.Second*30, c) - return RunWithRestart(ctx, func(ctx context.Context) error { - return leaser.Run(ctx) - }, c.databrokerChangeMonitor) -} - -// databrokerChangeMonitor runs infinite sync loop to see if there is any change in databroker -func (c *service) databrokerChangeMonitor(ctx context.Context) error { - _, recordVersion, serverVersion, err := databroker.InitialSync(ctx, c.GetDataBrokerServiceClient(), &databroker.SyncLatestRequest{ - Type: BundleCacheEntryRecordType, - }) - if err != nil { - return fmt.Errorf("error during initial sync: %w", err) - } - - stream, err := c.GetDataBrokerServiceClient().Sync(ctx, &databroker.SyncRequest{ - Type: BundleCacheEntryRecordType, - ServerVersion: serverVersion, - RecordVersion: recordVersion, - }) - if err != nil { - return fmt.Errorf("error calling sync: %w", err) - } - - for { - _, err := stream.Recv() - if err != nil { - return fmt.Errorf("error receiving record: %w", err) - } - } -} - -// run is a main control loop. -// it is very simple and sequential download and reconcile. -// it may be later optimized by splitting between download and reconciliation process, -// as we would get more resource bundles beyond the config. func (c *service) watchUpdates(ctx context.Context) error { return c.config.api.Watch(ctx, connect_mux.WithOnConnected(func(ctx context.Context) {