-
Notifications
You must be signed in to change notification settings - Fork 72
[IR] Implement append/pop for node input/outputs #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1518,6 +1518,35 @@ | |||||
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." | ||||||
) | ||||||
|
||||||
def append_input(self, /, value: Value | None) -> None: | ||||||
"""Append an input to the node. | ||||||
|
||||||
Args: | ||||||
value: The input value to append. | ||||||
|
||||||
Raises: | ||||||
ValueError: If the input value has a producer set already. | ||||||
""" | ||||||
self._inputs = (*self._inputs, value) | ||||||
if value is not None: | ||||||
value._add_usage(self, len(self._inputs) - 1) # pylint: disable=protected-access | ||||||
|
||||||
def pop_input(self) -> Value | None: | ||||||
"""Remove a trailing input from the node. | ||||||
|
||||||
Args: | ||||||
value: The input value to remove. | ||||||
|
||||||
Raises: | ||||||
ValueError: If the input value is used by other nodes. | ||||||
Comment on lines
+1540
to
+1541
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
if not self._inputs: | ||||||
raise ValueError("No inputs to pop.") | ||||||
value = self._inputs[-1] | ||||||
if value is not None: | ||||||
value._remove_usage(self, len(self._inputs) - 1) | ||||||
Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _remove_usage of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||||||
self._inputs = self._inputs[:-1] | ||||||
|
||||||
def predecessors(self) -> Sequence[Node]: | ||||||
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" | ||||||
# Use the ordered nature of a dictionary to deduplicate the nodes | ||||||
|
@@ -1588,6 +1617,7 @@ | |||||
raise ValueError("The node to append to does not belong to any graph.") | ||||||
self._graph.insert_after(self, nodes) | ||||||
|
||||||
|
||||||
@property | ||||||
def outputs(self) -> Sequence[Value]: | ||||||
return self._outputs | ||||||
|
@@ -1596,6 +1626,47 @@ | |||||
def outputs(self, _: Sequence[Value]) -> None: | ||||||
raise AttributeError("outputs is immutable. Please create a new node instead.") | ||||||
|
||||||
def append_output(self, /, value: Value) -> None: | ||||||
"""Append an output to the node. | ||||||
|
||||||
This is used to add an output to a node that has already been created. | ||||||
|
||||||
Args: | ||||||
value: The output value to append. | ||||||
|
||||||
Raises: | ||||||
ValueError: If the output value has a producer set already. | ||||||
""" | ||||||
if value.producer() is not None and value.producer() is not self: | ||||||
raise ValueError( | ||||||
f"Output value cannot have a producer when used for appending an output. " | ||||||
f"Output: {value}" | ||||||
) | ||||||
self._outputs = (*self._outputs, value) | ||||||
value._producer = self | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update index Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _producer of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||||||
|
||||||
def pop_output(self) -> Value: | ||||||
"""Remove a trailing output from the node. | ||||||
|
||||||
Args: | ||||||
value: The output value to remove. | ||||||
|
||||||
Raises: | ||||||
ValueError: If the output value is used by other nodes. | ||||||
""" | ||||||
if not self._outputs: | ||||||
raise ValueError("No outputs to pop.") | ||||||
value = self._outputs[-1] | ||||||
if value.uses(): | ||||||
raise ValueError( | ||||||
"Cannot pop an output that is used by other nodes. " | ||||||
"Remove the usages first by replacing user inputs with None. " | ||||||
f"Output: {value}, uses: {value.uses()}" | ||||||
) | ||||||
self._outputs = self._outputs[:-1] | ||||||
value._producer = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update index Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _producer of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||||||
return value | ||||||
|
||||||
@property | ||||||
def attributes(self) -> OrderedDict[str, Attr | RefAttr]: | ||||||
return self._attributes | ||||||
|
Check failure
Code scanning / lintrunner
MYPY/return Error