diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index f1831d60a77d6..2a1d2a8dcbe93 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -612,11 +612,33 @@ options: - cname: ltValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: ltTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_lt_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: ltValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: ltTensorByte arguments: - arg: THByteTensor* result output: True @@ -653,11 +675,33 @@ options: - cname: gtValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: gtTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_gt_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: gtValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: gtTensorByte arguments: - arg: THByteTensor* result output: True @@ -694,11 +738,33 @@ options: - cname: leValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: leTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_le_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: leValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: leTensorByte arguments: - arg: THByteTensor* result output: True @@ -735,11 +801,33 @@ options: - cname: geValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: geTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_ge_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: geValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: geTensorByte arguments: - arg: THByteTensor* result output: True @@ -776,11 +864,33 @@ options: - cname: eqValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: eqTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_eq_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: eqValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: eqTensorByte arguments: - arg: THByteTensor* result output: True @@ -817,11 +927,33 @@ options: - cname: neValue arguments: - - arg: THByteTensor* result + - arg: THBoolTensor* result output: True - THTensor* self - real other - cname: neTensor + arguments: + - arg: THBoolTensor* result + output: True + - arg: THTensor* self + broadcast: other fallback + - THTensor* other +]] +[[ + name: _th_ne_byte + cpu_bool: True + cuda_bool: True + variants: + - function + return: argument 0 + options: + - cname: neValueByte + arguments: + - arg: THByteTensor* result + output: True + - THTensor* self + - real other + - cname: neTensorByte arguments: - arg: THByteTensor* result output: True diff --git a/aten/src/ATen/native/Itertools.cpp b/aten/src/ATen/native/Itertools.cpp index cb379df7a31a2..09920acd1ac5a 100644 --- a/aten/src/ATen/native/Itertools.cpp +++ b/aten/src/ATen/native/Itertools.cpp @@ -12,7 +12,7 @@ Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) { // or i <= j <= k <= ... (depending on diagonal) Tensor range = at::arange(n, opt.dtype(kLong)); std::vector index_grids = at::meshgrid(std::vector(dims, range)); - Tensor mask = at::ones(index_grids[0].sizes(), opt.dtype(kByte)); + Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool)); if(diagonal) { for(int64_t i = 0; i < dims - 1; i++) { mask *= index_grids[i] <= index_grids[i+1]; diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index a2148c73cdfd8..ceba7c753697a 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -64,4 +64,124 @@ Tensor gather_cpu(const Tensor & self, int64_t dim, const Tensor & index, bool s return legacy::cpu::_th_gather(self, dim, index); } +Tensor & lt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_lt_byte_out(result, self, other); + } else { + return legacy::cpu::_th_lt_out(result, self, other); + } +} + +Tensor & lt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_lt_byte_out(result, self, value); + } else { + return legacy::cpu::_th_lt_out(result, self, value); + } +} + +Tensor & le_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_le_byte_out(result, self, other); + } else { + return legacy::cpu::_th_le_out(result, self, other); + } +} + +Tensor & le_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_le_byte_out(result, self, value); + } else { + return legacy::cpu::_th_le_out(result, self, value); + } +} + +Tensor & gt_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_gt_byte_out(result, self, other); + } else { + return legacy::cpu::_th_gt_out(result, self, other); + } +} + +Tensor & gt_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_gt_byte_out(result, self, value); + } else { + return legacy::cpu::_th_gt_out(result, self, value); + } +} + +Tensor & ge_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_ge_byte_out(result, self, other); + } else { + return legacy::cpu::_th_ge_out(result, self, other); + } +} + +Tensor & ge_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_ge_byte_out(result, self, value); + } else { + return legacy::cpu::_th_ge_out(result, self, value); + } +} + +Tensor & eq_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_eq_byte_out(result, self, other); + } else { + return legacy::cpu::_th_eq_out(result, self, other); + } +} + +Tensor & eq_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_eq_byte_out(result, self, value); + } else { + return legacy::cpu::_th_eq_out(result, self, value); + } +} + +Tensor & ne_out_cpu(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_ne_byte_out(result, self, other); + } else { + return legacy::cpu::_th_ne_out(result, self, other); + } +} + +Tensor & ne_scalar_out_cpu(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cpu::_th_ne_byte_out(result, self, value); + } else { + return legacy::cpu::_th_ne_out(result, self, value); + } +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index e988b69a72113..1d835f0590cc5 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -83,7 +83,7 @@ static void max_kernel_impl( Tensor& max_indices, const Tensor& self, c10::optional dim) { - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "max", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "max", [&] { Reduction::apply(max, max_indices, self, dim, true); }); } @@ -93,7 +93,7 @@ static void min_kernel_impl( Tensor& min_indices, const Tensor& self, c10::optional dim) { - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "min", [&] { + AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "min", [&] { Reduction::apply(min, min_indices, self, dim, false); }); } diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index b101cd52954a0..d7c1aba7218a4 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -60,4 +60,123 @@ Tensor gather_cuda(const Tensor & self, int64_t dim, const Tensor & index, bool return legacy::cuda::_th_gather(self, dim, index); } +Tensor & lt_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_lt_byte_out(result, self, other); + } else { + return legacy::cuda::_th_lt_out(result, self, other); + } +} + +Tensor & lt_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.lt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_lt_byte_out(result, self, value); + } else { + return legacy::cuda::_th_lt_out(result, self, value); + } +} + +Tensor & le_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_le_byte_out(result, self, other); + } else { + return legacy::cuda::_th_le_out(result, self, other); + } +} + +Tensor & le_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.le received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_le_byte_out(result, self, value); + } else { + return legacy::cuda::_th_le_out(result, self, value); + } +} + +Tensor & gt_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_gt_byte_out(result, self, other); + } else { + return legacy::cuda::_th_gt_out(result, self, other); + } +} + +Tensor & gt_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.gt received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_gt_byte_out(result, self, value); + } else { + return legacy::cuda::_th_gt_out(result, self, value); + } +} + +Tensor & ge_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_ge_byte_out(result, self, other); + } else { + return legacy::cuda::_th_ge_out(result, self, other); + } +} + +Tensor & ge_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ge received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_ge_byte_out(result, self, value); + } else { + return legacy::cuda::_th_ge_out(result, self, value); + } +} + +Tensor & eq_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_eq_byte_out(result, self, other); + } else { + return legacy::cuda::_th_eq_out(result, self, other); + } +} + +Tensor & eq_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.eq received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_eq_byte_out(result, self, value); + } else { + return legacy::cuda::_th_eq_out(result, self, value); + } +} + +Tensor & ne_out_cuda(Tensor & result, const Tensor & self, const Tensor & other) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_ne_byte_out(result, self, other); + } else { + return legacy::cuda::_th_ne_out(result, self, other); + } +} + +Tensor & ne_scalar_out_cuda(Tensor & result, const Tensor & self, const Scalar value) { + if (result.dtype() == at::ScalarType::Byte) { + AT_WARN("torch.ne received 'out' parameter with dtype torch.uint8, this behavior is now deprecated," \ + "please use 'out' parameter with dtype torch.bool instead."); + return legacy::cuda::_th_ne_byte_out(result, self, value); + } else { + return legacy::cuda::_th_ne_out(result, self, value); + } +} }} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f12429976435c..93470d95d53e6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3539,8 +3539,8 @@ - func: ne(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_ne_out - CUDA: legacy::cuda::_th_ne_out + CPU: ne_scalar_out_cpu + CUDA: ne_scalar_out_cuda - func: ne(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3550,8 +3550,8 @@ - func: ne(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_ne_out - CUDA: legacy::cuda::_th_ne_out + CPU: ne_out_cpu + CUDA: ne_out_cuda - func: ne(Tensor self, Tensor other) -> Tensor variants: method, function @@ -3561,8 +3561,8 @@ - func: eq(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_eq_out - CUDA: legacy::cuda::_th_eq_out + CPU: eq_scalar_out_cpu + CUDA: eq_scalar_out_cuda - func: eq(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3572,8 +3572,8 @@ - func: eq(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_eq_out - CUDA: legacy::cuda::_th_eq_out + CPU: eq_out_cpu + CUDA: eq_out_cuda - func: eq(Tensor self, Tensor other) -> Tensor variants: method, function @@ -3583,8 +3583,8 @@ - func: ge(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_ge_out - CUDA: legacy::cuda::_th_ge_out + CPU: ge_scalar_out_cpu + CUDA: ge_scalar_out_cuda - func: ge(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3594,8 +3594,8 @@ - func: ge(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_ge_out - CUDA: legacy::cuda::_th_ge_out + CPU: ge_out_cpu + CUDA: ge_out_cuda - func: ge(Tensor self, Tensor other) -> Tensor variants: method, function @@ -3605,8 +3605,8 @@ - func: le(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_le_out - CUDA: legacy::cuda::_th_le_out + CPU: le_scalar_out_cpu + CUDA: le_scalar_out_cuda - func: le(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3616,8 +3616,8 @@ - func: le(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_le_out - CUDA: legacy::cuda::_th_le_out + CPU: le_out_cpu + CUDA: le_out_cuda - func: le(Tensor self, Tensor other) -> Tensor variants: method, function @@ -3627,8 +3627,8 @@ - func: gt(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_gt_out - CUDA: legacy::cuda::_th_gt_out + CPU: gt_scalar_out_cpu + CUDA: gt_scalar_out_cuda - func: gt(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3638,8 +3638,8 @@ - func: gt(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_gt_out - CUDA: legacy::cuda::_th_gt_out + CPU: gt_out_cpu + CUDA: gt_out_cuda - func: gt(Tensor self, Tensor other) -> Tensor variants: method, function @@ -3649,8 +3649,8 @@ - func: lt(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_lt_out - CUDA: legacy::cuda::_th_lt_out + CPU: lt_scalar_out_cpu + CUDA: lt_scalar_out_cuda - func: lt(Tensor self, Scalar other) -> Tensor variants: method, function @@ -3660,8 +3660,8 @@ - func: lt(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: legacy::cpu::_th_lt_out - CUDA: legacy::cuda::_th_lt_out + CPU: lt_out_cpu + CUDA: lt_out_cuda - func: lt(Tensor self, Tensor other) -> Tensor variants: method, function diff --git a/aten/src/ATen/test/atest.cpp b/aten/src/ATen/test/atest.cpp index 95b3950569d3d..856deb63a3887 100644 --- a/aten/src/ATen/test/atest.cpp +++ b/aten/src/ATen/test/atest.cpp @@ -36,7 +36,7 @@ TEST(atest, atest) { float b = a.to(); ASSERT_EQ(b, 4); - foo = (foo * foo) == (foo.pow(3)); + foo = ((foo * foo) == (foo.pow(3))).to(kByte); foo = 2 + (foo + 1); // foo = foo[3]; auto foo_v = foo.accessor(); diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 5b49bc0704d03..1545b0f85f60a 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -9,12 +9,12 @@ TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); #if !defined(TH_REAL_IS_HALF) && !defined(TH_REAL_IS_BFLOAT16) -TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(ltValue)(THBoolTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(leValue)(THBoolTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(gtValue)(THBoolTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(geValue)(THBoolTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(neValue)(THBoolTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(eqValue)(THBoolTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, scalar_t value); @@ -23,12 +23,12 @@ TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, scalar_t value); TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(ltTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(leTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(gtTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(geTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(neTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(eqTensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); @@ -37,6 +37,20 @@ TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(ltValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(leValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(gtValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(geValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(neValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(eqValueByte)(THByteTensor *r_, THTensor* t, scalar_t value); + +TH_API void THTensor_(ltTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(leTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(gtTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(geTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(neTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(eqTensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb); + TH_API accreal THTensor_(sumall)(THTensor *t); TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb); diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 1dc2bbb6c5630..cb1c146eec496 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -9,30 +9,30 @@ #include #endif -#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \ - void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \ - { \ - THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \ - { \ - THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ - } \ - void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ +#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \ + void THTensor_(NAME##Value)(THBoolTensor *r_, THTensor* t, scalar_t value) \ + { \ + THBoolTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(bool, r_, scalar_t, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \ + { \ + THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##Tensor)(THBoolTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THBoolTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(bool, r_, scalar_t, ta, scalar_t, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ } TENSOR_IMPLEMENT_LOGICAL(lt,<) @@ -42,6 +42,27 @@ TENSOR_IMPLEMENT_LOGICAL(ge,>=) TENSOR_IMPLEMENT_LOGICAL(eq,==) TENSOR_IMPLEMENT_LOGICAL(ne,!=) +#define TENSOR_IMPLEMENT_LOGICAL_BYTE(NAME,OP) \ + void THTensor_(NAME##ValueByte)(THByteTensor *r_, THTensor* t, scalar_t value) \ + { \ + THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##TensorByte)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + +TENSOR_IMPLEMENT_LOGICAL_BYTE(lt,<) +TENSOR_IMPLEMENT_LOGICAL_BYTE(gt,>) +TENSOR_IMPLEMENT_LOGICAL_BYTE(le,<=) +TENSOR_IMPLEMENT_LOGICAL_BYTE(ge,>=) +TENSOR_IMPLEMENT_LOGICAL_BYTE(eq,==) +TENSOR_IMPLEMENT_LOGICAL_BYTE(ne,!=) + int THTensor_(equal)(THTensor *ta, THTensor* tb) { int equal = 1; diff --git a/aten/src/THC/generic/THCTensorMathCompare.cu b/aten/src/THC/generic/THCTensorMathCompare.cu index 0581d46f63193..06d275ba97357 100644 --- a/aten/src/THC/generic/THCTensorMathCompare.cu +++ b/aten/src/THC/generic/THCTensorMathCompare.cu @@ -2,52 +2,52 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompare.cu" #else -void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(ltValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorLTValueOp(value)); + bool>(value)); } -void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(gtValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorGTValueOp(value)); + bool>(value)); } -void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(leValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorLEValueOp(value)); + bool>(value)); } -void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(geValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorGEValueOp(value)); + bool>(value)); } -void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(eqValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorEQValueOp(value)); + bool>(value)); } -void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +void THCTensor_(neValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - THC_logicalValue(state, self_, src, + THC_logicalValue(state, self_, src, TensorNEValueOp(value)); + bool>(value)); } void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value) @@ -98,4 +98,53 @@ void THCTensor_(neValueT)(THCState *state, THCTensor *self_, THCTensor *src, sca scalar_t>(value)); } + +void THCTensor_(ltValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorLTValueOp(value)); +} + +void THCTensor_(gtValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorGTValueOp(value)); +} + +void THCTensor_(leValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorLEValueOp(value)); +} + +void THCTensor_(geValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorGEValueOp(value)); +} + +void THCTensor_(eqValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorEQValueOp(value)); +} + +void THCTensor_(neValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); + THC_logicalValue(state, self_, src, + TensorNEValueOp(value)); +} + #endif diff --git a/aten/src/THC/generic/THCTensorMathCompare.h b/aten/src/THC/generic/THCTensorMathCompare.h index f54d47e7f175f..aede51b512dbd 100644 --- a/aten/src/THC/generic/THCTensorMathCompare.h +++ b/aten/src/THC/generic/THCTensorMathCompare.h @@ -2,12 +2,12 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompare.h" #else -THC_API void THCTensor_(ltValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); -THC_API void THCTensor_(gtValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); -THC_API void THCTensor_(leValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); -THC_API void THCTensor_(geValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); -THC_API void THCTensor_(eqValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); -THC_API void THCTensor_(neValue)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(ltValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(gtValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(leValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(geValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(eqValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(neValue)(THCState *state, THCudaBoolTensor *self_, THCTensor *src, scalar_t value); THC_API void THCTensor_(ltValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value); THC_API void THCTensor_(gtValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value); @@ -16,5 +16,11 @@ THC_API void THCTensor_(geValueT)(THCState *state, THCTensor *self_, THCTensor * THC_API void THCTensor_(eqValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value); THC_API void THCTensor_(neValueT)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(ltValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(gtValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(leValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(geValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(eqValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); +THC_API void THCTensor_(neValueByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src, scalar_t value); #endif diff --git a/aten/src/THC/generic/THCTensorMathCompareT.cu b/aten/src/THC/generic/THCTensorMathCompareT.cu index 559553d7b4a55..106b4d70aeab2 100644 --- a/aten/src/THC/generic/THCTensorMathCompareT.cu +++ b/aten/src/THC/generic/THCTensorMathCompareT.cu @@ -2,52 +2,52 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.cu" #else -void THCTensor_(ltTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorLTOp()); + bool>()); } -void THCTensor_(gtTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorGTOp()); + bool>()); } -void THCTensor_(leTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(leTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorLEOp()); + bool>()); } -void THCTensor_(geTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(geTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorGEOp()); + bool>()); } -void THCTensor_(eqTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(eqTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorEQOp()); + bool>()); } -void THCTensor_(neTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THC_logicalTensor(state, self_, src1, src2, + THC_logicalTensor(state, self_, src1, src2, TensorNEOp()); + bool>()); } void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2) @@ -98,4 +98,52 @@ void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, T scalar_t>()); } +void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorLTOp()); +} + +void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorGTOp()); +} + +void THCTensor_(leTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorLEOp()); +} + +void THCTensor_(geTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorGEOp()); +} + +void THCTensor_(eqTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorEQOp()); +} + +void THCTensor_(neTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2) +{ + THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); + THC_logicalTensor(state, self_, src1, src2, + TensorNEOp()); +} + #endif diff --git a/aten/src/THC/generic/THCTensorMathCompareT.h b/aten/src/THC/generic/THCTensorMathCompareT.h index 2d5597832f958..d4387ceb2a277 100644 --- a/aten/src/THC/generic/THCTensorMathCompareT.h +++ b/aten/src/THC/generic/THCTensorMathCompareT.h @@ -2,12 +2,12 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorMathCompareT.h" #else -THC_API void THCTensor_(ltTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(gtTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(leTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(geTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(eqTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); -THC_API void THCTensor_(neTensor)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(ltTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(gtTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(leTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(geTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(eqTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(neTensor)(THCState *state, THCudaBoolTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(ltTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(gtTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); @@ -16,4 +16,11 @@ THC_API void THCTensor_(geTensorT)(THCState *state, THCTensor *self_, THCTensor THC_API void THCTensor_(eqTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); THC_API void THCTensor_(neTensorT)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(ltTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(gtTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(leTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(geTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(eqTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); +THC_API void THCTensor_(neTensorByte)(THCState *state, THCudaByteTensor *self_, THCTensor *src1, THCTensor *src2); + #endif diff --git a/test/expect/TestScript.test_listconstruct_erasure.expect b/test/expect/TestScript.test_listconstruct_erasure.expect index 818a115e47153..0f7d470b0709e 100644 --- a/test/expect/TestScript.test_listconstruct_erasure.expect +++ b/test/expect/TestScript.test_listconstruct_erasure.expect @@ -12,7 +12,7 @@ ModelProto { Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]}, - Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]}, + Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 9}]}, Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]} ] } diff --git a/test/onnx/expect/TestOperators.test_equal.expect b/test/onnx/expect/TestOperators.test_equal.expect index dd518e0c3bce4..3fc69486853c9 100644 --- a/test/onnx/expect/TestOperators.test_equal.expect +++ b/test/onnx/expect/TestOperators.test_equal.expect @@ -8,16 +8,6 @@ graph { output: "2" op_type: "Equal" } - node { - input: "2" - output: "3" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -58,10 +48,10 @@ graph { } } output { - name: "3" + name: "2" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 1 diff --git a/test/onnx/expect/TestOperators.test_ge.expect b/test/onnx/expect/TestOperators.test_ge.expect index 03c5b964b258b..ccb24b28f860c 100644 --- a/test/onnx/expect/TestOperators.test_ge.expect +++ b/test/onnx/expect/TestOperators.test_ge.expect @@ -13,16 +13,6 @@ graph { output: "3" op_type: "Not" } - node { - input: "3" - output: "4" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -57,10 +47,10 @@ graph { } } output { - name: "4" + name: "3" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 3 diff --git a/test/onnx/expect/TestOperators.test_gt.expect b/test/onnx/expect/TestOperators.test_gt.expect index e0f7af6e6599a..733fcbcf34d69 100644 --- a/test/onnx/expect/TestOperators.test_gt.expect +++ b/test/onnx/expect/TestOperators.test_gt.expect @@ -8,16 +8,6 @@ graph { output: "2" op_type: "Greater" } - node { - input: "2" - output: "3" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -58,10 +48,10 @@ graph { } } output { - name: "3" + name: "2" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 1 diff --git a/test/onnx/expect/TestOperators.test_isnan.expect b/test/onnx/expect/TestOperators.test_isnan.expect index 941e91949fc45..63f748cca3f2b 100644 --- a/test/onnx/expect/TestOperators.test_isnan.expect +++ b/test/onnx/expect/TestOperators.test_isnan.expect @@ -7,16 +7,6 @@ graph { output: "1" op_type: "IsNaN" } - node { - input: "1" - output: "2" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -32,10 +22,10 @@ graph { } } output { - name: "2" + name: "1" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 3 diff --git a/test/onnx/expect/TestOperators.test_le.expect b/test/onnx/expect/TestOperators.test_le.expect index 7bd6b76e4f5f8..82622456aa15c 100644 --- a/test/onnx/expect/TestOperators.test_le.expect +++ b/test/onnx/expect/TestOperators.test_le.expect @@ -13,16 +13,6 @@ graph { output: "3" op_type: "Not" } - node { - input: "3" - output: "4" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -57,10 +47,10 @@ graph { } } output { - name: "4" + name: "3" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 3 diff --git a/test/onnx/expect/TestOperators.test_lt.expect b/test/onnx/expect/TestOperators.test_lt.expect index e23a15e23ef87..59fed6233be47 100644 --- a/test/onnx/expect/TestOperators.test_lt.expect +++ b/test/onnx/expect/TestOperators.test_lt.expect @@ -8,16 +8,6 @@ graph { output: "2" op_type: "Less" } - node { - input: "2" - output: "3" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -58,10 +48,10 @@ graph { } } output { - name: "3" + name: "2" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 1 diff --git a/test/onnx/expect/TestOperators.test_ne.expect b/test/onnx/expect/TestOperators.test_ne.expect index f1b9cc3b53b91..35b9240eb9107 100644 --- a/test/onnx/expect/TestOperators.test_ne.expect +++ b/test/onnx/expect/TestOperators.test_ne.expect @@ -13,16 +13,6 @@ graph { output: "3" op_type: "Not" } - node { - input: "3" - output: "4" - op_type: "Cast" - attribute { - name: "to" - i: 2 - type: INT - } - } name: "torch-jit-export" input { name: "0" @@ -63,10 +53,10 @@ graph { } } output { - name: "4" + name: "3" type { tensor_type { - elem_type: 2 + elem_type: 9 shape { dim { dim_value: 1 diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 75a74a7a91720..9c835313406c0 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -27,7 +27,7 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi # check the schema with the onnx checker onnx.checker.check_model(model) - # check target type and attributes + # check target type and attributes graph = model.graph # ops should contain an object for each node # in graph.node, in the right order. @@ -64,8 +64,7 @@ class MyModule(Module): def forward(self, x): return torch.isnan(x) - ops = [{"op_name" : "IsNaN"}, - {"op_name" : "Cast", "attributes" : [{"name" : "to", "i" : 2, "type" : 2}]}] + ops = [{"op_name" : "IsNaN"}] ops = {9 : ops, 10 : ops} x = torch.tensor([1.0, float('nan'), 2.0]) check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) @@ -192,7 +191,7 @@ def forward(self, x): class DynamicSliceModel(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): - return x[1:x.size(0)] + return x[1:x.size(0)] ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Constant"}, diff --git a/test/test_distributed.py b/test/test_distributed.py index 5c4fedad21fff..06dac5d5ba83c 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -643,7 +643,7 @@ def _test_broadcast_helper( tensor = tensor.cuda(rank_to_GPU[rank][0]) dist.broadcast(tensor, src, group_id) self.assertEqual(tensor.size(), expected_tensor.size()) - self.assertEqual(tensor.ne(expected_tensor).max(), 0) + self.assertEqual(tensor.ne(expected_tensor).max(), torch.tensor(False)) self._barrier() diff --git a/test/test_torch.py b/test/test_torch.py index b6378b959a679..e76fe8ecb8e11 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6843,26 +6843,71 @@ def reference(x, k, o3, o32): self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k, 'F'), reference) def test_logical(self): - x = torch.rand(100, 100) * 2 - 1 - - xgt = torch.gt(x, 1) - xlt = torch.lt(x, 1) - - xeq = torch.eq(x, 1) - xne = torch.ne(x, 1) + for device in torch.testing.get_all_device_types(): + x = torch.tensor([1, 2, 3, 4], device=device) + self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) + + b = torch.tensor([2], device=device) + self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) + + + with warnings.catch_warnings(record=True) as warningsCount: + byteRes = torch.empty_like(x, device=device).byte() + boolRes = torch.empty_like(x, device=device).bool() + + torch.lt(x, b, out=byteRes) + torch.lt(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.le(x, b, out=byteRes) + torch.le(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.ge(x, b, out=byteRes) + torch.ge(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.gt(x, b, out=byteRes) + torch.gt(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.eq(x, b, out=byteRes) + torch.eq(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.ne(x, b, out=byteRes) + torch.ne(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + self.assertEquals(len(warningsCount), 6) + + # Bool Tensor + x = torch.tensor([True, False, True, False], device=device) + self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) + self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) + self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) + self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) - neqs = xgt + xlt - all = neqs + xeq - self.assertEqual(neqs.long().sum(), xne.long().sum(), 0) - self.assertEqual(x.nelement(), all.long().sum()) def test_isfinite(self): x = torch.Tensor([1, inf, 2, -inf, nan, -10]) - self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 0, 1, 0, 0, 1])) + self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, False, True, False, False, True])) def test_isfinite_int(self): x = torch.tensor([1, 2, 3]) - self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1])) + self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, True, True])) def test_isfinite_type(self): with self.assertRaises(TypeError): @@ -11525,10 +11570,6 @@ def test_bitwise_ops(self): else: self.assertFalse(x[idx] ^ y[idx]) - invert_result = ~x - for idx in iter_indices(x): - self.assertEqual(1 - x[idx], invert_result[idx]) - x_clone = x.clone() x_clone &= y self.assertEqual(x_clone, and_result) @@ -11541,9 +11582,20 @@ def test_bitwise_ops(self): x_clone ^= y self.assertEqual(x_clone, xor_result) - def test_invert(self): - x = torch.ByteTensor([0, 1, 1]) - self.assertEqual((~x).tolist(), [1, 0, 0]) + def test_op_invert(self): + res = 0xffff - torch.arange(127, dtype=torch.int8) + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.arange(127, dtype=dtype) + self.assertEqual(res.to(dtype), ~a) + + self.assertEqual(torch.tensor([True, False]), + ~torch.tensor([False, True])) + + # test exceptions + for dtype in(torch.half, torch.float, torch.double): + a = torch.zeros(10, dtype=dtype) + with self.assertRaises(TypeError): + b = ~a def test_apply(self): x = torch.arange(1, 6) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index fda08869eb112..10f90958cadaf 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -304,14 +304,14 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { static Tensor dispatch_invert(const Tensor & self) { AutoNoGIL no_gil; OptionalDeviceGuard device_guard(device_of(self)); - return 1 - self; + return self.bitwise_not(); } static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; - if (self_.scalar_type() != at::kByte) { - throw TypeError("~ (operator.invert) is only implemented on byte tensors"); + if (!isIntegralType(self_.scalar_type()) && self_.scalar_type() != at::kBool) { + throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors"); } return THPVariable_Wrap(dispatch_invert(self_)); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 0fd9ba50f3d42..bc49b9bb5076a 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -959,7 +959,7 @@ class ShapePropagator { [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kByte)}; + return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kBool)}; } return {}; }}; diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index ad5022b4d73ce..f091f9c5ce3c7 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -797,7 +797,7 @@ const std::vector functions = { other_size = other.size() def backward(grad_output): grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size) - grad_other = (grad_output * (1 - condition).type_as(grad_output))._grad_sum_to_size(other_size) + grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size) return None, grad_self, grad_other return torch.where(condition, self, other), backward diff --git a/torch/functional.py b/torch/functional.py index c936c441370b5..7c045a11016d6 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -209,12 +209,12 @@ def isfinite(tensor): tensor (Tensor): A tensor to check Returns: - Tensor: A ``torch.ByteTensor`` containing a 1 at each location of finite elements and 0 otherwise + Tensor: ``A torch.Tensor with dtype torch.bool`` containing a True at each location of finite elements and False otherwise Example:: >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) - tensor([ 1, 0, 1, 0, 0], dtype=torch.uint8) + tensor([True, False, True, False, False]) """ if not isinstance(tensor, torch.Tensor): raise TypeError("The argument is not a tensor: {}".format(repr(tensor))) @@ -224,7 +224,7 @@ def isfinite(tensor): # have a similar concept. It's safe to assume any created LongTensor doesn't # overflow and it's finite. if not tensor.is_floating_point(): - return torch.ones_like(tensor, dtype=torch.uint8) + return torch.ones_like(tensor, dtype=torch.bool) return (tensor == tensor) & (tensor.abs() != inf) @@ -235,17 +235,17 @@ def isinf(tensor): tensor (Tensor): A tensor to check Returns: - Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `+/-INF` elements and 0 otherwise + Tensor: ``A torch.Tensor with dtype torch.bool`` containing a True at each location of `+/-INF` elements and False otherwise Example:: >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) - tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8) + tensor([False, True, False, True, False]) """ if not isinstance(tensor, torch.Tensor): raise TypeError("The argument is not a tensor: {}".format(repr(tensor))) if tensor.dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: - return torch.zeros_like(tensor, dtype=torch.uint8) + return torch.zeros_like(tensor, dtype=torch.bool) return tensor.abs() == inf diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 351272b879be8..c15bcdd0c4ed1 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -98,12 +98,10 @@ def _comparison_operator(g, input, other, op_name): # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten}, # integer input type not supported in opset8. Cast to float if possible. -@wrap_logical_op_with_cast_to('Byte') def gt(g, input, other): return _comparison_operator(g, input, other, "Greater") -@wrap_logical_op_with_cast_to('Byte') def lt(g, input, other): return _comparison_operator(g, input, other, "Less") diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index f8518bff9c404..bc9cee79e14f8 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -215,7 +215,7 @@ def symbolic(g, self, dim=None, keepdim=None): # dim-reduce path desc = 'is' if allow_multi_dim_support else 'i' dim, keepdim = sym_help._get_const(dim, desc, 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim') - dim_list = dim if allow_multi_dim_support else [dim] + dim_list = dim if allow_multi_dim_support else [dim] return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) return symbolic @@ -769,18 +769,15 @@ def wrap_with_not(g, input, other): return wrap_with_not -@wrap_logical_op_with_cast_to('Byte') def eq(g, self, other): return g.op("Equal", self, other) -@wrap_logical_op_with_cast_to('Byte') @wrap_logical_op_with_negation def ne(g, self, other): return g.op("Equal", self, other) -@wrap_logical_op_with_cast_to('Byte') def gt(g, input, other): return gt_impl(g, input, other) @@ -790,7 +787,6 @@ def gt_impl(g, input, other): return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input)) -@wrap_logical_op_with_cast_to('Byte') def lt(g, input, other): return lt_impl(g, input, other) @@ -800,14 +796,12 @@ def lt_impl(g, input, other): return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input)) -@wrap_logical_op_with_cast_to('Byte') @wrap_logical_op_with_negation def ge(g, input, other): other = sym_help._maybe_get_scalar(other) return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input)) -@wrap_logical_op_with_cast_to('Byte') @wrap_logical_op_with_negation def le(g, input, other): other = sym_help._maybe_get_scalar(other) @@ -1618,7 +1612,6 @@ def nonzero(g, input): @parse_args('v') def isnan(g, input): output = g.op('IsNaN', input) - output = sym_help._cast_func_template(sym_help.cast_pytorch_to_onnx['Byte'], g, output, None) return output