Skip to content

Commit

Permalink
Fix: Ignore Identifier nodes in the diffing algorithm (#3065)
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Mar 1, 2024
1 parent e2becea commit c8a753b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 20 deletions.
41 changes: 30 additions & 11 deletions sqlglot/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,18 @@ def compute_node_mappings(
return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy)


LEAF_EXPRESSION_TYPES = (
# The expression types for which Update edits are allowed.
UPDATABLE_EXPRESSION_TYPES = (
exp.Boolean,
exp.DataType,
exp.Identifier,
exp.Literal,
exp.Table,
exp.Column,
exp.Lambda,
)

IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,)


class ChangeDistiller:
"""
Expand All @@ -152,8 +157,16 @@ def diff(

self._source = source
self._target = target
self._source_index = {id(n): n for n, *_ in self._source.bfs()}
self._target_index = {id(n): n for n, *_ in self._target.bfs()}
self._source_index = {
id(n): n
for n, *_ in self._source.bfs()
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._target_index = {
id(n): n
for n, *_ in self._target.bfs()
if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES)
}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
Expand All @@ -170,7 +183,10 @@ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.Lis
for kept_source_node_id, kept_target_node_id in matching_set:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
if (
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
or source_node == target_node
):
edit_script.extend(
self._generate_move_edits(source_node, target_node, matching_set)
)
Expand Down Expand Up @@ -307,17 +323,16 @@ def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False

for _, node in expression.iter_expressions():
has_child_exprs = True
yield from _get_leaves(node)
if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES):
has_child_exprs = True
yield from _get_leaves(node)

if not has_child_exprs:
yield expression


def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target) and (
not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
):
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")

Expand All @@ -343,7 +358,11 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
if expression:
for a in expression.args.values():
args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]
return [
a
for a in args
if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES)
]


def _lcs(
Expand Down
43 changes: 34 additions & 9 deletions tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_identifier
from sqlglot.expressions import Join, to_table


class TestDiff(unittest.TestCase):
Expand All @@ -18,15 +18,13 @@ def test_simple(self):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
[
Remove(to_identifier("b", quoted=False)), # the Identifier node
Remove(parse_one("b")), # the Column node
],
)

self._validate_delta_only(
diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
[
Insert(to_identifier("c", quoted=False)), # the Identifier node
Insert(parse_one("c")), # the Column node
],
)
Expand All @@ -38,9 +36,39 @@ def test_simple(self):
),
[
Update(
to_identifier("table_one", quoted=False),
to_identifier("table_two", quoted=False),
), # the Identifier node
to_table("table_one", quoted=False),
to_table("table_two", quoted=False),
), # the Table node
],
)

def test_lambda(self):
self._validate_delta_only(
diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")),
[
Update(
exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]),
),
],
)

def test_udf(self):
self._validate_delta_only(
diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')),
[
Insert(parse_one('"my.udf2"()')),
Remove(parse_one('"my.udf1"()')),
],
)
self._validate_delta_only(
diff(
parse_one('SELECT a, b, "my.udf"(x, y, z)'),
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
),
[
Insert(exp.column("w")),
Remove(exp.column("z")),
],
)

Expand Down Expand Up @@ -95,7 +123,6 @@ def test_cte(self):
diff(parse_one(expr_src), parse_one(expr_tgt)),
[
Remove(parse_one("LOWER(c) AS c")), # the Alias node
Remove(to_identifier("c", quoted=False)), # the Identifier node
Remove(parse_one("LOWER(c)")), # the Lower node
Remove(parse_one("'filter'")), # the Literal node
Insert(parse_one("'different_filter'")), # the Literal node
Expand Down Expand Up @@ -162,9 +189,7 @@ def test_identifier(self):
self._validate_delta_only(
diff(expr_src, expr_tgt),
[
Insert(expression=exp.to_identifier("b")),
Insert(expression=exp.to_column("tbl.b")),
Insert(expression=exp.to_identifier("tbl")),
],
)

Expand Down

0 comments on commit c8a753b

Please sign in to comment.