Skip to content

Commit bb10b1c

Browse files
committed
[IR] introduce slice support (#2302)
1 parent 2ae13be commit bb10b1c

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

onnxscript/ir/_core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,7 +2282,12 @@ def doc_string(self, value: str | None) -> None:
22822282
def opset_imports(self) -> dict[str, int]:
22832283
return self._opset_imports
22842284

2285-
def __getitem__(self, index: int) -> Node:
2285+
@typing.overload
2286+
def __getitem__(self, index: int) -> Node: ...
2287+
@typing.overload
2288+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
2289+
2290+
def __getitem__(self, index):
22862291
return self._nodes[index]
22872292

22882293
def __len__(self) -> int:
@@ -2712,7 +2717,12 @@ def __init__(
27122717
self._metadata_props: dict[str, str] | None = metadata_props
27132718
self._nodes: tuple[Node, ...] = tuple(nodes)
27142719

2715-
def __getitem__(self, index: int) -> Node:
2720+
@typing.overload
2721+
def __getitem__(self, index: int) -> Node: ...
2722+
@typing.overload
2723+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
2724+
2725+
def __getitem__(self, index):
27162726
return self._nodes[index]
27172727

27182728
def __len__(self) -> int:
@@ -2961,7 +2971,12 @@ def outputs(self) -> MutableSequence[Value]:
29612971
def attributes(self) -> OrderedDict[str, Attr]:
29622972
return self._attributes
29632973

2964-
def __getitem__(self, index: int) -> Node:
2974+
@typing.overload
2975+
def __getitem__(self, index: int) -> Node: ...
2976+
@typing.overload
2977+
def __getitem__(self, index: slice) -> tuple[Node, ...]: ...
2978+
2979+
def __getitem__(self, index):
29652980
return self._graph.__getitem__(index)
29662981

29672982
def __len__(self) -> int:

onnxscript/ir/_linked_list.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import Generic, Iterable, Iterator, Sequence, TypeVar
7+
from typing import Generic, Iterable, Iterator, Sequence, TypeVar, overload
88

99
T = TypeVar("T")
1010

@@ -136,27 +136,54 @@ def __len__(self) -> int:
136136
)
137137
return self._length
138138

139-
def __getitem__(self, index: int) -> T:
139+
@overload
140+
def __getitem__(self, index: int) -> T: ...
141+
@overload
142+
def __getitem__(self, index: slice) -> list[T]: ...
143+
144+
def __getitem__(self, index):
140145
"""Get the node at the given index.
141146
142147
Complexity is O(n).
143148
"""
144-
if index >= self._length or index < -self._length:
149+
if isinstance(index, slice):
150+
start, stop, step = index.indices(self._length)
151+
elif index < 0:
152+
start, stop, step = index, -self._length - 1, -1
153+
else:
154+
start, stop, step = index, self._length, 1
155+
156+
if start >= self._length or start < -self._length:
145157
raise IndexError(
146158
f"Index out of range: {index} not in range [-{self._length}, {self._length})"
147159
)
148-
if index < 0:
160+
if (step < 0 and stop >= start) or (step > 0 and stop <= start):
161+
return []
162+
if step < 0:
149163
# Look up from the end of the list
150164
iterator = reversed(self)
151165
item = next(iterator)
152-
for _ in range(-index - 1):
166+
# Skip index to match with start
167+
for _ in range(-start - 1 if start < 0 else self._length - start - 1):
153168
item = next(iterator)
154169
else:
155170
iterator = iter(self) # type: ignore[assignment]
156171
item = next(iterator)
157-
for _ in range(index):
172+
# Skip index to match with start
173+
for _ in range(start):
158174
item = next(iterator)
159-
return item
175+
176+
# Return a single element if index is an integer
177+
if isinstance(index, int):
178+
return item
179+
180+
# Return a list if index is a slice
181+
items = [item]
182+
for i in range(1, abs(stop - start)):
183+
item = next(iterator)
184+
if i % step == 0:
185+
items.append(item)
186+
return items
160187

161188
def _insert_one_after(
162189
self,

onnxscript/ir/_linked_list_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ def test_insert_after_supports_taking_elements_from_another_doubly_linked_list(
373373
self.assertEqual(len(other_linked_list), 1)
374374
self.assertEqual([elem.value for elem in other_linked_list], [42])
375375

376+
@parameterized.parameterized.expand(
377+
[(s, t, p) for s in [-2, 0, 2, 3] for t in [2, -1, -2] for p in [-3, -2, -1, 1, 2]]
378+
)
379+
def test_get_item_slice(self, start, stop, step):
380+
elems = [_TestElement(i) for i in range(5)]
381+
linked_list = _linked_list.DoublyLinkedSet(elems)
382+
self.assertEqual(len(linked_list), 5)
383+
self.assertEqual(linked_list[start:stop:step], elems[start:stop:step])
384+
376385

377386
if __name__ == "__main__":
378387
unittest.main()

0 commit comments

Comments
 (0)