Skip to content

Commit

Permalink
fix(client): Ensure result extensions are applied after all query ext…
Browse files Browse the repository at this point in the history
…ensions

Fix #20437
  • Loading branch information
SevInf committed Jul 28, 2023
1 parent 612138c commit 29f4a5a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 41 deletions.
40 changes: 2 additions & 38 deletions packages/client/src/runtime/RequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ import { throwValidationException } from './core/errorRendering/throwValidationE
import { hasBatchIndex } from './core/errors/ErrorWithBatchIndex'
import { NotFoundError } from './core/errors/NotFoundError'
import { createApplyBatchExtensionsFunction } from './core/extensions/applyQueryExtensions'
import { applyResultExtensions } from './core/extensions/applyResultExtensions'
import { MergedExtensionsList } from './core/extensions/MergedExtensionsList'
import { visitQueryResult } from './core/extensions/visitQueryResult'
import { deserializeJsonResponse } from './core/jsonProtocol/deserializeJsonResponse'
import { getBatchId } from './core/jsonProtocol/getBatchId'
import { isWrite } from './core/jsonProtocol/isWrite'
import { dmmfToJSModelName } from './core/model/utils/dmmfToJSModelName'
import { PrismaPromiseInteractiveTransaction, PrismaPromiseTransaction } from './core/request/PrismaPromise'
import { Action, JsArgs } from './core/types/JsApi'
import { DataLoader } from './DataLoader'
Expand Down Expand Up @@ -64,13 +61,6 @@ export type HandleErrorParams = {
transaction?: PrismaPromiseTransaction
}

type ApplyExtensionsParams = {
result: object
modelName: string
args: JsArgs
extensions: MergedExtensionsList
}

export class RequestHandler {
client: Client
dataloader: DataLoader<RequestParams>
Expand Down Expand Up @@ -150,20 +140,14 @@ export class RequestHandler {
}
}

mapQueryEngineResult(
{ dataPath, unpacker, modelName, args, extensions }: RequestParams,
response: QueryEngineResult<any>,
) {
mapQueryEngineResult({ dataPath, unpacker }: RequestParams, response: QueryEngineResult<any>) {
const data = response?.data
const elapsed = response?.elapsed

/**
* Unpack
*/
let result = this.unpack(data, dataPath, unpacker)
if (modelName) {
result = this.applyResultExtensions({ result, modelName, args, extensions })
}
const result = this.unpack(data, dataPath, unpacker)
if (process.env.PRISMA_CLIENT_GET_TIME) {
return { data: result, elapsed }
}
Expand Down Expand Up @@ -275,26 +259,6 @@ export class RequestHandler {
return unpacker ? unpacker(deserializeResponse) : deserializeResponse
}

applyResultExtensions({ result, modelName, args, extensions }: ApplyExtensionsParams) {
if (extensions.isEmpty() || result == null) {
return result
}
const model = this.client._runtimeDataModel.models[modelName]
if (!model) {
return result
}
return visitQueryResult({
result,
args: args ?? {},
modelName,
runtimeDataModel: this.client._runtimeDataModel,
visitor(value, dmmfModelName, args) {
const modelName = dmmfToJSModelName(dmmfModelName)
return applyResultExtensions({ result: value, modelName, select: args.select, extensions })
},
})
}

get [Symbol.toStringTag]() {
return 'RequestHandler'
}
Expand Down
36 changes: 33 additions & 3 deletions packages/client/src/runtime/getPrismaClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ import {
import { prettyPrintArguments } from './core/errorRendering/prettyPrintArguments'
import { $extends } from './core/extensions/$extends'
import { applyQueryExtensions } from './core/extensions/applyQueryExtensions'
import { applyResultExtensions } from './core/extensions/applyResultExtensions'
import { MergedExtensionsList } from './core/extensions/MergedExtensionsList'
import { visitQueryResult } from './core/extensions/visitQueryResult'
import { checkPlatformCaching } from './core/init/checkPlatformCaching'
import { serializeJsonQuery } from './core/jsonProtocol/serializeJsonQuery'
import { MetricsClient } from './core/metrics/MetricsClient'
import {
applyModelsAndClientExtensions,
unApplyModelsAndClientExtensions,
} from './core/model/applyModelsAndClientExtensions'
import { dmmfToJSModelName } from './core/model/utils/dmmfToJSModelName'
import { rawCommandArgsMapper } from './core/raw-query/rawCommandArgsMapper'
import { RawQueryArgs } from './core/raw-query/RawQueryArgs'
import {
Expand All @@ -60,7 +63,7 @@ import { UserArgs } from './core/request/UserArgs'
import { RuntimeDataModel } from './core/runtimeDataModel'
import { getTracingHelper } from './core/tracing/TracingHelper'
import { getLockCountPromise } from './core/transaction/utils/createLockCountPromise'
import { JsInputValue } from './core/types/JsApi'
import { JsArgs, JsInputValue } from './core/types/JsApi'
import { getLogLevel } from './getLogLevel'
import { itxClientDenyList } from './itxClientDenyList'
import { mergeBy } from './mergeBy'
Expand Down Expand Up @@ -848,7 +851,7 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client

let index = -1
// prepare recursive fn that will pipe params through middlewares
const consumer = (changedMiddlewareParams: QueryMiddlewareParams) => {
const consumer = async (changedMiddlewareParams: QueryMiddlewareParams) => {
// if this `next` was called and there's some more middlewares
const nextMiddleware = this._middlewares.get(++index)

Expand Down Expand Up @@ -879,7 +882,11 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client
delete requestParams.transaction // client extensions check for this
}

return applyQueryExtensions(this, requestParams) // also executes the query
const result = await applyQueryExtensions(this, requestParams) // also executes the query
if (!requestParams.model) {
return result
}
return this._applyResultExtensions(result, requestParams.model, requestParams.args)
}

return this._tracingHelper.runInChildSpan(spanOptions.operation, () => {
Expand All @@ -893,6 +900,29 @@ Or read our docs at https://www.prisma.io/docs/concepts/components/prisma-client
})
}

_applyResultExtensions(result: object | null, modelName: string, args: JsArgs) {
if (this._extensions.isEmpty() || result == null) {
return result
}
const model = this._runtimeDataModel.models[modelName]
if (!model) {
return result
}
return visitQueryResult({
result,
args: args ?? {},
modelName,
runtimeDataModel: this._runtimeDataModel,
visitor: (value, dmmfModelName, args) =>
applyResultExtensions({
result: value,
modelName: dmmfToJSModelName(dmmfModelName),
select: args.select,
extensions: this._extensions,
}),
})
}

async _executeRequest({
args,
clientMethod,
Expand Down
25 changes: 25 additions & 0 deletions packages/client/tests/functional/extensions/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,31 @@ testMatrix.setupTestSuite(
})
})

test('result extensions are applied after query extension', async () => {
const xprisma = prisma.$extends({
result: {
user: {
fullName: {
needs: { firstName: true, lastName: true },
compute(user) {
return `${user.firstName} ${user.lastName}`
},
},
},
},
query: {
user: {
findFirstOrThrow() {
return Promise.resolve({ email: 'ext@example.com', firstName: 'From', lastName: 'Query' })
},
},
},
})

const result = await xprisma.user.findFirstOrThrow()
expect(result.fullName).toBe('From Query')
})

testIf(provider !== 'sqlite')('top-level raw queries interception', async () => {
const fnEmitter = jest.fn()
const fnUser = jest.fn()
Expand Down

0 comments on commit 29f4a5a

Please sign in to comment.