Skip to content

Commit

Permalink
Speed up subscription behavior by tracking state in middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
markerikson committed Oct 8, 2022
1 parent 73abb6a commit 6975282
Show file tree
Hide file tree
Showing 20 changed files with 456 additions and 191 deletions.
96 changes: 67 additions & 29 deletions packages/toolkit/src/query/core/buildMiddleware/batchActions.ts
@@ -1,5 +1,7 @@
import type { QueryThunk, RejectedAction } from '../buildThunks'
import type { InternalHandlerBuilder } from './types'
import type { SubscriptionState } from '../apiState'
import { produceWithPatches } from 'immer'

// Copied from https://github.com/feross/queue-microtask
let promise: Promise<any>
Expand All @@ -14,39 +16,75 @@ const queueMicrotaskShim =
}, 0)
)

export const buildBatchedActionsHandler: InternalHandlerBuilder<boolean> = ({
api,
queryThunk,
}) => {
let abortedQueryActionsQueue: RejectedAction<QueryThunk, any>[] = []
export const buildBatchedActionsHandler: InternalHandlerBuilder<
[actionShouldContinue: boolean, subscriptionExists: boolean]
> = ({ api, queryThunk }) => {
const { actuallyMutateSubscriptions } = api.internalActions
const subscriptionsPrefix = `${api.reducerPath}/subscriptions`

let previousSubscriptions: SubscriptionState =
null as unknown as SubscriptionState

let dispatchQueued = false

return (action, mwApi) => {
if (queryThunk.rejected.match(action)) {
const { condition, arg } = action.meta

if (condition && arg.subscribe) {
// request was aborted due to condition (another query already running)
// _Don't_ dispatch right away - queue it for a debounced grouped dispatch
abortedQueryActionsQueue.push(action)

if (!dispatchQueued) {
queueMicrotaskShim(() => {
mwApi.dispatch(
api.internalActions.subscriptionRequestsRejected(
abortedQueryActionsQueue
)
)
abortedQueryActionsQueue = []
dispatchQueued = false
})
dispatchQueued = true
}
// _Don't_ let the action reach the reducers now!
return false
return (action, mwApi, internalState) => {
if (!previousSubscriptions) {
// Initialize it the first time this handler runs
previousSubscriptions = JSON.parse(
JSON.stringify(internalState.currentSubscriptions)
)
}

// Intercept requests by hooks to see if they're subscribed
// Necessary because we delay updating store state to the end of the tick
if (api.internalActions.internal_probeSubscription.match(action)) {
const { queryCacheKey, requestId } = action.payload
const hasSubscription =
!!internalState.currentSubscriptions[queryCacheKey]?.[requestId]
return [false, hasSubscription]
}

// Update subscription data based on this action
const didMutate = actuallyMutateSubscriptions(
internalState.currentSubscriptions,
action
)

if (didMutate) {
if (!dispatchQueued) {
queueMicrotaskShim(() => {
// Deep clone the current subscription data
const newSubscriptions: SubscriptionState = JSON.parse(
JSON.stringify(internalState.currentSubscriptions)
)
// Figure out a smaller diff between original and current
const [, patches] = produceWithPatches(
previousSubscriptions,
() => newSubscriptions
)

// Sync the store state for visibility
mwApi.next(api.internalActions.subscriptionsUpdated(patches))
// Save the cloned state for later reference
previousSubscriptions = newSubscriptions
dispatchQueued = false
})
dispatchQueued = true
}

const isSubscriptionSliceAction =
!!action.type?.startsWith(subscriptionsPrefix)
const isAdditionalSubscriptionAction =
queryThunk.rejected.match(action) &&
action.meta.condition &&
!!action.meta.arg.subscribe

const actionShouldContinue =
!isSubscriptionSliceAction && !isAdditionalSubscriptionAction

return [actionShouldContinue, false]
}

return true
return [true, false]
}
}
19 changes: 13 additions & 6 deletions packages/toolkit/src/query/core/buildMiddleware/cacheCollection.ts
Expand Up @@ -7,6 +7,7 @@ import type {
TimeoutId,
InternalHandlerBuilder,
ApiMiddlewareInternalHandler,
InternalMiddlewareState,
} from './types'

export type ReferenceCacheCollection = never
Expand Down Expand Up @@ -54,16 +55,19 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({

function anySubscriptionsRemainingForKey(
queryCacheKey: string,
api: SubMiddlewareApi
internalState: InternalMiddlewareState
) {
const subscriptions =
api.getState()[reducerPath].subscriptions[queryCacheKey]
const subscriptions = internalState.currentSubscriptions[queryCacheKey]
return !!subscriptions && !isObjectEmpty(subscriptions)
}

const currentRemovalTimeouts: QueryStateMeta<TimeoutId> = {}

const handler: ApiMiddlewareInternalHandler = (action, mwApi) => {
const handler: ApiMiddlewareInternalHandler = (
action,
mwApi,
internalState
) => {
if (unsubscribeQueryResult.match(action)) {
const state = mwApi.getState()[reducerPath]
const { queryCacheKey } = action.payload
Expand All @@ -72,6 +76,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
queryCacheKey,
state.queries[queryCacheKey]?.endpointName,
mwApi,
internalState,
state.config
)
}
Expand All @@ -94,6 +99,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
queryCacheKey as QueryCacheKey,
queryState?.endpointName,
mwApi,
internalState,
state.config
)
}
Expand All @@ -104,6 +110,7 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
queryCacheKey: QueryCacheKey,
endpointName: string | undefined,
api: SubMiddlewareApi,
internalState: InternalMiddlewareState,
config: ConfigState<string>
) {
const endpointDefinition = context.endpointDefinitions[
Expand All @@ -125,13 +132,13 @@ export const buildCacheCollectionHandler: InternalHandlerBuilder = ({
Math.min(keepUnusedDataFor, THIRTY_TWO_BIT_MAX_TIMER_SECONDS)
)

if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) {
if (!anySubscriptionsRemainingForKey(queryCacheKey, internalState)) {
const currentTimeout = currentRemovalTimeouts[queryCacheKey]
if (currentTimeout) {
clearTimeout(currentTimeout)
}
currentRemovalTimeouts[queryCacheKey] = setTimeout(() => {
if (!anySubscriptionsRemainingForKey(queryCacheKey, api)) {
if (!anySubscriptionsRemainingForKey(queryCacheKey, internalState)) {
api.dispatch(removeQueryResult({ queryCacheKey }))
}
delete currentRemovalTimeouts![queryCacheKey]
Expand Down
Expand Up @@ -197,6 +197,7 @@ export const buildCacheLifecycleHandler: InternalHandlerBuilder = ({
const handler: ApiMiddlewareInternalHandler = (
action,
mwApi,
internalState,
stateBefore
) => {
const cacheKey = getCacheKey(action)
Expand Down
33 changes: 26 additions & 7 deletions packages/toolkit/src/query/core/buildMiddleware/index.ts
Expand Up @@ -10,7 +10,11 @@ import type { QueryThunkArg } from '../buildThunks'
import { buildCacheCollectionHandler } from './cacheCollection'
import { buildInvalidationByTagsHandler } from './invalidationByTags'
import { buildPollingHandler } from './polling'
import type { BuildMiddlewareInput, InternalHandlerBuilder } from './types'
import type {
BuildMiddlewareInput,
InternalHandlerBuilder,
InternalMiddlewareState,
} from './types'
import { buildWindowEventHandler } from './windowEventHandling'
import { buildCacheLifecycleHandler } from './cacheLifecycle'
import { buildQueryLifecycleHandler } from './queryLifecycle'
Expand Down Expand Up @@ -69,6 +73,10 @@ export function buildMiddleware<
const batchedActionsHandler = buildBatchedActionsHandler(builderArgs)
const windowEventsHandler = buildWindowEventHandler(builderArgs)

let internalState: InternalMiddlewareState = {
currentSubscriptions: {},
}

return (next) => {
return (action) => {
if (!initialized) {
Expand All @@ -77,19 +85,30 @@ export function buildMiddleware<
mwApi.dispatch(api.internalActions.middlewareRegistered(apiUid))
}

const mwApiWithNext = { ...mwApi, next }

const stateBefore = mwApi.getState()

if (!batchedActionsHandler(action, mwApi, stateBefore)) {
return
}
const [actionShouldContinue, hasSubscription] = batchedActionsHandler(
action,
mwApiWithNext,
internalState,
stateBefore
)

let res: any

const res = next(action)
if (actionShouldContinue) {
res = next(action)
} else {
res = hasSubscription
}

if (!!mwApi.getState()[reducerPath]) {
// Only run these checks if the middleware is registered okay

// This looks for actions that aren't specific to the API slice
windowEventsHandler(action, mwApi, stateBefore)
windowEventsHandler(action, mwApiWithNext, internalState, stateBefore)

if (
isThisApiSliceAction(action) ||
Expand All @@ -98,7 +117,7 @@ export function buildMiddleware<
// Only run these additional checks if the actions are part of the API slice,
// or the action has hydration-related data
for (let handler of handlers) {
handler(action, mwApi, stateBefore)
handler(action, mwApiWithNext, internalState, stateBefore)
}
}
}
Expand Down
33 changes: 21 additions & 12 deletions packages/toolkit/src/query/core/buildMiddleware/polling.ts
Expand Up @@ -6,6 +6,7 @@ import type {
TimeoutId,
InternalHandlerBuilder,
ApiMiddlewareInternalHandler,
InternalMiddlewareState,
} from './types'

export const buildPollingHandler: InternalHandlerBuilder = ({
Expand All @@ -20,26 +21,30 @@ export const buildPollingHandler: InternalHandlerBuilder = ({
pollingInterval: number
}> = {}

const handler: ApiMiddlewareInternalHandler = (action, mwApi) => {
const handler: ApiMiddlewareInternalHandler = (
action,
mwApi,
internalState
) => {
if (
api.internalActions.updateSubscriptionOptions.match(action) ||
api.internalActions.unsubscribeQueryResult.match(action)
) {
updatePollingInterval(action.payload, mwApi)
updatePollingInterval(action.payload, mwApi, internalState)
}

if (
queryThunk.pending.match(action) ||
(queryThunk.rejected.match(action) && action.meta.condition)
) {
updatePollingInterval(action.meta.arg, mwApi)
updatePollingInterval(action.meta.arg, mwApi, internalState)
}

if (
queryThunk.fulfilled.match(action) ||
(queryThunk.rejected.match(action) && !action.meta.condition)
) {
startNextPoll(action.meta.arg, mwApi)
startNextPoll(action.meta.arg, mwApi, internalState)
}

if (api.util.resetApiState.match(action)) {
Expand All @@ -49,11 +54,12 @@ export const buildPollingHandler: InternalHandlerBuilder = ({

function startNextPoll(
{ queryCacheKey }: QuerySubstateIdentifier,
api: SubMiddlewareApi
api: SubMiddlewareApi,
internalState: InternalMiddlewareState
) {
const state = api.getState()[reducerPath]
const querySubState = state.queries[queryCacheKey]
const subscriptions = state.subscriptions[queryCacheKey]
const subscriptions = internalState.currentSubscriptions[queryCacheKey]

if (!querySubState || querySubState.status === QueryStatus.uninitialized)
return
Expand Down Expand Up @@ -84,11 +90,12 @@ export const buildPollingHandler: InternalHandlerBuilder = ({

function updatePollingInterval(
{ queryCacheKey }: QuerySubstateIdentifier,
api: SubMiddlewareApi
api: SubMiddlewareApi,
internalState: InternalMiddlewareState
) {
const state = api.getState()[reducerPath]
const querySubState = state.queries[queryCacheKey]
const subscriptions = state.subscriptions[queryCacheKey]
const subscriptions = internalState.currentSubscriptions[queryCacheKey]

if (!querySubState || querySubState.status === QueryStatus.uninitialized) {
return
Expand All @@ -105,7 +112,7 @@ export const buildPollingHandler: InternalHandlerBuilder = ({
const nextPollTimestamp = Date.now() + lowestPollingInterval

if (!currentPoll || nextPollTimestamp < currentPoll.nextPollTimestamp) {
startNextPoll({ queryCacheKey }, api)
startNextPoll({ queryCacheKey }, api, internalState)
}
}

Expand All @@ -125,13 +132,15 @@ export const buildPollingHandler: InternalHandlerBuilder = ({

function findLowestPollingInterval(subscribers: Subscribers = {}) {
let lowestPollingInterval = Number.POSITIVE_INFINITY
for (const subscription of Object.values(subscribers)) {
if (!!subscription.pollingInterval)
for (let key in subscribers) {
if (!!subscribers[key].pollingInterval) {
lowestPollingInterval = Math.min(
subscription.pollingInterval,
subscribers[key].pollingInterval!,
lowestPollingInterval
)
}
}

return lowestPollingInterval
}
return handler
Expand Down

0 comments on commit 6975282

Please sign in to comment.