Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4838,6 +4838,20 @@ def fn_channels_last(x):
[torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)],
)

def test_like_channels_last(self):
def foo():
randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32)
xc = randn.contiguous(memory_format=torch.channels_last)
clone = torch.zeros_like(xc, memory_format=torch.preserve_format)
rand_like = torch.rand_like(randn)
return (xc, clone, rand_like)

out = foo()
out_comp = torch.compile()(foo)()

for t, t_comp in zip(out, out_comp):
self.assertEqual(t.stride(), t_comp.stride())

def test_as_strided_scatter(self):
def fn(a, b):
return aten.as_strided_scatter(
Expand Down
33 changes: 20 additions & 13 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from functools import partial
from torch import multiprocessing as mp
from torch.testing import make_tensor

from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON,
IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN,
Expand Down Expand Up @@ -5133,16 +5134,21 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
# xc is a channels last tensor
xc = input_generator_fn(device)
# xc is not memory dense, but looks like channels last
if memory_format == torch.channels_last:
xc = xc[..., ::2, ::2]
else:
xc = xc[..., ::2, ::2, ::2]
# We don't preserve non-dense striding
if not TEST_WITH_TORCHINDUCTOR:
if memory_format == torch.channels_last:
xc = xc[..., ::2, ::2]
else:
xc = xc[..., ::2, ::2, ::2]

clone = transformation_fn(xc, memory_format=torch.preserve_format)


self.assertFalse(clone.is_contiguous())
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
self.assertFalse(xc.is_contiguous())
self.assertFalse(xc.is_contiguous(memory_format=memory_format))
if not TEST_WITH_TORCHINDUCTOR:
self.assertFalse(xc.is_contiguous())
self.assertFalse(xc.is_contiguous(memory_format=memory_format))
if compare_data:
self.assertEqual(xc, clone.to(xc))

Expand All @@ -5165,12 +5171,14 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf
if compare_data:
self.assertEqual(xc, clone.to(xc))

x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
for _ in range(10):
permutation = list(range(len(x.shape)))
random.shuffle(permutation)
x = x.permute(permutation)
self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
# TODO copy _like constructors to stride permutation instead of just layout
if not TEST_WITH_TORCHINDUCTOR:
x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
for i in range(10):
permutation = list(range(len(x.shape)))
random.shuffle(permutation)
x = x.permute(permutation)
self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())

def test_memory_format_to(self, device):
def get_generator(memory_format, shape):
Expand Down Expand Up @@ -5223,7 +5231,6 @@ def transformation_fn(tensor, **kwargs):
self._test_memory_format_transformations(
device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True)

@skipIfTorchInductor("To be supported")
def test_memory_format_factory_like_functions_preserve(self, device):
def get_generator(memory_format, shape):
def input_generator_fn(device):
Expand Down
32 changes: 23 additions & 9 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import logging
import math
import typing
from typing import Optional

import torch
import torch._decomp as decomp
import torch._prims_common as utils
import torch.ao.quantization.fx._decomposed
from torch._decomp import (
core_aten_decompositions,
Expand Down Expand Up @@ -307,24 +309,34 @@ def view_copy_dtype(self, dtype):
return self.to(dtype).clone()


def get_like_layout(
tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
) -> torch.memory_format:
# TODO: _to_copy tensor to stride permutation
if memory_format in (torch.preserve_format, None):
return utils.suggest_memory_format(tensor)
else:
return memory_format


@register_decomposition(aten.rand_like)
def rand_like(self, *, dtype=None, device=None, **kwargs):
def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
return torch.rand(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randn_like)
def randn_like(self, *, dtype=None, device=None, **kwargs):
def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
return torch.randn(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.full_like)
Expand All @@ -346,31 +358,33 @@ def full_like(
layout=layout or self.layout,
device=device or self.device,
requires_grad=requires_grad,
)
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint_like.default)
def randint_like(self, high, *, dtype=None, device=None, **kwargs):
def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
return aten.randint.low(
0,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(self, low, high, *, dtype=None, device=None, **kwargs):
def randint_like_low(
self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
):
return aten.randint.low(
low,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
)
).to(memory_format=get_like_layout(self, memory_format))


@register_decomposition(aten.randint.default)
Expand Down