Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,25 @@ export type UniqueConstraint = { name: string; fields: string[] };
* ZModel data model metadata
*/
export type ModelMeta = {
/**
* Model fields
*/
fields: Record<string, Record<string, FieldInfo>>;

/**
* Model unique constraints
*/
uniqueConstraints: Record<string, Record<string, UniqueConstraint>>;

/**
* Information for cascading delete
*/
deleteCascade: Record<string, string[]>;

/**
* Name of model that backs the `auth()` function
*/
authModel?: string;
};

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/policy/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ export function withPolicy<DbClient extends object>(
const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(options?.loadPath);

// validate user context
if (context?.user) {
const idFields = getIdFields(_modelMeta, 'User');
if (context?.user && _modelMeta.authModel) {
const idFields = getIdFields(_modelMeta, _modelMeta.authModel);
if (
!hasAllFields(
context.user,
Expand Down
14 changes: 11 additions & 3 deletions packages/schema/src/cli/cli-util.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast';
import { getLiteral } from '@zenstackhq/sdk';
import { getDataModels, getLiteral, hasAttribute } from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'fs';
import getLatestVersion from 'get-latest-version';
Expand Down Expand Up @@ -95,10 +95,18 @@ export async function loadDocument(fileName: string): Promise<Model> {
function validationAfterMerge(model: Model) {
const dataSources = model.declarations.filter((d) => isDataSource(d));
if (dataSources.length == 0) {
console.error(colors.red('Validation errors: Model must define a datasource'));
console.error(colors.red('Validation error: Model must define a datasource'));
throw new CliError('schema validation errors');
} else if (dataSources.length > 1) {
console.error(colors.red('Validation errors: Multiple datasource declarations are not allowed'));
console.error(colors.red('Validation error: Multiple datasource declarations are not allowed'));
throw new CliError('schema validation errors');
}

// at most one `@@auth` model
const dataModels = getDataModels(model);
const authModels = dataModels.filter((d) => hasAttribute(d, '@@auth'));
if (authModels.length > 1) {
console.error(colors.red('Validation error: Multiple `@@auth` models are not allowed'));
throw new CliError('schema validation errors');
}
}
Expand Down
13 changes: 10 additions & 3 deletions packages/schema/src/language-server/validator/schema-validator.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Model, isDataModel, isDataSource } from '@zenstackhq/language/ast';
import { hasAttribute } from '@zenstackhq/sdk';
import { LangiumDocuments, ValidationAcceptor } from 'langium';
import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils';
import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../constants';
import { isDataSource, Model } from '@zenstackhq/language/ast';
import { AstValidator } from '../types';
import { LangiumDocuments, ValidationAcceptor } from 'langium';
import { validateDuplicatedDeclarations } from './utils';
import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils';

/**
* Validates toplevel schema.
Expand Down Expand Up @@ -33,6 +34,12 @@ export default class SchemaValidator implements AstValidator<Model> {
) {
this.validateDataSources(model, accept);
}

// at most one `@@auth` model
const authModels = model.declarations.filter((d) => isDataModel(d) && hasAttribute(d, '@@auth'));
if (authModels.length > 1) {
accept('error', 'Multiple `@@auth` models are not allowed', { node: authModels[1] });
}
}

private validateDataSources(model: Model, accept: ValidationAcceptor) {
Expand Down
11 changes: 5 additions & 6 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,16 @@ export class ZModelLinker extends DefaultLinker {
const model = getContainingModel(node);

if (model) {
let userModel;
userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
let authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
return isDataModel(d) && hasAttribute(d, '@@auth');
});
if (!userModel) {
userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
if (!authModel) {
authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
return isDataModel(d) && d.name === 'User';
});
}
if (userModel) {
node.$resolvedType = { decl: userModel, nullable: true };
if (authModel) {
node.$resolvedType = { decl: authModel, nullable: true };
}
}
} else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
analyzePolicies,
createProject,
emitProject,
getAuthModel,
getDataModels,
getLiteral,
getPrismaClientImportSpec,
Expand Down Expand Up @@ -744,13 +745,11 @@ export default class PolicyGenerator {
);

if (hasAuthRef) {
const userModel = model.$container.declarations.find(
(decl): decl is DataModel => isDataModel(decl) && decl.name === 'User'
);
if (!userModel) {
throw new PluginError(name, 'User model not found');
const authModel = getAuthModel(getDataModels(model.$container));
if (!authModel) {
throw new PluginError(name, 'Auth model not found');
}
const userIdFields = getIdFields(userModel);
const userIdFields = getIdFields(authModel);
if (!userIdFields || userIdFields.length === 0) {
throw new PluginError(name, 'User model does not have an id field');
}
Expand Down
21 changes: 21 additions & 0 deletions packages/schema/tests/schema/validation/schema-validation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,25 @@ describe('Toplevel Schema Validation Tests', () => {
`)
).toContain('Cannot find model file models/abc.zmodel');
});

it('multiple auth models', async () => {
expect(
await loadModelWithError(`
datasource db1 {
provider = 'postgresql'
url = env('DATABASE_URL')
}

model X {
id String @id
@@auth
}

model Y {
id String @id
@@auth
}
`)
).toContain('Multiple `@@auth` models are not allowed');
});
});
11 changes: 9 additions & 2 deletions packages/sdk/src/model-meta-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ import { lowerCaseFirst } from 'lower-case-first';
import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph';
import {
emitProject,
getAttribute,
getAttributeArg,
getAttributeArgs,
getAuthModel,
getDataModels,
getLiteral,
hasAttribute,
isEnumFieldReference,
isForeignKeyField,
isIdField,
resolved,
saveProject,
getAttribute,
isEnumFieldReference,
} from '.';

export async function generate(
Expand Down Expand Up @@ -113,6 +114,12 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter)
}
}
});
writer.write(',');

const authModel = getAuthModel(dataModels);
if (authModel) {
writer.writeLine(`authModel: '${authModel.name}'`);
}
});
}

Expand Down
8 changes: 8 additions & 0 deletions packages/sdk/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,11 @@ export function getPreviewFeatures(model: Model) {

return [] as string[];
}

export function getAuthModel(dataModels: DataModel[]) {
let authModel = dataModels.find((m) => hasAttribute(m, '@@auth'));
if (!authModel) {
authModel = dataModels.find((m) => m.name === 'User');
}
return authModel;
}
27 changes: 27 additions & 0 deletions tests/integration/tests/enhancements/with-policy/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,31 @@ describe('With Policy: auth() test', () => {
const authDb1 = withPolicy({ id: 'user2', role: 'ADMIN' });
await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy();
});

it('non User auth model', async () => {
const { withPolicy } = await loadSchema(
`
model Foo {
id String @id @default(uuid())
role String

@@auth()
}

model Post {
id String @id @default(uuid())
title String

@@allow('read', true)
@@allow('create', auth().role == 'ADMIN')
}
`
);

const userDb = withPolicy({ id: 'user1', role: 'USER' });
await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy();

const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' });
await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy();
});
});