diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index f3432a9a6..ed729f981 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -15,15 +15,13 @@ import ( "nhooyr.io/websocket" ) -var ( - connectionInitMessage = []byte(`{"type":"connection_init"}`) -) - const ( - startMessage = `{"type":"start","id":"%s","payload":%s}` - stopMessage = `{"type":"stop","id":"%s"}` - internalError = `{"errors":[{"message":"connection error"}]}` - connectionError = `{"errors":[{"message":"connection error"}]}` + initMessageWithPayload = `{"type":"connection_init", "payload":%s}` + initMessageNoPayload = `{"type":"connection_init"}` + startMessage = `{"type":"start","id":"%s","payload":%s}` + stopMessage = `{"type":"stop","id":"%s"}` + internalError = `{"errors":[{"message":"connection error"}]}` + connectionError = `{"errors":[{"message":"connection error"}]}` ) // WebSocketGraphQLSubscriptionClient is a WebSocket client that allows running multiple subscriptions via the same WebSocket Connection @@ -86,7 +84,6 @@ func NewWebSocketGraphQLSubscriptionClient(httpClient *http.Client, ctx context. // If an existing WS with the same ID (Hash) exists, it is being re-used // If no connection exists, the client initiates a new one and sends the "init" and "connection ack" messages func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error { - handlerID, err := c.generateHandlerIDHash(options) if err != nil { return err @@ -112,6 +109,12 @@ func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, opti if options.Header == nil { options.Header = http.Header{} } + + initMessage, err := connectionInitMessage(options.Header) + if err != nil { + return err + } + options.Header.Set("Sec-WebSocket-Protocol", "graphql-ws") options.Header.Set("Sec-WebSocket-Version", "13") @@ -128,7 +131,7 @@ func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, opti return fmt.Errorf("upgrade unsuccessful") } // init + ack - err = conn.Write(ctx, websocket.MessageText, connectionInitMessage) + err = conn.Write(ctx, websocket.MessageText, []byte(initMessage)) if err != nil { return err } @@ -181,6 +184,21 @@ func (c *WebSocketGraphQLSubscriptionClient) generateHandlerIDHash(options Graph return xxh.Sum64(), nil } +func connectionInitMessage(header http.Header) (string, error) { + if len(header) == 0 { + return initMessageNoPayload, nil + } + payload := make(map[string]string, len(header)) + for name := range header { + payload[name] = header.Get(name) + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return "", err + } + return fmt.Sprintf(initMessageWithPayload, payloadBytes), nil +} + func newConnectionHandler(ctx context.Context, conn *websocket.Conn, readTimeout time.Duration, log abstractlogger.Logger) *connectionHandler { return &connectionHandler{ conn: conn, diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 4ed9b800e..2ec1eeec7 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -420,3 +420,92 @@ func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) { return connectedClients.Load() == 0 }, time.Second, time.Millisecond, "clients not 0") } + +func TestWebsocketSubscriptionClientWithInitPayload(t *testing.T) { + assertInitAck := func(ctx context.Context, conn *websocket.Conn) { + msgType, data, err := conn.Read(ctx) + assert.NoError(t, err) + assert.Equal(t, websocket.MessageText, msgType) + assert.Equal(t, `{"type":"connection_init", "payload":{"Authorization":"Bearer XXX"}}`, string(data)) + err = conn.Write(ctx, websocket.MessageText, []byte(`{"type":"connection_ack"}`)) + assert.NoError(t, err) + } + + handshakeHappened := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + assert.NoError(t, err) + assertInitAck(r.Context(), conn) + handshakeHappened = true + <-conn.CloseRead(r.Context()).Done() + })) + defer server.Close() + + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + client := NewWebSocketGraphQLSubscriptionClient(http.DefaultClient, clientCtx, + WithReadTimeout(time.Millisecond), + WithLogger(logger()), + ) + + subscribeHeaders := http.Header{} + subscribeHeaders.Add("Authorization", "Bearer XXX") + next := make(chan []byte) + subCtx, subCancel := context.WithCancel(context.Background()) + defer subCancel() + err := client.Subscribe(subCtx, GraphQLSubscriptionOptions{ + URL: strings.Replace(server.URL, "http", "ws", -1), + Body: GraphQLBody{ + Query: `subscription {messageAdded(roomName: "room"){text}}`, + }, + Header: subscribeHeaders, + }, next) + assert.NoError(t, err) + + assert.Len(t, client.handlers, 1, "handler not registered") + assert.Eventuallyf(t, func() bool { + return handshakeHappened + }, time.Second, time.Millisecond, "handshake was not performed") +} + +func TestConnectionInitMessage(t *testing.T) { + for i, tc := range []struct { + header http.Header + expectedMessage string + }{ + { + header: http.Header{}, + expectedMessage: `{"type":"connection_init"}`, + }, + { + header: nil, + expectedMessage: `{"type":"connection_init"}`, + }, + { + header: http.Header{"Foo": []string{"bar"}}, + expectedMessage: `{"type":"connection_init", "payload":{"Foo":"bar"}}`, + }, + { + header: http.Header{"Foo": []string{"bar", "baz"}}, + expectedMessage: `{"type":"connection_init", "payload":{"Foo":"bar"}}`, + }, + { + header: http.Header{"Foo": []string{""}}, + expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`, + }, + { + header: http.Header{"Foo": []string{}}, + expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`, + }, + { + header: http.Header{"Foo": nil}, + expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`, + }, + } { + t.Run(fmt.Sprint(i), func(t *testing.T) { + msg, err := connectionInitMessage(tc.header) + assert.NoError(t, err) + assert.Equal(t, tc.expectedMessage, msg) + }) + } +}