diff --git a/option/query_lambda.go b/option/query_lambda.go index a532428b..9f16dd3d 100644 --- a/option/query_lambda.go +++ b/option/query_lambda.go @@ -2,6 +2,11 @@ package option import "github.com/rockset/rockset-go-client/openapi" +const ( + QueryLambdaActive = "ACTIVE" + QueryLambdaInvalid = "INVALID" +) + type ExecuteQueryLambdaRequest struct { openapi.ExecuteQueryLambdaRequest Tag string diff --git a/wait/collections.go b/wait/collections.go index 2e84bdb3..7a66977e 100644 --- a/wait/collections.go +++ b/wait/collections.go @@ -11,7 +11,7 @@ import ( // UntilCollectionReady waits until the collection is ready. func (w *Waiter) UntilCollectionReady(ctx context.Context, workspace, name string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.CollectionStatusReady}, + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.CollectionStatusReady}, nil, func(ctx context.Context) (string, error) { c, err := w.rc.GetCollection(ctx, workspace, name) return c.GetStatus(), err diff --git a/wait/fake/fake_resource_getter.go b/wait/fake/fake_resource_getter.go index 985c243d..2fd8ae5c 100644 --- a/wait/fake/fake_resource_getter.go +++ b/wait/fake/fake_resource_getter.go @@ -84,6 +84,22 @@ type FakeResourceGetter struct { result1 openapi.QueryInfo result2 error } + GetQueryLambdaVersionStub func(context.Context, string, string, string) (openapi.QueryLambdaVersion, error) + getQueryLambdaVersionMutex sync.RWMutex + getQueryLambdaVersionArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + } + getQueryLambdaVersionReturns struct { + result1 openapi.QueryLambdaVersion + result2 error + } + getQueryLambdaVersionReturnsOnCall map[int]struct { + result1 openapi.QueryLambdaVersion + result2 error + } GetViewStub func(context.Context, string, string) (openapi.View, error) getViewMutex sync.RWMutex getViewArgsForCall []struct { @@ -483,6 +499,73 @@ func (fake *FakeResourceGetter) GetQueryInfoReturnsOnCall(i int, result1 openapi }{result1, result2} } +func (fake *FakeResourceGetter) GetQueryLambdaVersion(arg1 context.Context, arg2 string, arg3 string, arg4 string) (openapi.QueryLambdaVersion, error) { + fake.getQueryLambdaVersionMutex.Lock() + ret, specificReturn := fake.getQueryLambdaVersionReturnsOnCall[len(fake.getQueryLambdaVersionArgsForCall)] + fake.getQueryLambdaVersionArgsForCall = append(fake.getQueryLambdaVersionArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + }{arg1, arg2, arg3, arg4}) + stub := fake.GetQueryLambdaVersionStub + fakeReturns := fake.getQueryLambdaVersionReturns + fake.recordInvocation("GetQueryLambdaVersion", []interface{}{arg1, arg2, arg3, arg4}) + fake.getQueryLambdaVersionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionCallCount() int { + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() + return len(fake.getQueryLambdaVersionArgsForCall) +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionCalls(stub func(context.Context, string, string, string) (openapi.QueryLambdaVersion, error)) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = stub +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionArgsForCall(i int) (context.Context, string, string, string) { + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() + argsForCall := fake.getQueryLambdaVersionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionReturns(result1 openapi.QueryLambdaVersion, result2 error) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = nil + fake.getQueryLambdaVersionReturns = struct { + result1 openapi.QueryLambdaVersion + result2 error + }{result1, result2} +} + +func (fake *FakeResourceGetter) GetQueryLambdaVersionReturnsOnCall(i int, result1 openapi.QueryLambdaVersion, result2 error) { + fake.getQueryLambdaVersionMutex.Lock() + defer fake.getQueryLambdaVersionMutex.Unlock() + fake.GetQueryLambdaVersionStub = nil + if fake.getQueryLambdaVersionReturnsOnCall == nil { + fake.getQueryLambdaVersionReturnsOnCall = make(map[int]struct { + result1 openapi.QueryLambdaVersion + result2 error + }) + } + fake.getQueryLambdaVersionReturnsOnCall[i] = struct { + result1 openapi.QueryLambdaVersion + result2 error + }{result1, result2} +} + func (fake *FakeResourceGetter) GetView(arg1 context.Context, arg2 string, arg3 string) (openapi.View, error) { fake.getViewMutex.Lock() ret, specificReturn := fake.getViewReturnsOnCall[len(fake.getViewArgsForCall)] @@ -816,6 +899,8 @@ func (fake *FakeResourceGetter) Invocations() map[string][][]interface{} { defer fake.getIntegrationMutex.RUnlock() fake.getQueryInfoMutex.RLock() defer fake.getQueryInfoMutex.RUnlock() + fake.getQueryLambdaVersionMutex.RLock() + defer fake.getQueryLambdaVersionMutex.RUnlock() fake.getViewMutex.RLock() defer fake.getViewMutex.RUnlock() fake.getVirtualInstanceMutex.RLock() diff --git a/wait/query.go b/wait/query.go index c5374c14..0ad535be 100644 --- a/wait/query.go +++ b/wait/query.go @@ -6,12 +6,14 @@ import ( "github.com/rockset/rockset-go-client/option" ) -// UntilQueryDone waits until queryID has either completed, errored, or been cancelled. +// UntilQueryDone waits until queryID has either completed. +// Returns ErrBadWaitState if the query failed or was cancelled. func (w *Waiter) UntilQueryDone(ctx context.Context, queryID string) error { - // TODO should this only wait for COMPLETED and return an error for ERROR and CANCELLED? - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []option.QueryState{option.QueryCompleted, option.QueryError, option.QueryCancelled}, - func(ctx context.Context) (option.QueryState, error) { - q, err := w.rc.GetQueryInfo(ctx, queryID) - return option.QueryState(q.GetStatus()), err - })) + return w.rc.RetryWithCheck(ctx, + ResourceHasState(ctx, []option.QueryState{option.QueryCompleted}, + []option.QueryState{option.QueryError, option.QueryCancelled}, + func(ctx context.Context) (option.QueryState, error) { + q, err := w.rc.GetQueryInfo(ctx, queryID) + return option.QueryState(q.GetStatus()), err + })) } diff --git a/wait/query_lambda.go b/wait/query_lambda.go new file mode 100644 index 00000000..35edb6fe --- /dev/null +++ b/wait/query_lambda.go @@ -0,0 +1,25 @@ +package wait + +import ( + "context" + + "github.com/rockset/rockset-go-client/option" +) + +// UntilQueryLambdaVersionGone waits until a query lambda is deleted, i.e. GetQueryLambda() returns "not found". +func (w *Waiter) UntilQueryLambdaVersionGone(ctx context.Context, workspace, name, version string) error { + return w.rc.RetryWithCheck(ctx, ResourceIsGone(ctx, func(ctx context.Context) error { + _, err := w.rc.GetQueryLambdaVersion(ctx, workspace, name, version) + return err + })) +} + +// UntilQueryLambdaVersionActive waits until the Virtual Instance is active. +func (w *Waiter) UntilQueryLambdaVersionActive(ctx context.Context, workspace, name, version string) error { + return w.rc.RetryWithCheck(ctx, + ResourceHasState(ctx, []string{option.QueryLambdaActive}, []string{option.QueryLambdaInvalid}, + func(ctx context.Context) (string, error) { + ql, err := w.rc.GetQueryLambdaVersion(ctx, workspace, name, version) + return ql.GetState(), err + })) +} diff --git a/wait/query_lambda_test.go b/wait/query_lambda_test.go new file mode 100644 index 00000000..6bd28384 --- /dev/null +++ b/wait/query_lambda_test.go @@ -0,0 +1,35 @@ +package wait_test + +import ( + "context" + "testing" + + "github.com/rockset/rockset-go-client/openapi" + "github.com/rockset/rockset-go-client/option" + "github.com/rockset/rockset-go-client/wait" + "github.com/stretchr/testify/assert" +) + +func TestWait_untilQueryLambdaActive(t *testing.T) { + ctx := context.TODO() + + rs := fakeRocksetClient() + rs.GetQueryLambdaVersionReturnsOnCall(0, openapi.QueryLambdaVersion{State: openapi.PtrString("")}, nil) + rs.GetQueryLambdaVersionReturnsOnCall(1, openapi.QueryLambdaVersion{State: openapi.PtrString(option.QueryLambdaActive)}, nil) + + err := wait.New(&rs).UntilQueryLambdaVersionActive(ctx, "ws", "ql", "v") + assert.NoError(t, err) + assert.Equal(t, 2, rs.GetQueryLambdaVersionCallCount()) +} + +func TestWait_untilQueryLambdaActive_invalid(t *testing.T) { + ctx := context.TODO() + + rs := fakeRocksetClient() + rs.GetQueryLambdaVersionReturnsOnCall(0, openapi.QueryLambdaVersion{State: openapi.PtrString("")}, nil) + rs.GetQueryLambdaVersionReturnsOnCall(1, openapi.QueryLambdaVersion{State: openapi.PtrString(option.QueryLambdaInvalid)}, nil) + + err := wait.New(&rs).UntilQueryLambdaVersionActive(ctx, "ws", "ql", "v") + assert.ErrorIs(t, err, wait.ErrBadWaitState) + assert.Equal(t, 2, rs.GetQueryLambdaVersionCallCount()) +} diff --git a/wait/virtual_instance.go b/wait/virtual_instance.go index 74f537cf..d66d8b50 100644 --- a/wait/virtual_instance.go +++ b/wait/virtual_instance.go @@ -12,7 +12,7 @@ import ( // UntilVirtualInstanceActive waits until the Virtual Instance is active. func (w *Waiter) UntilVirtualInstanceActive(ctx context.Context, id string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceActive}, + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceActive}, nil, func(ctx context.Context) (string, error) { vi, err := w.rc.GetVirtualInstance(ctx, id) return vi.GetState(), err @@ -29,7 +29,7 @@ func (w *Waiter) UntilVirtualInstanceGone(ctx context.Context, id string) error // UntilVirtualInstanceSuspended waits until the Virtual Instance is suspended. func (w *Waiter) UntilVirtualInstanceSuspended(ctx context.Context, id string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceSuspended}, + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.VirtualInstanceSuspended}, nil, func(ctx context.Context) (string, error) { vi, err := w.rc.GetVirtualInstance(ctx, id) return vi.GetState(), err @@ -39,7 +39,7 @@ func (w *Waiter) UntilVirtualInstanceSuspended(ctx context.Context, id string) e // UntilMountActive waits until the collection mount is active, and queries can be issued to it on the // virtual instance. func (w *Waiter) UntilMountActive(ctx context.Context, vID, workspace, collection string) error { - return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.MountActive}, + return w.rc.RetryWithCheck(ctx, ResourceHasState(ctx, []string{option.MountActive}, nil, func(ctx context.Context) (string, error) { cm, err := w.rc.GetCollectionMount(ctx, vID, workspace+"."+collection) return cm.GetState(), err diff --git a/wait/wait.go b/wait/wait.go index 4b1c0126..091ef4df 100644 --- a/wait/wait.go +++ b/wait/wait.go @@ -3,6 +3,7 @@ package wait import ( "context" "errors" + "fmt" "github.com/rs/zerolog" @@ -25,6 +26,7 @@ type ResourceGetter interface { GetCollectionMount(ctx context.Context, id, collectionPath string) (openapi.CollectionMount, error) GetIntegration(ctx context.Context, name string) (openapi.Integration, error) GetQueryInfo(ctx context.Context, queryID string) (openapi.QueryInfo, error) + GetQueryLambdaVersion(ctx context.Context, workspace, name, version string) (openapi.QueryLambdaVersion, error) GetView(ctx context.Context, workspace, name string) (openapi.View, error) GetVirtualInstance(ctx context.Context, id string) (openapi.VirtualInstance, error) GetWorkspace(ctx context.Context, name string) (openapi.Workspace, error) @@ -34,8 +36,12 @@ func New(rs ResourceGetter) *Waiter { return &Waiter{rs} } -// ResourceHasState implements RetryFn to wait until the resource has the desired state -func ResourceHasState[T comparable](ctx context.Context, states []T, +var ErrBadWaitState = errors.New("encountered bad state while waiting") + +// ResourceHasState implements RetryFn to wait until the resource has the desired state, and if a bad state is +// encountered it will return ErrBadWaitState +func ResourceHasState[T comparable](ctx context.Context, validStates, badStates []T, + // TODO should T be Stringer instead? Then all fn func(ctx context.Context) (T, error)) retry.CheckFn { return func() (bool, error) { zl := zerolog.Ctx(ctx) @@ -44,11 +50,16 @@ func ResourceHasState[T comparable](ctx context.Context, states []T, return false, err } - for _, s := range states { + for _, s := range validStates { if state == s { return false, nil } } + for _, s := range badStates { + if state == s { + return false, fmt.Errorf("%w: %v", ErrBadWaitState, state) + } + } zl.Trace().Any("current", state).Msg("waiting for resource state") diff --git a/wait/wait_test.go b/wait/wait_test.go index 92308e00..72e4417c 100644 --- a/wait/wait_test.go +++ b/wait/wait_test.go @@ -96,16 +96,16 @@ func (s *WaitTestSuite) TestResourceIsGone() { } }) - retry, err := rc() - assert.True(s.T(), retry) + r, err := rc() + assert.True(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.Error(s.T(), err) } @@ -113,7 +113,7 @@ func (s *WaitTestSuite) TestResourceHasState() { ctx := context.TODO() var counter int - rc := wait.ResourceHasState(ctx, []string{"foo", "bar"}, func(ctx context.Context) (string, error) { + rc := wait.ResourceHasState(ctx, []string{"foo", "bar"}, nil, func(ctx context.Context) (string, error) { defer func() { counter++ }() switch counter { @@ -126,15 +126,45 @@ func (s *WaitTestSuite) TestResourceHasState() { } }) - retry, err := rc() - assert.True(s.T(), retry) + r, err := rc() + assert.True(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.NoError(s.T(), err) - retry, err = rc() - assert.False(s.T(), retry) + r, err = rc() + assert.False(s.T(), r) assert.Error(s.T(), err) } + +func (s *WaitTestSuite) TestResourceHasState_badState() { + ctx := context.TODO() + var counter int + + rc := wait.ResourceHasState(ctx, []string{"foo", "bar"}, []string{"err"}, func(ctx context.Context) (string, error) { + defer func() { counter++ }() + + switch counter { + case 0: + return "baz", nil + case 1: + return "bar", nil + default: + return "err", nil + } + }) + + r, err := rc() + assert.True(s.T(), r) + assert.NoError(s.T(), err) + + r, err = rc() + assert.False(s.T(), r) + assert.NoError(s.T(), err) + + r, err = rc() + assert.False(s.T(), r) + assert.ErrorIs(s.T(), err, wait.ErrBadWaitState) +}