Skip to content

Commit 477fe0d

Browse files
committed
feat(runtime): inject enhanced client or tx context so it can be retrieved in extensions
1 parent 82b8d25 commit 477fe0d

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

packages/runtime/src/enhancements/node/proxy.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ export function makeProxy<T extends PrismaProxyHandler>(
289289
return propVal;
290290
}
291291

292-
return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer);
292+
return createHandlerProxy(makeHandler(target, prop), propVal, prop, target, errorTransformer);
293293
},
294294
});
295295

@@ -303,10 +303,15 @@ function createHandlerProxy<T extends PrismaProxyHandler>(
303303
handler: T,
304304
origTarget: any,
305305
model: string,
306+
dbOrTx: any,
306307
errorTransformer?: ErrorTransformer
307308
): T {
308309
return new Proxy(handler, {
309310
get(target, propKey) {
311+
if (propKey === '$zenstack_parent') {
312+
return dbOrTx;
313+
}
314+
310315
const prop = target[propKey as keyof T];
311316
if (typeof prop !== 'function') {
312317
// the proxy handler doesn't have this method, fall back to the original target
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
import { Prisma } from '@prisma/client';
3+
4+
describe('Proxy Extension Context', () => {
5+
it('works', async () => {
6+
const { enhance } = await loadSchema(
7+
`
8+
model Counter {
9+
model String @unique
10+
value Int
11+
12+
@@allow('all', true)
13+
}
14+
15+
model Address {
16+
id String @id @default(cuid())
17+
city String
18+
19+
@@allow('all', true)
20+
}
21+
`
22+
);
23+
24+
const db = enhance();
25+
const dbExtended = db.$extends({
26+
model: {
27+
$allModels: {
28+
async createWithCounter(this: any, args: any) {
29+
const context = Prisma.getExtensionContext(this);
30+
const modelName = context.$name;
31+
const dbOrTx = this['$zenstack_parent'];
32+
33+
const fn = async (tx: any) => {
34+
const counter = await tx.counter.findUnique({
35+
where: { model: modelName },
36+
});
37+
await tx.counter.upsert({
38+
where: { model: modelName },
39+
update: { value: (counter?.value ?? 0) + 1 },
40+
create: { model: modelName, value: 1 },
41+
});
42+
return tx[modelName].create(args);
43+
};
44+
45+
if (dbOrTx['$transaction']) {
46+
// not running in a transaction, so we need to create a new transaction
47+
return dbOrTx.$transaction(fn);
48+
}
49+
50+
return fn(dbOrTx);
51+
},
52+
},
53+
},
54+
});
55+
56+
const cities = [
57+
'Vienna',
58+
'Paris',
59+
'London',
60+
'Berlin',
61+
'New York',
62+
'Tokyo',
63+
'Sydney',
64+
'Seoul',
65+
'Mumbai',
66+
'Delhi',
67+
'Shanghai',
68+
]
69+
70+
await Promise.all([
71+
...cities.map((city) => dbExtended.address.createWithCounter({ data: { city } })),
72+
...cities.map((city) => dbExtended.$transaction((tx: any) => tx.address.createWithCounter({ data: { city: `${city}$tx` } }))),
73+
]);
74+
75+
// expecting object
76+
await expect(
77+
dbExtended.counter.findUniqueOrThrow({ where: { model: 'Address' } })
78+
).resolves.toMatchObject({
79+
model: 'Address',
80+
value: cities.length * 2,
81+
});
82+
});
83+
});

0 commit comments

Comments
 (0)