Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discussion about cuda kernel #12

Closed
ClementPinard opened this issue Jun 11, 2018 · 11 comments
Closed

Discussion about cuda kernel #12

ClementPinard opened this issue Jun 11, 2018 · 11 comments

Comments

@ClementPinard
Copy link
Contributor

ClementPinard commented Jun 11, 2018

Hello,

this is more a thread discussion than a real issue, but I've been working on the cuda kernel readability.
And pytorch actually provides very nice way of presenting tensor data for kernels as if it was still a multidimensional vector.

see here for a working prototype : https://github.com/ClementPinard/extension-cpp/blob/deviceTensorExperiments/cuda/lltm_cuda_kernel.cu

Essentially, I designed a simple convertor from at::Tensor to THCDeviceTensor<scalar_t, 2, size_t, RestrictPtrTraits>

The conversion is not very pretty, but it allows us to write more readable memory accesses in kernels while still doing eventually the exact same thing (even the __restricted__ keyword is kept)

Let's look at the current code for forward :

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
    const scalar_t* __restrict__ gates,
    const scalar_t* __restrict__ old_cell,
    scalar_t* __restrict__ new_h,
    scalar_t* __restrict__ new_cell,
    scalar_t* __restrict__ input_gate,
    scalar_t* __restrict__ output_gate,
    scalar_t* __restrict__ candidate_cell,
    size_t state_size) {
  const int column = blockIdx.x * blockDim.x + threadIdx.x;
  const int index = blockIdx.y * state_size + column;
  const int gates_row = blockIdx.y * (state_size * 3);
  if (column < state_size) {
    input_gate[index] = sigmoid(gates[gates_row + column]);
    output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
    candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
    new_cell[index] =
        old_cell[index] + candidate_cell[index] * input_gate[index];
    new_h[index] = tanh(new_cell[index]) * output_gate[index];
  }
}

the columnand index are kinda hard to figure out. It actually use the fact that blockDim.y is batch size and thus BlockIdx.y the batch index. column is then the index in the state and index is batch_idx * batch_stride + column while gates_row is the first index of the gates in that particular element of the batch, because its batch stride is thrice as much.

Now my code proposition :

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
    const dTensor2R gates,
    const dTensor2R old_cell,
    dTensor2R new_h,
    dTensor2R new_cell,
    dTensor2R input_gate,
    dTensor2R output_gate,
    dTensor2R candidate_cell,
    size_t state_size) {
  const int n = blockIdx.y; //batch index
  CUDA_KERNEL_LOOP(c, state_size) {
    input_gate[n][c] = sigmoid((scalar_t) gates[n][c]);
    output_gate[n][c] = sigmoid((scalar_t) gates[n][c + state_size]);
    candidate_cell[n][c] = elu((scalar_t) gates[n][c + 2 * state_size]);
    new_cell[n][c] =
        old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
    new_h[n][c] = tanh((scalar_t) new_cell[n][c]) * output_gate[n][c];
  }
}

I use dTensor2Rthat defined as THCDeviceTensor<scalar_t, 2, size_t, RestrictPtrTraits> in a macro above.
Besides using the strided loop CUDA_KERNEL_LOOP (just for the sake of good practices), we now only need to compute n which is explicetely the batch index and c which is the column from above.
every relevant value can now be accessed with tensor[n][c + shift] making it very similar to an actual 2D array.

I tested my code on master (from a few days) and it works for both check.py and grad_check.py . It does not need pytorch source code, only the compiled binaries and the headers.

Is this proposition legit ? I feel like it could be good way of letting people write cuda with more complicated ND-tensors (like 4D tensors for regular feature maps) without all the complex indexing stuff. And if so, that could be a good reason for letting a more use friendly method for at::Tensor to deviceTHCTensor conversion being written.

@goldsborough
Copy link
Contributor

Thanks for contributing your ideas. I like how it makes the CUDA kernel shorter and more readable (assuming one knows what the macro does). It's important to note, however, that any use of TH things is not officially supported in C++ extensions. TH is a very low level backend to PyTorch and an active construction site. We remove or change things in it almost every day and there is no guarantee of any kind that THCDeviceTensor will still exist tomorrow. ATen is the only supported interface to PyTorch. It's fine to use TH things for your project as long as it works, but we won't advertise it for all users. It may be worth adding convenient functionality such as this to ATen directly, to make writing CUDA kernels easier.

@braveapple
Copy link

braveapple commented Jun 22, 2018

at::Tensor max_ROI_cuda(
  at::Tensor x,
  at::Tensor ROI_size
) {
  const auto batch_size = x.size(0);
  const auto channel_num = x.size(1);
  const auto feat_height = x.size(2);
  const auto feat_width = x.size(3);

  auto ROI_pos = at::zeros({x.size(0), x.size(1)}, x.type());

  const dim3 blocksPerGrid(1); // 1 block per grid (1D) (x, )
  const dim3 threadsPerBlock(batch_size, channel_num); // batch_size * channel_num threads per block (2D) (x, y)
  
  AT_DISPATCH_FLOATING_TYPES(x.type(), "max_ROI_cuda", ([&] {
    max_ROI_cuda_kernel<scalar_t><<<blocksPerGrid, threadsPerBlock>>>(
      feat_height,
      feat_width,
      x.data<scalar_t>(),
      ROI_size.data<scalar_t>(),
      ROI_pos.data<scalar_t>()
    );
  }));

  return ROI_pos; 
}

When using the function "at::zeros({x.size(0), x.size(1)}, x.type())", I got two building errors: (1) error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t); (2) error: no suitable user-defined conversion from "at::Type" to "at::IntList" exists. can anybody help me to fix this problem? Thanks.

