diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 5fedf10a7..9b4e706bc 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -78,9 +78,25 @@ export type UniqueConstraint = { name: string; fields: string[] }; * ZModel data model metadata */ export type ModelMeta = { + /** + * Model fields + */ fields: Record>; + + /** + * Model unique constraints + */ uniqueConstraints: Record>; + + /** + * Information for cascading delete + */ deleteCascade: Record; + + /** + * Name of model that backs the `auth()` function + */ + authModel?: string; }; /** diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index efe6f8f75..8b05241dd 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -73,8 +73,8 @@ export function withPolicy( const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(options?.loadPath); // validate user context - if (context?.user) { - const idFields = getIdFields(_modelMeta, 'User'); + if (context?.user && _modelMeta.authModel) { + const idFields = getIdFields(_modelMeta, _modelMeta.authModel); if ( !hasAllFields( context.user, diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 1f57da695..eed22e2eb 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -1,5 +1,5 @@ import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast'; -import { getLiteral } from '@zenstackhq/sdk'; +import { getDataModels, getLiteral, hasAttribute } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'fs'; import getLatestVersion from 'get-latest-version'; @@ -95,10 +95,18 @@ export async function loadDocument(fileName: string): Promise { function validationAfterMerge(model: Model) { const dataSources = model.declarations.filter((d) => isDataSource(d)); if (dataSources.length == 0) { - console.error(colors.red('Validation errors: Model must define a datasource')); + console.error(colors.red('Validation error: Model must define a datasource')); throw new CliError('schema validation errors'); } else if (dataSources.length > 1) { - console.error(colors.red('Validation errors: Multiple datasource declarations are not allowed')); + console.error(colors.red('Validation error: Multiple datasource declarations are not allowed')); + throw new CliError('schema validation errors'); + } + + // at most one `@@auth` model + const dataModels = getDataModels(model); + const authModels = dataModels.filter((d) => hasAttribute(d, '@@auth')); + if (authModels.length > 1) { + console.error(colors.red('Validation error: Multiple `@@auth` models are not allowed')); throw new CliError('schema validation errors'); } } diff --git a/packages/schema/src/language-server/validator/schema-validator.ts b/packages/schema/src/language-server/validator/schema-validator.ts index 5b05d4128..b80bf890d 100644 --- a/packages/schema/src/language-server/validator/schema-validator.ts +++ b/packages/schema/src/language-server/validator/schema-validator.ts @@ -1,9 +1,10 @@ +import { Model, isDataModel, isDataSource } from '@zenstackhq/language/ast'; +import { hasAttribute } from '@zenstackhq/sdk'; +import { LangiumDocuments, ValidationAcceptor } from 'langium'; +import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../constants'; -import { isDataSource, Model } from '@zenstackhq/language/ast'; import { AstValidator } from '../types'; -import { LangiumDocuments, ValidationAcceptor } from 'langium'; import { validateDuplicatedDeclarations } from './utils'; -import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils'; /** * Validates toplevel schema. @@ -33,6 +34,12 @@ export default class SchemaValidator implements AstValidator { ) { this.validateDataSources(model, accept); } + + // at most one `@@auth` model + const authModels = model.declarations.filter((d) => isDataModel(d) && hasAttribute(d, '@@auth')); + if (authModels.length > 1) { + accept('error', 'Multiple `@@auth` models are not allowed', { node: authModels[1] }); + } } private validateDataSources(model: Model, accept: ValidationAcceptor) { diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 1a78f3b95..de6afc9e5 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -278,17 +278,16 @@ export class ZModelLinker extends DefaultLinker { const model = getContainingModel(node); if (model) { - let userModel; - userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { + let authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { return isDataModel(d) && hasAttribute(d, '@@auth'); }); - if (!userModel) { - userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { + if (!authModel) { + authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => { return isDataModel(d) && d.name === 'User'; }); } - if (userModel) { - node.$resolvedType = { decl: userModel, nullable: true }; + if (authModel) { + node.$resolvedType = { decl: authModel, nullable: true }; } } } else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) { diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index d76954b43..95edb942a 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -35,6 +35,7 @@ import { analyzePolicies, createProject, emitProject, + getAuthModel, getDataModels, getLiteral, getPrismaClientImportSpec, @@ -744,13 +745,11 @@ export default class PolicyGenerator { ); if (hasAuthRef) { - const userModel = model.$container.declarations.find( - (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' - ); - if (!userModel) { - throw new PluginError(name, 'User model not found'); + const authModel = getAuthModel(getDataModels(model.$container)); + if (!authModel) { + throw new PluginError(name, 'Auth model not found'); } - const userIdFields = getIdFields(userModel); + const userIdFields = getIdFields(authModel); if (!userIdFields || userIdFields.length === 0) { throw new PluginError(name, 'User model does not have an id field'); } diff --git a/packages/schema/tests/schema/validation/schema-validation.test.ts b/packages/schema/tests/schema/validation/schema-validation.test.ts index c7b000338..5f1cc6254 100644 --- a/packages/schema/tests/schema/validation/schema-validation.test.ts +++ b/packages/schema/tests/schema/validation/schema-validation.test.ts @@ -38,4 +38,25 @@ describe('Toplevel Schema Validation Tests', () => { `) ).toContain('Cannot find model file models/abc.zmodel'); }); + + it('multiple auth models', async () => { + expect( + await loadModelWithError(` + datasource db1 { + provider = 'postgresql' + url = env('DATABASE_URL') + } + + model X { + id String @id + @@auth + } + + model Y { + id String @id + @@auth + } + `) + ).toContain('Multiple `@@auth` models are not allowed'); + }); }); diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 76bcd301a..a01aaddd4 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -15,17 +15,18 @@ import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; import { emitProject, + getAttribute, getAttributeArg, getAttributeArgs, + getAuthModel, getDataModels, getLiteral, hasAttribute, + isEnumFieldReference, isForeignKeyField, isIdField, resolved, saveProject, - getAttribute, - isEnumFieldReference, } from '.'; export async function generate( @@ -113,6 +114,12 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter) } } }); + writer.write(','); + + const authModel = getAuthModel(dataModels); + if (authModel) { + writer.writeLine(`authModel: '${authModel.name}'`); + } }); } diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 251be568b..6d726d091 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -352,3 +352,11 @@ export function getPreviewFeatures(model: Model) { return [] as string[]; } + +export function getAuthModel(dataModels: DataModel[]) { + let authModel = dataModels.find((m) => hasAttribute(m, '@@auth')); + if (!authModel) { + authModel = dataModels.find((m) => m.name === 'User'); + } + return authModel; +} diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 0eed19f9d..cffbafaff 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -185,4 +185,31 @@ describe('With Policy: auth() test', () => { const authDb1 = withPolicy({ id: 'user2', role: 'ADMIN' }); await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); }); + + it('non User auth model', async () => { + const { withPolicy } = await loadSchema( + ` + model Foo { + id String @id @default(uuid()) + role String + + @@auth() + } + + model Post { + id String @id @default(uuid()) + title String + + @@allow('read', true) + @@allow('create', auth().role == 'ADMIN') + } + ` + ); + + const userDb = withPolicy({ id: 'user1', role: 'USER' }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy(); + + const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' }); + await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + }); });