Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,34 +574,18 @@ Tensor& true_divide_(Tensor& self, const Scalar& divisor) {
}

Tensor& floor_divide_out(const Tensor& self, const Tensor& other, Tensor& result) {
TORCH_WARN_ONCE(
"floor_divide is deprecated, and will be removed in a future version of pytorch. "
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
"This results in incorrect rounding for negative values.\n"
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
);
// FIXME: Not actually doing floor division (#43874)
auto iter = TensorIterator::binary_op(result, self, other);
div_trunc_stub(iter.device_type(), iter);
div_floor_stub(iter.device_type(), iter);
if (!result.defined()) {
result = iter.output();
}
return result;
}

Tensor floor_divide(const Tensor& self, const Tensor& other) {
TORCH_WARN_ONCE(
"floor_divide is deprecated, and will be removed in a future version of pytorch. "
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
"This results in incorrect rounding for negative values.\n"
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
);
// FIXME: Not actually doing floor division (#43874)
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
div_trunc_stub(iter.device_type(), iter);
div_floor_stub(iter.device_type(), iter);
return iter.output();
}

Expand Down
29 changes: 0 additions & 29 deletions test/jit/test_upgraders.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,35 +123,6 @@ def test_aten_div_tensor_at_3(self):
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)

@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_div_other_variants(self):
def test_func():
a = torch.ones((4, 5, 6), dtype=torch.int64)
b = 4
return a // b

traced_func = torch.jit.trace(test_func, ())
buffer = io.BytesIO()
torch.jit.save(traced_func, buffer)

current_flag_value = torch._C._get_version_calculator_flag()
# calculate based on old version
torch._C._calculate_package_version_based_on_upgraders(False)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 4)

# calculate based on new version
torch._C._calculate_package_version_based_on_upgraders(True)
buffer.seek(0)
loaded_func = torch.jit.load(buffer)
version = self._load_model_version(loaded_func)
self.assertTrue(version == 4)

# make sure we preserve old behaviou
torch._C._calculate_package_version_based_on_upgraders(current_flag_value)

@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_full_other_variants(self):
def test_func():
Expand Down
3 changes: 3 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,7 @@ def forward(self, x):
x = torch.randn(2, 3)
self.run_test(ArithmeticModule(), x)

@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
def test_floor_div(self):
class FloorDivModule(torch.nn.Module):
def forward(self, x, y):
Expand All @@ -2017,6 +2018,7 @@ def forward(self, x, y):
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))

@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
def test_floor_div_script(self):
class FloorDivModule(torch.jit.ScriptModule):
@torch.jit.script_method
Expand All @@ -2027,6 +2029,7 @@ def forward(self, x, y):
y = torch.randn(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))

@unittest.skip("Floor division on ONNX is inconsistent with eager (see #78411)")
@skipIfUnsupportedMinOpsetVersion(9)
def test_floordiv(self):
class FloordivModule(torch.nn.Module):
Expand Down
101 changes: 27 additions & 74 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _numel(x):
# Assumes x is a scalar
return 1

if _numel(l) < 10 and _numel(r) < 10:
if _numel(l) <= 100 and _numel(r) <= 100:
msg = (
"Failed to produce expected results! Input lhs tensor was"
" {0}, rhs tensor was {1}, torch result is {2}, and reference result is"
Expand Down Expand Up @@ -1261,8 +1261,7 @@ def test_inplace_dunders(self, device):
t *= 1
t /= 1
t **= 1
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
t //= 1
t //= 1
t %= 1
self.assertEqual(expected, t.data_ptr())

Expand Down Expand Up @@ -1902,8 +1901,6 @@ def test_binary_op_scalar_device_unspecified(self, devices):
def test_div_and_floordiv_vs_python(self, device):
# Tests torch division ops which can handle both arguments being
# scalars.
# NOTE: torch.floor_divide currently truncates instead of flooring.
# the quotient. See https://github.com/pytorch/pytorch/issues/43874.
def _scalar_helper(python_op, torch_op):
for a, b in product(range(-10, 10), range(-10, 10)):
for op in (lambda x: x * 0.5, lambda x: math.floor(x)):
Expand All @@ -1926,19 +1923,16 @@ def _scalar_helper(python_op, torch_op):
actual_first_tensor = torch_op(a_t, b)
actual_second_tensor = torch_op(a, b_t)

self.assertEqual(actual_scalar, expected_div)
self.assertEqual(actual_tensor.item(), expected_div)
self.assertEqual(actual_scalar, expected)
self.assertEqual(actual_tensor.item(), expected)
self.assertEqual(actual_first_tensor, actual_tensor)
self.assertEqual(actual_second_tensor, actual_tensor)

_scalar_helper(operator.truediv, operator.truediv)
_scalar_helper(operator.truediv, torch.true_divide)
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
_scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv)
_scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide)
_scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv)
_scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide)

