Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ ZenStack v3 allows you to define database-evaluated computed fields with the fol
postCount: (eb) =>
eb
.selectFrom('Post')
.whereRef('Post.authorId', '=', 'User.id')
.whereRef('Post.authorId', '=', 'id')
.select(({ fn }) =>
fn.countAll<number>().as('postCount')
),
Expand Down
82 changes: 52 additions & 30 deletions packages/runtime/src/client/crud/dialects/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {

// #region common query builders

buildSelectModel(eb: ExpressionBuilder<any, any>, model: string) {
buildSelectModel(eb: ExpressionBuilder<any, any>, model: string, modelAlias: string) {
const modelDef = requireModel(this.schema, model);
let result = eb.selectFrom(model);
let result = eb.selectFrom(model === modelAlias ? model : `${model} as ${modelAlias}`);
// join all delegate bases
let joinBase = modelDef.baseModel;
while (joinBase) {
result = this.buildDelegateJoin(model, joinBase, result);
result = this.buildDelegateJoin(model, modelAlias, joinBase, result);
joinBase = requireModel(this.schema, joinBase).baseModel;
}
return result;
Expand All @@ -63,12 +63,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
model: GetModels<Schema>,
args: FindArgs<Schema, GetModels<Schema>, true>,
query: SelectQueryBuilder<any, any, {}>,
modelAlias: string,
) {
let result = query;

// where
if (args.where) {
result = result.where((eb) => this.buildFilter(eb, model, model, args?.where));
result = result.where((eb) => this.buildFilter(eb, model, modelAlias, args?.where));
}

// skip && take
Expand All @@ -85,7 +86,7 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
result = this.buildOrderBy(
result,
model,
model,
modelAlias,
args.orderBy,
skip !== undefined || take !== undefined,
negateOrderBy,
Expand All @@ -95,14 +96,14 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
if ('distinct' in args && (args as any).distinct) {
const distinct = ensureArray((args as any).distinct) as string[];
if (this.supportsDistinctOn) {
result = result.distinctOn(distinct.map((f) => sql.ref(`${model}.${f}`)));
result = result.distinctOn(distinct.map((f) => sql.ref(`${modelAlias}.${f}`)));
} else {
throw new QueryError(`"distinct" is not supported by "${this.schema.provider.type}" provider`);
}
}

if (args.cursor) {
result = this.buildCursorFilter(model, result, args.cursor, args.orderBy, negateOrderBy);
result = this.buildCursorFilter(model, result, args.cursor, args.orderBy, negateOrderBy, modelAlias);
}
return result;
}
Expand Down Expand Up @@ -172,13 +173,15 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
cursor: FindArgs<Schema, GetModels<Schema>, true>['cursor'],
orderBy: FindArgs<Schema, GetModels<Schema>, true>['orderBy'],
negateOrderBy: boolean,
modelAlias: string,
) {
const _orderBy = orderBy ?? makeDefaultOrderBy(this.schema, model);

const orderByItems = ensureArray(_orderBy).flatMap((obj) => Object.entries<SortOrder>(obj));

const eb = expressionBuilder<any, any>();
const cursorFilter = this.buildFilter(eb, model, model, cursor);
const subQueryAlias = `${model}$cursor$sub`;
const cursorFilter = this.buildFilter(eb, model, subQueryAlias, cursor);

let result = query;
const filters: ExpressionWrapper<any, any, any>[] = [];
Expand All @@ -192,9 +195,11 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
const op = j === i ? (_order === 'asc' ? '>=' : '<=') : '=';
andFilters.push(
eb(
eb.ref(`${model}.${field}`),
eb.ref(`${modelAlias}.${field}`),
op,
eb.selectFrom(model).select(`${model}.${field}`).where(cursorFilter),
this.buildSelectModel(eb, model, subQueryAlias)
.select(`${subQueryAlias}.${field}`)
.where(cursorFilter),
),
);
}
Expand Down Expand Up @@ -341,34 +346,38 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
private buildToManyRelationFilter(
eb: ExpressionBuilder<any, any>,
model: string,
table: string,
modelAlias: string,
field: string,
fieldDef: FieldDef,
payload: any,
) {
// null check needs to be converted to fk "is null" checks
if (payload === null) {
return eb(sql.ref(`${table}.${field}`), 'is', null);
return eb(sql.ref(`${modelAlias}.${field}`), 'is', null);
}

const relationModel = fieldDef.type;

// evaluating the filter involves creating an inner select,
// give it an alias to avoid conflict
const relationFilterSelectAlias = `${modelAlias}$${field}$filter`;

const buildPkFkWhereRefs = (eb: ExpressionBuilder<any, any>) => {
const m2m = getManyToManyRelation(this.schema, model, field);
if (m2m) {
// many-to-many relation
const modelIdField = getIdFields(this.schema, model)[0]!;
const relationIdField = getIdFields(this.schema, relationModel)[0]!;
return eb(
sql.ref(`${relationModel}.${relationIdField}`),
sql.ref(`${relationFilterSelectAlias}.${relationIdField}`),
'in',
eb
.selectFrom(m2m.joinTable)
.select(`${m2m.joinTable}.${m2m.otherFkName}`)
.whereRef(
sql.ref(`${m2m.joinTable}.${m2m.parentFkName}`),
'=',
sql.ref(`${table}.${modelIdField}`),
sql.ref(`${modelAlias}.${modelIdField}`),
),
);
} else {
Expand All @@ -380,13 +389,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
result = this.and(
eb,
result,
eb(sql.ref(`${table}.${fk}`), '=', sql.ref(`${relationModel}.${pk}`)),
eb(sql.ref(`${modelAlias}.${fk}`), '=', sql.ref(`${relationFilterSelectAlias}.${pk}`)),
);
} else {
result = this.and(
eb,
result,
eb(sql.ref(`${table}.${pk}`), '=', sql.ref(`${relationModel}.${fk}`)),
eb(sql.ref(`${modelAlias}.${pk}`), '=', sql.ref(`${relationFilterSelectAlias}.${fk}`)),
);
}
}
Expand All @@ -407,10 +416,12 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
eb,
result,
eb(
this.buildSelectModel(eb, relationModel)
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
.where(buildPkFkWhereRefs(eb))
.where((eb1) => this.buildFilter(eb1, relationModel, relationModel, subPayload)),
.where((eb1) =>
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
),
'>',
0,
),
Expand All @@ -423,11 +434,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
eb,
result,
eb(
this.buildSelectModel(eb, relationModel)
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
.where(buildPkFkWhereRefs(eb))
.where((eb1) =>
eb1.not(this.buildFilter(eb1, relationModel, relationModel, subPayload)),
eb1.not(
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
),
),
'=',
0,
Expand All @@ -441,10 +454,12 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
eb,
result,
eb(
this.buildSelectModel(eb, relationModel)
this.buildSelectModel(eb, relationModel, relationFilterSelectAlias)
.select((eb1) => eb1.fn.count(eb1.lit(1)).as('$count'))
.where(buildPkFkWhereRefs(eb))
.where((eb1) => this.buildFilter(eb1, relationModel, relationModel, subPayload)),
.where((eb1) =>
this.buildFilter(eb1, relationModel, relationFilterSelectAlias, subPayload),
),
'=',
0,
),
Expand Down Expand Up @@ -874,8 +889,9 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
);
const sort = this.negateSort(value._count, negated);
result = result.orderBy((eb) => {
let subQuery = this.buildSelectModel(eb, relationModel);
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, relationModel);
const subQueryAlias = `${modelAlias}$orderBy$${field}$count`;
let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias);
const joinPairs = buildJoinPairs(this.schema, model, modelAlias, field, subQueryAlias);
subQuery = subQuery.where(() =>
this.and(
eb,
Expand Down Expand Up @@ -909,7 +925,8 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
buildSelectAllFields(
model: string,
query: SelectQueryBuilder<any, any, any>,
omit?: Record<string, boolean | undefined>,
omit: Record<string, boolean | undefined> | undefined,
modelAlias: string,
) {
const modelDef = requireModel(this.schema, model);
let result = query;
Expand All @@ -921,13 +938,13 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
if (omit?.[field] === true) {
continue;
}
result = this.buildSelectField(result, model, model, field);
result = this.buildSelectField(result, model, modelAlias, field);
}

// select all fields from delegate descendants and pack into a JSON field `$delegate$Model`
const descendants = getDelegateDescendantModels(this.schema, model);
for (const subModel of descendants) {
result = this.buildDelegateJoin(model, subModel.name, result);
result = this.buildDelegateJoin(model, modelAlias, subModel.name, result);
result = result.select((eb) => {
const jsonObject: Record<string, Expression<any>> = {};
for (const field of Object.keys(subModel.fields)) {
Expand Down Expand Up @@ -964,11 +981,16 @@ export abstract class BaseCrudDialect<Schema extends SchemaDef> {
}
}

buildDelegateJoin(thisModel: string, otherModel: string, query: SelectQueryBuilder<any, any, any>) {
buildDelegateJoin(
thisModel: string,
thisModelAlias: string,
otherModelAlias: string,
query: SelectQueryBuilder<any, any, any>,
) {
const idFields = getIdFields(this.schema, thisModel);
query = query.leftJoin(otherModel, (qb) => {
query = query.leftJoin(otherModelAlias, (qb) => {
for (const idField of idFields) {
qb = qb.onRef(`${thisModel}.${idField}`, '=', `${otherModel}.${idField}`);
qb = qb.onRef(`${thisModelAlias}.${idField}`, '=', `${otherModelAlias}.${idField}`);
}
return qb;
});
Expand Down
31 changes: 25 additions & 6 deletions packages/runtime/src/client/crud/dialects/postgresql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,22 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale

// however if there're filter/orderBy/take/skip,
// we need to build a subquery to handle them before aggregation

// give sub query an alias to avoid conflict with parent scope
// (e.g., for cases like self-relation)
const subQueryAlias = `${relationModel}$${relationField}$sub`;

result = eb.selectFrom(() => {
let subQuery = this.buildSelectModel(eb, relationModel);
let subQuery = this.buildSelectModel(eb, relationModel, subQueryAlias);
subQuery = this.buildSelectAllFields(
relationModel,
subQuery,
typeof payload === 'object' ? payload?.omit : undefined,
subQueryAlias,
);

if (payload && typeof payload === 'object') {
subQuery = this.buildFilterSortTake(relationModel, payload, subQuery);
subQuery = this.buildFilterSortTake(relationModel, payload, subQuery, subQueryAlias);
}

// add join conditions
Expand All @@ -106,7 +112,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
invariant(relationIds.length === 1, 'many-to-many relation must have exactly one id field');
subQuery = subQuery.where(
eb(
eb.ref(`${relationModel}.${relationIds[0]}`),
eb.ref(`${subQueryAlias}.${relationIds[0]}`),
'in',
eb
.selectFrom(m2m.joinTable)
Expand All @@ -119,7 +125,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
),
);
} else {
const joinPairs = buildJoinPairs(this.schema, model, parentName, relationField, relationModel);
const joinPairs = buildJoinPairs(this.schema, model, parentName, relationField, subQueryAlias);
subQuery = subQuery.where((eb) =>
this.and(eb, ...joinPairs.map(([left, right]) => eb(sql.ref(left), '=', sql.ref(right)))),
);
Expand All @@ -130,6 +136,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale

result = this.buildRelationObjectSelect(
relationModel,
joinTableName,
relationField,
relationFieldDef,
result,
Expand All @@ -149,14 +156,22 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale

private buildRelationObjectSelect(
relationModel: string,
relationModelAlias: string,
relationField: string,
relationFieldDef: FieldDef,
qb: SelectQueryBuilder<any, any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
parentName: string,
) {
qb = qb.select((eb) => {
const objArgs = this.buildRelationObjectArgs(relationModel, relationField, eb, payload, parentName);
const objArgs = this.buildRelationObjectArgs(
relationModel,
relationModelAlias,
relationField,
eb,
payload,
parentName,
);

if (relationFieldDef.array) {
return eb.fn
Expand All @@ -172,6 +187,7 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale

private buildRelationObjectArgs(
relationModel: string,
relationModelAlias: string,
relationField: string,
eb: ExpressionBuilder<any, any>,
payload: true | FindArgs<Schema, GetModels<Schema>, true>,
Expand Down Expand Up @@ -202,7 +218,10 @@ export class PostgresCrudDialect<Schema extends SchemaDef> extends BaseCrudDiale
...Object.entries(relationModelDef.fields)
.filter(([, value]) => !value.relation)
.filter(([name]) => !(typeof payload === 'object' && (payload.omit as any)?.[name] === true))
.map(([field]) => [sql.lit(field), this.fieldRef(relationModel, field, eb, undefined, false)])
.map(([field]) => [
sql.lit(field),
this.fieldRef(relationModel, field, eb, relationModelAlias, false),
])
.flatMap((v) => v),
);
} else if (payload.select) {
Expand Down
Loading