From 4998581d0143a2f87526026514d446707dee18d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Garc=C3=ADa=20Hierro?= Date: Fri, 5 May 2023 11:07:25 +0100 Subject: [PATCH] refactor: remove embedding of context.Context in resolve.Context (#522) Instead, use an approach very similar to the one used by net/http.Request Also, update CI to use Go 1.19 and 1.20, enable race detector and run the linters as a separate job. --- .github/workflows/ci-full.yml | 20 - .github/workflows/main.yml | 44 +- Makefile | 39 +- .../graphql_datasource/graphql_datasource.go | 11 +- .../graphql_datasource_test.go | 67 +- .../graphql_datasource/graphql_sse_handler.go | 23 +- .../graphql_subscription_client_test.go | 2 + .../graphql_tws_handler_test.go | 8 + .../graphql_ws_handler_test.go | 6 + .../datasource/httpclient/httpclient.go | 73 +- .../datasource/kafka_datasource/config.go | 141 --- .../kafka_datasource/config_test.go | 149 --- .../kafka_datasource/kafka_consumer_group.go | 323 ------- .../kafka_consumer_group_test.go | 391 -------- .../kafka_datasource/kafka_datasource.go | 75 -- .../kafka_datasource/kafka_datasource_test.go | 317 ------- .../sarama_config_parameters_test.go | 885 ------------------ .../kafka_datasource/testdata/kafka_jaas.conf | 8 - .../staticdatasource/static_datasource.go | 2 +- pkg/engine/resolve/dataloader_test.go | 24 +- pkg/engine/resolve/defer_test.go | 2 +- pkg/engine/resolve/fetcher.go | 4 +- pkg/engine/resolve/inputtemplate.go | 42 +- pkg/engine/resolve/inputtemplate_test.go | 4 +- pkg/engine/resolve/resolve.go | 34 +- pkg/engine/resolve/resolve_mock_test.go | 73 +- pkg/engine/resolve/resolve_test.go | 136 +-- .../datasource_http_polling_stream.go | 26 +- pkg/fastbuffer/fastbuffer.go | 10 + pkg/graphql/execution_engine_v2.go | 2 +- .../execution_engine_v2_norace_test.go | 170 ++++ pkg/graphql/execution_engine_v2_test.go | 155 --- pkg/subscription/context.go | 37 +- pkg/subscription/context_test.go | 8 +- pkg/subscription/handler.go | 8 +- pkg/subscription/handler_test.go | 2 +- pkg/subscription/mock_client_test.go | 20 +- .../federation_intergation_test.go | 4 + 38 files changed, 585 insertions(+), 2760 deletions(-) delete mode 100644 .github/workflows/ci-full.yml delete mode 100644 pkg/engine/datasource/kafka_datasource/config.go delete mode 100644 pkg/engine/datasource/kafka_datasource/config_test.go delete mode 100644 pkg/engine/datasource/kafka_datasource/kafka_consumer_group.go delete mode 100644 pkg/engine/datasource/kafka_datasource/kafka_consumer_group_test.go delete mode 100644 pkg/engine/datasource/kafka_datasource/kafka_datasource.go delete mode 100644 pkg/engine/datasource/kafka_datasource/kafka_datasource_test.go delete mode 100644 pkg/engine/datasource/kafka_datasource/sarama_config_parameters_test.go delete mode 100644 pkg/engine/datasource/kafka_datasource/testdata/kafka_jaas.conf create mode 100644 pkg/graphql/execution_engine_v2_norace_test.go diff --git a/.github/workflows/ci-full.yml b/.github/workflows/ci-full.yml deleted file mode 100644 index 2125ed6aa..000000000 --- a/.github/workflows/ci-full.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: ci-full-manual -on: [workflow_dispatch] -jobs: - build: - name: Build (go ${{ matrix.go }}/${{ matrix.os }}) - runs-on: ${{ matrix.os }} - strategy: - matrix: - go: [ '1.18' ] - os: [ 'ubuntu-latest'] - steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 - with: - go-version: ${{ matrix.go }} - id: go - - name: Check out code into the Go module directory - uses: actions/checkout@v1 - - name: CI - run: make ci-full diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8e5a9d89a..11ea6fd96 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,24 +5,56 @@ on: branches: - master jobs: - build: - name: Build (go ${{ matrix.go }}/${{ matrix.os }}) + test: + name: Build and test (go ${{ matrix.go }} / ${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: matrix: - go: [ '1.18' ] + go: [ '1.19', '1.20' ] os: [ubuntu-latest, windows-latest] steps: - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 + uses: actions/setup-go@v4 with: - go-version: ${{ matrix.go }} + go-version: ^${{ matrix.go }} id: go - name: Set git to use LF run: | git config --global core.autocrlf false git config --global core.eol lf - name: Check out code into the Go module directory - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: CI run: make ci + - name: Run tests under race detector + if: runner.os != 'Windows' # These are very slow on Windows, skip them + run: make test-race + + lint: + name: Linters + runs-on: ubuntu-latest + steps: + - name: Set up Go 1.20 + uses: actions/setup-go@v4 + with: + go-version: ^1.20 + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: Run linters + uses: golangci/golangci-lint-action@v3 + with: + version: v1.51.1 + args: --timeout=3m + ci: + name: CI Success + if: ${{ always() }} + runs-on: ubuntu-latest + needs: [test, lint] + steps: + - run: exit 1 + if: >- + ${{ + contains(needs.*.result, 'failure') + || contains(needs.*.result, 'cancelled') + }} diff --git a/Makefile b/Makefile index 8af39d6d3..23fb63a45 100644 --- a/Makefile +++ b/Makefile @@ -1,40 +1,32 @@ -GOLANG_CI_VERSION = "v1.51.1" -GOLANG_CI_VERSION_SHORT = "1.51.1" -HAS_GOLANG_CI_LINT := $(shell command -v /tmp/ci/golangci-lint;) -INSTALLED_VERSION := $(shell command -v /tmp/ci/golangci-lint version;) -HAS_CORRECT_VERSION := $(shell command -v if [[ $(INSTALLED_VERSION) == *$(GOLANG_CI_VERSION_SHORT)* ]]; echo "OK" fi) - -.PHONY: bootstrap - .PHONY: test test: - go test --short -count=1 ./... + go test ./... -.PHONY: test-full -test-full: +.PHONY: test-quick +test-quick: go test -count=1 ./... +.PHONY: test-race +test-race: + go test -race ./... + # updateTestFixtures will update all! golden fixtures .PHONY: updateTestFixtures updateTestFixtures: go test ./pkg/... -update -.PHONY: lint -lint: - /tmp/ci/golangci-lint run - .PHONY: format format: go fmt ./... .PHONY: prepare-merge -prepare-merge: format test lint +prepare-merge: format test .PHONY: ci -ci: bootstrap test lint +ci: test -.PHONY: ci-full -ci-full: bootstrap test-full lint +.PHONY: ci-quick +ci-full: test-quick .PHONY: generate generate: $(GOPATH)/bin/go-enum $(GOPATH)/bin/mockgen $(GOPATH)/bin/stringer @@ -52,12 +44,3 @@ $(GOPATH)/bin/mockgen: $(GOPATH)/bin/stringer: go get -u -a golang.org/x/tools/cmd/stringer go install golang.org/x/tools/cmd/stringer - -bootstrap: -ifndef HAS_GOLANG_CI_LINT - curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b /tmp/ci ${GOLANG_CI_VERSION} -endif - -updateci: - rm /tmp/ci/golangci-lint - curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b /tmp/ci ${GOLANG_CI_VERSION} diff --git a/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 0eda763e9..0c3802ac0 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1329,7 +1329,8 @@ type Source struct { httpClient *http.Client } -func (s *Source) compactAndUnNullVariables(input []byte, undefinedVariables []string) []byte { +func (s *Source) compactAndUnNullVariables(input []byte) []byte { + undefinedVariables := httpclient.UndefinedVariables(input) variables, _, _, err := jsonparser.Get(input, "body", "variables") if err != nil { return input @@ -1339,7 +1340,9 @@ func (s *Source) compactAndUnNullVariables(input []byte, undefinedVariables []st } if bytes.ContainsAny(variables, " \t\n\r") { buf := bytes.NewBuffer(make([]byte, 0, len(variables))) - _ = json.Compact(buf, variables) + if err := json.Compact(buf, variables); err != nil { + panic(fmt.Errorf("compacting variables: %w", err)) + } variables = buf.Bytes() } @@ -1409,9 +1412,7 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { } func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) { - undefinedVariables := httpclient.CtxGetUndefinedVariables(ctx) - - input = s.compactAndUnNullVariables(input, undefinedVariables) + input = s.compactAndUnNullVariables(input) return httpclient.Do(s.httpClient, ctx, input, writer) } diff --git a/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 33e1c032d..23ff43e0a 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -7801,23 +7801,23 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("invalid json: should stop before sending to upstream", func(t *testing.T) { next := make(chan []byte) - ctx := context.Background() - defer ctx.Done() + ctx := resolve.NewContext(context.Background()) + defer ctx.Context().Done() - source := newSubscriptionSource(ctx) + source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: "#test") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, next) + err := source.Start(ctx.Context(), chatSubscriptionOptions, next) require.ErrorIs(t, err, resolve.ErrUnableToResolve) }) t.Run("invalid syntax (roomNam)", func(t *testing.T) { next := make(chan []byte) - ctx := context.Background() - defer ctx.Done() + ctx := resolve.NewContext(context.Background()) + defer ctx.Context().Done() - source := newSubscriptionSource(ctx) + source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, next) + err := source.Start(ctx.Context(), chatSubscriptionOptions, next) require.NoError(t, err) msg, ok := <-next @@ -7851,12 +7851,12 @@ func TestSubscriptionSource_Start(t *testing.T) { t.Run("should successfully subscribe with chat example", func(t *testing.T) { next := make(chan []byte) - ctx := context.Background() - defer ctx.Done() + ctx := resolve.NewContext(context.Background()) + defer ctx.Context().Done() - source := newSubscriptionSource(ctx) + source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, next) + err := source.Start(ctx.Context(), chatSubscriptionOptions, next) require.NoError(t, err) username := "myuser" @@ -7913,12 +7913,12 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { t.Run("invalid syntax (roomNam)", func(t *testing.T) { next := make(chan []byte) - ctx := context.Background() - defer ctx.Done() + ctx := resolve.NewContext(context.Background()) + defer ctx.Context().Done() - source := newSubscriptionSource(ctx) + source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, next) + err := source.Start(ctx.Context(), chatSubscriptionOptions, next) require.NoError(t, err) msg, ok := <-next @@ -7952,12 +7952,12 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { t.Run("should successfully subscribe with chat example", func(t *testing.T) { next := make(chan []byte) - ctx := context.Background() - defer ctx.Done() + ctx := resolve.NewContext(context.Background()) + defer ctx.Context().Done() - source := newSubscriptionSource(ctx) + source := newSubscriptionSource(ctx.Context()) chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomName: \"#test\") { text createdBy } }"}`) - err := source.Start(ctx, chatSubscriptionOptions, next) + err := source.Start(ctx.Context(), chatSubscriptionOptions, next) require.NoError(t, err) username := "myuser" @@ -8159,7 +8159,8 @@ func TestSource_Load(t *testing.T) { buf := bytes.NewBuffer(nil) undefinedVariables := []string{"a", "c"} - ctx := httpclient.CtxSetUndefinedVariables(context.Background(), undefinedVariables) + ctx := context.Background() + input = httpclient.SetUndefinedVariables(input, undefinedVariables) require.NoError(t, src.Load(ctx, input, buf)) assert.Equal(t, `{"variables":{"b":null}}`, buf.String()) @@ -8171,7 +8172,7 @@ func TestUnNullVariables(t *testing.T) { t.Run("should not unnull variables if not enabled", func(t *testing.T) { t.Run("two variables, one null", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null,"b":true}}}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null,"b":true}}}`)) expected := `{"body":{"variables":{"a":null,"b":true}}}` assert.Equal(t, expected, string(out)) }) @@ -8179,77 +8180,77 @@ func TestUnNullVariables(t *testing.T) { t.Run("variables with whitespace and empty objects", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"email":null,"firstName": "FirstTest", "lastName":"LastTest","phone":123456,"preferences":{ "notifications":{}},"password":"password"}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"email":null,"firstName": "FirstTest", "lastName":"LastTest","phone":123456,"preferences":{ "notifications":{}},"password":"password"}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"firstName":"FirstTest","lastName":"LastTest","phone":123456,"password":"password"}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("empty variables", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{}},"unnull_variables":true}`)) expected := `{"body":{"variables":{}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("null inside an array", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"list":["a",null,"b"]}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"list":["a",null,"b"]}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"list":["a",null,"b"]}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("complex null inside nested objects and arrays", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null, "b": {"key":null, "nested": {"nestedkey": null}}, "arr": ["1", null, "3"], "d": {"nested_arr":["4",null,"6"]}}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null, "b": {"key":null, "nested": {"nestedkey": null}}, "arr": ["1", null, "3"], "d": {"nested_arr":["4",null,"6"]}}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"b":{"key":null,"nested":{"nestedkey":null}},"arr":["1",null,"3"],"d":{"nested_arr":["4",null,"6"]}}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("two variables, one null", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null,"b":true}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":null,"b":true}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"b":true}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("two variables, one null reverse", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":true,"b":null}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"a":true,"b":null}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"a":true}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("null variables", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":null},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":null},"unnull_variables":true}`)) expected := `{"body":{"variables":null},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("ignore null inside non variables", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"foo":null},"body":"query {foo(bar: null){baz}}"},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"foo":null},"body":"query {foo(bar: null){baz}}"},"unnull_variables":true}`)) expected := `{"body":{"variables":{},"body":"query {foo(bar: null){baz}}"},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("ignore null in variable name", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"not_null":1,"null":2,"not_null2":3}},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"variables":{"not_null":1,"null":2,"not_null2":3}},"unnull_variables":true}`)) expected := `{"body":{"variables":{"not_null":1,"null":2,"not_null2":3}},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("variables missing", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"query":"{foo}"},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"query":"{foo}"},"unnull_variables":true}`)) expected := `{"body":{"query":"{foo}"},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) t.Run("variables null", func(t *testing.T) { s := &Source{} - out := s.compactAndUnNullVariables([]byte(`{"body":{"query":"{foo}","variables":null},"unnull_variables":true}`), []string{}) + out := s.compactAndUnNullVariables([]byte(`{"body":{"query":"{foo}","variables":null},"unnull_variables":true}`)) expected := `{"body":{"query":"{foo}","variables":null},"unnull_variables":true}` assert.Equal(t, expected, string(out)) }) diff --git a/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go b/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go index 9b810df52..4b57a9a0b 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go @@ -65,28 +65,7 @@ func (h *gqlSSEConnectionHandler) StartBlocking(sub Subscription) { } func (h *gqlSSEConnectionHandler) subscribe(ctx context.Context, sub Subscription, dataCh, errCh chan []byte) { - // if we used the downstream context, we got a panic if the downstream client disconnects immediately after the request was sent - // this happens, e.g. with React strict mode which renders the component twice - // to solve the issue, we use a separate context for the origin request - // with a goroutine that cancels the origin request if the downstream client disconnects - // in order to free resources after the initial handshake, we cancel the goroutine after we've received a response - originCtx, cancelOriginRequest := context.WithCancel(context.Background()) - defer cancelOriginRequest() - waitForResponse, cancelWaitForResponse := context.WithCancel(context.Background()) - go func() { - select { - case <-ctx.Done(): - // cancel the origin request if the downstream client disconnected - cancelOriginRequest() - case <-waitForResponse.Done(): - // end the goroutine to free resources - } - }() - resp, err := h.performSubscriptionRequest(originCtx) - // cancel the goroutine to free resources - // the originRequest will be canceled through defer cancelOriginRequest() - // as we check on every iteration (below) if the downstream ctx is done - cancelWaitForResponse() + resp, err := h.performSubscriptionRequest(ctx) if err != nil { h.log.Error("failed to perform subscription request", log.Error(err)) diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index efff20a18..d0cfdb09e 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -288,6 +288,8 @@ func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) { return true }, time.Second, time.Millisecond*10, "server did not close") assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } diff --git a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go index 5352e691c..997908556 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go @@ -80,6 +80,8 @@ func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } @@ -151,6 +153,8 @@ func TestWebsocketSubscriptionClientPing_GQLTWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } @@ -297,6 +301,8 @@ func TestWebSocketSubscriptionClientInitIncludePing_GQLTWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() assertion.Eventuallyf(func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } @@ -367,6 +373,8 @@ func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) { clientCancel() serverCancel() assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } diff --git a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go index 3eb3e6a96..0f8737143 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go @@ -87,6 +87,8 @@ func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() assertion.Eventuallyf(func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } @@ -152,6 +154,8 @@ func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) { }, time.Second, time.Millisecond*10, "server did not close") serverCancel() assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } @@ -323,6 +327,8 @@ func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) { serverCancel() clientCancel() assert.Eventuallyf(t, func() bool { + client.handlersMu.Lock() + defer client.handlersMu.Unlock() return len(client.handlers) == 0 }, time.Second, time.Millisecond, "client handlers not 0") } diff --git a/pkg/engine/datasource/httpclient/httpclient.go b/pkg/engine/datasource/httpclient/httpclient.go index b1b12b252..5d5428d02 100644 --- a/pkg/engine/datasource/httpclient/httpclient.go +++ b/pkg/engine/datasource/httpclient/httpclient.go @@ -2,8 +2,8 @@ package httpclient import ( "bytes" - "context" "encoding/json" + "fmt" "io" "github.com/buger/jsonparser" @@ -15,24 +15,21 @@ import ( "github.com/wundergraph/graphql-go-tools/pkg/lexer/literal" ) -type ctxKey string - const ( - PATH = "path" - URL = "url" - URLENCODEBODY = "url_encode_body" - BASEURL = "base_url" - METHOD = "method" - BODY = "body" - HEADER = "header" - QUERYPARAMS = "query_params" - USESSE = "use_sse" - SSEMETHODPOST = "sse_method_post" - SCHEME = "scheme" - HOST = "host" - UNNULLVARIABLES = "unnull_variables" - - removeUndefinedVariables ctxKey = "remove_undefined_variables" + PATH = "path" + URL = "url" + URLENCODEBODY = "url_encode_body" + BASEURL = "base_url" + METHOD = "method" + BODY = "body" + HEADER = "header" + QUERYPARAMS = "query_params" + USESSE = "use_sse" + SSEMETHODPOST = "sse_method_post" + SCHEME = "scheme" + HOST = "host" + UNNULLVARIABLES = "unnull_variables" + UNDEFINED_VARIABLES = "undefined" ) var ( @@ -50,18 +47,6 @@ var ( } ) -func CtxSetUndefinedVariables(ctx context.Context, undefinedVariables []string) context.Context { - return context.WithValue(ctx, removeUndefinedVariables, undefinedVariables) -} - -func CtxGetUndefinedVariables(ctx context.Context) []string { - undefinedVariables := ctx.Value(removeUndefinedVariables) - if undefinedVariables, ok := undefinedVariables.([]string); ok { - return undefinedVariables - } - return nil -} - func wrapQuotesIfString(b []byte) []byte { if bytes.HasPrefix(b, []byte("$$")) && bytes.HasSuffix(b, []byte("$$")) { @@ -236,3 +221,31 @@ func GetSubscriptionInput(input []byte) (url, header, body []byte) { }, subscriptionInputPaths...) return } + +func setUndefinedVariables(data []byte, undefinedVariables []string) ([]byte, error) { + if len(undefinedVariables) > 0 { + encoded, err := json.Marshal(undefinedVariables) + if err != nil { + return nil, err + } + return sjson.SetRawBytes(data, UNDEFINED_VARIABLES, encoded) + } + return data, nil +} + +func SetUndefinedVariables(data []byte, undefinedVariables []string) []byte { + result, err := setUndefinedVariables(data, undefinedVariables) + if err != nil { + panic(fmt.Errorf("couldn't set undefined variables: %w", err)) + } + return result +} + +func UndefinedVariables(data []byte) []string { + var undefinedVariables []string + gjson.GetBytes(data, UNDEFINED_VARIABLES).ForEach(func(key, value gjson.Result) bool { + undefinedVariables = append(undefinedVariables, value.Str) + return true + }) + return undefinedVariables +} diff --git a/pkg/engine/datasource/kafka_datasource/config.go b/pkg/engine/datasource/kafka_datasource/config.go deleted file mode 100644 index 180a55318..000000000 --- a/pkg/engine/datasource/kafka_datasource/config.go +++ /dev/null @@ -1,141 +0,0 @@ -package kafka_datasource - -import ( - "fmt" - - "github.com/Shopify/sarama" -) - -const ( - IsolationLevelReadUncommitted = "ReadUncommitted" - IsolationLevelReadCommitted = "ReadCommitted" -) - -const DefaultIsolationLevel = IsolationLevelReadUncommitted - -const ( - BalanceStrategyRange = "BalanceStrategyRange" - BalanceStrategySticky = "BalanceStrategySticky" - BalanceStrategyRoundRobin = "BalanceStrategyRoundRobin" -) - -const DefaultBalanceStrategy = BalanceStrategyRange - -var ( - DefaultKafkaVersion = "V1_0_0_0" - SaramaSupportedKafkaVersions = map[string]sarama.KafkaVersion{ - "V0_10_2_0": sarama.V0_10_2_0, - "V0_10_2_1": sarama.V0_10_2_1, - "V0_11_0_0": sarama.V0_11_0_0, - "V0_11_0_1": sarama.V0_11_0_1, - "V0_11_0_2": sarama.V0_11_0_2, - "V1_0_0_0": sarama.V1_0_0_0, - "V1_1_0_0": sarama.V1_1_0_0, - "V1_1_1_0": sarama.V1_1_1_0, - "V2_0_0_0": sarama.V2_0_0_0, - "V2_0_1_0": sarama.V2_0_1_0, - "V2_1_0_0": sarama.V2_1_0_0, - "V2_2_0_0": sarama.V2_2_0_0, - "V2_3_0_0": sarama.V2_3_0_0, - "V2_4_0_0": sarama.V2_4_0_0, - "V2_5_0_0": sarama.V2_5_0_0, - "V2_6_0_0": sarama.V2_6_0_0, - "V2_7_0_0": sarama.V2_7_0_0, - "V2_8_0_0": sarama.V2_8_0_0, - } -) - -type SASL struct { - // Whether or not to use SASL authentication when connecting to the broker - // (defaults to false). - Enable bool `json:"enable"` - // User is the authentication identity (authcid) to present for - // SASL/PLAIN or SASL/SCRAM authentication - User string `json:"user"` - // Password for SASL/PLAIN authentication - Password string `json:"password"` -} - -type GraphQLSubscriptionOptions struct { - BrokerAddresses []string `json:"broker_addresses"` - Topics []string `json:"topics"` - GroupID string `json:"group_id"` - ClientID string `json:"client_id"` - KafkaVersion string `json:"kafka_version"` - StartConsumingLatest bool `json:"start_consuming_latest"` - BalanceStrategy string `json:"balance_strategy"` - IsolationLevel string `json:"isolation_level"` - SASL SASL `json:"sasl"` - startedCallback func() -} - -func (g *GraphQLSubscriptionOptions) Sanitize() { - if g.KafkaVersion == "" { - g.KafkaVersion = DefaultKafkaVersion - } - - // Strategy for allocating topic partitions to members (default BalanceStrategyRange) - if g.BalanceStrategy == "" { - g.BalanceStrategy = DefaultBalanceStrategy - } - - if g.IsolationLevel == "" { - g.IsolationLevel = DefaultIsolationLevel - } -} - -func (g *GraphQLSubscriptionOptions) Validate() error { - switch { - case len(g.BrokerAddresses) == 0: - return fmt.Errorf("broker_addresses cannot be empty") - case len(g.Topics) == 0: - return fmt.Errorf("topics cannot be empty") - case g.GroupID == "": - return fmt.Errorf("group_id cannot be empty") - case g.ClientID == "": - return fmt.Errorf("client_id cannot be empty") - } - - if _, ok := SaramaSupportedKafkaVersions[g.KafkaVersion]; !ok { - return fmt.Errorf("kafka_version is invalid: %s", g.KafkaVersion) - } - - switch g.BalanceStrategy { - case BalanceStrategyRange, BalanceStrategySticky, BalanceStrategyRoundRobin: - default: - return fmt.Errorf("balance_strategy is invalid: %s", g.BalanceStrategy) - } - - switch g.IsolationLevel { - case IsolationLevelReadUncommitted, IsolationLevelReadCommitted: - default: - return fmt.Errorf("isolation_level is invalid: %s", g.IsolationLevel) - } - - if g.SASL.Enable { - switch { - case g.SASL.User == "": - return fmt.Errorf("sasl.user cannot be empty") - case g.SASL.Password == "": - return fmt.Errorf("sasl.password cannot be empty") - } - } - - return nil -} - -type SubscriptionConfiguration struct { - BrokerAddresses []string `json:"broker_addresses"` - Topics []string `json:"topics"` - GroupID string `json:"group_id"` - ClientID string `json:"client_id"` - KafkaVersion string `json:"kafka_version"` - StartConsumingLatest bool `json:"start_consuming_latest"` - BalanceStrategy string `json:"balance_strategy"` - IsolationLevel string `json:"isolation_level"` - SASL SASL `json:"sasl"` -} - -type Configuration struct { - Subscription SubscriptionConfiguration -} diff --git a/pkg/engine/datasource/kafka_datasource/config_test.go b/pkg/engine/datasource/kafka_datasource/config_test.go deleted file mode 100644 index 49178d46e..000000000 --- a/pkg/engine/datasource/kafka_datasource/config_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package kafka_datasource - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestConfig_GraphQLSubscriptionOptions(t *testing.T) { - t.Run("Set default isolation_level", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{} - g.Sanitize() - require.Equal(t, DefaultIsolationLevel, g.IsolationLevel) - }) - - t.Run("Set default balance_strategy", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{} - g.Sanitize() - require.Equal(t, DefaultBalanceStrategy, g.BalanceStrategy) - }) - - t.Run("Set default Kafka version", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{} - g.Sanitize() - require.Equal(t, DefaultKafkaVersion, g.KafkaVersion) - }) - - t.Run("Empty broker_addresses not allowed", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "broker_addresses cannot be empty") - }) - - t.Run("Empty topic not allowed", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - GroupID: "groupid", - ClientID: "clientid", - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "topics cannot be empty") - }) - - t.Run("Empty client_id not allowed", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "client_id cannot be empty") - }) - - t.Run("Invalid Kafka version", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - KafkaVersion: "1.3.5", - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "kafka_version is invalid: 1.3.5") - }) - - t.Run("Invalid SASL configuration - SASL nil", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - SASL: SASL{}, - } - g.Sanitize() - err := g.Validate() - require.NoError(t, err) - }) - - t.Run("Invalid SASL configuration - auth disabled", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - SASL: SASL{ - Enable: false, - }, - } - g.Sanitize() - err := g.Validate() - require.NoError(t, err) - }) - - t.Run("Invalid SASL configuration - user cannot be empty", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - SASL: SASL{ - Enable: true, - }, - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "sasl.user cannot be empty") - }) - - t.Run("Invalid SASL configuration - password cannot be empty", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - SASL: SASL{ - Enable: true, - User: "foobar", - }, - } - g.Sanitize() - err := g.Validate() - require.Equal(t, err.Error(), "sasl.password cannot be empty") - }) - - t.Run("Valid SASL configuration", func(t *testing.T) { - g := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"foobar"}, - GroupID: "groupid", - ClientID: "clientid", - SASL: SASL{ - Enable: true, - User: "foobar", - Password: "password", - }, - } - g.Sanitize() - err := g.Validate() - require.NoError(t, err) - }) -} diff --git a/pkg/engine/datasource/kafka_datasource/kafka_consumer_group.go b/pkg/engine/datasource/kafka_datasource/kafka_consumer_group.go deleted file mode 100644 index 30ee6c0df..000000000 --- a/pkg/engine/datasource/kafka_datasource/kafka_consumer_group.go +++ /dev/null @@ -1,323 +0,0 @@ -package kafka_datasource - -import ( - "context" - "sync" - "time" - - "github.com/Shopify/sarama" - "github.com/buger/jsonparser" - log "github.com/jensneuse/abstractlogger" -) - -const consumerGroupRetryInterval = time.Second - -type KafkaConsumerGroupBridge struct { - log log.Logger - ctx context.Context -} - -type KafkaConsumerGroup struct { - consumerGroup sarama.ConsumerGroup - options *GraphQLSubscriptionOptions - log log.Logger - startedCallback func() - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc -} - -type kafkaConsumerGroupHandler struct { - log log.Logger - startedCallback func() - options *GraphQLSubscriptionOptions - messages chan *sarama.ConsumerMessage - ctx context.Context -} - -// Setup is run at the beginning of a new session, before ConsumeClaim. -func (k *kafkaConsumerGroupHandler) Setup(_ sarama.ConsumerGroupSession) error { - k.log.Debug("kafkaConsumerGroupHandler.Setup", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID), - ) - return nil -} - -// Cleanup is run at the end of a session, once all ConsumeClaim goroutines have exited -// but before the offsets are committed for the very last time. -func (k *kafkaConsumerGroupHandler) Cleanup(_ sarama.ConsumerGroupSession) error { - k.log.Debug("kafkaConsumerGroupHandler.Cleanup", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID), - ) - return nil -} - -// ConsumeClaim must start a consumer loop of ConsumerGroupClaim's Messages(). -// Once the Messages() channel is closed, the Handler must finish its processing -// loop and exit. -func (k *kafkaConsumerGroupHandler) ConsumeClaim(session sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { - if k.options.StartConsumingLatest { - // Reset the offset before start consuming and don't commit the consumed messages. - // In this way, it will only read the latest messages. - session.ResetOffset(claim.Topic(), claim.Partition(), sarama.OffsetNewest, "") - } - - if k.startedCallback != nil { - k.startedCallback() - } - - for msg := range claim.Messages() { - ctx, cancel := context.WithTimeout(k.ctx, time.Second*5) - select { - case k.messages <- msg: - cancel() - // If the client wants to most recent messages, don't commit the - // offset and reset the offset to sarama.OffsetNewest, then start consuming. - if !k.options.StartConsumingLatest { - session.MarkMessage(msg, "") // Commit the message and advance the offset. - } - case <-ctx.Done(): - cancel() - return nil - } - } - k.log.Debug("kafkaConsumerGroupHandler.ConsumeClaim is gone", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID)) - return nil -} - -// NewKafkaConsumerGroup creates a new sarama.ConsumerGroup and returns a new -// *KafkaConsumerGroup instance. -func NewKafkaConsumerGroup(log log.Logger, saramaConfig *sarama.Config, options *GraphQLSubscriptionOptions) (*KafkaConsumerGroup, error) { - cg, err := sarama.NewConsumerGroup(options.BrokerAddresses, options.GroupID, saramaConfig) - if err != nil { - return nil, err - } - - ctx, cancel := context.WithCancel(context.Background()) - return &KafkaConsumerGroup{ - consumerGroup: cg, - startedCallback: options.startedCallback, - log: log, - options: options, - ctx: ctx, - cancel: cancel, - }, nil -} - -func (k *KafkaConsumerGroup) startConsuming(handler sarama.ConsumerGroupHandler) { - defer k.wg.Done() - - defer func() { - if err := k.consumerGroup.Close(); err != nil { - k.log.Error("KafkaConsumerGroup.Close returned an error", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID), - log.Error(err)) - } - }() - - k.wg.Add(1) - go func() { - defer k.wg.Done() - - // Errors returns a read channel of errors that occurred during the consumer life-cycle. - // By default, errors are logged and not returned over this channel. - // If you want to implement any custom error handling, set your config's - // Consumer.Return.Errors setting to true, and read from this channel. - for err := range k.consumerGroup.Errors() { - k.log.Error("KafkaConsumerGroup.Consumer", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID), - log.Error(err)) - } - }() - - // From Sarama documents: - // - // This method should be called inside an infinite loop, when a - // server-side rebalance happens, the consumer session will need to be - // recreated to get the new claims. - for { - select { - case <-k.ctx.Done(): - return - default: - } - - k.log.Info("KafkaConsumerGroup.consumerGroup.Consume has been called", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID)) - - // Blocking call - err := k.consumerGroup.Consume(k.ctx, k.options.Topics, handler) - if err != nil { - k.log.Error("KafkaConsumerGroup.startConsuming", - log.Strings("topics", k.options.Topics), - log.String("groupID", k.options.GroupID), - log.String("clientID", k.options.ClientID), - log.Error(err)) - } - // Rebalance or node restart takes time. Every Consume call - // triggers a context switch on the CPU. We should prevent an - // interrupt storm. - <-time.After(consumerGroupRetryInterval) - } -} - -// StartConsuming initializes a new consumer group handler and starts consuming at -// background. -func (k *KafkaConsumerGroup) StartConsuming(messages chan *sarama.ConsumerMessage) { - handler := &kafkaConsumerGroupHandler{ - log: k.log, - startedCallback: k.options.startedCallback, - options: k.options, - messages: messages, - ctx: k.ctx, - } - - k.wg.Add(1) - go k.startConsuming(handler) -} - -// Close stops background goroutines and closes the underlying ConsumerGroup instance. -func (k *KafkaConsumerGroup) Close() error { - select { - case <-k.ctx.Done(): - // Already closed - return nil - default: - } - - k.cancel() - return k.consumerGroup.Close() -} - -// WaitUntilConsumerStop waits until ConsumerGroup.Consume function stops. -func (k *KafkaConsumerGroup) WaitUntilConsumerStop() { - k.wg.Wait() -} - -func NewKafkaConsumerGroupBridge(ctx context.Context, logger log.Logger) *KafkaConsumerGroupBridge { - if logger == nil { - logger = log.NoopLogger - } - return &KafkaConsumerGroupBridge{ - ctx: ctx, - log: logger, - } -} - -func (c *KafkaConsumerGroupBridge) prepareSaramaConfig(options *GraphQLSubscriptionOptions) (*sarama.Config, error) { - sc := sarama.NewConfig() - sc.Version = SaramaSupportedKafkaVersions[options.KafkaVersion] - sc.ClientID = options.ClientID - sc.Consumer.Return.Errors = true - - // Strategy for allocating topic partitions to members (default BalanceStrategyRange) - // See this: https://chrzaszcz.dev/2021/09/kafka-assignors/ - // Sanitize function doesn't allow an empty BalanceStrategy parameter. - switch options.BalanceStrategy { - case BalanceStrategyRange: - sc.Consumer.Group.Rebalance.Strategy = sarama.BalanceStrategyRange - case BalanceStrategySticky: - sc.Consumer.Group.Rebalance.Strategy = sarama.BalanceStrategySticky - case BalanceStrategyRoundRobin: - sc.Consumer.Group.Rebalance.Strategy = sarama.BalanceStrategyRoundRobin - } - - if options.StartConsumingLatest { - // Start consuming from the latest offset after a client restart - sc.Consumer.Offsets.Initial = sarama.OffsetNewest - } - - // IsolationLevel support 2 mode: - // - use `ReadUncommitted` (default) to consume and return all messages in message channel - // - use `ReadCommitted` to hide messages that are part of an aborted transaction - switch options.IsolationLevel { - case IsolationLevelReadCommitted: - sc.Consumer.IsolationLevel = sarama.ReadCommitted - case IsolationLevelReadUncommitted: - sc.Consumer.IsolationLevel = sarama.ReadUncommitted - } - - // SASL based authentication with broker. While there are multiple SASL authentication methods - // the current implementation is limited to plaintext (SASL/PLAIN) authentication - if options.SASL.Enable { - sc.Net.SASL.Enable = true - sc.Net.SASL.User = options.SASL.User - sc.Net.SASL.Password = options.SASL.Password - } - - return sc, nil -} - -// Subscribe creates a new consumer group with given config and streams messages via next channel. -func (c *KafkaConsumerGroupBridge) Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error { - options.Sanitize() - if err := options.Validate(); err != nil { - return err - } - - saramaConfig, err := c.prepareSaramaConfig(&options) - if err != nil { - return err - } - - cg, err := NewKafkaConsumerGroup(c.log, saramaConfig, &options) - if err != nil { - return err - } - - messages := make(chan *sarama.ConsumerMessage) - cg.StartConsuming(messages) - - // Wait for messages. - go func() { - defer func() { - if err := cg.Close(); err != nil { - c.log.Error("KafkaConsumerGroup.Close returned an error", - log.Strings("topics", options.Topics), - log.String("groupID", options.GroupID), - log.String("clientID", options.ClientID), - log.Error(err), - ) - } - close(next) - }() - - for { - select { - case <-c.ctx.Done(): - // Gateway context - return - case <-ctx.Done(): - // Request context - return - case msg, ok := <-messages: - if !ok { - return - } - // The "data" field contains the result of your GraphQL request. - result, err := jsonparser.Set([]byte(`{}`), msg.Value, "data") - if err != nil { - return - } - next <- result - } - } - }() - - return nil -} - -var _ sarama.ConsumerGroupHandler = (*kafkaConsumerGroupHandler)(nil) diff --git a/pkg/engine/datasource/kafka_datasource/kafka_consumer_group_test.go b/pkg/engine/datasource/kafka_datasource/kafka_consumer_group_test.go deleted file mode 100644 index 22b0d618e..000000000 --- a/pkg/engine/datasource/kafka_datasource/kafka_consumer_group_test.go +++ /dev/null @@ -1,391 +0,0 @@ -//go:build !windows - -package kafka_datasource - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/Shopify/sarama" - "github.com/Shopify/sarama/mocks" - log "github.com/jensneuse/abstractlogger" - "github.com/stretchr/testify/require" - "go.uber.org/zap" -) - -const defaultPartition = 0 - -// newMockKafkaBroker creates a MockBroker to test ConsumerGroups. -func newMockKafkaBroker(t *testing.T, topic, group string, fr *sarama.FetchResponse) *sarama.MockBroker { - mockBroker := sarama.NewMockBroker(t, 0) - - mockMetadataResponse := sarama.NewMockMetadataResponse(t). - SetBroker(mockBroker.Addr(), mockBroker.BrokerID()). - SetLeader(topic, defaultPartition, mockBroker.BrokerID()). - SetController(mockBroker.BrokerID()) - - mockProducerResponse := sarama.NewMockProduceResponse(t). - SetError(topic, 0, sarama.ErrNoError). - SetVersion(2) - - mockOffsetResponse := sarama.NewMockOffsetResponse(t). - SetOffset(topic, defaultPartition, sarama.OffsetOldest, 0). - SetOffset(topic, defaultPartition, sarama.OffsetNewest, 1). - SetVersion(1) - - mockCoordinatorResponse := sarama.NewMockFindCoordinatorResponse(t). - SetCoordinator(sarama.CoordinatorType(0), group, mockBroker) - - mockJoinGroupResponse := sarama.NewMockJoinGroupResponse(t) - - mockSyncGroupResponse := sarama.NewMockSyncGroupResponse(t). - SetMemberAssignment(&sarama.ConsumerGroupMemberAssignment{ - Version: 0, - Topics: map[string][]int32{topic: {0}}, - UserData: nil, - }) - - mockHeartbeatResponse := sarama.NewMockHeartbeatResponse(t) - - mockOffsetFetchResponse := sarama.NewMockOffsetFetchResponse(t). - SetOffset(group, topic, defaultPartition, 0, "", sarama.KError(0)) - - // Need to mock ApiVersionsRequest when we upgrade Sarama - - //mockApiVersionsResponse := sarama.NewMockApiVersionsResponse(t) - mockOffsetCommitResponse := sarama.NewMockOffsetCommitResponse(t) - mockBroker.SetHandlerByMap(map[string]sarama.MockResponse{ - "MetadataRequest": mockMetadataResponse, - "ProduceRequest": mockProducerResponse, - "OffsetRequest": mockOffsetResponse, - "OffsetFetchRequest": mockOffsetFetchResponse, - "FetchRequest": sarama.NewMockSequence(fr), - "FindCoordinatorRequest": mockCoordinatorResponse, - "JoinGroupRequest": mockJoinGroupResponse, - "SyncGroupRequest": mockSyncGroupResponse, - "HeartbeatRequest": mockHeartbeatResponse, - //"ApiVersionsRequest": mockApiVersionsResponse, - "OffsetCommitRequest": mockOffsetCommitResponse, - }) - - return mockBroker -} - -// testConsumerGroupHandler implements sarama.ConsumerGroupHandler interface for testing purposes. -type testConsumerGroupHandler struct { - processMessage func(msg *sarama.ConsumerMessage) - ctx context.Context - cancel context.CancelFunc -} - -func newDefaultConsumerGroupHandler(processMessage func(msg *sarama.ConsumerMessage)) *testConsumerGroupHandler { - ctx, cancel := context.WithCancel(context.Background()) - return &testConsumerGroupHandler{ - processMessage: processMessage, - ctx: ctx, - cancel: cancel, - } -} - -func (d *testConsumerGroupHandler) Setup(_ sarama.ConsumerGroupSession) error { - d.cancel() // ready for consuming - return nil -} - -func (d *testConsumerGroupHandler) Cleanup(_ sarama.ConsumerGroupSession) error { return nil } -func (d *testConsumerGroupHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { - for msg := range claim.Messages() { - d.processMessage(msg) - sess.MarkMessage(msg, "") // Commit the message and advance the offset. - } - return nil -} - -func newTestConsumerGroup(groupID string, brokers []string) (sarama.ConsumerGroup, error) { - kConfig := mocks.NewTestConfig() - kConfig.Version = sarama.MaxVersion - kConfig.Consumer.Return.Errors = true - kConfig.ClientID = "graphql-go-tools-test" - kConfig.Consumer.Offsets.Initial = sarama.OffsetNewest - - // Start with a client - client, err := sarama.NewClient(brokers, kConfig) - if err != nil { - return nil, err - } - - // Create a new consumer group - return sarama.NewConsumerGroupFromClient(groupID, client) -} - -func TestKafkaMockBroker(t *testing.T) { - var ( - testMessageKey = sarama.StringEncoder("test.message.key") - testMessageValue = sarama.StringEncoder("test.message.value") - topic = "test.topic" - consumerGroup = "consumer.group" - ) - - fr := &sarama.FetchResponse{Version: 11} - mockBroker := newMockKafkaBroker(t, topic, consumerGroup, fr) - defer mockBroker.Close() - - brokerAddr := []string{mockBroker.Addr()} - - cg, err := newTestConsumerGroup(consumerGroup, brokerAddr) - require.NoError(t, err) - - defer func() { - require.NoError(t, cg.Close()) - }() - - called := 0 - - // Stop after 15 seconds and return an error. - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - processMessage := func(msg *sarama.ConsumerMessage) { - defer cancel() - - t.Logf("Processed message topic: %s, key: %s, value: %s, ", msg.Topic, msg.Key, msg.Value) - key, _ := testMessageKey.Encode() - value, _ := testMessageValue.Encode() - require.Equal(t, key, msg.Key) - require.Equal(t, value, msg.Value) - require.Equal(t, topic, msg.Topic) - called++ - } - - handler := newDefaultConsumerGroupHandler(processMessage) - - errCh := make(chan error, 1) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - // Start consuming. Consume is a blocker call and it runs handler.ConsumeClaim at background. - errCh <- cg.Consume(ctx, []string{topic}, handler) - }() - - // Ready for consuming - <-handler.ctx.Done() - - // Add a message to the topic. KafkaConsumerGroupBridge group will fetch that message and trigger ConsumeClaim method. - fr.AddMessage(topic, defaultPartition, testMessageKey, testMessageValue, 0) - - // When this context is canceled, the processMessage function has been called and run without any problem. - <-ctx.Done() - - wg.Wait() - - // KafkaConsumerGroupBridge is stopped here. - require.NoError(t, <-errCh) - require.Equal(t, 1, called) - require.ErrorIs(t, ctx.Err(), context.Canceled) -} - -// It's just a simple example of graphql federation gateway server, it's NOT a production ready code. -func logger() log.Logger { - logger, err := zap.NewDevelopmentConfig().Build() - if err != nil { - panic(err) - } - - return log.NewZapLogger(logger, log.DebugLevel) -} - -func TestKafkaConsumerGroup_StartConsuming_And_Stop(t *testing.T) { - var ( - testMessageKey = sarama.StringEncoder("test.message.key") - testMessageValue = sarama.StringEncoder("test.message.value") - topic = "test.topic" - consumerGroup = "consumer.group" - ) - - fr := &sarama.FetchResponse{Version: 11} - mockBroker := newMockKafkaBroker(t, topic, consumerGroup, fr) - defer mockBroker.Close() - - // Add a message to the topic. The consumer group will fetch that message and trigger ConsumeClaim method. - fr.AddMessage(topic, defaultPartition, testMessageKey, testMessageValue, 0) - - options := GraphQLSubscriptionOptions{ - BrokerAddresses: []string{mockBroker.Addr()}, - Topics: []string{topic}, - GroupID: consumerGroup, - ClientID: "graphql-go-tools-test", - KafkaVersion: testMockKafkaVersion, - } - options.Sanitize() - require.NoError(t, options.Validate()) - - saramaConfig := sarama.NewConfig() - saramaConfig.Version = SaramaSupportedKafkaVersions[options.KafkaVersion] - - cg, err := NewKafkaConsumerGroup(logger(), saramaConfig, &options) - require.NoError(t, err) - - messages := make(chan *sarama.ConsumerMessage) - cg.StartConsuming(messages) - - msg := <-messages - expectedKey, _ := testMessageKey.Encode() - require.Equal(t, expectedKey, msg.Key) - - expectedValue, _ := testMessageValue.Encode() - require.Equal(t, expectedValue, msg.Value) - - require.NoError(t, cg.Close()) - - done := make(chan struct{}) - go func() { - defer func() { - close(done) - }() - - cg.WaitUntilConsumerStop() - }() - - select { - case <-time.After(15 * time.Second): - require.Fail(t, "KafkaConsumerGroup could not closed in 15 seconds") - case <-done: - } -} - -func TestKafkaConsumerGroup_Config_StartConsumingLatest(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - consumedMsgCh := make(chan *sarama.ConsumerMessage) - var mockTopicName = "test.mock.topic" - - // Create a new Kafka consumer handler here. We'll test ConsumeClaim method. - // If the StartConsumingLatest config option is true, it resets the offset, - // and we'll observe this behavior. - kg := &kafkaConsumerGroupHandler{ - ctx: ctx, - messages: consumedMsgCh, - log: logger(), - options: &GraphQLSubscriptionOptions{ - StartConsumingLatest: true, - Topics: []string{mockTopicName}, - GroupID: "test.consumer.group", - ClientID: "test.client.id", - }, - } - session := &mockConsumerGroupSession{ - resetOffsetParams: make(map[string]interface{}), - } - - claim := &mockConsumerGroupClaim{ - topicName: mockTopicName, - messages: make(chan *sarama.ConsumerMessage, 1), - } - // Produce a test message. - claim.messages <- &sarama.ConsumerMessage{ - Topic: mockTopicName, - Partition: defaultPartition, - Key: []byte("key"), - Value: []byte("value"), - } - - errCh := make(chan error) - go func() { - errCh <- kg.ConsumeClaim(session, claim) - }() - - select { - case <-consumedMsgCh: - // Test message has been consumed - case <-time.After(15 * time.Second): - require.Fail(t, "the message could not be consumed") - } - - // This will stop ConsumeClaim method, and it will return with an error or nil. - close(claim.messages) - require.NoError(t, <-errCh) - - // If the StartConsumingLatest switch works without any problem, we observe the following changes: - - // sarama.ConsumerGroupSession - require.Equal(t, mockTopicName, session.resetOffsetParams["topic"]) - require.Equal(t, int32(defaultPartition), session.resetOffsetParams["partition"]) - require.Equal(t, sarama.OffsetNewest, session.resetOffsetParams["offset"]) - require.Equal(t, "", session.resetOffsetParams["metadata"]) - - // sarama.ConsumerGroupClaim - require.False(t, session.markMessageCalled) -} - -type mockConsumerGroupSession struct { - markMessageCalled bool - resetOffsetParams map[string]interface{} -} - -func (m *mockConsumerGroupSession) Claims() map[string][]int32 { - panic("implement me") -} - -func (m *mockConsumerGroupSession) MemberID() string { - panic("implement me") -} - -func (m *mockConsumerGroupSession) GenerationID() int32 { - panic("implement me") -} - -func (m *mockConsumerGroupSession) MarkOffset(topic string, partition int32, offset int64, metadata string) { - panic("implement me") -} - -func (m *mockConsumerGroupSession) Commit() { - panic("implement me") -} - -func (m *mockConsumerGroupSession) ResetOffset(topic string, partition int32, offset int64, metadata string) { - m.resetOffsetParams["topic"] = topic - m.resetOffsetParams["partition"] = partition - m.resetOffsetParams["offset"] = offset - m.resetOffsetParams["metadata"] = metadata -} - -func (m *mockConsumerGroupSession) MarkMessage(msg *sarama.ConsumerMessage, metadata string) { - m.markMessageCalled = true -} - -func (m *mockConsumerGroupSession) Context() context.Context { - panic("implement me") -} - -var _ sarama.ConsumerGroupSession = (*mockConsumerGroupSession)(nil) - -type mockConsumerGroupClaim struct { - topicName string - messages chan *sarama.ConsumerMessage -} - -func (m *mockConsumerGroupClaim) Topic() string { - return m.topicName -} - -func (m *mockConsumerGroupClaim) Partition() int32 { - return defaultPartition -} - -func (m *mockConsumerGroupClaim) InitialOffset() int64 { - return 0 -} - -func (m *mockConsumerGroupClaim) HighWaterMarkOffset() int64 { - return 0 -} - -func (m *mockConsumerGroupClaim) Messages() <-chan *sarama.ConsumerMessage { - return m.messages -} - -var _ sarama.ConsumerGroupClaim = (*mockConsumerGroupClaim)(nil) diff --git a/pkg/engine/datasource/kafka_datasource/kafka_datasource.go b/pkg/engine/datasource/kafka_datasource/kafka_datasource.go deleted file mode 100644 index 9726eb903..000000000 --- a/pkg/engine/datasource/kafka_datasource/kafka_datasource.go +++ /dev/null @@ -1,75 +0,0 @@ -package kafka_datasource - -import ( - "context" - "encoding/json" - - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" -) - -type Planner struct { - ctx context.Context - config Configuration -} - -func (p *Planner) Register(_ *plan.Visitor, configuration plan.DataSourceConfiguration, _ bool) error { - return json.Unmarshal(configuration.Custom, &p.config) -} - -func (p *Planner) ConfigureFetch() plan.FetchConfiguration { - return plan.FetchConfiguration{} -} - -func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration { - input, _ := json.Marshal(p.config.Subscription) - return plan.SubscriptionConfiguration{ - Input: string(input), - DataSource: &SubscriptionSource{ - client: NewKafkaConsumerGroupBridge(p.ctx, abstractlogger.NoopLogger), - }, - } -} - -func (p *Planner) DataSourcePlanningBehavior() plan.DataSourcePlanningBehavior { - return plan.DataSourcePlanningBehavior{ - MergeAliasedRootNodes: false, - OverrideFieldPathFromAlias: false, - } -} - -func (p *Planner) DownstreamResponseFieldAlias(_ int) (alias string, exists bool) { return } - -type Factory struct{} - -func (f *Factory) Planner(ctx context.Context) plan.DataSourcePlanner { - return &Planner{ - ctx: ctx, - } -} - -func ConfigJSON(config Configuration) json.RawMessage { - out, _ := json.Marshal(config) - return out -} - -type GraphQLSubscriptionClient interface { - Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error -} - -type SubscriptionSource struct { - client GraphQLSubscriptionClient -} - -func (s *SubscriptionSource) Start(ctx context.Context, input []byte, next chan<- []byte) error { - var options GraphQLSubscriptionOptions - err := json.Unmarshal(input, &options) - if err != nil { - return err - } - return s.client.Subscribe(ctx, options, next) -} - -var _ plan.PlannerFactory = (*Factory)(nil) -var _ plan.DataSourcePlanner = (*Planner)(nil) diff --git a/pkg/engine/datasource/kafka_datasource/kafka_datasource_test.go b/pkg/engine/datasource/kafka_datasource/kafka_datasource_test.go deleted file mode 100644 index 1b478d806..000000000 --- a/pkg/engine/datasource/kafka_datasource/kafka_datasource_test.go +++ /dev/null @@ -1,317 +0,0 @@ -//go:build !windows - -package kafka_datasource - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "testing" - - "github.com/Shopify/sarama" - "github.com/buger/jsonparser" - "github.com/jensneuse/abstractlogger" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/pkg/engine/datasourcetesting" - "github.com/wundergraph/graphql-go-tools/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/pkg/engine/resolve" -) - -const ( - testMockKafkaVersion = "V2_8_0_0" - testDefinition = ` -schema { - subscription: Subscription -} - -type Subscription { - remainingJedis: Int! -} -` -) - -type runTestOnTestDefinitionOptions func(planConfig *plan.Configuration, extraChecks *[]datasourcetesting.CheckFunc) - -func runTestOnTestDefinition(operation, operationName string, expectedPlan plan.Plan, options ...runTestOnTestDefinitionOptions) func(t *testing.T) { - extraChecks := make([]datasourcetesting.CheckFunc, 0) - config := plan.Configuration{ - DataSources: []plan.DataSourceConfiguration{ - { - RootNodes: []plan.TypeField{ - { - TypeName: "Subscription", - FieldNames: []string{"remainingJedis"}, - }, - }, - Custom: ConfigJSON(Configuration{ - Subscription: SubscriptionConfiguration{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"test.topic"}, - GroupID: "test.consumer.group", - ClientID: "test.client.id", - KafkaVersion: testMockKafkaVersion, - BalanceStrategy: DefaultBalanceStrategy, - IsolationLevel: DefaultIsolationLevel, - SASL: SASL{ - Enable: true, - User: testSASLUser, - Password: testSASLPassword, - }, - }, - }), - Factory: &Factory{}, - }, - }, - } - - for _, opt := range options { - opt(&config, &extraChecks) - } - - return datasourcetesting.RunTest(testDefinition, operation, operationName, expectedPlan, config, extraChecks...) -} - -func testWithFactory(factory *Factory) runTestOnTestDefinitionOptions { - return func(planConfig *plan.Configuration, extraChecks *[]datasourcetesting.CheckFunc) { - for _, ds := range planConfig.DataSources { - ds.Factory = factory - } - } -} - -func TestKafkaDataSource(t *testing.T) { - factory := &Factory{} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - t.Run("subscription", runTestOnTestDefinition(` - subscription RemainingJedis { - remainingJedis - } - `, "RemainingJedis", &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(fmt.Sprintf(`{"broker_addresses":["localhost:9092"],"topics":["test.topic"],"group_id":"test.consumer.group","client_id":"test.client.id","kafka_version":"%s","start_consuming_latest":false,"balance_strategy":"%s","isolation_level":"%s","sasl":{"enable":true,"user":"%s","password":"%s"}}`, - testMockKafkaVersion, - DefaultBalanceStrategy, - DefaultIsolationLevel, - testSASLUser, - testSASLPassword, - )), - Source: &SubscriptionSource{ - client: NewKafkaConsumerGroupBridge(ctx, logger()), - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("remainingJedis"), - Position: resolve.Position{ - Line: 3, - Column: 4, - }, - Value: &resolve.Integer{ - Path: []string{"remainingJedis"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, testWithFactory(factory))) - - t.Run("subscription with variables", datasourcetesting.RunTest(` - type Subscription { - foo(bar: String): Int! - } -`, ` - subscription SubscriptionWithVariables($bar: String) { - foo(bar: $bar) - } - `, "SubscriptionWithVariables", &plan.SubscriptionResponsePlan{ - Response: &resolve.GraphQLSubscription{ - Trigger: resolve.GraphQLSubscriptionTrigger{ - Input: []byte(fmt.Sprintf(`{"broker_addresses":["localhost:9092"],"topics":["test.topic.$$0$$"],"group_id":"test.consumer.group","client_id":"test.client.id","kafka_version":"%s","start_consuming_latest":false,"balance_strategy":"%s","isolation_level":"%s","sasl":{"enable":true,"user":"%s","password":"%s"}}`, - testMockKafkaVersion, - DefaultBalanceStrategy, - DefaultIsolationLevel, - testSASLUser, - testSASLPassword, - )), - Variables: resolve.NewVariables( - &resolve.ContextVariable{ - Path: []string{"bar"}, - Renderer: resolve.NewPlainVariableRendererWithValidation(`{"type":["string","null"]}`), - }, - ), - Source: &SubscriptionSource{ - client: NewKafkaConsumerGroupBridge(ctx, logger()), - }, - }, - Response: &resolve.GraphQLResponse{ - Data: &resolve.Object{ - Fields: []*resolve.Field{ - { - Name: []byte("foo"), - Position: resolve.Position{ - Line: 3, - Column: 4, - }, - Value: &resolve.Integer{ - Path: []string{"foo"}, - Nullable: false, - }, - }, - }, - }, - }, - }, - }, plan.Configuration{ - DataSources: []plan.DataSourceConfiguration{ - { - RootNodes: []plan.TypeField{ - { - TypeName: "Subscription", - FieldNames: []string{"foo"}, - }, - }, - Custom: ConfigJSON(Configuration{ - Subscription: SubscriptionConfiguration{ - BrokerAddresses: []string{"localhost:9092"}, - Topics: []string{"test.topic.{{.arguments.bar}}"}, - GroupID: "test.consumer.group", - ClientID: "test.client.id", - KafkaVersion: testMockKafkaVersion, - BalanceStrategy: DefaultBalanceStrategy, - IsolationLevel: DefaultIsolationLevel, - SASL: SASL{ - Enable: true, - User: testSASLUser, - Password: testSASLPassword, - }, - }, - }), - Factory: factory, - }, - }, - Fields: []plan.FieldConfiguration{ - { - TypeName: "Subscription", - FieldName: "foo", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "bar", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - }, - })) -} - -var errSubscriptionClientFail = errors.New("subscription client fail error") - -type FailingSubscriptionClient struct{} - -func (f FailingSubscriptionClient) Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error { - return errSubscriptionClientFail -} - -func TestKafkaDataSource_Subscription_Start(t *testing.T) { - newSubscriptionSource := func(ctx context.Context) SubscriptionSource { - subscriptionSource := SubscriptionSource{client: NewKafkaConsumerGroupBridge(ctx, abstractlogger.NoopLogger)} - return subscriptionSource - } - - t.Run("should return error when input is invalid", func(t *testing.T) { - source := SubscriptionSource{client: FailingSubscriptionClient{}} - err := source.Start(context.Background(), []byte(`{"broker_addresses":"",topic":"","group_id":""}`), nil) - assert.Error(t, err) - }) - - t.Run("should send and receive a message, then cancel subscription", func(t *testing.T) { - next := make(chan []byte) - subscriptionLifecycle, cancelSubscription := context.WithCancel(context.Background()) - resolverLifecycle, cancelResolver := context.WithCancel(context.Background()) - defer cancelResolver() - - topic := "graphql-go-tools.test.topic" - groupID := "graphql-go-tools.test.groupid" - source := newSubscriptionSource(resolverLifecycle) - - fr := &sarama.FetchResponse{Version: 11} - mockBroker := newMockKafkaBroker(t, topic, groupID, fr) - defer mockBroker.Close() - - options := GraphQLSubscriptionOptions{ - BrokerAddresses: []string{mockBroker.Addr()}, - Topics: []string{topic}, - GroupID: groupID, - ClientID: "graphql-go-tools.test.groupid", - KafkaVersion: testMockKafkaVersion, - } - optionsBytes, err := json.Marshal(options) - require.NoError(t, err) - err = source.Start(subscriptionLifecycle, optionsBytes, next) - require.NoError(t, err) - - testMessageKey := sarama.StringEncoder("test.message.key") - testMessageValue := sarama.StringEncoder(`{"stock":[{"name":"Trilby","price":293,"inStock":2}]}`) - - // Add a message to the topic. The consumer group will fetch that message and trigger ConsumeClaim method. - fr.AddMessage(topic, defaultPartition, testMessageKey, testMessageValue, 0) - - nextBytes := <-next - assert.Equal(t, `{"data":{"stock":[{"name":"Trilby","price":293,"inStock":2}]}}`, string(nextBytes)) - - cancelSubscription() - _, ok := <-next - assert.False(t, ok) - }) -} - -func TestKafkaConsumerGroupBridge_Subscribe(t *testing.T) { - var ( - testMessageKey = sarama.StringEncoder("test.message.key") - testMessageValue = sarama.StringEncoder(`{"stock":[{"name":"Trilby","price":293,"inStock":2}]}`) - topic = "test.topic" - consumerGroup = "consumer.group" - ) - - fr := &sarama.FetchResponse{Version: 11} - mockBroker := newMockKafkaBroker(t, topic, consumerGroup, fr) - defer mockBroker.Close() - - // Add a message to the topic. The consumer group will fetch that message and trigger ConsumeClaim method. - fr.AddMessage(topic, defaultPartition, testMessageKey, testMessageValue, 0) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cg := NewKafkaConsumerGroupBridge(ctx, logger()) // use abstractlogger.NoopLogger if there is no available logger. - - options := GraphQLSubscriptionOptions{ - BrokerAddresses: []string{mockBroker.Addr()}, - Topics: []string{topic}, - GroupID: consumerGroup, - ClientID: "graphql-go-tools-test", - KafkaVersion: testMockKafkaVersion, - } - - next := make(chan []byte) - err := cg.Subscribe(ctx, options, next) - require.NoError(t, err) - - msg := <-next - expectedMsg, err := testMessageValue.Encode() - require.NoError(t, err) - - value, _, _, err := jsonparser.Get(msg, "data") - require.NoError(t, err) - require.Equal(t, expectedMsg, value) -} diff --git a/pkg/engine/datasource/kafka_datasource/sarama_config_parameters_test.go b/pkg/engine/datasource/kafka_datasource/sarama_config_parameters_test.go deleted file mode 100644 index b482c4178..000000000 --- a/pkg/engine/datasource/kafka_datasource/sarama_config_parameters_test.go +++ /dev/null @@ -1,885 +0,0 @@ -//go:build !windows - -package kafka_datasource - -import ( - "context" - "errors" - "fmt" - "os" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/Shopify/sarama" - "github.com/go-zookeeper/zk" - "github.com/ory/dockertest" - "github.com/ory/dockertest/docker" - "github.com/stretchr/testify/require" -) - -// Possible errors with dockertest setup: -// -// Error: API error (404): could not find an available, non-overlapping IPv4 address pool among the defaults to assign to the network -// Solution: docker network prune - -const ( - testBrokerAddr = "localhost:9092" - testClientID = "graphql-go-tools-test" - messageTemplate = "topic: %s - message: %d" - testTopic = "start-consuming-latest-test" - testConsumerGroup = "start-consuming-latest-cg" - testSASLUser = "admin" - testSASLPassword = "admin-secret" - initialBrokerPort = 9092 - maxIdleConsumerSeconds = 10 * time.Second -) - -var defaultZooKeeperEnvVars = []string{ - "ALLOW_ANONYMOUS_LOGIN=yes", -} - -// See the following blogpost to understand how Kafka listeners works: -// https://www.confluent.io/blog/kafka-listeners-explained/ - -var defaultKafkaEnvVars = []string{ - "KAFKA_ZOOKEEPER_CONNECT=zookeeper:2181", - "ALLOW_PLAINTEXT_LISTENER=yes", - "KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=INSIDE:PLAINTEXT,OUTSIDE:PLAINTEXT", - "KAFKA_INTER_BROKER_LISTENER_NAME=INSIDE", -} - -type kafkaCluster struct { - pool *dockertest.Pool - network *docker.Network - kafkaRunOptions kafkaClusterOptions -} - -func newKafkaCluster(t *testing.T) *kafkaCluster { - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - require.NoError(t, pool.Client.Ping()) - - network, err := pool.Client.CreateNetwork(docker.CreateNetworkOptions{Name: "zookeeper_kafka_network"}) - require.NoError(t, err) - - return &kafkaCluster{ - pool: pool, - network: network, - } -} - -func getPortID(port int) string { - return fmt.Sprintf("%d/tcp", port) -} - -func (k *kafkaCluster) startZooKeeper(t *testing.T) { - t.Log("Trying to run ZooKeeper") - - resource, err := k.pool.RunWithOptions(&dockertest.RunOptions{ - Name: "zookeeper-graphql", - Repository: "zookeeper", - Tag: "3.8.0", - NetworkID: k.network.ID, - Hostname: "zookeeper", - ExposedPorts: []string{"2181"}, - Env: defaultZooKeeperEnvVars, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - }) - require.NoError(t, err) - - t.Cleanup(func() { - if err = k.pool.Purge(resource); err != nil { - require.NoError(t, err) - } - }) - - conn, _, err := zk.Connect([]string{fmt.Sprintf("127.0.0.1:%s", resource.GetPort("2181/tcp"))}, 10*time.Second) - require.NoError(t, err) - - defer conn.Close() - - retryFn := func() error { - switch conn.State() { - case zk.StateHasSession, zk.StateConnected: - return nil - default: - return errors.New("not yet connected") - } - } - - require.NoError(t, k.pool.Retry(retryFn)) - t.Log("ZooKeeper has been started") -} - -type kafkaClusterOption func(k *kafkaClusterOptions) - -type kafkaClusterOptions struct { - envVars []string - saslAuth bool -} - -func withKafkaEnvVars(envVars []string) kafkaClusterOption { - return func(k *kafkaClusterOptions) { - k.envVars = envVars - } -} - -func withKafkaSASLAuth() kafkaClusterOption { - return func(k *kafkaClusterOptions) { - k.saslAuth = true - } -} - -func (k *kafkaCluster) startKafka(t *testing.T, port int, envVars []string) *dockertest.Resource { - t.Logf("Trying to run Kafka on %d", port) - - internalPort := port + 1 - hostname := fmt.Sprintf("kafka%d", port) - - // We need a deterministic way to produce broker IDs. Kafka produces random IDs if we don't set - // deliberately. We need to use the same ID to handle node restarts properly. - // All port numbers have to be bigger or equal to 9092 - // - // * If the port number is 9092, brokerID is 0 - // * If the port number is 9094, brokerID is 2 - brokerID := port % initialBrokerPort - - envVars = append(envVars, fmt.Sprintf("KAFKA_CFG_BROKER_ID=%d", brokerID)) - envVars = append(envVars, fmt.Sprintf("KAFKA_ADVERTISED_LISTENERS=INSIDE://%s:%d,OUTSIDE://localhost:%d", hostname, internalPort, port)) - envVars = append(envVars, fmt.Sprintf("KAFKA_LISTENERS=INSIDE://0.0.0.0:%d,OUTSIDE://0.0.0.0:%d", internalPort, port)) - - portID := getPortID(port) - - // Name and Hostname have to be unique - resource, err := k.pool.RunWithOptions(&dockertest.RunOptions{ - Name: fmt.Sprintf("kafka-graphql-%d", port), - Repository: "bitnami/kafka", - Tag: "3.1", - NetworkID: k.network.ID, - Hostname: hostname, - Env: envVars, - PortBindings: map[docker.Port][]docker.PortBinding{ - docker.Port(portID): {{HostIP: "localhost", HostPort: portID}}, - }, - ExposedPorts: []string{portID}, - }, func(config *docker.HostConfig) { - config.RestartPolicy = docker.RestartOnFailure(10) - if k.kafkaRunOptions.saslAuth { - wd, _ := os.Getwd() - config.Mounts = []docker.HostMount{{ - Target: "/opt/bitnami/kafka/config/kafka_jaas.conf", - Source: fmt.Sprintf("%s/testdata/kafka_jaas.conf", wd), - Type: "bind", - }} - } - }) - require.NoError(t, err) - - t.Cleanup(func() { - err := k.pool.Purge(resource) - if err != nil { - err = errors.Unwrap(errors.Unwrap(err)) - _, ok := err.(*docker.NoSuchContainer) - if ok { - // we closed this resource manually - err = nil - } - } - require.NoError(t, err) - }) - - retryFn := func() error { - config := sarama.NewConfig() - config.Producer.Return.Successes = true - config.Producer.Return.Errors = true - if k.kafkaRunOptions.saslAuth { - config.Net.SASL.Enable = true - config.Net.SASL.User = testSASLUser - config.Net.SASL.Password = testSASLPassword - } - - brokerAddr := resource.GetHostPort(portID) - asyncProducer, err := sarama.NewAsyncProducer([]string{brokerAddr}, config) - if err != nil { - return err - } - defer asyncProducer.Close() - - var total int - loop: - for { - total++ - if total > 100 { - return fmt.Errorf("tried 100 times but no messages have been produced") - } - message := &sarama.ProducerMessage{ - Topic: "grahpql-go-tools-health-check", - Value: sarama.StringEncoder("hello, world!"), - } - - asyncProducer.Input() <- message - - select { - case <-asyncProducer.Errors(): - // We should try again - // - // Possible error msg: kafka: Failed to produce message to topic grahpql-go-tools-health-check: - // kafka server: In the middle of a leadership election, there is currently no leader for this - // partition and hence it is unavailable for writes. - continue loop - case <-asyncProducer.Successes(): - break loop - } - - } - return nil - } - - require.NoError(t, k.pool.Retry(retryFn)) - - t.Logf("Kafka is ready to accept connections on %d", port) - return resource -} - -func (k *kafkaCluster) start(t *testing.T, numMembers int, options ...kafkaClusterOption) map[string]*dockertest.Resource { - for _, opt := range options { - opt(&k.kafkaRunOptions) - } - if len(k.kafkaRunOptions.envVars) == 0 { - k.kafkaRunOptions.envVars = defaultKafkaEnvVars - } - - t.Cleanup(func() { - require.NoError(t, k.pool.Client.RemoveNetwork(k.network.ID)) - }) - - k.startZooKeeper(t) - - resources := make(map[string]*dockertest.Resource) - var port = initialBrokerPort // Initial port - for i := 0; i < numMembers; i++ { - var envVars []string - envVars = append(envVars, k.kafkaRunOptions.envVars...) - portID := getPortID(port) - resources[portID] = k.startKafka(t, port, envVars) - - // Increase the port numbers. Every member uses different a hostname and port numbers. - // It was good for debugging: - // - // Member 1: - // 9092 - INSIDE - // 9093 - OUTSIDE - // - // Member 2: - // 9094 - INSIDE - // 9095 - OUTSIDE - port = port + 2 - } - require.NotEmpty(t, resources) - return resources -} - -func (k *kafkaCluster) restart(t *testing.T, port int, broker *dockertest.Resource, options ...kafkaClusterOption) (*dockertest.Resource, error) { - if err := broker.Close(); err != nil { - return nil, err - } - - for _, opt := range options { - opt(&k.kafkaRunOptions) - } - if len(k.kafkaRunOptions.envVars) == 0 { - k.kafkaRunOptions.envVars = defaultKafkaEnvVars - } - - var envVars []string - envVars = append(envVars, k.kafkaRunOptions.envVars...) - return k.startKafka(t, port, envVars), nil -} - -func (k *kafkaCluster) addNewBroker(t *testing.T, port int, options ...kafkaClusterOption) (*dockertest.Resource, error) { - for _, opt := range options { - opt(&k.kafkaRunOptions) - } - if len(k.kafkaRunOptions.envVars) == 0 { - k.kafkaRunOptions.envVars = defaultKafkaEnvVars - } - - var envVars []string - envVars = append(envVars, k.kafkaRunOptions.envVars...) - return k.startKafka(t, port, envVars), nil -} - -func produceTestMessages(t *testing.T, options *GraphQLSubscriptionOptions, messages map[string][]string) { - config := sarama.NewConfig() - if options.SASL.Enable { - config.Net.SASL.Enable = true - config.Net.SASL.User = options.SASL.User - config.Net.SASL.Password = options.SASL.Password - } - - asyncProducer, err := sarama.NewAsyncProducer(options.BrokerAddresses, config) - require.NoError(t, err) - - for _, topic := range options.Topics { - values, ok := messages[topic] - if ok { - for _, value := range values { - message := &sarama.ProducerMessage{ - Topic: topic, - Value: sarama.StringEncoder(value), - } - asyncProducer.Input() <- message - } - } - } -} - -func consumeTestMessages(t *testing.T, messages chan *sarama.ConsumerMessage, producedMessages map[string][]string) { - var expectedNumMessages int - for _, values := range producedMessages { - expectedNumMessages += len(values) - } - - consumedMessages := make(map[string][]string) - var numMessages int -L: - for { - select { - case <-time.After(maxIdleConsumerSeconds): - require.Failf(t, "all produced messages could not be consumed", "consumer is idle for %s", maxIdleConsumerSeconds) - case msg := <-messages: - numMessages++ - topic := msg.Topic - value := string(msg.Value) - consumedMessages[topic] = append(consumedMessages[topic], value) - if numMessages >= expectedNumMessages { - break L - } - } - } - - require.Equal(t, producedMessages, consumedMessages) -} - -func testStartConsumer(t *testing.T, options *GraphQLSubscriptionOptions) (*KafkaConsumerGroup, chan *sarama.ConsumerMessage) { - ctx, cancel := context.WithCancel(context.Background()) - options.startedCallback = func() { - cancel() - } - - options.Sanitize() - require.NoError(t, options.Validate()) - - // Start a consumer - saramaConfig := sarama.NewConfig() - if options.SASL.Enable { - saramaConfig.Net.SASL.Enable = true - saramaConfig.Net.SASL.User = options.SASL.User - saramaConfig.Net.SASL.Password = options.SASL.Password - } - saramaConfig.Consumer.Return.Errors = true - - cg, err := NewKafkaConsumerGroup(logger(), saramaConfig, options) - require.NoError(t, err) - - messages := make(chan *sarama.ConsumerMessage) - cg.StartConsuming(messages) - - <-ctx.Done() - - return cg, messages -} - -func skipWhenShort(t *testing.T) { - if testing.Short() { - t.Skip("skipping kafka docker tests in short mode") - } -} - -func getBrokerAddresses(brokers map[string]*dockertest.Resource) (brokerAddresses []string) { - for portID, broker := range brokers { - brokerAddresses = append(brokerAddresses, broker.GetHostPort(portID)) - } - return brokerAddresses -} - -func publishMessagesContinuously(t *testing.T, ctx context.Context, options *GraphQLSubscriptionOptions) { - config := sarama.NewConfig() - if options.SASL.Enable { - config.Net.SASL.Enable = true - config.Net.SASL.User = options.SASL.User - config.Net.SASL.Password = options.SASL.Password - } - - asyncProducer, err := sarama.NewAsyncProducer(options.BrokerAddresses, config) - require.NoError(t, err) - - var i int - for { - select { - case <-ctx.Done(): - return - default: - } - for _, topic := range options.Topics { - message := &sarama.ProducerMessage{ - Topic: topic, - Value: sarama.StringEncoder(fmt.Sprintf(messageTemplate, topic, i)), - } - asyncProducer.Input() <- message - } - i++ - } -} - -func TestSarama_StartConsumingLatest_True(t *testing.T) { - skipWhenShort(t) - - // Test scenario: - // - // 1- Start a new consumer - // 2- Produce 10 messages - // 3- The consumer consumes the produced messages - // 4- Stop the consumer - // 5- Produce more messages - // 6- Start a new consumer with the same consumer group name - // 7- Produce more messages - // 8- Consumer will consume the messages produced on step 7. - - // Important note about offset management in Kafka: - // - // config.Consumer.Offsets.Initial only takes effect when offsets are not committed to Kafka/Zookeeper. - // If the consumer group already has offsets committed, the consumer will resume from the committed offset. - - k := newKafkaCluster(t) - brokers := k.start(t, 1) - - const ( - testTopic = "start-consuming-latest-test" - testConsumerGroup = "start-consuming-latest-cg" - ) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: true, - } - - cg, messages := testStartConsumer(t, options) - - // Produce messages - // message-1 - // message-9 - testMessages := map[string][]string{ - testTopic: {"message-1", "message-2"}, - } - produceTestMessages(t, options, testMessages) - consumeTestMessages(t, messages, testMessages) - - // Stop the first consumer group - require.NoError(t, cg.Close()) - - // Produce more messages - // message-3 - // message-4 - // These messages will be ignored by the consumer. - testMessages = map[string][]string{ - testTopic: {"message-3", "message-4"}, - } - produceTestMessages(t, options, testMessages) - - // Start a new consumer with the same consumer group name - cg, messages = testStartConsumer(t, options) - - // Produce more messages - // message-5 - // message-6 - testMessages = map[string][]string{ - testTopic: {"message-5", "message-6"}, - } - produceTestMessages(t, options, testMessages) - consumeTestMessages(t, messages, testMessages) - - // Stop the second consumer group - require.NoError(t, cg.Close()) -} - -func TestSarama_StartConsuming_And_Restart(t *testing.T) { - skipWhenShort(t) - - // Test scenario: - // - // 1- Start a new consumer - // 2- Produce 10 messages - // 3- The consumer consumes the produced messages - // 4- Stop the consumer - // 5- Produce more messages - // 6- Start a new consumer with the same consumer group name - // 7- Produce more messages - // 8- Consumer will consume all messages. - - k := newKafkaCluster(t) - brokers := k.start(t, 1) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - } - - cg, messages := testStartConsumer(t, options) - - // Produce messages - testMessages := map[string][]string{ - testTopic: {"message-1", "message-2"}, - } - produceTestMessages(t, options, testMessages) - consumeTestMessages(t, messages, testMessages) - - // Stop the first consumer group - require.NoError(t, cg.Close()) - - // Produce more messages - testMessages = map[string][]string{ - testTopic: {"message-3", "message-4"}, - } - produceTestMessages(t, options, testMessages) - - // Start a new consumer with the same consumer group name - cg, messages = testStartConsumer(t, options) - - // Produce more messages - testMessages = map[string][]string{ - testTopic: {"message-5", "message-6"}, - } - produceTestMessages(t, options, testMessages) - - testMessages = map[string][]string{ - testTopic: {"message-3", "message-4", "message-5", "message-6"}, - } - consumeTestMessages(t, messages, testMessages) - - // Stop the second consumer group - require.NoError(t, cg.Close()) -} - -func TestSarama_ConsumerGroup_SASL_Authentication(t *testing.T) { - skipWhenShort(t) - - kafkaEnvVars := []string{ - "ALLOW_PLAINTEXT_LISTENER=yes", - "KAFKA_OPTS=-Djava.security.auth.login.config=/opt/bitnami/kafka/config/kafka_jaas.conf", - "KAFKA_ZOOKEEPER_CONNECT=zookeeper:2181", - "KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=INSIDE:PLAINTEXT,OUTSIDE:SASL_PLAINTEXT", - "KAFKA_CFG_SASL_ENABLED_MECHANISMS=PLAIN", - "KAFKA_CFG_SASL_MECHANISM_INTER_BROKER_PROTOCOL=PLAIN", - "KAFKA_CFG_INTER_BROKER_LISTENER_NAME=INSIDE", - } - k := newKafkaCluster(t) - brokers := k.start(t, 1, withKafkaEnvVars(kafkaEnvVars), withKafkaSASLAuth()) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - SASL: SASL{ - Enable: true, - User: testSASLUser, - Password: testSASLPassword, - }, - } - - cg, messages := testStartConsumer(t, options) - - // Produce messages - testMessages := map[string][]string{ - testTopic: {"message-1", "message-2"}, - } - produceTestMessages(t, options, testMessages) - consumeTestMessages(t, messages, testMessages) - - require.NoError(t, cg.Close()) -} - -func TestSarama_Balance_Strategy(t *testing.T) { - skipWhenShort(t) - - strategies := map[string]string{ - BalanceStrategyRange: "range", - BalanceStrategySticky: "sticky", - BalanceStrategyRoundRobin: "roundrobin", - "": "range", // Sanitize function will set DefaultBalanceStrategy, it is BalanceStrategyRange. - } - - for strategy, name := range strategies { - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{testBrokerAddr}, - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: testClientID, - BalanceStrategy: strategy, - } - options.Sanitize() - require.NoError(t, options.Validate()) - - kc := &KafkaConsumerGroupBridge{ - ctx: context.Background(), - log: logger(), - } - - sc, err := kc.prepareSaramaConfig(options) - require.NoError(t, err) - - st := sc.Consumer.Group.Rebalance.Strategy - require.Equal(t, name, st.Name()) - } -} - -func TestSarama_Isolation_Level(t *testing.T) { - skipWhenShort(t) - - strategies := map[string]sarama.IsolationLevel{ - IsolationLevelReadUncommitted: sarama.ReadUncommitted, - IsolationLevelReadCommitted: sarama.ReadCommitted, - "": sarama.ReadUncommitted, // Sanitize function will set DefaultIsolationLevel, it is sarama.ReadUncommitted. - } - - for isolationLevel, value := range strategies { - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{testBrokerAddr}, - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: testClientID, - IsolationLevel: isolationLevel, - } - options.Sanitize() - require.NoError(t, options.Validate()) - - kc := &KafkaConsumerGroupBridge{ - ctx: context.Background(), - log: logger(), - } - - sc, err := kc.prepareSaramaConfig(options) - require.NoError(t, err) - - sc.Consumer.IsolationLevel = value - } -} - -func TestSarama_Config_SASL_Authentication(t *testing.T) { - skipWhenShort(t) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: []string{testBrokerAddr}, - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: testClientID, - SASL: SASL{ - Enable: true, - User: "foobar", - Password: "password", - }, - } - options.Sanitize() - require.NoError(t, options.Validate()) - - kc := &KafkaConsumerGroupBridge{ - ctx: context.Background(), - log: logger(), - } - - sc, err := kc.prepareSaramaConfig(options) - require.NoError(t, err) - require.True(t, sc.Net.SASL.Enable) - require.Equal(t, "foobar", sc.Net.SASL.User) - require.Equal(t, "password", sc.Net.SASL.Password) -} - -func TestSarama_Multiple_Broker(t *testing.T) { - skipWhenShort(t) - - k := newKafkaCluster(t) - brokers := k.start(t, 3) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - } - - cg, messages := testStartConsumer(t, options) - - // Produce messages - testMessages := map[string][]string{ - testTopic: {"message-1", "message-2"}, - } - produceTestMessages(t, options, testMessages) - consumeTestMessages(t, messages, testMessages) - - require.NoError(t, cg.Close()) -} - -func TestSarama_Cluster_Member_Restart(t *testing.T) { - skipWhenShort(t) - - k := newKafkaCluster(t) - brokers := k.start(t, 2) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - } - - cg, messages := testStartConsumer(t, options) - - // Stop one of the cluster members here. - // Please take care that we don't update the initial list of broker addresses. - for portID, broker := range brokers { - t.Logf(fmt.Sprintf("Restart the member on %s", portID)) - port, err := strconv.Atoi(strings.Trim(portID, "/tcp")) - require.NoError(t, err) - - newBroker, err := k.restart(t, port, broker) - require.NoError(t, err) - brokers[portID] = newBroker - break - } - - // Stop publishMessagesContinuously properly. A leaking goroutine - // may lead to inconsistencies in the other tests. - var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) - - wg.Add(1) - go func() { - defer wg.Done() - publishMessagesContinuously(t, ctx, options) - }() - -L: - for { - select { - case <-time.After(10 * time.Second): - require.Fail(t, "No message received in 10 seconds") - case msg, ok := <-messages: - if !ok { - require.Fail(t, "messages channel is closed") - } - t.Logf("Message received from %s: %v", msg.Topic, string(msg.Value)) - break L - } - } - - require.NoError(t, cg.Close()) - - // Stop publishMessagesContinuously - cancel() - wg.Wait() -} - -func TestSarama_Cluster_Add_Member(t *testing.T) { - skipWhenShort(t) - - k := newKafkaCluster(t) - brokers := k.start(t, 1) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{testTopic}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - } - - cg, messages := testStartConsumer(t, options) - - // Add a new Kafka node to the cluster - var ports []int - for portID := range brokers { - port, err := strconv.Atoi(strings.Trim(portID, "/tcp")) - require.NoError(t, err) - ports = append(ports, port) - } - // Find an unoccupied port for the new node. - sort.Ints(ports) - // [9092, 9094, 9096] - port := ports[len(ports)-1] + 2 // A Kafka node uses 2 ports. Increase by 2 to find an unoccupied port. - _, err := k.addNewBroker(t, port) - require.NoError(t, err) - - // Stop publishMessagesContinuously properly. A leaking goroutine - // may lead to inconsistencies in the other tests. - var wg sync.WaitGroup - ctx, cancel := context.WithCancel(context.Background()) - - wg.Add(1) - go func() { - defer wg.Done() - publishMessagesContinuously(t, ctx, options) - }() - -L: - for { - select { - case <-time.After(10 * time.Second): - require.Fail(t, "No message received in 10 seconds") - case msg, ok := <-messages: - if !ok { - require.Fail(t, "messages channel is closed") - } - t.Logf("Message received from %s: %v", msg.Topic, string(msg.Value)) - break L - } - } - - require.NoError(t, cg.Close()) - - // Stop publishMessagesContinuously - cancel() - wg.Wait() -} - -func TestSarama_Subscribe_To_Multiple_Topics(t *testing.T) { - skipWhenShort(t) - - k := newKafkaCluster(t) - brokers := k.start(t, 1) - - options := &GraphQLSubscriptionOptions{ - BrokerAddresses: getBrokerAddresses(brokers), - Topics: []string{"test-topic-1", "test-topic-2"}, - GroupID: testConsumerGroup, - ClientID: "graphql-go-tools-test", - StartConsumingLatest: false, - } - - cg, messages := testStartConsumer(t, options) - - testMessages := map[string][]string{ - "test-topic-1": {"test-topic-1-message-1", "test-topic-1-message-2"}, - "test-topic-2": {"test-topic-2-message-1", "test-topic-2-message-2"}, - } - - produceTestMessages(t, options, testMessages) - - consumeTestMessages(t, messages, testMessages) - require.NoError(t, cg.Close()) -} diff --git a/pkg/engine/datasource/kafka_datasource/testdata/kafka_jaas.conf b/pkg/engine/datasource/kafka_datasource/testdata/kafka_jaas.conf deleted file mode 100644 index 58fbb4ffc..000000000 --- a/pkg/engine/datasource/kafka_datasource/testdata/kafka_jaas.conf +++ /dev/null @@ -1,8 +0,0 @@ -KafkaServer { - org.apache.kafka.common.security.plain.PlainLoginModule required - username="admin" - password="admin-secret" - user_admin="admin-secret" - user_alice="alice-secret"; -}; -Client{}; \ No newline at end of file diff --git a/pkg/engine/datasource/staticdatasource/static_datasource.go b/pkg/engine/datasource/staticdatasource/static_datasource.go index dfe51458d..c1e31ee7f 100644 --- a/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -60,7 +60,7 @@ func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (_ Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) { +func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) { _, err = w.Write(input) return } diff --git a/pkg/engine/resolve/dataloader_test.go b/pkg/engine/resolve/dataloader_test.go index 270628eb4..650404707 100644 --- a/pkg/engine/resolve/dataloader_test.go +++ b/pkg/engine/resolve/dataloader_test.go @@ -93,7 +93,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: userService, - }, &Context{Context: context.Background()}, `{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}` + }, &Context{ctx: context.Background()}, `{"data":{"me": {"id": "1234","username": "Me","__typename": "User"}}}` })) t.Run("requires nested request", testFn(map[int]fetchState{ @@ -149,7 +149,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: userService, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` })) t.Run("fetch error", func(t *testing.T) { @@ -187,7 +187,7 @@ func TestDataLoader_Load(t *testing.T) { bufPair := NewBufPair() err := dl.Load( - &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, + &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, &SingleFetch{ BufferId: 2, InputTemplate: InputTemplate{ @@ -257,7 +257,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: nil, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":11, "username": "Username 11"}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":11, "username": "Username 11"}}` })) t.Run("fetch errors in corresponding call", testFnErr(map[int]fetchState{ @@ -287,7 +287,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: nil, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `someError` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `someError` })) t.Run("doesn't requires nested request", testFn(map[int]fetchState{ @@ -317,7 +317,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: nil, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":22, "username": "Username 22"}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":22, "username": "Username 22"}}` })) t.Run("requires nested request with array in path", testFn(map[int]fetchState{ @@ -373,7 +373,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: userService, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp", arrayElementKey}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp", arrayElementKey}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` })) t.Run("requires nested request with null array in path", testFn(map[int]fetchState{ @@ -429,7 +429,7 @@ func TestDataLoader_Load(t *testing.T) { }, }, DataSource: userService, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp", arrayElementKey}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp", arrayElementKey}}, `{"data":{"user": {"id":11, "username": "Username 11"}}}` })) } @@ -511,7 +511,7 @@ func TestDataLoader_LoadBatch(t *testing.T) { DataSource: userService, }, BatchFactory: batchFactory, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"name": "Trilby"}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"name": "Trilby"}` })) t.Run("deeply nested fetch with varying fields", testFn(map[int]fetchState{ @@ -577,7 +577,7 @@ func TestDataLoader_LoadBatch(t *testing.T) { DataSource: carService, }, BatchFactory: batchFactory, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"vehicle", "engine"}}, `{"horsepower": 200}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"vehicle", "engine"}}, `{"horsepower": 200}` })) t.Run("doesn't requires nested request", testFn(map[int]fetchState{ @@ -608,7 +608,7 @@ func TestDataLoader_LoadBatch(t *testing.T) { }, }, }, - }, &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":22, "username": "Username 22"}}` + }, &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, `{"user": {"id":22, "username": "Username 22"}}` })) t.Run("fetch error", func(t *testing.T) { @@ -648,7 +648,7 @@ func TestDataLoader_LoadBatch(t *testing.T) { Return(expErr) err := dl.LoadBatch( - &Context{Context: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, + &Context{ctx: context.Background(), lastFetchID: 1, responseElements: []string{"someProp"}}, &BatchFetch{ Fetch: &SingleFetch{ BufferId: 2, diff --git a/pkg/engine/resolve/defer_test.go b/pkg/engine/resolve/defer_test.go index 48e3eac83..9060cf9ea 100644 --- a/pkg/engine/resolve/defer_test.go +++ b/pkg/engine/resolve/defer_test.go @@ -424,7 +424,7 @@ func BenchmarkDefer(b *testing.B) { _ = resolver.ResolveGraphQLStreamingResponse(ctx, res, nil, writer) ctx.Free() - ctx.Context = bgCtx + ctx.ctx = bgCtx // writer.flushed = writer.flushed[:0] } } diff --git a/pkg/engine/resolve/fetcher.go b/pkg/engine/resolve/fetcher.go index 4902bffee..8f71b0df3 100644 --- a/pkg/engine/resolve/fetcher.go +++ b/pkg/engine/resolve/fetcher.go @@ -56,7 +56,7 @@ func (f *Fetcher) Fetch(ctx *Context, fetch *SingleFetch, preparedInput *fastbuf } if !f.EnableSingleFlightLoader || fetch.DisallowSingleFlight { - err = fetch.DataSource.Load(ctx.Context, preparedInput.Bytes(), dataBuf) + err = fetch.DataSource.Load(ctx.Context(), preparedInput.Bytes(), dataBuf) extractResponse(dataBuf.Bytes(), buf, fetch.ProcessResponseConfig) if ctx.afterFetchHook != nil { @@ -103,7 +103,7 @@ func (f *Fetcher) Fetch(ctx *Context, fetch *SingleFetch, preparedInput *fastbuf f.inflightFetchMu.Unlock() - err = fetch.DataSource.Load(ctx.Context, preparedInput.Bytes(), dataBuf) + err = fetch.DataSource.Load(ctx.Context(), preparedInput.Bytes(), dataBuf) extractResponse(dataBuf.Bytes(), &inflight.bufPair, fetch.ProcessResponseConfig) inflight.err = err diff --git a/pkg/engine/resolve/inputtemplate.go b/pkg/engine/resolve/inputtemplate.go index 17b29e100..2c78c1677 100644 --- a/pkg/engine/resolve/inputtemplate.go +++ b/pkg/engine/resolve/inputtemplate.go @@ -37,23 +37,28 @@ type InputTemplate struct { var setTemplateOutputNull = errors.New("set to null") -func (i *InputTemplate) Render(ctx *Context, data []byte, preparedInput *fastbuffer.FastBuffer) (err error) { - undefinedVariables := make([]string, 0) +func (i *InputTemplate) Render(ctx *Context, data []byte, preparedInput *fastbuffer.FastBuffer) error { + var undefinedVariables []string - for j := range i.Segments { - switch i.Segments[j].SegmentType { + for _, segment := range i.Segments { + var err error + switch segment.SegmentType { case StaticSegmentType: - preparedInput.WriteBytes(i.Segments[j].Data) + preparedInput.WriteBytes(segment.Data) case VariableSegmentType: - switch i.Segments[j].VariableKind { + switch segment.VariableKind { case ObjectVariableKind: - err = i.renderObjectVariable(ctx, data, i.Segments[j], preparedInput) + err = i.renderObjectVariable(ctx.Context(), data, segment, preparedInput) case ContextVariableKind: - err = i.renderContextVariable(ctx, i.Segments[j], preparedInput, &undefinedVariables) + var undefined bool + undefined, err = i.renderContextVariable(ctx, segment, preparedInput) + if undefined { + undefinedVariables = append(undefinedVariables, segment.VariableSourcePath[0]) + } case HeaderVariableKind: - err = i.renderHeaderVariable(ctx, i.Segments[j].VariableSourcePath, preparedInput) + err = i.renderHeaderVariable(ctx, segment.VariableSourcePath, preparedInput) default: - err = fmt.Errorf("InputTemplate.Render: cannot resolve variable of kind: %d", i.Segments[j].VariableKind) + err = fmt.Errorf("InputTemplate.Render: cannot resolve variable of kind: %d", segment.VariableKind) } if err != nil { if errors.Is(err, setTemplateOutputNull) { @@ -67,10 +72,12 @@ func (i *InputTemplate) Render(ctx *Context, data []byte, preparedInput *fastbuf } if len(undefinedVariables) > 0 { - ctx.Context = httpclient.CtxSetUndefinedVariables(ctx.Context, undefinedVariables) + output := httpclient.SetUndefinedVariables(preparedInput.Bytes(), undefinedVariables) + // The returned slice might be different, we need to copy back the data + preparedInput.Reset() + preparedInput.WriteBytes(output) } - - return + return nil } func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables []byte, segment TemplateSegment, preparedInput *fastbuffer.FastBuffer) error { @@ -94,15 +101,16 @@ func (i *InputTemplate) renderObjectVariable(ctx context.Context, variables []by return segment.Renderer.RenderVariable(ctx, value, preparedInput) } -func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *fastbuffer.FastBuffer, undefinedVariables *[]string) error { +func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegment, preparedInput *fastbuffer.FastBuffer) (variableWasUndefined bool, err error) { value, valueType, offset, err := jsonparser.Get(ctx.Variables, segment.VariableSourcePath...) if err != nil || valueType == jsonparser.Null { + undefined := false if err == jsonparser.KeyPathNotFoundError { - *undefinedVariables = append(*undefinedVariables, segment.VariableSourcePath[0]) + undefined = true } preparedInput.WriteBytes(literal.NULL) - return nil + return undefined, nil } if valueType == jsonparser.String { value = ctx.Variables[offset-len(value)-2 : offset] @@ -113,7 +121,7 @@ func (i *InputTemplate) renderContextVariable(ctx *Context, segment TemplateSegm } } } - return segment.Renderer.RenderVariable(ctx, value, preparedInput) + return false, segment.Renderer.RenderVariable(ctx.Context(), value, preparedInput) } func (i *InputTemplate) renderHeaderVariable(ctx *Context, path []string, preparedInput *fastbuffer.FastBuffer) error { diff --git a/pkg/engine/resolve/inputtemplate_test.go b/pkg/engine/resolve/inputtemplate_test.go index dd3ca992a..9b209bd7d 100644 --- a/pkg/engine/resolve/inputtemplate_test.go +++ b/pkg/engine/resolve/inputtemplate_test.go @@ -278,14 +278,14 @@ func TestInputTemplate_Render(t *testing.T) { }, } ctx := &Context{ - Context: context.Background(), + ctx: context.Background(), Variables: []byte(""), } buf := fastbuffer.New() err := template.Render(ctx, nil, buf) assert.NoError(t, err) out := buf.String() - assert.Equal(t, `{"key":null}`, out) + assert.Equal(t, `{"undefined":["a"],"key":null}`, out) }) t.Run("when SetTemplateOutputToNullOnVariableNull: true", func(t *testing.T) { diff --git a/pkg/engine/resolve/resolve.go b/pkg/engine/resolve/resolve.go index f1ac745a1..a936be77f 100644 --- a/pkg/engine/resolve/resolve.go +++ b/pkg/engine/resolve/resolve.go @@ -115,7 +115,7 @@ type AfterFetchHook interface { } type Context struct { - context.Context + ctx context.Context Variables []byte Request Request pathElements [][]byte @@ -138,8 +138,11 @@ type Request struct { } func NewContext(ctx context.Context) *Context { + if ctx == nil { + panic("nil context.Context") + } return &Context{ - Context: ctx, + ctx: ctx, Variables: make([]byte, 0, 4096), pathPrefix: make([]byte, 0, 4096), pathElements: make([][]byte, 0, 16), @@ -152,7 +155,20 @@ func NewContext(ctx context.Context) *Context { } } -func (c *Context) Clone() Context { +func (c *Context) Context() context.Context { + return c.ctx +} + +func (c *Context) WithContext(ctx context.Context) *Context { + if ctx == nil { + panic("nil context.Context") + } + cpy := *c + cpy.ctx = ctx + return &cpy +} + +func (c *Context) clone() Context { variables := make([]byte, len(c.Variables)) copy(variables, c.Variables) pathPrefix := make([]byte, len(c.pathPrefix)) @@ -175,7 +191,7 @@ func (c *Context) Clone() Context { copy(patches[i].data, c.patches[i].data) } return Context{ - Context: c.Context, + ctx: c.ctx, Variables: variables, Request: c.Request, pathElements: pathElements, @@ -191,7 +207,7 @@ func (c *Context) Clone() Context { } func (c *Context) Free() { - c.Context = nil + c.ctx = nil c.Variables = c.Variables[:0] c.pathPrefix = c.pathPrefix[:0] c.pathElements = c.pathElements[:0] @@ -530,7 +546,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ copy(subscriptionInput, rendered) r.freeBufPair(buf) - c, cancel := context.WithCancel(ctx) + c, cancel := context.WithCancel(ctx.Context()) defer cancel() resolverDone := r.ctx.Done() @@ -586,7 +602,7 @@ func (r *Resolver) ResolveGraphQLStreamingResponse(ctx *Context, response *Graph buf.Write(literal.LBRACK) - done := ctx.Context.Done() + done := ctx.Context().Done() Loop: for { @@ -823,7 +839,7 @@ func (r *Resolver) resolveArrayAsynchronous(ctx *Context, array *Array, arrayIte itemBuf := r.getBufPair() *bufSlice = append(*bufSlice, itemBuf) itemData := (*arrayItems)[i] - cloned := ctx.Clone() + cloned := ctx.clone() go func(ctx Context, i int) { ctx.addPathElement([]byte(strconv.Itoa(i))) if e := r.resolveNode(&ctx, array.Item, itemData, itemBuf); e != nil && !errors.Is(e, errTypeNameSkipped) { @@ -1224,7 +1240,7 @@ func (r *Resolver) freeResultSet(set *resultSet) { func (r *Resolver) resolveFetch(ctx *Context, fetch Fetch, data []byte, set *resultSet) (err error) { // if context is cancelled, we should not resolve the fetch - if errors.Is(ctx.Err(), context.Canceled) { + if errors.Is(ctx.Context().Err(), context.Canceled) { return nil } diff --git a/pkg/engine/resolve/resolve_mock_test.go b/pkg/engine/resolve/resolve_mock_test.go index fd8b54b9f..7f185d187 100644 --- a/pkg/engine/resolve/resolve_mock_test.go +++ b/pkg/engine/resolve/resolve_mock_test.go @@ -6,36 +6,37 @@ package resolve import ( context "context" - gomock "github.com/golang/mock/gomock" - fastbuffer "github.com/wundergraph/graphql-go-tools/pkg/fastbuffer" io "io" reflect "reflect" + + gomock "github.com/golang/mock/gomock" + fastbuffer "github.com/wundergraph/graphql-go-tools/pkg/fastbuffer" ) -// MockDataSource is a mock of DataSource interface +// MockDataSource is a mock of DataSource interface. type MockDataSource struct { ctrl *gomock.Controller recorder *MockDataSourceMockRecorder } -// MockDataSourceMockRecorder is the mock recorder for MockDataSource +// MockDataSourceMockRecorder is the mock recorder for MockDataSource. type MockDataSourceMockRecorder struct { mock *MockDataSource } -// NewMockDataSource creates a new mock instance +// NewMockDataSource creates a new mock instance. func NewMockDataSource(ctrl *gomock.Controller) *MockDataSource { mock := &MockDataSource{ctrl: ctrl} mock.recorder = &MockDataSourceMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { return m.recorder } -// Load mocks base method +// Load mocks base method. func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 io.Writer) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) @@ -43,118 +44,118 @@ func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 io.Writer) return ret0 } -// Load indicates an expected call of Load +// Load indicates an expected call of Load. func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) } -// MockBeforeFetchHook is a mock of BeforeFetchHook interface +// MockBeforeFetchHook is a mock of BeforeFetchHook interface. type MockBeforeFetchHook struct { ctrl *gomock.Controller recorder *MockBeforeFetchHookMockRecorder } -// MockBeforeFetchHookMockRecorder is the mock recorder for MockBeforeFetchHook +// MockBeforeFetchHookMockRecorder is the mock recorder for MockBeforeFetchHook. type MockBeforeFetchHookMockRecorder struct { mock *MockBeforeFetchHook } -// NewMockBeforeFetchHook creates a new mock instance +// NewMockBeforeFetchHook creates a new mock instance. func NewMockBeforeFetchHook(ctrl *gomock.Controller) *MockBeforeFetchHook { mock := &MockBeforeFetchHook{ctrl: ctrl} mock.recorder = &MockBeforeFetchHookMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockBeforeFetchHook) EXPECT() *MockBeforeFetchHookMockRecorder { return m.recorder } -// OnBeforeFetch mocks base method +// OnBeforeFetch mocks base method. func (m *MockBeforeFetchHook) OnBeforeFetch(arg0 HookContext, arg1 []byte) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnBeforeFetch", arg0, arg1) } -// OnBeforeFetch indicates an expected call of OnBeforeFetch +// OnBeforeFetch indicates an expected call of OnBeforeFetch. func (mr *MockBeforeFetchHookMockRecorder) OnBeforeFetch(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnBeforeFetch", reflect.TypeOf((*MockBeforeFetchHook)(nil).OnBeforeFetch), arg0, arg1) } -// MockAfterFetchHook is a mock of AfterFetchHook interface +// MockAfterFetchHook is a mock of AfterFetchHook interface. type MockAfterFetchHook struct { ctrl *gomock.Controller recorder *MockAfterFetchHookMockRecorder } -// MockAfterFetchHookMockRecorder is the mock recorder for MockAfterFetchHook +// MockAfterFetchHookMockRecorder is the mock recorder for MockAfterFetchHook. type MockAfterFetchHookMockRecorder struct { mock *MockAfterFetchHook } -// NewMockAfterFetchHook creates a new mock instance +// NewMockAfterFetchHook creates a new mock instance. func NewMockAfterFetchHook(ctrl *gomock.Controller) *MockAfterFetchHook { mock := &MockAfterFetchHook{ctrl: ctrl} mock.recorder = &MockAfterFetchHookMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockAfterFetchHook) EXPECT() *MockAfterFetchHookMockRecorder { return m.recorder } -// OnData mocks base method +// OnData mocks base method. func (m *MockAfterFetchHook) OnData(arg0 HookContext, arg1 []byte, arg2 bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnData", arg0, arg1, arg2) } -// OnData indicates an expected call of OnData +// OnData indicates an expected call of OnData. func (mr *MockAfterFetchHookMockRecorder) OnData(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnData", reflect.TypeOf((*MockAfterFetchHook)(nil).OnData), arg0, arg1, arg2) } -// OnError mocks base method +// OnError mocks base method. func (m *MockAfterFetchHook) OnError(arg0 HookContext, arg1 []byte, arg2 bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "OnError", arg0, arg1, arg2) } -// OnError indicates an expected call of OnError +// OnError indicates an expected call of OnError. func (mr *MockAfterFetchHookMockRecorder) OnError(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockAfterFetchHook)(nil).OnError), arg0, arg1, arg2) } -// MockDataSourceBatch is a mock of DataSourceBatch interface +// MockDataSourceBatch is a mock of DataSourceBatch interface. type MockDataSourceBatch struct { ctrl *gomock.Controller recorder *MockDataSourceBatchMockRecorder } -// MockDataSourceBatchMockRecorder is the mock recorder for MockDataSourceBatch +// MockDataSourceBatchMockRecorder is the mock recorder for MockDataSourceBatch. type MockDataSourceBatchMockRecorder struct { mock *MockDataSourceBatch } -// NewMockDataSourceBatch creates a new mock instance +// NewMockDataSourceBatch creates a new mock instance. func NewMockDataSourceBatch(ctrl *gomock.Controller) *MockDataSourceBatch { mock := &MockDataSourceBatch{ctrl: ctrl} mock.recorder = &MockDataSourceBatchMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockDataSourceBatch) EXPECT() *MockDataSourceBatchMockRecorder { return m.recorder } -// Demultiplex mocks base method +// Demultiplex mocks base method. func (m *MockDataSourceBatch) Demultiplex(arg0 *BufPair, arg1 []*BufPair) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Demultiplex", arg0, arg1) @@ -162,13 +163,13 @@ func (m *MockDataSourceBatch) Demultiplex(arg0 *BufPair, arg1 []*BufPair) error return ret0 } -// Demultiplex indicates an expected call of Demultiplex +// Demultiplex indicates an expected call of Demultiplex. func (mr *MockDataSourceBatchMockRecorder) Demultiplex(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Demultiplex", reflect.TypeOf((*MockDataSourceBatch)(nil).Demultiplex), arg0, arg1) } -// Input mocks base method +// Input mocks base method. func (m *MockDataSourceBatch) Input() *fastbuffer.FastBuffer { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Input") @@ -176,36 +177,36 @@ func (m *MockDataSourceBatch) Input() *fastbuffer.FastBuffer { return ret0 } -// Input indicates an expected call of Input +// Input indicates an expected call of Input. func (mr *MockDataSourceBatchMockRecorder) Input() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Input", reflect.TypeOf((*MockDataSourceBatch)(nil).Input)) } -// MockDataSourceBatchFactory is a mock of DataSourceBatchFactory interface +// MockDataSourceBatchFactory is a mock of DataSourceBatchFactory interface. type MockDataSourceBatchFactory struct { ctrl *gomock.Controller recorder *MockDataSourceBatchFactoryMockRecorder } -// MockDataSourceBatchFactoryMockRecorder is the mock recorder for MockDataSourceBatchFactory +// MockDataSourceBatchFactoryMockRecorder is the mock recorder for MockDataSourceBatchFactory. type MockDataSourceBatchFactoryMockRecorder struct { mock *MockDataSourceBatchFactory } -// NewMockDataSourceBatchFactory creates a new mock instance +// NewMockDataSourceBatchFactory creates a new mock instance. func NewMockDataSourceBatchFactory(ctrl *gomock.Controller) *MockDataSourceBatchFactory { mock := &MockDataSourceBatchFactory{ctrl: ctrl} mock.recorder = &MockDataSourceBatchFactoryMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockDataSourceBatchFactory) EXPECT() *MockDataSourceBatchFactoryMockRecorder { return m.recorder } -// CreateBatch mocks base method +// CreateBatch mocks base method. func (m *MockDataSourceBatchFactory) CreateBatch(arg0 [][]byte) (DataSourceBatch, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateBatch", arg0) @@ -214,7 +215,7 @@ func (m *MockDataSourceBatchFactory) CreateBatch(arg0 [][]byte) (DataSourceBatch return ret0, ret1 } -// CreateBatch indicates an expected call of CreateBatch +// CreateBatch indicates an expected call of CreateBatch. func (mr *MockDataSourceBatchFactoryMockRecorder) CreateBatch(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateBatch", reflect.TypeOf((*MockDataSourceBatchFactory)(nil).CreateBatch), arg0) diff --git a/pkg/engine/resolve/resolve_test.go b/pkg/engine/resolve/resolve_test.go index 6c6341e6d..7c2cd1786 100644 --- a/pkg/engine/resolve/resolve_test.go +++ b/pkg/engine/resolve/resolve_test.go @@ -1,5 +1,7 @@ package resolve +// go:generate mockgen -package resolve -destination resolve_mock_test.go . DataSource,BeforeFetchHook,AfterFetchHook,DataSourceBatch,DataSourceBatchFactory + import ( "bytes" "context" @@ -150,10 +152,10 @@ func TestResolver_ResolveNode(t *testing.T) { t.Run("Nullable empty object", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ Nullable: true, - }, Context{Context: context.Background()}, `null` + }, Context{ctx: context.Background()}, `null` })) t.Run("empty object", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { - return &EmptyObject{}, Context{Context: context.Background()}, `{}` + return &EmptyObject{}, Context{ctx: context.Background()}, `{}` })) t.Run("object with null field", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -163,7 +165,7 @@ func TestResolver_ResolveNode(t *testing.T) { Value: &Null{}, }, }, - }, Context{Context: context.Background()}, `{"foo":null}` + }, Context{ctx: context.Background()}, `{"foo":null}` })) t.Run("default graphql object", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -175,7 +177,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":null}` + }, Context{ctx: context.Background()}, `{"data":null}` })) t.Run("graphql object with simple data source", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -245,7 +247,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` })) t.Run("skip single field should resolve to empty response", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -279,7 +281,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{}}}` })) t.Run("skip multiple fields should resolve to empty response", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -323,7 +325,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{}}}` })) t.Run("skip __typename field be possible", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -365,7 +367,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{"id":"1"}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{"id":"1"}}}` })) t.Run("include __typename field be possible", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -407,7 +409,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"include":true}`)}, `{"data":{"user":{"id":"1","__typename":"User"}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"include":true}`)}, `{"data":{"user":{"id":"1","__typename":"User"}}}` })) t.Run("include __typename field with false value", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -449,7 +451,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"include":false}`)}, `{"data":{"user":{"id":"1"}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"include":false}`)}, `{"data":{"user":{"id":"1"}}}` })) t.Run("skip field when skip variable is true", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -521,7 +523,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"skip":true}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` })) t.Run("don't skip field when skip variable is false", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -593,7 +595,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"skip":false}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"skip":false}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` })) t.Run("don't skip field when skip variable is missing", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -665,7 +667,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` })) t.Run("include field when include variable is true", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -737,7 +739,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"include":true}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"include":true}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` })) t.Run("exclude field when include variable is false", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -809,7 +811,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"include":false}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"include":false}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` })) t.Run("exclude field when include variable is missing", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -881,7 +883,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{}`)}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky"}}}}` })) t.Run("fetch with context variable resolver", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) @@ -929,7 +931,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"id":1}`)}, `{"name":"Jens"}` + }, Context{ctx: context.Background(), Variables: []byte(`{"id":1}`)}, `{"name":"Jens"}` })) t.Run("resolve array of strings", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -950,7 +952,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"strings":["Alex","true","123"]}` + }, Context{ctx: context.Background()}, `{"strings":["Alex","true","123"]}` })) t.Run("resolve array of mixed scalar types", testErrFn(func(t *testing.T, r *Resolver, ctrl *gomock.Controller) (node Node, ctx Context, expectedErr string) { return &Object{ @@ -971,7 +973,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `invalid value type 'number' for path /data/strings/2, expecting string, got: 123. You can fix this by configuring this field as Int/Float/JSON Scalar` + }, Context{ctx: context.Background()}, `invalid value type 'number' for path /data/strings/2, expecting string, got: 123. You can fix this by configuring this field as Int/Float/JSON Scalar` })) t.Run("resolve array items", func(t *testing.T) { t.Run("with unescape json enabled", func(t *testing.T) { @@ -995,7 +997,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"jsonList":[{"field":"value"}]}` + }, Context{ctx: context.Background()}, `{"jsonList":[{"field":"value"}]}` })) t.Run("json input", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1017,7 +1019,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"jsonList":[{"field":"value"}]}` + }, Context{ctx: context.Background()}, `{"jsonList":[{"field":"value"}]}` })) }) t.Run("with unescape json disabled", func(t *testing.T) { @@ -1041,7 +1043,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"jsonList":["{\"field\":\"value\"}"]}` + }, Context{ctx: context.Background()}, `{"jsonList":["{\"field\":\"value\"}"]}` })) t.Run("json input", testErrFn(func(t *testing.T, r *Resolver, ctrl *gomock.Controller) (node Node, ctx Context, expectedErr string) { return &Object{ @@ -1063,7 +1065,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `invalid value type 'object' for path /data/jsonList/0, expecting string, got: {"field":"value"}. You can fix this by configuring this field as Int/Float/JSON Scalar` })) }) @@ -1190,7 +1192,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"synchronousFriends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}],"asynchronousFriends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}],"nullableFriends":null,"strings":["foo","bar","baz"],"integers":[123,456,789],"floats":[1.2,3.4,5.6],"booleans":[true,false,true]}` + }, Context{ctx: context.Background()}, `{"synchronousFriends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}],"asynchronousFriends":[{"id":1,"name":"Alex"},{"id":2,"name":"Patric"}],"nullableFriends":null,"strings":["foo","bar","baz"],"integers":[123,456,789],"floats":[1.2,3.4,5.6],"booleans":[true,false,true]}` })) t.Run("array response from data source", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1220,7 +1222,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{"pets":[{"name":"Woofie"}]}` })) t.Run("non null object with field condition can be null", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1248,7 +1250,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{}` })) t.Run("object with multiple type conditions", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1318,7 +1320,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{"data":{"namespaceCreate":{"code":"UserAlreadyHasPersonalNamespace","message":""}}}` })) t.Run("resolve fieldsets based on __typename", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1350,7 +1352,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{"pets":[{"name":"Woofie"}]}` })) @@ -1400,7 +1402,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{"pet":{"id":"1","detail":null}}` })) @@ -1434,7 +1436,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, + }, Context{ctx: context.Background()}, `{"pets":[{"name":"Woofie"}]}` })) t.Run("parent object variables", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1512,7 +1514,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"id":1,"name":"Jens","pet":{"name":"Woofie"}}` + }, Context{ctx: context.Background()}, `{"id":1,"name":"Jens","pet":{"name":"Woofie"}}` })) t.Run("with unescape json enabled", func(t *testing.T) { t.Run("json object within a string", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1541,7 +1543,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a JSON object - }, Context{Context: context.Background()}, `{"data":{ "hello": "world", "numberAsString": "1", "number": 1, "bool": true, "null": null, "array": [1,2,3], "object": {"key": "value"} }}` + }, Context{ctx: context.Background()}, `{"data":{ "hello": "world", "numberAsString": "1", "number": 1, "bool": true, "null": null, "array": [1,2,3], "object": {"key": "value"} }}` })) t.Run("json array within a string", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1569,7 +1571,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a JSON array - }, Context{Context: context.Background()}, `{"data":[1, 2, 3]}` + }, Context{ctx: context.Background()}, `{"data":[1, 2, 3]}` })) t.Run("string with array and objects brackets", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1597,7 +1599,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a string - }, Context{Context: context.Background()}, `{"data":"hi[1beep{2}]"}` + }, Context{ctx: context.Background()}, `{"data":"hi[1beep{2}]"}` })) t.Run("plain scalar values within a string", func(t *testing.T) { t.Run("boolean", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { @@ -1622,7 +1624,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a string - }, Context{Context: context.Background()}, `{"data":"true"}` + }, Context{ctx: context.Background()}, `{"data":"true"}` })) t.Run("int", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1650,7 +1652,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a string - }, Context{Context: context.Background()}, `{"data":"1"}` + }, Context{ctx: context.Background()}, `{"data":"1"}` })) t.Run("float", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1678,7 +1680,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a string - }, Context{Context: context.Background()}, `{"data":"2.0"}` + }, Context{ctx: context.Background()}, `{"data":"2.0"}` })) t.Run("null", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1706,7 +1708,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a string - }, Context{Context: context.Background()}, `{"data":"null"}` + }, Context{ctx: context.Background()}, `{"data":"null"}` })) t.Run("string", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1733,7 +1735,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expect data value to be valid JSON string - }, Context{Context: context.Background()}, `{"data":"hello world"}` + }, Context{ctx: context.Background()}, `{"data":"hello world"}` })) }) t.Run("plain scalar values as is", func(t *testing.T) { @@ -1759,7 +1761,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a JSON boolean - }, Context{Context: context.Background()}, `{"data":true}` + }, Context{ctx: context.Background()}, `{"data":true}` })) t.Run("int", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1787,7 +1789,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a JSON boolean - }, Context{Context: context.Background()}, `{"data":1}` + }, Context{ctx: context.Background()}, `{"data":1}` })) t.Run("float", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1815,7 +1817,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expected output is a JSON boolean - }, Context{Context: context.Background()}, `{"data":2.0}` + }, Context{ctx: context.Background()}, `{"data":2.0}` })) t.Run("null", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node Node, ctx Context, expectedOutput string) { return &Object{ @@ -1842,7 +1844,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, // expect data value to be valid JSON string - }, Context{Context: context.Background()}, `{"data":null}` + }, Context{ctx: context.Background()}, `{"data":null}` })) }) }) @@ -1952,7 +1954,7 @@ func TestResolver_WithHooks(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), beforeFetchHook: beforeFetch, afterFetchHook: afterFetch}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` + }, Context{ctx: context.Background(), beforeFetchHook: beforeFetch, afterFetchHook: afterFetch}, `{"data":{"user":{"id":"1","name":"Jens","registered":true,"pet":{"name":"Barky","kind":"Dog"}}}}` })) } @@ -1997,7 +1999,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { Data: &Object{ Nullable: true, }, - }, Context{Context: context.Background()}, `{"data":null}` + }, Context{ctx: context.Background()}, `{"data":null}` })) t.Run("__typename without renaming", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2064,7 +2066,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User","aliased":"User","rewritten":"User"}}}` + }, Context{ctx: context.Background()}, `{"data":{"user":{"id":1,"name":"Jannik","__typename":"User","aliased":"User","rewritten":"User"}}}` })) t.Run("__typename with renaming", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2132,7 +2134,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, Context{ - Context: context.Background(), + ctx: context.Background(), RenameTypeNames: []RenameTypeName{ { From: []byte("User"), @@ -2174,7 +2176,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"unable to resolve","locations":[{"line":3,"column":4}],"path":["country"]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"unable to resolve","locations":[{"line":3,"column":4}],"path":["country"]}],"data":null}` })) t.Run("fetch with simple error", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) @@ -2207,7 +2209,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"errorMessage"}],"data":{"name":null}}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"errorMessage"}],"data":{"name":null}}` })) t.Run("nested fetch error for non-nullable field", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) @@ -2254,7 +2256,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"errorMessage"},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["nestedObject"]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"errorMessage"},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["nestedObject"]}],"data":null}` })) t.Run("fetch with two Errors", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) @@ -2288,7 +2290,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}],"data":{"name":null}}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"errorMessage1"},{"message":"errorMessage2"}],"data":{"name":null}}` })) t.Run("not nullable object in nullable field", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2328,7 +2330,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":{"nullableField":null}}` + }, Context{ctx: context.Background()}, `{"data":{"nullableField":null}}` })) t.Run("null field should bubble up to parent with error", testFnWithError(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2474,7 +2476,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `invalid value type 'array' for path /data/stringObject/stringField, expecting string, got: [{"id":1},{"id":2},{"id":3}]. You can fix this by configuring this field as Int/Float Scalar` + }, Context{ctx: context.Background()}, `invalid value type 'array' for path /data/stringObject/stringField, expecting string, got: [{"id":1},{"id":2},{"id":3}]. You can fix this by configuring this field as Int/Float Scalar` })) t.Run("empty nullable array should resolve correctly", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2506,7 +2508,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":{"nullableArray":[]}}` + }, Context{ctx: context.Background()}, `{"data":{"nullableArray":[]}}` })) t.Run("empty not nullable array should resolve correctly", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2539,7 +2541,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"data":{"notNullableArray":[]}}` + }, Context{ctx: context.Background()}, `{"data":{"notNullableArray":[]}}` })) t.Run("when data null not nullable array should resolve to data null and errors", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2591,7 +2593,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"unable to resolve","locations":[{"line":0,"column":0}]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"unable to resolve","locations":[{"line":0,"column":0}]}],"data":null}` })) t.Run("when data null and errors present not nullable array should result to null data upsteam error and resolve error", testFn(false, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2629,7 +2631,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background()}, `{"errors":[{"message":"Could not get a name","locations":[{"line":3,"column":5}],"path":["todos",0,"name"]},{"message":"unable to resolve","locations":[{"line":0,"column":0}]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Could not get a name","locations":[{"line":3,"column":5}],"path":["todos",0,"name"]},{"message":"unable to resolve","locations":[{"line":0,"column":0}]}],"data":null}` })) t.Run("complex GraphQL Server plan", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { serviceOne := NewMockDataSource(ctrl) @@ -2879,7 +2881,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`)}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"firstArg":"firstArgValue","thirdArg":123,"secondArg": true, "fourthArg": 12.34}`)}, `{"data":{"serviceOne":{"fieldOne":"fieldOneValue"},"serviceTwo":{"fieldTwo":"fieldTwoValue","serviceOneResponse":{"fieldOne":"fieldOneValue"}},"anotherServiceOne":{"fieldOne":"anotherFieldOneValue"},"secondServiceTwo":{"fieldTwo":"secondFieldTwoValue"},"reusingServiceOne":{"fieldOne":"reUsingFieldOneValue"}}}` })) t.Run("federation", testFn(true, false, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { @@ -3069,7 +3071,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-1","name":"Trilby"}}]}}}` + }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-1","name":"Trilby"}}]}}}` })) t.Run("federation with enabled dataloader", testFn(true, true, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) @@ -3273,7 +3275,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Fedora"}}]}}}` + }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby"}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Fedora"}}]}}}` })) t.Run("federation with null response", testFn(true, true, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) @@ -3495,7 +3497,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, // ... `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"foo","product":{"upc":"top-1","name":"Trilby"}},{"body":"bar","product":{"upc":"top-2","name":"Fedora"}},{"body":"baz","product":null},{"body":"bat","product":null},{"body":"bal","product":{"upc":"top-5","name":"Boater"}},{"body":"ban","product":{"upc":"top-6","name":"Top Hat"}}]}}} - }, Context{Context: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"foo","product":{"upc":"top-1","name":"Trilby"}},{"body":"bar","product":{"upc":"top-2","name":"Fedora"}},{"body":"baz","product":null},{"body":"bat","product":{"upc":"top-4","name":"Boater"}},{"body":"bal","product":{"upc":"top-5","name":"Top Hat"}},{"body":"ban","product":{"upc":"top-6","name":"Bowler"}}]}}}` + }, Context{ctx: context.Background(), Variables: nil}, `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"foo","product":{"upc":"top-1","name":"Trilby"}},{"body":"bar","product":{"upc":"top-2","name":"Fedora"}},{"body":"baz","product":null},{"body":"bat","product":{"upc":"top-4","name":"Boater"}},{"body":"bal","product":{"upc":"top-5","name":"Top Hat"}},{"body":"ban","product":{"upc":"top-6","name":"Bowler"}}]}}}` })) t.Run("federation with enabled dataloader and fetch error ", testFn(true, true, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { @@ -3694,7 +3696,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: nil}, `{"errors":[{"message":"errorMessage"},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["me","reviews","0","product"]},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["me","reviews","1","product"]}],"data":{"me":{"id":"1234","username":"Me","reviews":[null,null]}}}` + }, Context{ctx: context.Background(), Variables: nil}, `{"errors":[{"message":"errorMessage"},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["me","reviews","0","product"]},{"message":"unable to resolve","locations":[{"line":0,"column":0}],"path":["me","reviews","1","product"]}],"data":{"me":{"id":"1234","username":"Me","reviews":[null,null]}}}` })) t.Run("federation with optional variable", testFn(true, true, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { userService := NewMockDataSource(ctrl) @@ -3929,7 +3931,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{Context: context.Background(), Variables: []byte(`{"companyId":"abc123","date":null}`)}, `{"data":{"me":{"employment":{"id":"xyz987","times":[{"id":"t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}}}}` + }, Context{ctx: context.Background(), Variables: []byte(`{"companyId":"abc123","date":null}`)}, `{"data":{"me":{"employment":{"id":"xyz987","times":[{"id":"t1","employee":{"id":"xyz987"},"start":"2022-11-02T08:00:00","end":"2022-11-02T12:00:00"}]}}}}` })) } @@ -3951,7 +3953,7 @@ func TestResolver_WithHeader(t *testing.T) { header := make(http.Header) header.Set(tc.header, "foo") ctx := &Context{ - Context: context.Background(), + ctx: context.Background(), Request: Request{ Header: header, }, @@ -4090,7 +4092,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, out := setup(c, fakeStream) ctx := Context{ - Context: c, + ctx: c, } err := resolver.ResolveGraphQLSubscription(&ctx, plan, out) @@ -4106,7 +4108,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, out := setup(c, nil) ctx := Context{ - Context: c, + ctx: c, } err := resolver.ResolveGraphQLSubscription(&ctx, plan, out) @@ -4126,7 +4128,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, out := setup(c, fakeStream) ctx := Context{ - Context: c, + ctx: c, } err := resolver.ResolveGraphQLSubscription(&ctx, plan, out) diff --git a/pkg/execution/datasource/datasource_http_polling_stream.go b/pkg/execution/datasource/datasource_http_polling_stream.go index 1b0009a14..4c78f1806 100644 --- a/pkg/execution/datasource/datasource_http_polling_stream.go +++ b/pkg/execution/datasource/datasource_http_polling_stream.go @@ -110,11 +110,27 @@ type HttpPollingStreamDataSource struct { Log log.Logger once sync.Once ch chan []byte - closed bool Delay time.Duration client *http.Client request *http.Request lastData []byte + + // The mutex guards the fields following it. Use the + // accessor methods to read/write them. + mu sync.RWMutex + closed bool +} + +func (s *HttpPollingStreamDataSource) close() { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true +} + +func (s *HttpPollingStreamDataSource) isClosed() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.closed } func (h *HttpPollingStreamDataSource) Resolve(ctx context.Context, args ResolverArgs, out io.Writer) (n int, err error) { @@ -130,7 +146,7 @@ func (h *HttpPollingStreamDataSource) Resolve(ctx context.Context, args Resolver } go h.startPolling(ctx) }) - if h.closed { + if h.isClosed() { return } select { @@ -145,7 +161,7 @@ func (h *HttpPollingStreamDataSource) Resolve(ctx context.Context, args Resolver ) } case <-ctx.Done(): - h.closed = true + h.close() return } return @@ -162,7 +178,7 @@ func (h *HttpPollingStreamDataSource) startPolling(ctx context.Context) { var data []byte select { case <-ctx.Done(): - h.closed = true + h.close() return default: response, err := h.client.Do(h.request) @@ -186,7 +202,7 @@ func (h *HttpPollingStreamDataSource) startPolling(ctx context.Context) { h.lastData = data select { case <-ctx.Done(): - h.closed = true + h.close() return case h.ch <- data: continue diff --git a/pkg/fastbuffer/fastbuffer.go b/pkg/fastbuffer/fastbuffer.go index c5186b8c3..13675d9c4 100644 --- a/pkg/fastbuffer/fastbuffer.go +++ b/pkg/fastbuffer/fastbuffer.go @@ -40,6 +40,16 @@ func (f *FastBuffer) Len() int { return len(f.b) } +// Grow increases the buffer capacity to be able to hold at least n more bytes +func (f *FastBuffer) Grow(n int) { + required := cap(f.b) - len(f.b) + n + if required > 0 { + b := make([]byte, len(f.b), len(f.b)+n) + copy(b, f.b) + f.b = b + } +} + func (f *FastBuffer) UnsafeString() string { sliceHeader := (*reflect.SliceHeader)(unsafe.Pointer(&f.b)) stringHeader := reflect.StringHeader{Data: sliceHeader.Data, Len: sliceHeader.Len} diff --git a/pkg/graphql/execution_engine_v2.go b/pkg/graphql/execution_engine_v2.go index 36e0f0a77..23d4fb750 100644 --- a/pkg/graphql/execution_engine_v2.go +++ b/pkg/graphql/execution_engine_v2.go @@ -127,7 +127,7 @@ func (e *internalExecutionContext) setRequest(request resolve.Request) { } func (e *internalExecutionContext) setContext(ctx context.Context) { - e.resolveContext.Context = ctx + e.resolveContext = e.resolveContext.WithContext(ctx) } func (e *internalExecutionContext) setVariables(variables []byte) { diff --git a/pkg/graphql/execution_engine_v2_norace_test.go b/pkg/graphql/execution_engine_v2_norace_test.go new file mode 100644 index 000000000..483ab2989 --- /dev/null +++ b/pkg/graphql/execution_engine_v2_norace_test.go @@ -0,0 +1,170 @@ +//go:build !race + +package graphql + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/pkg/testing/federationtesting" + "github.com/wundergraph/graphql-go-tools/pkg/testing/flags" +) + +// This tests produces data races in the generated gql code. Disable it when the race +// detector is enabled. +func TestExecutionEngineV2_FederationAndSubscription_IntegrationTest(t *testing.T) { + if flags.IsWindows { + t.Skip("skip on windows - test is timing dependendent") + } + + runIntegration := func(t *testing.T, enableDataLoader bool, secondRun bool) { + t.Helper() + ctx, cancelFn := context.WithCancel(context.Background()) + setup := newFederationSetup() + t.Cleanup(func() { + cancelFn() + setup.accountsUpstreamServer.Close() + setup.productsUpstreamServer.Close() + setup.reviewsUpstreamServer.Close() + setup.pollingUpstreamServer.Close() + }) + + engine, schema, err := newFederationEngine(ctx, setup, enableDataLoader) + require.NoError(t, err) + + t.Run("should successfully execute a federation operation", func(t *testing.T) { + gqlRequest := &Request{ + OperationName: "", + Variables: nil, + Query: federationtesting.QueryReviewsOfMe, + } + + validationResult, err := gqlRequest.ValidateForSchema(schema) + require.NoError(t, err) + require.True(t, validationResult.Valid) + + execCtx, execCtxCancelFn := context.WithCancel(context.Background()) + defer execCtxCancelFn() + + resultWriter := NewEngineResultWriter() + err = engine.Execute(execCtx, gqlRequest, &resultWriter) + if assert.NoError(t, err) { + assert.Equal(t, + `{"data":{"me":{"reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby","price":11}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Fedora","price":22}}]}}}`, + resultWriter.String(), + ) + } + }) + + t.Run("should successfully execute a federation subscription", func(t *testing.T) { + query := ` +subscription UpdatedPrice { + updatedPrice { + name + price + reviews { + body + author { + id + username + } + } + } +}` + + gqlRequest := &Request{ + OperationName: "", + Variables: nil, + Query: query, + } + + validationResult, err := gqlRequest.ValidateForSchema(schema) + require.NoError(t, err) + require.True(t, validationResult.Valid) + + execCtx, execCtxCancelFn := context.WithCancel(context.Background()) + defer execCtxCancelFn() + + message := make(chan string) + resultWriter := NewEngineResultWriter() + resultWriter.SetFlushCallback(func(data []byte) { + message <- string(data) + }) + + go func() { + err := engine.Execute(execCtx, gqlRequest, &resultWriter) + assert.NoError(t, err) + }() + + if assert.NoError(t, err) { + assert.Eventuallyf(t, func() bool { + msg := `{"data":{"updatedPrice":{"name":"Trilby","price":%d,"reviews":[{"body":"A highly effective form of birth control.","author":{"id":"1234","username":"Me"}}]}}}` + price := 10 + if secondRun { + price += 2 + } + + firstMessage := <-message + expectedFirstMessage := fmt.Sprintf(msg, price) + assert.Equal(t, expectedFirstMessage, firstMessage) + + secondMessage := <-message + expectedSecondMessage := fmt.Sprintf(msg, price+1) + assert.Equal(t, expectedSecondMessage, secondMessage) + return true + }, time.Second, 10*time.Millisecond, "did not receive expected messages") + } + }) + + /* Uncomment when polling subscriptions are ready: + + t.Run("should successfully subscribe to rest data source", func(t *testing.T) { + gqlRequest := &Request{ + OperationName: "", + Variables: nil, + Query: "subscription Counter { counter }", + } + + validationResult, err := gqlRequest.ValidateForSchema(setup.schema) + require.NoError(t, err) + require.True(t, validationResult.Valid) + + execCtx, execCtxCancelFn := context.WithCancel(context.Background()) + defer execCtxCancelFn() + + message := make(chan string) + resultWriter := NewEngineResultWriter() + resultWriter.SetFlushCallback(func(data []byte) { + fmt.Println(string(data)) + message <- string(data) + }) + + err = setup.engine.Execute(execCtx, gqlRequest, &resultWriter) + assert.NoError(t, err) + + if assert.NoError(t, err) { + assert.Eventuallyf(t, func() bool { + firstMessage := <-message + assert.Equal(t, `{"data":{"counter":1}}`, firstMessage) + secondMessage := <-message + assert.Equal(t, `{"data":{"counter":2}}`, secondMessage) + return true + }, time.Second, 10*time.Millisecond, "did not receive expected messages") + } + }) + */ + + } + + t.Run("federation", func(t *testing.T) { + runIntegration(t, false, false) + }) + + t.Run("federation with data loader enabled", func(t *testing.T) { + runIntegration(t, true, true) + }) +} diff --git a/pkg/graphql/execution_engine_v2_test.go b/pkg/graphql/execution_engine_v2_test.go index 67f1846a4..e8251d347 100644 --- a/pkg/graphql/execution_engine_v2_test.go +++ b/pkg/graphql/execution_engine_v2_test.go @@ -10,7 +10,6 @@ import ( "net/http/httptest" "sync" "testing" - "time" "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -28,7 +27,6 @@ import ( accounts "github.com/wundergraph/graphql-go-tools/pkg/testing/federationtesting/accounts/graph" products "github.com/wundergraph/graphql-go-tools/pkg/testing/federationtesting/products/graph" reviews "github.com/wundergraph/graphql-go-tools/pkg/testing/federationtesting/reviews/graph" - "github.com/wundergraph/graphql-go-tools/pkg/testing/flags" ) func TestEngineResponseWriter_AsHTTPResponse(t *testing.T) { @@ -1567,159 +1565,6 @@ func TestExecutionEngineV2_Execute(t *testing.T) { )) } -func TestExecutionEngineV2_FederationAndSubscription_IntegrationTest(t *testing.T) { - if flags.IsWindows { - t.Skip("skip on windows - test is timing dependendent") - } - - runIntegration := func(t *testing.T, enableDataLoader bool, secondRun bool) { - t.Helper() - ctx, cancelFn := context.WithCancel(context.Background()) - setup := newFederationSetup() - t.Cleanup(func() { - cancelFn() - setup.accountsUpstreamServer.Close() - setup.productsUpstreamServer.Close() - setup.reviewsUpstreamServer.Close() - setup.pollingUpstreamServer.Close() - }) - - engine, schema, err := newFederationEngine(ctx, setup, enableDataLoader) - require.NoError(t, err) - - t.Run("should successfully execute a federation operation", func(t *testing.T) { - gqlRequest := &Request{ - OperationName: "", - Variables: nil, - Query: federationtesting.QueryReviewsOfMe, - } - - validationResult, err := gqlRequest.ValidateForSchema(schema) - require.NoError(t, err) - require.True(t, validationResult.Valid) - - execCtx, execCtxCancelFn := context.WithCancel(context.Background()) - defer execCtxCancelFn() - - resultWriter := NewEngineResultWriter() - err = engine.Execute(execCtx, gqlRequest, &resultWriter) - if assert.NoError(t, err) { - assert.Equal(t, - `{"data":{"me":{"reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","name":"Trilby","price":11}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","name":"Fedora","price":22}}]}}}`, - resultWriter.String(), - ) - } - }) - - t.Run("should successfully execute a federation subscription", func(t *testing.T) { - query := ` -subscription UpdatedPrice { - updatedPrice { - name - price - reviews { - body - author { - id - username - } - } - } -}` - - gqlRequest := &Request{ - OperationName: "", - Variables: nil, - Query: query, - } - - validationResult, err := gqlRequest.ValidateForSchema(schema) - require.NoError(t, err) - require.True(t, validationResult.Valid) - - execCtx, execCtxCancelFn := context.WithCancel(context.Background()) - defer execCtxCancelFn() - - message := make(chan string) - resultWriter := NewEngineResultWriter() - resultWriter.SetFlushCallback(func(data []byte) { - message <- string(data) - }) - - go func() { - err := engine.Execute(execCtx, gqlRequest, &resultWriter) - assert.NoError(t, err) - }() - - if assert.NoError(t, err) { - assert.Eventuallyf(t, func() bool { - msg := `{"data":{"updatedPrice":{"name":"Trilby","price":%d,"reviews":[{"body":"A highly effective form of birth control.","author":{"id":"1234","username":"Me"}}]}}}` - price := 10 - if secondRun { - price += 2 - } - - firstMessage := <-message - expectedFirstMessage := fmt.Sprintf(msg, price) - assert.Equal(t, expectedFirstMessage, firstMessage) - - secondMessage := <-message - expectedSecondMessage := fmt.Sprintf(msg, price+1) - assert.Equal(t, expectedSecondMessage, secondMessage) - return true - }, time.Second, 10*time.Millisecond, "did not receive expected messages") - } - }) - - /* Uncomment when polling subscriptions are ready: - - t.Run("should successfully subscribe to rest data source", func(t *testing.T) { - gqlRequest := &Request{ - OperationName: "", - Variables: nil, - Query: "subscription Counter { counter }", - } - - validationResult, err := gqlRequest.ValidateForSchema(setup.schema) - require.NoError(t, err) - require.True(t, validationResult.Valid) - - execCtx, execCtxCancelFn := context.WithCancel(context.Background()) - defer execCtxCancelFn() - - message := make(chan string) - resultWriter := NewEngineResultWriter() - resultWriter.SetFlushCallback(func(data []byte) { - fmt.Println(string(data)) - message <- string(data) - }) - - err = setup.engine.Execute(execCtx, gqlRequest, &resultWriter) - assert.NoError(t, err) - - if assert.NoError(t, err) { - assert.Eventuallyf(t, func() bool { - firstMessage := <-message - assert.Equal(t, `{"data":{"counter":1}}`, firstMessage) - secondMessage := <-message - assert.Equal(t, `{"data":{"counter":2}}`, secondMessage) - return true - }, time.Second, 10*time.Millisecond, "did not receive expected messages") - } - }) - */ - - } - - t.Run("federation", func(t *testing.T) { - runIntegration(t, false, false) - }) - - t.Run("federation with data loader enabled", func(t *testing.T) { - runIntegration(t, true, true) - }) -} - func testNetHttpClient(t *testing.T, testCase roundTripperTestCase) *http.Client { defaultClient := httpclient.DefaultNetHttpClient return &http.Client{ diff --git a/pkg/subscription/context.go b/pkg/subscription/context.go index e61b33e6c..b71a77b5c 100644 --- a/pkg/subscription/context.go +++ b/pkg/subscription/context.go @@ -3,6 +3,7 @@ package subscription import ( "context" "net/http" + "sync" ) type InitialHttpRequestContext struct { @@ -17,27 +18,47 @@ func NewInitialHttpRequestContext(r *http.Request) *InitialHttpRequestContext { } } -type subscriptionCancellations map[string]context.CancelFunc +type subscriptionCancellations struct { + mu sync.RWMutex + cancellations map[string]context.CancelFunc +} -func (sc subscriptionCancellations) AddWithParent(id string, parent context.Context) context.Context { +func (sc *subscriptionCancellations) AddWithParent(id string, parent context.Context) context.Context { ctx, cancelFunc := context.WithCancel(parent) - sc[id] = cancelFunc + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.cancellations == nil { + sc.cancellations = make(map[string]context.CancelFunc) + } + sc.cancellations[id] = cancelFunc return ctx } -func (sc subscriptionCancellations) Cancel(id string) (ok bool) { - cancelFunc, ok := sc[id] +func (sc *subscriptionCancellations) Cancel(id string) (ok bool) { + sc.mu.Lock() + defer sc.mu.Unlock() + cancelFunc, ok := sc.cancellations[id] if !ok { return false } cancelFunc() - delete(sc, id) + delete(sc.cancellations, id) return true } -func (sc subscriptionCancellations) CancelAll() { - for _, cancelFunc := range sc { +func (sc *subscriptionCancellations) CancelAll() { + // We have full control over the cancellation functions (see AddWithParent()), so + // it's fine to invoke them with the lock held + sc.mu.RLock() + defer sc.mu.RUnlock() + for _, cancelFunc := range sc.cancellations { cancelFunc() } } + +func (sc *subscriptionCancellations) Len() int { + sc.mu.RLock() + defer sc.mu.RUnlock() + return len(sc.cancellations) +} diff --git a/pkg/subscription/context_test.go b/pkg/subscription/context_test.go index d49d32ea7..8858533b3 100644 --- a/pkg/subscription/context_test.go +++ b/pkg/subscription/context_test.go @@ -28,15 +28,15 @@ func TestSubscriptionCancellations(t *testing.T) { var ctx context.Context t.Run("should add a cancellation func to map", func(t *testing.T) { - require.Equal(t, 0, len(cancellations)) + require.Equal(t, 0, cancellations.Len()) ctx = cancellations.AddWithParent("1", context.Background()) - assert.Equal(t, 1, len(cancellations)) + assert.Equal(t, 1, cancellations.Len()) assert.NotNil(t, ctx) }) t.Run("should execute cancellation from map", func(t *testing.T) { - require.Equal(t, 1, len(cancellations)) + require.Equal(t, 1, cancellations.Len()) ctxTestFunc := func() bool { <-ctx.Done() return true @@ -45,6 +45,6 @@ func TestSubscriptionCancellations(t *testing.T) { ok := cancellations.Cancel("1") assert.Eventually(t, ctxTestFunc, time.Second, 5*time.Millisecond) assert.True(t, ok) - assert.Equal(t, 0, len(cancellations)) + assert.Equal(t, 0, cancellations.Len()) }) } diff --git a/pkg/subscription/handler.go b/pkg/subscription/handler.go index 7ceaa88b5..12491d27f 100644 --- a/pkg/subscription/handler.go +++ b/pkg/subscription/handler.go @@ -76,7 +76,7 @@ type Handler struct { keepAliveInterval time.Duration // subscriptionUpdateInterval is the actual interval on which the server sends subscription updates to the client. subscriptionUpdateInterval time.Duration - // subCancellations is map containing the cancellation functions to every active subscription. + // subCancellations stores a map containing the cancellation functions to every active subscription. subCancellations subscriptionCancellations // executorPool is responsible to create and hold executors. executorPool ExecutorPool @@ -126,9 +126,7 @@ func NewHandler(logger abstractlogger.Logger, client Client, executorPool Execut // Handle will handle the subscription connection. func (h *Handler) Handle(ctx context.Context) { - defer func() { - h.subCancellations.CancelAll() - }() + defer h.subCancellations.CancelAll() for { if !h.client.IsConnected() { @@ -508,5 +506,5 @@ func (h *Handler) handleError(id string, errors graphql.RequestErrors) { // ActiveSubscriptions will return the actual number of active subscriptions for that client. func (h *Handler) ActiveSubscriptions() int { - return len(h.subCancellations) + return h.subCancellations.Len() } diff --git a/pkg/subscription/handler_test.go b/pkg/subscription/handler_test.go index ea4d5367d..dbb28edbb 100644 --- a/pkg/subscription/handler_test.go +++ b/pkg/subscription/handler_test.go @@ -583,7 +583,7 @@ func TestHandler_Handle(t *testing.T) { handlerRoutineFunc := handlerRoutine(ctx) go handlerRoutineFunc() - time.Sleep(10 * time.Millisecond) + time.Sleep(50 * time.Millisecond) defer cancelFunc() go sendChatMutation(t, chatServer.URL) diff --git a/pkg/subscription/mock_client_test.go b/pkg/subscription/mock_client_test.go index 41c566a60..d1ccb7f1c 100644 --- a/pkg/subscription/mock_client_test.go +++ b/pkg/subscription/mock_client_test.go @@ -2,9 +2,11 @@ package subscription import ( "errors" + "sync" ) type mockClient struct { + mu sync.Mutex messagesFromServer []Message messageToServer *Message err error @@ -21,18 +23,24 @@ func newMockClient() *mockClient { } func (c *mockClient) ReadFromClient() (*Message, error) { + c.mu.Lock() returnErr := c.err + c.mu.Unlock() returnMessage := <-c.messagePipe if returnErr != nil { return nil, returnErr } + c.mu.Lock() + defer c.mu.Unlock() c.serverHasRead = true c.err = nil return returnMessage, returnErr } func (c *mockClient) WriteToClient(message Message) error { + c.mu.Lock() + defer c.mu.Unlock() c.messagesFromServer = append(c.messagesFromServer, message) return c.err } @@ -47,11 +55,15 @@ func (c *mockClient) Disconnect() error { } func (c *mockClient) hasMoreMessagesThan(num int) bool { + c.mu.Lock() + defer c.mu.Unlock() return len(c.messagesFromServer) > num } func (c *mockClient) readFromServer() []Message { - return c.messagesFromServer + c.mu.Lock() + defer c.mu.Unlock() + return c.messagesFromServer[0:len(c.messagesFromServer):len(c.messagesFromServer)] } func (c *mockClient) prepareConnectionInitMessage() *mockClient { @@ -106,11 +118,15 @@ func (c *mockClient) send() bool { } func (c *mockClient) withoutError() *mockClient { + c.mu.Lock() + defer c.mu.Unlock() c.err = nil return c } func (c *mockClient) withError() *mockClient { + c.mu.Lock() + defer c.mu.Unlock() c.err = errors.New("error") return c } @@ -120,6 +136,8 @@ func (c *mockClient) and() *mockClient { } func (c *mockClient) reset() *mockClient { + c.mu.Lock() + defer c.mu.Unlock() c.messagesFromServer = []Message{} c.err = nil return c diff --git a/pkg/testing/federationtesting/federation_intergation_test.go b/pkg/testing/federationtesting/federation_intergation_test.go index 5cfd686fa..73156f5e4 100644 --- a/pkg/testing/federationtesting/federation_intergation_test.go +++ b/pkg/testing/federationtesting/federation_intergation_test.go @@ -1,3 +1,5 @@ +//go:build !race + package federationtesting import ( @@ -62,6 +64,8 @@ func (f *federationSetup) close() { f.gatewayServer.Close() } +// This tests produces data races in the generated gql code. Disable it when the race +// detector is enabled. func TestFederationIntegrationTest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()