Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triu/tril: complete dtype support for CPU/CUDA. #101414

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions aten/src/ATen/native/TriangularOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void apply_triu_tril_single(
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = 0; j < std::min(m, i + k); j++) {
result[i * res_row_stride + j * res_col_stride] = 0;
result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = std::max(zero, i + k); j < m; j++) {
Expand All @@ -71,7 +71,7 @@ void apply_triu_tril_single(
parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
for (int64_t i : c10::irange(start, end)) {
for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
result[i * res_row_stride + j * res_col_stride] = 0;
result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
}
if (!inplace) { // copy the rest of the self if not inplace
for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
Expand Down Expand Up @@ -155,7 +155,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
result_c = result;
}

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
ScalarType::ComplexHalf,
ScalarType::BFloat16,
ScalarType::Half,
ScalarType::Bool,
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/cuda/TriangularOps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k
int64_t N = self.numel();
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid((N + dim_block.x - 1) / dim_block.x);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, at::ScalarType::Half, at::ScalarType::Bool,
self.scalar_type(), "triu_tril_cuda_template", [&]{
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw given that this op only copies elements or sets them to 0, can you dispatch just based on the size of the type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you fine with a follow-up?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure

at::ScalarType::ComplexHalf,
at::ScalarType::Half,
at::ScalarType::BFloat16,
at::ScalarType::Bool,
self.scalar_type(), "triu_tril_cuda_template", [&] {
if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) {
auto result_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(result);
auto self_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(self);
Expand Down
6 changes: 2 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16510,15 +16510,13 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
sample_inputs_func=sample_inputs_adjoint),
OpInfo('tril',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
error_inputs_func=error_inputs_tril_triu,
sample_inputs_func=sample_inputs_tril_triu),
OpInfo('triu',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
error_inputs_func=error_inputs_tril_triu,
Expand Down