diff --git a/pkg/graphql/execution_engine_v2.go b/pkg/graphql/execution_engine_v2.go index 39fec39a1..286791637 100644 --- a/pkg/graphql/execution_engine_v2.go +++ b/pkg/graphql/execution_engine_v2.go @@ -21,8 +21,9 @@ import ( ) type EngineV2Configuration struct { - schema *Schema - plannerConfig plan.Configuration + schema *Schema + plannerConfig plan.Configuration + websocketBeforeStartHook WebsocketBeforeStartHook } func NewEngineV2Configuration(schema *Schema) EngineV2Configuration { @@ -52,6 +53,11 @@ func (e *EngineV2Configuration) SetFieldConfigurations(fieldConfigs plan.FieldCo e.plannerConfig.Fields = fieldConfigs } +// SetWebsocketBeforeStartHook - sets before start hook which will be called before processing any operation sent over websockets +func (e *EngineV2Configuration) SetWebsocketBeforeStartHook(hook WebsocketBeforeStartHook) { + e.websocketBeforeStartHook = hook +} + type EngineResultWriter struct { buf *bytes.Buffer flushCallback func(data []byte) @@ -157,6 +163,10 @@ type ExecutionEngineV2 struct { executionPlanCache *lru.Cache } +type WebsocketBeforeStartHook interface { + OnBeforeStart(reqCtx context.Context, operation *Request) error +} + type ExecutionOptionsV2 func(ctx *internalExecutionContext) func WithBeforeFetchHook(hook resolve.BeforeFetchHook) ExecutionOptionsV2 { @@ -268,6 +278,10 @@ func (e *ExecutionEngineV2) getCachedPlan(ctx *internalExecutionContext, operati return p } +func (e *ExecutionEngineV2) GetWebsocketBeforeStartHook() WebsocketBeforeStartHook { + return e.config.websocketBeforeStartHook +} + func (e *ExecutionEngineV2) getExecutionCtx() *internalExecutionContext { return e.internalExecutionContextPool.Get().(*internalExecutionContext) } diff --git a/pkg/subscription/executor_v2.go b/pkg/subscription/executor_v2.go index 81d977efa..ae974f82f 100644 --- a/pkg/subscription/executor_v2.go +++ b/pkg/subscription/executor_v2.go @@ -10,12 +10,14 @@ import ( "github.com/jensneuse/graphql-go-tools/pkg/graphql" ) +// ExecutorV2Pool - provides reusable executors type ExecutorV2Pool struct { - engine *graphql.ExecutionEngineV2 - executorPool *sync.Pool + engine *graphql.ExecutionEngineV2 + executorPool *sync.Pool + connectionInitReqCtx context.Context // connectionInitReqCtx - holds original request context used to establish websocket connection } -func NewExecutorV2Pool(engine *graphql.ExecutionEngineV2) *ExecutorV2Pool { +func NewExecutorV2Pool(engine *graphql.ExecutionEngineV2, connectionInitReqCtx context.Context) *ExecutorV2Pool { return &ExecutorV2Pool{ engine: engine, executorPool: &sync.Pool{ @@ -23,6 +25,7 @@ func NewExecutorV2Pool(engine *graphql.ExecutionEngineV2) *ExecutorV2Pool { return &ExecutorV2{} }, }, + connectionInitReqCtx: connectionInitReqCtx, } } @@ -37,6 +40,7 @@ func (e *ExecutorV2Pool) Get(payload []byte) (Executor, error) { engine: e.engine, operation: &operation, context: context.Background(), + reqCtx: e.connectionInitReqCtx, }, nil } @@ -50,6 +54,7 @@ type ExecutorV2 struct { engine *graphql.ExecutionEngineV2 operation *graphql.Request context context.Context + reqCtx context.Context } func (e *ExecutorV2) Execute(writer resolve.FlushWriter) error { @@ -73,4 +78,5 @@ func (e *ExecutorV2) Reset() { e.engine = nil e.operation = nil e.context = context.Background() + e.reqCtx = context.TODO() } diff --git a/pkg/subscription/handler.go b/pkg/subscription/handler.go index 80a847903..219bb95fa 100644 --- a/pkg/subscription/handler.go +++ b/pkg/subscription/handler.go @@ -191,6 +191,11 @@ func (h *Handler) handleStart(id string, payload []byte) { return } + if err = h.handleOnBeforeStart(executor); err != nil { + h.handleError(id, graphql.RequestErrorsFromError(err)) + return + } + if executor.OperationType() == ast.OperationTypeSubscription { ctx := h.subCancellations.Add(id) go h.startSubscription(ctx, id, executor) @@ -200,6 +205,19 @@ func (h *Handler) handleStart(id string, payload []byte) { go h.handleNonSubscriptionOperation(id, executor) } +func (h *Handler) handleOnBeforeStart(executor Executor) error { + switch e := executor.(type) { + case *ExecutorV2: + if hook := e.engine.GetWebsocketBeforeStartHook(); hook != nil { + return hook.OnBeforeStart(e.reqCtx, e.operation) + } + case *ExecutorV1: + // do nothing + } + + return nil +} + // handleNonSubscriptionOperation will handle a non-subscription operation like a query or a mutation. func (h *Handler) handleNonSubscriptionOperation(id string, executor Executor) { defer func() { @@ -427,4 +445,4 @@ func (h *Handler) handleError(id string, errors graphql.RequestErrors) { // ActiveSubscriptions will return the actual number of active subscriptions for that client. func (h *Handler) ActiveSubscriptions() int { return len(h.subCancellations) -} \ No newline at end of file +} diff --git a/pkg/subscription/handler_test.go b/pkg/subscription/handler_test.go index 5b939dd77..eef78318d 100644 --- a/pkg/subscription/handler_test.go +++ b/pkg/subscription/handler_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "testing" @@ -23,6 +24,20 @@ import ( type handlerRoutine func(ctx context.Context) func() bool +type websocketHook struct { + called bool + reqCtx context.Context + hook func(reqCtx context.Context, operation *graphql.Request) error +} + +func (w *websocketHook) OnBeforeStart(reqCtx context.Context, operation *graphql.Request) error { + w.called = true + if w.hook != nil { + return w.hook(reqCtx, operation) + } + return nil +} + func TestHandler_Handle(t *testing.T) { starwars.SetRelativePathToStarWarsPackage("../starwars") @@ -280,72 +295,8 @@ func TestHandler_Handle(t *testing.T) { chatServer := httptest.NewServer(chat.GraphQLEndpointHandler()) defer chatServer.Close() - chatSchemaBytes, err := chat.LoadSchemaFromExamplesDirectoryWithinPkg() - require.NoError(t, err) - - chatSchema, err := graphql.NewSchemaFromReader(bytes.NewBuffer(chatSchemaBytes)) - require.NoError(t, err) - - engineConf := graphql.NewEngineV2Configuration(chatSchema) - engineConf.SetDataSources([]plan.DataSourceConfiguration{ - { - RootNodes: []plan.TypeField{ - {TypeName: "Mutation", FieldNames: []string{"post"}}, - {TypeName: "Subscription", FieldNames: []string{"messageAdded"}}, - }, - ChildNodes: []plan.TypeField{ - {TypeName: "Message", FieldNames: []string{"text", "createdBy"}}, - }, - Factory: &graphql_datasource.Factory{ - HTTPClient: httpclient.DefaultNetHttpClient, - }, - Custom: graphql_datasource.ConfigJson(graphql_datasource.Configuration{ - Fetch: graphql_datasource.FetchConfiguration{ - URL: chatServer.URL, - Method: http.MethodPost, - Header: nil, - }, - Subscription: graphql_datasource.SubscriptionConfiguration{ - URL: chatServer.URL, - }, - }), - }, - }) - engineConf.SetFieldConfigurations([]plan.FieldConfiguration{ - { - TypeName: "Mutation", - FieldName: "post", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "roomName", - SourceType: plan.FieldArgumentSource, - }, - { - Name: "username", - SourceType: plan.FieldArgumentSource, - }, - { - Name: "text", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - { - TypeName: "Subscription", - FieldName: "messageAdded", - Arguments: []plan.ArgumentConfiguration{ - { - Name: "roomName", - SourceType: plan.FieldArgumentSource, - }, - }, - }, - }) - engine, err := graphql.NewExecutionEngineV2(ctx, abstractlogger.NoopLogger, engineConf) - require.NoError(t, err) - - executorPool := NewExecutorV2Pool(engine) t.Run("connection_init", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) t.Run("should send connection error message when error on read occurrs", func(t *testing.T) { @@ -383,6 +334,7 @@ func TestHandler_Handle(t *testing.T) { }) t.Run("connection_keep_alive", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) t.Run("should successfully send keep alive messages after connection_init", func(t *testing.T) { @@ -417,6 +369,7 @@ func TestHandler_Handle(t *testing.T) { }) t.Run("erroneous operation(s)", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) ctx, cancelFunc := context.WithCancel(context.Background()) handlerRoutineFunc := handlerRoutine(ctx) @@ -445,9 +398,11 @@ func TestHandler_Handle(t *testing.T) { }) t.Run("non-subscription query", func(t *testing.T) { - subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) t.Run("should process query and return error when query is not valid", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + payload, err := chat.GraphQLRequestForOperation(chat.InvalidOperation) require.NoError(t, err) client.prepareStartMessage("1", payload).withoutError().and().send() @@ -474,8 +429,20 @@ func TestHandler_Handle(t *testing.T) { }) t.Run("should process and send result for a query", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + payload, err := chat.GraphQLRequestForOperation(chat.MutationSendMessage) require.NoError(t, err) + + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + assert.Equal(t, hookHolder.reqCtx, ctx) + assert.Contains(t, operation.Query, "mutation SendMessage") + return nil + } + defer func() { + hookHolder.hook = nil + }() + client.prepareStartMessage("1", payload).withoutError().and().send() ctx, cancelFunc := context.WithCancel(context.Background()) @@ -504,10 +471,56 @@ func TestHandler_Handle(t *testing.T) { assert.Contains(t, messagesFromServer, expectedDataMessage) assert.Contains(t, messagesFromServer, expectedCompleteMessage) assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) + }) + + t.Run("should process and send error message from hook for a query", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := chat.GraphQLRequestForOperation(chat.MutationSendMessage) + require.NoError(t, err) + + errMsg := "error_on_operation" + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + return errors.New(errMsg) + } + defer func() { + hookHolder.hook = nil + }() + + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + cancelFunc() + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + waitForClientHavingTwoMessages := func() bool { + return client.hasMoreMessagesThan(0) + } + require.Eventually(t, waitForClientHavingTwoMessages, 5*time.Second, 5*time.Millisecond) + + jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ + {Message: errMsg}, + }) + require.NoError(t, err) + expectedErrMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: jsonErrMessage, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedErrMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) }) + }) t.Run("subscription query", func(t *testing.T) { + executorPool, hookHolder := setupEngineV2(t, ctx, chatServer.URL) + t.Run("should start subscription on start", func(t *testing.T) { subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) payload, err := chat.GraphQLRequestForOperation(chat.SubscriptionLiveMessages) @@ -593,9 +606,52 @@ func TestHandler_Handle(t *testing.T) { cancelFunc() }) + + t.Run("should interrupt subscription on start and return error message from hook", func(t *testing.T) { + subscriptionHandler, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) + + payload, err := chat.GraphQLRequestForOperation(chat.SubscriptionLiveMessages) + require.NoError(t, err) + + errMsg := "sub_interrupted" + hookHolder.hook = func(ctx context.Context, operation *graphql.Request) error { + return errors.New(errMsg) + } + + client.prepareStartMessage("1", payload).withoutError().and().send() + + ctx, cancelFunc := context.WithCancel(context.Background()) + handlerRoutineFunc := handlerRoutine(ctx) + go handlerRoutineFunc() + + time.Sleep(10 * time.Millisecond) + cancelFunc() + + go sendChatMutation(t, chatServer.URL) + + require.Eventually(t, func() bool { + return client.hasMoreMessagesThan(0) + }, 1*time.Second, 10*time.Millisecond) + + jsonErrMessage, err := json.Marshal(graphql.RequestErrors{ + {Message: errMsg}, + }) + require.NoError(t, err) + expectedErrMessage := Message{ + Id: "1", + Type: MessageTypeError, + Payload: jsonErrMessage, + } + + messagesFromServer := client.readFromServer() + assert.Contains(t, messagesFromServer, expectedErrMessage) + assert.Equal(t, 0, subscriptionHandler.ActiveSubscriptions()) + assert.True(t, hookHolder.called) + }) }) t.Run("connection_terminate", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) t.Run("should successfully disconnect from client", func(t *testing.T) { @@ -612,6 +668,7 @@ func TestHandler_Handle(t *testing.T) { }) t.Run("client is disconnected", func(t *testing.T) { + executorPool, _ := setupEngineV2(t, ctx, chatServer.URL) _, client, handlerRoutine := setupSubscriptionHandlerTest(t, executorPool) t.Run("server should not read from client and stop handler", func(t *testing.T) { @@ -632,6 +689,82 @@ func TestHandler_Handle(t *testing.T) { } +func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*ExecutorV2Pool, *websocketHook) { + chatSchemaBytes, err := chat.LoadSchemaFromExamplesDirectoryWithinPkg() + require.NoError(t, err) + + chatSchema, err := graphql.NewSchemaFromReader(bytes.NewBuffer(chatSchemaBytes)) + require.NoError(t, err) + + engineConf := graphql.NewEngineV2Configuration(chatSchema) + engineConf.SetDataSources([]plan.DataSourceConfiguration{ + { + RootNodes: []plan.TypeField{ + {TypeName: "Mutation", FieldNames: []string{"post"}}, + {TypeName: "Subscription", FieldNames: []string{"messageAdded"}}, + }, + ChildNodes: []plan.TypeField{ + {TypeName: "Message", FieldNames: []string{"text", "createdBy"}}, + }, + Factory: &graphql_datasource.Factory{ + HTTPClient: httpclient.DefaultNetHttpClient, + }, + Custom: graphql_datasource.ConfigJson(graphql_datasource.Configuration{ + Fetch: graphql_datasource.FetchConfiguration{ + URL: chatServerURL, + Method: http.MethodPost, + Header: nil, + }, + Subscription: graphql_datasource.SubscriptionConfiguration{ + URL: chatServerURL, + }, + }), + }, + }) + engineConf.SetFieldConfigurations([]plan.FieldConfiguration{ + { + TypeName: "Mutation", + FieldName: "post", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "username", + SourceType: plan.FieldArgumentSource, + }, + { + Name: "text", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + { + TypeName: "Subscription", + FieldName: "messageAdded", + Arguments: []plan.ArgumentConfiguration{ + { + Name: "roomName", + SourceType: plan.FieldArgumentSource, + }, + }, + }, + }) + + hookHolder := &websocketHook{ + reqCtx: context.Background(), + } + engineConf.SetWebsocketBeforeStartHook(hookHolder) + + engine, err := graphql.NewExecutionEngineV2(ctx, abstractlogger.NoopLogger, engineConf) + require.NoError(t, err) + + executorPool := NewExecutorV2Pool(engine, hookHolder.reqCtx) + + return executorPool, hookHolder +} + func setupSubscriptionHandlerTest(t *testing.T, executorPool ExecutorPool) (subscriptionHandler *Handler, client *mockClient, routine handlerRoutine) { client = newMockClient()