diff --git a/refreshable/async.go b/refreshable/async.go new file mode 100644 index 00000000..4ccd6f75 --- /dev/null +++ b/refreshable/async.go @@ -0,0 +1,103 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package refreshable + +import ( + "context" + "time" +) + +// NewFromChannel populates an Updatable with the values channel. +// If an element is already available, the returned Value is guaranteed to be populated. +// The channel should be closed when no longer used to avoid leaking resources. +func NewFromChannel[T any](values <-chan T) Ready[T] { + out := newReady[T]() + select { + case initial, ok := <-values: + if !ok { + return out // channel already closed + } + out.Update(initial) + default: + } + go func() { + for value := range values { + out.Update(value) + } + }() + return out +} + +// NewFromTickerFunc returns a Ready Refreshable populated by the result of the provider called each interval. +// If the providers bool return is false, the value is ignored. +// The result's ReadyC channel is closed when a new value is populated. +// The refreshable will stop updating when the provided context is cancelled or the returned UnsubscribeFunc func is called. +func NewFromTickerFunc[T any](ctx context.Context, interval time.Duration, provider func(ctx context.Context) (T, bool)) (Ready[T], UnsubscribeFunc) { + out := newReady[T]() + ctx, cancel := context.WithCancel(ctx) + values := make(chan T) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + defer close(values) + for { + if value, ok := provider(ctx); ok { + out.Update(value) + } + select { + case <-ticker.C: + continue + case <-ctx.Done(): + return + } + } + }() + return out, UnsubscribeFunc(cancel) +} + +// Wait waits until the Ready has a current value or the context expires. +func Wait[T any](ctx context.Context, ready Ready[T]) (T, bool) { + select { + case <-ready.ReadyC(): + return ready.Current(), true + case <-ctx.Done(): + var zero T + return zero, false + } +} + +// ready is an Updatable which exposes a channel that is closed when a value is first available. +// Current returns the zero value before Update is called, marking the value ready. +type ready[T any] struct { + in Updatable[T] + readyC <-chan struct{} + cancel context.CancelFunc +} + +func newReady[T any]() *ready[T] { + ctx, cancel := context.WithCancel(context.Background()) + return &ready[T]{ + in: newZero[T](), + readyC: ctx.Done(), + cancel: cancel, + } +} + +func (r *ready[T]) Current() T { + return r.in.Current() +} + +func (r *ready[T]) Subscribe(consumer func(T)) UnsubscribeFunc { + return r.in.Subscribe(consumer) +} + +func (r *ready[T]) ReadyC() <-chan struct{} { + return r.readyC +} + +func (r *ready[T]) Update(val T) { + r.in.Update(val) + r.cancel() +} diff --git a/refreshable/go.mod b/refreshable/go.mod index 98d87a92..c0122769 100644 --- a/refreshable/go.mod +++ b/refreshable/go.mod @@ -1,4 +1,4 @@ -module github.com/palantir/pkg/refreshable +module github.com/palantir/pkg/refreshable/v2 go 1.20 diff --git a/refreshable/godel/config/check-plugin.yml b/refreshable/godel/config/check-plugin.yml index aa1fc55b..e69de29b 100644 --- a/refreshable/godel/config/check-plugin.yml +++ b/refreshable/godel/config/check-plugin.yml @@ -1,5 +0,0 @@ -checks: - golint: - filters: - - value: "should have comment or be unexported" - - value: "or a comment on this block" diff --git a/refreshable/refreshable.go b/refreshable/refreshable.go index 1a6e7cb7..f5c8fa3e 100644 --- a/refreshable/refreshable.go +++ b/refreshable/refreshable.go @@ -4,14 +4,92 @@ package refreshable -type Refreshable interface { +import ( + "context" +) + +// A Refreshable is a generic container type for a volatile underlying value. +// It supports atomic access and user-provided callback "subscriptions" on updates. +type Refreshable[T any] interface { // Current returns the most recent value of this Refreshable. - Current() interface{} + // If the value has not been initialized, returns T's zero value. + Current() T + + // Subscribe calls the consumer function when Value updates until stop is closed. + // The consumer must be relatively fast: Updatable.Set blocks until all subscribers have returned. + // Expensive or error-prone responses to refreshed values should be asynchronous. + // Updates considered no-ops by reflect.DeepEqual may be skipped. + // When called, consumer is executed with the Current value. + Subscribe(consumer func(T)) UnsubscribeFunc +} + +// A Updatable is a Refreshable which supports setting the value with a user-provided value. +// When a utility returns a (non-Updatable) Refreshable, it implies that value updates are handled internally. +type Updatable[T any] interface { + Refreshable[T] + // Update updates the Refreshable with a new T. + // It blocks until all subscribers have completed. + Update(T) +} + +// A Validated is a Refreshable capable of rejecting updates according to validation logic. +// Its Current method returns the most recent value to pass validation. +type Validated[T any] interface { + Refreshable[T] + // Validation returns the result of the most recent validation. + // If the last value was valid, Validation returns the same value as Current and a nil error. + // If the last value was invalid, it and the error are returned. Current returns the most recent valid value. + Validation() (T, error) +} + +// Ready extends Refreshable for asynchronous implementations which may not have a value when they are constructed. +// Callers should check that the Ready channel is closed before using the Current value. +type Ready[T any] interface { + Refreshable[T] + // ReadyC returns a channel which is closed after a value is successfully populated. + ReadyC() <-chan struct{} +} - // Subscribe subscribes to changes of this Refreshable. The provided function is called with the value of Current() - // whenever the value changes. - Subscribe(consumer func(interface{})) (unsubscribe func()) +// UnsubscribeFunc removes a subscription from a refreshable's internal tracking and/or stops its update routine. +// It is safe to call multiple times. +type UnsubscribeFunc func() + +// New returns a new Updatable that begins with the given value. +func New[T any](val T) Updatable[T] { + return newDefault(val) +} + +// Map returns a new Refreshable based on the current one that handles updates based on the current Refreshable. +func Map[T any, M any](original Refreshable[T], mapFn func(T) M) (Refreshable[M], UnsubscribeFunc) { + out := newDefault(mapFn(original.Current())) + stop := original.Subscribe(func(v T) { + out.Update(mapFn(v)) + }) + return (*readOnlyRefreshable[M])(out), stop +} + +// MapContext is like Map but unsubscribes when the context is cancelled. +func MapContext[T any, M any](ctx context.Context, original Refreshable[T], mapFn func(T) M) Refreshable[M] { + out, stop := Map(original, mapFn) + go func() { + <-ctx.Done() + stop() + }() + return out +} + +// MapWithError is similar to Validate but allows for the function to return a mapping/mutation +// of the input object in addition to returning an error. The returned validRefreshable will contain the mapped value. +// An error is returned if the current original value fails to map. +func MapWithError[T any, M any](original Refreshable[T], mapFn func(T) (M, error)) (Validated[M], UnsubscribeFunc, error) { + v, stop := newValidRefreshable(original, mapFn) + _, err := v.Validation() + return v, stop, err +} - // Map returns a new Refreshable based on the current one that handles updates based on the current Refreshable. - Map(func(interface{}) interface{}) Refreshable +// Validate returns a new Refreshable that returns the latest original value accepted by the validatingFn. +// If the upstream value results in an error, it is reported by Validation(). +// An error is returned if the current original value is invalid. +func Validate[T any](original Refreshable[T], validatingFn func(T) error) (Validated[T], UnsubscribeFunc, error) { + return MapWithError(original, identity(validatingFn)) } diff --git a/refreshable/refreshable_default.go b/refreshable/refreshable_default.go index 2ad8488c..22f9e8d0 100644 --- a/refreshable/refreshable_default.go +++ b/refreshable/refreshable_default.go @@ -5,84 +5,84 @@ package refreshable import ( - "fmt" "reflect" "sync" "sync/atomic" ) -type DefaultRefreshable struct { - typ reflect.Type - current *atomic.Value - - sync.Mutex // protects subscribers - subscribers []*func(interface{}) +type defaultRefreshable[T any] struct { + mux sync.Mutex + current atomic.Value + subscribers []*func(T) } -func NewDefaultRefreshable(val interface{}) *DefaultRefreshable { - current := atomic.Value{} - current.Store(val) - - return &DefaultRefreshable{ - current: ¤t, - typ: reflect.TypeOf(val), - } +func newDefault[T any](val T) *defaultRefreshable[T] { + d := new(defaultRefreshable[T]) + d.current.Store(&val) + return d } -func (d *DefaultRefreshable) Update(val interface{}) error { - d.Lock() - defer d.Unlock() - - if valType := reflect.TypeOf(val); valType != d.typ { - return fmt.Errorf("new refreshable value must be type %s: got %s", d.typ, valType) - } +func newZero[T any]() *defaultRefreshable[T] { + d := new(defaultRefreshable[T]) + var zero T + d.current.Store(&zero) + return d +} - if reflect.DeepEqual(d.current.Load(), val) { - return nil +// Update changes the value of the Refreshable, then blocks while subscribers are executed. +func (d *defaultRefreshable[T]) Update(val T) { + d.mux.Lock() + defer d.mux.Unlock() + old := d.current.Swap(&val) + if reflect.DeepEqual(*(old.(*T)), val) { + return } - d.current.Store(val) - for _, sub := range d.subscribers { (*sub)(val) } - return nil } -func (d *DefaultRefreshable) Current() interface{} { - return d.current.Load() +func (d *defaultRefreshable[T]) Current() T { + return *(d.current.Load().(*T)) } -func (d *DefaultRefreshable) Subscribe(consumer func(interface{})) (unsubscribe func()) { - d.Lock() - defer d.Unlock() +func (d *defaultRefreshable[T]) Subscribe(consumer func(T)) UnsubscribeFunc { + d.mux.Lock() + defer d.mux.Unlock() consumerFnPtr := &consumer d.subscribers = append(d.subscribers, consumerFnPtr) + consumer(d.Current()) + return d.unsubscribe(consumerFnPtr) +} + +func (d *defaultRefreshable[T]) unsubscribe(consumerFnPtr *func(T)) UnsubscribeFunc { return func() { - d.unsubscribe(consumerFnPtr) + d.mux.Lock() + defer d.mux.Unlock() + + matchIdx := -1 + for idx, currSub := range d.subscribers { + if currSub == consumerFnPtr { + matchIdx = idx + break + } + } + if matchIdx != -1 { + d.subscribers = append(d.subscribers[:matchIdx], d.subscribers[matchIdx+1:]...) + } } + } -func (d *DefaultRefreshable) unsubscribe(consumerFnPtr *func(interface{})) { - d.Lock() - defer d.Unlock() +// readOnlyRefreshable aliases defaultRefreshable but hides the Update method so the type +// does not implement Updatable. +type readOnlyRefreshable[T any] defaultRefreshable[T] - matchIdx := -1 - for idx, currSub := range d.subscribers { - if currSub == consumerFnPtr { - matchIdx = idx - break - } - } - if matchIdx != -1 { - d.subscribers = append(d.subscribers[:matchIdx], d.subscribers[matchIdx+1:]...) - } +func (d *readOnlyRefreshable[T]) Current() T { + return (*defaultRefreshable[T])(d).Current() } -func (d *DefaultRefreshable) Map(mapFn func(interface{}) interface{}) Refreshable { - newRefreshable := NewDefaultRefreshable(mapFn(d.Current())) - d.Subscribe(func(updatedVal interface{}) { - _ = newRefreshable.Update(mapFn(updatedVal)) - }) - return newRefreshable +func (d *readOnlyRefreshable[T]) Subscribe(consumer func(T)) UnsubscribeFunc { + return (*defaultRefreshable[T])(d).Subscribe(consumer) } diff --git a/refreshable/refreshable_default_test.go b/refreshable/refreshable_default_test.go index 75be7915..53a9328c 100644 --- a/refreshable/refreshable_default_test.go +++ b/refreshable/refreshable_default_test.go @@ -7,58 +7,66 @@ package refreshable_test import ( "testing" - "github.com/palantir/pkg/refreshable" - "github.com/stretchr/testify/assert" + "github.com/palantir/pkg/refreshable/v2" "github.com/stretchr/testify/require" ) func TestDefaultRefreshable(t *testing.T) { - type container struct{ Value string } + type container struct { + Value string + } v := &container{Value: "original"} - r := refreshable.NewDefaultRefreshable(v) - assert.Equal(t, r.Current(), v) + r := refreshable.New(v) + require.Equal(t, r.Current(), v) t.Run("Update", func(t *testing.T) { v2 := &container{Value: "updated"} - err := r.Update(v2) - require.NoError(t, err) - assert.Equal(t, r.Current(), v2) + r.Update(v2) + require.Equal(t, r.Current(), v2) }) t.Run("Subscribe", func(t *testing.T) { var v1, v2 container - unsub1 := r.Subscribe(func(i interface{}) { - v1 = *(i.(*container)) + unsub1 := r.Subscribe(func(i *container) { + v1 = *i }) - _ = r.Subscribe(func(i interface{}) { - v2 = *(i.(*container)) + _ = r.Subscribe(func(i *container) { + v2 = *i }) - assert.Equal(t, v1.Value, "") - assert.Equal(t, v2.Value, "") - err := r.Update(&container{Value: "value"}) - require.NoError(t, err) - assert.Equal(t, v1.Value, "value") - assert.Equal(t, v2.Value, "value") + require.Equal(t, v1.Value, "updated") + require.Equal(t, v2.Value, "updated") + r.Update(&container{Value: "value"}) + require.Equal(t, v1.Value, "value") + require.Equal(t, v2.Value, "value") unsub1() - err = r.Update(&container{Value: "value2"}) - require.NoError(t, err) - assert.Equal(t, v1.Value, "value", "should be unchanged after unsubscribing") - assert.Equal(t, v2.Value, "value2", "should be updated after unsubscribing other") + r.Update(&container{Value: "value2"}) + require.Equal(t, v1.Value, "value", "should be unchanged after unsubscribing") + require.Equal(t, v2.Value, "value2", "should be updated after unsubscribing other") }) t.Run("Map", func(t *testing.T) { - err := r.Update(&container{Value: "value"}) - require.NoError(t, err) - m := r.Map(func(i interface{}) interface{} { - return len(i.(*container).Value) + r.Update(&container{Value: "value"}) + rLen, stop := refreshable.Map[*container, int](r, func(i *container) int { + return len(i.Value) }) - assert.Equal(t, m.Current(), 5) + defer stop() + require.Equal(t, 5, rLen.Current()) - err = r.Update(&container{Value: "updated"}) - require.NoError(t, err) - assert.Equal(t, m.Current(), 7) + rLenUpdates := 0 + rLen.Subscribe(func(int) { rLenUpdates++ }) + require.Equal(t, 1, rLenUpdates) + // update to new value with same length and ensure the + // equality check prevented unnecessary subscriber updates. + r.Update(&container{Value: "VALUE"}) + require.Equal(t, "VALUE", r.Current().Value) + require.Equal(t, 1, rLenUpdates) + + r.Update(&container{Value: "updated"}) + require.Equal(t, "updated", r.Current().Value) + require.Equal(t, 7, rLen.Current()) + require.Equal(t, 2, rLenUpdates) }) } diff --git a/refreshable/refreshable_types.go b/refreshable/refreshable_types.go deleted file mode 100644 index e4a784af..00000000 --- a/refreshable/refreshable_types.go +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright (c) 2021 Palantir Technologies. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package refreshable - -import ( - "time" -) - -type String interface { - Refreshable - CurrentString() string - MapString(func(string) interface{}) Refreshable - SubscribeToString(func(string)) (unsubscribe func()) -} - -type StringPtr interface { - Refreshable - CurrentStringPtr() *string - MapStringPtr(func(*string) interface{}) Refreshable - SubscribeToStringPtr(func(*string)) (unsubscribe func()) -} - -type StringSlice interface { - Refreshable - CurrentStringSlice() []string - MapStringSlice(func([]string) interface{}) Refreshable - SubscribeToStringSlice(func([]string)) (unsubscribe func()) -} - -type Int interface { - Refreshable - CurrentInt() int - MapInt(func(int) interface{}) Refreshable - SubscribeToInt(func(int)) (unsubscribe func()) -} - -type IntPtr interface { - Refreshable - CurrentIntPtr() *int - MapIntPtr(func(*int) interface{}) Refreshable - SubscribeToIntPtr(func(*int)) (unsubscribe func()) -} - -type Int64 interface { - Refreshable - CurrentInt64() int64 - MapInt64(func(int64) interface{}) Refreshable - SubscribeToInt64(func(int64)) (unsubscribe func()) -} - -type Int64Ptr interface { - Refreshable - CurrentInt64Ptr() *int64 - MapInt64Ptr(func(*int64) interface{}) Refreshable - SubscribeToInt64Ptr(func(*int64)) (unsubscribe func()) -} - -type Float64 interface { - Refreshable - CurrentFloat64() float64 - MapFloat64(func(float64) interface{}) Refreshable - SubscribeToFloat64(func(float64)) (unsubscribe func()) -} - -type Float64Ptr interface { - Refreshable - CurrentFloat64Ptr() *float64 - MapFloat64Ptr(func(*float64) interface{}) Refreshable - SubscribeToFloat64Ptr(func(*float64)) (unsubscribe func()) -} - -type Bool interface { - Refreshable - CurrentBool() bool - MapBool(func(bool) interface{}) Refreshable - SubscribeToBool(func(bool)) (unsubscribe func()) -} - -type BoolPtr interface { - Refreshable - CurrentBoolPtr() *bool - MapBoolPtr(func(*bool) interface{}) Refreshable - SubscribeToBoolPtr(func(*bool)) (unsubscribe func()) -} - -// Duration is a Refreshable that can return the current time.Duration. -type Duration interface { - Refreshable - CurrentDuration() time.Duration - MapDuration(func(time.Duration) interface{}) Refreshable - SubscribeToDuration(func(time.Duration)) (unsubscribe func()) -} - -type DurationPtr interface { - Refreshable - CurrentDurationPtr() *time.Duration - MapDurationPtr(func(*time.Duration) interface{}) Refreshable - SubscribeToDurationPtr(func(*time.Duration)) (unsubscribe func()) -} - -func NewBool(in Refreshable) Bool { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewBoolPtr(in Refreshable) BoolPtr { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewDuration(in Refreshable) Duration { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewDurationPtr(in Refreshable) DurationPtr { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewString(in Refreshable) String { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewStringPtr(in Refreshable) StringPtr { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewStringSlice(in Refreshable) StringSlice { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewInt(in Refreshable) Int { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewIntPtr(in Refreshable) IntPtr { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewInt64(in Refreshable) Int64 { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewInt64Ptr(in Refreshable) Int64Ptr { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewFloat64(in Refreshable) Float64 { - return refreshableTyped{ - Refreshable: in, - } -} - -func NewFloat64Ptr(in Refreshable) Float64Ptr { - return refreshableTyped{ - Refreshable: in, - } -} - -var ( - _ Bool = (*refreshableTyped)(nil) - _ BoolPtr = (*refreshableTyped)(nil) - _ Duration = (*refreshableTyped)(nil) - _ Int = (*refreshableTyped)(nil) - _ IntPtr = (*refreshableTyped)(nil) - _ Int64 = (*refreshableTyped)(nil) - _ Int64Ptr = (*refreshableTyped)(nil) - _ Float64 = (*refreshableTyped)(nil) - _ Float64Ptr = (*refreshableTyped)(nil) - _ String = (*refreshableTyped)(nil) - _ StringPtr = (*refreshableTyped)(nil) - _ StringSlice = (*refreshableTyped)(nil) -) - -type refreshableTyped struct { - Refreshable -} - -func (rt refreshableTyped) CurrentString() string { - return rt.Current().(string) -} - -func (rt refreshableTyped) MapString(mapFn func(string) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(string)) - }) -} - -func (rt refreshableTyped) SubscribeToString(subFn func(string)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(string)) - }) -} - -func (rt refreshableTyped) CurrentStringPtr() *string { - return rt.Current().(*string) -} - -func (rt refreshableTyped) MapStringPtr(mapFn func(*string) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*string)) - }) -} - -func (rt refreshableTyped) SubscribeToStringPtr(subFn func(*string)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*string)) - }) -} - -func (rt refreshableTyped) CurrentStringSlice() []string { - return rt.Current().([]string) -} - -func (rt refreshableTyped) MapStringSlice(mapFn func([]string) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.([]string)) - }) -} - -func (rt refreshableTyped) SubscribeToStringSlice(subFn func([]string)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.([]string)) - }) -} - -func (rt refreshableTyped) CurrentInt() int { - return rt.Current().(int) -} - -func (rt refreshableTyped) MapInt(mapFn func(int) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(int)) - }) -} - -func (rt refreshableTyped) SubscribeToInt(subFn func(int)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(int)) - }) -} - -func (rt refreshableTyped) CurrentIntPtr() *int { - return rt.Current().(*int) -} - -func (rt refreshableTyped) MapIntPtr(mapFn func(*int) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*int)) - }) -} - -func (rt refreshableTyped) SubscribeToIntPtr(subFn func(*int)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*int)) - }) -} - -func (rt refreshableTyped) CurrentInt64() int64 { - return rt.Current().(int64) -} - -func (rt refreshableTyped) MapInt64(mapFn func(int64) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(int64)) - }) -} - -func (rt refreshableTyped) SubscribeToInt64(subFn func(int64)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(int64)) - }) -} - -func (rt refreshableTyped) CurrentInt64Ptr() *int64 { - return rt.Current().(*int64) -} - -func (rt refreshableTyped) MapInt64Ptr(mapFn func(*int64) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*int64)) - }) -} - -func (rt refreshableTyped) SubscribeToInt64Ptr(subFn func(*int64)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*int64)) - }) -} - -func (rt refreshableTyped) CurrentFloat64() float64 { - return rt.Current().(float64) -} - -func (rt refreshableTyped) MapFloat64(mapFn func(float64) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(float64)) - }) -} - -func (rt refreshableTyped) SubscribeToFloat64(subFn func(float64)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(float64)) - }) -} - -func (rt refreshableTyped) CurrentFloat64Ptr() *float64 { - return rt.Current().(*float64) -} - -func (rt refreshableTyped) MapFloat64Ptr(mapFn func(*float64) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*float64)) - }) -} - -func (rt refreshableTyped) SubscribeToFloat64Ptr(subFn func(*float64)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*float64)) - }) -} - -func (rt refreshableTyped) CurrentBool() bool { - return rt.Current().(bool) -} - -func (rt refreshableTyped) MapBool(mapFn func(bool) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(bool)) - }) -} - -func (rt refreshableTyped) SubscribeToBool(subFn func(bool)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(bool)) - }) -} - -func (rt refreshableTyped) CurrentBoolPtr() *bool { - return rt.Current().(*bool) -} - -func (rt refreshableTyped) MapBoolPtr(mapFn func(*bool) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*bool)) - }) -} - -func (rt refreshableTyped) SubscribeToBoolPtr(subFn func(*bool)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*bool)) - }) -} - -func (rt refreshableTyped) CurrentDuration() time.Duration { - return rt.Current().(time.Duration) -} - -func (rt refreshableTyped) MapDuration(mapFn func(time.Duration) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(time.Duration)) - }) -} - -func (rt refreshableTyped) SubscribeToDuration(subFn func(time.Duration)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(time.Duration)) - }) -} - -func (rt refreshableTyped) CurrentDurationPtr() *time.Duration { - return rt.Current().(*time.Duration) -} - -func (rt refreshableTyped) MapDurationPtr(mapFn func(*time.Duration) interface{}) Refreshable { - return rt.Map(func(i interface{}) interface{} { - return mapFn(i.(*time.Duration)) - }) -} - -func (rt refreshableTyped) SubscribeToDurationPtr(subFn func(*time.Duration)) (unsubscribe func()) { - return rt.Subscribe(func(i interface{}) { - subFn(i.(*time.Duration)) - }) -} diff --git a/refreshable/refreshable_validating.go b/refreshable/refreshable_validating.go index d00bcf58..b6fa8adf 100644 --- a/refreshable/refreshable_validating.go +++ b/refreshable/refreshable_validating.go @@ -1,95 +1,56 @@ -// Copyright (c) 2021 Palantir Technologies. All rights reserved. +// Copyright (c) 2022 Palantir Technologies. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package refreshable -import ( - "errors" - "sync/atomic" -) - -type ValidatingRefreshable struct { - Refreshable - lastValidateErr *atomic.Value +type validRefreshable[T any] struct { + r Updatable[validRefreshableContainer[T]] } -// this is needed to be able to store the absence of an error in an atomic.Value -type errorWrapper struct { - err error +type validRefreshableContainer[T any] struct { + validated T + unvalidated T + lastErr error } -func (v *ValidatingRefreshable) LastValidateErr() error { - return v.lastValidateErr.Load().(errorWrapper).err -} +func (v *validRefreshable[T]) Current() T { return v.r.Current().validated } -// NewValidatingRefreshable returns a new Refreshable whose current value is the latest value that passes the provided -// validatingFn successfully. This returns an error if the current value of the passed in Refreshable does not pass the -// validatingFn or if the validatingFn or Refreshable are nil. -func NewValidatingRefreshable(origRefreshable Refreshable, validatingFn func(interface{}) error) (*ValidatingRefreshable, error) { - mappingFn := func(i interface{}) (interface{}, error) { - if err := validatingFn(i); err != nil { - return nil, err - } - return nil, nil - } - return newValidatingRefreshable(origRefreshable, mappingFn, false) +func (v *validRefreshable[T]) Subscribe(consumer func(T)) UnsubscribeFunc { + return v.r.Subscribe(func(val validRefreshableContainer[T]) { + consumer(val.validated) + }) } -// NewMapValidatingRefreshable is similar to NewValidatingRefreshable but allows for the function to return a mapping/mutation -// of the input object in addition to returning an error. The returned ValidatingRefreshable will contain the mapped value. -// The mapped value must always be of the same type (but not necessarily that of the input type). -func NewMapValidatingRefreshable(origRefreshable Refreshable, mappingFn func(interface{}) (interface{}, error)) (*ValidatingRefreshable, error) { - return newValidatingRefreshable(origRefreshable, mappingFn, true) +// Validation returns the most recent upstream Refreshable and its validation result. +// If nil, the validRefreshable is up-to-date with its original. +func (v *validRefreshable[T]) Validation() (T, error) { + c := v.r.Current() + return c.unvalidated, c.lastErr } -func newValidatingRefreshable(origRefreshable Refreshable, validatingFn func(interface{}) (interface{}, error), storeMappedVal bool) (*ValidatingRefreshable, error) { - if validatingFn == nil { - return nil, errors.New("failed to create validating Refreshable because the validating function was nil") - } - - if origRefreshable == nil { - return nil, errors.New("failed to create validating Refreshable because the passed in Refreshable was nil") - } - - var validatedRefreshable *DefaultRefreshable - currentVal := origRefreshable.Current() - mappedVal, err := validatingFn(currentVal) - if err != nil { - return nil, err - } - if storeMappedVal { - validatedRefreshable = NewDefaultRefreshable(mappedVal) - } else { - validatedRefreshable = NewDefaultRefreshable(currentVal) - } - - var lastValidateErr atomic.Value - lastValidateErr.Store(errorWrapper{}) - v := ValidatingRefreshable{ - Refreshable: validatedRefreshable, - lastValidateErr: &lastValidateErr, - } +func newValidRefreshable[T any, M any](original Refreshable[T], mappingFn func(T) (M, error)) (*validRefreshable[M], UnsubscribeFunc) { + valid := &validRefreshable[M]{r: newDefault(validRefreshableContainer[M]{})} + stop := original.Subscribe(func(valueT T) { + updateValidRefreshable(valid, valueT, mappingFn) + }) + return valid, stop +} - updateValueFn := func(i interface{}) { - mappedVal, err := validatingFn(i) - if err != nil { - v.lastValidateErr.Store(errorWrapper{err}) - return - } - if storeMappedVal { - err = validatedRefreshable.Update(mappedVal) - } else { - err = validatedRefreshable.Update(i) - } - v.lastValidateErr.Store(errorWrapper{err: err}) +func updateValidRefreshable[T any, M any](valid *validRefreshable[M], value T, mapFn func(T) (M, error)) { + validated := valid.r.Current().validated + unvalidated, err := mapFn(value) + if err == nil { + validated = unvalidated } + valid.r.Update(validRefreshableContainer[M]{ + validated: validated, + unvalidated: unvalidated, + lastErr: err, + }) +} - origRefreshable.Subscribe(updateValueFn) - - // manually update value after performing subscription. This ensures that, if the current value changed between when - // it was fetched earlier in the function and when the subscription was performed, it is properly captured. - updateValueFn(origRefreshable.Current()) - - return &v, nil +// identity is a validating map function that returns its input argument type. +func identity[T any](validatingFn func(T) error) func(i T) (T, error) { + return func(i T) (T, error) { return i, validatingFn(i) } } diff --git a/refreshable/refreshable_validating_test.go b/refreshable/refreshable_validating_test.go index f7d1c384..824f0144 100644 --- a/refreshable/refreshable_validating_test.go +++ b/refreshable/refreshable_validating_test.go @@ -8,91 +8,108 @@ import ( "errors" "net/url" "testing" + "time" - "github.com/palantir/pkg/refreshable" + "github.com/palantir/pkg/refreshable/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestValidatingRefreshable(t *testing.T) { type container struct{ Value string } - r := refreshable.NewDefaultRefreshable(container{Value: "value"}) - vr, err := refreshable.NewValidatingRefreshable(r, func(i interface{}) error { - if len(i.(container).Value) == 0 { + r := refreshable.New(container{Value: "value"}) + vr, _, err := refreshable.Validate[container](r, func(i container) error { + if len(i.Value) == 0 { return errors.New("empty") } return nil }) require.NoError(t, err) - require.NoError(t, vr.LastValidateErr()) - require.Equal(t, r.Current().(container).Value, "value") - require.Equal(t, vr.Current().(container).Value, "value") + v, err := vr.Validation() + require.NoError(t, err) + require.Equal(t, "value", v.Value) + require.Equal(t, "value", r.Current().Value) + require.Equal(t, "value", vr.Current().Value) // attempt bad update - err = r.Update(container{}) - require.NoError(t, err, "no err expected from default refreshable") - require.Equal(t, r.Current().(container).Value, "") - - require.EqualError(t, vr.LastValidateErr(), "empty", "expected err from validating refreshable") - require.Equal(t, vr.Current().(container).Value, "value", "expected unchanged validating refreshable") + r.Update(container{}) + require.Equal(t, r.Current().Value, "") + v, err = vr.Validation() + require.EqualError(t, err, "empty", "expected validation error") + require.Equal(t, "", v.Value, "expected invalid value from Validation") + require.Equal(t, vr.Current().Value, "value", "expected unchanged validating refreshable") // attempt good update - require.NoError(t, r.Update(container{Value: "value2"})) - require.NoError(t, vr.LastValidateErr()) - require.Equal(t, "value2", vr.Current().(container).Value) - require.Equal(t, "value2", r.Current().(container).Value) + r.Update(container{Value: "value2"}) + v, err = vr.Validation() + require.NoError(t, err) + require.Equal(t, "value2", v.Value) + require.Equal(t, "value2", vr.Current().Value) + require.Equal(t, "value2", r.Current().Value) } func TestMapValidatingRefreshable(t *testing.T) { - r := refreshable.NewDefaultRefreshable("https://palantir.com:443") - vr, err := refreshable.NewMapValidatingRefreshable(r, func(i interface{}) (interface{}, error) { - return url.Parse(i.(string)) - }) + r := refreshable.New("https://palantir.com:443") + vr, _, err := refreshable.MapWithError[string, *url.URL](r, url.Parse) + require.NoError(t, err) + _, err = vr.Validation() require.NoError(t, err) - require.NoError(t, vr.LastValidateErr()) - require.Equal(t, r.Current().(string), "https://palantir.com:443") - require.Equal(t, vr.Current().(*url.URL).Hostname(), "palantir.com") + require.Equal(t, r.Current(), "https://palantir.com:443") + require.Equal(t, vr.Current().Hostname(), "palantir.com") // attempt bad update - err = r.Update(":::error.com") - require.NoError(t, err, "no err expected from default refreshable") - assert.Equal(t, r.Current().(string), ":::error.com") - require.EqualError(t, vr.LastValidateErr(), "parse \":::error.com\": missing protocol scheme", "expected err from validating refreshable") - assert.Equal(t, vr.Current().(*url.URL).Hostname(), "palantir.com", "expected unchanged validating refreshable") + r.Update(":::error.com") + assert.Equal(t, r.Current(), ":::error.com") + _, err = vr.Validation() + require.EqualError(t, err, "parse \":::error.com\": missing protocol scheme", "expected err from validating refreshable") + assert.Equal(t, vr.Current().Hostname(), "palantir.com", "expected unchanged validating refreshable") // attempt good update - require.NoError(t, r.Update("https://example.com")) - require.NoError(t, vr.LastValidateErr()) - require.Equal(t, r.Current().(string), "https://example.com") - require.Equal(t, vr.Current().(*url.URL).Hostname(), "example.com") + r.Update("https://example.com") + _, err = vr.Validation() + require.NoError(t, err) + require.Equal(t, r.Current(), "https://example.com") + require.Equal(t, vr.Current().Hostname(), "example.com") } // TestValidatingRefreshable_SubscriptionRaceCondition tests that the ValidatingRefreshable stays current // if the underlying refreshable updates during the creation process. func TestValidatingRefreshable_SubscriptionRaceCondition(t *testing.T) { - r := &updateImmediatelyRefreshable{r: refreshable.NewDefaultRefreshable(1), newValue: 2} - vr, err := refreshable.NewValidatingRefreshable(r, func(i interface{}) error { return nil }) + //r := &updateImmediatelyRefreshable{r: refreshable.New(1), newValue: 2} + r := refreshable.New(1) + var seen1, seen2 bool + vr, _, err := refreshable.Validate[int](r, func(i int) error { + go r.Update(2) + switch i { + case 1: + seen1 = true + case 2: + seen2 = true + } + return nil + }) require.NoError(t, err) // If this returns 1, it is likely because the VR contains a stale value - assert.Equal(t, 2, vr.Current()) + assert.Eventually(t, func() bool { + return vr.Current() == 2 + }, time.Second, time.Millisecond) + + assert.True(t, seen1, "expected to process 1 value") + assert.True(t, seen2, "expected to process 2 value") } // updateImmediatelyRefreshable is a mock implementation which updates to newValue immediately when Current() is called type updateImmediatelyRefreshable struct { - r *refreshable.DefaultRefreshable - newValue interface{} + r refreshable.Updatable[int] + newValue int } -func (r *updateImmediatelyRefreshable) Current() interface{} { +func (r *updateImmediatelyRefreshable) Current() int { c := r.r.Current() - _ = r.r.Update(r.newValue) + r.r.Update(r.newValue) return c } -func (r *updateImmediatelyRefreshable) Subscribe(f func(interface{})) func() { +func (r *updateImmediatelyRefreshable) Subscribe(f func(int)) refreshable.UnsubscribeFunc { return r.r.Subscribe(f) } - -func (r *updateImmediatelyRefreshable) Map(f func(interface{}) interface{}) refreshable.Refreshable { - return r.r.Map(f) -}