From c7938c717f333453540e864a65c5b6f591eec634 Mon Sep 17 00:00:00 2001 From: Francis Gulotta Date: Mon, 16 Aug 2021 10:20:15 -0400 Subject: [PATCH] fix: onSubscribe callback type also test it --- lib/messages/subscribe-test.ts | 107 +++++++++++++++++++++++++++++++-- lib/types.ts | 4 +- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/lib/messages/subscribe-test.ts b/lib/messages/subscribe-test.ts index 6c3241dc..fe06cd09 100644 --- a/lib/messages/subscribe-test.ts +++ b/lib/messages/subscribe-test.ts @@ -5,17 +5,19 @@ import { mockServerContext } from '../test/mockServer' import { connection_init } from './connection_init' import { equals } from '@aws/dynamodb-expressions' import { collect } from 'streaming-iterables' +import { subscribe as pubsubSubscribe } from '../pubsub/subscribe' +import { makeExecutableSchema } from '@graphql-tools/schema' const connectionId = '7rWmyMbMr' const ConnectionId = connectionId const connectionInitEvent: any = { requestContext: { connectedAt: 1628905962601, connectionId, domainName: 'localhost:6001', eventType: 'MESSAGE', messageDirection: 'IN', messageId: 'Pn6evkpk2', requestId: 'gN1MPybyL', requestTimeEpoch: 1628905962602, routeKey: '$default', stage: 'testing' }, isBase64Encoded: false, body: '{"type":"connection_init"}' } describe('messages/subscribe', () => { - before(async () => { + beforeEach(async () => { await tables.start({ cwd: './mocks/arc-basic-events', quiet: true }) }) - after(async () => { + afterEach(async () => { tables.end() }) @@ -101,7 +103,104 @@ describe('messages/subscribe', () => { assert.match(error.message, /Cannot query field "HIHOWEAREYOU" on type "Query"/ ) }) describe('callbacks', () => { - it('fires onSubscribe before subscribing') - it('fires onAfterSubscribe after subscribing') + it('fires onSubscribe before subscribing', async () => { + + const onSubscribe: string[] = [] + + const typeDefs = ` + type Query { + hello: String + } + type Subscription { + greetings: String + } + ` + const resolvers = { + Query: { + hello: () => 'Hello World!', + }, + Subscription: { + greetings:{ + subscribe: pubsubSubscribe('greetings', { + onSubscribe() { + onSubscribe.push('We did it!') + throw new Error('don\'t subscribe!') + }, + }), + resolve: ({payload}) => { + return payload + }, + }, + }, + } + + const schema = makeExecutableSchema({ + typeDefs, + resolvers, + }) + const server = await mockServerContext({ + schema, + }) + const event: any = { requestContext: { connectedAt: 1628889984369, connectionId, domainName: 'localhost:3339', eventType: 'MESSAGE', messageDirection: 'IN', messageId: 'el4MNdOJy', requestId: '0yd7bkvXz', requestTimeEpoch: 1628889984774, routeKey: '$default', stage: 'testing' }, isBase64Encoded: false, body: '{"id":"1234","type":"subscribe","payload":{"query":"subscription { greetings }"}}' } + + await connection_init({ server, event: connectionInitEvent, message: JSON.parse(connectionInitEvent.body) }) + try { + await subscribe({ server, event, message: JSON.parse(event.body) }) + throw new Error('should not have subscribed') + } catch (error) { + assert.equal(error.message, 'don\'t subscribe!') + } + assert.deepEqual(onSubscribe, ['We did it!']) + const subscriptions = await collect(server.mapper.query(server.model.Subscription, { connectionId: equals(event.requestContext.connectionId) }, { indexName: 'ConnectionIndex' })) + assert.isEmpty(subscriptions) + + }) + it('fires onAfterSubscribe after subscribing', async () => { + const events: string[] = [] + + const typeDefs = ` + type Query { + hello: String + } + type Subscription { + greetings: String + } + ` + const resolvers = { + Query: { + hello: () => 'Hello World!', + }, + Subscription: { + greetings:{ + subscribe: pubsubSubscribe('greetings', { + onSubscribe() { + events.push('onSubscribe') + }, + onAfterSubscribe() { + events.push('onAfterSubscribe') + }, + }), + resolve: ({payload}) => { + return payload + }, + }, + }, + } + + const schema = makeExecutableSchema({ + typeDefs, + resolvers, + }) + const server = await mockServerContext({ + schema, + }) + const event: any = { requestContext: { connectedAt: 1628889984369, connectionId, domainName: 'localhost:3339', eventType: 'MESSAGE', messageDirection: 'IN', messageId: 'el4MNdOJy', requestId: '0yd7bkvXz', requestTimeEpoch: 1628889984774, routeKey: '$default', stage: 'testing' }, isBase64Encoded: false, body: '{"id":"1234","type":"subscribe","payload":{"query":"subscription { greetings }"}}' } + + await connection_init({ server, event: connectionInitEvent, message: JSON.parse(connectionInitEvent.body) }) + await subscribe({ server, event, message: JSON.parse(event.body) }) + assert.deepEqual(events, ['onSubscribe', 'onAfterSubscribe']) + const subscriptions = await collect(server.mapper.query(server.model.Subscription, { connectionId: equals(event.requestContext.connectionId) }, { indexName: 'ConnectionIndex' })) + assert.isNotEmpty(subscriptions) + }) }) }) diff --git a/lib/types.ts b/lib/types.ts index 39aed707..ab0cd069 100644 --- a/lib/types.ts +++ b/lib/types.ts @@ -77,7 +77,7 @@ export type SubscribePsuedoIterable = { topicDefinitions: SubscriptionDefinition[] onSubscribe?: (...args: SubscribeArgs) => void | Promise onComplete?: (...args: SubscribeArgs) => void | Promise - onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | undefined | Promise + onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | void | Promise } export type SubscribeArgs = [root: any, args: Record, context: any, info: GraphQLResolveInfo] @@ -121,7 +121,7 @@ export interface SubscribeOptions { filter?: object | ((...args: SubscribeArgs) => object) onSubscribe?: (...args: SubscribeArgs) => void | Promise onComplete?: (...args: SubscribeArgs) => void | Promise - onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | undefined | Promise + onAfterSubscribe?: (...args: SubscribeArgs) => PubSubEvent | Promise | void | Promise } export type ApiGatewayHandler = (event: TEvent) => void | Promise