Skip to content

Commit

Permalink
Enable fast pass tensor_fill for single element complex tensors (#50383)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50383

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D25879881

Pulled By: anjali411

fbshipit-source-id: a254cff48ea9a6a38f7ee206815a04c31a9bcab0
  • Loading branch information
anjali411 authored and facebook-github-bot committed Jan 12, 2021
1 parent 6420071 commit 5834438
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/ScalarOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ inline void fill_inplace(Tensor& self, Scalar value_scalar) {

namespace detail {
Tensor& scalar_fill(Tensor& self, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND3(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16, self.scalar_type(), "fill_out", [&]() {
fill_inplace<scalar_t>(self, value);
});
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/ScalarOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ namespace c10 {
// to implement this without going through Derived Types (which are not part of core).
inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) {
// This is the fast track we have for CPU scalar tensors.
if (device == at::kCPU && !s.isComplex()) {
if (device == at::kCPU) {
if (s.isFloatingPoint()) {
return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU);
} else if (s.isComplex()) {
return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
} else if (s.isBoolean()) {
return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
} else {
Expand Down
5 changes: 1 addition & 4 deletions aten/src/ATen/TensorIndexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,7 @@ static inline int64_t count_specified_dimensions(const ArrayRef<TensorIndex>& in
// The rest of the functions are in `at::indexing::impl` namespace, signifying
// that they shouldn't be used from Python indexing implementation.
static inline Tensor scalarToTensor(Scalar v, const TensorOptions& options, const at::Device& self_device) {
if (self_device == at::kCPU && !v.isComplex() &&
options.dtype_opt()->toScalarType() != ScalarType::ComplexDouble &&
options.dtype_opt()->toScalarType() != ScalarType::ComplexFloat &&
options.dtype_opt()->toScalarType() != ScalarType::ComplexHalf) {
if (self_device == at::kCPU) {
return at::detail::scalar_tensor_static(v, options.dtype_opt()->toScalarType(), self_device);
} else {
return impl::scalarToTensorNonNativeDeviceType(v, options);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Tensor& fill_out(Tensor& self, Scalar value) {
self.copy_(out);
return self;
}
if (self.device() == at::kCPU && self.numel() == 1 && !self.is_complex() && !value.isComplex()) {
if (self.device() == at::kCPU && self.numel() == 1) {
return at::detail::scalar_fill(self, value);
}
auto iter = TensorIteratorConfig()
Expand Down

0 comments on commit 5834438

Please sign in to comment.