Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
PluginOptions,
RUNTIME_PACKAGE,
getPrismaClientImportSpec,
parseOptionAsStrings,
requireOption,
resolvePath,
saveProject,
Expand Down Expand Up @@ -32,11 +33,14 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
let outDir = requireOption<string>(options, 'output');
outDir = resolvePath(outDir, options);

// resolve "generateModels" option
const generateModels = parseOptionAsStrings(options, 'generateModels', name);

// resolve "generateModelActions" option
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions');
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions', name);

// resolve "generateClientHelpers" option
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers');
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers', name);
if (generateClientHelpers && !generateClientHelpers.every((v) => ['react', 'next'].includes(v))) {
throw new PluginError(name, `Option "generateClientHelpers" only support values "react" and "next"`);
}
Expand All @@ -50,10 +54,15 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.

const prismaClientDmmf = dmmf;

const modelOperations = prismaClientDmmf.mappings.modelOperations;
const models = prismaClientDmmf.datamodel.models;
let modelOperations = prismaClientDmmf.mappings.modelOperations;
if (generateModels) {
modelOperations = modelOperations.filter((mo) => generateModels.includes(mo.model));
}

// TODO: remove this legacy code that deals with "@Gen.hide" comment syntax inherited
// from original code
const hiddenModels: string[] = [];
resolveModelsComments(models, hiddenModels);
resolveModelsComments(prismaClientDmmf.datamodel.models, hiddenModels);

const zodSchemasImport = (options.zodSchemasImport as string) ?? '@zenstackhq/runtime/zod';
createAppRouter(
Expand Down Expand Up @@ -472,24 +481,3 @@ function createHelper(outDir: string) {
);
checkRead.formatText();
}

function parseOptionAsStrings(options: PluginOptions, optionaName: string) {
const value = options[optionaName];
if (value === undefined) {
return undefined;
} else if (typeof value === 'string') {
// comma separated string
return value
.split(',')
.filter((i) => !!i)
.map((i) => i.trim());
} else if (Array.isArray(value) && value.every((i) => typeof i === 'string')) {
// string array
return value as string[];
} else {
throw new PluginError(
name,
`Invalid "${optionaName}" option: must be a comma-separated string or an array of strings`
);
}
}
129 changes: 129 additions & 0 deletions packages/plugins/trpc/tests/trpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,4 +285,133 @@ model post_item {
}
);
});

it('generate for selected models and actions', async () => {
const { projectDir } = await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}

generator js {
provider = 'prisma-client-js'
}

plugin trpc {
provider = '${process.cwd()}/dist'
output = '$projectRoot/trpc'
generateModels = ['Post']
generateModelActions = ['findMany', 'update']
}

model User {
id String @id
email String @unique
posts Post[]
}

model Post {
id String @id
title String
author User? @relation(fields: [authorId], references: [id])
authorId String?
}

model Foo {
id String @id
value Int
}
`,
{
addPrelude: false,
pushDb: false,
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
compile: true,
}
);

expect(fs.existsSync(path.join(projectDir, 'trpc/routers/User.router.ts'))).toBeFalsy();
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Foo.router.ts'))).toBeFalsy();
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Post.router.ts'))).toBeTruthy();

const postRouterContent = fs.readFileSync(path.join(projectDir, 'trpc/routers/Post.router.ts'), 'utf8');
expect(postRouterContent).toContain('findMany:');
expect(postRouterContent).toContain('update:');
expect(postRouterContent).not.toContain('findUnique:');
expect(postRouterContent).not.toContain('create:');

// trpc plugin passes "generateModels" option down to implicitly enabled zod plugin

expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
).toBeTruthy();
// zod for User is generated due to transitive dependency
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
).toBeTruthy();
expect(fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))).toBeFalsy();
});

it('generate for selected models with zod plugin declared', async () => {
const { projectDir } = await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}

generator js {
provider = 'prisma-client-js'
}

plugin zod {
provider = '@core/zod'
}

plugin trpc {
provider = '${process.cwd()}/dist'
output = '$projectRoot/trpc'
generateModels = ['Post']
generateModelActions = ['findMany', 'update']
}

model User {
id String @id
email String @unique
posts Post[]
}

model Post {
id String @id
title String
author User? @relation(fields: [authorId], references: [id])
authorId String?
}

model Foo {
id String @id
value Int
}
`,
{
addPrelude: false,
pushDb: false,
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
compile: true,
}
);

// trpc plugin's "generateModels" shouldn't interfere in this case

expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
).toBeTruthy();
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
).toBeTruthy();
expect(
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))
).toBeTruthy();
});
});
42 changes: 35 additions & 7 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ export class PluginRunner {
}

