diff --git a/README.md b/README.md index 389bf268..5597fdac 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # Graphql Lambda Subscriptions + [![Release](https://github.com/reconbot/graphql-lambda-subscriptions/actions/workflows/test.yml/badge.svg)](https://github.com/reconbot/graphql-lambda-subscriptions/actions/workflows/test.yml) This is a fork of [`subscriptionless`](https://github.com/andyrichardson/subscriptionless) and is a Amazon Lambda Serverless equivalent to [graphQL-ws](https://github.com/enisdenjo/graphql-ws). It follows the [`graphql-ws prototcol`](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). It is tested with the [Architect Sandbox](https://arc.codes/docs/en/reference/cli/sandbox) against `graphql-ws` directly and run in production today. For many applications `graphql-lambda-subscriptions` should do what `graphql-ws` does for you today without having to run a server. @@ -17,6 +18,7 @@ I had different requirements and needed more features. This project wouldn't exi - Provides a Pub/Sub system to broadcast events to subscriptions - Provides hooks for the full lifecycle of a subscription - Type compatible with GraphQL and [`nexus.js`](https://nexusjs.org) +- Optional Logging ## Quick Start @@ -203,7 +205,7 @@ resources: ```tf resource "aws_dynamodb_table" "connections-table" { - name = "subscriptionless_connections" + name = "graphql_connections" billing_mode = "PROVISIONED" read_capacity = 1 write_capacity = 1 @@ -216,7 +218,7 @@ resource "aws_dynamodb_table" "connections-table" { } resource "aws_dynamodb_table" "subscriptions-table" { - name = "subscriptionless_subscriptions" + name = "graphql_subscriptions" billing_mode = "PROVISIONED" read_capacity = 1 write_capacity = 1 @@ -370,7 +372,7 @@ Context values are accessible in all resolver level functions (`resolve`, `subsc 📖 Default value -Assuming no `context` argument is provided, the default value is an object containing a `connectionParams` attribute. +Assuming no `context` argument is provided, the default value is an object containing a `connectionInitPayload` attribute. This attribute contains the [(optionally parsed)](#events) payload from `connection_init`. @@ -379,7 +381,7 @@ export const resolver = { Subscribe: { mySubscription: { resolve: (event, args, context) => { - console.log(context.connectionParams); // payload from connection_init + console.log(context.connectionInitPayload); // payload from connection_init }, }, }, @@ -418,9 +420,9 @@ The default context value is passed as an argument. ```ts const instance = createInstance({ /* ... */ - context: ({ connectionParams }) => ({ + context: ({ connectionInitPayload }) => ({ myAttr: 'hello', - user: connectionParams.user, + user: connectionInitPayload.user, }), }); ``` diff --git a/lib/handleStateMachineEvent.ts b/lib/handleStateMachineEvent.ts index 61ab3039..201006cd 100644 --- a/lib/handleStateMachineEvent.ts +++ b/lib/handleStateMachineEvent.ts @@ -1,40 +1,41 @@ import { MessageType } from 'graphql-ws' import { ServerClosure, ServerInstance } from './types' -import { sendMessage } from './utils/sendMessage' +import { postToConnection } from './utils/postToConnection' import { deleteConnection } from './utils/deleteConnection' -export const handleStateMachineEvent = (c: ServerClosure): ServerInstance['stateMachineHandler'] => async (input) => { - if (!c.pingpong) { +export const handleStateMachineEvent = (serverPromise: Promise): ServerInstance['stateMachineHandler'] => async (input) => { + const server = await serverPromise + if (!server.pingpong) { throw new Error('Invalid pingpong settings') } - const connection = Object.assign(new c.model.Connection(), { + const connection = Object.assign(new server.model.Connection(), { id: input.connectionId, }) // Initial state - send ping message if (input.state === 'PING') { - await sendMessage(c)({ ...input, message: { type: MessageType.Ping } }) - await c.mapper.update(Object.assign(connection, { hasPonged: false }), { + await postToConnection(server)({ ...input, message: { type: MessageType.Ping } }) + await server.mapper.update(Object.assign(connection, { hasPonged: false }), { onMissing: 'skip', }) return { ...input, state: 'REVIEW', - seconds: c.pingpong.delay, + seconds: server.pingpong.delay, } } // Follow up state - check if pong was returned - const conn = await c.mapper.get(connection) + const conn = await server.mapper.get(connection) if (conn.hasPonged) { return { ...input, state: 'PING', - seconds: c.pingpong.timeout, + seconds: server.pingpong.timeout, } } - await deleteConnection(c)({ ...input }) + await deleteConnection(server)({ ...input }) return { ...input, state: 'ABORT', diff --git a/lib/handleGatewayEvent.ts b/lib/handleWebSocketEvent.ts similarity index 76% rename from lib/handleGatewayEvent.ts rename to lib/handleWebSocketEvent.ts index 6c64b28a..17253e20 100644 --- a/lib/handleGatewayEvent.ts +++ b/lib/handleWebSocketEvent.ts @@ -7,9 +7,10 @@ import { subscribe } from './messages/subscribe' import { connection_init } from './messages/connection_init' import { pong } from './messages/pong' -export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatewayHandler'] => async (event) => { +export const handleWebSocketEvent = (serverPromise: Promise): ServerInstance['webSocketHandler'] => async (event) => { + const server = await serverPromise if (!event.requestContext) { - server.log('handleGatewayEvent unknown') + server.log('handleWebSocketEvent unknown') return { statusCode: 200, body: '', @@ -17,7 +18,7 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew } if (event.requestContext.eventType === 'CONNECT') { - server.log('handleGatewayEvent CONNECT', { connectionId: event.requestContext.connectionId }) + server.log('handleWebSocketEvent CONNECT', { connectionId: event.requestContext.connectionId }) await server.onConnect?.({ event }) return { statusCode: 200, @@ -30,7 +31,7 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew if (event.requestContext.eventType === 'MESSAGE') { const message = event.body === null ? null : JSON.parse(event.body) - server.log('handleGatewayEvent MESSAGE', { connectionId: event.requestContext.connectionId, type: message.type }) + server.log('handleWebSocketEvent MESSAGE', { connectionId: event.requestContext.connectionId, type: message.type }) if (message.type === MessageType.ConnectionInit) { await connection_init({ server, event, message }) @@ -74,7 +75,7 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew } if (event.requestContext.eventType === 'DISCONNECT') { - server.log('handleGatewayEvent DISCONNECT', { connectionId: event.requestContext.connectionId }) + server.log('handleWebSocketEvent DISCONNECT', { connectionId: event.requestContext.connectionId }) await disconnect({ server, event, message: null }) return { statusCode: 200, @@ -82,7 +83,7 @@ export const handleGatewayEvent = (server: ServerClosure): ServerInstance['gatew } } - server.log('handleGatewayEvent UNKNOWN', { connectionId: event.requestContext.connectionId }) + server.log('handleWebSocketEvent UNKNOWN', { connectionId: event.requestContext.connectionId }) return { statusCode: 200, body: '', diff --git a/lib/index-test.ts b/lib/index-test.ts index d06b9718..a1c1f6f8 100644 --- a/lib/index-test.ts +++ b/lib/index-test.ts @@ -3,10 +3,10 @@ import { Handler } from 'aws-lambda' import { tables } from '@architect/sandbox' import { createInstance } from '.' import { mockServerArgs } from './test/mockServer' -import { APIGatewayWebSocketEvent, WebsocketResponse } from './types' +import { APIGatewayWebSocketEvent, WebSocketResponse } from './types' describe('createInstance', () => { - describe('gatewayHandler', () => { + describe('webSocketHandler', () => { before(async () => { await tables.start({ cwd: './mocks/arc-basic-events', quiet: true }) }) @@ -18,8 +18,8 @@ describe('createInstance', () => { it('is type compatible with aws-lambda handler', async () => { const server = createInstance(await mockServerArgs()) - const gatewayHandler: Handler = server.gatewayHandler - assert.ok(gatewayHandler) + const webSocketHandler: Handler = server.webSocketHandler + assert.ok(webSocketHandler) }) }) }) diff --git a/lib/index.ts b/lib/index.ts index c875c209..786e4680 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1,15 +1,15 @@ import { ServerArgs, ServerClosure, ServerInstance } from './types' import { publish } from './pubsub/publish' import { complete } from './pubsub/complete' -import { handleGatewayEvent } from './handleGatewayEvent' +import { handleWebSocketEvent } from './handleWebSocketEvent' import { handleStateMachineEvent } from './handleStateMachineEvent' import { makeServerClosure } from './makeServerClosure' export const createInstance = (opts: ServerArgs): ServerInstance => { - const closure: ServerClosure = makeServerClosure(opts) + const closure: Promise = makeServerClosure(opts) return { - gatewayHandler: handleGatewayEvent(closure), + webSocketHandler: handleWebSocketEvent(closure), stateMachineHandler: handleStateMachineEvent(closure), publish: publish(closure), complete: complete(closure), @@ -17,6 +17,25 @@ export const createInstance = (opts: ServerArgs): ServerInstance => { } export * from './pubsub/subscribe' -export * from './types' +export { + ServerArgs, + ServerInstance, + APIGatewayWebSocketRequestContext, + SubscribeOptions, + SubscribeArgs, + SubscribePseudoIterable, + MaybePromise, + ApiGatewayManagementApiSubset, + TableNames, + APIGatewayWebSocketEvent, + LoggerFunction, + ApiSebSocketHandler, + WebSocketResponse, + StateFunctionInput, + PubSubEvent, + PartialBy, + SubscriptionDefinition, + SubscriptionFilter, +} from './types' export { Subscription } from './model/Subscription' export { Connection } from './model/Connection' diff --git a/lib/makeServerClosure.ts b/lib/makeServerClosure.ts index 0ebda6ee..935f6dd4 100644 --- a/lib/makeServerClosure.ts +++ b/lib/makeServerClosure.ts @@ -3,22 +3,30 @@ import { ServerArgs, ServerClosure } from './types' import { createModel } from './model/createModel' import { Subscription } from './model/Subscription' import { Connection } from './model/Connection' -import { log } from './utils/logger' +import { log as debugLogger } from './utils/logger' -export function makeServerClosure(opts: ServerArgs): ServerClosure { +export const makeServerClosure = async (opts: ServerArgs): Promise => { + const { + tableNames, + log = debugLogger, + dynamodb, + apiGatewayManagementApi, + ...rest + } = opts return { - log: log, - ...opts, + ...rest, + apiGatewayManagementApi: await apiGatewayManagementApi, + log, model: { Subscription: createModel({ model: Subscription, - table: opts.tableNames?.subscriptions || 'subscriptionless_subscriptions', + table: (await tableNames)?.subscriptions || 'graphql_subscriptions', }), Connection: createModel({ model: Connection, - table: opts.tableNames?.connections || 'subscriptionless_connections', + table: (await tableNames)?.connections || 'graphql_connections', }), }, - mapper: new DataMapper({ client: opts.dynamodb }), + mapper: new DataMapper({ client: await dynamodb }), } } diff --git a/lib/messages/complete.ts b/lib/messages/complete.ts index bd7c0e58..bf44f4cf 100644 --- a/lib/messages/complete.ts +++ b/lib/messages/complete.ts @@ -26,7 +26,7 @@ export const complete: MessageHandler = server.schema, parse(sub.subscription.query), undefined, - await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, diff --git a/lib/messages/connection_init.ts b/lib/messages/connection_init.ts index 9c1dda4e..47d70123 100644 --- a/lib/messages/connection_init.ts +++ b/lib/messages/connection_init.ts @@ -1,7 +1,7 @@ import { StepFunctions } from 'aws-sdk' import { ConnectionInitMessage, MessageType } from 'graphql-ws' import { StateFunctionInput, MessageHandler } from '../types' -import { sendMessage } from '../utils/sendMessage' +import { postToConnection } from '../utils/postToConnection' import { deleteConnection } from '../utils/deleteConnection' /** Handler function for 'connection_init' message. */ @@ -34,7 +34,7 @@ export const connection_init: MessageHandler = payload, }) await server.mapper.put(connection) - return sendMessage(server)({ + return postToConnection(server)({ ...event.requestContext, message: { type: MessageType.ConnectionAck }, }) diff --git a/lib/messages/disconnect.ts b/lib/messages/disconnect.ts index 5f4c46b7..d3d1aeda 100644 --- a/lib/messages/disconnect.ts +++ b/lib/messages/disconnect.ts @@ -36,7 +36,7 @@ export const disconnect: MessageHandler = async ({ server, event }) => { server.schema, parse(sub.subscription.query), undefined, - await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, diff --git a/lib/messages/ping.ts b/lib/messages/ping.ts index 0e13cfac..0cb162f7 100644 --- a/lib/messages/ping.ts +++ b/lib/messages/ping.ts @@ -1,5 +1,5 @@ import { PingMessage, MessageType } from 'graphql-ws' -import { sendMessage } from '../utils/sendMessage' +import { postToConnection } from '../utils/postToConnection' import { deleteConnection } from '../utils/deleteConnection' import { MessageHandler } from '../types' @@ -7,7 +7,7 @@ import { MessageHandler } from '../types' export const ping: MessageHandler = async ({ server, event, message }) => { try { await server.onPing?.({ event, message }) - return sendMessage(server)({ + return postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Pong }, }) diff --git a/lib/messages/subscribe.ts b/lib/messages/subscribe.ts index 14408f73..9f839a94 100644 --- a/lib/messages/subscribe.ts +++ b/lib/messages/subscribe.ts @@ -8,7 +8,7 @@ import { import { APIGatewayWebSocketEvent, ServerClosure, MessageHandler, SubscribePseudoIterable, PubSubEvent } from '../types' import { constructContext } from '../utils/constructContext' import { getResolverAndArgs } from '../utils/getResolverAndArgs' -import { sendMessage } from '../utils/sendMessage' +import { postToConnection } from '../utils/postToConnection' import { deleteConnection } from '../utils/deleteConnection' import { isArray } from '../utils/isArray' @@ -38,7 +38,7 @@ const setupSubscription: MessageHandler = async ({ server, eve if (errors) { server.log('subscribe:validateError', errors) - return sendMessage(server)({ + return postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Error, @@ -48,7 +48,7 @@ const setupSubscription: MessageHandler = async ({ server, eve }) } - const contextValue = await constructContext({ server, connectionParams: connection.payload, connectionId }) + const contextValue = await constructContext({ server, connectionInitPayload: connection.payload, connectionId }) const execContext = buildExecutionContext( server.schema, @@ -61,7 +61,7 @@ const setupSubscription: MessageHandler = async ({ server, eve ) if (isArray(execContext)) { - return sendMessage(server)({ + return postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Error, @@ -87,7 +87,7 @@ const setupSubscription: MessageHandler = async ({ server, eve const onSubscribeErrors = await onSubscribe?.(root, args, context, info) if (onSubscribeErrors){ server.log('onSubscribe', { onSubscribeErrors }) - return sendMessage(server)({ + return postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Error, @@ -110,7 +110,7 @@ const setupSubscription: MessageHandler = async ({ server, eve ...message.payload, }, connectionId: connection.id, - connectionParams: connection.payload, + connectionInitPayload: connection.payload, requestContext: event.requestContext, ttl: connection.ttl, }) @@ -155,7 +155,7 @@ async function executeQuery(server: ServerClosure, message: SubscribeMessage, co undefined, ) - await sendMessage(server)({ + await postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Next, @@ -164,7 +164,7 @@ async function executeQuery(server: ServerClosure, message: SubscribeMessage, co }, }) - await sendMessage(server)({ + await postToConnection(server)({ ...event.requestContext, message: { type: MessageType.Complete, diff --git a/lib/model/Subscription.ts b/lib/model/Subscription.ts index bf5b0c77..b9981f58 100644 --- a/lib/model/Subscription.ts +++ b/lib/model/Subscription.ts @@ -41,7 +41,7 @@ export class Subscription { /** Redundant copy of connection_init payload */ @attribute() - connectionParams: object + connectionInitPayload: object @attribute() requestContext: APIGatewayWebSocketRequestContext diff --git a/lib/pubsub/complete-test.ts b/lib/pubsub/complete-test.ts index 1f3b5625..2afba159 100644 --- a/lib/pubsub/complete-test.ts +++ b/lib/pubsub/complete-test.ts @@ -12,7 +12,6 @@ describe('pubsub:complete', () => { }) it('takes a topic', async () => { - const server = await mockServerContext() - await complete(server)({ topic: 'Topic12' }) + await complete(mockServerContext())({ topic: 'Topic12' }) }) }) diff --git a/lib/pubsub/complete.ts b/lib/pubsub/complete.ts index d02226e2..d5699c65 100644 --- a/lib/pubsub/complete.ts +++ b/lib/pubsub/complete.ts @@ -3,13 +3,14 @@ import { parse } from 'graphql' import { CompleteMessage, MessageType } from 'graphql-ws' import { buildExecutionContext } from 'graphql/execution/execute' import { ServerClosure, PubSubEvent, SubscribePseudoIterable, ServerInstance } from '../types' -import { sendMessage } from '../utils/sendMessage' +import { postToConnection } from '../utils/postToConnection' import { constructContext } from '../utils/constructContext' import { getResolverAndArgs } from '../utils/getResolverAndArgs' import { isArray } from '../utils/isArray' import { getFilteredSubs } from './getFilteredSubs' -export const complete = (server: ServerClosure): ServerInstance['complete'] => async event => { +export const complete = (serverPromise: Promise): ServerInstance['complete'] => async event => { + const server = await serverPromise const subscriptions = await getFilteredSubs({ server, event }) server.log('pubsub:complete %j', { event, subscriptions }) @@ -18,7 +19,7 @@ export const complete = (server: ServerClosure): ServerInstance['complete'] => a id: sub.subscriptionId, type: MessageType.Complete, } - await sendMessage(server)({ + await postToConnection(server)({ ...sub.requestContext, message, }) @@ -28,7 +29,7 @@ export const complete = (server: ServerClosure): ServerInstance['complete'] => a server.schema, parse(sub.subscription.query), undefined, - await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, diff --git a/lib/pubsub/publish.ts b/lib/pubsub/publish.ts index 2772602f..535e0a3b 100644 --- a/lib/pubsub/publish.ts +++ b/lib/pubsub/publish.ts @@ -1,11 +1,12 @@ import { parse, execute } from 'graphql' import { MessageType, NextMessage } from 'graphql-ws' import { ServerClosure, ServerInstance } from '../types' -import { sendMessage } from '../utils/sendMessage' +import { postToConnection } from '../utils/postToConnection' import { constructContext } from '../utils/constructContext' import { getFilteredSubs } from './getFilteredSubs' -export const publish = (server: ServerClosure): ServerInstance['publish'] => async event => { +export const publish = (serverPromise: Promise): ServerInstance['publish'] => async event => { + const server = await serverPromise server.log('pubsub:publish %j', { event }) const subscriptions = await getFilteredSubs({ server, event }) server.log('pubsub:publish %j', { subscriptions: subscriptions.map(({ connectionId, filter, subscription }) => ({ connectionId, filter, subscription }) ) }) @@ -15,7 +16,7 @@ export const publish = (server: ServerClosure): ServerInstance['publish'] => asy server.schema, parse(sub.subscription.query), event, - await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + await constructContext({ server, connectionInitPayload: sub.connectionInitPayload, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, @@ -27,7 +28,7 @@ export const publish = (server: ServerClosure): ServerInstance['publish'] => asy payload, } - await sendMessage(server)({ + await postToConnection(server)({ ...sub.requestContext, message, }) diff --git a/lib/test/execute-helper.ts b/lib/test/execute-helper.ts index 7a707f97..cb572415 100644 --- a/lib/test/execute-helper.ts +++ b/lib/test/execute-helper.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import WebSocket from 'ws' import { deferGenerator } from 'inside-out-async' diff --git a/lib/test/graphql-ws-schema.ts b/lib/test/graphql-ws-schema.ts index 8424e529..76bf4b2b 100644 --- a/lib/test/graphql-ws-schema.ts +++ b/lib/test/graphql-ws-schema.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import ws from 'ws' import { useServer } from 'graphql-ws/lib/use/ws' import { makeExecutableSchema } from '@graphql-tools/schema' @@ -51,7 +52,10 @@ const schema = makeExecutableSchema({ resolvers, }) -export const startGqlWSServer = async () => { +export const startGqlWSServer = async (): Promise<{ + url: string + stop: () => Promise +}> => { const server = new ws.Server({ port: PORT, path: '/', diff --git a/lib/test/integration-events-test.ts b/lib/test/integration-events-test.ts index 5324328a..69a9c600 100644 --- a/lib/test/integration-events-test.ts +++ b/lib/test/integration-events-test.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { assert, use } from 'chai' import { start as sandBoxStart, end as sandBoxStop } from '@architect/sandbox' import { collect, map } from 'streaming-iterables' diff --git a/lib/types.ts b/lib/types.ts index 9bd512b2..2607c2a4 100644 --- a/lib/types.ts +++ b/lib/types.ts @@ -10,10 +10,13 @@ import { Connection } from './model/Connection' export type ServerArgs = { schema: GraphQLSchema - dynamodb: DynamoDB - apiGatewayManagementApi?: ApiGatewayManagementApiSubset - context?: ((arg: { connectionParams: any, connectionId: string }) => MaybePromise) | object - tableNames?: Partial + dynamodb: MaybePromise + apiGatewayManagementApi?: MaybePromise + tableNames?: MaybePromise> + /* + Makes the context object for all operations defaults to { connectionInitPayload, connectionId } + */ + context?: ((arg: { connectionInitPayload: any, connectionId: string }) => MaybePromise) | object pingpong?: { machine: string delay: number @@ -21,7 +24,9 @@ export type ServerArgs = { } onConnect?: (e: { event: APIGatewayWebSocketEvent }) => MaybePromise onDisconnect?: (e: { event: APIGatewayWebSocketEvent }) => MaybePromise - /* Takes connection_init event and returns payload to be persisted (may include auth steps) */ + /* + Takes connection_init event and returns the connectionInitPayload to be persisted. Throw if you'd like the connection to be disconnected. Useful for auth. + */ onConnectionInit?: (e: { event: APIGatewayWebSocketEvent message: ConnectionInitMessage @@ -35,6 +40,9 @@ export type ServerArgs = { message: PongMessage }) => MaybePromise onError?: (error: any, context: any) => MaybePromise + /* + Defaults to debug('graphql-lambda-subscriptions') from https://www.npmjs.com/package/debug + */ log?: LoggerFunction } @@ -47,10 +55,11 @@ export type ServerClosure = { Connection: typeof Connection } log: LoggerFunction -} & Omit + apiGatewayManagementApi?: ApiGatewayManagementApiSubset +} & Omit export interface ServerInstance { - gatewayHandler: ApiGatewayHandler + webSocketHandler: ApiSebSocketHandler stateMachineHandler: (input: StateFunctionInput) => Promise publish: (event: PubSubEvent) => Promise complete: (event: PartialBy) => Promise @@ -63,7 +72,7 @@ export type TableNames = { export type LoggerFunction = (input: string, obj?: any) => void -export type WebsocketResponse = { +export type WebSocketResponse = { statusCode: number headers?: Record body: string @@ -139,6 +148,6 @@ export interface ApiGatewayManagementApiSubset { deleteConnection(input: { ConnectionId: string }): { promise: () => Promise } } -export type ApiGatewayHandler = (event: TEvent) => Promise +export type ApiSebSocketHandler = (event: TEvent) => Promise export type PartialBy = Omit & Partial> diff --git a/lib/utils/constructContext.ts b/lib/utils/constructContext.ts index 7945eda7..d39b01a0 100644 --- a/lib/utils/constructContext.ts +++ b/lib/utils/constructContext.ts @@ -2,9 +2,9 @@ import { ServerClosure } from '../types' // eslint-disable-next-line @typescript-eslint/no-explicit-any -export const constructContext = ({ server, connectionParams, connectionId }: { connectionParams: object, server: ServerClosure, connectionId: string }): any => { +export const constructContext = ({ server, connectionInitPayload, connectionId }: { connectionInitPayload: object, server: ServerClosure, connectionId: string }): any => { if (typeof server.context === 'function') { - return server.context({ connectionParams, connectionId }) + return server.context({ connectionInitPayload, connectionId }) } - return { ...server.context, connectionParams, connectionId } + return { ...server.context, connectionInitPayload, connectionId } } diff --git a/lib/utils/sendMessage.ts b/lib/utils/postToConnection.ts similarity index 94% rename from lib/utils/sendMessage.ts rename to lib/utils/postToConnection.ts index 6c501852..9ad879e6 100644 --- a/lib/utils/sendMessage.ts +++ b/lib/utils/postToConnection.ts @@ -11,7 +11,7 @@ import { ServerClosure } from '../types' type GraphqlWSMessages = ConnectionAckMessage | NextMessage | CompleteMessage | ErrorMessage | PingMessage | PongMessage -export const sendMessage = (server: ServerClosure) => +export const postToConnection = (server: ServerClosure) => async ({ connectionId: ConnectionId, domainName, diff --git a/mocks/arc-basic-events/lib/graphql.js b/mocks/arc-basic-events/lib/graphql.js index 40ea48f2..9a2200fa 100644 --- a/mocks/arc-basic-events/lib/graphql.js +++ b/mocks/arc-basic-events/lib/graphql.js @@ -174,10 +174,43 @@ const buildSubscriptionServer = async () => { apiGatewayManagementApi: makeManagementAPI(), onError: err => { console.log('onError', err.message) - // throw err }, }) return server } -module.exports = { buildSubscriptionServer } +const fetchTableNames = async () => { + const tables = await arcTables() + + const ensureName = (table) => { + const actualTableName = tables.name(table) + if (!actualTableName) { + throw new Error(`No table found for ${table}`) + } + return actualTableName + } + + return { + connections: ensureName('Connection'), + subscriptions: ensureName('Subscription'), + } + +} + +const subscriptionServer = createInstance({ + dynamodb: arcTables.db, + schema, + context: () => { + return { + publish: subscriptionServer.publish, + complete: subscriptionServer.complete, + } + }, + tableNames: fetchTableNames(), + apiGatewayManagementApi: makeManagementAPI(), + onError: err => { + console.log('onError', err.message) + }, +}) + +module.exports = { subscriptionServer, buildSubscriptionServer } diff --git a/mocks/arc-basic-events/src/ws/connect/index.js b/mocks/arc-basic-events/src/ws/connect/index.js index 41d7c5c7..5968bbb3 100644 --- a/mocks/arc-basic-events/src/ws/connect/index.js +++ b/mocks/arc-basic-events/src/ws/connect/index.js @@ -1,9 +1,3 @@ -const { buildSubscriptionServer } = require('../../../lib/graphql') +const { subscriptionServer } = require('../../../lib/graphql') -const serverPromise = buildSubscriptionServer() - -exports.handler = async function connect (event) { - // console.log('connect') - const server = await serverPromise - return server.gatewayHandler(event) -} +exports.handler = subscriptionServer.webSocketHandler diff --git a/mocks/arc-basic-events/src/ws/default/index.js b/mocks/arc-basic-events/src/ws/default/index.js index 4872a408..5968bbb3 100644 --- a/mocks/arc-basic-events/src/ws/default/index.js +++ b/mocks/arc-basic-events/src/ws/default/index.js @@ -1,9 +1,3 @@ -const { buildSubscriptionServer } = require('../../../lib/graphql') +const { subscriptionServer } = require('../../../lib/graphql') -const serverPromise = buildSubscriptionServer() - -exports.handler = async function ws (event) { - // console.log('default', event) - const server = await serverPromise - return server.gatewayHandler(event) -} +exports.handler = subscriptionServer.webSocketHandler diff --git a/mocks/arc-basic-events/src/ws/disconnect/index.js b/mocks/arc-basic-events/src/ws/disconnect/index.js index f9bb02b3..5968bbb3 100644 --- a/mocks/arc-basic-events/src/ws/disconnect/index.js +++ b/mocks/arc-basic-events/src/ws/disconnect/index.js @@ -1,9 +1,3 @@ -const { buildSubscriptionServer } = require('../../../lib/graphql') +const { subscriptionServer } = require('../../../lib/graphql') -const serverPromise = buildSubscriptionServer() - -exports.handler = async function ws (event) { - // console.log('disconnect') - const server = await serverPromise - return server.gatewayHandler(event) -} +exports.handler = subscriptionServer.webSocketHandler diff --git a/package.json b/package.json index 8d50b596..3751d581 100644 --- a/package.json +++ b/package.json @@ -36,7 +36,6 @@ "@aws/dynamodb-data-mapper": "^0.7.3", "@aws/dynamodb-data-mapper-annotations": "^0.7.3", "@aws/dynamodb-expressions": "^0.7.3", - "aggregate-error": "^4.0.0", "debug": "^4.3.2", "streaming-iterables": "^6.0.0" }, @@ -51,19 +50,20 @@ "@microsoft/api-extractor": "^7.18.4", "@types/architect__sandbox": "^3.3.3", "@types/aws-lambda": "^8.10.81", - "@types/chai": "^4.2.19", "@types/chai-subset": "^1.3.3", + "@types/chai": "^4.2.19", "@types/debug": "^4.1.7", "@types/mocha": "^9.0.0", "@types/node": "^16.6.0", "@types/ws": "^7.4.7", "@typescript-eslint/eslint-plugin": "^4.27.0", "@typescript-eslint/parser": "^4.27.0", + "aggregate-error": "^4.0.0", "aws-sdk": ">= 2.844.0", - "chai": "^4.3.4", "chai-subset": "^1.6.0", - "esbuild": "^0.12.20", + "chai": "^4.3.4", "esbuild-register": "^3.0.0", + "esbuild": "^0.12.20", "eslint": "^7.29.0", "graphql": ">= 14.0.0", "graphql-ws": "^5.3.0", diff --git a/rollup.config.js b/rollup.config.js index 2cbeccfb..809b9b22 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -3,7 +3,9 @@ import resolve from 'rollup-plugin-node-resolve' export default { input: './dist-ts/index.js', plugins: [ - resolve({}), + resolve({ + preferBuiltins: true, + }), ], output: [ { format: 'esm', file: './dist/index-esm.mjs' },