Skip to content

Commit

Permalink
fix: don't treat \ as escape in standard strings, support E-strings…
Browse files Browse the repository at this point in the history
…, support vars after `->>` operator (#14700)
  • Loading branch information
ephys committed Jul 3, 2022
1 parent 6871f3f commit 1c85d01
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 46 deletions.
16 changes: 16 additions & 0 deletions src/dialects/abstract/index.ts
Expand Up @@ -108,6 +108,12 @@ export type DialectSupports = {
* e.g. 'DEFERRABLE' and 'INITIALLY DEFERRED'
*/
deferrableConstraints: false,

/**
* This dialect supports E-prefixed strings, e.g. "E'foo'", which
* enables the ability to use backslash escapes inside of the string.
*/
escapeStringConstants: boolean,
};

export abstract class AbstractDialect {
Expand Down Expand Up @@ -197,6 +203,7 @@ export abstract class AbstractDialect {
tmpTableTrigger: false,
indexHints: false,
searchPath: false,
escapeStringConstants: false,
};

declare readonly defaultVersion: string;
Expand All @@ -216,6 +223,15 @@ export abstract class AbstractDialect {

abstract createBindCollector(): BindCollector;

/**
* Whether this dialect can use \ in strings to escape string delimiters.
*
* @returns
*/
canBackslashEscape(): boolean {
return false;
}

static getDefaultPort(): number {
throw new Error(`getDefaultPort not implemented in ${this.name}`);
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialects/mariadb/index.js
Expand Up @@ -62,6 +62,10 @@ export class MariaDbDialect extends AbstractDialect {
return createUnspecifiedOrderedBindCollector();
}

canBackslashEscape() {
return true;
}

static getDefaultPort() {
return 3306;
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialects/mysql/index.js
Expand Up @@ -61,6 +61,10 @@ export class MysqlDialect extends AbstractDialect {
return createUnspecifiedOrderedBindCollector();
}

canBackslashEscape() {
return true;
}

static getDefaultPort() {
return 3306;
}
Expand Down
9 changes: 9 additions & 0 deletions src/dialects/postgres/index.js
Expand Up @@ -52,6 +52,7 @@ export class PostgresDialect extends AbstractDialect {
TSVECTOR: true,
deferrableConstraints: true,
searchPath: true,
escapeStringConstants: true,
});

constructor(sequelize) {
Expand All @@ -72,6 +73,14 @@ export class PostgresDialect extends AbstractDialect {
return createSpecifiedOrderedBindCollector();
}

canBackslashEscape() {
// postgres can use \ to escape if one of these is true:
// - standard_conforming_strings is off
// - the string is prefixed with E (out of scope for this method)

return !this.sequelize.options.standardConformingStrings;
}

static getDefaultPort() {
return 5432;
}
Expand Down
39 changes: 35 additions & 4 deletions src/utils/sql.ts
Expand Up @@ -40,6 +40,7 @@ function mapBindParametersAndReplacements(
let previousSliceEnd = 0;
let isSingleLineComment = false;
let isCommentBlock = false;
let stringIsBackslashEscapable = false;

for (let i = 0; i < sqlString.length; i++) {
const char = sqlString[i];
Expand All @@ -53,8 +54,9 @@ function mapBindParametersAndReplacements(
}

if (isString) {
if (char === `'` && !isBackslashEscaped(sqlString, i - 1)) {
if (char === `'` && (!stringIsBackslashEscapable || !isBackslashEscaped(sqlString, i - 1))) {
isString = false;
stringIsBackslashEscapable = false;
}

continue;
Expand Down Expand Up @@ -99,6 +101,25 @@ function mapBindParametersAndReplacements(

if (char === `'`) {
isString = true;

// The following query is supported in almost all dialects,
// SELECT E'test';
// but postgres interprets it as an E-prefixed string, while other dialects interpret it as
// SELECT E 'test';
// which selects the type E and aliases it to 'test'.

stringIsBackslashEscapable
// all ''-style strings in this dialect can be backslash escaped
= dialect.canBackslashEscape()
// checking if this is a postgres-style E-prefixed string, which also supports backslash escaping
|| (
dialect.supports.escapeStringConstants
// is this a E-prefixed string, such as `E'abc'` ?
&& sqlString[i - 1] === 'E'
// reject things such as `AE'abc'` (the prefix must be exactly E)
&& canPrecedeNewToken(sqlString[i - 2])
);

continue;
}

Expand Down Expand Up @@ -133,7 +154,7 @@ function mapBindParametersAndReplacements(
if (onBind) {
// we want to be conservative with what we consider to be a bind parameter to avoid risk of conflict with potential operators
// users need to add a space before the bind parameter (except after '(', ',', and '=')
if (previousChar !== undefined && !/[\s(,=]/.test(previousChar)) {
if (!canPrecedeNewToken(previousChar)) {
continue;
}

Expand Down Expand Up @@ -162,7 +183,7 @@ function mapBindParametersAndReplacements(
const previousChar = sqlString[i - 1];
// we want to be conservative with what we consider to be a replacement to avoid risk of conflict with potential operators
// users need to add a space before the bind parameter (except after '(', ',', '=', and '[' (for arrays))
if (previousChar !== undefined && !/[\s(,=[]/.test(previousChar)) {
if (!canPrecedeNewToken(previousChar) && previousChar !== '[') {
continue;
}

Expand Down Expand Up @@ -197,7 +218,9 @@ function mapBindParametersAndReplacements(

// we want to be conservative with what we consider to be a replacement to avoid risk of conflict with potential operators
// users need to add a space before the bind parameter (except after '(', ',', '=', and '[' (for arrays))
if (previousChar !== undefined && !/[\s(,=[]/.test(previousChar)) {
// -> [ is temporarily added to allow 'ARRAY[:name]' to be replaced
// https://github.com/sequelize/sequelize/issues/14410 will make this obsolete.
if (!canPrecedeNewToken(previousChar) && previousChar !== '[') {
continue;
}

Expand Down Expand Up @@ -230,11 +253,19 @@ function mapBindParametersAndReplacements(
}
}

if (isString) {
throw new Error(`The following SQL query includes an unterminated string literal:\n${sqlString}`);
}

output += sqlString.slice(previousSliceEnd, sqlString.length);

return output;
}

function canPrecedeNewToken(char: string | undefined): boolean {
return char === undefined || /[\s(>,=]/.test(char);
}

/**
* Maps bind parameters from Sequelize's format ($1 or $name) to the dialect's format.
*
Expand Down
122 changes: 120 additions & 2 deletions test/support.ts
Expand Up @@ -285,9 +285,109 @@ export function getPoolMax(): number {
return Config[getTestDialect()].pool?.max ?? 1;
}

type ExpectationKey = Dialect | 'default';
type ExpectationKey = 'default' | Permutations<Dialect>;

export type ExpectationRecord<V> = PartialRecord<ExpectationKey, V | Expectation<V> | Error>;

type Permutations<T extends string, U extends string = T> =
T extends any ? (T | `${T} ${Permutations<Exclude<U, T>>}`) : never;

type PartialRecord<K extends keyof any, V> = Partial<Record<K, V>>;

export function expectPerDialect<Out>(
method: () => Out,
assertions: ExpectationRecord<Out>,
) {
const expectations: PartialRecord<'default' | Dialect, Out | Error | Expectation<Out>> = Object.create(null);

for (const [key, value] of Object.entries(assertions)) {
const acceptedDialects = key.split(' ') as Array<Dialect | 'default'>;

for (const dialect of acceptedDialects) {
if (dialect === 'default' && acceptedDialects.length > 1) {
throw new Error(`The 'default' expectation cannot be combined with other dialects.`);
}

if (expectations[dialect] !== undefined) {
throw new Error(`The expectation for ${dialect} was already defined.`);
}

expectations[dialect] = value;
}
}

let result: Out | Error;

try {
result = method();
} catch (error: unknown) {
assert(error instanceof Error, 'method threw a non-error');

result = error;
}

const expectation = expectations[sequelize.dialect.name] ?? expectations.default;
if (expectation === undefined) {
throw new Error(`No expectation was defined for ${sequelize.dialect.name} and the 'default' expectation has not been defined.`);
}

if (expectation instanceof Error) {
assert(result instanceof Error, `Expected method to error with "${expectation.message}", but it returned ${inspect(result)}.`);

expect(result.message).to.equal(expectation.message);
} else {
assert(!(result instanceof Error), `Did not expect query to error, but it errored with ${result instanceof Error ? result.message : ''}`);

assertMatchesExpectation(result, expectation);
}
}

function assertMatchesExpectation<V>(result: V, expectation: V | Expectation<V>): void {
if (expectation instanceof Expectation) {
expectation.assert(result);
} else {
expect(result).to.deep.equal(expectation);
}
}

abstract class Expectation<Value> {
abstract assert(value: Value): void;
}

class SqlExpectation extends Expectation<string> {
constructor(private readonly sql: string) {
super();
}

assert(value: string) {
expect(minifySql(value)).to.equal(minifySql(this.sql));
}
}

export function toMatchSql(sql: string) {
return new SqlExpectation(sql);
}

type HasPropertiesInput<Obj extends Record<string, unknown>> = {
[K in keyof Obj]?: any | Expectation<Obj[K]> | Error;
};

class HasPropertiesExpectation<Obj extends Record<string, unknown>> extends Expectation<Obj> {
constructor(private readonly properties: HasPropertiesInput<Obj>) {
super();
}

assert(value: Obj) {
for (const key of Object.keys(this.properties) as Array<keyof Obj>) {
assertMatchesExpectation(value[key], this.properties[key]);
}
}
}

export function toHaveProperties<Obj extends Record<string, unknown>>(properties: HasPropertiesInput<Obj>) {
return new HasPropertiesExpectation<Obj>(properties);
}

export function expectsql(
query: { query: string, bind: unknown } | Error,
assertions: { query: PartialRecord<ExpectationKey, string | Error>, bind: PartialRecord<ExpectationKey, unknown> },
Expand All @@ -302,7 +402,25 @@ export function expectsql(
| { query: PartialRecord<ExpectationKey, string | Error>, bind: PartialRecord<ExpectationKey, unknown> }
| PartialRecord<ExpectationKey, string | Error>,
): void {
const expectations: PartialRecord<ExpectationKey, string | Error> = 'query' in assertions ? assertions.query : assertions;
const rawExpectationMap: PartialRecord<ExpectationKey, string | Error> = 'query' in assertions ? assertions.query : assertions;
const expectations: PartialRecord<'default' | Dialect, string | Error> = Object.create(null);

for (const [key, value] of Object.entries(rawExpectationMap)) {
const acceptedDialects = key.split(' ') as Array<Dialect | 'default'>;

for (const dialect of acceptedDialects) {
if (dialect === 'default' && acceptedDialects.length > 1) {
throw new Error(`The 'default' expectation cannot be combined with other dialects.`);
}

if (expectations[dialect] !== undefined) {
throw new Error(`The expectation for ${dialect} was already defined.`);
}

expectations[dialect] = value;
}
}

let expectation = expectations[sequelize.dialect.name];

const dialect = sequelize.dialect;
Expand Down

0 comments on commit 1c85d01

Please sign in to comment.