@ClementPinard
Copy link
Contributor Author

ClementPinard commented Jun 22, 2018

I think you should replace x.type() by x.options()

auto ROI_pos = at::zeros({x.size(0), x.size(1)}, x.options());
The key here is that the second argument is not the type but a list of options, within which can be found the type, but also e.g. the device

more info : https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/TensorOptions.h

this is done e.g. here : https://github.com/pytorch/pytorch/master/aten/src/ATen/native/SummaryOps.cpp#L36

@braveapple
Copy link

Thanks. Your advice helps me a lot.

@goldsborough
Copy link
Contributor

Hmm no this should not have been a problem, TensorOptions has an implicit constructor from Type, otherwise all such code in the wild would break. It's true that x.options() is, since 1 week, the correct way of doing this since it preserves the device, but x.type() should still work fine. @braveapple your code compiles perfectly fine for me, I just tried it. Could you maybe paste the full error you got at the time?

@braveapple
Copy link

braveapple commented Jun 23, 2018

Hello @goldsborough. When I used x.type(), I also got such a building error.

image

$ python step.py install
running install
running bdist_egg
running egg_info
writing space_dropout_cuda.egg-info/PKG-INFO
writing top-level names to space_dropout_cuda.egg-info/top_level.txt
writing dependency_links to space_dropout_cuda.egg-info/dependency_links.txt
reading manifest file 'space_dropout_cuda.egg-info/SOURCES.txt'
writing manifest file 'space_dropout_cuda.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'space_dropout_cuda' extension
gcc -pthread -B /home/dmt/anaconda2/compiler_compat -Wl,--sysroot=/ -fno-strict-aliasing -g -O2 -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/TH -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/dmt/anaconda2/include/python2.7 -c space_dropout_cuda.cpp -o build/temp.linux-x86_64-2.7/space_dropout_cuda.o -DTORCH_EXTENSION_NAME=space_dropout_cuda -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
/usr/local/cuda/bin/nvcc -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/TH -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/dmt/anaconda2/include/python2.7 -c space_dropout_cuda_kernel.cu -o build/temp.linux-x86_64-2.7/space_dropout_cuda_kernel.o -DTORCH_EXTENSION_NAME=space_dropout_cuda --compiler-options '-fPIC' -std=c++11

space_dropout_cuda_kernel.cu(182): error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t)

space_dropout_cuda_kernel.cu(182): error: no suitable user-defined conversion from "at::Type" to "at::IntList" exists

2 errors detected in the compilation of "/tmp/tmpxft_00007903_00000000-6_space_dropout_cuda_kernel.cpp1.ii".
error: command '/usr/local/cuda/bin/nvcc' failed with exit status 1

@braveapple
Copy link

braveapple commented Jun 23, 2018

Hello @goldsborough. When I used x.options(), I also got similar building error.

image

$ python step.py install
running install
running bdist_egg
running egg_info
writing space_dropout_cuda.egg-info/PKG-INFO
writing top-level names to space_dropout_cuda.egg-info/top_level.txt
writing dependency_links to space_dropout_cuda.egg-info/dependency_links.txt
reading manifest file 'space_dropout_cuda.egg-info/SOURCES.txt'
writing manifest file 'space_dropout_cuda.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'space_dropout_cuda' extension
gcc -pthread -B /home/dmt/anaconda2/compiler_compat -Wl,--sysroot=/ -fno-strict-aliasing -g -O2 -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/TH -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/dmt/anaconda2/include/python2.7 -c space_dropout_cuda.cpp -o build/temp.linux-x86_64-2.7/space_dropout_cuda.o -DTORCH_EXTENSION_NAME=space_dropout_cuda -std=c++11
cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
/usr/local/cuda/bin/nvcc -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/TH -I/home/dmt/anaconda2/lib/python2.7/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/dmt/anaconda2/include/python2.7 -c space_dropout_cuda_kernel.cu -o build/temp.linux-x86_64-2.7/space_dropout_cuda_kernel.o -DTORCH_EXTENSION_NAME=space_dropout_cuda --compiler-options '-fPIC' -std=c++11
space_dropout_cuda_kernel.cu(182): error: class "at::Tensor" has no member "options"

space_dropout_cuda_kernel.cu(182): error: no instance of constructor "at::Type::Type" matches the argument list argument types are: (int64_t, int64_t)

2 errors detected in the compilation of "/tmp/tmpxft_00000639_00000000-6_space_dropout_cuda_kernel.cpp1.ii".
error: command '/usr/local/cuda/bin/nvcc' failed with exit status 1

@ClementPinard
Copy link
Contributor Author

what's your pytorch version ?

import torch
torch.__version__

the error: class "at::Tensor" has no member "options" makes me think that your version is not very up to date.

@braveapple
Copy link

braveapple commented Jun 25, 2018

@ClementPinard. Thanks for your reply! My pytorch version is 0.4.0 (the newest version).
image

@ClementPinard
Copy link
Contributor Author

https://github.com/pytorch/pytorch/blob/v0.4.0/aten/src/ATen/test/basic.cpp

when looking at the 0.4.0 version of this code, if think you can try to invert type and sizes

auto ROI_pos = at::zeros(x.type(), {x.size(0), x.size(1)});

@ClementPinard
Copy link
Contributor Author

packed tensor accessors are now a thing, thanks @t-vi ! Would it be a good idea to implement it here ? Just implemented it for my own extension, and it works like a charm (and is more official than THCDeviceTensor 😆 , would be nice to spread awarenesse of this awesome feature, which is sadly without any documentation for the moment (apart from tests e.g. here )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants