diff --git a/.eslintrc.js b/.eslintrc.js index bf287168..a98af63e 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -16,22 +16,24 @@ module.exports = { '@typescript-eslint', ], rules: { - indent: [ 'error', 2 ], - 'linebreak-style': [ 'error', 'unix' ], - quotes: [ 'error', 'single' ], - semi: 'off', - '@typescript-eslint/semi': ['error', 'never'], - 'quote-props': ['error', 'as-needed'], - 'no-param-reassign': 'error', - 'comma-dangle': ['error', 'always-multiline'], - 'space-infix-ops': ['error'], - 'no-multi-spaces': ['error'], - 'no-unused-vars': 'off', - '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }], '@typescript-eslint/member-delimiter-style': ['error', { multiline: { delimiter: 'none' }, singleline: { delimiter: 'comma', requireLast: false }, multilineDetection: 'last-member', }], + '@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }], + '@typescript-eslint/semi': ['error', 'never'], + 'array-bracket-spacing': ['error', 'never', { singleValue: false }], + 'comma-dangle': ['error', 'always-multiline'], + 'linebreak-style': ['error', 'unix'], + 'no-multi-spaces': ['error'], + 'no-param-reassign': 'error', + 'no-unused-vars': 'off', + 'object-curly-spacing': ['error', 'always'], + 'quote-props': ['error', 'as-needed'], + 'space-infix-ops': ['error'], + indent: ['error', 2], + quotes: ['error', 'single'], + semi: 'off', }, } diff --git a/lib/gateway.ts b/lib/gateway.ts index e30ce900..49b27aea 100644 --- a/lib/gateway.ts +++ b/lib/gateway.ts @@ -12,81 +12,79 @@ import { subscribe } from './messages/subscribe' import { connection_init } from './messages/connection_init' import { pong } from './messages/pong' -export const handleGatewayEvent = - (server: ServerClosure): ApiGatewayHandler => - async (event) => { - if (!event.requestContext) { - return { - statusCode: 200, - body: '', - } - } - - if (event.requestContext.eventType === 'CONNECT') { - await server.onConnect?.({ event }) - return { - statusCode: 200, - headers: { - 'Sec-WebSocket-Protocol': GRAPHQL_TRANSPORT_WS_PROTOCOL, - }, - body: '', - } - } - - if (event.requestContext.eventType === 'MESSAGE') { - const message = JSON.parse(event.body!) +export const handleGatewayEvent = (server: ServerClosure): ApiGatewayHandler => async (event) => { + if (!event.requestContext) { + return { + statusCode: 200, + body: '', + } + } - if (message.type === MessageType.ConnectionInit) { - await connection_init({ server, event, message }) - return { - statusCode: 200, - body: '', - } - } + if (event.requestContext.eventType === 'CONNECT') { + await server.onConnect?.({ event }) + return { + statusCode: 200, + headers: { + 'Sec-WebSocket-Protocol': GRAPHQL_TRANSPORT_WS_PROTOCOL, + }, + body: '', + } + } - if (message.type === MessageType.Subscribe) { - await subscribe({ server, event, message }) - return { - statusCode: 200, - body: '', - } - } + if (event.requestContext.eventType === 'MESSAGE') { + const message = event.body === null ? null : JSON.parse(event.body) - if (message.type === MessageType.Complete) { - await complete({ server, event, message }) - return { - statusCode: 200, - body: '', - } - } + if (message.type === MessageType.ConnectionInit) { + await connection_init({ server, event, message }) + return { + statusCode: 200, + body: '', + } + } - if (message.type === MessageType.Ping) { - await ping({ server, event, message }) - return { - statusCode: 200, - body: '', - } - } + if (message.type === MessageType.Subscribe) { + await subscribe({ server, event, message }) + return { + statusCode: 200, + body: '', + } + } - if (message.type === MessageType.Pong) { - await pong({ server, event, message }) - return { - statusCode: 200, - body: '', - } - } + if (message.type === MessageType.Complete) { + await complete({ server, event, message }) + return { + statusCode: 200, + body: '', } + } - if (event.requestContext.eventType === 'DISCONNECT') { - await disconnect({ server, event, message: null }) - return { - statusCode: 200, - body: '', - } + if (message.type === MessageType.Ping) { + await ping({ server, event, message }) + return { + statusCode: 200, + body: '', } + } + if (message.type === MessageType.Pong) { + await pong({ server, event, message }) return { statusCode: 200, body: '', } } + } + + if (event.requestContext.eventType === 'DISCONNECT') { + await disconnect({ server, event, message: null }) + return { + statusCode: 200, + body: '', + } + } + + return { + statusCode: 200, + body: '', + } +} diff --git a/lib/messages/complete.ts b/lib/messages/complete.ts index 13b80c28..82acca74 100644 --- a/lib/messages/complete.ts +++ b/lib/messages/complete.ts @@ -2,53 +2,47 @@ import AggregateError from 'aggregate-error' import { parse } from 'graphql' import { CompleteMessage } from 'graphql-ws' import { buildExecutionContext } from 'graphql/execution/execute' -import { SubscribePsuedoIterable, MessageHandler } from '../types' -import { deleteConnection } from '../utils/aws' -import { constructContext, getResolverAndArgs } from '../utils/graphql' +import { collect } from 'streaming-iterables' +import { SubscribePseudoIterable, MessageHandler } from '../types' +import { deleteConnection } from '../utils/deleteConnection' +import { constructContext } from '../utils/constructContext' +import { getResolverAndArgs } from '../utils/getResolverAndArgs' import { isArray } from '../utils/isArray' /** Handler function for 'complete' message. */ export const complete: MessageHandler = async ({ server, event, message }) => { try { - const topicSubscriptions = server.mapper.query(server.model.Subscription, { - id: `${event.requestContext.connectionId!}|${message.id}`, - }) - let deletions = [] as Promise[] - for await (const entity of topicSubscriptions) { - deletions = [ - ...deletions, - (async () => { - // only call onComplete per subscription - if (deletions.length === 0) { - const execContext = buildExecutionContext( - server.schema, - parse(entity.subscription.query), - undefined, - await constructContext(server)(entity), - entity.subscription.variables, - entity.subscription.operationName, - undefined, - ) - - if (isArray(execContext)) { - throw new AggregateError(execContext) - } + const topicSubscriptions = await collect(server.mapper.query(server.model.Subscription, { + id: `${event.requestContext.connectionId}|${message.id}`, + })) + if (topicSubscriptions.length === 0) { + return + } + // only call onComplete on the first one as any others are duplicates + const sub = topicSubscriptions[0] + const execContext = buildExecutionContext( + server.schema, + parse(sub.subscription.query), + undefined, + await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + sub.subscription.variables, + sub.subscription.operationName, + undefined, + ) - const [field, root, args, context, info] = getResolverAndArgs(server)(execContext) + if (isArray(execContext)) { + throw new AggregateError(execContext) + } - const onComplete = (field?.subscribe as SubscribePsuedoIterable)?.onComplete - if (onComplete) { - await onComplete(root, args, context, info) - } - } + const [field, root, args, context, info] = getResolverAndArgs(server)(execContext) - await server.mapper.delete(entity) - })(), - ] + const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete + if (onComplete) { + await onComplete(root, args, context, info) } - await Promise.all(deletions) + await Promise.all(topicSubscriptions.map(sub => server.mapper.delete(sub))) } catch (err) { await server.onError?.(err, { event, message }) await deleteConnection(server)(event.requestContext) diff --git a/lib/messages/connection_init.ts b/lib/messages/connection_init.ts index bcfbc374..5ca65a1f 100644 --- a/lib/messages/connection_init.ts +++ b/lib/messages/connection_init.ts @@ -1,7 +1,8 @@ import { StepFunctions } from 'aws-sdk' import { ConnectionInitMessage, MessageType } from 'graphql-ws' import { StateFunctionInput, MessageHandler } from '../types' -import { deleteConnection, sendMessage } from '../utils/aws' +import { sendMessage } from '../utils/sendMessage' +import { deleteConnection } from '../utils/deleteConnection' /** Handler function for 'connection_init' message. */ export const connection_init: MessageHandler = @@ -15,10 +16,10 @@ export const connection_init: MessageHandler = await new StepFunctions() .startExecution({ stateMachineArn: server.pingpong.machine, - name: event.requestContext.connectionId!, + name: event.requestContext.connectionId, input: JSON.stringify({ - connectionId: event.requestContext.connectionId!, - domainName: event.requestContext.domainName!, + connectionId: event.requestContext.connectionId, + domainName: event.requestContext.domainName, stage: event.requestContext.stage, state: 'PING', choice: 'WAIT', @@ -30,7 +31,7 @@ export const connection_init: MessageHandler = // Write to persistence const connection = Object.assign(new server.model.Connection(), { - id: event.requestContext.connectionId!, + id: event.requestContext.connectionId, requestContext: event.requestContext, payload: res, }) diff --git a/lib/messages/disconnect.ts b/lib/messages/disconnect.ts index 234775fe..6fb0092b 100644 --- a/lib/messages/disconnect.ts +++ b/lib/messages/disconnect.ts @@ -2,9 +2,12 @@ import AggregateError from 'aggregate-error' import { parse } from 'graphql' import { equals } from '@aws/dynamodb-expressions' import { buildExecutionContext } from 'graphql/execution/execute' -import { constructContext, getResolverAndArgs } from '../utils/graphql' -import { SubscribePsuedoIterable, MessageHandler } from '../types' +import { constructContext } from '../utils/constructContext' +import { getResolverAndArgs } from '../utils/getResolverAndArgs' +import { SubscribePseudoIterable, MessageHandler } from '../types' import { isArray } from '../utils/isArray' +import { collect } from 'streaming-iterables' +import { Connection } from '../model/Connection' /** Handler function for 'disconnect' message. */ export const disconnect: MessageHandler = @@ -12,30 +15,30 @@ export const disconnect: MessageHandler = try { await server.onDisconnect?.({ event }) - const entities = server.mapper.query( + const topicSubscriptions = await collect(server.mapper.query( server.model.Subscription, { connectionId: equals(event.requestContext.connectionId), }, { indexName: 'ConnectionIndex' }, - ) + )) const completed = {} as Record - const deletions = [] as Promise[] - for await (const entity of entities) { + const deletions = [] as Promise[] + for (const sub of topicSubscriptions) { deletions.push( (async () => { // only call onComplete per subscription - if (!completed[entity.subscriptionId]) { - completed[entity.subscriptionId] = true + if (!completed[sub.subscriptionId]) { + completed[sub.subscriptionId] = true const execContext = buildExecutionContext( server.schema, - parse(entity.subscription.query), + parse(sub.subscription.query), undefined, - await constructContext(server)(entity), - entity.subscription.variables, - entity.subscription.operationName, + await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), + sub.subscription.variables, + sub.subscription.operationName, undefined, ) @@ -46,13 +49,13 @@ export const disconnect: MessageHandler = const [field, root, args, context, info] = getResolverAndArgs(server)(execContext) - const onComplete = (field?.subscribe as SubscribePsuedoIterable)?.onComplete + const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete if (onComplete) { await onComplete(root, args, context, info) } } - await server.mapper.delete(entity) + await server.mapper.delete(sub) })(), ) } @@ -63,7 +66,7 @@ export const disconnect: MessageHandler = // Delete connection server.mapper.delete( Object.assign(new server.model.Connection(), { - id: event.requestContext.connectionId!, + id: event.requestContext.connectionId, }), ), ]) diff --git a/lib/messages/ping.ts b/lib/messages/ping.ts index 90a17713..0e13cfac 100644 --- a/lib/messages/ping.ts +++ b/lib/messages/ping.ts @@ -1,18 +1,18 @@ import { PingMessage, MessageType } from 'graphql-ws' -import { deleteConnection, sendMessage } from '../utils/aws' +import { sendMessage } from '../utils/sendMessage' +import { deleteConnection } from '../utils/deleteConnection' import { MessageHandler } from '../types' /** Handler function for 'ping' message. */ -export const ping: MessageHandler = - async ({ server, event, message }) => { - try { - await server.onPing?.({ event, message }) - return sendMessage(server)({ - ...event.requestContext, - message: { type: MessageType.Pong }, - }) - } catch (err) { - await server.onError?.(err, { event, message }) - await deleteConnection(server)(event.requestContext) - } +export const ping: MessageHandler = async ({ server, event, message }) => { + try { + await server.onPing?.({ event, message }) + return sendMessage(server)({ + ...event.requestContext, + message: { type: MessageType.Pong }, + }) + } catch (err) { + await server.onError?.(err, { event, message }) + await deleteConnection(server)(event.requestContext) } +} diff --git a/lib/messages/pong.ts b/lib/messages/pong.ts index 4a04ac79..afea4b76 100644 --- a/lib/messages/pong.ts +++ b/lib/messages/pong.ts @@ -1,5 +1,5 @@ import { PongMessage } from 'graphql-ws' -import { deleteConnection } from '../utils/aws' +import { deleteConnection } from '../utils/deleteConnection' import { MessageHandler } from '../types' /** Handler function for 'pong' message. */ @@ -9,7 +9,7 @@ export const pong: MessageHandler = await server.onPong?.({ event, message }) await server.mapper.update( Object.assign(new server.model.Connection(), { - id: event.requestContext.connectionId!, + id: event.requestContext.connectionId, hasPonged: true, }), { diff --git a/lib/messages/subscribe-test.ts b/lib/messages/subscribe-test.ts index fe06cd09..547418e7 100644 --- a/lib/messages/subscribe-test.ts +++ b/lib/messages/subscribe-test.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { assert } from 'chai' import { tables } from '@architect/sandbox' import { subscribe } from './subscribe' @@ -127,7 +128,7 @@ describe('messages/subscribe', () => { throw new Error('don\'t subscribe!') }, }), - resolve: ({payload}) => { + resolve: ({ payload }) => { return payload }, }, @@ -180,7 +181,7 @@ describe('messages/subscribe', () => { events.push('onAfterSubscribe') }, }), - resolve: ({payload}) => { + resolve: ({ payload }) => { return payload }, }, diff --git a/lib/messages/subscribe.ts b/lib/messages/subscribe.ts index 7b427e06..aa907b11 100644 --- a/lib/messages/subscribe.ts +++ b/lib/messages/subscribe.ts @@ -7,8 +7,10 @@ import { execute, } from 'graphql/execution/execute' import { APIGatewayWebSocketEvent, ServerClosure, SubscribeHandler, MessageHandler } from '../types' -import { constructContext, getResolverAndArgs } from '../utils/graphql' -import { deleteConnection, sendMessage } from '../utils/aws' +import { constructContext } from '../utils/constructContext' +import { getResolverAndArgs } from '../utils/getResolverAndArgs' +import { sendMessage } from '../utils/sendMessage' +import { deleteConnection } from '../utils/deleteConnection' import { isArray } from '../utils/isArray' /** Handler function for 'subscribe' message. */ @@ -25,7 +27,7 @@ export const subscribe: MessageHandler = const setupSubscription: MessageHandler = async ({ server, event, message }) => { const connection = await server.mapper.get( Object.assign(new server.model.Connection(), { - id: event.requestContext.connectionId!, + id: event.requestContext.connectionId, }), ) const connectionParams = connection.payload || {} @@ -37,7 +39,7 @@ const setupSubscription: MessageHandler = async ({ server, eve throw new AggregateError(errors) } - const contextValue = await constructContext(server)({ connectionParams }) + const contextValue = await constructContext({ server, connectionParams, connectionId: connection.id }) const execContext = buildExecutionContext( server.schema, @@ -91,7 +93,7 @@ const setupSubscription: MessageHandler = async ({ server, eve variableValues: args, ...message.payload, }, - connectionId: event.requestContext.connectionId!, + connectionId: event.requestContext.connectionId, connectionParams, requestContext: event.requestContext, ttl: connection.ttl, @@ -121,6 +123,7 @@ const validateMessage = (server: ServerClosure) => (message: SubscribeMessage) = } } +// eslint-disable-next-line @typescript-eslint/no-explicit-any async function executeQuery(server: ServerClosure, message: SubscribeMessage, contextValue: any, event: APIGatewayWebSocketEvent) { const result = await execute( server.schema, diff --git a/lib/model/Subscription.ts b/lib/model/Subscription.ts index cd1dd419..bf5b0c77 100644 --- a/lib/model/Subscription.ts +++ b/lib/model/Subscription.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/ban-types */ import { attribute, diff --git a/lib/model/createModel.ts b/lib/model/createModel.ts index 1055f4eb..606ebb97 100644 --- a/lib/model/createModel.ts +++ b/lib/model/createModel.ts @@ -7,7 +7,7 @@ export const createModel = ({ }: { table: string model: T -}) => { +}): T => { Object.defineProperties(model.prototype, { [DynamoDbTable]: { value: table }, }) diff --git a/lib/pubsub/complete.ts b/lib/pubsub/complete.ts index 74dba807..d06d9cbc 100644 --- a/lib/pubsub/complete.ts +++ b/lib/pubsub/complete.ts @@ -2,30 +2,31 @@ import AggregateError from 'aggregate-error' import { parse } from 'graphql' import { CompleteMessage, MessageType } from 'graphql-ws' import { buildExecutionContext } from 'graphql/execution/execute' -import { ServerClosure, PubSubEvent, SubscribePsuedoIterable } from '../types' -import { sendMessage } from '../utils/aws' -import { constructContext, getResolverAndArgs } from '../utils/graphql' +import { ServerClosure, PubSubEvent, SubscribePseudoIterable } from '../types' +import { sendMessage } from '../utils/sendMessage' +import { constructContext } from '../utils/constructContext' +import { getResolverAndArgs } from '../utils/getResolverAndArgs' import { isArray } from '../utils/isArray' import { getFilteredSubs } from './getFilteredSubs' -export const complete = (c: ServerClosure) => async (event: PubSubEvent): Promise => { - const subscriptions = await getFilteredSubs(c)(event) +export const complete = (server: ServerClosure) => async (event: PubSubEvent): Promise => { + const subscriptions = await getFilteredSubs({ server, event }) const iters = subscriptions.map(async (sub) => { const message: CompleteMessage = { id: sub.subscriptionId, type: MessageType.Complete, } - await sendMessage(c)({ + await sendMessage(server)({ ...sub.requestContext, message, }) - await c.mapper.delete(sub) + await server.mapper.delete(sub) const execContext = buildExecutionContext( - c.schema, + server.schema, parse(sub.subscription.query), undefined, - await constructContext(c)(sub), + await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, @@ -35,9 +36,9 @@ export const complete = (c: ServerClosure) => async (event: PubSubEvent): Promis throw new AggregateError(execContext) } - const [field, root, args, context, info] = getResolverAndArgs(c)(execContext) + const [field, root, args, context, info] = getResolverAndArgs(server)(execContext) - const onComplete = (field?.subscribe as SubscribePsuedoIterable)?.onComplete + const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete if (onComplete) { await onComplete(root, args, context, info) } diff --git a/lib/pubsub/getFilteredSubs.ts b/lib/pubsub/getFilteredSubs.ts index f4609235..23a18417 100644 --- a/lib/pubsub/getFilteredSubs.ts +++ b/lib/pubsub/getFilteredSubs.ts @@ -4,50 +4,44 @@ import { equals, ConditionExpression, } from '@aws/dynamodb-expressions' +import { collect } from 'streaming-iterables' import { Subscription } from '../model/Subscription' import { ServerClosure, PubSubEvent } from '../types' -export const getFilteredSubs = (c: Omit) => - async (event: PubSubEvent): Promise => { - const flattenPayload = flatten(event.payload) - const iterator = c.mapper.query( - c.model.Subscription, - { topic: equals(event.topic) }, - { - filter: { - type: 'And', - conditions: Object.entries(flattenPayload).reduce( - (p, [key, value]) => [ - ...p, - { - type: 'Or', - conditions: [ - { - ...attributeNotExists(), - subject: `filter.${key}`, - }, - { - ...equals(value), - subject: `filter.${key}`, - }, - ], - }, - ], - [] as ConditionExpression[], - ), - }, - indexName: 'TopicIndex', +export const getFilteredSubs = async ({ server, event }: { server: Omit, event: PubSubEvent }): Promise => { + const flattenPayload = flatten(event.payload) + const iterator = server.mapper.query( + server.model.Subscription, + { topic: equals(event.topic) }, + { + filter: { + type: 'And', + conditions: Object.entries(flattenPayload).reduce( + (p, [key, value]) => [ + ...p, + { + type: 'Or', + conditions: [ + { + ...attributeNotExists(), + subject: `filter.${key}`, + }, + { + ...equals(value), + subject: `filter.${key}`, + }, + ], + }, + ], + [] as ConditionExpression[], + ), }, - ) + indexName: 'TopicIndex', + }, + ) - // Aggregate all targets - const subs: Subscription[] = [] - for await (const sub of iterator) { - subs.push(sub) - } - - return subs - } + return await collect(iterator) +} export const flatten = ( obj: object, @@ -71,8 +65,8 @@ export const flatten = ( } if (typeof v1 === 'string' || - typeof v1 === 'number' || - typeof v1 === 'boolean') { + typeof v1 === 'number' || + typeof v1 === 'boolean') { return { ...p, [k1]: v1 } } diff --git a/lib/pubsub/publish.ts b/lib/pubsub/publish.ts index eaf45af5..8828d1cb 100644 --- a/lib/pubsub/publish.ts +++ b/lib/pubsub/publish.ts @@ -1,18 +1,18 @@ import { parse, execute } from 'graphql' import { MessageType, NextMessage } from 'graphql-ws' import { PubSubEvent, ServerClosure } from '../types' -import { sendMessage } from '../utils/aws' -import { constructContext } from '../utils/graphql' +import { sendMessage } from '../utils/sendMessage' +import { constructContext } from '../utils/constructContext' import { getFilteredSubs } from './getFilteredSubs' -export const publish = (c: ServerClosure) => async (event: PubSubEvent): Promise => { - const subscriptions = await getFilteredSubs(c)(event) +export const publish = (server: ServerClosure) => async (event: PubSubEvent): Promise => { + const subscriptions = await getFilteredSubs({ server, event }) const iters = subscriptions.map(async (sub) => { const payload = await execute( - c.schema, + server.schema, parse(sub.subscription.query), event, - await constructContext(c)(sub), + await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }), sub.subscription.variables, sub.subscription.operationName, undefined, @@ -24,7 +24,7 @@ export const publish = (c: ServerClosure) => async (event: PubSubEvent): Promise payload, } - await sendMessage(c)({ + await sendMessage(server)({ ...sub.requestContext, message, }) diff --git a/lib/pubsub/subscribe.ts b/lib/pubsub/subscribe.ts index 9db1afb9..927f412d 100644 --- a/lib/pubsub/subscribe.ts +++ b/lib/pubsub/subscribe.ts @@ -1,4 +1,4 @@ -import { SubscribeArgs, SubscribeHandler, SubscribeOptions, SubscribePsuedoIterable, SubscriptionDefinition } from '../types' +import { SubscribeArgs, SubscribeHandler, SubscribeOptions, SubscribePseudoIterable, SubscriptionDefinition } from '../types' /** Creates subscribe handler */ export const subscribe = ( @@ -8,7 +8,7 @@ export const subscribe = ( onSubscribe, onComplete, onAfterSubscribe, - }: SubscribeOptions = {}): (...args: SubscribeArgs) => SubscribePsuedoIterable => { + }: SubscribeOptions = {}): (...args: SubscribeArgs) => SubscribePseudoIterable => { return (...args: SubscribeArgs) => { const handler = createHandler([{ topic, @@ -24,17 +24,13 @@ export const subscribe = ( } /** Merge multiple subscribe handlers */ -export const concat = - (...handlers: SubscribeHandler[]) => - (...args: SubscribeArgs): SubscribePsuedoIterable => - createHandler( handlers.map((h) => h(...args).topicDefinitions).flat() ) +export const concat = (...handlers: SubscribeHandler[]) => (...args: SubscribeArgs): SubscribePseudoIterable => createHandler(handlers.map((h) => h(...args).topicDefinitions).flat()) const createHandler = (topicDefinitions: SubscriptionDefinition[]) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - // eslint-disable-next-line require-yield - const handler: any = async function *() { + // eslint-disable-next-line @typescript-eslint/no-explicit-any,require-yield + const handler: any = async function* () { throw new Error('Subscription handler should not have been called') } handler.topicDefinitions = topicDefinitions - return handler as SubscribePsuedoIterable + return handler as SubscribePseudoIterable } diff --git a/lib/stepFunctionHandler.ts b/lib/stepFunctionHandler.ts index 8c18097e..6b26f159 100644 --- a/lib/stepFunctionHandler.ts +++ b/lib/stepFunctionHandler.ts @@ -1,40 +1,42 @@ import { MessageType } from 'graphql-ws' import { ServerClosure, StateFunctionInput } from './types' -import { deleteConnection, sendMessage } from './utils/aws' +import { sendMessage } from './utils/sendMessage' +import { deleteConnection } from './utils/deleteConnection' -export const handleStateMachineEvent = - (c: ServerClosure) => - async (input: StateFunctionInput): Promise => { - const connection = Object.assign(new c.model.Connection(), { - id: input.connectionId, - }) +export const handleStateMachineEvent = (c: ServerClosure) => async (input: StateFunctionInput): Promise => { + if (!c.pingpong) { + throw new Error('Invalid pingpong settings') + } + const connection = Object.assign(new c.model.Connection(), { + id: input.connectionId, + }) - // Initial state - send ping message - if (input.state === 'PING') { - await sendMessage(c)({ ...input, message: { type: MessageType.Ping } }) - await c.mapper.update(Object.assign(connection, { hasPonged: false }), { - onMissing: 'skip', - }) - return { - ...input, - state: 'REVIEW', - seconds: c.pingpong!.delay, - } - } - - // Follow up state - check if pong was returned - const conn = await c.mapper.get(connection) - if (conn.hasPonged) { - return { - ...input, - state: 'PING', - seconds: c.pingpong!.timeout, - } - } + // Initial state - send ping message + if (input.state === 'PING') { + await sendMessage(c)({ ...input, message: { type: MessageType.Ping } }) + await c.mapper.update(Object.assign(connection, { hasPonged: false }), { + onMissing: 'skip', + }) + return { + ...input, + state: 'REVIEW', + seconds: c.pingpong.delay, + } + } - await deleteConnection(c)({ ...input }) - return { - ...input, - state: 'ABORT', - } + // Follow up state - check if pong was returned + const conn = await c.mapper.get(connection) + if (conn.hasPonged) { + return { + ...input, + state: 'PING', + seconds: c.pingpong.timeout, } + } + + await deleteConnection(c)({ ...input }) + return { + ...input, + state: 'ABORT', + } +} diff --git a/lib/test/integration-basic-events-test.ts b/lib/test/integration-basic-events-test.ts index 6e67a712..ae164639 100644 --- a/lib/test/integration-basic-events-test.ts +++ b/lib/test/integration-basic-events-test.ts @@ -37,7 +37,7 @@ const executeSubscription = async (query: string) => { const unsubscribe = client.subscribe( { query }, { - next: ({data}) => { + next: ({ data }) => { values.queueValue(data) }, error: (error: Error) => { diff --git a/lib/test/mockServer.ts b/lib/test/mockServer.ts index 8dc9b11f..b7778caf 100644 --- a/lib/test/mockServer.ts +++ b/lib/test/mockServer.ts @@ -22,7 +22,7 @@ const resolvers = { Subscription: { greetings:{ subscribe: subscribe('greetings'), - resolve: ({payload}) => { + resolve: ({ payload }) => { return payload }, }, @@ -34,6 +34,7 @@ const schema = makeExecutableSchema({ resolvers, }) +// eslint-disable-next-line @typescript-eslint/no-explicit-any const ensureName = (tables: any, table: string) => { const actualTableName = tables.name(table) if (!actualTableName) { diff --git a/lib/types.ts b/lib/types.ts index ab0cd069..d30ce0b8 100644 --- a/lib/types.ts +++ b/lib/types.ts @@ -12,7 +12,7 @@ export type ServerArgs = { schema: GraphQLSchema dynamodb: DynamoDB apiGatewayManagementApi?: ApiGatewayManagementApiSubset - context?: ((arg: { connectionParams: any }) => object) | object + context?: ((arg: { connectionParams: any, connectionId: string }) => MaybePromise) | object tableNames?: Partial pingpong?: { machine: string @@ -34,7 +34,7 @@ export type ServerArgs = { event: APIGatewayWebSocketEvent message: PongMessage }) => MaybePromise - onError?: (error: any, context: any) => void + onError?: (error: any, context: any) => MaybePromise } export type MaybePromise = T | Promise @@ -70,14 +70,14 @@ export type SubscriptionDefinition = { filter?: object | (() => void) } -export type SubscribeHandler = (...args: any[]) => SubscribePsuedoIterable +export type SubscribeHandler = (...args: any[]) => SubscribePseudoIterable -export type SubscribePsuedoIterable = { +export type SubscribePseudoIterable = { (...args: SubscribeArgs): AsyncGenerator topicDefinitions: SubscriptionDefinition[] - onSubscribe?: (...args: SubscribeArgs) => void | Promise - onComplete?: (...args: SubscribeArgs) => void | Promise - onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | void | Promise + onSubscribe?: (...args: SubscribeArgs) => MaybePromise + onComplete?: (...args: SubscribeArgs) => MaybePromise + onAfterSubscribe?: (...args: SubscribeArgs) => MaybePromise } export type SubscribeArgs = [root: any, args: Record, context: any, info: GraphQLResolveInfo] @@ -92,8 +92,7 @@ export type StateFunctionInput = { seconds: number } -export interface APIGatewayWebSocketRequestContext - extends APIGatewayEventRequestContext { +export interface APIGatewayWebSocketRequestContext extends APIGatewayEventRequestContext { connectionId: string domainName: string } @@ -119,9 +118,9 @@ export interface ApiGatewayManagementApiSubset { export interface SubscribeOptions { filter?: object | ((...args: SubscribeArgs) => object) - onSubscribe?: (...args: SubscribeArgs) => void | Promise - onComplete?: (...args: SubscribeArgs) => void | Promise - onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | void | Promise + onSubscribe?: (...args: SubscribeArgs) => MaybePromise + onComplete?: (...args: SubscribeArgs) => MaybePromise + onAfterSubscribe?: (...args: SubscribeArgs) => MaybePromise } -export type ApiGatewayHandler = (event: TEvent) => void | Promise +export type ApiGatewayHandler = (event: TEvent) => Promise diff --git a/lib/utils/aws.ts b/lib/utils/aws.ts deleted file mode 100644 index 6edd07c9..00000000 --- a/lib/utils/aws.ts +++ /dev/null @@ -1,64 +0,0 @@ -import { ApiGatewayManagementApi } from 'aws-sdk' -import { - ConnectionAckMessage, - NextMessage, - CompleteMessage, - ErrorMessage, - PingMessage, - PongMessage, -} from 'graphql-ws' -import { ServerClosure, APIGatewayWebSocketRequestContext } from '../types' - -export const sendMessage = - (c: ServerClosure) => - async ({ - connectionId: ConnectionId, - domainName, - stage, - message, - }: { - message: - | ConnectionAckMessage - | NextMessage - | CompleteMessage - | ErrorMessage - | PingMessage - | PongMessage - } & Pick< - APIGatewayWebSocketRequestContext, - 'connectionId' | 'domainName' | 'stage' - >): Promise => { - const api = - c.apiGatewayManagementApi ?? - new ApiGatewayManagementApi({ - apiVersion: 'latest', - endpoint: `${domainName}/${stage}`, - }) - - await api - .postToConnection({ - ConnectionId, - Data: JSON.stringify(message), - }) - .promise() - } - -export const deleteConnection = - (c: ServerClosure) => - async ({ - connectionId: ConnectionId, - domainName, - stage, - }: Pick< - APIGatewayWebSocketRequestContext, - 'connectionId' | 'domainName' | 'stage' - >): Promise => { - const api = - c.apiGatewayManagementApi ?? - new ApiGatewayManagementApi({ - apiVersion: 'latest', - endpoint: `${domainName}/${stage}`, - }) - - await api.deleteConnection({ ConnectionId }).promise() - } diff --git a/lib/utils/constructContext.ts b/lib/utils/constructContext.ts new file mode 100644 index 00000000..7945eda7 --- /dev/null +++ b/lib/utils/constructContext.ts @@ -0,0 +1,10 @@ +/* eslint-disable @typescript-eslint/ban-types */ +import { ServerClosure } from '../types' + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export const constructContext = ({ server, connectionParams, connectionId }: { connectionParams: object, server: ServerClosure, connectionId: string }): any => { + if (typeof server.context === 'function') { + return server.context({ connectionParams, connectionId }) + } + return { ...server.context, connectionParams, connectionId } +} diff --git a/lib/utils/deleteConnection.ts b/lib/utils/deleteConnection.ts new file mode 100644 index 00000000..ec80a5bc --- /dev/null +++ b/lib/utils/deleteConnection.ts @@ -0,0 +1,21 @@ +import { ApiGatewayManagementApi } from 'aws-sdk' +import { ServerClosure } from '../types' + +export const deleteConnection = (c: ServerClosure) => + async ({ + connectionId: ConnectionId, + domainName, + stage, + }:{ + connectionId: string + domainName: string + stage: string + }): Promise => { + const api = c.apiGatewayManagementApi ?? + new ApiGatewayManagementApi({ + apiVersion: 'latest', + endpoint: `${domainName}/${stage}`, + }) + + await api.deleteConnection({ ConnectionId }).promise() + } diff --git a/lib/utils/getResolverAndArgs.ts b/lib/utils/getResolverAndArgs.ts new file mode 100644 index 00000000..179143a2 --- /dev/null +++ b/lib/utils/getResolverAndArgs.ts @@ -0,0 +1,50 @@ +import { getOperationRootType } from 'graphql' +import { + buildResolveInfo, + collectFields, + ExecutionContext, + getFieldDef, +} from 'graphql/execution/execute' +import { addPath } from 'graphql/jsutils/Path' +import { ServerClosure } from '../types' + +type ResolverAndArgs = [ReturnType, null, ExecutionContext['variableValues'], ExecutionContext['contextValue'], ReturnType] + +export const getResolverAndArgs = (c: Omit) => (execContext: ExecutionContext): ResolverAndArgs => { + // Taken from graphql js - https://github.com/graphql/graphql-js/blob/main/src/subscription/subscribe.js#L190 + const type = getOperationRootType(c.schema, execContext.operation) + const fields = collectFields( + execContext, + type, + execContext.operation.selectionSet, + Object.create(null), + Object.create(null), + ) + const responseNames = Object.keys(fields) + const responseName = responseNames[0] + const fieldNodes = fields[responseName] + const fieldNode = fieldNodes[0] + const fieldName = fieldNode.name.value + const fieldDef = getFieldDef(c.schema, type, fieldName) + const path = addPath(undefined, responseName, type.name) + + if (!fieldDef) { + throw new Error('invalid schema, unknown field definition') + } + + const info = buildResolveInfo( + execContext, + fieldDef, + fieldNodes, + type, + path, + ) + + return [ + fieldDef, + null, + execContext.variableValues, + execContext.contextValue, + info, + ] +} diff --git a/lib/utils/graphql.ts b/lib/utils/graphql.ts deleted file mode 100644 index 40728aa8..00000000 --- a/lib/utils/graphql.ts +++ /dev/null @@ -1,52 +0,0 @@ -/* eslint-disable @typescript-eslint/ban-types */ -import { getOperationRootType } from 'graphql' -import { - buildResolveInfo, - collectFields, - ExecutionContext, - getFieldDef, -} from 'graphql/execution/execute' -import { addPath } from 'graphql/jsutils/Path' -import { ServerClosure } from '../types' - -export const constructContext = - (c: ServerClosure) => - ({ connectionParams }: { connectionParams: object }) => - typeof c.context === 'function' - ? c.context({ connectionParams }) - : { ...c.context, connectionParams } - -export const getResolverAndArgs = - (c: Omit) => (execContext: ExecutionContext) => { - // Taken from graphql js - https://github.com/graphql/graphql-js/blob/main/src/subscription/subscribe.js#L190 - const type = getOperationRootType(c.schema, execContext.operation) - const fields = collectFields( - execContext, - type, - execContext.operation.selectionSet, - Object.create(null), - Object.create(null), - ) - const responseNames = Object.keys(fields) - const responseName = responseNames[0] - const fieldNodes = fields[responseName] - const fieldNode = fieldNodes[0] - const fieldName = fieldNode.name.value - const fieldDef = getFieldDef(c.schema, type, fieldName) - const path = addPath(undefined, responseName, type.name) - const info = buildResolveInfo( - execContext, - fieldDef!, - fieldNodes, - type, - path, - ) - - return [ - fieldDef, - null, - execContext.variableValues, - execContext.contextValue, - info, - ] as const - } diff --git a/lib/utils/sendMessage.ts b/lib/utils/sendMessage.ts new file mode 100644 index 00000000..53540fa0 --- /dev/null +++ b/lib/utils/sendMessage.ts @@ -0,0 +1,38 @@ +import { ApiGatewayManagementApi } from 'aws-sdk' +import { + ConnectionAckMessage, + NextMessage, + CompleteMessage, + ErrorMessage, + PingMessage, + PongMessage, +} from 'graphql-ws' +import { ServerClosure } from '../types' + +type GraphqlWSMessages = ConnectionAckMessage | NextMessage | CompleteMessage | ErrorMessage | PingMessage | PongMessage + +export const sendMessage = (c: ServerClosure) => + async ({ + connectionId: ConnectionId, + domainName, + stage, + message, + }: { + connectionId: string + domainName: string + stage: string + message: GraphqlWSMessages + }): Promise => { + const api = c.apiGatewayManagementApi ?? + new ApiGatewayManagementApi({ + apiVersion: 'latest', + endpoint: `${domainName}/${stage}`, + }) + + await api + .postToConnection({ + ConnectionId, + Data: JSON.stringify(message), + }) + .promise() + } diff --git a/mocks/arc-basic-events/lib/graphql.js b/mocks/arc-basic-events/lib/graphql.js index 4a972c3b..c0583958 100644 --- a/mocks/arc-basic-events/lib/graphql.js +++ b/mocks/arc-basic-events/lib/graphql.js @@ -39,14 +39,14 @@ const resolvers = { Subscription: { greetings:{ subscribe: subscribe('greetings', { - async onAfterSubscribe(_, __, {publish, complete}) { + async onAfterSubscribe(_, __, { publish, complete }) { await publish({ topic: 'greetings', payload: 'yoyo' }) await publish({ topic: 'greetings', payload: 'hows it' }) await publish({ topic: 'greetings', payload: 'howdy' }) await complete({ topic: 'greetings', payload: 'wtf' }) }, }), - resolve: ({payload}) => { + resolve: ({ payload }) => { return payload }, }, diff --git a/package.json b/package.json index c36a034f..090302d4 100644 --- a/package.json +++ b/package.json @@ -37,7 +37,7 @@ "@aws/dynamodb-data-mapper-annotations": "^0.7.3", "@aws/dynamodb-expressions": "^0.7.3", "aggregate-error": "^4.0.0", - "esbuild": "^0.12.20" + "streaming-iterables": "^6.0.0" }, "peerDependencies": { "aws-sdk": ">= 2.844.0", @@ -58,15 +58,15 @@ "aws-sdk": ">= 2.844.0", "chai": "^4.3.4", "esbuild-register": "^3.0.0", + "esbuild": "^0.12.20", "eslint": "^7.29.0", - "graphql": ">= 14.0.0", "graphql-ws": "^5.3.0", + "graphql": ">= 14.0.0", "inside-out-async": "^1.0.0", "mocha": "^9.0.1", - "rollup": "^2.56.0", "rollup-plugin-node-resolve": "^5.2.0", + "rollup": "^2.56.0", "semantic-release": "^17.4.4", - "streaming-iterables": "^6.0.0", "ts-node": "^10.2.0", "tslib": "^2.3.0", "typescript": "^4.3.5",