Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"npm.packageManager": "pnpm",
"eslint.packageManager": "pnpm"
}
2 changes: 1 addition & 1 deletion packages/internal/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/internal",
"version": "0.1.18",
"version": "0.1.20",
"description": "ZenStack internal runtime library",
"main": "lib/index.js",
"types": "lib/index.d.ts",
Expand Down
55 changes: 19 additions & 36 deletions packages/internal/src/handler/data/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
break;
}
} catch (err: any) {
console.error(`Error handling ${method} ${model}: ${err}`);
console.log(`Error handling ${method} ${model}: ${err}`);
if (err instanceof RequestHandlerError) {
switch (err.code) {
case ServerErrorCode.DENIED_BY_POLICY:
Expand All @@ -76,11 +76,18 @@ export default class DataHandler<DbClient> implements RequestHandler {
message: err.message,
});
}
} else if (err.code && PRISMA_ERROR_MAPPING[err.code]) {
res.status(400).send({
code: PRISMA_ERROR_MAPPING[err.code],
message: 'database access error',
});
} else if (err.code) {
if (PRISMA_ERROR_MAPPING[err.code]) {
res.status(400).send({
code: PRISMA_ERROR_MAPPING[err.code],
message: 'database access error',
});
} else {
res.status(400).send({
code: 'PRISMA:' + err.code,
message: 'an unhandled Prisma error occurred',
});
}
} else {
console.error(
`An unknown error occurred: ${JSON.stringify(err)}`
Expand Down Expand Up @@ -110,7 +117,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
if (id) {
if (processedArgs.where) {
processedArgs.where = {
AND: [args.where, { id }],
AND: [processedArgs.where, { id }],
};
} else {
processedArgs.where = { id };
Expand All @@ -127,13 +134,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
}

console.log(`Finding ${model}:\n${JSON.stringify(processedArgs)}`);
await this.queryProcessor.postProcess(
model,
processedArgs,
r,
'read',
context
);
await this.queryProcessor.postProcess(model, r, 'read', context);

res.status(200).send(r);
}
Expand Down Expand Up @@ -190,13 +191,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
return created;
});

await this.queryProcessor.postProcess(
model,
processedArgs,
r,
'create',
context
);
await this.queryProcessor.postProcess(model, r, 'create', context);
res.status(201).send(r);
}

Expand Down Expand Up @@ -265,13 +260,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
return updated;
});

await this.queryProcessor.postProcess(
model,
updateArgs,
r,
'update',
context
);
await this.queryProcessor.postProcess(model, r, 'update', context);
res.status(200).send(r);
}

Expand Down Expand Up @@ -307,13 +296,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
console.log(`Deleting ${model}:\n${JSON.stringify(delArgs)}`);
const db = (this.service.db as any)[model];
const r = await db.delete(delArgs);
await this.queryProcessor.postProcess(
model,
delArgs,
r,
'delete',
context
);
await this.queryProcessor.postProcess(model, r, 'delete', context);

