Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
- [x] Sorting
- [x] Pagination
- [x] Distinct
- [ ] Update
- [x] Update
- [x] Input validation
- [x] Top-level
- [x] Nested to-many
- [x] Nested to-one
- [x] Incremental update for numeric fields
- [x] Array update
- [x] Upsert
- [ ] Implement with "on conflict"
- [x] Delete
- [x] Aggregation
- [x] Count
Expand All @@ -54,22 +55,23 @@
- [x] Computed fields
- [ ] Prisma client extension
- [ ] Misc
- [ ] Cache validation schemas
- [ ] Compound ID
- [ ] Cross field comparison
- [ ] Many-to-many relation
- [ ] Cache validation schemas
- [x] Many-to-many relation
- [ ] Empty AND/OR/NOT behavior
- [?] Logging
- [ ] Error system
- [?] Error system
- [x] Custom table name
- [x] Custom field name
- [ ] Empty AND/OR/NOT behavior
- [?] Strict undefined check
- [ ] Access Policy
- [ ] Short-circuit pre-create check for scalar-field only policies
- [ ] Inject "replace into"
- [ ] Inject "on conflict do update"
- [ ] Polymorphism
- [x] Migration
- [ ] Databases
- [x] SQLite
- [x] PostgreSQL
- [ ] Schema
- [ ] MySQL
3 changes: 1 addition & 2 deletions packages/cli/src/actions/generate.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { isPlugin, LiteralExpr, type Model } from '@zenstackhq/language/ast';
import type { CliGenerator } from '@zenstackhq/runtime/client';
import { TsSchemaGenerator } from '@zenstackhq/sdk';
import { PrismaSchemaGenerator, TsSchemaGenerator } from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'node:fs';
import path from 'node:path';
import invariant from 'tiny-invariant';
import { PrismaSchemaGenerator } from '../prisma/prisma-schema-generator';
import { getSchemaFile, loadSchemaDocument } from './action-utils';

