diff --git a/drizzle-kit/src/api.ts b/drizzle-kit/src/api.ts index 70960756a9..86ee7fe3b1 100644 --- a/drizzle-kit/src/api.ts +++ b/drizzle-kit/src/api.ts @@ -36,6 +36,9 @@ export type DrizzlePgDB = DB & { export type PreparePgDBOptions = { queryConcurrency?: number; }; +export type IntrospectPgDBOptions = { + tableConcurrency?: number; +}; export type DrizzlePgDBIntrospectSchema = Omit< PgSchemaKit, 'internal' @@ -84,6 +87,7 @@ export const introspectPgDB = async ( db: DrizzlePgDB, filters: string[], schemaFilters: string[], + options: IntrospectPgDBOptions = {}, ): Promise => { const matchers = filters.map((it) => { return new Minimatch(it); @@ -119,6 +123,7 @@ export const introspectPgDB = async ( undefined, undefined, undefined, + options, ); const schema = { id: originUUID, prevId: '', ...res } as PgSchemaKit; diff --git a/drizzle-kit/src/serializer/pgSerializer.ts b/drizzle-kit/src/serializer/pgSerializer.ts index 88a74ab409..7ce1bf0df3 100644 --- a/drizzle-kit/src/serializer/pgSerializer.ts +++ b/drizzle-kit/src/serializer/pgSerializer.ts @@ -44,6 +44,38 @@ import type { import { type DB, escapeSingleQuotes, isPgArrayType } from '../utils'; import { getColumnCasing, sqlToStr } from './utils'; +export type PgIntrospectOptions = { + tableConcurrency?: number; +}; + +async function mapWithConcurrency( + items: T[], + concurrency: number | undefined, + fn: (item: T) => Promise, +) { + if (concurrency === undefined) { + await Promise.all(items.map(fn)); + return; + } + + if (!Number.isInteger(concurrency) || concurrency < 1) { + throw new RangeError('tableConcurrency must be a positive integer'); + } + + let nextIndex = 0; + const workerCount = Math.min(concurrency, items.length); + + await Promise.all( + Array.from({ length: workerCount }, async () => { + while (nextIndex < items.length) { + const currentIndex = nextIndex; + nextIndex += 1; + await fn(items[currentIndex]); + } + }), + ); +} + export const indexName = (tableName: string, columns: string[]) => { return `${tableName}_${columns.join('_')}_index`; }; @@ -983,6 +1015,7 @@ export const fromDatabase = async ( status: IntrospectStatus, ) => void, tsSchema?: PgSchemaInternal, + options: PgIntrospectOptions = {}, ): Promise => { const result: Record = {}; const views: Record = {}; @@ -1209,13 +1242,20 @@ WHERE const sequencesInColumns: string[] = []; - const all = allTables - .filter((it) => it.type === 'table') - .map((row) => { + const tableRows = allTables.filter((it) => it.type === 'table'); + tableCount = tableRows.filter((row) => tablesFilter(row.table_name as string)).length; + + if (progressCallback) { + progressCallback('tables', tableCount, 'done'); + } + + await mapWithConcurrency( + tableRows, + options.tableConcurrency, + async (row) => { return new Promise(async (res, rej) => { const tableName = row.table_name as string; if (!tablesFilter(tableName)) return res(''); - tableCount += 1; const tableSchema = row.table_schema; try { @@ -1668,18 +1708,14 @@ WHERE } res(''); }); - }); - - if (progressCallback) { - progressCallback('tables', tableCount, 'done'); - } - - for await (const _ of all) { - } + }, + ); - const allViews = allTables - .filter((it) => it.type === 'view' || it.type === 'materialized_view') - .map((row) => { + const viewRows = allTables.filter((it) => it.type === 'view' || it.type === 'materialized_view'); + const allViews = mapWithConcurrency( + viewRows, + options.tableConcurrency, + async (row) => { return new Promise(async (res, rej) => { const viewName = row.table_name as string; if (!tablesFilter(viewName)) return res(''); @@ -1899,12 +1935,12 @@ WHERE } res(''); }); - }); + }, + ); - viewsCount = allViews.length; + viewsCount = viewRows.length; - for await (const _ of allViews) { - } + await allViews; if (progressCallback) { progressCallback('columns', columnsCount, 'done'); diff --git a/drizzle-kit/tests/pgSerializer.test.ts b/drizzle-kit/tests/pgSerializer.test.ts new file mode 100644 index 0000000000..bde1b1f024 --- /dev/null +++ b/drizzle-kit/tests/pgSerializer.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, test, vi } from 'vitest'; +import { fromDatabase } from '../src/serializer/pgSerializer'; + +const TABLES_QUERY_MARKER = 'pg_catalog.pg_class c'; + +function createObservedDb({ tableCount }: { tableCount: number }) { + let activeQueries = 0; + let maxActiveQueries = 0; + + const query = vi.fn(async (sql: string) => { + activeQueries += 1; + maxActiveQueries = Math.max(maxActiveQueries, activeQueries); + + await new Promise((resolve) => setTimeout(resolve, 5)); + + activeQueries -= 1; + + if (sql.includes(TABLES_QUERY_MARKER)) { + return Array.from({ length: tableCount }, (_, index) => ({ + table_schema: 'public', + table_name: `table_${index}`, + type: 'table', + rls_enabled: false, + })); + } + + return []; + }); + + return { + db: { query }, + query, + getMaxActiveQueries: () => maxActiveQueries, + }; +} + +describe('fromDatabase', () => { + test('limits table introspection fanout with tableConcurrency', async () => { + const observed = createObservedDb({ tableCount: 8 }); + + await fromDatabase( + observed.db as any, + undefined, + [], + undefined, + undefined, + undefined, + { tableConcurrency: 2 }, + ); + + expect(observed.query).toHaveBeenCalled(); + expect(observed.getMaxActiveQueries()).toBeLessThanOrEqual(2); + }); + + test('rejects invalid tableConcurrency values', async () => { + const observed = createObservedDb({ tableCount: 1 }); + + await expect( + fromDatabase( + observed.db as any, + undefined, + [], + undefined, + undefined, + undefined, + { tableConcurrency: 0 }, + ), + ).rejects.toThrow('tableConcurrency must be a positive integer'); + await expect( + fromDatabase( + observed.db as any, + undefined, + [], + undefined, + undefined, + undefined, + { tableConcurrency: -1 }, + ), + ).rejects.toThrow('tableConcurrency must be a positive integer'); + await expect( + fromDatabase( + observed.db as any, + undefined, + [], + undefined, + undefined, + undefined, + { tableConcurrency: 1.5 }, + ), + ).rejects.toThrow('tableConcurrency must be a positive integer'); + }); +});