diff --git a/cli/cli.ts b/cli/cli.ts index 5dd8ba4..3e8cb70 100644 --- a/cli/cli.ts +++ b/cli/cli.ts @@ -13,6 +13,8 @@ import InputHistoryPrompt from './input-history' import { editFile } from './utils' import { println } from './output' import { CLIResult } from './result' +import { createDatabaseConnection } from '@/shared/connectors' +import { DatabaseConnection } from '@/shared/connectors/utils' type YesNo = 'yes' | 'no' inquirer.registerPrompt('input-history', InputHistoryPrompt) @@ -20,7 +22,7 @@ inquirer.registerPrompt('input-history', InputHistoryPrompt) export const startCLI = async (options: CommonOptions) => { const openai = initOpenAI(options.key, options.org) const history: GptSqlResponse[] = [] - + const dbConnection = await createDatabaseConnection(options.database) if (process.stdin.isTTY) { // * Interactive mode // * The history of queries is not calculated from the full his @@ -39,6 +41,7 @@ export const startCLI = async (options: CommonOptions) => { ]) promptHistory.push(query) await executeQueryAndShowResult({ + dbConnection, openai, query, history, @@ -53,6 +56,7 @@ export const startCLI = async (options: CommonOptions) => { }) for await (const query of rl) { await executeQueryAndShowResult({ + dbConnection, openai, query, history, @@ -69,14 +73,14 @@ const executeQueryAndShowResult = async ( openai: OpenAIApi history?: GptSqlResponse[] stdin?: boolean + dbConnection: DatabaseConnection } ) => { const { + dbConnection, history = [], database, historyMode, - openai, - model, format, outputSql, outputResult, @@ -91,10 +95,11 @@ const executeQueryAndShowResult = async ( try { if (!sqlQuery) { spinner.text = 'Getting SQL introspection...' - const introspection = await database.getIntrospection() + const introspection = await dbConnection.getIntrospection() spinner.text = 'Calling OpenAI... ' const result = await getSqlQuery({ ...options, + dbConnection, introspection }) sqlQuery = result.sqlQuery @@ -142,7 +147,7 @@ const executeQueryAndShowResult = async ( } } spinner.text = 'Running query...' - const result = new CLIResult(await database.runSqlQuery(sqlQuery)) + const result = new CLIResult(await dbConnection.runSqlQuery(sqlQuery)) spinner.stop() if (!confirm) { // * Print the SQL query, but only if it's not already printed @@ -217,7 +222,7 @@ const executeQueryAndShowResult = async ( try { spinner.start() const result = new CLIResult( - await database.runSqlQuery(sqlQuery) + await dbConnection.runSqlQuery(sqlQuery) ) spinner.stop() println(chalk.dim(sqlQuery), outputSql) diff --git a/cli/index.ts b/cli/index.ts index 5b666d2..9927854 100644 --- a/cli/index.ts +++ b/cli/index.ts @@ -7,7 +7,7 @@ import { startCLI } from './cli' import { envProgram } from './env' import { parseInteger } from './utils' import { startWeb } from './web' -import { createDatabaseConnection } from '@/shared/connectors' +import { parseConnectionString } from '@/shared/connectors' const program = envProgram .name('chat-dbt') @@ -17,7 +17,10 @@ const program = envProgram 'database connection string, for instance "postgres://user:password@localhost:5432/postgres". Supported databases: postgres, clickhouse.' ) .env('DB_CONNECTION_STRING') - .argParser(value => createDatabaseConnection(value)) + .argParser(value => { + parseConnectionString(value) + return value + }) .makeOptionMandatory(true) ) .addOption( diff --git a/cli/web.ts b/cli/web.ts index 8f04dad..d25e7dc 100644 --- a/cli/web.ts +++ b/cli/web.ts @@ -15,7 +15,7 @@ export const startWeb = async ({ port, browser, ...rest }: WebOptions) => { serverRuntimeConfig: { org, key, - connectionString: database.connectionString + database }, publicRuntimeConfig: options }, diff --git a/next.config.mjs b/next.config.mjs index 08cdf86..24a5b4f 100644 --- a/next.config.mjs +++ b/next.config.mjs @@ -5,7 +5,7 @@ export default { distDir: 'dist/web', // * Load options from environment variables serverRuntimeConfig: { - connectionString: process.env.DB_CONNECTION_STRING, + database: process.env.DB_CONNECTION_STRING, key: process.env.OPENAI_API_KEY, org: process.env.OPENAI_ORGANIZATION }, diff --git a/pages/api/gpt-sql-query.ts b/pages/api/gpt-sql-query.ts index a32be5e..091af44 100644 --- a/pages/api/gpt-sql-query.ts +++ b/pages/api/gpt-sql-query.ts @@ -7,7 +7,7 @@ import { HistoryMode } from '@/shared/options' import { Result } from '@/shared/result' import { createDatabaseConnection } from '@/shared/connectors' -const { key, org, connectionString } = getSecrets() +const { key, org, database } = getSecrets() const { model } = getOptions() const openai = initOpenAI(key, org) @@ -25,19 +25,19 @@ export default async function handler( return res.status(400).json({ error: 'no request', query: '' }) } - const database = createDatabaseConnection(connectionString) + const dbConnection = createDatabaseConnection(database) try { const { sqlQuery, usage } = await getSqlQuery({ openai, model, query, - database, + dbConnection, history, historyMode }) try { - const result = new Result(await database.runSqlQuery(sqlQuery)) + const result = new Result(await dbConnection.runSqlQuery(sqlQuery)) return res.status(200).json({ query, sqlQuery, diff --git a/shared/chat-gpt.ts b/shared/chat-gpt.ts index 3301386..f4b5d5a 100644 --- a/shared/chat-gpt.ts +++ b/shared/chat-gpt.ts @@ -21,7 +21,7 @@ type MessageOptions = { query: string history?: GptSqlResponse[] historyMode: HistoryMode - database: DatabaseConnection + dbConnection: DatabaseConnection introspection?: Instrospection } @@ -30,11 +30,11 @@ const createMessages = async ({ history, historyMode, introspection, - database + dbConnection }: MessageOptions): Promise => { // * Get the SQL introspection const schema = JSON.stringify( - introspection ? introspection : await database.getIntrospection(), + introspection ? introspection : await dbConnection.getIntrospection(), null, 0 ) @@ -42,7 +42,7 @@ const createMessages = async ({ const messages: ChatCompletionRequestMessage[] = [ { role: 'system', - content: `You are a database developer that only responds in ${database.dialectName} without formatting` + content: `You are a database developer that only responds in ${dbConnection.dialectName} without formatting` }, { role: 'system', diff --git a/shared/connectors/index.ts b/shared/connectors/index.ts index cabb292..aa4d062 100644 --- a/shared/connectors/index.ts +++ b/shared/connectors/index.ts @@ -13,3 +13,5 @@ export const createDatabaseConnection = ( return new ClickHouseDatabaseConnection(connectionString) } } + +export { parseConnectionString } diff --git a/utils/options.ts b/utils/options.ts index 11243fa..1424250 100644 --- a/utils/options.ts +++ b/utils/options.ts @@ -1,9 +1,8 @@ import getConfig from 'next/config' import { CommonOptions } from '@/cli' -export const getSecrets = (): Pick & { - connectionString: string -} => getConfig().serverRuntimeConfig +export const getSecrets = (): Pick => + getConfig().serverRuntimeConfig export const getOptions = (): Omit< CommonOptions,