Skip to content
This repository was archived by the owner on May 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Graphql Lambda Subscriptions

[![Release](https://github.com/reconbot/graphql-lambda-subscriptions/actions/workflows/test.yml/badge.svg)](https://github.com/reconbot/graphql-lambda-subscriptions/actions/workflows/test.yml)

This is a fork of [`subscriptionless`](https://github.com/andyrichardson/subscriptionless) and is a Amazon Lambda Serverless equivalent to [graphQL-ws](https://github.com/enisdenjo/graphql-ws). It follows the [`graphql-ws prototcol`](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). It is tested with the [Architect Sandbox](https://arc.codes/docs/en/reference/cli/sandbox) against `graphql-ws` directly and run in production today. For many applications `graphql-lambda-subscriptions` should do what `graphql-ws` does for you today without having to run a server.
Expand All @@ -17,6 +18,7 @@ I had different requirements and needed more features. This project wouldn't exi
- Provides a Pub/Sub system to broadcast events to subscriptions
- Provides hooks for the full lifecycle of a subscription
- Type compatible with GraphQL and [`nexus.js`](https://nexusjs.org)
- Optional Logging

## Quick Start

Expand Down Expand Up @@ -203,7 +205,7 @@ resources:

```tf
resource "aws_dynamodb_table" "connections-table" {
name = "subscriptionless_connections"
name = "graphql_connections"
billing_mode = "PROVISIONED"
read_capacity = 1
write_capacity = 1
Expand All @@ -216,7 +218,7 @@ resource "aws_dynamodb_table" "connections-table" {
}

resource "aws_dynamodb_table" "subscriptions-table" {
name = "subscriptionless_subscriptions"
name = "graphql_subscriptions"
billing_mode = "PROVISIONED"
read_capacity = 1
write_capacity = 1
Expand Down Expand Up @@ -370,7 +372,7 @@ Context values are accessible in all resolver level functions (`resolve`, `subsc

<summary>📖 Default value</summary>

Assuming no `context` argument is provided, the default value is an object containing a `connectionParams` attribute.
Assuming no `context` argument is provided, the default value is an object containing a `connectionInitPayload` attribute.

This attribute contains the [(optionally parsed)](#events) payload from `connection_init`.

Expand All @@ -379,7 +381,7 @@ export const resolver = {
Subscribe: {
mySubscription: {
resolve: (event, args, context) => {
console.log(context.connectionParams); // payload from connection_init
console.log(context.connectionInitPayload); // payload from connection_init
},
},
},
Expand Down Expand Up @@ -418,9 +420,9 @@ The default context value is passed as an argument.
```ts
const instance = createInstance({
/* ... */
context: ({ connectionParams }) => ({
context: ({ connectionInitPayload }) => ({
myAttr: 'hello',
user: connectionParams.user,
user: connectionInitPayload.user,
}),
});
```
Expand Down
21 changes: 11 additions & 10 deletions lib/handleStateMachineEvent.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,41 @@
import { MessageType } from 'graphql-ws'
import { ServerClosure, ServerInstance } from './types'
import { sendMessage } from './utils/sendMessage'
import { postToConnection } from './utils/postToConnection'
import { deleteConnection } from './utils/deleteConnection'

export const handleStateMachineEvent = (c: ServerClosure): ServerInstance['stateMachineHandler'] => async (input) => {
if (!c.pingpong) {
export const handleStateMachineEvent = (serverPromise: Promise<ServerClosure>): ServerInstance['stateMachineHandler'] => async (input) => {
const server = await serverPromise
if (!server.pingpong) {
throw new Error('Invalid pingpong settings')
}
const connection = Object.assign(new c.model.Connection(), {
const connection = Object.assign(new server.model.Connection(), {
id: input.connectionId,
})

// Initial state - send ping message
if (input.state === 'PING') {
await sendMessage(c)({ ...input, message: { type: MessageType.Ping } })
await c.mapper.update(Object.assign(connection, { hasPonged: false }), {
await postToConnection(server)({ ...input, message: { type: MessageType.Ping } })
await server.mapper.update(Object.assign(connection, { hasPonged: false }), {
onMissing: 'skip',
})
return {
...input,
state: 'REVIEW',
seconds: c.pingpong.delay,
seconds: server.pingpong.delay,
}
}

// Follow up state - check if pong was returned
const conn = await c.mapper.get(connection)
const conn = await server.mapper.get(connection)
if (conn.hasPonged) {
return {
...input,
state: 'PING',
seconds: c.pingpong.timeout,
seconds: server.pingpong.timeout,
}
}

await deleteConnection(c)({ ...input })
await deleteConnection(server)({ ...input })
return {
...input,
state: 'ABORT',
Expand Down
13 changes: 7 additions & 6 deletions lib/handleGatewayEvent.ts → lib/handleWebSocketEvent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@ import { subscribe } from './messages/subscribe'
import { connection_init } from './messages/connection_init'
import { pong } from './messages/pong'

export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatewayHandler'] => async (event) => {
export const handleWebSocketEvent = (serverPromise: Promise<ServerClosure>): ServerInstance['webSocketHandler'] => async (event) => {
const server = await serverPromise
if (!event.requestContext) {
server.log('handleGatewayEvent unknown')
server.log('handleWebSocketEvent unknown')
return {
statusCode: 200,
body: '',
}
}

if (event.requestContext.eventType === 'CONNECT') {
server.log('handleGatewayEvent CONNECT', { connectionId: event.requestContext.connectionId })
server.log('handleWebSocketEvent CONNECT', { connectionId: event.requestContext.connectionId })
await server.onConnect?.({ event })
return {
statusCode: 200,
Expand All @@ -30,7 +31,7 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew

if (event.requestContext.eventType === 'MESSAGE') {
const message = event.body === null ? null : JSON.parse(event.body)
server.log('handleGatewayEvent MESSAGE', { connectionId: event.requestContext.connectionId, type: message.type })
server.log('handleWebSocketEvent MESSAGE', { connectionId: event.requestContext.connectionId, type: message.type })

if (message.type === MessageType.ConnectionInit) {
await connection_init({ server, event, message })
Expand Down Expand Up @@ -74,15 +75,15 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew
}

if (event.requestContext.eventType === 'DISCONNECT') {
server.log('handleGatewayEvent DISCONNECT', { connectionId: event.requestContext.connectionId })
server.log('handleWebSocketEvent DISCONNECT', { connectionId: event.requestContext.connectionId })
await disconnect({ server, event, message: null })
return {
statusCode: 200,
body: '',
}
}

server.log('handleGatewayEvent UNKNOWN', { connectionId: event.requestContext.connectionId })
server.log('handleWebSocketEvent UNKNOWN', { connectionId: event.requestContext.connectionId })
return {
statusCode: 200,
body: '',
Expand Down
8 changes: 4 additions & 4 deletions lib/index-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { Handler } from 'aws-lambda'
import { tables } from '@architect/sandbox'
import { createInstance } from '.'
import { mockServerArgs } from './test/mockServer'
import { APIGatewayWebSocketEvent, WebsocketResponse } from './types'
import { APIGatewayWebSocketEvent, WebSocketResponse } from './types'

describe('createInstance', () => {
describe('gatewayHandler', () => {
describe('webSocketHandler', () => {
before(async () => {
await tables.start({ cwd: './mocks/arc-basic-events', quiet: true })
})
Expand All @@ -18,8 +18,8 @@ describe('createInstance', () => {
it('is type compatible with aws-lambda handler', async () => {
const server = createInstance(await mockServerArgs())

const gatewayHandler: Handler<APIGatewayWebSocketEvent, WebsocketResponse> = server.gatewayHandler
assert.ok(gatewayHandler)
const webSocketHandler: Handler<APIGatewayWebSocketEvent, WebSocketResponse> = server.webSocketHandler
assert.ok(webSocketHandler)
})
})
})
27 changes: 23 additions & 4 deletions lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
import { ServerArgs, ServerClosure, ServerInstance } from './types'
import { publish } from './pubsub/publish'
import { complete } from './pubsub/complete'
import { handleGatewayEvent } from './handleGatewayEvent'
import { handleWebSocketEvent } from './handleWebSocketEvent'
import { handleStateMachineEvent } from './handleStateMachineEvent'
import { makeServerClosure } from './makeServerClosure'

export const createInstance = (opts: ServerArgs): ServerInstance => {
const closure: ServerClosure = makeServerClosure(opts)
const closure: Promise<ServerClosure> = makeServerClosure(opts)

return {
gatewayHandler: handleGatewayEvent(closure),
webSocketHandler: handleWebSocketEvent(closure),
stateMachineHandler: handleStateMachineEvent(closure),
publish: publish(closure),
complete: complete(closure),
}
}

export * from './pubsub/subscribe'
export * from './types'
export {
ServerArgs,
ServerInstance,
APIGatewayWebSocketRequestContext,
SubscribeOptions,
SubscribeArgs,
SubscribePseudoIterable,
MaybePromise,
ApiGatewayManagementApiSubset,
TableNames,
APIGatewayWebSocketEvent,
LoggerFunction,
ApiSebSocketHandler,
WebSocketResponse,
StateFunctionInput,
PubSubEvent,
PartialBy,
SubscriptionDefinition,
SubscriptionFilter,
} from './types'
export { Subscription } from './model/Subscription'
export { Connection } from './model/Connection'
22 changes: 15 additions & 7 deletions lib/makeServerClosure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,30 @@ import { ServerArgs, ServerClosure } from './types'
import { createModel } from './model/createModel'
import { Subscription } from './model/Subscription'
import { Connection } from './model/Connection'
import { log } from './utils/logger'
import { log as debugLogger } from './utils/logger'

export function makeServerClosure(opts: ServerArgs): ServerClosure {
export const makeServerClosure = async (opts: ServerArgs): Promise<ServerClosure> => {
const {
tableNames,
log = debugLogger,
dynamodb,
apiGatewayManagementApi,
...rest
} = opts
return {
log: log,
...opts,
...rest,
apiGatewayManagementApi: await apiGatewayManagementApi,
log,
model: {
Subscription: createModel({
model: Subscription,
table: opts.tableNames?.subscriptions || 'subscriptionless_subscriptions',
table: (await tableNames)?.subscriptions || 'graphql_subscriptions',
}),
Connection: createModel({
model: Connection,
table: opts.tableNames?.connections || 'subscriptionless_connections',
table: (await tableNames)?.connections || 'graphql_connections',
}),
},
mapper: new DataMapper({ client: opts.dynamodb }),
mapper: new DataMapper({ client: await dynamodb }),
}
}
2 changes: 1 addition & 1 deletion lib/messages/complete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const complete: MessageHandler<CompleteMessage> =
server.schema,
parse(sub.subscription.query),
undefined,
await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }),
await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }),
sub.subscription.variables,
sub.subscription.operationName,
undefined,
Expand Down
4 changes: 2 additions & 2 deletions lib/messages/connection_init.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { StepFunctions } from 'aws-sdk'
import { ConnectionInitMessage, MessageType } from 'graphql-ws'
import { StateFunctionInput, MessageHandler } from '../types'
import { sendMessage } from '../utils/sendMessage'
import { postToConnection } from '../utils/postToConnection'
import { deleteConnection } from '../utils/deleteConnection'

/** Handler function for 'connection_init' message. */
Expand Down Expand Up @@ -34,7 +34,7 @@ export const connection_init: MessageHandler<ConnectionInitMessage> =
payload,
})
await server.mapper.put(connection)
return sendMessage(server)({
return postToConnection(server)({
...event.requestContext,
message: { type: MessageType.ConnectionAck },
})
Expand Down
2 changes: 1 addition & 1 deletion lib/messages/disconnect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export const disconnect: MessageHandler<null> = async ({ server, event }) => {
server.schema,
parse(sub.subscription.query),
undefined,
await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }),
await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }),
sub.subscription.variables,
sub.subscription.operationName,
undefined,
Expand Down
4 changes: 2 additions & 2 deletions lib/messages/ping.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { PingMessage, MessageType } from 'graphql-ws'
import { sendMessage } from '../utils/sendMessage'
import { postToConnection } from '../utils/postToConnection'
import { deleteConnection } from '../utils/deleteConnection'
import { MessageHandler } from '../types'

/** Handler function for 'ping' message. */
export const ping: MessageHandler<PingMessage> = async ({ server, event, message }) => {
try {
await server.onPing?.({ event, message })
return sendMessage(server)({
return postToConnection(server)({
...event.requestContext,
message: { type: MessageType.Pong },
})
Expand Down
16 changes: 8 additions & 8 deletions lib/messages/subscribe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
import { APIGatewayWebSocketEvent, ServerClosure, MessageHandler, SubscribePseudoIterable, PubSubEvent } from '../types'
import { constructContext } from '../utils/constructContext'
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
import { sendMessage } from '../utils/sendMessage'
import { postToConnection } from '../utils/postToConnection'
import { deleteConnection } from '../utils/deleteConnection'
import { isArray } from '../utils/isArray'

Expand Down Expand Up @@ -38,7 +38,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve

if (errors) {
server.log('subscribe:validateError', errors)
return sendMessage(server)({
return postToConnection(server)({
...event.requestContext,
message: {
type: MessageType.Error,
Expand All @@ -48,7 +48,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
})
}

const contextValue = await constructContext({ server, connectionParams: connection.payload, connectionId })
const contextValue = await constructContext({ server, connectionInitPayload: connection.payload, connectionId })

const execContext = buildExecutionContext(
server.schema,
Expand All @@ -61,7 +61,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
)

if (isArray(execContext)) {
return sendMessage(server)({
return postToConnection(server)({
...event.requestContext,
message: {
type: MessageType.Error,
Expand All @@ -87,7 +87,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
const onSubscribeErrors = await onSubscribe?.(root, args, context, info)
if (onSubscribeErrors){
server.log('onSubscribe', { onSubscribeErrors })
return sendMessage(server)({
return postToConnection(server)({
...event.requestContext,
message: {
type: MessageType.Error,
Expand All @@ -110,7 +110,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
...message.payload,
},
connectionId: connection.id,
connectionParams: connection.payload,
connectionInitPayload: connection.payload,
requestContext: event.requestContext,
ttl: connection.ttl,
})
Expand Down Expand Up @@ -155,7 +155,7 @@ async function executeQuery(server: ServerClosure, message: SubscribeMessage, co
undefined,
)

await sendMessage(server)({
await postToConnection(server)({
...event.requestContext,
message: {
type: MessageType.Next,
Expand All @@ -164,7 +164,7 @@ async function executeQuery(server: ServerClosure, message: SubscribeMessage, co
},
})

await sendMessage(server)({
await postToConnection(server)({
...event.requestContext,
message: {
type: MessageType.Complete,
Expand Down
Loading