diff --git a/src/dialects/abstract/index.ts b/src/dialects/abstract/index.ts index 77a3bc206636..f8e3146fc893 100644 --- a/src/dialects/abstract/index.ts +++ b/src/dialects/abstract/index.ts @@ -108,6 +108,12 @@ export type DialectSupports = { * e.g. 'DEFERRABLE' and 'INITIALLY DEFERRED' */ deferrableConstraints: false, + + /** + * This dialect supports E-prefixed strings, e.g. "E'foo'", which + * enables the ability to use backslash escapes inside of the string. + */ + escapeStringConstants: boolean, }; export abstract class AbstractDialect { @@ -197,6 +203,7 @@ export abstract class AbstractDialect { tmpTableTrigger: false, indexHints: false, searchPath: false, + escapeStringConstants: false, }; declare readonly defaultVersion: string; @@ -216,6 +223,15 @@ export abstract class AbstractDialect { abstract createBindCollector(): BindCollector; + /** + * Whether this dialect can use \ in strings to escape string delimiters. + * + * @returns + */ + canBackslashEscape(): boolean { + return false; + } + static getDefaultPort(): number { throw new Error(`getDefaultPort not implemented in ${this.name}`); } diff --git a/src/dialects/mariadb/index.js b/src/dialects/mariadb/index.js index 1734c46937e1..cd241c47a723 100644 --- a/src/dialects/mariadb/index.js +++ b/src/dialects/mariadb/index.js @@ -62,6 +62,10 @@ export class MariaDbDialect extends AbstractDialect { return createUnspecifiedOrderedBindCollector(); } + canBackslashEscape() { + return true; + } + static getDefaultPort() { return 3306; } diff --git a/src/dialects/mysql/index.js b/src/dialects/mysql/index.js index daf677ec8f43..c6dcae16264f 100644 --- a/src/dialects/mysql/index.js +++ b/src/dialects/mysql/index.js @@ -61,6 +61,10 @@ export class MysqlDialect extends AbstractDialect { return createUnspecifiedOrderedBindCollector(); } + canBackslashEscape() { + return true; + } + static getDefaultPort() { return 3306; } diff --git a/src/dialects/postgres/index.js b/src/dialects/postgres/index.js index 20bfe760def9..72e1486801ea 100644 --- a/src/dialects/postgres/index.js +++ b/src/dialects/postgres/index.js @@ -52,6 +52,7 @@ export class PostgresDialect extends AbstractDialect { TSVECTOR: true, deferrableConstraints: true, searchPath: true, + escapeStringConstants: true, }); constructor(sequelize) { @@ -72,6 +73,14 @@ export class PostgresDialect extends AbstractDialect { return createSpecifiedOrderedBindCollector(); } + canBackslashEscape() { + // postgres can use \ to escape if one of these is true: + // - standard_conforming_strings is off + // - the string is prefixed with E (out of scope for this method) + + return !this.sequelize.options.standardConformingStrings; + } + static getDefaultPort() { return 5432; } diff --git a/src/utils/sql.ts b/src/utils/sql.ts index c7fc14ed40b5..1c81abd1203d 100644 --- a/src/utils/sql.ts +++ b/src/utils/sql.ts @@ -40,6 +40,7 @@ function mapBindParametersAndReplacements( let previousSliceEnd = 0; let isSingleLineComment = false; let isCommentBlock = false; + let stringIsBackslashEscapable = false; for (let i = 0; i < sqlString.length; i++) { const char = sqlString[i]; @@ -53,8 +54,9 @@ function mapBindParametersAndReplacements( } if (isString) { - if (char === `'` && !isBackslashEscaped(sqlString, i - 1)) { + if (char === `'` && (!stringIsBackslashEscapable || !isBackslashEscaped(sqlString, i - 1))) { isString = false; + stringIsBackslashEscapable = false; } continue; @@ -99,6 +101,25 @@ function mapBindParametersAndReplacements( if (char === `'`) { isString = true; + + // The following query is supported in almost all dialects, + // SELECT E'test'; + // but postgres interprets it as an E-prefixed string, while other dialects interpret it as + // SELECT E 'test'; + // which selects the type E and aliases it to 'test'. + + stringIsBackslashEscapable + // all ''-style strings in this dialect can be backslash escaped + = dialect.canBackslashEscape() + // checking if this is a postgres-style E-prefixed string, which also supports backslash escaping + || ( + dialect.supports.escapeStringConstants + // is this a E-prefixed string, such as `E'abc'` ? + && sqlString[i - 1] === 'E' + // reject things such as `AE'abc'` (the prefix must be exactly E) + && canPrecedeNewToken(sqlString[i - 2]) + ); + continue; } @@ -133,7 +154,7 @@ function mapBindParametersAndReplacements( if (onBind) { // we want to be conservative with what we consider to be a bind parameter to avoid risk of conflict with potential operators // users need to add a space before the bind parameter (except after '(', ',', and '=') - if (previousChar !== undefined && !/[\s(,=]/.test(previousChar)) { + if (!canPrecedeNewToken(previousChar)) { continue; } @@ -162,7 +183,7 @@ function mapBindParametersAndReplacements( const previousChar = sqlString[i - 1]; // we want to be conservative with what we consider to be a replacement to avoid risk of conflict with potential operators // users need to add a space before the bind parameter (except after '(', ',', '=', and '[' (for arrays)) - if (previousChar !== undefined && !/[\s(,=[]/.test(previousChar)) { + if (!canPrecedeNewToken(previousChar) && previousChar !== '[') { continue; } @@ -197,7 +218,9 @@ function mapBindParametersAndReplacements( // we want to be conservative with what we consider to be a replacement to avoid risk of conflict with potential operators // users need to add a space before the bind parameter (except after '(', ',', '=', and '[' (for arrays)) - if (previousChar !== undefined && !/[\s(,=[]/.test(previousChar)) { + // -> [ is temporarily added to allow 'ARRAY[:name]' to be replaced + // https://github.com/sequelize/sequelize/issues/14410 will make this obsolete. + if (!canPrecedeNewToken(previousChar) && previousChar !== '[') { continue; } @@ -230,11 +253,19 @@ function mapBindParametersAndReplacements( } } + if (isString) { + throw new Error(`The following SQL query includes an unterminated string literal:\n${sqlString}`); + } + output += sqlString.slice(previousSliceEnd, sqlString.length); return output; } +function canPrecedeNewToken(char: string | undefined): boolean { + return char === undefined || /[\s(>,=]/.test(char); +} + /** * Maps bind parameters from Sequelize's format ($1 or $name) to the dialect's format. * diff --git a/test/support.ts b/test/support.ts index 960c6abc1c34..a8624892f049 100644 --- a/test/support.ts +++ b/test/support.ts @@ -285,9 +285,109 @@ export function getPoolMax(): number { return Config[getTestDialect()].pool?.max ?? 1; } -type ExpectationKey = Dialect | 'default'; +type ExpectationKey = 'default' | Permutations; + +export type ExpectationRecord = PartialRecord | Error>; + +type Permutations = + T extends any ? (T | `${T} ${Permutations>}`) : never; + type PartialRecord = Partial>; +export function expectPerDialect( + method: () => Out, + assertions: ExpectationRecord, +) { + const expectations: PartialRecord<'default' | Dialect, Out | Error | Expectation> = Object.create(null); + + for (const [key, value] of Object.entries(assertions)) { + const acceptedDialects = key.split(' ') as Array; + + for (const dialect of acceptedDialects) { + if (dialect === 'default' && acceptedDialects.length > 1) { + throw new Error(`The 'default' expectation cannot be combined with other dialects.`); + } + + if (expectations[dialect] !== undefined) { + throw new Error(`The expectation for ${dialect} was already defined.`); + } + + expectations[dialect] = value; + } + } + + let result: Out | Error; + + try { + result = method(); + } catch (error: unknown) { + assert(error instanceof Error, 'method threw a non-error'); + + result = error; + } + + const expectation = expectations[sequelize.dialect.name] ?? expectations.default; + if (expectation === undefined) { + throw new Error(`No expectation was defined for ${sequelize.dialect.name} and the 'default' expectation has not been defined.`); + } + + if (expectation instanceof Error) { + assert(result instanceof Error, `Expected method to error with "${expectation.message}", but it returned ${inspect(result)}.`); + + expect(result.message).to.equal(expectation.message); + } else { + assert(!(result instanceof Error), `Did not expect query to error, but it errored with ${result instanceof Error ? result.message : ''}`); + + assertMatchesExpectation(result, expectation); + } +} + +function assertMatchesExpectation(result: V, expectation: V | Expectation): void { + if (expectation instanceof Expectation) { + expectation.assert(result); + } else { + expect(result).to.deep.equal(expectation); + } +} + +abstract class Expectation { + abstract assert(value: Value): void; +} + +class SqlExpectation extends Expectation { + constructor(private readonly sql: string) { + super(); + } + + assert(value: string) { + expect(minifySql(value)).to.equal(minifySql(this.sql)); + } +} + +export function toMatchSql(sql: string) { + return new SqlExpectation(sql); +} + +type HasPropertiesInput> = { + [K in keyof Obj]?: any | Expectation | Error; +}; + +class HasPropertiesExpectation> extends Expectation { + constructor(private readonly properties: HasPropertiesInput) { + super(); + } + + assert(value: Obj) { + for (const key of Object.keys(this.properties) as Array) { + assertMatchesExpectation(value[key], this.properties[key]); + } + } +} + +export function toHaveProperties>(properties: HasPropertiesInput) { + return new HasPropertiesExpectation(properties); +} + export function expectsql( query: { query: string, bind: unknown } | Error, assertions: { query: PartialRecord, bind: PartialRecord }, @@ -302,7 +402,25 @@ export function expectsql( | { query: PartialRecord, bind: PartialRecord } | PartialRecord, ): void { - const expectations: PartialRecord = 'query' in assertions ? assertions.query : assertions; + const rawExpectationMap: PartialRecord = 'query' in assertions ? assertions.query : assertions; + const expectations: PartialRecord<'default' | Dialect, string | Error> = Object.create(null); + + for (const [key, value] of Object.entries(rawExpectationMap)) { + const acceptedDialects = key.split(' ') as Array; + + for (const dialect of acceptedDialects) { + if (dialect === 'default' && acceptedDialects.length > 1) { + throw new Error(`The 'default' expectation cannot be combined with other dialects.`); + } + + if (expectations[dialect] !== undefined) { + throw new Error(`The expectation for ${dialect} was already defined.`); + } + + expectations[dialect] = value; + } + } + let expectation = expectations[sequelize.dialect.name]; const dialect = sequelize.dialect; diff --git a/test/unit/utils/sql.test.ts b/test/unit/utils/sql.test.ts index 233c07dc1a8e..723258a67f9a 100644 --- a/test/unit/utils/sql.test.ts +++ b/test/unit/utils/sql.test.ts @@ -1,6 +1,13 @@ import { injectReplacements, mapBindParameters } from '@sequelize/core/_non-semver-use-at-your-own-risk_/utils/sql.js'; import { expect } from 'chai'; -import { expectsql, sequelize } from '../../support'; +import { + createSequelizeInstance, + expectPerDialect, + expectsql, + sequelize, + toHaveProperties, + toMatchSql, +} from '../../support'; const dialect = sequelize.dialect; @@ -54,6 +61,17 @@ describe('mapBindParameters', () => { }); }); + it('parses bind parameters following JSONB indexing', () => { + const { sql } = mapBindParameters(`SELECT * FROM users WHERE json_col->>$key`, dialect); + + expectsql(sql, { + default: `SELECT * FROM users WHERE json_col->>?`, + postgres: `SELECT * FROM users WHERE json_col->>$1`, + sqlite: `SELECT * FROM users WHERE json_col->>$key`, + mssql: `SELECT * FROM users WHERE json_col->>@key`, + }); + }); + it('parses bind parameters followed by a semicolon', () => { const { sql } = mapBindParameters('SELECT * FROM users WHERE id = $id;', dialect); @@ -160,21 +178,48 @@ describe('mapBindParameters', () => { } }); - it('does not consider the token to be a bind parameter if it is part of a string with a backslash escaped quote', () => { - const { sql, bindOrder } = mapBindParameters(`SELECT * FROM users WHERE id = '\\'$id' OR id = $id`, dialect); + it('does not consider the token to be a bind parameter if it is part of a string with a backslash escaped quote, in dialects that support backslash escape', () => { + expectPerDialect(() => mapBindParameters(`SELECT * FROM users WHERE id = '\\' $id' OR id = $id`, dialect), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' $id' OR id = $id`), - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\'$id' OR id = ?`, - postgres: `SELECT * FROM users WHERE id = '\\'$id' OR id = $1`, - sqlite: `SELECT * FROM users WHERE id = '\\'$id' OR id = $id`, - mssql: `SELECT * FROM users WHERE id = '\\'$id' OR id = @id`, + 'mysql mariadb': toHaveProperties({ + sql: toMatchSql(`SELECT * FROM users WHERE id = '\\' $id' OR id = ?`), + bindOrder: ['id'], + }), }); + }); - if (supportsNamedParameters) { - expect(bindOrder).to.be.null; - } else { - expect(bindOrder).to.deep.eq(['id']); - } + it('does not consider the token to be a bind parameter if it is part of a string with a backslash escaped quote, in dialects that support standardConformingStrings = false', () => { + expectPerDialect(() => mapBindParameters(`SELECT * FROM users WHERE id = '\\' $id' OR id = $id`, getNonStandardConfirmingStringDialect()), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' $id' OR id = $id`), + + 'mysql mariadb': toHaveProperties({ + sql: toMatchSql(`SELECT * FROM users WHERE id = '\\' $id' OR id = ?`), + bindOrder: ['id'], + }), + postgres: toHaveProperties({ + sql: `SELECT * FROM users WHERE id = '\\' $id' OR id = $1`, + bindOrder: ['id'], + }), + }); + }); + + it('does not consider the token to be a bind parameter if it is part of an E-prefixed string with a backslash escaped quote, in dialects that support E-prefixed strings', () => { + expectPerDialect(() => mapBindParameters(`SELECT * FROM users WHERE id = E'\\' $id' OR id = $id`, dialect), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = E'\\' $id' OR id = $id`), + + 'mysql mariadb': toHaveProperties({ + sql: toMatchSql(`SELECT * FROM users WHERE id = E'\\' $id' OR id = ?`), + bindOrder: ['id'], + }), + postgres: toHaveProperties({ + sql: `SELECT * FROM users WHERE id = E'\\' $id' OR id = $1`, + bindOrder: ['id'], + }), + }); }); it('considers the token to be a bind parameter if it is outside a string ending with an escaped backslash', () => { @@ -195,20 +240,15 @@ describe('mapBindParameters', () => { }); it('does not consider the token to be a bind parameter if it is part of a string with an escaped backslash followed by a backslash escaped quote', () => { - const { sql, bindOrder } = mapBindParameters(`SELECT * FROM users WHERE id = '\\\\\\'$id' OR id = $id`, dialect); + expectPerDialect(() => mapBindParameters(`SELECT * FROM users WHERE id = '\\\\\\' $id' OR id = $id`, dialect), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\\\\\' $id' OR id = $id`), - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\\\\\'$id' OR id = ?`, - postgres: `SELECT * FROM users WHERE id = '\\\\\\'$id' OR id = $1`, - sqlite: `SELECT * FROM users WHERE id = '\\\\\\'$id' OR id = $id`, - mssql: `SELECT * FROM users WHERE id = '\\\\\\'$id' OR id = @id`, + 'mysql mariadb': toHaveProperties({ + sql: toMatchSql(`SELECT * FROM users WHERE id = '\\\\\\' $id' OR id = ?`), + bindOrder: ['id'], + }), }); - - if (supportsNamedParameters) { - expect(bindOrder).to.be.null; - } else { - expect(bindOrder).to.deep.eq(['id']); - } }); it('does not consider the token to be a bind parameter if it is in a single line comment', () => { @@ -318,6 +358,17 @@ describe('injectReplacements (named replacements)', () => { }); }); + it('parses named replacements following JSONB indexing', () => { + const sql = injectReplacements(`SELECT * FROM users WHERE json_col->>:key`, dialect, { + key: 'name', + }); + + expectsql(sql, { + default: `SELECT * FROM users WHERE json_col->>'name'`, + mssql: `SELECT * FROM users WHERE json_col->>N'name'`, + }); + }); + it('parses named replacements followed by a semicolon', () => { const sql = injectReplacements('SELECT * FROM users WHERE id = :id;', dialect, { id: 1, @@ -395,12 +446,33 @@ describe('injectReplacements (named replacements)', () => { }); it('does not consider the token to be a replacement if it is part of a string with a backslash escaped quote', () => { - const sql = injectReplacements(`SELECT * FROM users WHERE id = '\\':id' OR id = :id`, dialect, { - id: 1, + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\' :id' OR id = :id`, dialect, { id: 1 }); + + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' :id' OR id = :id`), + + 'mysql mariadb': toMatchSql(`SELECT * FROM users WHERE id = '\\' :id' OR id = 1`), }); + }); - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\':id' OR id = 1`, + it('does not consider the token to be a replacement if it is part of a string with a backslash escaped quote, in dialects that support standardConformingStrings = false', () => { + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\' :id' OR id = :id`, getNonStandardConfirmingStringDialect(), { id: 1 }); + + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' :id' OR id = :id`), + + 'mysql mariadb postgres': toMatchSql(`SELECT * FROM users WHERE id = '\\' :id' OR id = 1`), + }); + }); + + it('does not consider the token to be a replacement if it is part of an E-prefixed string with a backslash escaped quote, in dialects that support E-prefixed strings', () => { + expectPerDialect(() => injectReplacements(`SELECT * FROM users WHERE id = E'\\' :id' OR id = :id`, dialect, { id: 1 }), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = E'\\' :id' OR id = :id`), + + 'mysql mariadb postgres': toMatchSql(`SELECT * FROM users WHERE id = E'\\' :id' OR id = 1`), }); }); @@ -415,12 +487,13 @@ describe('injectReplacements (named replacements)', () => { }); it('does not consider the token to be a replacement if it is part of a string with an escaped backslash followed by a backslash escaped quote', () => { - const sql = injectReplacements(`SELECT * FROM users WHERE id = '\\\\\\':id' OR id = :id`, dialect, { - id: 1, - }); + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\\\\\' :id' OR id = :id`, dialect, { id: 1 }); - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\\\\\':id' OR id = 1`, + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\\\\\' :id' OR id = :id`), + + 'mysql mariadb': `SELECT * FROM users WHERE id = '\\\\\\' :id' OR id = 1`, }); }); @@ -492,6 +565,15 @@ describe('injectReplacements (positional replacements)', () => { }); }); + it('parses named replacements following JSONB indexing', () => { + const sql = injectReplacements(`SELECT * FROM users WHERE json_col->>?`, dialect, ['name']); + + expectsql(sql, { + default: `SELECT * FROM users WHERE json_col->>'name'`, + mssql: `SELECT * FROM users WHERE json_col->>N'name'`, + }); + }); + it('parses positional replacements followed by a semicolon', () => { const sql = injectReplacements('SELECT * FROM users WHERE id = ?;', dialect, [1]); @@ -546,10 +628,33 @@ describe('injectReplacements (positional replacements)', () => { }); it('does not consider the token to be a replacement if it is part of a string with a backslash escaped quote', () => { - const sql = injectReplacements(`SELECT * FROM users WHERE id = '\\'?' OR id = ?`, dialect, [1]); + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\' ?' OR id = ?`, dialect, [1]); - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\'?' OR id = 1`, + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' ?' OR id = ?`), + + 'mysql mariadb': toMatchSql(`SELECT * FROM users WHERE id = '\\' ?' OR id = 1`), + }); + }); + + it('does not consider the token to be a replacement if it is part of a string with a backslash escaped quote, in dialects that support standardConformingStrings = false', () => { + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\' ?' OR id = ?`, getNonStandardConfirmingStringDialect(), [1]); + + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\' ?' OR id = ?`), + + 'mysql mariadb postgres': toMatchSql(`SELECT * FROM users WHERE id = '\\' ?' OR id = 1`), + }); + }); + + it('does not consider the token to be a replacement if it is part of an E-prefixed string with a backslash escaped quote, in dialects that support E-prefixed strings', () => { + expectPerDialect(() => injectReplacements(`SELECT * FROM users WHERE id = E'\\' ?' OR id = ?`, dialect, [1]), { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = E'\\' ?' OR id = ?`), + + 'mysql mariadb postgres': toMatchSql(`SELECT * FROM users WHERE id = E'\\' ?' OR id = 1`), }); }); @@ -562,10 +667,13 @@ describe('injectReplacements (positional replacements)', () => { }); it('does not consider the token to be a replacement if it is part of a string with an escaped backslash followed by a backslash escaped quote', () => { - const sql = injectReplacements(`SELECT * FROM users WHERE id = '\\\\\\'?' OR id = ?`, dialect, [1]); + const test = () => injectReplacements(`SELECT * FROM users WHERE id = '\\\\\\' ?' OR id = ?`, dialect, [1]); - expectsql(sql, { - default: `SELECT * FROM users WHERE id = '\\\\\\'?' OR id = 1`, + expectPerDialect(test, { + default: new Error(`The following SQL query includes an unterminated string literal: +SELECT * FROM users WHERE id = '\\\\\\' ?' OR id = ?`), + + 'mysql mariadb': `SELECT * FROM users WHERE id = '\\\\\\' ?' OR id = 1`, }); }); @@ -656,3 +764,9 @@ describe('injectReplacements (positional replacements)', () => { }); }); }); + +function getNonStandardConfirmingStringDialect() { + return createSequelizeInstance({ + standardConformingStrings: false, + }).dialect; +}