Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core/config: refactor change dispatcher #4657

Merged
merged 8 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 12 additions & 11 deletions config/config_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/rs/zerolog"

"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
Expand All @@ -19,27 +20,27 @@ import (
// A ChangeListener is called when configuration changes.
type ChangeListener = func(context.Context, *Config)

type changeDispatcherEvent struct {
cfg *Config
}

// A ChangeDispatcher manages listeners on config changes.
type ChangeDispatcher struct {
sync.Mutex
onConfigChangeListeners []ChangeListener
target events.Target[changeDispatcherEvent]
}

// Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.Lock()
defer dispatcher.Unlock()

for _, li := range dispatcher.onConfigChangeListeners {
li(ctx, cfg)
}
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
cfg: cfg,
})
}

// OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
dispatcher.Lock()
defer dispatcher.Unlock()
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
li(ctx, evt.cfg)
})
}

// A Source gets configuration.
Expand Down
13 changes: 10 additions & 3 deletions config/layered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package config_test
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -13,6 +15,8 @@ import (
)

func TestLayeredConfig(t *testing.T) {
t.Parallel()

ctx := context.Background()

t.Run("error on initial build", func(t *testing.T) {
Expand All @@ -33,12 +37,15 @@ func TestLayeredConfig(t *testing.T) {
})
require.NoError(t, err)

var dst *config.Config
var dst atomic.Pointer[config.Config]
dst.Store(layered.GetConfig())
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
dst = c
dst.Store(c)
})

underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
assert.Eventually(t, func() bool {
return dst.Load().Options.GetDeriveInternalDomain() == "b.com"
}, 10*time.Second, time.Millisecond)
})
}
166 changes: 166 additions & 0 deletions internal/events/target.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package events

import (
"context"
"errors"
"sync"

"github.com/google/uuid"
)

type (
// A Listener is a function that listens for events of type T.
Listener[T any] func(ctx context.Context, event T)
// A Handle represents a listener.
Handle string

addListenerEvent[T any] struct {
listener Listener[T]
handle Handle
}
removeListenerEvent[T any] struct {
handle Handle
}
dispatchEvent[T any] struct {
ctx context.Context
event T
}
)

// A Target is a target for events.
//
// Listeners are added with AddListener with a function to be called when the event occurs.
// AddListener returns a Handle which can be used to remove a listener with RemoveListener.
//
// Dispatch dispatches events to all the registered listeners.
//
// Target is safe to use in its zero state.
//
// The first time any method of Target is called a background goroutine is started that handles
// any requests and maintains the state of the listeners. Each listener also starts a
// separate goroutine so that all listeners can be invoked concurrently.
//
// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically
// methods and dispatches will return immediately. However a slow listener will cause the next event
// dispatch to block. This is the opposite behavior from Manager.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for follow-up: that might be a great place to inject telemetry instrumentation to, in order to automatically calculate delays, especially if we give human readable names to subscribers

//
// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and
// Dispatch are no-ops.
type Target[T any] struct {
initOnce sync.Once
ctx context.Context
cancel context.CancelCauseFunc
addListenerCh chan addListenerEvent[T]
removeListenerCh chan removeListenerEvent[T]
dispatchCh chan dispatchEvent[T]
listeners map[Handle]chan dispatchEvent[T]
}

// AddListener adds a listener to the target.
func (t *Target[T]) AddListener(listener Listener[T]) Handle {
t.init()

// using a handle is necessary because you can't use a function as a map key.
handle := Handle(uuid.NewString())

select {
case <-t.ctx.Done():
case t.addListenerCh <- addListenerEvent[T]{listener, handle}:
}

return handle
}

// Close closes the event target. This can be called multiple times safely.
// Once closed the target cannot be used.
func (t *Target[T]) Close() {
t.init()

t.cancel(errors.New("target closed"))
}

// Dispatch dispatches an event to all listeners.
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
t.init()

select {
case <-t.ctx.Done():
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}

// RemoveListener removes a listener from the target.
func (t *Target[T]) RemoveListener(handle Handle) {
t.init()

select {
case <-t.ctx.Done():
case t.removeListenerCh <- removeListenerEvent[T]{handle}:
}
}

func (t *Target[T]) init() {
t.initOnce.Do(func() {
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.addListenerCh = make(chan addListenerEvent[T], 1)
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
t.dispatchCh = make(chan dispatchEvent[T], 1)
t.listeners = map[Handle]chan dispatchEvent[T]{}
go t.run()
})
}

func (t *Target[T]) run() {
// listen for add/remove/dispatch events and call functions
for {
select {
case <-t.ctx.Done():
return
wasaga marked this conversation as resolved.
Show resolved Hide resolved
case evt := <-t.addListenerCh:
t.addListener(evt.listener, evt.handle)
case evt := <-t.removeListenerCh:
t.removeListener(evt.handle)
case evt := <-t.dispatchCh:
t.dispatch(evt.ctx, evt.event)
}
}
}

// these functions are not thread-safe. They are intended to be called only by "run".

func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
ch := make(chan dispatchEvent[T], 1)
t.listeners[handle] = ch
// start a goroutine to send events to the listener
go func() {
for {
select {
case <-t.ctx.Done():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this case is missing a return, to stop the goroutine from looping.

case evt := <-ch:
listener(evt.ctx, evt.event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need some way to break out of the loop when ch is closed? It looks like calling RemoveListener() will result in the listener being called with a zero event value.

}
}
}()
}

func (t *Target[T]) removeListener(handle Handle) {
ch, ok := t.listeners[handle]
if !ok {
// nothing to do since the listener doesn't exist
return
}
// close the channel to kill the goroutine
close(ch)
delete(t.listeners, handle)
}

func (t *Target[T]) dispatch(ctx context.Context, evt T) {
// loop over all the listeners and send the event to them
for _, ch := range t.listeners {
select {
case <-t.ctx.Done():
return
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}
}
53 changes: 53 additions & 0 deletions internal/events/target_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package events_test

import (
"context"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/pomerium/pomerium/internal/events"
)

func TestTarget(t *testing.T) {
t.Parallel()

var target events.Target[int64]
t.Cleanup(target.Close)

var calls1, calls2, calls3 atomic.Int64
h1 := target.AddListener(func(_ context.Context, i int64) {
calls1.Add(i)
})
h2 := target.AddListener(func(_ context.Context, i int64) {
calls2.Add(i)
})
h3 := target.AddListener(func(_ context.Context, i int64) {
calls3.Add(i)
})

shouldBe := func(i1, i2, i3 int64) {
t.Helper()

assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Second, time.Millisecond)
}

target.Dispatch(context.Background(), 1)
shouldBe(1, 1, 1)

target.RemoveListener(h2)
target.Dispatch(context.Background(), 2)
shouldBe(3, 1, 3)

target.RemoveListener(h1)
target.Dispatch(context.Background(), 3)
shouldBe(3, 1, 6)

target.RemoveListener(h3)
target.Dispatch(context.Background(), 4)
shouldBe(3, 1, 6)
}