Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core: Add MultiStringParser to match a collection of strings #3510

Merged
merged 10 commits into from Jun 30, 2022
7 changes: 3 additions & 4 deletions src/sqlfluff/core/dialects/base.py
Expand Up @@ -9,12 +9,11 @@
StringParser,
)
from sqlfluff.core.parser.grammar.base import BaseGrammar
from sqlfluff.core.parser.parsers import BaseParser

DialectElementType = Union[
Type[BaseSegment], BaseGrammar, StringParser, SegmentGenerator
]
DialectElementType = Union[Type[BaseSegment], BaseGrammar, BaseParser, SegmentGenerator]
# NOTE: Post expansion, no generators remain
ExpandedDialectElementType = Union[Type[BaseSegment], StringParser, BaseGrammar]
ExpandedDialectElementType = Union[Type[BaseSegment], BaseParser, BaseGrammar]


class Dialect:
Expand Down
8 changes: 7 additions & 1 deletion src/sqlfluff/core/parser/__init__.py
Expand Up @@ -31,7 +31,12 @@
OptionallyBracketed,
Conditional,
)
from sqlfluff.core.parser.parsers import StringParser, NamedParser, RegexParser
from sqlfluff.core.parser.parsers import (
StringParser,
NamedParser,
RegexParser,
MultiStringParser,
)
from sqlfluff.core.parser.markers import PositionMarker
from sqlfluff.core.parser.lexer import Lexer, StringLexer, RegexLexer
from sqlfluff.core.parser.parser import Parser
Expand Down Expand Up @@ -65,6 +70,7 @@
"OptionallyBracketed",
"Conditional",
"StringParser",
"MultiStringParser",
"NamedParser",
"RegexParser",
"PositionMarker",
Expand Down
4 changes: 2 additions & 2 deletions src/sqlfluff/core/parser/grammar/base.py
Expand Up @@ -17,7 +17,7 @@
from sqlfluff.core.parser.match_wrapper import match_wrapper
from sqlfluff.core.parser.matchable import Matchable
from sqlfluff.core.parser.context import ParseContext
from sqlfluff.core.parser.parsers import StringParser
from sqlfluff.core.parser.parsers import BaseParser

# Either a Grammar or a Segment CLASS
MatchableType = Union[Matchable, Type[BaseSegment]]
Expand Down Expand Up @@ -88,7 +88,7 @@ def _resolve_ref(elem):
# t: instance / f: class, ref, func
(True, str, Ref.keyword),
(True, BaseGrammar, lambda x: x),
(True, StringParser, lambda x: x),
(True, BaseParser, lambda x: x),
(False, BaseSegment, lambda x: x),
]
# Get-out clause for None
Expand Down
136 changes: 111 additions & 25 deletions src/sqlfluff/core/parser/parsers.py
Expand Up @@ -3,32 +3,32 @@
Matchable objects which return individual segments.
"""

from abc import abstractmethod
import regex
from typing import Type, Optional, List, Tuple, Union
from typing import Collection, Type, Optional, List, Tuple, Union

from sqlfluff.core.parser.context import ParseContext
from sqlfluff.core.parser.matchable import Matchable
from sqlfluff.core.parser.match_result import MatchResult
from sqlfluff.core.parser.segments import RawSegment, BaseSegment


class StringParser(Matchable):
"""An object which matches and returns raw segments based on strings."""
class BaseParser(Matchable):
"""An abstract class from which other Parsers should inherit."""

# Meta segments are handled seperately. All StringParser elements
# Meta segments are handled seperately. All Parser elements
# are assumed to be not meta.
is_meta = False
is_meta: bool = False

@abstractmethod
def __init__(
self,
template: str,
raw_class: Type[RawSegment],
name: Optional[str] = None,
type: Optional[str] = None,
optional: bool = False,
**segment_kwargs,
):
self.template = template
) -> None:
self.raw_class = raw_class
self.name = name
self.type = type
Expand All @@ -39,21 +39,9 @@ def is_optional(self) -> bool:
"""Return whether this element is optional."""
return self.optional

def simple(self, parse_context: "ParseContext") -> Optional[List[str]]:
"""Return simple options for this matcher.

