Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix(tokenizer): don't increment array cursor by 2 on CRLF #3204

Merged
merged 2 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 9 additions & 16 deletions sqlglot/tokens.py
Expand Up @@ -565,8 +565,7 @@ class Tokenizer(metaclass=_Tokenizer):
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
# Used for breaking a var like x'y' but nothing else the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
Expand Down Expand Up @@ -892,7 +891,7 @@ class Tokenizer(metaclass=_Tokenizer):

COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN}

# handle numeric literals like in hive (3L = BIGINT)
# Handle numeric literals like in hive (3L = BIGINT)
NUMERIC_LITERALS: t.Dict[str, str] = {}

COMMENTS = ["--", ("/*", "*/")]
Expand Down Expand Up @@ -965,8 +964,7 @@ def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end:
current = self._current

# skip spaces inline rather than iteratively call advance()
# for performance reasons
# Skip spaces here rather than iteratively calling advance() for performance reasons
while current < self.size:
char = self.sql[current]

Expand All @@ -975,12 +973,10 @@ def _scan(self, until: t.Optional[t.Callable] = None) -> None:
else:
break

n = current - self._current
self._start = current
self._advance(n if n > 1 else 1)
offset = current - self._current if current > self._current else 1

if self._char is None:
break
self._start = current
self._advance(offset)

if not self._char.isspace():
if self._char.isdigit():
Expand Down Expand Up @@ -1008,12 +1004,9 @@ def _chars(self, size: int) -> str:
def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
# Ensures we don't count an extra line if we get a \r\n line break sequence
if self._char == "\r" and self._peek == "\n":
i = 2
self._start += 1

self._col = 1
self._line += 1
if not (self._char == "\r" and self._peek == "\n"):
self._col = 1
self._line += 1
else:
self._col += i

Expand Down
33 changes: 24 additions & 9 deletions sqlglotrs/src/tokenizer.rs
Expand Up @@ -118,8 +118,27 @@ impl<'a> TokenizerState<'a> {

fn scan(&mut self, until_peek_char: Option<char>) -> Result<(), TokenizerError> {
while self.size > 0 && !self.is_end {
self.start = self.current;
self.advance(1)?;
let mut current = self.current;

// Skip spaces here rather than iteratively calling advance() for performance reasons
while current < self.size {
let ch = self.char_at(current)?;

if ch == ' ' || ch == '\t' {
current += 1;
} else {
break;
}
}

let offset = if current > self.current {
current - self.current
} else {
1
};

self.start = current;
self.advance(offset as isize)?;

if self.current_char == '\0' {
break;
Expand Down Expand Up @@ -153,16 +172,12 @@ impl<'a> TokenizerState<'a> {
}

fn advance(&mut self, i: isize) -> Result<(), TokenizerError> {
let mut i = i;
if Some(&self.token_types.break_) == self.settings.white_space.get(&self.current_char) {
// Ensures we don't count an extra line if we get a \r\n line break sequence.
if self.current_char == '\r' && self.peek_char == '\n' {
i = 2;
self.start += 1;
if ! (self.current_char == '\r' && self.peek_char == '\n') {
self.column = 1;
self.line += 1;
}

self.column = 1;
self.line += 1;
} else {
self.column = self.column.wrapping_add_signed(i);
}
Expand Down
12 changes: 12 additions & 0 deletions tests/test_tokens.py
Expand Up @@ -85,6 +85,18 @@ def test_crlf(self):
],
)

for simple_query in ("SELECT 1\r\n", "\r\nSELECT 1"):
tokens = Tokenizer().tokenize(simple_query)
tokens = [(token.token_type, token.text) for token in tokens]

self.assertEqual(
tokens,
[
(TokenType.SELECT, "SELECT"),
(TokenType.NUMBER, "1"),
],
)

def test_command(self):
tokens = Tokenizer().tokenize("SHOW;")
self.assertEqual(tokens[0].token_type, TokenType.SHOW)
Expand Down