From e02f02ac0a1456c1fab9b805e86a348b4b8a9488 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Fri, 17 May 2024 17:10:42 -0700 Subject: [PATCH] don't check memory format for empty tensors [ghstack-poisoned] --- test/test_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index a62cc635a3d30..ab05e9df43558 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -663,10 +663,10 @@ def inner_check_out_mem_format(output): d = output.dim() if (d == 4 and ((input_mem_format == torch.channels_last) or (module_mem_format == torch.channels_last and module_memformat_affects_out))): - self.assertTrue(output.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last)) elif (d == 5 and ((input_mem_format == torch.channels_last_3d) or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))): - self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d)) + self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d)) else: self.assertTrue(output.is_contiguous()) return self._traverse_obj(output, inner_check_out_mem_format)