From 4c76be456c0b831d9afacf05ed4f298adb14f2be Mon Sep 17 00:00:00 2001 From: Martin Englund Date: Wed, 18 Oct 2023 09:17:56 -0700 Subject: [PATCH] make it possible to wait for a query lambda --- errors/wait.go | 25 +++++++++ option/collection.go | 10 ++-- option/query.go | 2 + option/query_lambda.go | 9 ++++ option/virtual_instance.go | 42 ++++++++------- wait/collections.go | 6 +-- wait/collections_test.go | 8 +-- wait/fake/fake_resource_getter.go | 85 +++++++++++++++++++++++++++++++ wait/query.go | 16 +++--- wait/query_lambda.go | 25 +++++++++ wait/query_lambda_test.go | 36 +++++++++++++ wait/query_test.go | 24 +++++++++ wait/virtual_instance.go | 18 +++---- wait/virtual_instance_test.go | 24 ++++----- wait/wait.go | 20 ++++++-- wait/wait_test.go | 59 ++++++++++++++++----- 16 files changed, 336 insertions(+), 73 deletions(-) create mode 100644 errors/wait.go create mode 100644 wait/query_lambda.go create mode 100644 wait/query_lambda_test.go create mode 100644 wait/query_test.go diff --git a/errors/wait.go b/errors/wait.go new file mode 100644 index 00000000..09ebea01 --- /dev/null +++ b/errors/wait.go @@ -0,0 +1,25 @@ +package errors + +import ( + "errors" + "fmt" +) + +var ErrBadWaitState = errors.New("encountered bad state while waiting for resource") + +type BadWaitState struct { + State string +} + +func (e BadWaitState) Error() string { + return fmt.Sprintf("%s: %s", ErrBadWaitState.Error(), e.State) +} +func (e BadWaitState) Unwrap() error { + return ErrBadWaitState +} + +func NewBadWaitStateError(state string) BadWaitState { + return BadWaitState{ + State: state, + } +} diff --git a/option/collection.go b/option/collection.go index 26527955..67c82272 100644 --- a/option/collection.go +++ b/option/collection.go @@ -7,10 +7,14 @@ import ( "github.com/rockset/rockset-go-client/openapi" ) +type CollectionStatus string + +func (c CollectionStatus) String() string { return string(c) } + const ( - CollectionStatusCreated = "CREATED" - CollectionStatusInitialized = "INITIALIZED" - CollectionStatusReady = "READY" + CollectionStatusCreated CollectionStatus = "CREATED" + CollectionStatusInitialized CollectionStatus = "INITIALIZED" + CollectionStatusReady CollectionStatus = "READY" ) type ListCollectionOptions struct { diff --git a/option/query.go b/option/query.go index 20a73008..f4e3528c 100644 --- a/option/query.go +++ b/option/query.go @@ -4,6 +4,8 @@ import "github.com/rockset/rockset-go-client/openapi" type QueryState string +func (q QueryState) String() string { return string(q) } + const ( QueryQueued QueryState = "QUEUED" QueryRunning QueryState = "RUNNING" diff --git a/option/query_lambda.go b/option/query_lambda.go index a532428b..30837f54 100644 --- a/option/query_lambda.go +++ b/option/query_lambda.go @@ -2,6 +2,15 @@ package option import "github.com/rockset/rockset-go-client/openapi" +type QueryLambdaState string + +func (q QueryLambdaState) String() string { return string(q) } + +const ( + QueryLambdaActive QueryLambdaState = "ACTIVE" + QueryLambdaInvalid QueryLambdaState = "INVALID" +) + type ExecuteQueryLambdaRequest struct { openapi.ExecuteQueryLambdaRequest Tag string diff --git a/option/virtual_instance.go b/option/virtual_instance.go index c9ce0670..f98bdc3b 100644 --- a/option/virtual_instance.go +++ b/option/virtual_instance.go @@ -2,24 +2,32 @@ package option import "time" +type VirtualInstanceState string + +func (v VirtualInstanceState) String() string { return string(v) } + +type MountState string + +func (m MountState) String() string { return string(m) } + const ( - VirtualInstanceInitializing = "INITIALIZING" - VirtualInstanceProvisioningResources = "PROVISIONING_RESOURCES" - VirtualInstanceRebalancingCollections = "REBALANCING_COLLECTIONS" - VirtualInstanceActive = "ACTIVE" - VirtualInstanceSuspending = "SUSPENDING" - VirtualInstanceSuspended = "SUSPENDED" - VirtualInstanceResuming = "RESUMING" - VirtualInstanceDeleted = "DELETED" - - MountCreating = "CREATING" - MountActive = "ACTIVE" - MountRefreshing = "REFRESHING" - MountExpired = "EXPIRED" - MountDeleting = "DELETING" - MountSwitchingRefreshType = "SWITCHING_REFRESH_TYPE" - MountSuspended = "SUSPENDED" - MountSuspending = "SUSPENDING" + VirtualInstanceInitializing VirtualInstanceState = "INITIALIZING" + VirtualInstanceProvisioningResources VirtualInstanceState = "PROVISIONING_RESOURCES" + VirtualInstanceRebalancingCollections VirtualInstanceState = "REBALANCING_COLLECTIONS" + VirtualInstanceActive VirtualInstanceState = "ACTIVE" + VirtualInstanceSuspending VirtualInstanceState = "SUSPENDING" + VirtualInstanceSuspended VirtualInstanceState = "SUSPENDED" + VirtualInstanceResuming VirtualInstanceState = "RESUMING" + VirtualInstanceDeleted VirtualInstanceState = "DELETED" + + MountCreating MountState = "CREATING" + MountActive MountState = "ACTIVE" + MountRefreshing MountState = "REFRESHING" + MountExpired MountState = "EXPIRED" + MountDeleting MountState = "DELETING" + MountSwitchingRefreshType MountState = "SWITCHING_REFRESH_TYPE" + MountSuspended MountState = "SUSPENDED" + MountSuspending MountState = "SUSPENDING" ) // VirtualInstanceOptions contains the optional settings for a virtual instance. diff --git a/wait/collections.go b/wait/collections.go index 2e84bdb3..e71593e2 100644 --- a/wait/collections.go +++ b/wait/collections.go @@ -11,10 +11,10 @@ import ( // UntilCollectionReady waits until the collection is ready. func (w *Waiter) UntilCollectionReady(ctx context.Context, workspace, name string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.CollectionStatusReady}, - func(ctx context.Context) (string, error) { + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.CollectionStatus{option.CollectionStatusReady}, nil, + func(ctx context.Context) (option.CollectionStatus, error) { c, err := w.rc.GetCollection(ctx, workspace, name) - return c.GetStatus(), err + return option.CollectionStatus(c.GetStatus()), err })) } diff --git a/wait/collections_test.go b/wait/collections_test.go index 3c00d57c..1afb401e 100644 --- a/wait/collections_test.go +++ b/wait/collections_test.go @@ -15,9 +15,9 @@ func TestWait_untilCollectionReady(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetCollectionReturnsOnCall(0, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusInitialized)}, nil) - rs.GetCollectionReturnsOnCall(1, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusCreated)}, nil) - rs.GetCollectionReturnsOnCall(2, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusReady)}, nil) + rs.GetCollectionReturnsOnCall(0, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusInitialized.String())}, nil) + rs.GetCollectionReturnsOnCall(1, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusCreated.String())}, nil) + rs.GetCollectionReturnsOnCall(2, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusReady.String())}, nil) err := wait.New(&rs).UntilCollectionReady(ctx, "workspace", "collection") assert.NoError(t, err) @@ -28,7 +28,7 @@ func TestWait_untilCollectionGone(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetCollectionReturnsOnCall(0, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusReady)}, nil) + rs.GetCollectionReturnsOnCall(0, openapi.Collection{Status: openapi.PtrString(option.CollectionStatusReady.String())}, nil) rs.GetCollectionReturnsOnCall(1, openapi.Collection{}, NotFoundErr) err := wait.New(&rs).UntilCollectionGone(ctx, "workspace", "collection") diff --git a/wait/fake/fake_resource_getter.go b/wait/fake/fake_resource_getter.go index 985c243d..2fd8ae5c 100644 --- a/wait/fake/fake_resource_getter.go +++ b/wait/fake/fake_resource_getter.go @@ -84,6 +84,22 @@ type FakeResourceGetter struct { result1 openapi.QueryInfo result2 error } + GetQueryLambdaVersionStub func(context.Context, string, string, string) (openapi.QueryLambdaVersion, error) + getQueryLambdaVersionMutex sync.RWMutex + getQueryLambdaVersionArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + } + getQueryLambdaVersionReturns struct { + result1 openapi.QueryLambdaVersion + result2 error + } + getQueryLambdaVersionReturnsOnCall map[int]struct { + result1 openapi.QueryLambdaVersion + result2 error + } GetViewStub func(context.Context, string, string) (openapi.View, error) getViewMutex sync.RWMutex getViewArgsForCall []struct { @@ -483,6 +499,73 @@ func (fake *FakeResourceGetter) GetQueryInfoReturnsOnCall(i int, result1 openapi }{result1, result2} } +func (fake *FakeResourceGetter) GetQueryLambdaVersion(arg1 context.Context, arg2 string, arg3 string, arg4 string) (openapi.QueryLambdaVersion, error) { + fake.getQueryLambdaVersionMutex.Lock() + ret, specificReturn := fake.getQueryLambdaVersionReturnsOnCall[len(fake.getQueryLambdaVersionArgsForCall)] + fake.getQueryLambdaVersionArgsForCall = append(fake.getQueryLambdaVersionArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + }{arg1, arg2, arg3, arg4}) + stub := fake.GetQueryLambdaVersionStub + fakeReturns := fake.getQueryLambdaVersionReturns + fake.recordInvocation("GetQueryLambdaVersion", []interface{}{arg1, arg2, arg3, arg4}) + fake.getQueryLambdaVersionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionCallCount() int { + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() + return len(fake.getQueryLambdaVersionArgsForCall) +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionCalls(stub func(context.Context, string, string, string) (openapi.QueryLambdaVersion, error)) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = stub +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionArgsForCall(i int) (context.Context, string, string, string) { + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() + argsForCall := fake.getQueryLambdaVersionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionReturns(result1 openapi.QueryLambdaVersion, result2 error) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = nil + fake.getQueryLambdaVersionReturns = struct { + result1 openapi.QueryLambdaVersion + result2 error + }{result1, result2} +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionReturnsOnCall(i int, result1 openapi.QueryLambdaVersion, result2 error) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = nil + if fake.getQueryLambdaVersionReturnsOnCall == nil { + fake.getQueryLambdaVersionReturnsOnCall = make(map[int]struct { + result1 openapi.QueryLambdaVersion + result2 error + }) + } + fake.getQueryLambdaVersionReturnsOnCall[i] = struct { + result1 openapi.QueryLambdaVersion + result2 error + }{result1, result2} +} + func (fake *FakeResourceGetter) GetView(arg1 context.Context, arg2 string, arg3 string) (openapi.View, error) { fake.getViewMutex.Lock() ret, specificReturn := fake.getViewReturnsOnCall[len(fake.getViewArgsForCall)] @@ -816,6 +899,8 @@ func (fake *FakeResourceGetter) Invocations() map[string][][]interface{} { defer fake.getIntegrationMutex.RUnlock() fake.getQueryInfoMutex.RLock() defer fake.getQueryInfoMutex.RUnlock() + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() fake.getViewMutex.RLock() defer fake.getViewMutex.RUnlock() fake.getVirtualInstanceMutex.RLock() diff --git a/wait/query.go b/wait/query.go index c5374c14..0ad535be 100644 --- a/wait/query.go +++ b/wait/query.go @@ -6,12 +6,14 @@ import ( "github.com/rockset/rockset-go-client/option" ) -// UntilQueryDone waits until queryID has either completed, errored, or been cancelled. +// UntilQueryDone waits until queryID has either completed. +// Returns ErrBadWaitState if the query failed or was cancelled. func (w *Waiter) UntilQueryDone(ctx context.Context, queryID string) error { - // TODO should this only wait for COMPLETED and return an error for ERROR and CANCELLED? - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.QueryState{option.QueryCompleted, option.QueryError, option.QueryCancelled}, - func(ctx context.Context) (option.QueryState, error) { - q, err := w.rc.GetQueryInfo(ctx, queryID) - return option.QueryState(q.GetStatus()), err - })) + return w.rc.RetryWithCheck(ctx, + ResourceHasState(ctx, []option.QueryState{option.QueryCompleted}, + []option.QueryState{option.QueryError, option.QueryCancelled}, + func(ctx context.Context) (option.QueryState, error) { + q, err := w.rc.GetQueryInfo(ctx, queryID) + return option.QueryState(q.GetStatus()), err + })) } diff --git a/wait/query_lambda.go b/wait/query_lambda.go new file mode 100644 index 00000000..70ea8781 --- /dev/null +++ b/wait/query_lambda.go @@ -0,0 +1,25 @@ +package wait + +import ( + "context" + + "github.com/rockset/rockset-go-client/option" +) + +// UntilQueryLambdaVersionGone waits until a query lambda is deleted, i.e. GetQueryLambda() returns "not found". +func (w *Waiter) UntilQueryLambdaVersionGone(ctx context.Context, workspace, name, version string) error { + return w.rc.RetryWithCheck(ctx, ResourceIsGone(ctx, func(ctx context.Context) error { + _, err := w.rc.GetQueryLambdaVersion(ctx, workspace, name, version) + return err + })) +} + +// UntilQueryLambdaVersionActive waits until the Virtual Instance is active. +func (w *Waiter) UntilQueryLambdaVersionActive(ctx context.Context, workspace, name, version string) error { + return w.rc.RetryWithCheck(ctx, + ResourceHasState(ctx, []option.QueryLambdaState{option.QueryLambdaActive}, []option.QueryLambdaState{option.QueryLambdaInvalid}, + func(ctx context.Context) (option.QueryLambdaState, error) { + ql, err := w.rc.GetQueryLambdaVersion(ctx, workspace, name, version) + return option.QueryLambdaState(ql.GetState()), err + })) +} diff --git a/wait/query_lambda_test.go b/wait/query_lambda_test.go new file mode 100644 index 00000000..5c6219da --- /dev/null +++ b/wait/query_lambda_test.go @@ -0,0 +1,36 @@ +package wait_test + +import ( + "context" + rockerr "github.com/rockset/rockset-go-client/errors" + "testing" + + "github.com/rockset/rockset-go-client/openapi" + "github.com/rockset/rockset-go-client/option" + "github.com/rockset/rockset-go-client/wait" + "github.com/stretchr/testify/assert" +) + +func TestWait_untilQueryLambdaActive(t *testing.T) { + ctx := context.TODO() + + rs := fakeRocksetClient() + rs.GetQueryLambdaVersionReturnsOnCall(0, openapi.QueryLambdaVersion{State: openapi.PtrString("")}, nil) + rs.GetQueryLambdaVersionReturnsOnCall(1, openapi.QueryLambdaVersion{State: openapi.PtrString(option.QueryLambdaActive.String())}, nil) + + err := wait.New(&rs).UntilQueryLambdaVersionActive(ctx, "ws", "ql", "v") + assert.NoError(t, err) + assert.Equal(t, 2, rs.GetQueryLambdaVersionCallCount()) +} + +func TestWait_untilQueryLambdaActive_invalid(t *testing.T) { + ctx := context.TODO() + + rs := fakeRocksetClient() + rs.GetQueryLambdaVersionReturnsOnCall(0, openapi.QueryLambdaVersion{State: openapi.PtrString("")}, nil) + rs.GetQueryLambdaVersionReturnsOnCall(1, openapi.QueryLambdaVersion{State: openapi.PtrString(option.QueryLambdaInvalid.String())}, nil) + + err := wait.New(&rs).UntilQueryLambdaVersionActive(ctx, "ws", "ql", "v") + assert.ErrorIs(t, err, rockerr.ErrBadWaitState) + assert.Equal(t, 2, rs.GetQueryLambdaVersionCallCount()) +} diff --git a/wait/query_test.go b/wait/query_test.go new file mode 100644 index 00000000..6b1ae6fb --- /dev/null +++ b/wait/query_test.go @@ -0,0 +1,24 @@ +package wait_test + +import ( + "context" + "github.com/rockset/rockset-go-client/option" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/rockset/rockset-go-client/openapi" + "github.com/rockset/rockset-go-client/wait" +) + +func TestWait_untilQueryDone(t *testing.T) { + ctx := context.TODO() + + rs := fakeRocksetClient() + rs.GetQueryInfoReturnsOnCall(0, openapi.QueryInfo{Status: openapi.PtrString(option.QueryRunning.String())}, nil) + rs.GetQueryInfoReturnsOnCall(1, openapi.QueryInfo{Status: openapi.PtrString(option.QueryCompleted.String())}, nil) + + err := wait.New(&rs).UntilQueryDone(ctx, "id") + assert.NoError(t, err) + assert.Equal(t, 2, rs.GetQueryInfoCallCount()) +} diff --git a/wait/virtual_instance.go b/wait/virtual_instance.go index 74f537cf..84da320a 100644 --- a/wait/virtual_instance.go +++ b/wait/virtual_instance.go @@ -12,10 +12,10 @@ import ( // UntilVirtualInstanceActive waits until the Virtual Instance is active. func (w *Waiter) UntilVirtualInstanceActive(ctx context.Context, id string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceActive}, - func(ctx context.Context) (string, error) { + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.VirtualInstanceState{option.VirtualInstanceActive}, nil, + func(ctx context.Context) (option.VirtualInstanceState, error) { vi, err := w.rc.GetVirtualInstance(ctx, id) - return vi.GetState(), err + return option.VirtualInstanceState(vi.GetState()), err })) } @@ -29,20 +29,20 @@ func (w *Waiter) UntilVirtualInstanceGone(ctx context.Context, id string) error // UntilVirtualInstanceSuspended waits until the Virtual Instance is suspended. func (w *Waiter) UntilVirtualInstanceSuspended(ctx context.Context, id string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceSuspended}, - func(ctx context.Context) (string, error) { + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.VirtualInstanceState{option.VirtualInstanceSuspended}, nil, + func(ctx context.Context) (option.VirtualInstanceState, error) { vi, err := w.rc.GetVirtualInstance(ctx, id) - return vi.GetState(), err + return option.VirtualInstanceState(vi.GetState()), err })) } // UntilMountActive waits until the collection mount is active, and queries can be issued to it on the // virtual instance. func (w *Waiter) UntilMountActive(ctx context.Context, vID, workspace, collection string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.MountActive}, - func(ctx context.Context) (string, error) { + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.MountState{option.MountActive}, nil, + func(ctx context.Context) (option.MountState, error) { cm, err := w.rc.GetCollectionMount(ctx, vID, workspace+"."+collection) - return cm.GetState(), err + return option.MountState(cm.GetState()), err })) } diff --git a/wait/virtual_instance_test.go b/wait/virtual_instance_test.go index 9469db2b..5d613fd9 100644 --- a/wait/virtual_instance_test.go +++ b/wait/virtual_instance_test.go @@ -17,8 +17,8 @@ func TestWait_untilVirtualInstanceActive(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceInitializing)}, nil) - rs.GetVirtualInstanceReturnsOnCall(1, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive)}, nil) + rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceInitializing.String())}, nil) + rs.GetVirtualInstanceReturnsOnCall(1, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive.String())}, nil) err := wait.New(&rs).UntilVirtualInstanceActive(ctx, "id") assert.NoError(t, err) @@ -29,9 +29,9 @@ func TestWait_untilVirtualInstanceSuspended(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive)}, nil) - rs.GetVirtualInstanceReturnsOnCall(1, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceSuspending)}, nil) - rs.GetVirtualInstanceReturnsOnCall(2, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceSuspended)}, nil) + rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive.String())}, nil) + rs.GetVirtualInstanceReturnsOnCall(1, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceSuspending.String())}, nil) + rs.GetVirtualInstanceReturnsOnCall(2, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceSuspended.String())}, nil) err := wait.New(&rs).UntilVirtualInstanceSuspended(ctx, "id") assert.NoError(t, err) @@ -42,7 +42,7 @@ func TestWait_untilVirtualInstanceGone(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive)}, nil) + rs.GetVirtualInstanceReturnsOnCall(0, openapi.VirtualInstance{State: openapi.PtrString(option.VirtualInstanceActive.String())}, nil) rs.GetVirtualInstanceReturnsOnCall(1, openapi.VirtualInstance{}, NotFoundErr) err := wait.New(&rs).UntilVirtualInstanceGone(ctx, "id") @@ -54,8 +54,8 @@ func TestWait_untilMountActive(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountCreating)}, nil) - rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountActive)}, nil) + rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountCreating.String())}, nil) + rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountActive.String())}, nil) err := wait.New(&rs).UntilMountActive(ctx, "id", "workspace", "collection") assert.NoError(t, err) @@ -66,8 +66,8 @@ func TestWait_untilMountGone404(t *testing.T) { ctx := context.TODO() rs := fakeRocksetClient() - rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountActive)}, nil) - rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountDeleting)}, nil) + rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountActive.String())}, nil) + rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountDeleting.String())}, nil) rs.GetCollectionMountReturnsOnCall(2, openapi.CollectionMount{}, NotFoundErr) err := wait.New(&rs).UntilMountGone(ctx, "id", "workspace", "collection") @@ -84,8 +84,8 @@ func TestWait_untilMountGone400(t *testing.T) { e404.ErrorModel.Message = &msg rs := fakeRocksetClient() - rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountActive)}, nil) - rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountDeleting)}, nil) + rs.GetCollectionMountReturnsOnCall(0, openapi.CollectionMount{State: openapi.PtrString(option.MountActive.String())}, nil) + rs.GetCollectionMountReturnsOnCall(1, openapi.CollectionMount{State: openapi.PtrString(option.MountDeleting.String())}, nil) rs.GetCollectionMountReturnsOnCall(2, openapi.CollectionMount{}, e404) err := wait.New(&rs).UntilMountGone(ctx, "id", "workspace", "collection") diff --git a/wait/wait.go b/wait/wait.go index 4b1c0126..78c7d20b 100644 --- a/wait/wait.go +++ b/wait/wait.go @@ -3,7 +3,7 @@ package wait import ( "context" "errors" - + "fmt" "github.com/rs/zerolog" rockerr "github.com/rockset/rockset-go-client/errors" @@ -25,6 +25,7 @@ type ResourceGetter interface { GetCollectionMount(ctx context.Context, id, collectionPath string) (openapi.CollectionMount, error) GetIntegration(ctx context.Context, name string) (openapi.Integration, error) GetQueryInfo(ctx context.Context, queryID string) (openapi.QueryInfo, error) + GetQueryLambdaVersion(ctx context.Context, workspace, name, version string) (openapi.QueryLambdaVersion, error) GetView(ctx context.Context, workspace, name string) (openapi.View, error) GetVirtualInstance(ctx context.Context, id string) (openapi.VirtualInstance, error) GetWorkspace(ctx context.Context, name string) (openapi.Workspace, error) @@ -34,8 +35,12 @@ func New(rs ResourceGetter) *Waiter { return &Waiter{rs} } -// ResourceHasState implements RetryFn to wait until the resource has the desired state -func ResourceHasState[T comparable](ctx context.Context, states []T, +var ErrBadWaitState = errors.New("encountered bad state while waiting") + +// ResourceHasState implements RetryFn to wait until the resource has the desired state, and if a bad state is +// encountered it will return ErrBadWaitState +func ResourceHasState[T fmt.Stringer](ctx context.Context, validStates, badStates []T, + // TODO should T be Stringer instead? Then all fn func(ctx context.Context) (T, error)) retry.CheckFn { return func() (bool, error) { zl := zerolog.Ctx(ctx) @@ -44,11 +49,16 @@ func ResourceHasState[T comparable](ctx context.Context, states []T, return false, err } - for _, s := range states { - if state == s { + for _, s := range validStates { + if state.String() == s.String() { return false, nil } } + for _, s := range badStates { + if state.String() == s.String() { + return false, rockerr.NewBadWaitStateError(state.String()) // fmt.Errorf("%w: %v", ErrBadWaitState, state) + } + } zl.Trace().Any("current", state).Msg("waiting for resource state") diff --git a/wait/wait_test.go b/wait/wait_test.go index 92308e00..88621fbb 100644 --- a/wait/wait_test.go +++ b/wait/wait_test.go @@ -96,24 +96,27 @@ func (s *WaitTestSuite) TestResourceIsGone() { } }) - retry, err := rc() - assert.True(s.T(), retry) + r, err := rc() + assert.True(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.Error(s.T(), err) } +type state string + +func (s state) String() string { return string(s) } func (s *WaitTestSuite) TestResourceHasState() { ctx := context.TODO() var counter int - rc := wait.ResourceHasState(ctx, []string{"foo", "bar"}, func(ctx context.Context) (string, error) { + rc := wait.ResourceHasState(ctx, []state{"foo", "bar"}, nil, func(ctx context.Context) (state, error) { defer func() { counter++ }() switch counter { @@ -126,15 +129,45 @@ func (s *WaitTestSuite) TestResourceHasState() { } }) - retry, err := rc() - assert.True(s.T(), retry) + r, err := rc() + assert.True(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.Error(s.T(), err) } + +func (s *WaitTestSuite) TestResourceHasState_badState() { + ctx := context.TODO() + var counter int + + rc := wait.ResourceHasState(ctx, []state{"foo", "bar"}, []state{"err"}, func(ctx context.Context) (state, error) { + defer func() { counter++ }() + + switch counter { + case 0: + return "baz", nil + case 1: + return "bar", nil + default: + return "err", nil + } + }) + + r, err := rc() + assert.True(s.T(), r) + assert.NoError(s.T(), err) + + r, err = rc() + assert.False(s.T(), r) + assert.NoError(s.T(), err) + + r, err = rc() + assert.False(s.T(), r) + assert.ErrorIs(s.T(), err, rockerr.ErrBadWaitState) +}