Skip to content

Commit

Permalink
fix(optimizer): Remove XOR from connector simplifications (#3380)
Browse files Browse the repository at this point in the history
* fix(optimizer): Remove XOR from connector simplifications

* PR Feedback, further simplify A XOR A -> False

* PR Feedback 2

* PR Feedback 3
  • Loading branch information
VaggelisD committed Apr 30, 2024
1 parent 3e8de71 commit b2a7e55
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
31 changes: 27 additions & 4 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6519,7 +6519,7 @@ def condition(

def and_(
*expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
) -> Condition:
) -> And:
"""
Combine multiple conditions with an AND logical operator.
Expand All @@ -6537,12 +6537,12 @@ def and_(
Returns:
And: the new condition
"""
return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, **opts))
return t.cast(And, _combine(expressions, And, dialect, copy=copy, **opts))


def or_(
*expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
) -> Condition:
) -> Or:
"""
Combine multiple conditions with an OR logical operator.
Expand All @@ -6560,7 +6560,30 @@ def or_(
Returns:
Or: the new condition
"""
return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, **opts))
return t.cast(Or, _combine(expressions, Or, dialect, copy=copy, **opts))


def xor(
*expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts
) -> Xor:
"""
Combine multiple conditions with an XOR logical operator.
Example:
>>> xor("x=1", xor("y=1", "z=1")).sql()
'x = 1 XOR (y = 1 XOR z = 1)'
Args:
*expressions: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect: the dialect used to parse the input expression.
copy: whether to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Xor: the new condition
"""
return t.cast(Xor, _combine(expressions, Xor, dialect, copy=copy, **opts))


def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not:
Expand Down
17 changes: 13 additions & 4 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def flatten(expression):
def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
if left == right:
if isinstance(expression, exp.Xor):
return exp.false()
return left
if isinstance(expression, exp.And):
if is_false(left) or is_false(right):
Expand Down Expand Up @@ -365,10 +367,17 @@ def uniq_sort(expression, root=True):
C AND A AND B AND B -> A AND B AND C
"""
if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())

if isinstance(expression, exp.Xor):
result_func = exp.xor
# Do not deduplicate XOR as A XOR A != A if A == True
deduped = None
arr = tuple((gen(e), e) for e in flattened)
else:
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
deduped = {gen(e): e for e in flattened}
arr = tuple(deduped.items())

# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
Expand All @@ -378,7 +387,7 @@ def uniq_sort(expression, root=True):
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
if deduped and len(deduped) < len(flattened):
expression = result_func(*deduped.values(), copy=False)

return expression
Expand Down
11 changes: 11 additions & 0 deletions tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ a AND b;
(x is not null) != (y is null);
(NOT x IS NULL) <> (y IS NULL);

# dialect: mysql
A XOR A;
FALSE;

--------------------------------------
-- Absorption
--------------------------------------
Expand Down Expand Up @@ -232,6 +236,13 @@ x - 1;
A AND D AND B AND E AND F AND G AND E AND A;
A AND B AND D AND E AND F AND G;

A OR D OR B OR E OR F OR G OR E OR A;
A OR B OR D OR E OR F OR G;

# dialect: mysql
A XOR D XOR B XOR E XOR F XOR G XOR C;
A XOR B XOR C XOR D XOR E XOR F XOR G;

A AND NOT B AND C AND B;
FALSE;

Expand Down

0 comments on commit b2a7e55

Please sign in to comment.