Skip to content

Commit 2937953

Browse files
committed
[IR] introduce slice support (#2302)
1 parent b617ad5 commit 2937953

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

onnxscript/ir/_core.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,7 +2282,12 @@ def doc_string(self, value: str | None) -> None:
22822282
def opset_imports(self) -> dict[str, int]:
22832283
return self._opset_imports
22842284

2285-
def __getitem__(self, index: int) -> Node:
2285+
@typing.overload
2286+
def __getitem__(self, index: int) -> Node: ...
2287+
@typing.overload
2288+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
2289+
2290+
def __getitem__(self, index):
22862291
return self._nodes[index]
22872292

22882293
def __len__(self) -> int:
@@ -2316,7 +2321,7 @@ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
23162321
def node(self, index_or_name: int | str, /) -> Node:
23172322
"""Get a node by index or name.
23182323
2319-
This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1).
2324+
This is an O(n) operation.
23202325
23212326
.. note::
23222327
If you need repeated random access, consider turning it into a list with ``list(graph)`` .
@@ -2711,8 +2716,13 @@ def __init__(
27112716
self._metadata: _metadata.MetadataStore | None = None
27122717
self._metadata_props: dict[str, str] | None = metadata_props
27132718
self._nodes: tuple[Node, ...] = tuple(nodes)
2719+
2720+
@typing.overload
2721+
def __getitem__(self, index: int) -> Node: ...
2722+
@typing.overload
2723+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
27142724

2715-
def __getitem__(self, index: int) -> Node:
2725+
def __getitem__(self, index):
27162726
return self._nodes[index]
27172727

27182728
def __len__(self) -> int:
@@ -2961,7 +2971,12 @@ def outputs(self) -> MutableSequence[Value]:
29612971
def attributes(self) -> OrderedDict[str, Attr]:
29622972
return self._attributes
29632973

2964-
def __getitem__(self, index: int) -> Node:
2974+
@typing.overload
2975+
def __getitem__(self, index: int) -> Node: ...
2976+
@typing.overload
2977+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
2978+
2979+
def __getitem__(self, index):
29652980
return self._graph.__getitem__(index)
29662981

29672982
def __len__(self) -> int:

onnxscript/ir/_linked_list.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import Generic, Iterable, Iterator, Sequence, TypeVar
7+
from typing import Generic, Iterable, Iterator, Sequence, TypeVar, overload
88

99
T = TypeVar("T")
1010

@@ -83,9 +83,7 @@ class DoublyLinkedSet(Sequence[T], Generic[T]):
8383
iteration will start from the "next" node at the _original_ location.
8484
8585
Time complexity:
86-
Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n),
87-
although accessing nodes at either end of the set is O(1). I.e.
88-
``linked_set[0]`` and ``linked_set[-1]`` are O(1).
86+
Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n).
8987
9088
Values need to be hashable. ``None`` is not a valid value in the set.
9189
"""
@@ -136,27 +134,17 @@ def __len__(self) -> int:
136134
)
137135
return self._length
138136

139-
def __getitem__(self, index: int) -> T:
137+
@overload
138+
def __getitem__(self, index: int) -> T: ...
139+
@overload
140+
def __getitem__(self, index: slice) -> tuple[T, ...]: ...
141+
142+
def __getitem__(self, index):
140143
"""Get the node at the given index.
141144
142145
Complexity is O(n).
143146
"""
144-
if index >= self._length or index < -self._length:
145-
raise IndexError(
146-
f"Index out of range: {index} not in range [-{self._length}, {self._length})"
147-
)
148-
if index < 0:
149-
# Look up from the end of the list
150-
iterator = reversed(self)
151-
item = next(iterator)
152-
for _ in range(-index - 1):
153-
item = next(iterator)
154-
else:
155-
iterator = iter(self) # type: ignore[assignment]
156-
item = next(iterator)
157-
for _ in range(index):
158-
item = next(iterator)
159-
return item
147+
return tuple(self)[index]
160148

161149
def _insert_one_after(
162150
self,

onnxscript/ir/_linked_list_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ def test_insert_after_supports_taking_elements_from_another_doubly_linked_list(
373373
self.assertEqual(len(other_linked_list), 1)
374374
self.assertEqual([elem.value for elem in other_linked_list], [42])
375375

376+
def test_get_item_slice(self):
377+
elems = [_TestElement(i) for i in range(3)]
378+
linked_list = _linked_list.DoublyLinkedSet(elems)
379+
self.assertEqual(len(linked_list), 3)
380+
self.assertEqual(list(linked_list[1:2]), elems[1:2])
381+
self.assertEqual(list(linked_list[:2]), elems[:2])
382+
self.assertEqual(list(linked_list[-2:]), elems[-2:])
383+
376384

377385
if __name__ == "__main__":
378386
unittest.main()

0 commit comments

Comments
 (0)