type Options = {
Expand Down
2 changes: 1 addition & 1 deletion packages/cli/test/ts-schema-gen.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { describe, expect, it } from 'vitest';

describe('TypeScript schema generation tests', () => {
it('generates correct data models', async () => {
const schema = await generateTsSchema(`
const { schema } = await generateTsSchema(`
model User {
id String @id @default(uuid())
name String
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"@types/tmp": "^0.2.6",
"@zenstackhq/language": "workspace:*",
"@zenstackhq/testtools": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"tmp": "^0.2.3"
}
}
18 changes: 18 additions & 0 deletions packages/runtime/src/client/client-impl.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {
DefaultConnectionProvider,
DefaultQueryExecutor,
Kysely,
Log,
PostgresDialect,
Expand Down Expand Up @@ -47,6 +48,7 @@ export const ZenStackClient = function <Schema extends SchemaDef>(

export class ClientImpl<Schema extends SchemaDef> {
private kysely: ToKysely<Schema>;
private kyselyRaw: ToKysely<any>;
public readonly $options: ClientOptions<Schema>;
public readonly $schema: Schema;
readonly kyselyProps: KyselyProps;
Expand Down Expand Up @@ -77,6 +79,7 @@ export class ClientImpl<Schema extends SchemaDef> {
new DefaultConnectionProvider(baseClient.kyselyProps.driver)
),
};
this.kyselyRaw = baseClient.kyselyRaw;
} else {
const dialect = this.getKyselyDialect();
const driver = new ZenStackDriver(
Expand All @@ -103,6 +106,17 @@ export class ClientImpl<Schema extends SchemaDef> {
driver,
executor,
};

// raw kysely instance with default executor
this.kyselyRaw = new Kysely({
...this.kyselyProps,
executor: new DefaultQueryExecutor(
compiler,
adapter,
connectionProvider,
[]
),
});
}

this.kysely = new Kysely(this.kyselyProps);
Expand All @@ -114,6 +128,10 @@ export class ClientImpl<Schema extends SchemaDef> {
return this.kysely;
}

public get $qbRaw() {
return this.kyselyRaw;
}

private getKyselyDialect() {
return match(this.schema.provider.type)
.with('sqlite', () => this.makeSqliteKyselyDialect())
Expand Down
5 changes: 5 additions & 0 deletions packages/runtime/src/client/contract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ export type ClientContract<Schema extends SchemaDef> = {
*/
readonly $qb: ToKysely<Schema>;

/**
* The raw Kysely query builder without any ZenStack enhancements.
*/
readonly $qbRaw: ToKysely<any>;

/**
* Starts a transaction.
*/
Expand Down
103 changes: 66 additions & 37 deletions packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import {
buildFieldRef,
buildJoinPairs,
getField,
getIdFields,
getManyToManyRelation,
getRelationForeignKeyFieldPairs,
isEnum,
makeDefaultOrderBy,
Expand Down Expand Up @@ -68,18 +70,18 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
eb: ExpressionBuilder<any, any>,
model: string,
modelAlias: string,
where: object | undefined
where: boolean | object | undefined
) {
let result = this.true(eb);

if (where === undefined) {
return result;
if (where === true || where === undefined) {
return this.true(eb);
}

if (where === null || typeof where !== 'object') {
throw new InternalError('impossible null as filter');
if (where === false) {
return this.false(eb);
}

let result = this.true(eb);

for (const [key, payload] of Object.entries(where)) {
if (payload === undefined) {
continue;
Expand Down Expand Up @@ -148,7 +150,12 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}

// call expression builder and combine the results
if ('$expr' in where && typeof where['$expr'] === 'function') {
if (
typeof where === 'object' &&
where !== null &&
'$expr' in where &&
typeof where['$expr'] === 'function'
) {
result = this.and(eb, result, where['$expr'](eb));
}

Expand Down Expand Up @@ -356,45 +363,67 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
fieldDef: FieldDef,
payload: any
) {
const relationModel = fieldDef.type;

const relationKeyPairs = getRelationForeignKeyFieldPairs(
this.schema,
model,
field
);

// null check needs to be converted to fk "is null" checks
if (payload === null) {
return eb(sql.ref(`${table}.${field}`), 'is', null);
}

const relationModel = fieldDef.type;

const buildPkFkWhereRefs = (eb: ExpressionBuilder<any, any>) => {
let r = this.true(eb);
for (const { fk, pk } of relationKeyPairs.keyPairs) {
if (relationKeyPairs.ownedByModel) {
r = this.and(
eb,
r,
eb(
sql.ref(`${table}.${fk}`),
'=',
sql.ref(`${relationModel}.${pk}`)
)
);
} else {
r = this.and(
eb,
r,
eb(
sql.ref(`${table}.${pk}`),
const m2m = getManyToManyRelation(this.schema, model, field);
if (m2m) {
// many-to-many relation
const modelIdField = getIdFields(this.schema, model)[0]!;
const relationIdField = getIdFields(
this.schema,
relationModel
)[0]!;
return eb(
sql.ref(`${relationModel}.${relationIdField}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
'=',
sql.ref(`${relationModel}.${fk}`)
sql.ref(`${table}.${modelIdField}`)
)
);
);
} else {
const relationKeyPairs = getRelationForeignKeyFieldPairs(
this.schema,
model,
field
);

let result = this.true(eb);
for (const { fk, pk } of relationKeyPairs.keyPairs) {
if (relationKeyPairs.ownedByModel) {
result = this.and(
eb,
result,
eb(
sql.ref(`${table}.${fk}`),
'=',
sql.ref(`${relationModel}.${pk}`)
)
);
} else {
result = this.and(
eb,
result,
eb(
sql.ref(`${table}.${pk}`),
'=',
sql.ref(`${relationModel}.${fk}`)
)
);
}
}
return result;
}
return r;
};

let result = this.true(eb);
Expand Down
65 changes: 54 additions & 11 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import {
type RawBuilder,
type SelectQueryBuilder,
} from 'kysely';
import invariant from 'tiny-invariant';
import { match } from 'ts-pattern';
import type { SchemaDef } from '../../../schema';
import type { BuiltinType, FieldDef, GetModels } from '../../../schema/schema';
import type { FindArgs } from '../../crud-types';
import {
buildFieldRef,
buildJoinPairs,
getIdFields,
getManyToManyRelation,
requireField,
requireModel,
} from '../../query-utils';
Expand Down Expand Up @@ -129,21 +132,61 @@ export class PostgresCrudDialect<
}

// add join conditions
const joinPairs = buildJoinPairs(

const m2m = getManyToManyRelation(
this.schema,
model,
parentName,
relationField,
relationModel
relationField
);
subQuery = subQuery.where((eb) =>
this.and(
eb,
...joinPairs.map(([left, right]) =>
eb(sql.ref(left), '=', sql.ref(right))

if (m2m) {
// many-to-many relation
const parentIds = getIdFields(this.schema, model);
const relationIds = getIdFields(
this.schema,
relationModel
);
invariant(
parentIds.length === 1,
'many-to-many relation must have exactly one id field'
);
invariant(
relationIds.length === 1,
'many-to-many relation must have exactly one id field'
);
subQuery = subQuery.where(
eb(
eb.ref(`${relationModel}.${relationIds[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(
`${m2m.joinTable}.${m2m.otherFkName}`
)
.whereRef(
`${parentName}.${parentIds[0]}`,
'=',
`${m2m.joinTable}.${m2m.parentFkName}`
)
)
)
);
);
} else {
const joinPairs = buildJoinPairs(
this.schema,
model,
parentName,
relationField,
relationModel
);
subQuery = subQuery.where((eb) =>
this.and(
eb,
...joinPairs.map(([left, right]) =>
eb(sql.ref(left), '=', sql.ref(right))
)
)
);
}

return subQuery.as(joinTableName);
});
Expand Down
Loading