Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve compare kernel #29743

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions aten/src/ATen/native/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,18 @@ void TensorIterator::compute_types() {
}

for (auto &op : operands_) {
bool skip_output = compute_common_dtype_only_for_inputs && op.is_output;
bool is_different = op.tensor.defined() && op.current_dtype != common_dtype_;

if (may_have_differing_types) {
validate_dtype(op, common_dtype_, common_dtype_strategy_);
bool cast_by_copy = compute_common_dtype && !common_device_is_cuda && (!compute_common_dtype_only_for_inputs || !op.is_output);
bool cast_by_copy = compute_common_dtype && !common_device_is_cuda && !skip_output;
if (cast_by_copy) {
maybe_copy_casting_to_common_dtype(op, common_dtype_);
}
}

if (op.tensor.defined() && op.current_dtype != common_dtype_) {
if (is_different && !skip_output) {
have_differing_types_ = true;
}

Expand Down Expand Up @@ -686,6 +689,7 @@ TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
iter.allow_cpu_scalars_ = true;
iter.compute_common_dtype_only_for_inputs();
iter.build();
iter.dynamic_cast_if(iter.dtype() != kBool);
return iter;
}

Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/native/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ struct CAFFE2_API TensorIterator {
bool is_final_output() const { return final_output_; }

bool needs_dynamic_casting() const {
return (common_dtype_strategy_ != CommonDTypeStrategy::NONE) && have_differing_types_;
return force_dynamic_casting_ || ((common_dtype_strategy_ != CommonDTypeStrategy::NONE) && have_differing_types_);
}

void set_check_mem_overlap(bool check_mem_overlap) {
Expand Down Expand Up @@ -343,6 +343,10 @@ struct CAFFE2_API TensorIterator {
resize_outputs_ = false;
}

void dynamic_cast_if(bool condition) {
force_dynamic_casting_ = force_dynamic_casting_ || condition;
}

void build();

protected:
Expand Down Expand Up @@ -383,6 +387,7 @@ struct CAFFE2_API TensorIterator {
bool final_output_ = true;
bool check_mem_overlap_ = false;
bool have_differing_types_ = false;
bool force_dynamic_casting_ = false;
bool all_ops_same_shape_ = false;
bool requires_channels_last_output_ = false;
};
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/cuda/BinaryCompareKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,47 @@ namespace at { namespace native {

void lt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "lt_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a < b;
});
});
}

void le_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "le_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a <= b;
});
});
}

void gt_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "gt_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a > b;
});
});
}

void ge_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ge_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a >= b;
});
});
}

void eq_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "eq_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a == b;
});
});
}

void ne_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ne_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return a != b;
});
});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void bitwise_xor_kernel_cuda(TensorIterator& iter) {

void logical_xor_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "logical_xor_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return static_cast<scalar_t>(bool(a) != bool(b));
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
return bool(a) != bool(b);
});
});
}
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ void copy_device_to_device(TensorIterator& iter, bool non_blocking) {
cudaMemcpyDeviceToDevice,
copy_stream));
} else {
// this is done intentionally done after build because copy has a "promotion"
// rule that always "promote" to target dtype.
iter.promote_common_dtype();
iter.dynamic_cast_if(true);
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
});
Expand Down