# NOTE: torch.floor_divide currently truncates instead of flooring.
# See https://github.com/pytorch/pytorch/issues/43874.
@onlyNativeDeviceTypes
def test_div_and_floordiv_script_vs_python(self, device):
# Creates jitted functions of two tensors
Expand All @@ -1960,13 +1954,12 @@ def _wrapped_floordiv(a, b):
continue

expected_div = a / b
expected_truncdiv = math.trunc(a / b)
expected_floordiv = math.floor(a / b)
a_t = torch.tensor(a, device=device)
b_t = torch.tensor(b, device=device)

self.assertEqual(scripted_div(a_t, b_t), expected_div)
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv)
self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv)

# Creates jitted functions of one tensor
def _wrapped_div_scalar(a):
Expand Down Expand Up @@ -1996,8 +1989,6 @@ def _wrapped_rfloordiv_scalar(a):
a_t = torch.tensor(a, device=device)

self.assertEqual(a / 5, scripted_div_scalar(a_t))
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t))

# Skips zero divisors
if a == 0:
Expand All @@ -2014,8 +2005,6 @@ def _wrapped_rfloordiv_scalar(a):
# See issue gh-52387
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))

# NOTE: torch.floor_divide currently truncates instead of flooring
# the quotient. See https://github.com/pytorch/pytorch/issues/43874.
@onlyNativeDeviceTypes
def test_idiv_and_ifloordiv_vs_python(self, device):
def _wrapped_idiv_tensor(a, b):
Expand Down Expand Up @@ -2075,7 +2064,6 @@ def _wrapped_ifloordiv_scalar(a):

expected_idiv = a / b
expected_ifloordiv = a // b
expected_itruncdiv = math.trunc(a / b)

a_t = torch.tensor(a, device=device)
b_t = torch.tensor(b, device=device)
Expand Down Expand Up @@ -2110,39 +2098,27 @@ def _wrapped_ifloordiv_scalar(a):
if not a_t.is_floating_point() and b_t.is_floating_point():
# Inplace modification fails because a float tensor is required
# if the divisor is a float tensor
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(
UserWarning, "floor_divide"
):
a_t.clone().floor_divide_(b_t)
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(
UserWarning, "floor_divide"
):
scripted_floor_divide_tensor(a_t.clone(), b_t)
a_t.clone().floor_divide_(b_t)
scripted_floor_divide__tensor(a_t.clone(), b_t)
tmp = a_t.clone()
with self.assertRaises(RuntimeError), self.assertWarnsOnceRegex(
UserWarning, "floor_divide"
):
tmp //= b_t
tmp //= b_t
else:
# Inplace modification is OK when both or neither tensor is
# a float tensor
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
self.assertEqual(
a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv
)
self.assertEqual(
scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
expected_itruncdiv,
)
tmp = a_t.clone()
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
tmp //= b_t
self.assertEqual(tmp.item(), expected_itruncdiv)

with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
self.assertEqual(
scripted_floor_divide__scalar(a_t), math.trunc(a / 5)
a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv
)
self.assertEqual(
scripted_floor_divide__tensor(a_t.clone(), b_t).item(),
expected_ifloordiv,
)
tmp = a_t.clone()
tmp //= b_t
self.assertEqual(tmp.item(), expected_ifloordiv)

self.assertEqual(
scripted_floor_divide__scalar(a_t), math.floor(a / 5)
)

# Tests binary op equivalence with Python builtin ops
# Also tests that reverse operations are equivalent to forward ops
Expand Down Expand Up @@ -2747,9 +2723,8 @@ def test_floor_divide_tensor(self, device, dtype):
x = torch.randn(10, device=device).mul(30).to(dtype)
y = torch.arange(1, 11, dtype=dtype, device=device)

with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"):
z = x // y
z_alt = torch.trunc(x.double() / y.double()).to(dtype)
z = x // y
z_alt = torch.floor(x.double() / y.double()).to(dtype)

self.assertEqual(z.dtype, x.dtype)
self.assertEqual(z, z_alt)
Expand All @@ -2761,36 +2736,14 @@ def test_floor_divide_tensor(self, device, dtype):
def test_floor_divide_scalar(self, device, dtype):
x = torch.randn(100, device=device).mul(10).to(dtype)