res.status(200).send(r);
}
Expand All @@ -334,7 +317,7 @@ export default class DataHandler<DbClient> implements RequestHandler {
context
);
console.log(
`Finding to-be-deleted ${model}:\n${JSON.stringify(readArgs)}`
`Finding pre-operation ${model}:\n${JSON.stringify(readArgs)}`
);
const read = await db.findFirst(readArgs);
if (!read) {
Expand Down
128 changes: 126 additions & 2 deletions packages/internal/src/handler/data/query-processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ export class QueryProcessor {
}

if (r.include || r.select) {
if (r.include && r.select) {
throw Error(
'Passing both "include" and "select" at the same level of query is not supported'
);
}

// "include" and "select" are mutually exclusive
const selector = r.include ? 'include' : 'select';
for (const [field, value] of Object.entries(r[selector])) {
Expand Down Expand Up @@ -59,11 +65,129 @@ export class QueryProcessor {
return r;
}

private async getToOneFieldInfo(
model: string,
fieldName: string,
fieldValue: any
) {
if (
!!fieldValue &&
!Array.isArray(fieldValue) &&
typeof fieldValue === 'object' &&
typeof fieldValue.id == 'string'
) {
return null;
}

const fieldInfo = await this.service.resolveField(model, fieldName);
if (!fieldInfo || fieldInfo.isArray) {
return null;
}

return fieldInfo;
}

private async collectRelationFields(
model: string,
data: any,
map: Map<string, string[]>
) {
for (const [fieldName, fieldValue] of Object.entries(data)) {
const val: any = fieldValue;
const fieldInfo = await this.getToOneFieldInfo(
model,
fieldName,
fieldValue
);
if (!fieldInfo) {
continue;
}

if (!map.has(fieldInfo.type)) {
map.set(fieldInfo.type, []);
}
map.get(fieldInfo.type)!.push(val.id);

// recurse into field value
this.collectRelationFields(fieldInfo.type, val, map);
}
}

private async checkIdsAgainstPolicy(
relationFieldMap: Map<string, string[]>,
operation: PolicyOperationKind,
context: QueryContext
) {
const promises = Array.from(relationFieldMap.entries()).map(
async ([model, ids]) => {
const args = {
select: { id: true },
where: {
id: { in: ids },
},
};

const processedArgs = this.processQueryArgs(
model,
args,
operation,
context,
true
);

const checkedIds: Array<{ id: string }> = await this.service.db[
model
].findMany(processedArgs);
return [model, checkedIds.map((r) => r.id)] as [
string,
string[]
];
}
);
return new Map<string, string[]>(await Promise.all(promises));
}

private async sanitizeData(
model: string,
data: any,
validatedIds: Map<string, string[]>
) {
for (const [fieldName, fieldValue] of Object.entries(data)) {
const fieldInfo = await this.getToOneFieldInfo(
model,
fieldName,
fieldValue
);
if (!fieldInfo) {
continue;
}
const fv = fieldValue as { id: string };
const valIds = validatedIds.get(fieldInfo.type);

if (!valIds || !valIds.includes(fv.id)) {
console.log(
`Deleting field ${fieldName} from ${model}#${data.id}, because field value #${fv.id} failed policy check`
);
delete data[fieldName];
}

await this.sanitizeData(fieldInfo.type, fieldValue, validatedIds);
}
}

async postProcess(
model: string,
queryArgs: any,
data: any,
operation: PolicyOperationKind,
context: QueryContext
) {}
) {
const relationFieldMap = new Map<string, string[]>();
await this.collectRelationFields(model, data, relationFieldMap);
const validatedIds = await this.checkIdsAgainstPolicy(
relationFieldMap,
operation,
context
);
await this.sanitizeData(model, data, validatedIds);
}
}
2 changes: 1 addition & 1 deletion packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "zenstack",
"displayName": "ZenStack CLI and Language Tools",
"description": "ZenStack CLI and Language Tools",
"version": "0.1.38",
"version": "0.1.40",
"engines": {
"vscode": "^1.56.0"
},
Expand Down
10 changes: 9 additions & 1 deletion packages/schema/src/generator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,17 @@ export class ZenStackGenerator {
new PrismaGenerator(),
new ServiceGenerator(),
new ReactHooksGenerator(),
new NextAuthGenerator(),
];

try {
require('next-auth');
generators.push(new NextAuthGenerator());
} catch {
console.warn(
'Next-auth module is not installed, skipping generating adapter.'
);
}

for (const generator of generators) {
await generator.generate(context);
}
Expand Down
17 changes: 14 additions & 3 deletions packages/schema/src/generator/service/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context, Generator } from '../types';
import { Project, StructureKind } from 'ts-morph';
import { Project, StructureKind, VariableDeclarationKind } from 'ts-morph';
import * as path from 'path';
import colors from 'colors';
import { INTERNAL_PACKAGE } from '../constants';
Expand Down Expand Up @@ -41,6 +41,15 @@ export default class ServiceGenerator implements Generator {
.addBody()
.setBodyText('return this._prisma;');

sf.addVariableStatement({
declarationKind: VariableDeclarationKind.Let,
declarations: [
{
name: 'guardModule',
type: 'any',
},
],
});
cls
.addMethod({
name: 'resolveField',
Expand All @@ -57,8 +66,10 @@ export default class ServiceGenerator implements Generator {
],
})
.addBody().setBodyText(`
const module: any = await import('./query/guard');
return module._fieldMapping?.[model]?.[field];
if (!guardModule) {
guardModule = await import('./query/guard');
}
return guardModule._fieldMapping?.[model]?.[field];
`);

cls
Expand Down
6 changes: 4 additions & 2 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,14 @@ export class ZModelLinker extends DefaultLinker {
document: LangiumDocument<AstNode>,
extraScopes: ScopeProvider[]
) {
this.resolve(node.left, document, extraScopes);
this.resolve(node.right, document, extraScopes);
switch (node.operator) {
// TODO: support arithmetics?
// case '+':
// case '-':
// case '*':
// case '/':
// this.resolve(node.left, document, extraScopes);
// this.resolve(node.right, document, extraScopes);
// this.resolveToBuiltinTypeOrDecl(node, 'Int');
// break;

Expand All @@ -195,6 +195,8 @@ export class ZModelLinker extends DefaultLinker {
case '!=':
case '&&':
case '||':
this.resolve(node.left, document, extraScopes);
this.resolve(node.right, document, extraScopes);
this.resolveToBuiltinTypeOrDecl(node, 'Boolean');
break;

Expand Down
Loading