Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
normalize_gm,
)
from torch._dynamo.utils import ifdynstaticdefault, same
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.lists import RangeVariable

from torch.nn import functional as F
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -2369,6 +2371,157 @@ def fn():
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
self.assertEqual(opt_fn(), fn())

def gen_random_range_args(self):
args_count = random.randint(1, 3)
args = [random.randint(-10, 10) for _ in range(args_count)]
if args_count == 3 and args[2] == 0:
args[2] = 1
return args

def test_range_length(self):
def test(*args, expected=None):
r = range(*args)
range_variable = RangeVariable([ConstantVariable.create(v) for v in args])

self.assertEqual(len(r), range_variable.range_length())

if expected is not None:
self.assertEqual(len(r), expected)

test(1, 1, 1, expected=0)
test(1, 0, expected=0)
test(-10, expected=0)

test(4, expected=4)
test(10, expected=10)

# step >1
test(1, 10, 2, expected=5)

# negative step
test(10, 1, -1, expected=9)
test(10, 1, -3)

# Fuzz testing
for i in range(100):
args = self.gen_random_range_args()
print("testing :", args)
test(*args)

def test_indexed_range(self):
def test(range, index, expected=None):
range_variable = RangeVariable(
[
ConstantVariable.create(v)
for v in [range.start, range.stop, range.step]
]
)

self.assertEqual(
range[index],
range_variable.apply_index(index).as_python_constant(),
)

if expected is not None:
self.assertEqual(range[index], expected)

test(range(10), 1, expected=1)
test(range(10, 20, 2), 1, expected=12)

# Fuzz testing
for i in range(100):
range_args = self.gen_random_range_args()
r = range(*range_args)

if len(r) == 0:
continue

index = random.randint(0, len(r) - 1)

print("testing:", r, index)
test(r, index)

def test_sliced_range(self):
def test(range, slice, expected=None):
range_variable = RangeVariable(
[
ConstantVariable.create(v)
for v in [range.start, range.stop, range.step]
]
)

self.assertEqual(
range[slice],
range_variable.apply_slice(slice).as_python_constant(),
)

if expected is not None:
self.assertEqual(
range[slice],
expected,
)

test(range(10), slice(1, 10, 2), expected=range(1, 10, 2))
test(range(10), slice(None, 10, None), expected=range(0, 10))
test(range(10), slice(-1, 7, None), expected=range(9, 7))
test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2))
test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4))
test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4))
test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9))

def rand_slice():
def flip_coin():
# 1 out of 10
return random.randint(1, 10) == 5

def r_item(allow_zero=True):
i = random.randint(-10, 10)
if not allow_zero and i == 0:
i = 1
if flip_coin():
i = None
return i

arg_count = random.randint(1, 3)

if arg_count == 1:
return slice(r_item())
elif arg_count == 2:
return slice(r_item(), r_item())
else:
return slice(r_item(), r_item(), r_item(False))

# Fuzz testing
for i in range(100):
range_args = self.gen_random_range_args()
r = range(*range_args)
# generate random slice
s = rand_slice()

print("testing:", r, s)
test(r, s)

def test_range_with_slice_index(self):
def fn(x):
acc = 1
for k in range(2)[1::2]:
acc *= acc * k
return x * acc

opt_fn = torch.compile(fullgraph=True)(fn)
x = torch.ones(1)
self.assertEqual(opt_fn(x), fn(x))

def test_range_with_index(self):
def fn(x):
acc = 1
acc *= acc * range(10, 20, 2)[2]
return x * acc

opt_fn = torch.compile(fullgraph=True)(fn)
x = torch.ones(1)
self.assertEqual(opt_fn(x), fn(x))

def test_rand_inlined(self):
@torch.compile(backend="eager", dynamic=True)
def fn():
Expand Down
114 changes: 113 additions & 1 deletion torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.fx

from ..._guards import Source

from .. import polyfill, variables
Expand Down Expand Up @@ -176,13 +177,124 @@ def debug_repr(self):
def python_type(self):
return range

def start(self):
return self.items[0].as_python_constant()

def stop(self):
return self.items[1].as_python_constant()

def step(self):
return self.items[2].as_python_constant()

def range_length(self):
lo = self.start()
hi = self.stop()
step = self.step()

assert step != 0
if step > 0 and lo < hi:
return 1 + (hi - 1 - lo) // step
elif step < 0 and lo > hi:
return 1 + (lo - 1 - hi) // (0 - step)
else:
return 0

def _get_slice_indices(self, length, slice):
step_is_negative = 0

if slice.step is None:
step = 1
step_is_negative = False
else:
step = slice.step
step_is_negative = slice.step < 0

# Find lower and upper bounds for start and stop.
if step_is_negative:
lower = -1
upper = length + lower
else:
lower = 0
upper = length

# Compute start
if slice.start is None:
start = upper if step_is_negative else lower
else:
start = slice.start

if start < 0:
start += length
if start < lower:
start = lower
else:
if start > upper:
start = upper

# Compute stop.
if slice.stop is None:
stop = lower if step_is_negative else upper

else:
stop = slice.stop

if stop < 0:
stop += length
if stop < lower:
stop = lower
else:
if stop > upper:
stop = upper

return [start, stop, step]

def apply_index(self, index):
length = self.range_length()
if index < 0:
index = length + index

if index < 0 or index >= length:
raise IndexError(f"index {index} is out of range")

return variables.ConstantVariable.create(self.start() + (index * self.step()))

def apply_slice(self, slice):
(slice_start, slice_stop, slice_step) = self._get_slice_indices(
self.range_length(), slice
)

def compute_item(index):
return self.start() + (index * self.step())

sub_step = self.step() * slice_step
sub_start = compute_item(slice_start)
sub_stop = compute_item(slice_stop)

result = RangeVariable(
[
variables.ConstantVariable.create(x)
for x in [sub_start, sub_stop, sub_step]
],
mutable_local=MutableLocal() if self.mutable_local else None,
)
return result

def as_python_constant(self):
return range(*[x.as_python_constant() for x in self.items])

def getitem_const(self, arg: VariableTracker):
# implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c
index = arg.as_python_constant()

if isinstance(index, slice):
return self.apply_slice(index)
else:
return self.apply_index(index)

def as_proxy(self):
return self.python_type()(*self._as_proxy())

def unpack_var_sequence(self, tx):
def unpack_var_sequence(self, tx=None):
return [variables.ConstantVariable.create(x) for x in self.as_python_constant()]

def reconstruct(self, codegen):
Expand Down