Because string matchers are not case sensitive we can
just return the template here.
"""
return [self.template.upper()]

@abstractmethod
def _is_first_match(self, segment: BaseSegment):
"""Does the segment provided match according to the current rules."""
# Is the target a match and IS IT CODE.
# The latter stops us accidentally matching comments.
if self.template.upper() == segment.raw.upper() and segment.is_code:
return True
return False

def _make_match_from_first_result(self, segments: Tuple[BaseSegment, ...]):
"""Make a MatchResult from the first segment in the given list.
Expand Down Expand Up @@ -100,9 +88,107 @@ def match(
return MatchResult.from_unmatched(segments)


class NamedParser(StringParser):
class StringParser(BaseParser):
"""An object which matches and returns raw segments based on strings."""

def __init__(
self,
template: str,
raw_class: Type[RawSegment],
name: Optional[str] = None,
type: Optional[str] = None,
optional: bool = False,
**segment_kwargs,
):
self.template = template.upper()
# Create list version upfront to avoid recreating it multiple times.
self._simple = [self.template]
super().__init__(
raw_class=raw_class,
name=name,
type=type,
optional=optional,
**segment_kwargs,
)

def simple(self, parse_context: "ParseContext") -> Optional[List[str]]:
"""Return simple options for this matcher.

Because string matchers are not case sensitive we can
just return the template here.
"""
return self._simple

def _is_first_match(self, segment: BaseSegment):
"""Does the segment provided match according to the current rules."""
# Is the target a match and IS IT CODE.
# The latter stops us accidentally matching comments.
if self.template == segment.raw_upper and segment.is_code:
return True
return False


class MultiStringParser(BaseParser):
"""An object which matches and returns raw segments on a collection of strings."""

def __init__(
self,
templates: Collection[str],
raw_class: Type[RawSegment],
name: Optional[str] = None,
type: Optional[str] = None,
optional: bool = False,
**segment_kwargs,
):
self.templates = {template.upper() for template in templates}
# Create list version upfront to avoid recreating it multiple times.
self._simple = list(self.templates)
super().__init__(
raw_class=raw_class,
name=name,
type=type,
optional=optional,
**segment_kwargs,
)

def simple(self, parse_context: "ParseContext") -> Optional[List[str]]:
"""Return simple options for this matcher.

Because string matchers are not case sensitive we can
just return the templates here.
"""
return self._simple

def _is_first_match(self, segment: BaseSegment):
"""Does the segment provided match according to the current rules."""
# Is the target a match and IS IT CODE.
# The latter stops us accidentally matching comments.
if segment.is_code and segment.raw_upper in self.templates:
return True
return False


class NamedParser(BaseParser):
"""An object which matches and returns raw segments based on names."""

def __init__(
self,
template: str,
raw_class: Type[RawSegment],
name: Optional[str] = None,
type: Optional[str] = None,
optional: bool = False,
**segment_kwargs,
):
self.template = template.lower()
super().__init__(
raw_class=raw_class,
name=name,
type=type,
optional=optional,
**segment_kwargs,
)

def simple(cls, parse_context: ParseContext) -> Optional[List[str]]:
"""Does this matcher support a uppercase hash matching route?

Expand All @@ -122,12 +208,12 @@ def _is_first_match(self, segment: BaseSegment):
"""
# Case sensitivity is not supported. Names are all
# lowercase by convention.
if self.template.lower() == segment.name.lower():
if self.template == segment.name.lower():
return True
return False


class RegexParser(StringParser):
class RegexParser(BaseParser):
"""An object which matches and returns raw segments based on a regex."""

def __init__(
Expand All @@ -141,9 +227,9 @@ def __init__(
**segment_kwargs,
):
# Store the optional anti-template
self.template = template
self.anti_template = anti_template
super().__init__(
template=template,
raw_class=raw_class,
name=name,
type=type,
Expand Down
13 changes: 7 additions & 6 deletions src/sqlfluff/dialects/dialect_ansi.py
Expand Up @@ -48,6 +48,7 @@
StringParser,
SymbolSegment,
WhitespaceSegment,
MultiStringParser,
)
from sqlfluff.core.parser.segments.base import BracketedSegment
from sqlfluff.dialects.dialect_ansi_keywords import (
Expand Down Expand Up @@ -273,8 +274,8 @@
),
# The following functions can be called without parentheses per ANSI specification
BareFunctionSegment=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("bare_functions")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("bare_functions"),
CodeSegment,
name="bare_function",
type="bare_function",
Expand Down Expand Up @@ -318,16 +319,16 @@
),
# Ansi Intervals
DatetimeUnitSegment=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("datetime_units")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("datetime_units"),
CodeSegment,
name="date_part",
type="date_part",
)
),
DatePartFunctionName=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("date_part_function_name")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("date_part_function_name"),
CodeSegment,
name="function_name_identifier",
type="function_name_identifier",
Expand Down
5 changes: 3 additions & 2 deletions src/sqlfluff/dialects/dialect_bigquery.py
Expand Up @@ -35,6 +35,7 @@
StringLexer,
StringParser,
SymbolSegment,
MultiStringParser,
)
from sqlfluff.core.parser.segments.base import BracketedSegment
from sqlfluff.dialects.dialect_bigquery_keywords import (
Expand Down Expand Up @@ -160,8 +161,8 @@
),
),
ExtendedDatetimeUnitSegment=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("extended_datetime_units")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("extended_datetime_units"),
CodeSegment,
name="date_part",
type="date_part",
Expand Down
9 changes: 5 additions & 4 deletions src/sqlfluff/dialects/dialect_exasol.py
Expand Up @@ -29,6 +29,7 @@
StringParser,
RegexParser,
NewlineSegment,
MultiStringParser,
)
from sqlfluff.core.dialects import load_raw_dialect
from sqlfluff.core.parser.segments.generator import SegmentGenerator
Expand Down Expand Up @@ -170,16 +171,16 @@
"escaped_identifier", SymbolSegment, type="identifier"
),
SessionParameterSegment=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("session_parameters")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("session_parameters"),
CodeSegment,
name="session_parameter",
type="session_parameter",
)
),
SystemParameterSegment=SegmentGenerator(
lambda dialect: RegexParser(
r"^(" + r"|".join(dialect.sets("system_parameters")) + r")$",
lambda dialect: MultiStringParser(
dialect.sets("system_parameters"),
CodeSegment,
name="system_parameter",
type="system_parameter",
Expand Down
28 changes: 28 additions & 0 deletions test/core/parser/parser_test.py
@@ -0,0 +1,28 @@
"""The Test file for Parsers (Matchable Classes)."""

from sqlfluff.core.parser import (
KeywordSegment,
MultiStringParser,
)
from sqlfluff.core.parser.context import RootParseContext


def test__parser__multistringparser__match(generate_test_segments):
"""Test the MultiStringParser matchable."""
parser = MultiStringParser(["foo", "bar"], KeywordSegment)
with RootParseContext(dialect=None) as ctx:
# Check directly
seg_list = generate_test_segments(["foo", "fo"])
# Matches when it should
assert parser.match(seg_list[:1], parse_context=ctx).matched_segments == (
KeywordSegment("foo", seg_list[0].pos_marker),
)
# Doesn't match when it shouldn't
assert parser.match(seg_list[1:], parse_context=ctx).matched_segments == tuple()


def test__parser__multistringparser__simple():
"""Test the MultiStringParser matchable."""
parser = MultiStringParser(["foo", "bar"], KeywordSegment)
with RootParseContext(dialect=None) as ctx:
assert parser.simple(ctx)