diff --git a/TODO.md b/TODO.md index a7fd7374..f2393796 100644 --- a/TODO.md +++ b/TODO.md @@ -42,7 +42,7 @@ - [x] Nested to-many - [x] Nested to-one - [x] Incremental update for numeric fields - - [ ] Array update + - [x] Array update - [x] Upsert - [x] Delete - [x] Aggregation diff --git a/packages/runtime/src/client/crud-types.ts b/packages/runtime/src/client/crud-types.ts index fc98c5a2..999b1606 100644 --- a/packages/runtime/src/client/crud-types.ts +++ b/packages/runtime/src/client/crud-types.ts @@ -168,6 +168,8 @@ export type WhereInput< GetFieldType, FieldIsOptional > + : FieldIsArray extends true + ? ArrayFilter> : // primitive PrimitiveFilter< GetFieldType, @@ -183,7 +185,7 @@ export type WhereInput< NOT?: OrArray>; }; -export type EnumFilter< +type EnumFilter< Schema extends SchemaDef, T extends GetEnums, Nullable extends boolean @@ -196,7 +198,15 @@ export type EnumFilter< not?: EnumFilter; }; -export type PrimitiveFilter< +type ArrayFilter = { + equals?: MapBaseType[]; + has?: MapBaseType; + hasEvery?: MapBaseType[]; + hasSome?: MapBaseType[]; + isEmpty?: boolean; +}; + +type PrimitiveFilter< T extends string, Nullable extends boolean > = T extends 'String' @@ -622,7 +632,7 @@ type CreateScalarPayload< Schema, Model, { - [Key in ScalarFields]: MapFieldType< + [Key in ScalarFields]: ScalarCreatePayload< Schema, Model, Key @@ -630,6 +640,18 @@ type CreateScalarPayload< } >; +type ScalarCreatePayload< + Schema extends SchemaDef, + Model extends GetModels, + Field extends ScalarFields +> = + | MapFieldType + | (FieldIsArray extends true + ? { + set?: MapFieldType[]; + } + : never); + type CreateFKPayload< Schema extends SchemaDef, Model extends GetModels @@ -802,6 +824,12 @@ type ScalarUpdatePayload< multiply?: number; divide?: number; } + : never) + | (FieldIsArray extends true + ? { + set?: MapFieldType[]; + push?: OrArray, true>; + } : never); export type UpdateRelationInput< @@ -937,7 +965,9 @@ type NumericFields< Model, Key > extends 'Int' | 'Float' | 'BigInt' | 'Decimal' - ? Key + ? FieldIsArray extends true + ? never + : Key : never]: GetField; }; diff --git a/packages/runtime/src/client/crud/dialects/base.ts b/packages/runtime/src/client/crud/dialects/base.ts index 0365224c..f89fbf15 100644 --- a/packages/runtime/src/client/crud/dialects/base.ts +++ b/packages/runtime/src/client/crud/dialects/base.ts @@ -118,6 +118,19 @@ export abstract class BaseCrudDialect { payload ) ); + } else if (fieldDef.array) { + result = this.and( + eb, + result, + this.buildArrayFilter( + eb, + model, + modelAlias, + key, + fieldDef, + payload + ) + ); } else { result = this.and( eb, @@ -477,37 +490,120 @@ export abstract class BaseCrudDialect { return result; } + private buildArrayFilter( + eb: ExpressionBuilder, + model: string, + modelAlias: string, + field: string, + fieldDef: FieldDef, + payload: any + ) { + const clauses: Expression[] = []; + const fieldType = fieldDef.type as BuiltinType; + const fieldRef = buildFieldRef( + this.schema, + model, + field, + this.options, + eb, + modelAlias + ); + + for (const [key, _value] of Object.entries(payload)) { + if (_value === undefined) { + continue; + } + + const value = this.transformPrimitive(_value, fieldType); + + switch (key) { + case 'equals': { + clauses.push( + this.buildLiteralFilter( + eb, + fieldRef, + fieldType, + eb.val(value) + ) + ); + break; + } + + case 'has': { + clauses.push(eb(fieldRef, '@>', eb.val([value]))); + break; + } + + case 'hasEvery': { + clauses.push(eb(fieldRef, '@>', eb.val(value))); + break; + } + + case 'hasSome': { + clauses.push(eb(fieldRef, '&&', eb.val(value))); + break; + } + + case 'isEmpty': { + clauses.push( + eb(fieldRef, value === true ? '=' : '!=', eb.val([])) + ); + break; + } + + default: { + throw new InternalError(`Invalid array filter key: ${key}`); + } + } + } + + return this.and(eb, ...clauses); + } + buildPrimitiveFilter( eb: ExpressionBuilder, model: string, - table: string, + modelAlias: string, field: string, fieldDef: FieldDef, payload: any ) { if (payload === null) { - return eb(sql.ref(`${table}.${field}`), 'is', null); + return eb(sql.ref(`${modelAlias}.${field}`), 'is', null); } if (isEnum(this.schema, fieldDef.type)) { - return this.buildEnumFilter(eb, table, field, fieldDef, payload); + return this.buildEnumFilter( + eb, + modelAlias, + field, + fieldDef, + payload + ); } return match(fieldDef.type as BuiltinType) .with('String', () => - this.buildStringFilter(eb, table, field, payload) + this.buildStringFilter(eb, modelAlias, field, payload) ) .with(P.union('Int', 'Float', 'Decimal', 'BigInt'), (type) => - this.buildNumberFilter(eb, model, table, field, type, payload) + this.buildNumberFilter( + eb, + model, + modelAlias, + field, + type, + payload + ) ) .with('Boolean', () => - this.buildBooleanFilter(eb, table, field, payload) + this.buildBooleanFilter(eb, modelAlias, field, payload) ) .with('DateTime', () => - this.buildDateTimeFilter(eb, table, field, payload) + this.buildDateTimeFilter(eb, modelAlias, field, payload) ) .with('Bytes', () => - this.buildBytesFilter(eb, table, field, payload) + this.buildBytesFilter(eb, modelAlias, field, payload) ) .exhaustive(); } diff --git a/packages/runtime/src/client/crud/dialects/postgresql.ts b/packages/runtime/src/client/crud/dialects/postgresql.ts index 28114760..857f45e7 100644 --- a/packages/runtime/src/client/crud/dialects/postgresql.ts +++ b/packages/runtime/src/client/crud/dialects/postgresql.ts @@ -25,16 +25,24 @@ export class PostgresCrudDialect< return 'postgresql' as const; } - override transformPrimitive(value: unknown, type: BuiltinType) { - return match(type) - .with('DateTime', () => - value instanceof Date - ? value - : typeof value === 'string' - ? new Date(value) - : value - ) - .otherwise(() => value); + override transformPrimitive(value: unknown, type: BuiltinType): unknown { + if (value === undefined) { + return value; + } + + if (Array.isArray(value)) { + return value.map((v) => this.transformPrimitive(v, type)); + } else { + return match(type) + .with('DateTime', () => + value instanceof Date + ? value + : typeof value === 'string' + ? new Date(value) + : value + ) + .otherwise(() => value); + } } override buildRelationSelection( @@ -347,8 +355,12 @@ export class PostgresCrudDialect< } override buildArrayLiteralSQL(values: unknown[]): string { - return `ARRAY[${values.map((v) => - typeof v === 'string' ? `'${v}'` : v - )}]`; + if (values.length === 0) { + return '{}'; + } else { + return `ARRAY[${values.map((v) => + typeof v === 'string' ? `'${v}'` : v + )}]`; + } } } diff --git a/packages/runtime/src/client/crud/dialects/sqlite.ts b/packages/runtime/src/client/crud/dialects/sqlite.ts index 5b836e2b..2bc6320b 100644 --- a/packages/runtime/src/client/crud/dialects/sqlite.ts +++ b/packages/runtime/src/client/crud/dialects/sqlite.ts @@ -26,19 +26,23 @@ export class SqliteCrudDialect< return 'sqlite' as const; } - override transformPrimitive(value: unknown, type: BuiltinType) { + override transformPrimitive(value: unknown, type: BuiltinType): unknown { if (value === undefined) { return value; } - return match(type) - .with('Boolean', () => (value ? 1 : 0)) - .with('DateTime', () => - value instanceof Date ? value.toISOString() : value - ) - .with('Decimal', () => (value as Decimal).toString()) - .with('Bytes', () => Buffer.from(value as Uint8Array)) - .otherwise(() => value); + if (Array.isArray(value)) { + return value.map((v) => this.transformPrimitive(v, type)); + } else { + return match(type) + .with('Boolean', () => (value ? 1 : 0)) + .with('DateTime', () => + value instanceof Date ? value.toISOString() : value + ) + .with('Decimal', () => (value as Decimal).toString()) + .with('Bytes', () => Buffer.from(value as Uint8Array)) + .otherwise(() => value); + } } override buildRelationSelection( diff --git a/packages/runtime/src/client/crud/operations/base.ts b/packages/runtime/src/client/crud/operations/base.ts index 841aab90..9d4c698e 100644 --- a/packages/runtime/src/client/crud/operations/base.ts +++ b/packages/runtime/src/client/crud/operations/base.ts @@ -540,10 +540,24 @@ export abstract class BaseOperationHandler { isScalarField(this.schema, model, field) || isForeignKeyField(this.schema, model, field) ) { - createFields[field] = this.dialect.transformPrimitive( - value, - fieldDef.type as BuiltinType - ); + if ( + fieldDef.array && + value && + typeof value === 'object' && + 'set' in value && + Array.isArray(value.set) + ) { + // deal with nested "set" for scalar lists + createFields[field] = this.dialect.transformPrimitive( + value.set, + fieldDef.type as BuiltinType + ); + } else { + createFields[field] = this.dialect.transformPrimitive( + value, + fieldDef.type as BuiltinType + ); + } } else { if ( fieldDef.relation?.fields && @@ -1069,18 +1083,36 @@ export abstract class BaseOperationHandler { typeof finalData[field] === 'object' && finalData[field] ) { + // numeric fields incremental updates updateFields[field] = this.transformIncrementalUpdate( model, field, fieldDef, finalData[field] ); - } else { - updateFields[field] = this.dialect.transformPrimitive( - finalData[field], - fieldDef.type as BuiltinType + continue; + } + + if ( + fieldDef.array && + typeof finalData[field] === 'object' && + !Array.isArray(finalData[field]) && + finalData[field] + ) { + // scalar list updates + updateFields[field] = this.transformScalarListUpdate( + model, + field, + fieldDef, + finalData[field] ); + continue; } + + updateFields[field] = this.dialect.transformPrimitive( + finalData[field], + fieldDef.type as BuiltinType + ); } else { if (!allowRelationUpdate) { throw new QueryError( @@ -1170,7 +1202,7 @@ export abstract class BaseOperationHandler { const key = Object.keys(payload)[0]; const value = this.dialect.transformPrimitive( - Object.values(payload)[0], + payload[key!], fieldDef.type as BuiltinType ); const eb = expressionBuilder(); @@ -1182,7 +1214,7 @@ export abstract class BaseOperationHandler { eb ); - const op = match(key) + return match(key) .with('set', () => value) .with('increment', () => eb(fieldRef, '+', value)) .with('decrement', () => eb(fieldRef, '-', value)) @@ -1193,7 +1225,42 @@ export abstract class BaseOperationHandler { `Invalid incremental update operation: ${key}` ); }); - return op; + } + + private transformScalarListUpdate( + model: GetModels, + field: string, + fieldDef: FieldDef, + payload: Record + ) { + invariant( + Object.keys(payload).length === 1, + 'Only one of "set", "push" can be provided' + ); + const key = Object.keys(payload)[0]; + const value = this.dialect.transformPrimitive( + payload[key!], + fieldDef.type as BuiltinType + ); + const eb = expressionBuilder(); + const fieldRef = buildFieldRef( + this.schema, + model, + field, + this.options, + eb + ); + + return match(key) + .with('set', () => value) + .with('push', () => { + return eb(fieldRef, '||', eb.val(ensureArray(value))); + }) + .otherwise(() => { + throw new InternalError( + `Invalid array update operation: ${key}` + ); + }); } private isNumericField(fieldDef: FieldDef) { diff --git a/packages/runtime/src/client/crud/validator.ts b/packages/runtime/src/client/crud/validator.ts index d2df6194..4dfa710c 100644 --- a/packages/runtime/src/client/crud/validator.ts +++ b/packages/runtime/src/client/crud/validator.ts @@ -252,6 +252,11 @@ export class InputValidator { !!fieldDef.optional ); } + } else if (fieldDef.array) { + // array field + fieldSchema = this.makeArrayFilterSchema( + fieldDef.type as BuiltinType + ); } else { // primitive field fieldSchema = this.makePrimitiveFilterSchema( @@ -345,6 +350,16 @@ export class InputValidator { ]); } + private makeArrayFilterSchema(type: BuiltinType) { + return z.object({ + equals: this.makePrimitiveSchema(type).array().optional(), + has: this.makePrimitiveSchema(type).optional(), + hasEvery: this.makePrimitiveSchema(type).array().optional(), + hasSome: this.makePrimitiveSchema(type).array().optional(), + isEmpty: z.boolean().optional(), + }); + } + private makePrimitiveFilterSchema(type: BuiltinType, optional: boolean) { return match(type) .with('String', () => this.makeStringFilterSchema(optional)) @@ -744,7 +759,14 @@ export class InputValidator { ); if (fieldDef.array) { - fieldSchema = z.array(fieldSchema).optional(); + fieldSchema = z + .union([ + z.array(fieldSchema), + z.object({ + set: z.array(fieldSchema), + }), + ]) + .optional(); } if (fieldDef.optional || fieldHasDefaultValue(fieldDef)) { @@ -1068,12 +1090,32 @@ export class InputValidator { divide: z.number().optional(), }) .refine( - (v) => Object.keys(v).length <= 1, + (v) => Object.keys(v).length === 1, 'Only one of "set", "increment", "decrement", "multiply", or "divide" can be provided' ), ]); } + if (fieldDef.array) { + fieldSchema = z + .union([ + fieldSchema.array(), + z + .object({ + set: z.array(fieldSchema).optional(), + push: this.orArray( + fieldSchema, + true + ).optional(), + }) + .refine( + (v) => Object.keys(v).length === 1, + 'Only one of "set", "push" can be provided' + ), + ]) + .optional(); + } + if (fieldDef.optional) { fieldSchema = fieldSchema.nullable(); } diff --git a/packages/runtime/src/client/helpers/schema-db-pusher.ts b/packages/runtime/src/client/helpers/schema-db-pusher.ts index 5171d2ee..d0e752a6 100644 --- a/packages/runtime/src/client/helpers/schema-db-pusher.ts +++ b/packages/runtime/src/client/helpers/schema-db-pusher.ts @@ -161,7 +161,7 @@ export class SchemaDbPusher { } // nullable - if (!fieldDef.optional) { + if (!fieldDef.optional && !fieldDef.array) { col = col.notNull(); } diff --git a/packages/runtime/src/client/query-utils.ts b/packages/runtime/src/client/query-utils.ts index 15ad1676..62d696bc 100644 --- a/packages/runtime/src/client/query-utils.ts +++ b/packages/runtime/src/client/query-utils.ts @@ -192,11 +192,12 @@ export function buildFieldRef( model: string, field: string, options: ClientOptions, - eb: ExpressionBuilder + eb: ExpressionBuilder, + modelAlias?: string ): ExpressionWrapper { const fieldDef = requireField(schema, model, field); if (!fieldDef.computed) { - return eb.ref(field); + return eb.ref(modelAlias ? `${modelAlias}.${field}` : field); } else { let computer: Function | undefined; if ('computedFields' in options) { diff --git a/packages/runtime/src/client/result-processor.ts b/packages/runtime/src/client/result-processor.ts index 88a7ba9f..f47a64f6 100644 --- a/packages/runtime/src/client/result-processor.ts +++ b/packages/runtime/src/client/result-processor.ts @@ -29,11 +29,12 @@ export class ResultProcessor { return data; } for (const [key, value] of Object.entries(data)) { - if (value === undefined || value === null) { + if (value === undefined) { continue; } if (key === '_count') { + // underlying database provider may return string for count data[key] = typeof value === 'string' ? JSON.parse(value) : value; continue; @@ -43,6 +44,15 @@ export class ResultProcessor { if (!fieldDef) { continue; } + + if (value === null) { + // scalar list defaults to empty array + if (fieldDef.array && !fieldDef.relation && value === null) { + data[key] = []; + } + continue; + } + if (fieldDef.relation) { data[key] = this.processRelation(value, fieldDef); } else { diff --git a/packages/runtime/test/client-api/scalar-list.test.ts b/packages/runtime/test/client-api/scalar-list.test.ts new file mode 100644 index 00000000..b031ab83 --- /dev/null +++ b/packages/runtime/test/client-api/scalar-list.test.ts @@ -0,0 +1,235 @@ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { createTestClient } from '../utils'; + +const PG_DB_NAME = 'client-api-scalar-list-tests'; + +describe('Scalar list tests', () => { + const schema = ` + model User { + id String @id @default(cuid()) + name String + tags String[] + flags Boolean[] + } + `; + + let client: any; + + beforeEach(async () => { + client = await createTestClient(schema, { + provider: 'postgresql', + dbName: PG_DB_NAME, + }); + }); + + afterEach(async () => { + await client?.$disconnect(); + }); + + it('works with create', async () => { + await expect( + client.user.create({ + data: { + name: 'user', + }, + }) + ).resolves.toMatchObject({ + tags: [], + }); + + await expect( + client.user.create({ + data: { + name: 'user', + tags: [], + }, + }) + ).resolves.toMatchObject({ + tags: [], + }); + + await expect( + client.user.create({ + data: { + name: 'user', + tags: ['tag1', 'tag2'], + }, + }) + ).resolves.toMatchObject({ + tags: ['tag1', 'tag2'], + }); + + await expect( + client.user.create({ + data: { + name: 'user', + tags: { set: ['tag1', 'tag2'] }, + }, + }) + ).resolves.toMatchObject({ + tags: ['tag1', 'tag2'], + }); + + await expect( + client.user.create({ + data: { + name: 'user', + flags: [true, false], + }, + }) + ).resolves.toMatchObject({ flags: [true, false] }); + + await expect( + client.user.create({ + data: { + name: 'user', + flags: { set: [true, false] }, + }, + }) + ).resolves.toMatchObject({ flags: [true, false] }); + }); + + it('works with update', async () => { + const user = await client.user.create({ + data: { + name: 'user', + tags: ['tag1', 'tag2'], + }, + }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: ['tag3', 'tag4'] }, + }) + ).resolves.toMatchObject({ tags: ['tag3', 'tag4'] }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: { set: ['tag5'] } }, + }) + ).resolves.toMatchObject({ tags: ['tag5'] }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: { push: 'tag6' } }, + }) + ).resolves.toMatchObject({ tags: ['tag5', 'tag6'] }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: { push: [] } }, + }) + ).resolves.toMatchObject({ tags: ['tag5', 'tag6'] }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: { push: ['tag7', 'tag8'] } }, + }) + ).resolves.toMatchObject({ tags: ['tag5', 'tag6', 'tag7', 'tag8'] }); + + await expect( + client.user.update({ + where: { id: user.id }, + data: { tags: { set: [] } }, + }) + ).resolves.toMatchObject({ tags: [] }); + }); + + it('works with filter', async () => { + const user1 = await client.user.create({ + data: { + name: 'user1', + tags: ['tag1', 'tag2'], + }, + }); + // @ts-ignore + const user2 = await client.user.create({ + data: { + name: 'user2', + }, + }); + const user3 = await client.user.create({ + data: { + name: 'user3', + tags: [], + }, + }); + + await expect( + client.user.findMany({ + where: { tags: { equals: ['tag1', 'tag2'] } }, + }) + ).resolves.toMatchObject([user1]); + + await expect( + client.user.findFirst({ + where: { tags: { equals: ['tag1'] } }, + }) + ).toResolveNull(); + + await expect( + client.user.findMany({ + where: { tags: { has: 'tag1' } }, + }) + ).resolves.toMatchObject([user1]); + + await expect( + client.user.findFirst({ + where: { tags: { has: 'tag3' } }, + }) + ).toResolveNull(); + + await expect( + client.user.findMany({ + where: { tags: { hasSome: ['tag1'] } }, + }) + ).resolves.toMatchObject([user1]); + + await expect( + client.user.findMany({ + where: { tags: { hasSome: ['tag1', 'tag3'] } }, + }) + ).resolves.toMatchObject([user1]); + + await expect( + client.user.findFirst({ + where: { tags: { hasSome: [] } }, + }) + ).toResolveNull(); + + await expect( + client.user.findFirst({ + where: { tags: { hasEvery: ['tag3', 'tag4'] } }, + }) + ).toResolveNull(); + + await expect( + client.user.findMany({ + where: { tags: { hasEvery: ['tag1', 'tag2'] } }, + }) + ).resolves.toMatchObject([user1]); + + await expect( + client.user.findFirst({ + where: { tags: { hasEvery: ['tag1', 'tag3'] } }, + }) + ).toResolveNull(); + + await expect( + client.user.findMany({ + where: { tags: { isEmpty: true } }, + }) + ).resolves.toEqual([user3]); + + await expect( + client.user.findMany({ + where: { tags: { isEmpty: false } }, + }) + ).resolves.toEqual([user1]); + }); +});