-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torch.empty_like and torch.zeros_like raise error if any memory forma…
…t is provided with sparse input (#43699) (#44058) Summary: Fixes #43699 - Changed the order of `TORCH_CHECK` and `if (options.layout() == kSparse && self.is_sparse())` inside `empty_like` method. - [x] Added tests EDIT: More details on that and why we can not take zeros_like approach. Python code : ```python res = torch.zeros_like(input_coalesced, memory_format=torch.preserve_format) ``` is routed to ```c++ // TensorFactories.cpp Tensor zeros_like( const Tensor& self, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) { if (options.layout() == kSparse && self.is_sparse()) { auto res = at::empty({0}, options); // to be resized res.sparse_resize_and_clear_( self.sizes(), self.sparse_dim(), self.dense_dim()); return res; } auto result = at::empty_like(self, options, optional_memory_format); return result.zero_(); } ``` and passed to `if (options.layout() == kSparse && self.is_sparse())` When we call in Python ```python res = torch.empty_like(input_coalesced, memory_format=torch.preserve_format) ``` it is routed to ```c++ Tensor empty_like( const Tensor& self, const TensorOptions& options_, c10::optional<c10::MemoryFormat> optional_memory_format) { TORCH_CHECK( !(options_.has_memory_format() && optional_memory_format.has_value()), "Cannot set memory_format both in TensorOptions and explicit argument; please delete " "the redundant setter."); TensorOptions options = self.options() .merge_in(options_) .merge_in(TensorOptions().memory_format(optional_memory_format)); TORCH_CHECK( !(options.layout() != kStrided && optional_memory_format.has_value()), "memory format option is only supported by strided tensors"); if (options.layout() == kSparse && self.is_sparse()) { auto result = at::empty({0}, options); // to be resized result.sparse_resize_and_clear_( self.sizes(), self.sparse_dim(), self.dense_dim()); return result; } ``` cc pearu Pull Request resolved: #44058 Reviewed By: albanD Differential Revision: D23672494 Pulled By: mruberry fbshipit-source-id: af232274dd2b516dd6e875fc986e3090fa285658
- Loading branch information
1 parent
1fde54d
commit 24df3b7
Showing
4 changed files
with
133 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters