Skip to content
This repository was archived by the owner on May 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ import { pong } from './messages/pong'

export const handleGatewayEvent = (server: ServerClosure): ApiGatewayHandler<APIGatewayWebSocketEvent, WebsocketResponse> => async (event) => {
if (!event.requestContext) {
server.log('handleGatewayEvent unknown')
return {
statusCode: 200,
body: '',
}
}

if (event.requestContext.eventType === 'CONNECT') {
server.log('handleGatewayEvent CONNECT', { connectionId: event.requestContext.connectionId })
await server.onConnect?.({ event })
return {
statusCode: 200,
Expand All @@ -33,6 +35,7 @@ export const handleGatewayEvent = (server: ServerClosure): ApiGatewayHandler<API

if (event.requestContext.eventType === 'MESSAGE') {
const message = event.body === null ? null : JSON.parse(event.body)
server.log('handleGatewayEvent MESSAGE', { connectionId: event.requestContext.connectionId, type: message.type })

if (message.type === MessageType.ConnectionInit) {
await connection_init({ server, event, message })
Expand Down Expand Up @@ -76,13 +79,15 @@ export const handleGatewayEvent = (server: ServerClosure): ApiGatewayHandler<API
}

if (event.requestContext.eventType === 'DISCONNECT') {
server.log('handleGatewayEvent DISCONNECT', { connectionId: event.requestContext.connectionId })
await disconnect({ server, event, message: null })
return {
statusCode: 200,
body: '',
}
}

server.log('handleGatewayEvent UNKNOWN', { connectionId: event.requestContext.connectionId })
return {
statusCode: 200,
body: '',
Expand Down
2 changes: 1 addition & 1 deletion lib/index-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ describe('createInstance', () => {
})

after(async () => {
tables.end()
await tables.end()
})

it('is type compatible with aws-lambda handler', async () => {
Expand Down
2 changes: 2 additions & 0 deletions lib/makeServerClosure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ import { ServerArgs, ServerClosure } from './types'
import { createModel } from './model/createModel'
import { Subscription } from './model/Subscription'
import { Connection } from './model/Connection'
import { log } from './utils/logger'

export function makeServerClosure(opts: ServerArgs): ServerClosure {
return {
log: log,
...opts,
model: {
Subscription: createModel({
Expand Down
4 changes: 2 additions & 2 deletions lib/messages/complete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { parse } from 'graphql'
import { CompleteMessage } from 'graphql-ws'
import { buildExecutionContext } from 'graphql/execution/execute'
import { collect } from 'streaming-iterables'
import { SubscribePseudoIterable, MessageHandler } from '../types'
import { SubscribePseudoIterable, MessageHandler, PubSubEvent } from '../types'
import { deleteConnection } from '../utils/deleteConnection'
import { constructContext } from '../utils/constructContext'
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
Expand Down Expand Up @@ -37,7 +37,7 @@ export const complete: MessageHandler<CompleteMessage> =

const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)

const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete
const onComplete = (field?.subscribe as SubscribePseudoIterable<PubSubEvent>)?.onComplete
if (onComplete) {
await onComplete(root, args, context, info)
}
Expand Down
4 changes: 2 additions & 2 deletions lib/messages/disconnect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { equals } from '@aws/dynamodb-expressions'
import { buildExecutionContext } from 'graphql/execution/execute'
import { constructContext } from '../utils/constructContext'
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
import { SubscribePseudoIterable, MessageHandler } from '../types'
import { SubscribePseudoIterable, MessageHandler, PubSubEvent } from '../types'
import { isArray } from '../utils/isArray'
import { collect } from 'streaming-iterables'
import { Connection } from '../model/Connection'
Expand Down Expand Up @@ -49,7 +49,7 @@ export const disconnect: MessageHandler<null> =

const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)

const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete
const onComplete = (field?.subscribe as SubscribePseudoIterable<PubSubEvent>)?.onComplete
if (onComplete) {
await onComplete(root, args, context, info)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/messages/subscribe-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ describe('messages/subscribe', () => {
})

afterEach(async () => {
tables.end()
await tables.end()
})

it('executes a query/mutation', async () => {
Expand Down
28 changes: 16 additions & 12 deletions lib/messages/subscribe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
assertValidExecutionArguments,
execute,
} from 'graphql/execution/execute'
import { APIGatewayWebSocketEvent, ServerClosure, SubscribeHandler, MessageHandler } from '../types'
import { APIGatewayWebSocketEvent, ServerClosure, MessageHandler, SubscribePseudoIterable, PubSubEvent } from '../types'
import { constructContext } from '../utils/constructContext'
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
import { sendMessage } from '../utils/sendMessage'
Expand All @@ -25,9 +25,11 @@ export const subscribe: MessageHandler<SubscribeMessage> =
}

const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, event, message }) => {
const connectionId = event.requestContext.connectionId

const connection = await server.mapper.get(
Object.assign(new server.model.Connection(), {
id: event.requestContext.connectionId,
id: connectionId,
}),
)
const connectionParams = connection.payload || {}
Expand All @@ -39,7 +41,7 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
throw new AggregateError(errors)
}

const contextValue = await constructContext({ server, connectionParams, connectionId: connection.id })
const contextValue = await constructContext({ server, connectionParams, connectionId })

const execContext = buildExecutionContext(
server.schema,
Expand Down Expand Up @@ -74,33 +76,33 @@ const setupSubscription: MessageHandler<SubscribeMessage> = async ({ server, eve
throw new Error('No field')
}

const { topicDefinitions, onSubscribe, onAfterSubscribe } = await (field.subscribe as SubscribeHandler)(
root,
args,
context,
info,
)
const { topicDefinitions, onSubscribe, onAfterSubscribe } = field.subscribe as SubscribePseudoIterable<PubSubEvent>

server.log('onSubscribe', { onSubscribe: !!onSubscribe })
await onSubscribe?.(root, args, context, info)

await Promise.all(topicDefinitions.map(async ({ topic, filter }) => {
const filterData = typeof filter === 'function' ? await filter(root, args, context, info) : filter

const subscription = Object.assign(new server.model.Subscription(), {
id: `${event.requestContext.connectionId}|${message.id}`,
id: `${connectionId}|${message.id}`,
topic,
filter: filter || {},
filter: filterData || {},
subscriptionId: message.id,
subscription: {
variableValues: args,
...message.payload,
},
connectionId: event.requestContext.connectionId,
connectionId,
connectionParams,
requestContext: event.requestContext,
ttl: connection.ttl,
})
server.log('subscribe:putSubscription %j', subscription)
await server.mapper.put(subscription)
}))

server.log('onAfterSubscribe', { onAfterSubscribe: !!onAfterSubscribe })
await onAfterSubscribe?.(root, args, context, info)
}

Expand All @@ -125,6 +127,8 @@ const validateMessage = (server: ServerClosure) => (message: SubscribeMessage) =

// eslint-disable-next-line @typescript-eslint/no-explicit-any
async function executeQuery(server: ServerClosure, message: SubscribeMessage, contextValue: any, event: APIGatewayWebSocketEvent) {
server.log('executeQuery', { connectionId: event.requestContext.connectionId })

const result = await execute(
server.schema,
parse(message.payload.query),
Expand Down
4 changes: 2 additions & 2 deletions lib/model/createModel.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { DynamoDbTable } from '@aws/dynamodb-data-mapper'
import { Class } from '../types'

export const createModel = <T extends Class>({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const createModel = <T extends { new(...args: any[]): any }>({
model,
table,
}: {
Expand Down
18 changes: 18 additions & 0 deletions lib/pubsub/complete-test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { tables } from '@architect/sandbox'
import { mockServerContext } from '../test/mockServer'
import { complete } from './complete'

describe('pubsub:complete', () => {
before(async () => {
await tables.start({ cwd: './mocks/arc-basic-events', quiet: true })
})

after(async () => {
await tables.end()
})

it('takes a topic', async () => {
const server = await mockServerContext()
await complete(server)({ topic: 'Topic12' })
})
})
8 changes: 5 additions & 3 deletions lib/pubsub/complete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ import AggregateError from 'aggregate-error'
import { parse } from 'graphql'
import { CompleteMessage, MessageType } from 'graphql-ws'
import { buildExecutionContext } from 'graphql/execution/execute'
import { ServerClosure, PubSubEvent, SubscribePseudoIterable } from '../types'
import { ServerClosure, PubSubEvent, SubscribePseudoIterable, PartialBy } from '../types'
import { sendMessage } from '../utils/sendMessage'
import { constructContext } from '../utils/constructContext'
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
import { isArray } from '../utils/isArray'
import { getFilteredSubs } from './getFilteredSubs'

export const complete = (server: ServerClosure) => async (event: PubSubEvent): Promise<void> => {
export const complete = (server: ServerClosure) => async (event: PartialBy<PubSubEvent, 'payload'>): Promise<void> => {
const subscriptions = await getFilteredSubs({ server, event })
server.log('pubsub:complete %j', { event, subscriptions })

const iters = subscriptions.map(async (sub) => {
const message: CompleteMessage = {
id: sub.subscriptionId,
Expand Down Expand Up @@ -38,7 +40,7 @@ export const complete = (server: ServerClosure) => async (event: PubSubEvent): P

const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)

const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete
const onComplete = (field?.subscribe as SubscribePseudoIterable<PubSubEvent>)?.onComplete
if (onComplete) {
await onComplete(root, args, context, info)
}
Expand Down
20 changes: 20 additions & 0 deletions lib/pubsub/getFilteredSubs-test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { assert } from 'chai'
import { collapseKeys } from './getFilteredSubs'

describe('collapseKeys', () => {
it('makes the deep objects into dots', () => {
assert.deepEqual(collapseKeys({}), {})
assert.deepEqual(collapseKeys({ a: 4, b: { c: 5, d: 'hi', e: { f: false } } }), {
a: 4,
'b.c': 5,
'b.d': 'hi',
'b.e.f': false,
})
assert.deepEqual(collapseKeys({ a: [1,2,3, { b: 4, c: [], d: null, e: undefined }] }), {
'a.0': 1,
'a.1': 2,
'a.2': 3,
'a.3.b': 4,
})
})
})
93 changes: 48 additions & 45 deletions lib/pubsub/getFilteredSubs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,44 @@ import {
} from '@aws/dynamodb-expressions'
import { collect } from 'streaming-iterables'
import { Subscription } from '../model/Subscription'
import { ServerClosure, PubSubEvent } from '../types'
import { ServerClosure, PubSubEvent, PartialBy } from '../types'

export const getFilteredSubs = async ({ server, event }: { server: Omit<ServerClosure, 'gateway'>, event: PartialBy<PubSubEvent, 'payload'> }): Promise<Subscription[]> => {
if (!event.payload || Object.keys(event.payload).length === 0) {
server.log('getFilteredSubs %j', { event })

const iterator = server.mapper.query(
server.model.Subscription,
{ topic: equals(event.topic) },
{ indexName: 'TopicIndex' },
)

return await collect(iterator)
}
const flattenPayload = collapseKeys(event.payload)
const conditions: ConditionExpression[] = Object.entries(flattenPayload).map(([key, value]) => ({
type: 'Or',
conditions: [
{
...attributeNotExists(),
subject: `filter.${key}`,
},
{
...equals(value),
subject: `filter.${key}`,
},
],
}))

server.log('getFilteredSubs %j', { event, conditions })

export const getFilteredSubs = async ({ server, event }: { server: Omit<ServerClosure, 'gateway'>, event: PubSubEvent }): Promise<Subscription[]> => {
const flattenPayload = flatten(event.payload)
const iterator = server.mapper.query(
server.model.Subscription,
{ topic: equals(event.topic) },
{
filter: {
type: 'And',
conditions: Object.entries(flattenPayload).reduce(
(p, [key, value]) => [
...p,
{
type: 'Or',
conditions: [
{
...attributeNotExists(),
subject: `filter.${key}`,
},
{
...equals(value),
subject: `filter.${key}`,
},
],
},
],
[] as ConditionExpression[],
),
conditions,
},
indexName: 'TopicIndex',
},
Expand All @@ -43,33 +52,27 @@ export const getFilteredSubs = async ({ server, event }: { server: Omit<ServerCl
return await collect(iterator)
}

export const flatten = (
export const collapseKeys = (
obj: object,
): Record<string, number | string | boolean> => {
if (obj === undefined || obj === null) {
return {}
}
return Object.entries(obj).reduce((p, [k1, v1]) => {
const record = {}
for (const [k1, v1] of Object.entries(obj)) {
if (typeof v1 === 'string' || typeof v1 === 'number' || typeof v1 === 'boolean') {
record[k1] = v1
continue
}

if (v1 && typeof v1 === 'object') {
const next = Object.entries(v1).reduce(
(prev, [k2, v2]) => ({
...prev,
[`${k1}.${k2}`]: v2,
}),
{},
)
return {
...p,
...flatten(next),
const next = {}

for (const [k2, v2] of Object.entries(v1)) {
next[`${k1}.${k2}`] = v2
}
}

if (typeof v1 === 'string' ||
typeof v1 === 'number' ||
typeof v1 === 'boolean') {
return { ...p, [k1]: v1 }
for (const [k1, v1] of Object.entries(collapseKeys(next))) {
record[k1] = v1
}
}

return p
}, {})
}
return record
}
5 changes: 4 additions & 1 deletion lib/pubsub/publish.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import { sendMessage } from '../utils/sendMessage'
import { constructContext } from '../utils/constructContext'
import { getFilteredSubs } from './getFilteredSubs'

export const publish = (server: ServerClosure) => async (event: PubSubEvent): Promise<void> => {
export const publish = (server: ServerClosure) => async <T extends PubSubEvent>(event: T): Promise<void> => {
server.log('pubsub:publish %j', { event })
const subscriptions = await getFilteredSubs({ server, event })
server.log('pubsub:publish %j', { subscriptions: subscriptions.map(({ connectionId, filter, subscription }) => ({ connectionId, filter, subscription }) ) })

const iters = subscriptions.map(async (sub) => {
const payload = await execute(
server.schema,
Expand Down
Loading