diff --git a/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 95dfa2316..e05c88ecc 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1163,16 +1163,24 @@ func (p *Planner) addField(ref int) { p.nodes = append(p.nodes, field) } +type OnWsConnectionInitCallback func(ctx context.Context, url string, header http.Header) (json.RawMessage, error) + type Factory struct { - BatchFactory resolve.DataSourceBatchFactory - HTTPClient *http.Client - StreamingClient *http.Client - subscriptionClient *SubscriptionClient + BatchFactory resolve.DataSourceBatchFactory + HTTPClient *http.Client + StreamingClient *http.Client + OnWsConnectionInitCallback *OnWsConnectionInitCallback + subscriptionClient *SubscriptionClient } func (f *Factory) Planner(ctx context.Context) plan.DataSourcePlanner { if f.subscriptionClient == nil { - f.subscriptionClient = NewGraphQLSubscriptionClient(f.HTTPClient, f.StreamingClient, ctx) + opts := make([]Options, 0) + if f.OnWsConnectionInitCallback != nil { + opts = append(opts, WithOnWsConnectionInitCallback(f.OnWsConnectionInitCallback)) + } + + f.subscriptionClient = NewGraphQLSubscriptionClient(f.HTTPClient, f.StreamingClient, ctx, opts...) } return &Planner{ batchFactory: f.BatchFactory, diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index 11622c562..81c064f29 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -7,9 +7,8 @@ import ( "sync" "time" - "github.com/cespare/xxhash/v2" - "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/jensneuse/abstractlogger" "nhooyr.io/websocket" ) @@ -20,14 +19,15 @@ const ackWaitTimeout = 30 * time.Second // It takes care of de-duplicating connections to the same origin under certain circumstances // If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used type SubscriptionClient struct { - streamingClient *http.Client - httpClient *http.Client - engineCtx context.Context - log abstractlogger.Logger - hashPool sync.Pool - handlers map[uint64]ConnectionHandler - handlersMu sync.Mutex - wsSubProtocol string + streamingClient *http.Client + httpClient *http.Client + engineCtx context.Context + log abstractlogger.Logger + hashPool sync.Pool + handlers map[uint64]ConnectionHandler + handlersMu sync.Mutex + wsSubProtocol string + onWsConnectionInitCallback *OnWsConnectionInitCallback readTimeout time.Duration } @@ -52,10 +52,17 @@ func WithWSSubProtocol(protocol string) Options { } } +func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Options { + return func(options *opts) { + options.onWsConnectionInitCallback = callback + } +} + type opts struct { - readTimeout time.Duration - log abstractlogger.Logger - wsSubProtocol string + readTimeout time.Duration + log abstractlogger.Logger + wsSubProtocol string + onWsConnectionInitCallback *OnWsConnectionInitCallback } func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) *SubscriptionClient { @@ -78,7 +85,8 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi return xxhash.New() }, }, - wsSubProtocol: op.wsSubProtocol, + wsSubProtocol: op.wsSubProtocol, + onWsConnectionInitCallback: op.onWsConnectionInitCallback, } } @@ -199,6 +207,11 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti return nil, fmt.Errorf("upgrade unsuccessful") } + connectionInitMessage, err := c.getConnectionInitMessage(reqCtx, options.URL, options.Header) + if err != nil { + return nil, err + } + // init + ack err = conn.Write(reqCtx, websocket.MessageText, connectionInitMessage) if err != nil { @@ -223,6 +236,30 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti } } +func (c *SubscriptionClient) getConnectionInitMessage(ctx context.Context, url string, header http.Header) ([]byte, error) { + if c.onWsConnectionInitCallback == nil { + return connectionInitMessage, nil + } + + callback := *c.onWsConnectionInitCallback + + payload, err := callback(ctx, url, header) + if err != nil { + return nil, err + } + + if len(payload) == 0 { + return connectionInitMessage, nil + } + + msg, err := jsonparser.Set(connectionInitMessage, payload, "payload") + if err != nil { + return nil, err + } + + return msg, nil +} + type ConnectionHandler interface { StartBlocking(sub Subscription) SubscribeCH() chan<- Subscription diff --git a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 53de7c564..372596624 100644 --- a/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2,6 +2,7 @@ package graphql_datasource import ( "context" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -10,6 +11,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/buger/jsonparser" ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" @@ -27,6 +30,40 @@ func logger() ll.Logger { return ll.NewZapLogger(logger, ll.DebugLevel) } +func TestGetConnectionInitMessageHelper(t *testing.T) { + var callback OnWsConnectionInitCallback = func(ctx context.Context, url string, header http.Header) (json.RawMessage, error) { + return json.RawMessage(`{"authorization":"secret"}`), nil + } + + tests := []struct { + name string + callback *OnWsConnectionInitCallback + want string + }{ + { + name: "without payload", + callback: nil, + want: `{"type":"connection_init"}`, + }, + { + name: "with payload", + callback: &callback, + want: `{"type":"connection_init","payload":{"authorization":"secret"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := SubscriptionClient{onWsConnectionInitCallback: tt.callback} + got, err := client.getConnectionInitMessage(context.Background(), "", nil) + require.NoError(t, err) + require.NotEmpty(t, got) + + assert.Equal(t, tt.want, string(got)) + }) + } +} + func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) { serverDone := &sync.WaitGroup{} connectedClients := atomic.NewInt64(0)