Skip to content

Commit

Permalink
feat: Adding EXCLUDE constraint support (#3116)
Browse files Browse the repository at this point in the history
* feat: Adding EXCLUDE constraint support

* Fixing make style, differentiating identity tests

* Removing unnecessary f-string

* Undo removed line, reorder unnamed constraint dict alphabetically

* Refactoring code

* Add missing keyword argument in self.expressions

* Get rid of redundant arg_types defn

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
VaggelisD and georgesittas committed Mar 11, 2024
1 parent b1c8cac commit 09708f5
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 29 deletions.
25 changes: 21 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,15 @@ class EncodeColumnConstraint(ColumnConstraintKind):
pass


# https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
class ExcludeColumnConstraint(ColumnConstraintKind):
pass


class WithOperator(Expression):
arg_types = {"this": True, "op": True}


class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {
Expand Down Expand Up @@ -1854,14 +1863,22 @@ class Index(Expression):
arg_types = {
"this": False,
"table": False,
"using": False,
"where": False,
"columns": False,
"unique": False,
"primary": False,
"amp": False, # teradata
"params": False,
}


class IndexParameters(Expression):
arg_types = {
"using": False,
"include": False,
"partition_by": False, # teradata
"columns": False,
"with_storage": False,
"partition_by": False,
"tablespace": False,
"where": False,
}


Expand Down
33 changes: 23 additions & 10 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class Generator(metaclass=_Generator):
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}",
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda *_: "EXTERNAL",
exp.GlobalProperty: lambda *_: "GLOBAL",
Expand Down Expand Up @@ -140,6 +141,7 @@ class Generator(metaclass=_Generator):
exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]),
exp.VolatileProperty: lambda *_: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}",
}

# Whether null ordering is supported in order by
Expand Down Expand Up @@ -1210,17 +1212,9 @@ def hint_sql(self, expression: exp.Hint) -> str:

return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"

def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table}" if table else ""
def indexparameters_sql(self, expression: exp.IndexParameters) -> str:
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
Expand All @@ -1229,7 +1223,26 @@ def index_sql(self, expression: exp.Index) -> str:
include = self.expressions(expression, key="include", flat=True)
if include:
include = f" INCLUDE ({include})"
return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}"
with_storage = self.expressions(expression, key="with_storage", flat=True)
with_storage = f" WITH {with_storage}" if with_storage else ""
tablespace = self.sql(expression, "tablespace")
tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else ""

return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}"

def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
primary = "PRIMARY " if expression.args.get("primary") else ""
amp = "AMP " if expression.args.get("amp") else ""
name = self.sql(expression, "this")
name = f"{name} " if name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table}" if table else ""

index = "INDEX " if not table else ""

params = self.sql(expression, "params")
return f"{unique}{primary}{amp}{index}{name}{table}{params}"

def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
Expand Down
73 changes: 58 additions & 15 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,9 @@ class Parser(metaclass=_Parser):
exp.DefaultColumnConstraint, this=self._parse_bitwise()
),
"ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()),
"EXCLUDE": lambda self: self.expression(
exp.ExcludeColumnConstraint, this=self._parse_index_params()
),
"FOREIGN KEY": lambda self: self._parse_foreign_key(),
"FORMAT": lambda self: self.expression(
exp.DateFormatColumnConstraint, this=self._parse_var_or_string()
Expand Down Expand Up @@ -877,7 +880,15 @@ class Parser(metaclass=_Parser):
"RENAME": lambda self: self._parse_alter_table_rename(),
}

SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"}
SCHEMA_UNNAMED_CONSTRAINTS = {
"CHECK",
"EXCLUDE",
"FOREIGN KEY",
"LIKE",
"PERIOD",
"PRIMARY KEY",
"UNIQUE",
}

NO_PAREN_FUNCTION_PARSERS = {
"ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()),
Expand Down Expand Up @@ -1008,7 +1019,8 @@ class Parser(metaclass=_Parser):
CLONE_KEYWORDS = {"CLONE", "COPY"}
HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"}

OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"}

OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN}

TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
Expand Down Expand Up @@ -2841,6 +2853,7 @@ def _parse_join(

def _parse_opclass(self) -> t.Optional[exp.Expression]:
this = self._parse_conjunction()

if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
return this

Expand All @@ -2849,6 +2862,37 @@ def _parse_opclass(self) -> t.Optional[exp.Expression]:

return this

def _parse_index_params(self) -> exp.IndexParameters:
using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None

if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_with_operator)
else:
columns = None

include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
partition_by = self._parse_partition_by()
with_storage = (
self._parse_csv(self._parse_conjunction) if self._match(TokenType.WITH) else None
)
tablespace = (
self._parse_var(any_token=True)
if self._match_text_seq("USING", "INDEX", "TABLESPACE")
else None
)
where = self._parse_where()

return self.expression(
exp.IndexParameters,
using=using,
columns=columns,
include=include,
partition_by=partition_by,
where=where,
with_storage=with_storage,
tablespace=tablespace,
)

def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
Expand All @@ -2872,27 +2916,16 @@ def _parse_index(
index = self._parse_id_var()
table = None

using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None

if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass))
else:
columns = None

include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None
params = self._parse_index_params()

return self.expression(
exp.Index,
this=index,
table=table,
using=using,
columns=columns,
unique=unique,
primary=primary,
amp=amp,
include=include,
partition_by=self._parse_partition_by(),
where=self._parse_where(),
params=params,
)

def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]:
Expand Down Expand Up @@ -6094,3 +6127,13 @@ def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expressio
option=option,
partition=partition,
)

def _parse_with_operator(self) -> t.Optional[exp.Expression]:
this = self._parse_ordered(self._parse_opclass)

if not self._match(TokenType.WITH):
return this

op = self._parse_var(any_token=True)

return self.expression(exp.WithOperator, this=this, op=op)
12 changes: 12 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_postgres(self):
self.validate_identity(
"SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"
)
self.validate_identity(
"CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))"
)
self.validate_identity(
"CREATE TABLE t (i INT, PRIMARY KEY (i), EXCLUDE USING gist(col varchar_pattern_ops DESC NULLS LAST WITH &&) WITH (sp1 = 1, sp2 = 2))"
)
self.validate_identity(
"CREATE TABLE t (i INT, EXCLUDE USING btree(INT4RANGE(vid, nid, '[]') ASC NULLS FIRST WITH &&) INCLUDE (col1, col2))"
)
self.validate_identity(
"CREATE TABLE t (i INT, EXCLUDE USING gin(col1 WITH &&, col2 WITH ||) USING INDEX TABLESPACE tablespace WHERE (id > 5))"
)
self.validate_identity(
"SELECT c.oid, n.nspname, c.relname "
"FROM pg_catalog.pg_class AS c "
Expand Down

0 comments on commit 09708f5

Please sign in to comment.