Skip to content

Commit

Permalink
don't check memory format for empty tensors
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
shunting314 committed May 18, 2024
1 parent ed27879 commit e02f02a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,10 @@ def inner_check_out_mem_format(output):
d = output.dim()
if (d == 4 and ((input_mem_format == torch.channels_last)
or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last))
elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d))
self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d))
else:
self.assertTrue(output.is_contiguous())
return self._traverse_obj(output, inner_check_out_mem_format)
Expand Down

0 comments on commit e02f02a

Please sign in to comment.