Skip to content

Commit

Permalink
Fix(postgres): fallback to parameter parser if heredoc is untokenizab…
Browse files Browse the repository at this point in the history
…le (#2935)

* Fix(postgres): fallback to parameter parser if heredoc is untokenizable

* PR feedback
  • Loading branch information
georgesittas committed Feb 8, 2024
1 parent b827626 commit 08cd117
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class Tokenizer(tokens.Tokenizer):
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
HEREDOC_STRINGS = ["$"]

HEREDOC_TAG_IS_IDENTIFIER = True
HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
Expand Down
20 changes: 20 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def _quotes_to_format(
command_prefix_tokens={
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
},
heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
)
token_types = RsTokenTypeSettings(
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
Expand All @@ -517,6 +518,7 @@ def _quotes_to_format(
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[klass.HEREDOC_STRING_ALTERNATIVE],
)
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
else:
Expand Down Expand Up @@ -573,6 +575,12 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()

# Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers
HEREDOC_TAG_IS_IDENTIFIER = False

# Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc
HEREDOC_STRING_ALTERNATIVE = TokenType.VAR

# Autofilled
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
Expand Down Expand Up @@ -1249,6 +1257,18 @@ def _scan_string(self, start: str) -> bool:
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
if (
self.HEREDOC_TAG_IS_IDENTIFIER
and not self._peek.isidentifier()
and not self._peek == end
):
if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
self._add(self.HEREDOC_STRING_ALTERNATIVE)
else:
self._scan_var()

return True

self._advance()
tag = "" if self._char == end else self._extract_string(end)
end = f"{start}{tag}{end}"
Expand Down
6 changes: 6 additions & 0 deletions sqlglotrs/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct TokenTypeSettings {
pub semicolon: TokenType,
pub string: TokenType,
pub var: TokenType,
pub heredoc_string_alternative: TokenType,
}

#[pymethods]
Expand All @@ -34,6 +35,7 @@ impl TokenTypeSettings {
semicolon: TokenType,
string: TokenType,
var: TokenType,
heredoc_string_alternative: TokenType,
) -> Self {
TokenTypeSettings {
bit_string,
Expand All @@ -47,6 +49,7 @@ impl TokenTypeSettings {
semicolon,
string,
var,
heredoc_string_alternative,
}
}
}
Expand All @@ -69,6 +72,7 @@ pub struct TokenizerSettings {
pub var_single_tokens: HashSet<char>,
pub commands: HashSet<TokenType>,
pub command_prefix_tokens: HashSet<TokenType>,
pub heredoc_tag_is_identifier: bool,
}

#[pymethods]
Expand All @@ -90,6 +94,7 @@ impl TokenizerSettings {
var_single_tokens: HashSet<String>,
commands: HashSet<TokenType>,
command_prefix_tokens: HashSet<TokenType>,
heredoc_tag_is_identifier: bool,
) -> Self {
let to_char = |v: &String| {
if v.len() == 1 {
Expand Down Expand Up @@ -138,6 +143,7 @@ impl TokenizerSettings {
var_single_tokens: var_single_tokens_native,
commands,
command_prefix_tokens,
heredoc_tag_is_identifier,
}
}
}
Expand Down
19 changes: 18 additions & 1 deletion sqlglotrs/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,19 @@ impl<'a> TokenizerState<'a> {
} else if *token_type == self.token_types.bit_string {
(Some(2), *token_type, end.clone())
} else if *token_type == self.token_types.heredoc_string {
if self.settings.heredoc_tag_is_identifier
&& !self.is_identifier(self.peek_char)
&& self.peek_char.to_string() != *end
{
if self.token_types.heredoc_string_alternative != self.token_types.var {
self.add(self.token_types.heredoc_string_alternative, None)?
} else {
self.scan_var()?
};

return Ok(true)
};

self.advance(1)?;
let tag = if self.current_char.to_string() == *end {
String::from("")
Expand Down Expand Up @@ -469,7 +482,7 @@ impl<'a> TokenizerState<'a> {
} else if self.peek_char.to_ascii_uppercase() == 'E' && scientific == 0 {
scientific += 1;
self.advance(1)?;
} else if self.peek_char.is_alphabetic() || self.peek_char == '_' {
} else if self.is_identifier(self.peek_char) {
let number_text = self.text();
let mut literal = String::from("");

Expand Down Expand Up @@ -643,6 +656,10 @@ impl<'a> TokenizerState<'a> {
Ok(text)
}

fn is_identifier(&mut self, name: char) -> bool {
name.is_alphabetic() || name == '_'
}

fn extract_value(&mut self) -> Result<String, TokenizerError> {
loop {
if !self.peek_char.is_whitespace()
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def test_clickhouse(self):
self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""")
self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b")
self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b")
self.validate_identity(
"SELECT $1$foo$1$",
"SELECT 'foo'",
)
self.validate_identity(
"SELECT * FROM table LIMIT 1, 2 BY a, b",
"SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b",
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_postgres(self):
self.assertIsInstance(expr, exp.AlterTable)
self.assertEqual(expr.sql(dialect="postgres"), alter_table_only)

self.validate_identity("SELECT x FROM t WHERE CAST($1 AS TEXT) = 'ok'")
self.validate_identity("SELECT * FROM t TABLESAMPLE SYSTEM (50) REPEATABLE (55)")
self.validate_identity("x @@ y")
self.validate_identity("CAST(x AS MONEY)")
Expand Down

0 comments on commit 08cd117

Please sign in to comment.