Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flip() fn in ATen (CPU + CUDA) #7873

Merged
merged 967 commits into from
Jun 16, 2018
Merged

Conversation

weiyangfb
Copy link
Contributor

@weiyangfb weiyangfb commented May 26, 2018

Summary:

  1. fixes flip a Tensor #229
  2. implemented torch.flip() to reverse tensor (contiguous and non-contiguous) along specified dimensions
  3. implemented forward and backward functions for both of CPU and CUDA
  4. added tests at test_torch, test_cuda, and test_autograd

Details:
Given that a tensor element's offset = sum_i indices[i] * strides(i), we can flip on indices for each element, and then copy values to the corresponding offset.

Usage:
x = torch.arange(8).view(2, 2, 2).flip(0, 1, 2) # flip along the 1st, 2nd, and 3rd dimensions

Future work:

  1. use thrust to speed up CUDA implementation

@ivan-bilan
Copy link

Great, can't wait for this to be released.

@sethah
Copy link
Contributor

sethah commented May 27, 2018

Will this need an entry in _torch_docs.py?

@ngimel
Copy link
Collaborator

ngimel commented May 28, 2018

Nice work!
Please unify dimensions error checking for cuda and cpu versions (right now it's 50 lines of copy-pasted code).
For cuda implementation, please run collapseDims pass on input (see in this file https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/TensorInfo.cuh), so that e.g. last-dimension flip of a multi-D contiguous tensor is the same as dimension flip for a 2d tensor.
Also, instead of implementing specialized kernel for this, for the flipped tensor you can create TensorInfo object with the negative strides for flipped dimensions (negative strides are generally not supported, and TensorInfo IndexType is usually unsigned, but you can instantiate it with signed) and run kernelPointwiseApply2 from CUDAApplyUtils.cuh with CopyOp. That way, you don't have to reimplement indexToOffset and back functions (TensorInfo already has them), and don't have to materialize indices tensor (that's really bad for performance).

@fmassa
Copy link
Member

fmassa commented May 28, 2018

Just to follow-up with my message on my previous post, here is an (untested) implementation of flip that uses a combination of meshgrid and advanced indexing. Might be good to benchmark against the current implementation.

def multi_meshgrid(*args):
    """
    Creates a meshgrid from possibly many
    elements (instead of only 2).
    Returns a nd tensor with as many dimensions
    as there are arguments
    """
    args = list(args)
    template = [1 for _ in args]
    for i in range(len(args)):
        n = args[i].shape[0]
        template_copy = template.copy()
        template_copy[i] = n
        args[i] = args[i].view(*template_copy)
        # there will be some broadcast magic going on
    return tuple(args)

def flip(tensor, dims):
    if not isinstance(dims, (tuple, list)):
        dims = [dims]
    indices = [torch.arange(tensor.shape[dim] - 1, -1, -1,
        dtype=torch.int64) for dim in dims]
    multi_indices = multi_meshgrid(*indices)
    final_indices = [slice(i) for i in tensor.shape]
    for i, dim in enumerate(dims):
        final_indices[dim] = multi_indices[i]
    flipped = tensor[final_indices]
    # need to permute the final dimensions
    # if dims is not consecutive, but I'm lazy
    # now :-)
    return flipped

std::stringstream ss;
ss << "expected input tensor dims not empty, "
<< "but got tensor dims size=" << flip_dims_size;
throw std::runtime_error(ss.str());

This comment was marked as off-topic.

// check duplicates in dims
auto flip_dims_v = std::vector<int64_t>(dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {

This comment was marked as off-topic.

// check len of dims
if (flip_dims_size > total_dims) {
std::stringstream ss;
ss << "expected flip dims size <= tensor total dims, "

This comment was marked as off-topic.

}

if (min_d < 0) {
std::stringstream ss;

This comment was marked as off-topic.


Tensor flip_cpu(const Tensor& self, IntList dims) {

int64_t total_dims = self.dim(), flip_dims_size = dims.size();

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

}

// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;

This comment was marked as off-topic.


// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;
for (auto d : dims) {

This comment was marked as off-topic.

throw std::runtime_error(ss.str());
}

Tensor out_t = self.clone();

This comment was marked as off-topic.

@@ -0,0 +1,156 @@
#include "ATen/NativeFunctions.h"
#include "ATen/ATen.h"
#include <algorithm>

This comment was marked as off-topic.

flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {
std::stringstream ss;
ss << "dims has duplicates, "

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor Author

@sethah Yes, I agree that should be an entry added to _torch_docs.py, will try that do that after code is finalized.

@weiyangfb
Copy link
Contributor Author

@fmassa I just gave it a try, and your implementation is indeed much faster!! Here are some results:

Your implementation:

data = torch.arange(1000000).view(1000,1000)
%timeit flip(data, (0,1))
----------------------------

100 loops, best of 3: 7.62 ms per loop

My implementation:

data = torch.arange(1000000).view(1000,1000)
%timeit data.flip(0,1)
----------------------------

100 loops, best of 3: 19.5 ms per loop

@weiyangfb
Copy link
Contributor Author

@ngimel Thanks a ton for the great great suggestions! I will modify the cuda implementation using TensorInfo. But can I try to understand how to apply negative strides for flipped dimensions? And why signed IndexType is important here? Can I also ask why collapseDims pass is very much relevant to flip on nD tensor here (maybe a simple example)?

And yes, I will reuse the error checks.

@weiyangfb
Copy link
Contributor Author

@fmassa Your implementation is very nice and only requires one copy of input tensor. Can I translate your code into the CPU implementation of flip() using tensor.index() ?

@fmassa
Copy link
Member

fmassa commented May 30, 2018

@weiyangfb definitely! My implementation also works on the GPU, it might be good to benchmark it as well to see how it compares to the dedicated kernel.
It should probably be a bit slower because I called arange on the CPU, but that could be called on the GPU as well.

The benefit of writing it as a native function is that we don't need to implement a backward pass for it (even if the backward of flip is simply a flip).

@ngimel
Copy link
Collaborator

ngimel commented May 30, 2018

collapseDims is important because it will reduce the amount of indexing math that you have to do. Suppose you have a contiguous 4d tensor, where you want to flip the last dimension. You can collapse the first 3 dims to view this tensor as 2d, then your indexing math will be simpler (you have to loop over just 2 dimensions). If you are flipping multiple dimensions, applying collapseDims is much trickier (may be impossible, if your flip dimensions are not contiguous, say you want to flip 0 and 2), but for a single flipped dimension collapseDims should help.
Now, to negative strides. Suppose you want to flip a 1d tensor. You can create TensorInfo object with data pointer pointing to the end of your output tensor, and set a stride of the 0-th dimension to -1, copy your original tensor to the tensor described by this TensorInfo object (using standard pointwiseApply kernel that's already in ATen), and then view the result a contiguous tensor. Similarly for flips in other dimensions/multiple flipped dimensions - you'd have to move base pointer, and set negative strides for the dimensions you want to flip, but that will be CPU code, not GPU. Obviously, since you want negative strides, you can not use unsigned IndexType for those values.
That said, it is quite possible that @fmassa's implementation already achieves good fraction of peak bandwidth, you should benchmark it first (not the absolute time, but what bandwidth you achieve compared to maximum on your card), in which case you can just use it.

@weiyangfb
Copy link
Contributor Author

@fmassa You are completely right! I translated your code and had it tested out. Now the performance in CPU is similar, where on GPU my current implementation is slightly faster.

data = torch.arange(1000000).view(1000,1000)
%timeit flip_meshgrid(data, (0,1))
--------------------------------------------

100 loops, best of 3: 10.8 ms per loop

data = torch.arange(1000000).view(1000,1000)
%timeit data.flip(0,1)
--------------------------------------------

100 loops, best of 3: 11 ms per loop

data_cuda = torch.arange(1000000, device=torch.device('cuda')).view(1000,1000)
%timeit flip_meshgrid(data_cuda, (0,1))
--------------------------------------------

1000 loops, best of 3: 1.72 ms per loop

data_cuda = torch.arange(1000000, device=torch.device('cuda')).view(1000,1000)
%timeit data_cuda.flip(0,1)
--------------------------------------------

1000 loops, best of 3: 637 µs per loop

@fmassa
Copy link
Member

fmassa commented May 30, 2018

Nice, thanks for the benchmarks @weiyangfb ! Can you try also adding a torch.cuda.synchronize() when benchmarking then CUDA kernels? Also, I'd be curious to know which fraction of the time was spent on the indexing, and which one was spent on the torch.arange. Would it be possible to check that as well?

Thanks!

@weiyangfb
Copy link
Contributor Author

weiyangfb commented May 30, 2018

@ngimel Thanks a lot for the detailed instructions! Even though collapseDims might not help in the case of nD flip, I'd love to have it to speed up the case of flip in one dimension. So if I understand correctly, I probably will need to apply collapseDims on nD input tensor with dim to be flipped excluded - this gives a 2D tensor. Then I will need to use IndexToOffset along with negative stride to move elements from src to dst tensor. One quick question, set a stride of the 0-th dimension to -1 works for 1D tensor, and so what formula works for the 2D tensor?

Currently I removed the materialized indices in cuda kernel, and had it tested. I am still not quite sure how to test for GPU bandwidth, here I looked at some numbers from nvidia-smi --query-gpu=gpu_name,gpu_bus_id,utilization.gpu,utilization.memory,memory.used --format=csv -l

@fmassa implementation:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit flip_meshgrid(data_cuda, (0,1))
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 83 %, 29 %, 478 MiB

1000 loops, best of 3: 1.72 ms per loop

My implementation with materialized indices:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit data_cuda.flip(0,1)
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 90 %, 73 %, 478 MiB

1000 loops, best of 3: 637 µs per loop

My current implementation without materialized indices:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit data_cuda.flip(0,1)
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 85 %, 36 %, 463 MiB

1000 loops, best of 3: 357 µs per loop

temp[i] = indices[i].size(0);
indices[i] = indices[i].view(IntList(temp));
}
return self.index(TensorList(indices));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ngimel
Copy link
Collaborator

ngimel commented May 31, 2018

@weiyangfb, correct about collapseDims, as for bandwidth, you can compute it as bytes/time, in you case its 8e6/357e-6 = 22.4 GB/s, not that great. K40 peak bandwidth is around 200 GB/s. (8e6 because your tensor has 1 million elements, 4 bytes per element, each element has to be read and written, hence 4*2). You can also compare your time with e.g. time for a pointwise operation for a same size tensor, e.g. a *=2.
For nD tensor, for each dimension that you are flipping you have to shift the base pointer by (dim[i]-1)*stride[i], and set the stride to -stride[i], where stride[i] are the strides of contiguous tensor of those dimensions. At least I think so, please check my math.

@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

@weiyangfb
Copy link
Contributor Author

@ngimel Huge thanks for walking me through CUDA performance details and the formula for flipping! For nD tensor, I think your math is correct. Thanks a lot for sharing the formula! I will implement this for the case of flipping on a single dim. Will update the PR in a bit.

This performance analysis is super helpful! I will keep tracking this! I am also trying to use a.t().contiguous() as a benchmark. Is it going to be a tighter lower bound since flip() requires similar non-continuous memory access?

@ngimel
Copy link
Collaborator

ngimel commented Jun 4, 2018

For flipping, save for some alignment issues (which can be avoided for sure by e.g. using 1024x1024 tensor), you accesses are still contiguous (elements that are adjacent in original tensor would still be adjacent in the flipped one, even if they are in the different order), so comparing with a regular pointwise op is better. You might want to run your comparison against some real 2d tensor that cannot be collapsed to 1d (you can create it by e.g. running torch.chunk on the 1st dim), to add some index math that you necessarily have for flipping.

@weiyangfb
Copy link
Contributor Author

@ngimel Thanks a lot! Now it makes all sense! I am using TensorInfo and collapseDims to speed up the case where flip dim is the 1st or last dim. Here are some performance results:

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(0)
----------------------------------
10000 loops, best of 3: 178 µs per loop
data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(2)
----------------------------------
10000 loops, best of 3: 181 µs per loop

benchmark:

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.mul(2)
----------------------------------
10000 loops, best of 3: 86.3 µs per loop

And if I understand it correctly, collapseDims might not be able to squeeze nD to 2D tensor if flip dim is not the 1st or last dim, I am using the previous impl for these cases.

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(1)
----------------------------------
1000 loops, best of 3: 364 µs per loop

aryamccarthy and others added 3 commits June 11, 2018 12:17
This provides a bare-minimum MPI Process Group implementation, the commit is on top of @pietern's Gloo Process Group PR.

* [c10d] MPI Process Group Implementation

ref: pytorch#7434

* Better exception, atexit func, and addressed comments

* Clang formatting changes

* Static initialization and addressed comments

* Added constness back

* Test will now launch mpi processes if found

* CMakeList Changed
* Fix Windows doc for import error

* Fix doc again

* Fix wrong format
yf225 and others added 7 commits June 11, 2018 12:18
* Add cursors to C++ API

* Small self nits

* s/struct/class

* Use more STL like names for cursors
* Implement arange_like operator

* add ONNX symbolic

* lint

* change name

* Comment the hack
…sts; 3. using TensorInfo and collapseDims to speed up CUDA impl for cases where flip dim is the 1st or last dim
@weiyangfb
Copy link
Contributor Author

weiyangfb commented Jun 13, 2018

@fmassa Using torch.cuda.synchronize() does not change much on the runtime, am I doing it correctly?

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
def meshgrid():
    flip_meshgrid(data_cuda, (0, 1))
    torch.cuda.synchronize()
%timeit meshgrid()

1000 loops, best of 3: 1.76 ms per loop

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
def flip():
    data_cuda.flip(0,1)
    torch.cuda.synchronize()
%timeit flip()

1000 loops, best of 3: 353 µs per loop

I don't know why though. Can I ask normally how to profile the fraction of time spent?

@weiyangfb
Copy link
Contributor Author

@fmassa @ngimel is this PR ready for stamp?

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good, I'd still like to reduce the amount of integer math.

int64_t total_dims) {
for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) {
int64_t cur_indices = linear_index, rem = 0, dst_offset = 0;
for (int64_t i = 0; i < total_dims; i++) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

for (int64_t i = 0; i < total_dims; i++) {
int64_t temp = cur_indices;
cur_indices = cur_indices / in_tensor_info.strides[i];
rem = temp - cur_indices * in_tensor_info.strides[i];

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

int flip_dim,
int64_t total_dims) {
for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) {
int64_t cur_indices = linear_index, rem = 0, dst_offset = 0;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

…64_t) IndexType for indices in pointwise CUDA kernel
@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

@weiyangfb
Copy link
Contributor Author

caffe2 failing test seems not related

@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

2 similar comments
@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

@weiyangfb
Copy link
Contributor Author

@pytorchbot retest this please

@weiyangfb
Copy link
Contributor Author

caffe2 and lint tests failing seems not related, can I get a stamp on this?

@soumith soumith merged commit c9b8d85 into pytorch:master Jun 16, 2018
@weiyangfb weiyangfb deleted the flip_tensor branch June 22, 2018 18:12
@adam-dziedzic
Copy link

Can you recreate such results:

>>> a
tensor([[1., 1., 1.],
        [1., 0., 2.]], dtype=torch.float64)
>>> torch.flip(a, [1])
tensor([[1., 1.],
        [0., 1.]], dtype=torch.float64)

?

@weiyangfb
Copy link
Contributor Author

@adam-dziedzic Yes, I can reproduce your results. I think this is a bug. Let me create an issue for this.

@ashwhall
Copy link

@weiyangfb Does this operation copy the memory or give a view into it? I'm flipping HD video, so copying the data is a real memory-bottleneck.

@soumith
Copy link
Member

soumith commented Jul 24, 2018

@ashwhall it copies the memory over, but does it pretty efficiently.

facebook-github-bot pushed a commit that referenced this pull request Nov 8, 2018
Summary:
- a walk around for #13292, a complete fix requires investigation on the root cause when using advanced indexing
- this PR brings in `filp()` CUDA implementation for CPU kernel
- with this change:
```
>>> t = torch.randn(1, 3, 4, 5)
>> t.flip(1, 3).shape
torch.Size([1, 3, 4, 5])
```
- performance:
```
====== with this PR ======
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 100 a.flip(0, 1)
1.98 ms ± 579 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

====== Perf at previous PR #7873 ======
100 loops, best of 3: 11 ms per loop
```
Pull Request resolved: #13344

Differential Revision: D12968003

Pulled By: weiyangfb

fbshipit-source-id: 66f434049d143a0575a35b5c983b3e0577a1a28d
@soumith soumith mentioned this pull request Dec 12, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

flip a Tensor