Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement operators logical_and and logical_or #24379

Closed
wants to merge 24 commits into from

Conversation

xuhdev
Copy link
Collaborator

@xuhdev xuhdev commented Aug 15, 2019

Stack from ghstack:

Pull Request resolved: #24379

Differential Revision: D16830128

@pytorchbot pytorchbot added module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen module: operators labels Aug 15, 2019
xuhdev added a commit that referenced this pull request Aug 15, 2019
ghstack-source-id: d9d93b7d68d335a9f096bdf66b2b59c27ba4dc1c
Pull Request resolved: #24379
@xuhdev xuhdev requested a review from gchanan August 15, 2019 00:17
@xuhdev
Copy link
Collaborator Author

xuhdev commented Aug 15, 2019

@bddppq Are you aware of a nvfunctional equivalent in ROCm?

torch/_torch_docs.py Outdated Show resolved Hide resolved
@@ -81,21 +81,29 @@ void div_kernel(TensorIterator& iter) {
}
}

void logical_xor_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, iter.dtype(1), "logical_xor_cpu", [&]() {
void logical_binary_kernel_impl(TensorIterator& iter, const char* op_name, std::function<bool(bool, bool)> op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this; not only does it not work with ROCm, but it can't be inlined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now changed this to templating. According to my experiments, nvcc seems to be able to inline functions passed in as template parameters.

__host__ __device__ float called1(float x) {
  return x + 1;
}

__host__ __device__ float called2(float x) {
  return x + 10;
}

template <typename Op>
__host__ __device__ float caller(float x, Op op) {
  return op(x);
}

__global__ void user(float x, float& y) {
  x = caller(x, called1);
  y = caller(x, called2);
}

After nvcc -ptx -src-in-ptx -arch=sm_60 test.cu:

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_60
.address_size 64

	// .globl	_Z4userfRf

.visible .entry _Z4userfRf(
	.param .f32 _Z4userfRf_param_0,
	.param .u64 _Z4userfRf_param_1
)
{
	.reg .f32 	%f<4>;
	.reg .b64 	%rd<3>;


	ld.param.f32 	%f1, [_Z4userfRf_param_0];
	ld.param.u64 	%rd1, [_Z4userfRf_param_1];
	cvta.to.global.u64 	%rd2, %rd1;
	add.f32 	%f2, %f1, 0f3F800000;
	add.f32 	%f3, %f2, 0f41200000;
	st.global.f32 	[%rd2], %f3;
	ret;
}

test/test_torch.py Outdated Show resolved Hide resolved
@xuhdev xuhdev requested a review from gchanan August 15, 2019 20:36
xuhdev added a commit that referenced this pull request Aug 15, 2019
ghstack-source-id: 2b0ecfda305c32e84d7aadd54640198684e74295
Pull Request resolved: #24379
@pytorchbot pytorchbot added the oncall: quantization Quantization support in PyTorch label Aug 16, 2019
xuhdev added a commit that referenced this pull request Nov 16, 2019
Superseding #24379 as type promotion has been implemented.

Close #24379

[ghstack-poisoned]
xuhdev added a commit to xuhdev/pytorch that referenced this pull request Nov 16, 2019
Superseding pytorch#24379 as type promotion has been implemented.

Close pytorch#24379
@facebook-github-bot facebook-github-bot deleted the gh/xuhdev/25/head branch December 10, 2019 15:21
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
Superseding pytorch#24379 as type promotion has been implemented.

Close pytorch#24379
Pull Request resolved: pytorch#28162

Differential Revision: D18580867

Pulled By: ailzhang

fbshipit-source-id: 7e4d7c37da4dc8df87314bd4f1f6a7539e46586a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks module: internals Related to internal abstractions in c10 and ATen oncall: quantization Quantization support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants