Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix!: allow transformations that mutate the visited node's parent #3182

Merged
merged 1 commit into from Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 35 additions & 36 deletions sqlglot/expressions.py
Expand Up @@ -320,19 +320,39 @@ def append(self, arg_key: str, value: t.Any) -> None:
value.index = len(values)
values.append(value)

def set(self, arg_key: str, value: t.Any) -> None:
def set(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None:
"""
Sets arg_key to value.

Args:
arg_key: name of the expression arg.
value: value to set the arg to.
"""
if value is None:
index: if the arg is a list, this specifies what position to add the value in it.
"""
if index is not None:
expressions = self.args.get(arg_key) or []

if seq_get(expressions, index) is None:
return
if value is None:
expressions.pop(index)
for v in expressions[index:]:
v.index = v.index - 1
return

if isinstance(value, list):
expressions.pop(index)
expressions[index:index] = value
else:
expressions[index] = value

value = expressions
elif value is None:
self.args.pop(arg_key, None)
else:
self.args[arg_key] = value
self._set_parent(arg_key, value)
return

self.args[arg_key] = value
self._set_parent(arg_key, value, index)

def _set_parent(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None:
if hasattr(value, "parent"):
Expand Down Expand Up @@ -579,13 +599,13 @@ def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs)
new_node = None

for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node):
parent, arg_key, index = node.parent, node.arg_key, node.index
new_node = fun(node, *args, **kwargs)

if root:
if new_node is not node:
node.replace(new_node)
else:
if not root:
root = new_node
elif new_node is not node:
parent.set(arg_key, new_node, index)

assert root
return root.assert_is(Expression)
Expand Down Expand Up @@ -617,37 +637,16 @@ def replace(self, expression):
"""
parent = self.parent

if not parent:
if not parent or parent is expression:
return expression

key = self.arg_key
value = parent.args.get(key)
exp_is_list = type(expression) is list

if type(value) is list:
index = self.index

if exp_is_list:
value.pop(index)
value[index:index] = expression
parent._set_parent(key, value)
else:
if expression is None:
value.pop(index)

for v in value[index:]:
v.index = v.index - 1
else:
value[index] = expression
parent._set_parent(key, expression, index=index)
elif value is not None:
if expression is None:
parent.args.pop(key)
else:
if exp_is_list and value.parent:
value.parent.replace(expression)
else:
parent.set(key, expression)
if type(expression) is list and isinstance(value, Expression) and value.parent:
value.parent.replace(expression)
else:
parent.set(key, expression, self.index)

if expression is not self:
self.parent = None
Expand Down
12 changes: 12 additions & 0 deletions tests/test_expressions.py
Expand Up @@ -501,6 +501,18 @@ def fun(node):

self.assertEqual(expression.transform(fun).sql(), "FUN(a)")

def test_transform_with_parent_mutation(self):
expression = parse_one("SELECT COUNT(1) FROM table")

def fun(node):
if str(node) == "COUNT(1)":
# node gets silently mutated here - its parent points to the filter node
return exp.Filter(this=node, expression=exp.Where(this=exp.true()))
return node

transformed = expression.transform(fun)
self.assertEqual(transformed.sql(), "SELECT COUNT(1) FILTER(WHERE TRUE) FROM table")

def test_transform_multiple_children(self):
expression = parse_one("SELECT * FROM x")

Expand Down