// "@core/access-policy" has implicit requirements
let zodImplicitlyAdded = false;
if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) {
// make sure "@core/model-meta" is enabled
if (!corePlugins.find((p) => p.provider === '@core/model-meta')) {
Expand All @@ -193,25 +194,52 @@ export class PluginRunner {
// '@core/zod' plugin is auto-enabled by "@core/access-policy"
// if there're validation rules
if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) {
zodImplicitlyAdded = true;
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
plugins.forEach((plugin) => {
// TODO: generalize this
const isTrpcPlugin =
plugin.provider === '@zenstackhq/trpc' ||
// for testing
(process.env.ZENSTACK_TEST && plugin.provider.includes('trpc'));

for (const dep of plugin.dependencies) {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
// TODO: generalize this
if (existing.provider === '@core/zod') {
// Zod plugin can be automatically enabled in `modelOnly` mode, however
// other plugin (tRPC) for now requires it to run in full mode
existing.options = {};

if (
isTrpcPlugin &&
zodImplicitlyAdded // don't do it for user defined zod plugin
) {
// pass trpc plugin's `generateModels` option down to zod plugin
existing.options.generateModels = plugin.options.generateModels;
}
}
} else {
// add core dependency
corePlugins.push({ provider: dep });
const toAdd = { provider: dep, options: {} as Record<string, unknown> };

// TODO: generalize this
if (dep === '@core/zod' && isTrpcPlugin) {
// pass trpc plugin's `generateModels` option down to zod plugin
toAdd.options.generateModels = plugin.options.generateModels;
}

corePlugins.push(toAdd);
}
}
});
}
});

return corePlugins;
}
Expand Down
71 changes: 64 additions & 7 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
isEnumFieldReference,
isForeignKeyField,
isFromStdlib,
parseOptionAsStrings,
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
Expand All @@ -21,6 +22,7 @@ import { streamAllContents } from 'langium';
import path from 'path';
import { Project } from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
import { getDefaultOutputFolder } from '../plugin-utils';
import Transformer from './transformer';
import removeDir from './utils/removeDir';
Expand All @@ -44,12 +46,26 @@ export async function generate(
output = resolvePath(output, options);
await handleGeneratorOutputValue(output);

// calculate the models to be excluded
const excludeModels = getExcludedModels(model, options);

const prismaClientDmmf = dmmf;

const modelOperations = prismaClientDmmf.mappings.modelOperations;
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma;
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma;
const models: DMMF.Model[] = prismaClientDmmf.datamodel.models;
const modelOperations = prismaClientDmmf.mappings.modelOperations.filter(
(o) => !excludeModels.find((e) => e === o.model)
);

// TODO: better way of filtering than string startsWith?
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma.filter(
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLocaleLowerCase()))
);
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma.filter(
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLowerCase()))
);

const models: DMMF.Model[] = prismaClientDmmf.datamodel.models.filter(
(m) => !excludeModels.find((e) => e === m.name)
);

// whether Prisma's Unchecked* series of input types should be generated
const generateUnchecked = options.noUncheckedInput !== true;
Expand All @@ -73,7 +89,7 @@ export async function generate(
dataSource?.fields.find((f) => f.name === 'provider')?.value
) as ConnectorType;

await generateModelSchemas(project, model, output);
await generateModelSchemas(project, model, output, excludeModels);

if (options.modelOnly !== true) {
// detailed object schemas referenced from input schemas
Expand Down Expand Up @@ -120,6 +136,45 @@ export async function generate(
}
}

function getExcludedModels(model: Model, options: PluginOptions) {
// resolve "generateModels" option
const generateModels = parseOptionAsStrings(options, 'generateModels', name);
if (generateModels) {
if (options.modelOnly === true) {
// no model reference needs to be considered, directly exclude any model not included
return model.declarations
.filter((d) => isDataModel(d) && !generateModels.includes(d.name))
.map((m) => m.name);
} else {
// calculate a transitive closure of models to be included
const todo = getDataModels(model).filter((dm) => generateModels.includes(dm.name));
const included = new Set<DataModel>();
while (todo.length > 0) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const dm = todo.pop()!;
included.add(dm);

// add referenced models to the todo list
dm.fields
.map((f) => f.type.reference?.ref)
.filter((type): type is DataModel => isDataModel(type))
.forEach((type) => {
if (!included.has(type)) {
todo.push(type);
}
});
}

// finally find the models to be excluded
return getDataModels(model)
.filter((dm) => !included.has(dm))
.map((m) => m.name);
}
} else {
return [];
}
}

async function handleGeneratorOutputValue(output: string) {
// create the output directory and delete contents that might exist from a previous run
await fs.mkdir(output, { recursive: true });
Expand Down Expand Up @@ -184,10 +239,12 @@ async function generateObjectSchemas(
);
}

async function generateModelSchemas(project: Project, zmodel: Model, output: string) {
async function generateModelSchemas(project: Project, zmodel: Model, output: string, excludedModels: string[]) {
const schemaNames: string[] = [];
for (const dm of getDataModels(zmodel)) {
schemaNames.push(await generateModelSchema(dm, project, output));
if (!excludedModels.includes(dm.name)) {
schemaNames.push(await generateModelSchema(dm, project, output));
}
}

project.createSourceFile(
Expand Down
Loading