From 9f5e0a7728bfb15f48e198ac39b4eefeb47959fd Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 6 Oct 2025 19:36:56 -0700 Subject: [PATCH 1/4] feat: implement field validation --- packages/{ => config}/eslint-config/base.js | 0 .../{ => config}/eslint-config/package.json | 0 .../{ => config}/typescript-config/base.json | 0 .../typescript-config/package.json | 0 .../{ => config}/vitest-config/base.config.js | 0 .../{ => config}/vitest-config/package.json | 0 .../src/client/crud/operations/base.ts | 13 +- .../crud/{validator.ts => validator/index.ts} | 56 +-- .../src/client/crud/validator/utils.ts | 325 ++++++++++++++++++ packages/runtime/src/client/query-utils.ts | 29 +- packages/testtools/src/types.d.ts | 1 + packages/testtools/src/vitest-ext.ts | 46 ++- pnpm-lock.yaml | 80 ++--- tests/e2e/orm/client-api/compound-id.test.ts | 2 +- .../orm/validation/custom-validation.test.ts | 111 ++++++ tests/e2e/orm/validation/nested.test.ts | 39 +++ tests/e2e/orm/validation/toplevel.test.ts | 118 +++++++ 17 files changed, 734 insertions(+), 86 deletions(-) rename packages/{ => config}/eslint-config/base.js (100%) rename packages/{ => config}/eslint-config/package.json (100%) rename packages/{ => config}/typescript-config/base.json (100%) rename packages/{ => config}/typescript-config/package.json (100%) rename packages/{ => config}/vitest-config/base.config.js (100%) rename packages/{ => config}/vitest-config/package.json (100%) rename packages/runtime/src/client/crud/{validator.ts => validator/index.ts} (96%) create mode 100644 packages/runtime/src/client/crud/validator/utils.ts create mode 100644 tests/e2e/orm/validation/custom-validation.test.ts create mode 100644 tests/e2e/orm/validation/nested.test.ts create mode 100644 tests/e2e/orm/validation/toplevel.test.ts diff --git a/packages/eslint-config/base.js b/packages/config/eslint-config/base.js similarity index 100% rename from packages/eslint-config/base.js rename to packages/config/eslint-config/base.js diff --git a/packages/eslint-config/package.json b/packages/config/eslint-config/package.json similarity index 100% rename from packages/eslint-config/package.json rename to packages/config/eslint-config/package.json diff --git a/packages/typescript-config/base.json b/packages/config/typescript-config/base.json similarity index 100% rename from packages/typescript-config/base.json rename to packages/config/typescript-config/base.json diff --git a/packages/typescript-config/package.json b/packages/config/typescript-config/package.json similarity index 100% rename from packages/typescript-config/package.json rename to packages/config/typescript-config/package.json diff --git a/packages/vitest-config/base.config.js b/packages/config/vitest-config/base.config.js similarity index 100% rename from packages/vitest-config/base.config.js rename to packages/config/vitest-config/base.config.js diff --git a/packages/vitest-config/package.json b/packages/config/vitest-config/package.json similarity index 100% rename from packages/vitest-config/package.json rename to packages/config/vitest-config/package.json diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 34924952..65bdbbc2 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -131,15 +131,10 @@ export abstract class BaseOperationHandler { model: GetModels, filter: any, ): Promise { - const idFields = requireIdFields(this.schema, model); - const _filter = flattenCompoundUniqueFilters(this.schema, model, filter); - const query = kysely - .selectFrom(model) - .where((eb) => eb.and(_filter)) - .select(idFields.map((f) => kysely.dynamic.ref(f))) - .limit(1) - .modifyEnd(this.makeContextComment({ model, operation: 'read' })); - return this.executeQueryTakeFirst(kysely, query, 'exists'); + return this.readUnique(kysely, model, { + where: filter, + select: this.makeIdSelect(model), + }); } protected async read( diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator/index.ts similarity index 96% rename from packages/runtime/src/client/crud/validator.ts rename to packages/runtime/src/client/crud/validator/index.ts index beb31faf..e0cb39f6 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator/index.ts @@ -4,17 +4,18 @@ import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodSchema, ZodType } from 'zod'; import { + type AttributeApplication, type BuiltinType, type EnumDef, type FieldDef, type GetModels, type ModelDef, type SchemaDef, -} from '../../schema'; -import { enumerate } from '../../utils/enumerate'; -import { extractFields } from '../../utils/object-utils'; -import { formatError } from '../../utils/zod-utils'; -import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../constants'; +} from '../../../schema'; +import { enumerate } from '../../../utils/enumerate'; +import { extractFields } from '../../../utils/object-utils'; +import { formatError } from '../../../utils/zod-utils'; +import { AGGREGATE_OPERATORS, LOGICAL_COMBINATORS, NUMERIC_FIELD_TYPES } from '../../constants'; import { type AggregateArgs, type CountArgs, @@ -29,8 +30,8 @@ import { type UpdateManyAndReturnArgs, type UpdateManyArgs, type UpsertArgs, -} from '../crud-types'; -import { InputValidationError, InternalError } from '../errors'; +} from '../../crud-types'; +import { InputValidationError, InternalError } from '../../errors'; import { fieldHasDefaultValue, getDiscriminatorField, @@ -38,7 +39,8 @@ import { getUniqueFields, requireField, requireModel, -} from '../query-utils'; +} from '../../query-utils'; +import { addCustomValidation, addNumberValidation, addStringValidation } from './utils'; type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; @@ -191,11 +193,14 @@ export class InputValidator { schema = getSchema(model, options); this.schemaCache.set(cacheKey!, schema); } - const { error } = schema.safeParse(args); + const { error, data } = schema.safeParse(args); if (error) { - throw new InputValidationError(`Invalid ${operation} args: ${formatError(error)}`, error); + throw new InputValidationError( + `Invalid ${operation} args for model "${model}": ${formatError(error)}`, + error, + ); } - return args as T; + return data as T; } // #region Find @@ -235,13 +240,13 @@ export class InputValidator { return result; } - private makePrimitiveSchema(type: string) { + private makePrimitiveSchema(type: string, attributes?: AttributeApplication[]) { if (this.schema.typeDefs && type in this.schema.typeDefs) { return this.makeTypeDefSchema(type); } else { return match(type) - .with('String', () => z.string()) - .with('Int', () => z.number().int()) + .with('String', () => addStringValidation(z.string(), attributes)) + .with('Int', () => addNumberValidation(z.number().int(), attributes)) .with('Float', () => z.number()) .with('Boolean', () => z.boolean()) .with('BigInt', () => z.union([z.number().int(), z.bigint()])) @@ -860,7 +865,7 @@ export class InputValidator { uncheckedVariantFields[field] = fieldSchema; } } else { - let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type); + let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes); if (fieldDef.array) { fieldSchema = z @@ -889,14 +894,17 @@ export class InputValidator { } }); + const uncheckedCreateSchema = addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes); + const checkedCreateSchema = addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes); + if (!hasRelation) { - return this.orArray(z.strictObject(uncheckedVariantFields), canBeArray); + return this.orArray(uncheckedCreateSchema, canBeArray); } else { return z.union([ - z.strictObject(uncheckedVariantFields), - z.strictObject(checkedVariantFields), - ...(canBeArray ? [z.array(z.strictObject(uncheckedVariantFields))] : []), - ...(canBeArray ? [z.array(z.strictObject(checkedVariantFields))] : []), + uncheckedCreateSchema, + checkedCreateSchema, + ...(canBeArray ? [z.array(uncheckedCreateSchema)] : []), + ...(canBeArray ? [z.array(checkedCreateSchema)] : []), ]); } } @@ -1112,7 +1120,7 @@ export class InputValidator { uncheckedVariantFields[field] = fieldSchema; } } else { - let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type).optional(); + let fieldSchema: ZodType = this.makePrimitiveSchema(fieldDef.type, fieldDef.attributes).optional(); if (this.isNumericField(fieldDef)) { fieldSchema = z.union([ @@ -1161,10 +1169,12 @@ export class InputValidator { } }); + const uncheckedUpdateSchema = addCustomValidation(z.strictObject(uncheckedVariantFields), modelDef.attributes); + const checkedUpdateSchema = addCustomValidation(z.strictObject(checkedVariantFields), modelDef.attributes); if (!hasRelation) { - return z.strictObject(uncheckedVariantFields); + return uncheckedUpdateSchema; } else { - return z.union([z.strictObject(uncheckedVariantFields), z.strictObject(checkedVariantFields)]); + return z.union([uncheckedUpdateSchema, checkedUpdateSchema]); } } diff --git a/packages/runtime/src/client/crud/validator/utils.ts b/packages/runtime/src/client/crud/validator/utils.ts new file mode 100644 index 00000000..9c3d647c --- /dev/null +++ b/packages/runtime/src/client/crud/validator/utils.ts @@ -0,0 +1,325 @@ +import { invariant } from '@zenstackhq/common-helpers'; +import type { + AttributeApplication, + BinaryExpression, + CallExpression, + Expression, + FieldExpression, + MemberExpression, + UnaryExpression, +} from '@zenstackhq/sdk/schema'; +import { match, P } from 'ts-pattern'; +import { z } from 'zod'; +import { ExpressionUtils } from '../../../schema'; +import { QueryError } from '../../errors'; + +function getArgValue(expr: Expression | undefined): T | undefined { + if (!expr || !ExpressionUtils.isLiteral(expr)) { + return undefined; + } + return expr.value as T; +} + +export function addStringValidation(schema: z.ZodString, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + for (const attr of attributes) { + match(attr.name) + .with('@length', () => { + const min = getArgValue(attr.args?.[0]?.value); + if (min !== undefined) { + schema = schema.min(min); + } + const max = getArgValue(attr.args?.[1]?.value); + if (max !== undefined) { + schema = schema.max(max); + } + }) + .with('@startsWith', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + schema = schema.startsWith(value); + } + }) + .with('@endsWith', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + schema = schema.endsWith(value); + } + }) + .with('@contains', () => { + const value = getArgValue(attr.args?.[0]?.value); + if (value !== undefined) { + schema = schema.includes(value); + } + }) + .with('@regex', () => { + const pattern = getArgValue(attr.args?.[0]?.value); + if (pattern !== undefined) { + schema = schema.regex(new RegExp(pattern)); + } + }) + .with('@email', () => { + schema = schema.email(); + }) + .with('@datetime', () => { + schema = schema.datetime(); + }) + .with('@url', () => { + schema = schema.url(); + }) + .with('@trim', () => { + schema = schema.trim(); + }) + .with('@lower', () => { + schema = schema.toLowerCase(); + }) + .with('@upper', () => { + schema = schema.toUpperCase(); + }); + } + return schema; +} + +export function addNumberValidation(schema: z.ZodNumber, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + match(attr.name) + .with('@gt', () => { + schema = schema.gt(val); + }) + .with('@gte', () => { + schema = schema.gte(val); + }) + .with('@lt', () => { + schema = schema.lt(val); + }) + .with('@lte', () => { + schema = schema.lte(val); + }) + .with('@lt', () => { + schema = schema.lt(val); + }) + .with('@lte', () => { + schema = schema.lte(val); + }); + } + return schema; +} + +export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema { + const attrs = attributes?.filter((a) => a.name === '@@validate'); + if (!attrs || attrs.length === 0) { + return schema; + } + + for (const attr of attrs) { + const expr = attr.args?.[0]?.value; + if (!expr) { + continue; + } + const message = getArgValue(attr.args?.[1]?.value); + const pathExpr = attr.args?.[2]?.value; + let path: string[] | undefined = undefined; + if (pathExpr && ExpressionUtils.isArray(pathExpr)) { + path = pathExpr.items.map((e) => ExpressionUtils.getLiteralValue(e) as string); + } + schema = applyValidation(schema, expr, message, path); + } + return schema; +} + +function applyValidation( + schema: z.ZodSchema, + expr: Expression, + message: string | undefined, + path: string[] | undefined, +) { + const options: z.CustomErrorParams = {}; + if (message) { + options.message = message; + } + if (path) { + options.path = path; + } + return schema.refine((data) => Boolean(evalExpression(data, expr)), options); +} + +function evalExpression(data: any, expr: Expression): unknown { + return match(expr) + .with({ kind: 'literal' }, (e) => e.value) + .with({ kind: 'array' }, (e) => e.items.map((item) => evalExpression(data, item))) + .with({ kind: 'field' }, (e) => evalField(data, e)) + .with({ kind: 'member' }, (e) => evalMember(data, e)) + .with({ kind: 'unary' }, (e) => evalUnary(data, e)) + .with({ kind: 'binary' }, (e) => evalBinary(data, e)) + .with({ kind: 'call' }, (e) => evalCall(data, e)) + .with({ kind: 'this' }, () => data ?? null) + .with({ kind: 'null' }, () => null) + .exhaustive(); +} + +function evalField(data: any, e: FieldExpression) { + return data?.[e.field] ?? null; +} + +function evalUnary(data: any, expr: UnaryExpression) { + const operand = evalExpression(data, expr.operand); + switch (expr.op) { + case '!': + return !operand; + default: + throw new Error(`Unsupported unary operator: ${expr.op}`); + } +} + +function evalBinary(data: any, expr: BinaryExpression) { + const left = evalExpression(data, expr.left); + const right = evalExpression(data, expr.right); + return match(expr.op) + .with('&&', () => Boolean(left) && Boolean(right)) + .with('||', () => Boolean(left) || Boolean(right)) + .with('==', () => left == right) // eslint-disable-line eqeqeq + .with('!=', () => left != right) // eslint-disable-line eqeqeq + .with('<', () => (left as any) < (right as any)) + .with('<=', () => (left as any) <= (right as any)) + .with('>', () => (left as any) > (right as any)) + .with('>=', () => (left as any) >= (right as any)) + .with('?', () => { + if (!Array.isArray(left)) { + return false; + } + return left.some((item) => item === right); + }) + .with('!', () => { + if (!Array.isArray(left)) { + return false; + } + return left.every((item) => item === right); + }) + .with('^', () => { + if (!Array.isArray(left)) { + return false; + } + return !left.some((item) => item === right); + }) + .with('in', () => { + if (!Array.isArray(right)) { + return false; + } + return right.includes(left); + }) + .exhaustive(); +} + +function evalMember(data: any, expr: MemberExpression) { + let result: any = evalExpression(data, expr.receiver); + for (const member of expr.members) { + if (!result || typeof result !== 'object') { + return undefined; + } + result = result[member]; + } + return result ?? null; +} + +function evalCall(data: any, expr: CallExpression) { + const fieldArg = expr.args?.[0] ? evalExpression(data, expr.args[0]) : undefined; + return ( + match(expr.function) + // string functions + .with('length', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + + const min = getArgValue(expr.args?.[1]); + const max = getArgValue(expr.args?.[2]); + if (min && fieldArg.length < min) { + return false; + } + if (max && fieldArg.length > max) { + return false; + } + return true; + }) + .with(P.union('startsWith', 'endsWith', 'contains'), (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + invariant(expr.args?.[1], `"${f}" requires a search argument`); + + const search = getArgValue(expr.args?.[1])!; + const caseInsensitive = getArgValue(expr.args?.[2]) ?? false; + + const matcher = (x: string, y: string) => + match(f) + .with('startsWith', () => x.startsWith(y)) + .with('endsWith', () => x.endsWith(y)) + .with('contains', () => x.includes(y)) + .exhaustive(); + return caseInsensitive + ? matcher(fieldArg.toLowerCase(), search.toLowerCase()) + : matcher(fieldArg, search); + }) + .with('regex', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(typeof fieldArg === 'string', `"${f}" first argument must be a string`); + const pattern = getArgValue(expr.args?.[1])!; + invariant(pattern !== undefined, `"${f}" requires a pattern argument`); + return new RegExp(pattern).test(fieldArg); + }) + .with(P.union('email', 'url', 'datetime'), (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + return z.string()[f]().safeParse(fieldArg).success; + }) + // list functions + .with(P.union('has', 'hasEvery', 'hasSome'), (f) => { + invariant(expr.args?.[1], `${f} requires a search argument`); + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(Array.isArray(fieldArg), `"${f}" first argument must be an array field`); + + const search = evalExpression(data, expr.args?.[1])!; + const matcher = (x: any[], y: any) => + match(f) + .with('has', () => x.some((item) => item === y)) + .with('hasEvery', () => { + invariant(Array.isArray(y), 'hasEvery second argument must be an array'); + return y.every((v) => x.some((item) => item === v)); + }) + .with('hasSome', () => { + invariant(Array.isArray(y), 'hasSome second argument must be an array'); + return y.some((v) => x.some((item) => item === v)); + }) + .exhaustive(); + return matcher(fieldArg, search); + }) + .with('isEmpty', (f) => { + if (fieldArg === undefined || fieldArg === null) { + return false; + } + invariant(Array.isArray(fieldArg), `"${f}" first argument must be an array field`); + return fieldArg.length === 0; + }) + .otherwise(() => { + throw new QueryError(`Unknown function "${expr.function}"`); + }) + ); +} diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 869d3535..b5107cdf 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -304,16 +304,37 @@ export function flattenCompoundUniqueFilters(schema: SchemaDef, model: string, f return filter; } - const result: any = {}; + const flattenedResult: any = {}; + const restFilter: any = {}; + for (const [key, value] of Object.entries(filter)) { if (compoundUniques.some(({ name }) => name === key)) { // flatten the compound field - Object.assign(result, value); + Object.assign(flattenedResult, value); } else { - result[key] = value; + restFilter[key] = value; + } + } + + if (Object.keys(flattenedResult).length === 0) { + // nothing flattened + return filter; + } else if (Object.keys(restFilter).length === 0) { + // all flattened + return flattenedResult; + } else { + const flattenedKeys = Object.keys(flattenedResult); + const restKeys = Object.keys(restFilter); + if (flattenedKeys.some((k) => restKeys.includes(k))) { + // keys overlap, cannot merge directly, build an AND clause + return { + AND: [flattenedResult, restFilter], + }; + } else { + // safe to merge directly + return { ...flattenedResult, ...restFilter }; } } - return result; } export function ensureArray(value: T | T[]): T[] { diff --git a/packages/testtools/src/types.d.ts b/packages/testtools/src/types.d.ts index b547127c..9f58106f 100644 --- a/packages/testtools/src/types.d.ts +++ b/packages/testtools/src/types.d.ts @@ -7,6 +7,7 @@ interface CustomMatchers { toResolveWithLength: (length: number) => Promise; toBeRejectedNotFound: () => Promise; toBeRejectedByPolicy: (expectedMessages?: string[]) => Promise; + toBeRejectedByValidation: (expectedMessages?: string[]) => Promise; } declare module 'vitest' { diff --git a/packages/testtools/src/vitest-ext.ts b/packages/testtools/src/vitest-ext.ts index 70b5a61b..06b1709b 100644 --- a/packages/testtools/src/vitest-ext.ts +++ b/packages/testtools/src/vitest-ext.ts @@ -1,4 +1,4 @@ -import { NotFoundError, RejectedByPolicyError } from '@zenstackhq/runtime'; +import { InputValidationError, NotFoundError, RejectedByPolicyError } from '@zenstackhq/runtime'; import { expect } from 'vitest'; function isPromise(value: any) { @@ -19,6 +19,18 @@ function expectError(err: any, errorType: any) { } } +function expectErrorMessages(expectedMessages: string[], message: string) { + for (const m of expectedMessages) { + if (!message.includes(m)) { + return { + message: () => `expected message not found in error: ${m}, got message: ${message}`, + pass: false, + }; + } + } + return undefined; +} + expect.extend({ async toResolveTruthy(received: Promise) { if (!isPromise(received)) { @@ -84,14 +96,9 @@ expect.extend({ await received; } catch (err) { if (expectedMessages && err instanceof RejectedByPolicyError) { - const message = err.message || ''; - for (const m of expectedMessages) { - if (!message.includes(m)) { - return { - message: () => `expected message not found in error: ${m}, got message: ${message}`, - pass: false, - }; - } + const r = expectErrorMessages(expectedMessages, err.message || ''); + if (r) { + return r; } } return expectError(err, RejectedByPolicyError); @@ -101,4 +108,25 @@ expect.extend({ pass: false, }; }, + + async toBeRejectedByValidation(received: Promise, expectedMessages?: string[]) { + if (!isPromise(received)) { + return { message: () => 'a promise is expected', pass: false }; + } + try { + await received; + } catch (err) { + if (expectedMessages && err instanceof InputValidationError) { + const r = expectErrorMessages(expectedMessages, err.message || ''); + if (r) { + return r; + } + } + return expectError(err, InputValidationError); + } + return { + message: () => `expected InputValidationError, got no error`, + pass: false, + }; + }, }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3c85aa5e..740f983e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -134,7 +134,7 @@ importers: version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime @@ -143,10 +143,10 @@ importers: version: link:../testtools '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config better-sqlite3: specifier: 'catalog:' version: 12.2.0 @@ -158,10 +158,16 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config + + packages/config/eslint-config: {} + + packages/config/typescript-config: {} + + packages/config/vitest-config: {} packages/create-zenstack: dependencies: @@ -177,10 +183,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config packages/dialects/sql.js: devDependencies: @@ -189,13 +195,13 @@ importers: version: 1.4.9 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../vitest-config + version: link:../../config/vitest-config kysely: specifier: 'catalog:' version: 0.27.6 @@ -203,8 +209,6 @@ importers: specifier: ^1.13.0 version: 1.13.0 - packages/eslint-config: {} - packages/ide/vscode: dependencies: '@zenstackhq/language': @@ -225,10 +229,10 @@ importers: version: 1.101.0 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config packages/language: dependencies: @@ -253,13 +257,13 @@ importers: version: link:../common-helpers '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config glob: specifier: ^11.0.2 version: 11.0.2 @@ -296,13 +300,13 @@ importers: version: 8.11.11 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../../eslint-config + version: link:../../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../typescript-config + version: link:../../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../vitest-config + version: link:../../config/vitest-config packages/runtime: dependencies: @@ -357,7 +361,7 @@ importers: version: 2.0.7 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/language': specifier: workspace:* version: link:../language @@ -366,10 +370,10 @@ importers: version: link:../sdk '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../vitest-config + version: link:../config/vitest-config tsx: specifier: ^4.19.2 version: 4.19.2 @@ -397,10 +401,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config decimal.js: specifier: ^10.4.3 version: 10.4.3 @@ -419,10 +423,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config packages/testtools: dependencies: @@ -477,10 +481,10 @@ importers: version: 0.2.6 '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config copyfiles: specifier: ^2.4.1 version: 2.4.1 @@ -488,10 +492,6 @@ importers: specifier: 'catalog:' version: 5.8.3 - packages/typescript-config: {} - - packages/vitest-config: {} - packages/zod: dependencies: '@zenstackhq/runtime': @@ -503,10 +503,10 @@ importers: devDependencies: '@zenstackhq/eslint-config': specifier: workspace:* - version: link:../eslint-config + version: link:../config/eslint-config '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../typescript-config + version: link:../config/typescript-config zod: specifier: ~3.25.0 version: 3.25.76 @@ -531,7 +531,7 @@ importers: version: link:../../packages/cli '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config prisma: specifier: 'catalog:' version: 6.14.0(typescript@5.8.3) @@ -580,10 +580,10 @@ importers: version: 11.0.0 '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../packages/vitest-config + version: link:../../packages/config/vitest-config tests/regression: dependencies: @@ -605,10 +605,10 @@ importers: version: link:../../packages/sdk '@zenstackhq/typescript-config': specifier: workspace:* - version: link:../../packages/typescript-config + version: link:../../packages/config/typescript-config '@zenstackhq/vitest-config': specifier: workspace:* - version: link:../../packages/vitest-config + version: link:../../packages/config/vitest-config packages: diff --git a/tests/e2e/orm/client-api/compound-id.test.ts b/tests/e2e/orm/client-api/compound-id.test.ts index b983b045..dc11c253 100644 --- a/tests/e2e/orm/client-api/compound-id.test.ts +++ b/tests/e2e/orm/client-api/compound-id.test.ts @@ -1,5 +1,5 @@ -import { describe, expect, it } from 'vitest'; import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; describe('Compound ID tests', () => { describe('to-one relation', () => { diff --git a/tests/e2e/orm/validation/custom-validation.test.ts b/tests/e2e/orm/validation/custom-validation.test.ts new file mode 100644 index 00000000..543afb50 --- /dev/null +++ b/tests/e2e/orm/validation/custom-validation.test.ts @@ -0,0 +1,111 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Custom validation tests', () => { + it('works with custom validation', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + str1 String? + str2 String? + str3 String? + str4 String? + str5 String? + int1 Int? + list1 Int[] + list2 Int[] + + @@validate( + (str1 == null || length(str1, 8, 10)) + && (int1 == null || (int1 > 1 && int1 < 4)), + 'invalid fields') + + @@validate(str1 == null || (startsWith(str1, 'a') && endsWith(str1, 'm') && contains(str1, 'b')), 'invalid fields') + + @@validate(str2 == null || regex(str2, '^x.*z$'), 'invalid str2') + + @@validate(str3 == null || email(str3), 'invalid str3') + + @@validate(str4 == null || url(str4), 'invalid str4') + + @@validate(str5 == null || datetime(str5), 'invalid str5') + + @@validate(list1 == null || (has(list1, 1) && hasSome(list1, [2, 3]) && hasEvery(list1, [4, 5])), 'invalid list1') + + @@validate(list2 == null || isEmpty(list2), 'invalid list2', ['x', 'y']) + } + `, + { provider: 'postgresql' }, + ); + + await db.foo.create({ data: { id: 1 } }); + + for (const action of ['create', 'update']) { + const _t = + action === 'create' + ? (data: any) => db.foo.create({ data: { id: 2, ...data } }) + : (data: any) => db.foo.update({ where: { id: 1 }, data }); + // violates length + await expect(_t({ str1: 'abd@efg.com' })).toBeRejectedByValidation(['invalid fields']); + await expect(_t({ str1: 'a@b.c' })).toBeRejectedByValidation(['invalid fields']); + + // violates int1 > 1 + await expect(_t({ int1: 1 })).toBeRejectedByValidation(['invalid fields']); + + // violates startsWith + await expect(_t({ str1: 'b@cd.com' })).toBeRejectedByValidation(['invalid fields']); + + // violates endsWith + await expect(_t({ str1: 'a@b.gov' })).toBeRejectedByValidation(['invalid fields']); + + // violates contains + await expect(_t({ str1: 'a@cd.com' })).toBeRejectedByValidation(['invalid fields']); + + // violates regex + await expect(_t({ str2: 'xab' })).toBeRejectedByValidation(['invalid str2']); + + // violates email + await expect(_t({ str3: 'not-an-email' })).toBeRejectedByValidation(['invalid str3']); + + // violates url + await expect(_t({ str4: 'not-an-url' })).toBeRejectedByValidation(['invalid str4']); + + // violates datetime + await expect(_t({ str5: 'not-an-datetime' })).toBeRejectedByValidation(['invalid str5']); + + // violates has + await expect(_t({ list1: [2, 3, 4, 5] })).toBeRejectedByValidation(['invalid list1']); + + // violates hasSome + await expect(_t({ list1: [1, 4, 5] })).toBeRejectedByValidation(['invalid list1']); + + // violates hasEvery + await expect(_t({ list1: [1, 2, 3, 4] })).toBeRejectedByValidation(['invalid list1']); + + // violates isEmpty + let thrown = false; + try { + await _t({ list2: [1] }); + } catch (err) { + thrown = true; + expect((err as any).cause.issues[0].path).toEqual(['data', 'x', 'y']); + } + expect(thrown); + + // satisfies all + await expect( + _t({ + str1: 'ab12345m', + str2: 'x...z', + str3: 'ab@c.com', + str4: 'http://a.b.c', + str5: new Date().toISOString(), + int1: 2, + list1: [1, 2, 4, 5], + list2: [], + }), + ).toResolveTruthy(); + } + }); +}); diff --git a/tests/e2e/orm/validation/nested.test.ts b/tests/e2e/orm/validation/nested.test.ts new file mode 100644 index 00000000..80849e37 --- /dev/null +++ b/tests/e2e/orm/validation/nested.test.ts @@ -0,0 +1,39 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Nested field validation tests', () => { + it('works with nested create/update', async () => { + const db = await createTestClient(` + model User { + id Int @id @default(autoincrement()) + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + email String @email + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@validate(contains(email, 'zenstack'), 'email must be a zenstack email') + } + `); + + await db.user.create({ data: { id: 1 } }); + + for (const action of ['create', 'update']) { + const _t = + action === 'create' + ? (data: any) => db.user.update({ where: { id: 1 }, data: { profile: { create: data } } }) + : (data: any) => db.user.update({ where: { id: 1 }, data: { profile: { update: data } } }); + + // violates email + await expect(_t({ email: 'zenstack' })).toBeRejectedByValidation(['Invalid email']); + + // violates custom validation + await expect(_t({ email: 'a@b.com' })).toBeRejectedByValidation(['email must be a zenstack email']); + + // satisfies all + await expect(_t({ email: 'me@zenstack.dev' })).toResolveTruthy(); + } + }); +}); diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts new file mode 100644 index 00000000..5c8843f3 --- /dev/null +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -0,0 +1,118 @@ +import { createTestClient } from '@zenstackhq/testtools'; +import { describe, expect, it } from 'vitest'; + +describe('Toplevel field validation tests', () => { + it('works with string fields', async () => { + const db = await createTestClient(` + model Foo { + id Int @id @default(autoincrement()) + str1 String? @length(2, 4) @startsWith('a') @endsWith('b') @contains('m') @regex('b{2}') + str2 String? @email + str3 String? @datetime + str4 String? @url + str5 String? @trim @lower + str6 String? @upper + } + `); + + await db.foo.create({ data: { id: 1 } }); + + for (const action of ['create', 'update', 'upsert', 'updateMany']) { + const _t = + action === 'create' + ? (data: any) => db.foo.create({ data }) + : action === 'update' + ? (data: any) => db.foo.update({ where: { id: 1 }, data }) + : action === 'upsert' + ? (data: any) => db.foo.upsert({ where: { id: 1 }, create: data, update: data }) + : (data: any) => db.foo.updateMany({ where: { id: 1 }, data }); + + // violates @length min + await expect(_t({ str1: 'a' })).toBeRejectedByValidation(); + + // violates @length max + await expect(_t({ str1: 'abcde' })).toBeRejectedByValidation(); + + // violates @startsWith + await expect(_t({ str1: 'bcd' })).toBeRejectedByValidation(); + + // violates @endsWith + await expect(_t({ str1: 'abc' })).toBeRejectedByValidation(); + + // violates @contains + await expect(_t({ str1: 'abz' })).toBeRejectedByValidation(); + + // violates @regex + await expect(_t({ str1: 'amcb' })).toBeRejectedByValidation(); + + // satisfies all + await expect(_t({ str1: 'ambb' })).toResolveTruthy(); + + // violates @email + await expect(_t({ str2: 'not-an-email' })).toBeRejectedByValidation(['Invalid email']); + + // satisfies @email + await expect(_t({ str2: 'test@example.com' })).toResolveTruthy(); + + // violates @datetime + await expect(_t({ str3: 'not-datetime' })).toBeRejectedByValidation(); + + // satisfies @datetime + await expect(_t({ str3: new Date().toISOString() })).toResolveTruthy(); + + // violates @url + await expect(_t({ str4: 'not-a-url' })).toBeRejectedByValidation(); + + // satisfies @url + await expect(_t({ str4: 'https://example.com' })).toResolveTruthy(); + + // test @trim and @lower + if (action !== 'updateMany') { + await expect(_t({ str5: ' AbC ' })).resolves.toMatchObject({ str5: 'abc' }); + } else { + await expect(_t({ str5: ' AbC ' })).resolves.toMatchObject({ count: 1 }); + } + + // test @upper + if (action !== 'updateMany') { + await expect(_t({ str6: 'aBc' })).resolves.toMatchObject({ str6: 'ABC' }); + } else { + await expect(_t({ str6: 'aBc' })).resolves.toMatchObject({ count: 1 }); + } + } + }); + + it('works with number fields', async () => { + const db = await createTestClient(` + model Foo { + id Int @id @default(autoincrement()) + int1 Int? @gt(2) @lt(4) + int2 Int? @gte(2) @lte(4) + } + `); + + await db.foo.create({ data: { id: 1 } }); + + for (const action of ['create', 'update']) { + const _t = + action === 'create' + ? (data: any) => db.foo.create({ data }) + : (data: any) => db.foo.update({ where: { id: 1 }, data }); + + // violates @gt + await expect(_t({ int1: 1 })).toBeRejectedByValidation(); + + // violates @lt + await expect(_t({ int1: 4 })).toBeRejectedByValidation(); + + // violates @gte + await expect(_t({ int2: 1 })).toBeRejectedByValidation(); + + // violates @lte + await expect(_t({ int2: 5 })).toBeRejectedByValidation(); + + // satisfies all + await expect(_t({ int1: 3, int2: 4 })).toResolveTruthy(); + } + }); +}); From c52b1df258f3261c89dfe78c8e69148d86748d12 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 8 Oct 2025 12:28:39 -0700 Subject: [PATCH 2/4] update --- packages/language/res/stdlib.zmodel | 8 +- .../src/client/crud/dialects/postgresql.ts | 2 +- .../src/client/crud/dialects/sqlite.ts | 2 +- .../src/client/crud/validator/index.ts | 27 +++- .../src/client/crud/validator/utils.ts | 135 ++++++++++++++---- packages/runtime/src/client/query-builder.ts | 2 +- packages/runtime/src/utils/type-utils.ts | 2 +- packages/sdk/src/schema/schema.ts | 2 +- .../e2e/orm/client-api/type-coverage.test.ts | 4 +- .../orm/validation/custom-validation.test.ts | 8 +- tests/e2e/orm/validation/nested.test.ts | 6 +- tests/e2e/orm/validation/toplevel.test.ts | 103 +++++++++---- 12 files changed, 231 insertions(+), 70 deletions(-) diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 52d34ae4..9684692d 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -543,22 +543,22 @@ attribute @upper() @@@targetField([StringField]) @@@validation /** * Validates a number field is greater than the given value. */ -attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is greater than or equal to the given value. */ -attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than the given value. */ -attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than or equal to the given value. */ -attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation +attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates the entity with a complex condition. diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index b6c40661..10bed59a 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import Decimal from 'decimal.js'; +import { Decimal } from 'decimal.js'; import { sql, type Expression, diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 5c024dfb..20aeead8 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import Decimal from 'decimal.js'; +import { Decimal } from 'decimal.js'; import { ExpressionWrapper, sql, diff --git a/packages/runtime/src/client/crud/validator/index.ts b/packages/runtime/src/client/crud/validator/index.ts index e0cb39f6..38bd421b 100644 --- a/packages/runtime/src/client/crud/validator/index.ts +++ b/packages/runtime/src/client/crud/validator/index.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import Decimal from 'decimal.js'; +import { Decimal } from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodSchema, ZodType } from 'zod'; @@ -40,7 +40,13 @@ import { requireField, requireModel, } from '../../query-utils'; -import { addCustomValidation, addNumberValidation, addStringValidation } from './utils'; +import { + addBigIntValidation, + addCustomValidation, + addDecimalValidation, + addNumberValidation, + addStringValidation, +} from './utils'; type GetSchemaFunc = (model: GetModels, options: Options) => ZodType; @@ -247,10 +253,21 @@ export class InputValidator { return match(type) .with('String', () => addStringValidation(z.string(), attributes)) .with('Int', () => addNumberValidation(z.number().int(), attributes)) - .with('Float', () => z.number()) + .with('Float', () => addNumberValidation(z.number(), attributes)) .with('Boolean', () => z.boolean()) - .with('BigInt', () => z.union([z.number().int(), z.bigint()])) - .with('Decimal', () => z.union([z.number(), z.instanceof(Decimal), z.string()])) + .with('BigInt', () => + z.union([ + addNumberValidation(z.number().int(), attributes), + addBigIntValidation(z.bigint(), attributes), + ]), + ) + .with('Decimal', () => + z.union([ + addNumberValidation(z.number(), attributes), + addDecimalValidation(z.instanceof(Decimal), attributes), + addDecimalValidation(z.string(), attributes), + ]), + ) .with('DateTime', () => z.union([z.date(), z.string().datetime()])) .with('Bytes', () => z.instanceof(Uint8Array)) .otherwise(() => z.unknown()); diff --git a/packages/runtime/src/client/crud/validator/utils.ts b/packages/runtime/src/client/crud/validator/utils.ts index 9c3d647c..1d3a12d1 100644 --- a/packages/runtime/src/client/crud/validator/utils.ts +++ b/packages/runtime/src/client/crud/validator/utils.ts @@ -8,6 +8,7 @@ import type { MemberExpression, UnaryExpression, } from '@zenstackhq/sdk/schema'; +import { Decimal } from 'decimal.js'; import { match, P } from 'ts-pattern'; import { z } from 'zod'; import { ExpressionUtils } from '../../../schema'; @@ -25,62 +26,63 @@ export function addStringValidation(schema: z.ZodString, attributes: AttributeAp return schema; } + let result = schema; for (const attr of attributes) { match(attr.name) .with('@length', () => { const min = getArgValue(attr.args?.[0]?.value); if (min !== undefined) { - schema = schema.min(min); + result = result.min(min); } const max = getArgValue(attr.args?.[1]?.value); if (max !== undefined) { - schema = schema.max(max); + result = result.max(max); } }) .with('@startsWith', () => { const value = getArgValue(attr.args?.[0]?.value); if (value !== undefined) { - schema = schema.startsWith(value); + result = result.startsWith(value); } }) .with('@endsWith', () => { const value = getArgValue(attr.args?.[0]?.value); if (value !== undefined) { - schema = schema.endsWith(value); + result = result.endsWith(value); } }) .with('@contains', () => { const value = getArgValue(attr.args?.[0]?.value); if (value !== undefined) { - schema = schema.includes(value); + result = result.includes(value); } }) .with('@regex', () => { const pattern = getArgValue(attr.args?.[0]?.value); if (pattern !== undefined) { - schema = schema.regex(new RegExp(pattern)); + result = result.regex(new RegExp(pattern)); } }) .with('@email', () => { - schema = schema.email(); + result = result.email(); }) .with('@datetime', () => { - schema = schema.datetime(); + result = result.datetime(); }) .with('@url', () => { - schema = schema.url(); + result = result.url(); }) .with('@trim', () => { - schema = schema.trim(); + result = result.trim(); }) .with('@lower', () => { - schema = schema.toLowerCase(); + result = result.toLowerCase(); }) .with('@upper', () => { - schema = schema.toUpperCase(); + result = result.toUpperCase(); }); } - return schema; + return result; } export function addNumberValidation(schema: z.ZodNumber, attributes: AttributeApplication[] | undefined): z.ZodSchema { @@ -88,6 +90,7 @@ export function addNumberValidation(schema: z.ZodNumber, attributes: AttributeAp return schema; } + let result = schema; for (const attr of attributes) { const val = getArgValue(attr.args?.[0]?.value); if (val === undefined) { @@ -95,25 +98,108 @@ export function addNumberValidation(schema: z.ZodNumber, attributes: AttributeAp } match(attr.name) .with('@gt', () => { - schema = schema.gt(val); + result = result.gt(val); }) .with('@gte', () => { - schema = schema.gte(val); + result = result.gte(val); }) .with('@lt', () => { - schema = schema.lt(val); + result = result.lt(val); }) .with('@lte', () => { - schema = schema.lte(val); + result = result.lte(val); + }); + } + return result; +} + +export function addBigIntValidation(schema: z.ZodBigInt, attributes: AttributeApplication[] | undefined): z.ZodSchema { + if (!attributes || attributes.length === 0) { + return schema; + } + + let result = schema; + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + const bigIntVal = BigInt(val); + match(attr.name) + .with('@gt', () => { + result = result.gt(bigIntVal); + }) + .with('@gte', () => { + result = result.gte(bigIntVal); }) .with('@lt', () => { - schema = schema.lt(val); + result = result.lt(bigIntVal); }) .with('@lte', () => { - schema = schema.lte(val); + result = result.lte(bigIntVal); }); } - return schema; + return result; +} + +export function addDecimalValidation( + schema: z.ZodType | z.ZodString, + attributes: AttributeApplication[] | undefined, +): z.ZodSchema { + let result: z.ZodSchema = schema; + + // parse string to Decimal + if (schema instanceof z.ZodString) { + result = schema + .superRefine((v, ctx) => { + try { + new Decimal(v); + } catch (err) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: `Invalid decimal: ${err}`, + }); + } + }) + .transform((val) => new Decimal(val)); + } + + // add validations + + function refine(schema: z.ZodSchema, op: 'gt' | 'gte' | 'lt' | 'lte', value: number) { + return schema.superRefine((v, ctx) => { + const base = z.number(); + const { error } = base[op](value).safeParse((v as Decimal).toNumber()); + error?.errors.forEach((e) => { + ctx.addIssue(e); + }); + }); + } + + if (attributes) { + for (const attr of attributes) { + const val = getArgValue(attr.args?.[0]?.value); + if (val === undefined) { + continue; + } + + match(attr.name) + .with('@gt', () => { + result = refine(result, 'gt', val); + }) + .with('@gte', () => { + result = refine(result, 'gte', val); + }) + .with('@lt', () => { + result = refine(result, 'lt', val); + }) + .with('@lte', () => { + result = refine(result, 'lte', val); + }); + } + } + + return result; } export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeApplication[] | undefined): z.ZodSchema { @@ -122,6 +208,7 @@ export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeAp return schema; } + let result = schema; for (const attr of attrs) { const expr = attr.args?.[0]?.value; if (!expr) { @@ -133,9 +220,9 @@ export function addCustomValidation(schema: z.ZodSchema, attributes: AttributeAp if (pathExpr && ExpressionUtils.isArray(pathExpr)) { path = pathExpr.items.map((e) => ExpressionUtils.getLiteralValue(e) as string); } - schema = applyValidation(schema, expr, message, path); + result = applyValidation(result, expr, message, path); } - return schema; + return result; } function applyValidation( @@ -245,10 +332,10 @@ function evalCall(data: any, expr: CallExpression) { const min = getArgValue(expr.args?.[1]); const max = getArgValue(expr.args?.[2]); - if (min && fieldArg.length < min) { + if (min !== undefined && fieldArg.length < min) { return false; } - if (max && fieldArg.length > max) { + if (max !== undefined && fieldArg.length > max) { return false; } return true; diff --git a/packages/runtime/src/client/query-builder.ts b/packages/runtime/src/client/query-builder.ts index 91ec4dfa..be14605b 100644 --- a/packages/runtime/src/client/query-builder.ts +++ b/packages/runtime/src/client/query-builder.ts @@ -1,4 +1,4 @@ -import type Decimal from 'decimal.js'; +import { type Decimal } from 'decimal.js'; import type { Generated, Kysely } from 'kysely'; import type { FieldHasDefault, diff --git a/packages/runtime/src/utils/type-utils.ts b/packages/runtime/src/utils/type-utils.ts index e5cd1f33..e6bbff62 100644 --- a/packages/runtime/src/utils/type-utils.ts +++ b/packages/runtime/src/utils/type-utils.ts @@ -1,4 +1,4 @@ -import type Decimal from 'decimal.js'; +import type { Decimal } from 'decimal.js'; export type Optional = Omit & Partial>; diff --git a/packages/sdk/src/schema/schema.ts b/packages/sdk/src/schema/schema.ts index e8beefc9..7c2cbc72 100644 --- a/packages/sdk/src/schema/schema.ts +++ b/packages/sdk/src/schema/schema.ts @@ -1,4 +1,4 @@ -import type Decimal from 'decimal.js'; +import type { Decimal } from 'decimal.js'; import type { Expression } from './expression'; export type DataSourceProviderType = 'sqlite' | 'postgresql'; diff --git a/tests/e2e/orm/client-api/type-coverage.test.ts b/tests/e2e/orm/client-api/type-coverage.test.ts index 9ce29fce..71e543ce 100644 --- a/tests/e2e/orm/client-api/type-coverage.test.ts +++ b/tests/e2e/orm/client-api/type-coverage.test.ts @@ -1,6 +1,6 @@ -import Decimal from 'decimal.js'; -import { describe, expect, it } from 'vitest'; import { createTestClient, getTestDbProvider } from '@zenstackhq/testtools'; +import { Decimal } from 'decimal.js'; +import { describe, expect, it } from 'vitest'; describe('Zmodel type coverage tests', () => { it('supports all types - plain', async () => { diff --git a/tests/e2e/orm/validation/custom-validation.test.ts b/tests/e2e/orm/validation/custom-validation.test.ts index 543afb50..35667e4c 100644 --- a/tests/e2e/orm/validation/custom-validation.test.ts +++ b/tests/e2e/orm/validation/custom-validation.test.ts @@ -39,13 +39,13 @@ describe('Custom validation tests', () => { { provider: 'postgresql' }, ); - await db.foo.create({ data: { id: 1 } }); + await db.foo.create({ data: { id: 100 } }); for (const action of ['create', 'update']) { const _t = action === 'create' - ? (data: any) => db.foo.create({ data: { id: 2, ...data } }) - : (data: any) => db.foo.update({ where: { id: 1 }, data }); + ? (data: any) => db.foo.create({ data }) + : (data: any) => db.foo.update({ where: { id: 100 }, data }); // violates length await expect(_t({ str1: 'abd@efg.com' })).toBeRejectedByValidation(['invalid fields']); await expect(_t({ str1: 'a@b.c' })).toBeRejectedByValidation(['invalid fields']); @@ -91,7 +91,7 @@ describe('Custom validation tests', () => { thrown = true; expect((err as any).cause.issues[0].path).toEqual(['data', 'x', 'y']); } - expect(thrown); + expect(thrown).toBe(true); // satisfies all await expect( diff --git a/tests/e2e/orm/validation/nested.test.ts b/tests/e2e/orm/validation/nested.test.ts index 80849e37..0949a503 100644 --- a/tests/e2e/orm/validation/nested.test.ts +++ b/tests/e2e/orm/validation/nested.test.ts @@ -3,7 +3,8 @@ import { describe, expect, it } from 'vitest'; describe('Nested field validation tests', () => { it('works with nested create/update', async () => { - const db = await createTestClient(` + const db = await createTestClient( + ` model User { id Int @id @default(autoincrement()) profile Profile? @@ -16,7 +17,8 @@ describe('Nested field validation tests', () => { userId Int @unique @@validate(contains(email, 'zenstack'), 'email must be a zenstack email') } - `); + `, + ); await db.user.create({ data: { id: 1 } }); diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts index 5c8843f3..51d52366 100644 --- a/tests/e2e/orm/validation/toplevel.test.ts +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -1,9 +1,11 @@ import { createTestClient } from '@zenstackhq/testtools'; +import { Decimal } from 'decimal.js'; import { describe, expect, it } from 'vitest'; describe('Toplevel field validation tests', () => { it('works with string fields', async () => { - const db = await createTestClient(` + const db = await createTestClient( + ` model Foo { id Int @id @default(autoincrement()) str1 String? @length(2, 4) @startsWith('a') @endsWith('b') @contains('m') @regex('b{2}') @@ -13,19 +15,22 @@ describe('Toplevel field validation tests', () => { str5 String? @trim @lower str6 String? @upper } - `); + `, + ); - await db.foo.create({ data: { id: 1 } }); + await db.foo.create({ data: { id: 100 } }); for (const action of ['create', 'update', 'upsert', 'updateMany']) { + console.log(`Testing action: ${action}`); const _t = action === 'create' ? (data: any) => db.foo.create({ data }) : action === 'update' - ? (data: any) => db.foo.update({ where: { id: 1 }, data }) + ? (data: any) => db.foo.update({ where: { id: 100 }, data }) : action === 'upsert' - ? (data: any) => db.foo.upsert({ where: { id: 1 }, create: data, update: data }) - : (data: any) => db.foo.updateMany({ where: { id: 1 }, data }); + ? (data: any) => + db.foo.upsert({ where: { id: 100 }, create: { id: 101, ...data }, update: data }) + : (data: any) => db.foo.updateMany({ where: { id: 100 }, data }); // violates @length min await expect(_t({ str1: 'a' })).toBeRejectedByValidation(); @@ -83,36 +88,86 @@ describe('Toplevel field validation tests', () => { }); it('works with number fields', async () => { - const db = await createTestClient(` + const db = await createTestClient( + ` model Foo { id Int @id @default(autoincrement()) int1 Int? @gt(2) @lt(4) int2 Int? @gte(2) @lte(4) } - `); + `, + ); - await db.foo.create({ data: { id: 1 } }); + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); - for (const action of ['create', 'update']) { - const _t = - action === 'create' - ? (data: any) => db.foo.create({ data }) - : (data: any) => db.foo.update({ where: { id: 1 }, data }); + // violates @lt + await expect(db.foo.create({ data: { int1: 4 } })).toBeRejectedByValidation(); - // violates @gt - await expect(_t({ int1: 1 })).toBeRejectedByValidation(); + // violates @gte + await expect(db.foo.create({ data: { int2: 1 } })).toBeRejectedByValidation(); - // violates @lt - await expect(_t({ int1: 4 })).toBeRejectedByValidation(); + // violates @lte + await expect(db.foo.create({ data: { int2: 5 } })).toBeRejectedByValidation(); - // violates @gte - await expect(_t({ int2: 1 })).toBeRejectedByValidation(); + // satisfies all + await expect(db.foo.create({ data: { int1: 3, int2: 4 } })).toResolveTruthy(); + }); - // violates @lte - await expect(_t({ int2: 5 })).toBeRejectedByValidation(); + it('works with bigint fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + int1 BigInt? @gt(2) @lt(4) + int2 BigInt? @gte(2) @lte(4) + } + `, + ); - // satisfies all - await expect(_t({ int1: 3, int2: 4 })).toResolveTruthy(); + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); + + // violates @lt + await expect(db.foo.create({ data: { int1: 4 } })).toBeRejectedByValidation(); + + // violates @gte + await expect(db.foo.create({ data: { int2: 1n } })).toBeRejectedByValidation(); + + // violates @lte + await expect(db.foo.create({ data: { int2: 5n } })).toBeRejectedByValidation(); + + // satisfies all + await expect(db.foo.create({ data: { int1: 3, int2: 4 } })).toResolveTruthy(); + }); + + it('works with decimal fields', async () => { + const db = await createTestClient( + ` + model Foo { + id Int @id @default(autoincrement()) + int1 Decimal? @gt(2) @lt(4) + int2 Decimal? @gte(2) @lte(4) } + `, + ); + + // violates @gt + await expect(db.foo.create({ data: { int1: 1 } })).toBeRejectedByValidation(); + + // violates @lt + await expect(db.foo.create({ data: { int1: new Decimal(4) } })).toBeRejectedByValidation(); + + // invalid decimal string + await expect(db.foo.create({ data: { int2: 'f1.2' } })).toBeRejectedByValidation(); + + // violates @gte + await expect(db.foo.create({ data: { int2: '1.1' } })).toBeRejectedByValidation(); + + // violates @lte + await expect(db.foo.create({ data: { int2: '5.12345678' } })).toBeRejectedByValidation(); + + // satisfies all + await expect(db.foo.create({ data: { int1: '3.3', int2: new Decimal(3.9) } })).toResolveTruthy(); }); }); From c618515ca2f3e881a5d19d6c1e34ebaa5de4ec06 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 8 Oct 2025 12:50:31 -0700 Subject: [PATCH 3/4] update --- packages/language/res/stdlib.zmodel | 8 ++++---- packages/runtime/src/client/contract.ts | 2 +- packages/runtime/src/client/crud/dialects/postgresql.ts | 2 +- packages/runtime/src/client/crud/dialects/sqlite.ts | 2 +- packages/runtime/src/client/crud/validator/index.ts | 2 +- packages/runtime/src/client/crud/validator/utils.ts | 6 +++--- packages/runtime/src/client/query-builder.ts | 2 +- packages/runtime/src/utils/type-utils.ts | 2 +- packages/sdk/src/schema/schema.ts | 2 +- tests/e2e/orm/client-api/type-coverage.test.ts | 2 +- tests/e2e/orm/validation/toplevel.test.ts | 2 +- 11 files changed, 16 insertions(+), 16 deletions(-) diff --git a/packages/language/res/stdlib.zmodel b/packages/language/res/stdlib.zmodel index 9684692d..85dc8e91 100644 --- a/packages/language/res/stdlib.zmodel +++ b/packages/language/res/stdlib.zmodel @@ -543,22 +543,22 @@ attribute @upper() @@@targetField([StringField]) @@@validation /** * Validates a number field is greater than the given value. */ -attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation +attribute @gt(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is greater than or equal to the given value. */ -attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation +attribute @gte(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than the given value. */ -attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation +attribute @lt(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates a number field is less than or equal to the given value. */ -attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation +attribute @lte(_ value: Any, _ message: String?) @@@targetField([IntField, FloatField, DecimalField, BigIntField]) @@@validation /** * Validates the entity with a complex condition. diff --git a/packages/runtime/src/client/contract.ts b/packages/runtime/src/client/contract.ts index 002f478c..2374bc6e 100644 --- a/packages/runtime/src/client/contract.ts +++ b/packages/runtime/src/client/contract.ts @@ -1,4 +1,4 @@ -import type { Decimal } from 'decimal.js'; +import type Decimal from 'decimal.js'; import { type GetModels, type IsDelegateModel, type ProcedureDef, type SchemaDef } from '../schema'; import type { AuthType } from '../schema/auth'; import type { OrUndefinedIf, Simplify, UnwrapTuplePromises } from '../utils/type-utils'; diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 10bed59a..b6c40661 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import { sql, type Expression, diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 20aeead8..5c024dfb 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import { ExpressionWrapper, sql, diff --git a/packages/runtime/src/client/crud/validator/index.ts b/packages/runtime/src/client/crud/validator/index.ts index 38bd421b..90cc67e0 100644 --- a/packages/runtime/src/client/crud/validator/index.ts +++ b/packages/runtime/src/client/crud/validator/index.ts @@ -1,5 +1,5 @@ import { invariant } from '@zenstackhq/common-helpers'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import stableStringify from 'json-stable-stringify'; import { match, P } from 'ts-pattern'; import { z, ZodSchema, ZodType } from 'zod'; diff --git a/packages/runtime/src/client/crud/validator/utils.ts b/packages/runtime/src/client/crud/validator/utils.ts index 1d3a12d1..6b0a17d5 100644 --- a/packages/runtime/src/client/crud/validator/utils.ts +++ b/packages/runtime/src/client/crud/validator/utils.ts @@ -8,7 +8,7 @@ import type { MemberExpression, UnaryExpression, } from '@zenstackhq/sdk/schema'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import { match, P } from 'ts-pattern'; import { z } from 'zod'; import { ExpressionUtils } from '../../../schema'; @@ -275,8 +275,8 @@ function evalBinary(data: any, expr: BinaryExpression) { return match(expr.op) .with('&&', () => Boolean(left) && Boolean(right)) .with('||', () => Boolean(left) || Boolean(right)) - .with('==', () => left == right) // eslint-disable-line eqeqeq - .with('!=', () => left != right) // eslint-disable-line eqeqeq + .with('==', () => left == right) + .with('!=', () => left != right) .with('<', () => (left as any) < (right as any)) .with('<=', () => (left as any) <= (right as any)) .with('>', () => (left as any) > (right as any)) diff --git a/packages/runtime/src/client/query-builder.ts b/packages/runtime/src/client/query-builder.ts index be14605b..91ec4dfa 100644 --- a/packages/runtime/src/client/query-builder.ts +++ b/packages/runtime/src/client/query-builder.ts @@ -1,4 +1,4 @@ -import { type Decimal } from 'decimal.js'; +import type Decimal from 'decimal.js'; import type { Generated, Kysely } from 'kysely'; import type { FieldHasDefault, diff --git a/packages/runtime/src/utils/type-utils.ts b/packages/runtime/src/utils/type-utils.ts index e6bbff62..e5cd1f33 100644 --- a/packages/runtime/src/utils/type-utils.ts +++ b/packages/runtime/src/utils/type-utils.ts @@ -1,4 +1,4 @@ -import type { Decimal } from 'decimal.js'; +import type Decimal from 'decimal.js'; export type Optional = Omit & Partial>; diff --git a/packages/sdk/src/schema/schema.ts b/packages/sdk/src/schema/schema.ts index 7c2cbc72..e8beefc9 100644 --- a/packages/sdk/src/schema/schema.ts +++ b/packages/sdk/src/schema/schema.ts @@ -1,4 +1,4 @@ -import type { Decimal } from 'decimal.js'; +import type Decimal from 'decimal.js'; import type { Expression } from './expression'; export type DataSourceProviderType = 'sqlite' | 'postgresql'; diff --git a/tests/e2e/orm/client-api/type-coverage.test.ts b/tests/e2e/orm/client-api/type-coverage.test.ts index 71e543ce..a0c24880 100644 --- a/tests/e2e/orm/client-api/type-coverage.test.ts +++ b/tests/e2e/orm/client-api/type-coverage.test.ts @@ -1,5 +1,5 @@ import { createTestClient, getTestDbProvider } from '@zenstackhq/testtools'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import { describe, expect, it } from 'vitest'; describe('Zmodel type coverage tests', () => { diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts index 51d52366..ed2c513c 100644 --- a/tests/e2e/orm/validation/toplevel.test.ts +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -1,5 +1,5 @@ import { createTestClient } from '@zenstackhq/testtools'; -import { Decimal } from 'decimal.js'; +import Decimal from 'decimal.js'; import { describe, expect, it } from 'vitest'; describe('Toplevel field validation tests', () => { From 9639a7790c5cfb5b21406bf4e236c6789267d41c Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 8 Oct 2025 12:59:16 -0700 Subject: [PATCH 4/4] update --- tests/e2e/orm/validation/toplevel.test.ts | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/e2e/orm/validation/toplevel.test.ts b/tests/e2e/orm/validation/toplevel.test.ts index ed2c513c..a7d76475 100644 --- a/tests/e2e/orm/validation/toplevel.test.ts +++ b/tests/e2e/orm/validation/toplevel.test.ts @@ -1,4 +1,4 @@ -import { createTestClient } from '@zenstackhq/testtools'; +import { createTestClient, loadSchemaWithError } from '@zenstackhq/testtools'; import Decimal from 'decimal.js'; import { describe, expect, it } from 'vitest'; @@ -170,4 +170,40 @@ describe('Toplevel field validation tests', () => { // satisfies all await expect(db.foo.create({ data: { int1: '3.3', int2: new Decimal(3.9) } })).toResolveTruthy(); }); + + it('rejects accessing relation fields', async () => { + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + bars Bar[] + @@validate(bars != null) + } + + model Bar { + id Int @id @default(autoincrement()) + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + } + `, + 'cannot use relation fields', + ); + + await loadSchemaWithError( + ` + model Foo { + id Int @id @default(autoincrement()) + bars Bar[] + @@validate(bars.fooId > 0) + } + + model Bar { + id Int @id @default(autoincrement()) + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + } + `, + 'cannot use relation fields', + ); + }); });