Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-BREAKING] Changed tensor comparison return type from uint8 to bool #21113

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8ac7b69
Changed tensor comparison return type from uint8 to bool
izdeby May 30, 2019
60c3dee
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
906ae2c
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
6baa96b
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
9c422e1
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
5a1ac07
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
804a503
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
3274dab
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
6ec7344
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
c86ae6c
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby May 30, 2019
3234e8e
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby Jun 1, 2019
1c015f6
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby Jun 3, 2019
3a1d363
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby Jun 3, 2019
0af4bb3
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby Jun 3, 2019
5eefdcb
Update on "[WIP] Changed tensor comparison return type from uint8 to …
izdeby Jun 3, 2019
877c1fc
Update on "Changed tensor comparison return type from uint8 to bool"
izdeby Jun 3, 2019
1f6c450
Update on "Changed tensor comparison return type from uint8 to bool"
izdeby Jun 3, 2019
88dfa92
Update on "Changed tensor comparison return type from uint8 to bool"
izdeby Jun 4, 2019
ebba3b9
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 7, 2019
e236554
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 13, 2019
fa54652
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 13, 2019
119b6f4
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 13, 2019
3b3480e
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 13, 2019
6947868
Update on "[BC-BREAKING] [WIP] Changed tensor comparison return type …
izdeby Jun 26, 2019
09b71b8
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jun 26, 2019
2a2d2cd
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 11, 2019
7fe64bf
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 11, 2019
cbad160
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 11, 2019
3fe4537
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 11, 2019
71b5bf3
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 16, 2019
4ade573
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 22, 2019
bcd45a2
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 24, 2019
8f555d6
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 25, 2019
5c69147
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 25, 2019
b3024a4
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 25, 2019
ebc100e
fix onnx expect files
houseroad Jul 26, 2019
1a12bc9
fix one more onnx test
houseroad Jul 26, 2019
630c5f4
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 26, 2019
2a7a70b
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 30, 2019
0d1b92f
Update on "[BC-BREAKING] Changed tensor comparison return type from u…
izdeby Jul 31, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
144 changes: 138 additions & 6 deletions aten/src/ATen/Declarations.cwrap
Expand Up @@ -628,11 +628,33 @@
options:
- cname: ltValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_lt_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: ltValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: ltTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -669,11 +691,33 @@
options:
- cname: gtValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: gtTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_gt_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: gtValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: gtTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -710,11 +754,33 @@
options:
- cname: leValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: leTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_le_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: leValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: leTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -751,11 +817,33 @@
options:
- cname: geValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: geTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_ge_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: geValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: geTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -792,11 +880,33 @@
options:
- cname: eqValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: eqTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_eq_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: eqValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: eqTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down Expand Up @@ -833,11 +943,33 @@
options:
- cname: neValue
arguments:
- arg: THByteTensor* result
- arg: THBoolTensor* result
output: True
- THTensor* self
- real other
- cname: neTensor
arguments:
- arg: THBoolTensor* result
output: True
- arg: THTensor* self
broadcast: other fallback
- THTensor* other
]]
[[
name: _th_ne_byte
cpu_bool: True
cuda_bool: True
variants:
- function
return: argument 0
options:
- cname: neValueByte
arguments:
- arg: THByteTensor* result
output: True
- THTensor* self
- real other
- cname: neTensorByte
arguments:
- arg: THByteTensor* result
output: True
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Itertools.cpp
Expand Up @@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
// or i <= j <= k <= ... (depending on diagonal)
Tensor range = at::arange(n, opt.dtype(kLong));
std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range));
Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte));
Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool));
if(diagonal) {
for(int64_t i = 0; i < dims - 1; i++) {
mask *= index_grids[i] <= index_grids[i+1];
Expand Down
120 changes: 120 additions & 0 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Expand Up @@ -64,4 +64,124 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s
return legacy::cpu::_th_gather(self, dim, index);
}

Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, other);
} else {
return legacy::cpu::_th_lt_out(result, self, other);
}
}

Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_lt_byte_out(result, self, value);
} else {
return legacy::cpu::_th_lt_out(result, self, value);
}
}

Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_le_byte_out(result, self, other);
} else {
return legacy::cpu::_th_le_out(result, self, other);
}
}

Tensor & le_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_le_byte_out(result, self, value);
} else {
return legacy::cpu::_th_le_out(result, self, value);
}
}

Tensor & gt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_gt_byte_out(result, self, other);
} else {
return legacy::cpu::_th_gt_out(result, self, other);
}
}

Tensor & gt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_gt_byte_out(result, self, value);
} else {
return legacy::cpu::_th_gt_out(result, self, value);
}
}

Tensor & ge_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ge_byte_out(result, self, other);
} else {
return legacy::cpu::_th_ge_out(result, self, other);
}
}

Tensor & ge_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ge_byte_out(result, self, value);
} else {
return legacy::cpu::_th_ge_out(result, self, value);
}
}

Tensor & eq_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_eq_byte_out(result, self, other);
} else {
return legacy::cpu::_th_eq_out(result, self, other);
}
}

Tensor & eq_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_eq_byte_out(result, self, value);
} else {
return legacy::cpu::_th_eq_out(result, self, value);
}
}

Tensor & ne_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ne_byte_out(result, self, other);
} else {
return legacy::cpu::_th_ne_out(result, self, other);
}
}

Tensor & ne_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) {
if (result.dtype() == at::ScalarType::Byte) {
AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \
"please use 'out' parameter with dtype torch.bool instead.");
return legacy::cpu::_th_ne_byte_out(result, self, value);
} else {
return legacy::cpu::_th_ne_out(result, self, value);
}
}

}} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/TensorCompareKernel.cpp
Expand Up @@ -83,7 +83,7 @@ static void max_kernel_impl(
Tensor& max_indices,
const Tensor& self,
c10::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] {
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
});
}
Expand All @@ -93,7 +93,7 @@ static void min_kernel_impl(
Tensor& min_indices,
const Tensor& self,
c10::optional<int64_t> dim) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] {
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
});
}
Expand Down