Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate headers to WS upstream using connection_init #343

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@ import (
"nhooyr.io/websocket"
)

var (
connectionInitMessage = []byte(`{"type":"connection_init"}`)
)

const (
startMessage = `{"type":"start","id":"%s","payload":%s}`
stopMessage = `{"type":"stop","id":"%s"}`
internalError = `{"errors":[{"message":"connection error"}]}`
connectionError = `{"errors":[{"message":"connection error"}]}`
initMessageWithPayload = `{"type":"connection_init", "payload":%s}`
initMessageNoPayload = `{"type":"connection_init"}`
startMessage = `{"type":"start","id":"%s","payload":%s}`
stopMessage = `{"type":"stop","id":"%s"}`
internalError = `{"errors":[{"message":"connection error"}]}`
connectionError = `{"errors":[{"message":"connection error"}]}`
)

// WebSocketGraphQLSubscriptionClient is a WebSocket client that allows running multiple subscriptions via the same WebSocket Connection
Expand Down Expand Up @@ -86,7 +84,6 @@ func NewWebSocketGraphQLSubscriptionClient(httpClient *http.Client, ctx context.
// If an existing WS with the same ID (Hash) exists, it is being re-used
// If no connection exists, the client initiates a new one and sends the "init" and "connection ack" messages
func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, options GraphQLSubscriptionOptions, next chan<- []byte) error {

handlerID, err := c.generateHandlerIDHash(options)
if err != nil {
return err
Expand All @@ -112,6 +109,12 @@ func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, opti
if options.Header == nil {
options.Header = http.Header{}
}

initMessage, err := connectionInitMessage(options.Header)
if err != nil {
return err
}

options.Header.Set("Sec-WebSocket-Protocol", "graphql-ws")
options.Header.Set("Sec-WebSocket-Version", "13")

Expand All @@ -128,7 +131,7 @@ func (c *WebSocketGraphQLSubscriptionClient) Subscribe(ctx context.Context, opti
return fmt.Errorf("upgrade unsuccessful")
}
// init + ack
err = conn.Write(ctx, websocket.MessageText, connectionInitMessage)
err = conn.Write(ctx, websocket.MessageText, []byte(initMessage))
if err != nil {
return err
}
Expand Down Expand Up @@ -181,6 +184,21 @@ func (c *WebSocketGraphQLSubscriptionClient) generateHandlerIDHash(options Graph
return xxh.Sum64(), nil
}

func connectionInitMessage(header http.Header) (string, error) {
if len(header) == 0 {
return initMessageNoPayload, nil
}
payload := make(map[string]string, len(header))
for name := range header {
payload[name] = header.Get(name)
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return "", err
}
return fmt.Sprintf(initMessageWithPayload, payloadBytes), nil
}

func newConnectionHandler(ctx context.Context, conn *websocket.Conn, readTimeout time.Duration, log abstractlogger.Logger) *connectionHandler {
return &connectionHandler{
conn: conn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,92 @@ func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) {
return connectedClients.Load() == 0
}, time.Second, time.Millisecond, "clients not 0")
}

func TestWebsocketSubscriptionClientWithInitPayload(t *testing.T) {
assertInitAck := func(ctx context.Context, conn *websocket.Conn) {
msgType, data, err := conn.Read(ctx)
assert.NoError(t, err)
assert.Equal(t, websocket.MessageText, msgType)
assert.Equal(t, `{"type":"connection_init", "payload":{"Authorization":"Bearer XXX"}}`, string(data))
err = conn.Write(ctx, websocket.MessageText, []byte(`{"type":"connection_ack"}`))
assert.NoError(t, err)
}

handshakeHappened := false
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, nil)
assert.NoError(t, err)
assertInitAck(r.Context(), conn)
handshakeHappened = true
<-conn.CloseRead(r.Context()).Done()
}))
defer server.Close()

clientCtx, clientCancel := context.WithCancel(context.Background())
defer clientCancel()
client := NewWebSocketGraphQLSubscriptionClient(http.DefaultClient, clientCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
)

subscribeHeaders := http.Header{}
subscribeHeaders.Add("Authorization", "Bearer XXX")
next := make(chan []byte)
subCtx, subCancel := context.WithCancel(context.Background())
defer subCancel()
err := client.Subscribe(subCtx, GraphQLSubscriptionOptions{
URL: strings.Replace(server.URL, "http", "ws", -1),
Body: GraphQLBody{
Query: `subscription {messageAdded(roomName: "room"){text}}`,
},
Header: subscribeHeaders,
}, next)
assert.NoError(t, err)

assert.Len(t, client.handlers, 1, "handler not registered")
assert.Eventuallyf(t, func() bool {
return handshakeHappened
}, time.Second, time.Millisecond, "handshake was not performed")
}

func TestConnectionInitMessage(t *testing.T) {
for i, tc := range []struct {
header http.Header
expectedMessage string
}{
{
header: http.Header{},
expectedMessage: `{"type":"connection_init"}`,
},
{
header: nil,
expectedMessage: `{"type":"connection_init"}`,
},
{
header: http.Header{"Foo": []string{"bar"}},
expectedMessage: `{"type":"connection_init", "payload":{"Foo":"bar"}}`,
},
{
header: http.Header{"Foo": []string{"bar", "baz"}},
expectedMessage: `{"type":"connection_init", "payload":{"Foo":"bar"}}`,
},
{
header: http.Header{"Foo": []string{""}},
expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`,
},
{
header: http.Header{"Foo": []string{}},
expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`,
},
{
header: http.Header{"Foo": nil},
expectedMessage: `{"type":"connection_init", "payload":{"Foo":""}}`,
},
} {
t.Run(fmt.Sprint(i), func(t *testing.T) {
msg, err := connectionInitMessage(tc.header)
assert.NoError(t, err)
assert.Equal(t, tc.expectedMessage, msg)
})
}
}