Skip to content

Commit

Permalink
feat: subscriptions, ws connection init callback (#425)
Browse files Browse the repository at this point in the history
* feat: subscriptions, ws connection init callback

* feat: subscriptions, ws connection init callback
  • Loading branch information
YuriBuerov committed Oct 14, 2022
1 parent 3289bbe commit 3ef7902
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 19 deletions.
18 changes: 13 additions & 5 deletions pkg/engine/datasource/graphql_datasource/graphql_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -1163,16 +1163,24 @@ func (p *Planner) addField(ref int) {
p.nodes = append(p.nodes, field)
}

type OnWsConnectionInitCallback func(ctx context.Context, url string, header http.Header) (json.RawMessage, error)

type Factory struct {
BatchFactory resolve.DataSourceBatchFactory
HTTPClient *http.Client
StreamingClient *http.Client
subscriptionClient *SubscriptionClient
BatchFactory resolve.DataSourceBatchFactory
HTTPClient *http.Client
StreamingClient *http.Client
OnWsConnectionInitCallback *OnWsConnectionInitCallback
subscriptionClient *SubscriptionClient
}

func (f *Factory) Planner(ctx context.Context) plan.DataSourcePlanner {
if f.subscriptionClient == nil {
f.subscriptionClient = NewGraphQLSubscriptionClient(f.HTTPClient, f.StreamingClient, ctx)
opts := make([]Options, 0)
if f.OnWsConnectionInitCallback != nil {
opts = append(opts, WithOnWsConnectionInitCallback(f.OnWsConnectionInitCallback))
}

f.subscriptionClient = NewGraphQLSubscriptionClient(f.HTTPClient, f.StreamingClient, ctx, opts...)
}
return &Planner{
batchFactory: f.BatchFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"sync"
"time"

"github.com/cespare/xxhash/v2"

"github.com/buger/jsonparser"
"github.com/cespare/xxhash/v2"
"github.com/jensneuse/abstractlogger"
"nhooyr.io/websocket"
)
Expand All @@ -20,14 +19,15 @@ const ackWaitTimeout = 30 * time.Second
// It takes care of de-duplicating connections to the same origin under certain circumstances
// If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used
type SubscriptionClient struct {
streamingClient *http.Client
httpClient *http.Client
engineCtx context.Context
log abstractlogger.Logger
hashPool sync.Pool
handlers map[uint64]ConnectionHandler
handlersMu sync.Mutex
wsSubProtocol string
streamingClient *http.Client
httpClient *http.Client
engineCtx context.Context
log abstractlogger.Logger
hashPool sync.Pool
handlers map[uint64]ConnectionHandler
handlersMu sync.Mutex
wsSubProtocol string
onWsConnectionInitCallback *OnWsConnectionInitCallback

readTimeout time.Duration
}
Expand All @@ -52,10 +52,17 @@ func WithWSSubProtocol(protocol string) Options {
}
}

func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Options {
return func(options *opts) {
options.onWsConnectionInitCallback = callback
}
}

type opts struct {
readTimeout time.Duration
log abstractlogger.Logger
wsSubProtocol string
readTimeout time.Duration
log abstractlogger.Logger
wsSubProtocol string
onWsConnectionInitCallback *OnWsConnectionInitCallback
}

func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) *SubscriptionClient {
Expand All @@ -78,7 +85,8 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
return xxhash.New()
},
},
wsSubProtocol: op.wsSubProtocol,
wsSubProtocol: op.wsSubProtocol,
onWsConnectionInitCallback: op.onWsConnectionInitCallback,
}
}

Expand Down Expand Up @@ -199,6 +207,11 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti
return nil, fmt.Errorf("upgrade unsuccessful")
}

connectionInitMessage, err := c.getConnectionInitMessage(reqCtx, options.URL, options.Header)
if err != nil {
return nil, err
}

// init + ack
err = conn.Write(reqCtx, websocket.MessageText, connectionInitMessage)
if err != nil {
Expand All @@ -223,6 +236,30 @@ func (c *SubscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti
}
}

func (c *SubscriptionClient) getConnectionInitMessage(ctx context.Context, url string, header http.Header) ([]byte, error) {
if c.onWsConnectionInitCallback == nil {
return connectionInitMessage, nil
}

callback := *c.onWsConnectionInitCallback

payload, err := callback(ctx, url, header)
if err != nil {
return nil, err
}

if len(payload) == 0 {
return connectionInitMessage, nil
}

msg, err := jsonparser.Set(connectionInitMessage, payload, "payload")
if err != nil {
return nil, err
}

return msg, nil
}

type ConnectionHandler interface {
StartBlocking(sub Subscription)
SubscribeCH() chan<- Subscription
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package graphql_datasource

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
Expand All @@ -10,6 +11,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/buger/jsonparser"
ll "github.com/jensneuse/abstractlogger"
"github.com/stretchr/testify/assert"
Expand All @@ -27,6 +30,40 @@ func logger() ll.Logger {
return ll.NewZapLogger(logger, ll.DebugLevel)
}

func TestGetConnectionInitMessageHelper(t *testing.T) {
var callback OnWsConnectionInitCallback = func(ctx context.Context, url string, header http.Header) (json.RawMessage, error) {
return json.RawMessage(`{"authorization":"secret"}`), nil
}

tests := []struct {
name string
callback *OnWsConnectionInitCallback
want string
}{
{
name: "without payload",
callback: nil,
want: `{"type":"connection_init"}`,
},
{
name: "with payload",
callback: &callback,
want: `{"type":"connection_init","payload":{"authorization":"secret"}}`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := SubscriptionClient{onWsConnectionInitCallback: tt.callback}
got, err := client.getConnectionInitMessage(context.Background(), "", nil)
require.NoError(t, err)
require.NotEmpty(t, got)

assert.Equal(t, tt.want, string(got))
})
}
}

func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) {
serverDone := &sync.WaitGroup{}
connectedClients := atomic.NewInt64(0)
Expand Down

0 comments on commit 3ef7902

Please sign in to comment.