with self.assertWarnsOnceRegex(UserWarning, "__floordiv__"):
z = x // 3
z = x // 3
z_alt = torch.tensor(
[math.trunc(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
[math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device
)

self.assertEqual(z.dtype, x.dtype)
self.assertEqual(z, z_alt)

# Note: this tests fails on XLA
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.long)
def test_floor_divide_out(self, device, dtype):
x = torch.randn(10, device=device).mul(10).to(dtype)
y = torch.arange(1, 11, dtype=dtype, device=device)
o = torch.empty(10, dtype=dtype, device=device)

with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
torch.floor_divide(x, y, out=o)
self.assertEqual(o, x // y)

# Tests scalar with out
torch.floor_divide(x, 2, out=o)
self.assertEqual(o, x // 2)

if dtype == torch.int:
o = torch.empty(10, dtype=torch.float, device=device)
torch.floor_divide(x, y, out=o)
self.assertEqual(o, torch.floor_divide(x.float(), y.float()))

@onlyCPU
@dtypes(*get_all_math_dtypes("cpu"))
def test_rdiv(self, device, dtype):
Expand Down
26 changes: 11 additions & 15 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7090,19 +7090,6 @@ def test_number_div(self):
self.checkScript(div_int_nofuture, ())
self.checkScript(div_float_nofuture, ())

def test_floor_div(self):
@torch.jit.script
def foo(a, b):
# type: (int, int) -> int
return a // b
for i in range(-8, 8):
for j in range(-8, 8):
if j != 0:
self.assertEqual(foo(i, j), i // j)
else:
with self.assertRaisesRegex(RuntimeError, 'division by 0'):
foo(i, j)

# Testing bitwise shorthand aug assignment
def test_bool_augassign_bitwise_or(self):
def func(a: bool, b: bool) -> bool:
Expand Down Expand Up @@ -12514,6 +12501,16 @@ def fn():
for a, b in zip(eager_out, script_out):
check_equal_and_dtype(a, b)

def test_floor_div(self):
@torch.jit.script
def foo(a, b):
# type: (int, int) -> int
return a // b
for i in range(-8, 8):
for j in range(-8, 8):
if j != 0:
self.assertEqual(foo(i, j), i // j)

def test_floordiv(self):
funcs_template = dedent('''
def fn():
Expand All @@ -12532,8 +12529,7 @@ def fn():
cu = torch.jit.CompilationUnit(funcs_str)
f_script = cu.fn
f = scope['fn']
with self.assertWarnsOnceRegex(UserWarning, "floor_divide"):
self.assertEqual(f_script(), f())
self.assertEqual(f_script(), f())

def test_call_python_fn_from_script_fn(self):
@torch.jit.ignore
Expand Down
8 changes: 3 additions & 5 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,11 +1663,9 @@ def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device,
self.assertEqual(self.safeToDense(y1), expected)
self.assertEqual(self.safeToDense(y2), expected)

with self.assertWarnsOnceRegex(UserWarning, '__floordiv__'):
y1 = x1 // 37.5
y1 = x1 // 37.5
y2 = x1.clone()
with self.assertWarnsOnceRegex(UserWarning, 'floor_divide'):
y2.floor_divide_(37.5)
y2.floor_divide_(37.5)
expected = self.safeToDense(x1) // 37.5
self.assertEqual(self.safeToDense(y1), expected)
self.assertEqual(self.safeToDense(y2), expected)
Expand Down Expand Up @@ -3010,7 +3008,7 @@ def test_div_by_sparse_error(self, device):
/ torch.tensor(1., device=device).to_sparse())

def test_floor_divide_by_sparse_error(self, device):
self.assertRaisesRegex(RuntimeError, 'Sparse division requires',
self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires',
lambda: torch.tensor(1., device=device).to_sparse()
// torch.tensor(1., device=device).to_sparse())

Expand Down
14 changes: 2 additions & 12 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,21 +666,11 @@ def __rpow__(self, other):

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other):
warnings.warn("__floordiv__ is deprecated, and its behavior will change in a future version of pytorch. "
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
"This results in incorrect rounding for negative values. "
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3)
return torch.div(self, other, rounding_mode='trunc')
return torch.floor_divide(self, other)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rfloordiv__(self, other):
warnings.warn("__rfloordiv__ is deprecated, and its behavior will change in a future version of pytorch. "
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
"This results in incorrect rounding for negative values. "
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor').", stacklevel=3)
return torch.div(other, self, rounding_mode='trunc')
return torch.floor_divide(other, self)

@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rlshift__(self, other):
Expand Down
Loading