New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added flip() fn in ATen (CPU + CUDA) #7873

Merged
merged 967 commits into from Jun 16, 2018

Conversation

@weiyangfb
Copy link
Contributor

weiyangfb commented May 26, 2018

Summary:

  1. fixes #229
  2. implemented torch.flip() to reverse tensor (contiguous and non-contiguous) along specified dimensions
  3. implemented forward and backward functions for both of CPU and CUDA
  4. added tests at test_torch, test_cuda, and test_autograd

Details:
Given that a tensor element's offset = sum_i indices[i] * strides(i), we can flip on indices for each element, and then copy values to the corresponding offset.

Usage:
x = torch.arange(8).view(2, 2, 2).flip(0, 1, 2) # flip along the 1st, 2nd, and 3rd dimensions

Future work:

  1. use thrust to speed up CUDA implementation
@ivan-bilan

This comment has been minimized.

Copy link

ivan-bilan commented May 27, 2018

Great, can't wait for this to be released.

@sethah

This comment has been minimized.

Copy link
Contributor

sethah commented May 27, 2018

Will this need an entry in _torch_docs.py?

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented May 28, 2018

Nice work!
Please unify dimensions error checking for cuda and cpu versions (right now it's 50 lines of copy-pasted code).
For cuda implementation, please run collapseDims pass on input (see in this file https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/TensorInfo.cuh), so that e.g. last-dimension flip of a multi-D contiguous tensor is the same as dimension flip for a 2d tensor.
Also, instead of implementing specialized kernel for this, for the flipped tensor you can create TensorInfo object with the negative strides for flipped dimensions (negative strides are generally not supported, and TensorInfo IndexType is usually unsigned, but you can instantiate it with signed) and run kernelPointwiseApply2 from CUDAApplyUtils.cuh with CopyOp. That way, you don't have to reimplement indexToOffset and back functions (TensorInfo already has them), and don't have to materialize indices tensor (that's really bad for performance).

@fmassa

This comment has been minimized.

Copy link
Member

fmassa commented May 28, 2018

Just to follow-up with my message on my previous post, here is an (untested) implementation of flip that uses a combination of meshgrid and advanced indexing. Might be good to benchmark against the current implementation.

def multi_meshgrid(*args):
    """
    Creates a meshgrid from possibly many
    elements (instead of only 2).
    Returns a nd tensor with as many dimensions
    as there are arguments
    """
    args = list(args)
    template = [1 for _ in args]
    for i in range(len(args)):
        n = args[i].shape[0]
        template_copy = template.copy()
        template_copy[i] = n
        args[i] = args[i].view(*template_copy)
        # there will be some broadcast magic going on
    return tuple(args)

def flip(tensor, dims):
    if not isinstance(dims, (tuple, list)):
        dims = [dims]
    indices = [torch.arange(tensor.shape[dim] - 1, -1, -1,
        dtype=torch.int64) for dim in dims]
    multi_indices = multi_meshgrid(*indices)
    final_indices = [slice(i) for i in tensor.shape]
    for i, dim in enumerate(dims):
        final_indices[dim] = multi_indices[i]
    flipped = tensor[final_indices]
    # need to permute the final dimensions
    # if dims is not consecutive, but I'm lazy
    # now :-)
    return flipped
std::stringstream ss;
ss << "expected input tensor dims not empty, "
<< "but got tensor dims size=" << flip_dims_size;
throw std::runtime_error(ss.str());

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

Please use the AT_ERROR/AT_CHECK macro, which includes backtraces in errors. Write this as AT_CHECK(flip_dims_size != 0, "expected input tensor dims not to be empty, but got tensor dims size=", flip_dims_size)

// check duplicates in dims
auto flip_dims_v = std::vector<int64_t>(dims);
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

Use AT_CHECK

// check len of dims
if (flip_dims_size > total_dims) {
std::stringstream ss;
ss << "expected flip dims size <= tensor total dims, "

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

AT_CHECK

}

if (min_d < 0) {
std::stringstream ss;

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

AT_CHECK


Tensor flip_cpu(const Tensor& self, IntList dims) {

int64_t total_dims = self.dim(), flip_dims_size = dims.size();

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

nit: const

This comment has been minimized.

@weiyangfb

weiyangfb May 30, 2018

Author Contributor

what's wrong with the const here?

This comment has been minimized.

@goldsborough

goldsborough May 30, 2018

Contributor

I suggested you use const in const int64_t total_dims = ....

This comment has been minimized.

@weiyangfb

weiyangfb May 31, 2018

Author Contributor

Ah, I see, I will change it accordingly

}

// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

nit: const


// check if dims axis within range
int64_t min_d = total_dims, max_d = 0;
for (auto d : dims) {

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

You can use std::min_element and std::max_element, or maybe even std::minmax_element which returns both. Should save a line and is clearer

throw std::runtime_error(ss.str());
}

Tensor out_t = self.clone();

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

_t usually indicates a type. Better just call it out or out_tensor

@@ -0,0 +1,156 @@
#include "ATen/NativeFunctions.h"
#include "ATen/ATen.h"
#include <algorithm>

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

The recommended include order is

  1. Project headers
  2. Third party headers
  3. Standard library headers
flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end());
if ((int64_t)flip_dims_v.size() < flip_dims_size) {
std::stringstream ss;
ss << "dims has duplicates, "

This comment has been minimized.

@goldsborough

goldsborough May 29, 2018

Contributor

AT_CHECK and probably shouldn't copy pasta

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 29, 2018

@sethah Yes, I agree that should be an entry added to _torch_docs.py, will try that do that after code is finalized.

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 29, 2018

@fmassa I just gave it a try, and your implementation is indeed much faster!! Here are some results:

Your implementation:

data = torch.arange(1000000).view(1000,1000)
%timeit flip(data, (0,1))
----------------------------

100 loops, best of 3: 7.62 ms per loop

My implementation:

data = torch.arange(1000000).view(1000,1000)
%timeit data.flip(0,1)
----------------------------

100 loops, best of 3: 19.5 ms per loop

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 30, 2018

@ngimel Thanks a ton for the great great suggestions! I will modify the cuda implementation using TensorInfo. But can I try to understand how to apply negative strides for flipped dimensions? And why signed IndexType is important here? Can I also ask why collapseDims pass is very much relevant to flip on nD tensor here (maybe a simple example)?

And yes, I will reuse the error checks.

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 30, 2018

@fmassa Your implementation is very nice and only requires one copy of input tensor. Can I translate your code into the CPU implementation of flip() using tensor.index() ?

@fmassa

This comment has been minimized.

Copy link
Member

fmassa commented May 30, 2018

@weiyangfb definitely! My implementation also works on the GPU, it might be good to benchmark it as well to see how it compares to the dedicated kernel.
It should probably be a bit slower because I called arange on the CPU, but that could be called on the GPU as well.

The benefit of writing it as a native function is that we don't need to implement a backward pass for it (even if the backward of flip is simply a flip).

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented May 30, 2018

collapseDims is important because it will reduce the amount of indexing math that you have to do. Suppose you have a contiguous 4d tensor, where you want to flip the last dimension. You can collapse the first 3 dims to view this tensor as 2d, then your indexing math will be simpler (you have to loop over just 2 dimensions). If you are flipping multiple dimensions, applying collapseDims is much trickier (may be impossible, if your flip dimensions are not contiguous, say you want to flip 0 and 2), but for a single flipped dimension collapseDims should help.
Now, to negative strides. Suppose you want to flip a 1d tensor. You can create TensorInfo object with data pointer pointing to the end of your output tensor, and set a stride of the 0-th dimension to -1, copy your original tensor to the tensor described by this TensorInfo object (using standard pointwiseApply kernel that's already in ATen), and then view the result a contiguous tensor. Similarly for flips in other dimensions/multiple flipped dimensions - you'd have to move base pointer, and set negative strides for the dimensions you want to flip, but that will be CPU code, not GPU. Obviously, since you want negative strides, you can not use unsigned IndexType for those values.
That said, it is quite possible that @fmassa's implementation already achieves good fraction of peak bandwidth, you should benchmark it first (not the absolute time, but what bandwidth you achieve compared to maximum on your card), in which case you can just use it.

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 30, 2018

@fmassa You are completely right! I translated your code and had it tested out. Now the performance in CPU is similar, where on GPU my current implementation is slightly faster.

data = torch.arange(1000000).view(1000,1000)
%timeit flip_meshgrid(data, (0,1))
--------------------------------------------

100 loops, best of 3: 10.8 ms per loop

data = torch.arange(1000000).view(1000,1000)
%timeit data.flip(0,1)
--------------------------------------------

100 loops, best of 3: 11 ms per loop

data_cuda = torch.arange(1000000, device=torch.device('cuda')).view(1000,1000)
%timeit flip_meshgrid(data_cuda, (0,1))
--------------------------------------------

1000 loops, best of 3: 1.72 ms per loop

data_cuda = torch.arange(1000000, device=torch.device('cuda')).view(1000,1000)
%timeit data_cuda.flip(0,1)
--------------------------------------------

1000 loops, best of 3: 637 µs per loop

@fmassa

This comment has been minimized.

Copy link
Member

fmassa commented May 30, 2018

Nice, thanks for the benchmarks @weiyangfb ! Can you try also adding a torch.cuda.synchronize() when benchmarking then CUDA kernels? Also, I'd be curious to know which fraction of the time was spent on the indexing, and which one was spent on the torch.arange. Would it be possible to check that as well?

Thanks!

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented May 30, 2018

@ngimel Thanks a lot for the detailed instructions! Even though collapseDims might not help in the case of nD flip, I'd love to have it to speed up the case of flip in one dimension. So if I understand correctly, I probably will need to apply collapseDims on nD input tensor with dim to be flipped excluded - this gives a 2D tensor. Then I will need to use IndexToOffset along with negative stride to move elements from src to dst tensor. One quick question, set a stride of the 0-th dimension to -1 works for 1D tensor, and so what formula works for the 2D tensor?

Currently I removed the materialized indices in cuda kernel, and had it tested. I am still not quite sure how to test for GPU bandwidth, here I looked at some numbers from nvidia-smi --query-gpu=gpu_name,gpu_bus_id,utilization.gpu,utilization.memory,memory.used --format=csv -l

@fmassa implementation:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit flip_meshgrid(data_cuda, (0,1))
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 83 %, 29 %, 478 MiB

1000 loops, best of 3: 1.72 ms per loop

My implementation with materialized indices:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit data_cuda.flip(0,1)
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 90 %, 73 %, 478 MiB

1000 loops, best of 3: 637 µs per loop

My current implementation without materialized indices:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit data_cuda.flip(0,1)
-------------------------------------------------------------------
Tesla K40m, 00000000:28:00.0, 85 %, 36 %, 463 MiB

1000 loops, best of 3: 357 µs per loop
temp[i] = indices[i].size(0);
indices[i] = indices[i].view(IntList(temp));
}
return self.index(TensorList(indices));

This comment has been minimized.

@fmassa

fmassa May 31, 2018

Member

note that you need to permute the dimensions if they are not consecutive.
So for example, if you have

a = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
b = a.flip(0, 2)

this implementation will need to transpose the returned results because the advanced indexing puts all non-consecutive indices in the beginning of the tensor.

b.shape # should be 2, 4, 3
result = b.permute(0, 2, 1)  # this is missing from the code!

This comment has been minimized.

@weiyangfb

weiyangfb Jun 7, 2018

Author Contributor

Good catch! I will add this. Is there an easy way to generate the correct permute order?

This comment has been minimized.

@weiyangfb

weiyangfb Jun 7, 2018

Author Contributor

I am still missing something. In your code, how does final_indices = [slice(i) for i in tensor.shape] translate to ATen code?

def multi_meshgrid(*args):
    """
    Creates a meshgrid from possibly many
    elements (instead of only 2).
    Returns a nd tensor with as many dimensions
    as there are arguments
    """
    args = list(args)
    template = [1 for _ in args]
    for i in range(len(args)):
        n = args[i].shape[0]
        template_copy = template.copy()
        template_copy[i] = n
        args[i] = args[i].view(*template_copy)
        # there will be some broadcast magic going on
    return tuple(args)

def flip(tensor, dims):
    if not isinstance(dims, (tuple, list)):
        dims = [dims]
    indices = [torch.arange(tensor.shape[dim] - 1, -1, -1,
        dtype=torch.int64) for dim in dims]
    multi_indices = multi_meshgrid(*indices)
    final_indices = [slice(i) for i in tensor.shape]
    for i, dim in enumerate(dims):
        final_indices[dim] = multi_indices[i]
    flipped = tensor[final_indices]
    # need to permute the final dimensions
    # if dims is not consecutive, but I'm lazy
    # now :-)
    return flipped

This comment has been minimized.

@fmassa

fmassa Jun 8, 2018

Member

To get the final permutation order, all you need to do is to sort the indices after permutation.
say we select on dims 1 and 3, of a 5d tensor the end result of the flip will have the following order of dimensions

flip_dims = [1, 3, 0, 2, 4, 5]

where 1, 3 comes first because of advanced indexing on non-consecutive dimensions, and then comes the other dimensions.
So this is obtained by something like

dims + [i for i in range(tensor.dim()) if i not in dims]

Once you have this tensor containing the order of the dimensions, all you need to do is to get the indices after sorting:

_, permutation = flip_dims.sort(0)

Note that those steps are only necessary if the dimensions in dim are non-consecutive, i.e., if

torch.all(tensor(dims[1:] - dims[:-1]) < 2)

(or something like that, the condition can be found in here).

This comment has been minimized.

@fmassa

fmassa Jun 8, 2018

Member

For the slice(i), I believe all you need to do is to replace all the slices with undefined tensors, as that's what's used in the ATen backend to represent full slices I believe.

So I'd say you'd need to have a std::vector<Tensor> with tensor.dim() number of elements, where all the dimensions that are slice are left as undefined tensors.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 11, 2018

Author Contributor

@fmassa Thanks a lot! This works out perfectly now!

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented May 31, 2018

@weiyangfb, correct about collapseDims, as for bandwidth, you can compute it as bytes/time, in you case its 8e6/357e-6 = 22.4 GB/s, not that great. K40 peak bandwidth is around 200 GB/s. (8e6 because your tensor has 1 million elements, 4 bytes per element, each element has to be read and written, hence 4*2). You can also compare your time with e.g. time for a pointwise operation for a same size tensor, e.g. a *=2.
For nD tensor, for each dimension that you are flipping you have to shift the base pointer by (dim[i]-1)*stride[i], and set the stride to -stride[i], where stride[i] are the strides of contiguous tensor of those dimensions. At least I think so, please check my math.

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 4, 2018

@pytorchbot retest this please

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 4, 2018

@ngimel Huge thanks for walking me through CUDA performance details and the formula for flipping! For nD tensor, I think your math is correct. Thanks a lot for sharing the formula! I will implement this for the case of flipping on a single dim. Will update the PR in a bit.

This performance analysis is super helpful! I will keep tracking this! I am also trying to use a.t().contiguous() as a benchmark. Is it going to be a tighter lower bound since flip() requires similar non-continuous memory access?

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Jun 4, 2018

For flipping, save for some alignment issues (which can be avoided for sure by e.g. using 1024x1024 tensor), you accesses are still contiguous (elements that are adjacent in original tensor would still be adjacent in the flipped one, even if they are in the different order), so comparing with a regular pointwise op is better. You might want to run your comparison against some real 2d tensor that cannot be collapsed to 1d (you can create it by e.g. running torch.chunk on the 1st dim), to add some index math that you necessarily have for flipping.

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 11, 2018

@ngimel Thanks a lot! Now it makes all sense! I am using TensorInfo and collapseDims to speed up the case where flip dim is the 1st or last dim. Here are some performance results:

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(0)
----------------------------------
10000 loops, best of 3: 178 µs per loop
data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(2)
----------------------------------
10000 loops, best of 3: 181 µs per loop

benchmark:

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.mul(2)
----------------------------------
10000 loops, best of 3: 86.3 µs per loop

And if I understand it correctly, collapseDims might not be able to squeeze nD to 2D tensor if flip dim is not the 1st or last dim, I am using the previous impl for these cases.

data_cuda = torch.arange(1000000, device=cuda).view(100,100,100)
%timeit data_cuda.flip(1)
----------------------------------
1000 loops, best of 3: 364 µs per loop

aryamccarthy and others added some commits May 29, 2018

[c10d] MPI Process Group Implementation (#7783)
This provides a bare-minimum MPI Process Group implementation, the commit is on top of @pietern's Gloo Process Group PR.

* [c10d] MPI Process Group Implementation

ref: #7434

* Better exception, atexit func, and addressed comments

* Clang formatting changes

* Static initialization and addressed comments

* Added constness back

* Test will now launch mpi processes if found

* CMakeList Changed
Fix Windows doc for import error (#7704)
* Fix Windows doc for import error

* Fix doc again

* Fix wrong format

ezyang and others added some commits Jun 11, 2018

Support printing sparse tensors in ATen, fixes #8333. (#8334)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
[C++ API] Cursors (#8190)
* Add cursors to C++ API

* Small self nits

* s/struct/class

* Use more STL like names for cursors
Implement dim_arange operator (#8266)
* Implement arange_like operator

* add ONNX symbolic

* lint

* change name

* Comment the hack
1. fixed flip CPU impl for non-continuous flip dims; 2. added more te…
…sts; 3. using TensorInfo and collapseDims to speed up CUDA impl for cases where flip dim is the 1st or last dim
@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 13, 2018

@fmassa Using torch.cuda.synchronize() does not change much on the runtime, am I doing it correctly?

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
def meshgrid():
    flip_meshgrid(data_cuda, (0, 1))
    torch.cuda.synchronize()
%timeit meshgrid()

1000 loops, best of 3: 1.76 ms per loop

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
def flip():
    data_cuda.flip(0,1)
    torch.cuda.synchronize()
%timeit flip()

1000 loops, best of 3: 353 µs per loop

I don't know why though. Can I ask normally how to profile the fraction of time spent?

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 14, 2018

@fmassa @ngimel is this PR ready for stamp?

@ngimel
Copy link
Contributor

ngimel left a comment

Generally looks good, I'd still like to reduce the amount of integer math.

int64_t total_dims) {
for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) {
int64_t cur_indices = linear_index, rem = 0, dst_offset = 0;
for (int64_t i = 0; i < total_dims; i++) {

This comment has been minimized.

@ngimel

ngimel Jun 14, 2018

Contributor

You know that total_dims here is 2, right? You've collapse the dims. If you make it a compile-time constant, loop will be unrolled and be more efficient.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 14, 2018

Author Contributor

Ah, great point! I will remove the for loop.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 15, 2018

Author Contributor

Removing for loop gives way better performance! (updated %timeit with cuda.synchronize())

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit torch.cuda.synchronize(); data_cuda.flip(0); torch.cuda.synchronize()
--------------------
10000 loops, best of 3: 116 µs per loop
for (int64_t i = 0; i < total_dims; i++) {
int64_t temp = cur_indices;
cur_indices = cur_indices / in_tensor_info.strides[i];
rem = temp - cur_indices * in_tensor_info.strides[i];

This comment has been minimized.

@ngimel

ngimel Jun 14, 2018

Contributor

You can restructure you code in such a way that there's one less divide - for the last iteration cur_indices is guaranteed to be less than the stride you are dividing it by (same applies to your second kernel).

This comment has been minimized.

@weiyangfb

weiyangfb Jun 14, 2018

Author Contributor

Great catch! I will change it accordingly.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 14, 2018

Author Contributor

I am still a bit confused about the last iter, for example, if we have a tensor like this:

[
  [1, 2, 3],
  [4, 5, 6]
]

its sizes = (2, 3), strides = (3, 1). When linear_index = 2, at the last iter cur_indices = 2 > strides[-1] = 1

int flip_dim,
int64_t total_dims) {
for (IndexType linear_index = blockIdx.x * blockDim.x + threadIdx.x; linear_index < N; linear_index += gridDim.x * blockDim.x) {
int64_t cur_indices = linear_index, rem = 0, dst_offset = 0;

This comment has been minimized.

@ngimel

ngimel Jun 14, 2018

Contributor

Can you please check if using int64_t indices is hurting performance? Usually all THC kernels are templated on index type, and for the tensors that can be indexed by 32-bit int they use 32-bit integers. I suspect if you use int32_t where appropriate perf will be better.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 14, 2018

Author Contributor

Yes, I can check that. But will int32_t limits the total number of elements in a tensor since N is a int64_t here?

This comment has been minimized.

@ngimel

ngimel Jun 14, 2018

Contributor

It's int64_t because you hardcode it to int64_t, in THC pointwise kernels are dispatched either to int64_t variants or int32_t variants depending on the number of elements. If it does not affect the performance, using int64_t is certainly safer and saves you from another template instantiation, however, I suspect it does affect performance.

This comment has been minimized.

@weiyangfb

weiyangfb Jun 14, 2018

Author Contributor

Ah, I see! Will look into it! Thanks for the super clear clarification 👍

This comment has been minimized.

@weiyangfb

weiyangfb Jun 15, 2018

Author Contributor

You are right! int64_t does hurt the performance. For case:

data_cuda = torch.arange(1000000, device=cuda).view(1000,1000)
%timeit data_cuda.flip(0)

Results look like these:

  time
hardcode int64_t 178 µs
hardcode int32_t 116 µs
templating int64_t 139 µs
templating int32_t 102 µs

Maybe templating int64_t is the way to go?

This comment has been minimized.

@soumith

soumith Jun 15, 2018

Member

@weiyangfb you need to insert torch.cuda.synchronize() calls before and after the relevant code, if you want accurate benchmarking.

For example:

%timeit torch.cuda.synchronize(); data_cuda.flip(0); torch.cuda.synchronize()

This comment has been minimized.

@weiyangfb

weiyangfb Jun 15, 2018

Author Contributor

@soumith Thanks a lot! The results of hardcode and templating are the same now:

  time
hardcode int64_t 165 µs
hardcode int32_t 128 µs
templating int64_t 165 µs
templating int32_t 128 µs

@ngimel will it be fine to template on index type = int64_t in this case?

This comment has been minimized.

@ngimel

ngimel Jun 15, 2018

Contributor

To me it's still a significant perf difference, but on the other hand if flip is a small fraction of the workload no one notice it. Up to @soumith to decide if it's worth additional code complexity and binary size.

This comment has been minimized.

@soumith

soumith Jun 15, 2018

Member

i vote for single template, no need to specialize two cases. it's indeed a niche function for now

This comment has been minimized.

@weiyangfb

weiyangfb Jun 15, 2018

Author Contributor

Sounds good! I will keep the template.

1. removed for loop in pointwise CUDA kernel; 2. using templated (int…
…64_t) IndexType for indices in pointwise CUDA kernel
@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 15, 2018

@pytorchbot retest this please

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 15, 2018

caffe2 failing test seems not related

weiyangfb added some commits Jun 15, 2018

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 15, 2018

@pytorchbot retest this please

2 similar comments
@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 15, 2018

@pytorchbot retest this please

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 15, 2018

@pytorchbot retest this please

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jun 16, 2018

caffe2 and lint tests failing seems not related, can I get a stamp on this?

@soumith soumith merged commit c9b8d85 into pytorch:master Jun 16, 2018

40 of 42 checks passed

pr/caffe2-py2-gcc4.8-ubuntu14.04-test Build failed
Details
pr/caffe2-py2-mkl-ubuntu16.04-test Build failed
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
pr/caffe2-conda2-macos10.13-build Build successful
Details
pr/caffe2-conda2-ubuntu16.04-test Build successful
Details
pr/caffe2-conda3-cuda9.0-cudnn7-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-android-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-clang3.8-rocm1.7.1-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-clang3.8-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-clang3.9-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-cuda8.0-cudnn5-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-cuda8.0-cudnn6-ubuntu16.04-test Build successful
Details
pr/caffe2-py2-cuda8.0-cudnn7-aten-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-cuda8.0-cudnn7-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-cuda9.0-cudnn7-aten-ubuntu16.04-test Build successful
Details
pr/caffe2-py2-cuda9.0-cudnn7-centos7-build Build successful
Details
pr/caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test Build successful
Details
pr/caffe2-py2-cuda9.0-cudnn7-windows-build Build successful
Details
pr/caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-test Build successful
Details
pr/caffe2-py2-gcc4.9-ubuntu14.04-build Build successful
Details
pr/caffe2-py2-gcc5-ubuntu16.04-test Build successful
Details
pr/caffe2-py2-gcc6-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-gcc7-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-ios-macos10.13-build Build successful
Details
pr/caffe2-py2-setup-ubuntu16.04-build Build successful
Details
pr/caffe2-py2-system-macos10.13-build Build successful
Details
pr/caffe2-py3.6-clang3.8-rocm1.7.1-ubuntu16.04-build Build successful
Details
pr/py2-clang3.8-rocmnightly-ubuntu16.04 Build successful
Details
pr/pytorch-linux-trusty-py2.7 Build successful
Details
pr/pytorch-linux-trusty-py2.7.9 Build successful
Details
pr/pytorch-linux-trusty-py3.5 Build successful
Details
pr/pytorch-linux-trusty-py3.6-gcc4.8 Build successful
Details
pr/pytorch-linux-trusty-py3.6-gcc5.4 Build successful
Details
pr/pytorch-linux-trusty-py3.6-gcc7.2 Build successful
Details
pr/pytorch-linux-trusty-pynightly Build successful
Details
pr/pytorch-linux-xenial-cuda8-cudnn6-py3 Build successful
Details
pr/pytorch-linux-xenial-cuda9-cudnn7-py2 Build successful
Details
pr/pytorch-linux-xenial-cuda9-cudnn7-py3 Build successful
Details
pr/pytorch-linux-xenial-py3-clang5-asan Build successful
Details
pr/pytorch-macos-10.13-cuda9.2-cudnn7-py3 Build successful
Details
pr/pytorch-macos-10.13-py3 Build successful
Details
pr/pytorch-win-ws2016-cuda9-cudnn7-py3 Build successful
Details

pjh5 added a commit to pjh5/pytorch that referenced this pull request Jun 18, 2018

Added flip() fn in ATen (CPU + CUDA) (#7873)
* Spelling fix in MultivariateNormal docstring (#7915)

* [c10d] MPI Process Group Implementation (#7783)

This provides a bare-minimum MPI Process Group implementation, the commit is on top of @pietern's Gloo Process Group PR.

* [c10d] MPI Process Group Implementation

ref: https://github.com/pytorch/pytorch/issues/7434

* Better exception, atexit func, and addressed comments

* Clang formatting changes

* Static initialization and addressed comments

* Added constness back

* Test will now launch mpi processes if found

* CMakeList Changed

* Fix Windows doc for import error (#7704)

* Fix Windows doc for import error

* Fix doc again

* Fix wrong format

* Moved condition for dilated grouped convolutions to CUDNN convolution implementation (#7465)

* Updates to caffe2 operator documentation (#7917)

* Significant updates to the operator docs in prep for merge

* [auto] Update onnx to 307995b - Update from upstream (onnx/onnx#1038)
https://github.com/onnx/onnx/commit/307995b1439e478122780ffc9d4e3ee8910fb7ad

* Test if ASAN is actually working as part of ASAN tests. (#6050)

* Test if ASAN is actually working as part of ASAN tests.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Drop explicit use of libstdc++, we should not care.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Build with DEBUG=1

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Increase main thread stack size when using ASAN.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Split up detail.h (#7836)

* Fix THCUNN SpatialDepthwiseConvolution assuming contiguity (#7952)

* Fix fbcode compatibility (#7939)

* add test for correctness of transpose fusion (#7950)

* [JIT][script] Fix emitted gather and slice for dynamic indices (#7861)

* [JIT][script] Fix emitted gather for dynamic indices

* Also fix slice

* Address comments

* cache and use BLAS_SET_BY_USER so that it doesn't set itself to TRUE when run second time (#7942)

* Add unsafe flag to skip checking in prepare (#7832)

* Add unsafe flag to skip checking in prepare

* pop

* Rename cuda::type to cuda::into_type and provide cuda::from_type. (#7937)

These are used to convert Half -> half and half -> Half respectively.
from_type will be used for runtime type checking in THC.

* Try to fix TORCH_CUDA_ARCH_LIST for PyTorch again (#7936)

* try again

* use DEFINED

* use a loop

* Minor fixes

*  remove sort requirement from pad-sequence (#7928)

* pad-sequence no longer requires sorting entries

pad-sequence can get the max_len from the list of sequences. entries only need to be sorted if output will be used for pack_padded_sequence, which can throw the error itself.

* remove sort requirement from pad-sequence

Picks up from #5974.

Removes the requirement that input sequences to pad_sequence have to be
sorted. Addressed the comments in the PR:
- Updated docstring for pad_sequence
- Remove sort requirement in pad_sequence test
- Test unsorted and sorted sequences in pad_sequence test

* Fix checkBackend error message (#7926)

* Fix checkBackend error message

Fixes #7849

* Switch order of printing args

* Split CI tests in half and run them in parallel (#7867)

* Split and run tests in parallel

* Refactor tests

* Handling of scalars in torch.Size (#5676)

* Handling of scalars in torch.Size

torch.Size() constructor uses python_arg_parser

IntList in python_arg_parser can take iter/range

Have IntList take python iterables and ranges.

Address comments: don't use python_arg_parser and instead call __index__ in THPSize_pynew

Address comments

Address comments

* Rebased

* Address nit

* [JIT] Fission and fusion passes for addmm (#7938)

* Addmm decomposition pass

* Addmm peephole pass

* Fix handling of output shape in fusion pass

* Add DCE to the peephole passes

* add comments

* maybe bugfix?

* Fix GPU tests

* fix py2/3 test issue

* Set smaller grain size for some cases (#7941)

* Fix returning scalar input in Python autograd function (#7934)

* fix _wrap_outputs not working with scalar inputs

* add a test

* Prevent git autocrlf for bash scripts (#7949)

* Delete unused file (#7919)

* Fix typo in autodiff formula for addmm (#7932)

* 1) use meshgrid for flip() CPU implementation, only need one copy of input tensor; 2) changed kernel of CUDA implementation, no need materialized indices tensor; 3) reusing error checking code

* [caffe2] YellowFin parameter update GPU code fix. (#6993)

* [Caffe2] Keep name of caffe2_pybind11_state and caffe2_pybind11_state_gpu in debug build (#7155)

* Allowing MatMul to create a gradient even with 3 inputs. useful if you are differentiating a graph twice (#6536)

* added const for local variables

* Fix the cpp libtorch CUDA build (#7975)

* Use mingfeima's mkldnn (#7977)

* Fix the import part of the windows doc (#7979)

* Change perf test folder after git checkout (#7980)

* Move the broadcast check in MKL Add/Sum to runtime (#7978)

* Use Glog's implementation of STL logging when possible. (#7206)

Inject custom workaround into namespace std so that it can be found by ADL.

* [Hotfix] Bring back warnings and -Werror to ATen (#7866)

* Bring back warnings and -Werror to ATen

* Unbreak...

* Fix tbb errors

* Enable ONNX backend Mean tests (#7985)

* Add third wayt to determine IS_CONDA (#7971)

* Fix EmbeddingBag max_norm option (#7959)

* fix EmbeddingBag max_norm option

* flake8

* add warning to the embedding bag arg change

* Raise error when torch.load a storage on a non-existing device (#7921)

* Raise error when torch.load a storage on a non-existing device

Before, doing torch.load(...) on a CUDA tensor on a CPU-only machine
would raise an unreadable error:

```
~/pytorch/pytorch/torch/cuda/__init__.py in __enter__(self)
    223         if self.idx is -1:
    224             return
--> 225         self.prev_idx = torch._C._cuda_getDevice()
    226         if self.prev_idx != self.idx:
    227             torch._C._cuda_setDevice(self.idx)

AttributeError: module 'torch._C' has no attribute '_cuda_getDevice'
```

This PR makes it so that torch.load raises a hard error if one tries to
load a storage onto a non-existing device and suggests the user to use
torch.load's map_location feature.

* Address comments

* missing dep

* Make THStorage / THCStorage have void* data ptr. (#7964)

* Make THStorage / THCStorage have void* data ptr.

This is the initial step in unifying the ATen and TH tensor representations, next is to only generate a single THStorage / THCStorage type.

The major changes here are:
1) data has been renamed to data_ptr and made void* in THStorage/THCStorage.
2) THStorage / THCStorage stores a at::ScalarType representing its data type (This will be useful when we generate a single THStorage/THCStorage).
3) APIs for Accessing the data as a real*:
a) storage->data<real>() -- this does runtime-type checking (checks that the at::ScalarType is correct).
b) storage->unsafeData<real>() -- as above, but no runtime-type checking (used in inner loops / fast code paths).
c) THStorage_(data)(storage) -- this already existed, just calls storage->data<real>().

* Add include.

* Attempt to fix clang build issues.

* Clarify comment and remove extra character.

* Rename unsafeData -> unsafe_data.

* Remove unnecessary 'to' function to get compile time rather than link time errors.

* Import/export observer symbols for DLL, which fixes the linking error in Visual Studio. (#6834)

* Import/export observer symbols for DLL, which fixes the linking error in Visual Studio.

* Add support of all default cmake build types for release to cuda.

* Remove python bindings for `torch.slice` (#7924)

* skip python bindings for slice

* remove tests

* convert slice test to indexing

* Build ONNX for PyTorch version of libcaffe2 (#7967)

* support loading gzip (#6490)

* support loading gzip

* address comments

* address comments

* fix lint

* fix test for python2

* Add memory leak check in CUDA tests (#7270)

* Add memory leak check in CUDA tests

* Tracking multi-GPU too

* fix run_test.py not running __name__ == '__main__' content; add test for make_cuda_memory_checked_test

* add a comment

* skip if cuda

* 1. Change the wrapper to a method in common.py:TestCase
2. Refactor common constants/method that initialize CUDA context into common_cuda.py
3. Update some test files to use TEST_CUDA and TEST_MULTIGPU

* Fix MaxUnpool3d forward memory leak

* Fix MultiLabelMarginCriterion forward memory leak

* Fix MultiMarginLoss backward memory leak

* default doCUDAMemoryCheck to False

* make the wrapper skip-able

* use TEST_MULTIGPU

* add align_corners=True/False tests for Upsample; fix TEST_CUDNN

* finalize interface

* VolumetricMaxUnpooling_updateOutput

* fix test_nccl

* rename THC caching allocator methods to be clearer

* make the wrapped function a method

* address comments; revert changes to aten/src/THC/THCCachingAllocator.cpp

* fix renamed var

* Revert "Set smaller grain size for some cases" (#7988)

* Entry for c10d in CODEOWNERS (#8001)

* Fix a couple of typos (#7998)

* Fix typo

* Fix typo

* Fix typo

* Fix typo

*  Add on-stack observer cache for Observable (#7931)

observers_list_ stores all the observers for an observable. The list is allocated on heap, which
 can cause LLC miss. Add an on-stack observer cache for fast access. In production, we have seen 20%
 speed up for start and stop observer calls.

* Reduce grain size for Unary operations (#8003)

* [auto] Update onnx to 8ec0e5f - Add index check for Transpose's type inference function (onnx/onnx#1053)
https://github.com/onnx/onnx/commit/8ec0e5fe9badecb1c4cc9f136f791499f20c1377

* Make AT_FORALL_SCALAR_TYPES usable outside of at::namespace. (#7935)

* Make AT_FORALL_SCALAR_TYPES usable outside of at::namespace.

This requires renaming the _cast functions which used the unqualified names.

* Separate onnx mapping of scalar type from cast name.

* Fix flake8.

* Properly cast onnx.

* Remove WITH_ROCM cmake flag/variable (use USE_ROCM solely) (#8013)

* Mention the pytorch-ci-hud on the README. (#8004)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Re-enable build env check (#7969)

* Re-enable build env check

* Fix linux test error

* Try to fix macOS test error

* Update nn.rst (#8029)

* Example for Transformed Distribution (#8011)

* [auto] Update onnx to 33e9cd4 - Remove the usage of default value to fix invalid proto3 files. (onnx/onnx#1052)
https://github.com/onnx/onnx/commit/33e9cd4182fe468675241fba4ae8a16c2f0bd82f

* [auto] Update onnx to 1504a33 - Convert schema assert for duplicate type names to exception (onnx/onnx#1057)
https://github.com/onnx/onnx/commit/1504a33abb7b1bfa773e000e2442545ce403c740

* Support CUDA tensors in ProcessGroupGloo  (#7694)

This adds an unconditional dependency on CUDA, which is not desirable
for the long term. Ideally we have split like ATen where we have
different artifacts for different backends so you can decide at runtime
what to use.

* [auto] Update onnx to 3fb9656 - Fix for fbcode CI (onnx/onnx#1062)
https://github.com/onnx/onnx/commit/3fb965666e7fc271d093ca27529a7a1b1e103c3b

* propagate nan in some activations (#8033)

* propagate nan in some activations

* fix py2 not having math.nan

* flake8

* Fix profiler crash when no events register (#8034)

* Fix profiler crash when no events register

When trying to profile, attempting to print the event table throws a vague error because the event list is empty:

....
max_name_length = max(len(evt.key) for evt in events)
ValueError: max() arg is an empty sequence

This change fixes the error by returning an empty string.

* Update profiler.py

* Allow CI testing with different AVX configs (#8020)

* Allow CI testing with different AVX configs

* Unset ATEN_DISABLE_AVX and ATEN_DISABLE_AVX2 in default config

* Support for generating ATen during the fbcode build, rather than committing the generated files (#8002)

Paint the internal bikeshed a slightly different color to appease Buck tooling.

* Factor python dependency out of interpreter (#7970)

* Factor python dependency out of interpreter

* Remove NO_PYTHON for the autograd engine

If there is no python bindings, then a default Engine is constructed
the first time it is requested.

If the python libraries are loaded, then they override the default
accessor and the default engine becomes a python Engine.

Note: it is possible for two engines to be generated if a non-python
one gets created before the python bindings are loaded. This case
is rare, and just results in additional threads being spawned.

* Fixing AlexNet test which is skipped in CI

* [auto] Update onnx to 760c928 - add missing hasNInputShapes check for bidirectionalBroadcastShapeInference (onnx/onnx#1060)
https://github.com/onnx/onnx/commit/760c9283d0dfdc4b8705e4fa4bd95aca68dea459

* Support modules that output scalar in Gather (and data parallel) (#7973)

* Support modules that output scalar in Gather (and data parallel)

* Improve warning msg

* [auto] Update onnx to 9e7855d - Remove PyTorch generated Upsample tests cases (onnx/onnx#1064)
https://github.com/onnx/onnx/commit/9e7855dcd43e855e26e13a797f4b12ac9d9f2188

* [script] Add support for torch.zeros, torch.ones, etc. (#7799)

* [script] Add support for torch.zeros, torch.ones, etc.

* modifies gen_jit_dispatch to creating bindings for functions that do
  not take tensor arguments, but do have an initial type argument
* adds tensor attributes to these functions for device, layout, and
  dtype specification
* extends the list of valid compiler constants to include device, layout,
  and dtype.
* allows functions with Generators, but only using the default generator

Known limitations:
* when using `torch.float`, we convert it to a scalar tensor and make
  no checks that it is actually used only in a dtype specification.
  This is similar to how we handle Python numbers, creating some situations
  where the script is more permissive. Fixing this requires much more
  significant changes to the IR, so is lower priority for now.
* devices specified using string literals e.g. 'cuda:1' do not work,
  since we do not support string literals in general.

* Add profiling annotations to NeuralNet[Operator|Data] (#8005)

* Update from facebook 1ee4edd286a3 (#8040)

* Adding instance weight to batch distill loss

as title

* add bfloat 16-31

added bfloat 16-31 and their respective unit tests

* [CUDA9] Upgrade - fbcode

CUDA9 upgrade diff D5654023 has been out for a while thanks to Pieter. But with time growing it's becoming quite hard to rebase, because of the symlinks and auto-generated build/config files in tp2. Break D5654023 into two diffs, one touching tp2 config files, and another one touching fbcode TARGETS file (adding nvcc flag). These two should be a bit easier to rebase (for detailed procedure see "Test Plan").

This diff can only be committed if:
1. CUDA 9 rpm is rolled out fleet-wide (TBD)
2. NVidia driver 390.40 is rolled out fleet-wide (done)
3. Upgrade CUDA 9.1, cudnn 7.1, nccl 2.1 (done)
4. Make sure all dependents are built (done)
5. Test all C2 operators, PyTorch (see test plan)

* Share intermediate int32 buffer across Conv ops

Adding a known type

* [C2 fix] infer function for ensure_cpu_output_op

this is adding the missing device funtion for ensure_cpu_output_op

* [int8] Add blob serializer/deserializer for Int8TensorCPU

To export to logfiledb

* [nomnigraph] Add try catch block to optimization passes in predictor

This will catch failures that happen in the optimization pass.

* Caffe2: avoid static initialization order fiasco for CAFFE_ENFORCE

CAFFE_ENFORCE uses strack trace fetcher. Which is currently a
global static variable. If at static initialization time CAFFE_ENFORCE
is used, this is a SIOF. Recently CAFFE_ENFORCE was added into init
functions registration, so we started to see this.

Meyers singleton is going to provide safety here. If stacktrace
fetcher was not registered yet, it will just use a dummy one.

* NUMA support in SparseNN CPU benchmark

Adding support for NUMA in SparseNN CPU benchmark

* [mobile-roofline] Add logging needed for roofline model

This should be all that's needed

* Let the operators using the same input if the operators are not chained

or else, we have to change the input data dims

* fix null-pointer-use UBSAN errors in in reshape_op.h

* revert previous fix on input blob name

as title

* Adding flag to let MineHardNegative automatically extract single value from dict

Model exporter requires the output of the model to be a struct. This makes it convenient to use those models directly in MineHardNegative by allow automatic extraction of the single element of dict, which is a common use case.

* Reverting change that broke internal tests back to OSS compatible state

* Skip CUDA memory leak test on BN tests on windows (#8043)

* workaround for Sequential when one cannot retrieve python source (#8048)

* [auto] Update onnx to 0dbec2a - - Generate protoc type hints on Windows (onnx/onnx#1047)
https://github.com/onnx/onnx/commit/0dbec2a0474abcc92806d54d4dab1948674fcf74

* [auto] Update onnx to 4f8ef17 - Remove erroneous documentation around maps and sequences. (onnx/onnx#1069)
https://github.com/onnx/onnx/commit/4f8ef17ad3965e834b93d3753e54dee296aabc11

* [auto] Update onnx to e6a500e - Extract constant to initializer (onnx/onnx#1050)
https://github.com/onnx/onnx/commit/e6a500e54c50e3d300141f62958dba5f163aea4f

* [auto] Update onnx to 033f956 - make gcc happy (onnx/onnx#1061)
https://github.com/onnx/onnx/commit/033f956f41c55fd409e1c4a0d09795ae5411447f

* Remove NO_PYTHON macros from Exceptions.h/cpp (#8007)

Removes cases where NO_PYTHON was unnecessary in Exception.h/cpp

* [ready] Clean up torch.distributions (#8046)

* Have a single THStorage and THCStorage type. (#8030)

No longer generate data-type specific Storage types, since all Storage types are now identical anyway.
For (some) backwards compatibility and documentation purposes, the Real names, e.g. THLongStorage are now #defined as aliases to the single THStorage type

* Reduce usages of TensorUtils<T>::DataType in THC. (#8056)

TensorUtils<T> is basically ATen-dispatch-lite in that it allows one to do multi-type THC function dispatch with a single call.
However, it is templatized on the Tensor type, and since we are moving to a single Tensor type, this doesn't work.

Most of the functions in TensorUtils (e.g. getDims) can be pulled up a level, to just call THCTensor_nDimension (or directly accessing the member),
but the DataType specific functions are more problematic.

So, this PR does two things:
1) Replaces calls of 'TensorUtils<THCTensor>::DataType' with 'real' since these are identical
2) Templatizes the THC_pointwiseApplyX functions to take scalar types.  To ensure this is done correctly, we static_assert that the scalar type template parameter matches the scalar type of
   the corresponding template parameter.  We will need to get rid of these static_asserts in the future, but this is useful for now.

* Support to run ONNX Upsample operator (mode=nearest) in Caffe2 (#8037)

* Added support to run ONNX Upsample operator (mode=nearest) in Caffe2

* adding error checks to upsample

* adding error checks to upsample

* adding error checks to upsample

* changing to np.isclose

* Revert onnx submodule update

* still fixing

* [auto] Update onnx to eb12f72 - Add conv transpose test cases (onnx/onnx#886)
https://github.com/onnx/onnx/commit/eb12f72a8619e2fbad0d86200677fd96201d4351

* [auto] Update onnx to bd98abb - Add a hook for doing post-processing on protobuf generated header files (onnx/onnx#1068)
https://github.com/onnx/onnx/commit/bd98abbba052c1fa2dadc266dbf0d36c1b941970

* Skip ConvTraspose ONNX backend tests (#8074)

* Post process onnx proto (#8064)

* Post processing onnx generated protobuf files to hide global symbols

* .

* .

* Add code for TensorBoard visualization of JIT GraphExecutors (#8050)

* [auto] Update onnx to cc26486 - bump version to 7 for prelu. (onnx/onnx#1063)
https://github.com/onnx/onnx/commit/cc2648654172f0b7044f9469e6c2204c19a3ae1e

* [auto] Update onnx to 356208d - add input tensor dimension checks to shape inference (onnx/onnx#1070)
https://github.com/onnx/onnx/commit/356208d7560a3e88cabf11fddfe6fbaa748da35c

* Move backtrace to its own header (#8096)

* Move backtrace to its own header

* Move cxxabi.h into Backtrace.cpp

* Fix and ignore some warnings (#8081)

* Do an additional sanity check that nvcc and CUDA include dir agree. (#8094)

If you set CUDA_HOME and CUDA_NVCC_EXECUTABLE together, you may
end up in a situation where the CUDA_VERSION of your includes
mismatches the CUDA version of your nvcc.  See #8092 for a concrete
case where this can occur.  Explicitly detect this situation and
give a good error message in this case!

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* use regex in kwarg parser (#8061)

* Removing remaining NO_PYTHON ifdefs (#8067)

* Remove NO_PYTHON in tracing

* Remove NO_PYTHON in ir.h

* Remove NO_PYTHON in test_jit.cpp

* Replace std::size_t with size_t (#8093)

* Remove out-of-date comment (#8114)

* [Caffe2] Enabling AMD GPU Backend for Caffe2 (#7955)

* Add hip support for caffe2 core

* Add MIOPEN header/wrapper to caffe2 core

* Add HIP device into caffe2 PB

* top level makefile change for rocm/hip

* makefile scaffolding for AMD/RocM/HIP

* Makefile scafodding for AMD/RocM/HIP; add makefile/utility for HIP files

* caffe2 PB update for AMD/ROCM HIP device

* Add AMD/RocM/Thrust dependency

* HIP threadpool update

* Fix makefile macro

* makefile fix: duplicate test/binary name

* makefile clean-up

* makefile clean-up

* add HIP operator registry

* add utilities for hip device

* Add USE_HIP to config summary

* makefile fix for BUILD_TEST

* merge latest

* Fix indentation

* code clean-up

* Guard builds without HIP and use the same cmake script as PyTorch to find HIP

* Setup rocm environment variables in build.sh (ideally should be done in the docker images)

* setup locale

* set HIP_PLATFORM

* Revert "set HIP_PLATFORM"

This reverts commit 8ec58db2b390c9259220c49fa34cd403568300ad.

* continue the build script environment variables mess

* HCC_AMDGPU_TARGET

* Cleanup the mess, has been fixed in the lastest docker images

* Assign protobuf field hip_gpu_id a new field number for backward compatibility

* change name to avoid conflict

* Fix duplicated thread pool flag

* Refactor cmake files to not add hip includes and libs globally

* Fix the wrong usage of environment variables detection in cmake

* Add MIOPEN CNN operators

* Revert "Add MIOPEN CNN operators"

This reverts commit 6e89ad4385b5b8967a7854c4adda52c012cee42a.

* Resolve merge conflicts

* .

* Update GetAsyncNetHIPThreadPool

* Enable BUILD_CAFFE2 in pytorch build

* Unifiy USE_HIP and USE_ROCM

* always check USE_ROCM

* .

* remove unrelated change

* move all core hip files to separate subdirectory

* .

* .

* recurse glob core directory

* .

* correct include

* .

* Detect CUDNN related environment variables in cmake (#8082)

* Implement adaptive softmax (#5287)

* Implement adaptive softmax

* fix test for python 2

* add return_logprob flag

* add a test for cross-entropy path

* address review comments

* Fix docs

* pytorch 0.4 fixes

* address review comments

* don't use no_grad when computing log-probs

* add predict method

* add test for predict

* change methods order

* get rid of hardcoded int values

* Add an optional bias term to the head of AdaptiveSoftmax

* Make libshm also test if rt requires pthread. (#8112)

In some configurations (e.g., our internal build of GCC 5 + GLIBC 2.23),
-lrt is not sufficient to use shm_open; you also need to declare
a dependency on pthread.  This patch adds a surgical extra fix to
detect this situation, in the case that I noticed it failing in the
wild.

Fixes #8110

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* [auto] Update onnx to 2d5ce4a - Remove empty model (onnx/onnx#1058)
https://github.com/onnx/onnx/commit/2d5ce4aeb6c485490ad567cbe610bbe1a83ac72d

* Add missing pragma once. (#8118)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* [auto] Update onnx to 2a87616 - Tests for LRN operator (onnx/onnx#903)
https://github.com/onnx/onnx/commit/2a876162ac91438cea370d75a11c9a96942e89da

* Split SparseTensorImpl off from TensorImpl. (#7990)

* Split SparseTensorImpl off from TensorImpl.

At the moment they have the same data layout, but with the upcoming refactor
they will not, and we need a place to put all of the sparse tensor specific
fields.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Update SparseTensorImpl.h

* [Caffe2] Support non peer access in muji and fix bug when reduced_affix is empty (#6896)

* [Caffe2] Support non peer access in muji

* [Caffe2] Add test for 4 gpus and 2 groups

* [Caffe2] Add comments

* Fix bug when reduced_affix is empty

* Fix typo and add comments about cpu and amd gpu

* Skip OnnxBackendNodeModelTest::test_lrn_default_cuda that causes segfault (#8127)

* Replace most remaining usages of TensorUtils<T>::DataType. (#8124)

As in https://github.com/pytorch/pytorch/pull/8056, this doesn't work with a single TensorImpl type.
This replaces the usages of with a templatized parameter and static_asserts that the new and old are equal.

After this we can get rid of the old template parameter, but I want to ensure they are equivalent across all builds first.

* Add utf-8 header to Python file with Unicode. (#8131)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Add back lrn test (#8134)

* Revert "Skip OnnxBackendNodeModelTest::test_lrn_default_cuda that causes segfault (#8127)"

This reverts commit 410191c4175eaae141306cdb3c3c1c1e8a495225.

* Fix mismatched default values

* Add non_blocking to Tensor/Module.to (#7312)

* Add non_blocking to Tensor/Module.to

* flake8

* Add argparse tests

* cpp parse

* Use C++ parser

* use a commong parse function with Tensor.to

* fix test_jit

* use THPObjectPtr

* increase refcount for None, True, and False

* address comments

* address comments

* Fix job name checking for AVX tests (#8135)

* Fix a corner case for ReShapeOp (#8142)

In my use case, in the backward propogate pass, the reshape need to
change a [0] tensor into [0,0] shaped tensor. The original implementation would
cause out of index issue. This diff fix this problem.

* cpu/ideep context converter (#8139)

* fix type mismatch while call torch._C._cuda_setDevice (#8065)

* fix type mismatch while call torch._C._cuda_setDevice

* fix type mismatch in scatter

* fix type mismatch in scatter

* fix type mismatch while call torch._C._cuda_setDevice

* fix type mismatch while call torch._C._cuda_setDevice

* fix type mismatch while call torch._C._cuda_setDevice

* docs: Add warning to torch.repeat() (#8116)

* docs: Add warning to torch.repeat()

closes #7993

* docs: Add links for numpy functions

* docs: Break the too long line

* Accelerate bernoulli number generation on CPU  (#7171)

* opt bernoulli rng with vsl and openmp

* detect cpu vendor for bernnoulli

* retrigger test platform

*  check the vendor more severely

* use cpuinfo to check vendor

* docs: add canonical_url and fix redirect link (#8155)

* docs: enable redirect link to work for each specific page

* docs: add canonical_url for search engines

closes #7222

* docs: update redirect link to canonical_url

* docstring support for @script and @script_method (#7898)

* docstring support for @script and @script_method

* make it python2 compatible

* improve according to review

* improve build_stmts

* use filter instead of list comprehension

* improve the way wrap is handled for script_method

* stash the original method instead

* allow dynamic attr for ScriptMethod and GraphExecutor

* a bit comment on build_Expr

* remove _build_wrap

* a bit improve on comments

* rename to __original_methods

* should be _original_methods

* [auto] Update onnx to 968d28d - fix Node::isBefore (onnx/onnx#1075)
https://github.com/onnx/onnx/commit/968d28d901d2efdaf6d5fcfd529106762524cdfa

* remove some unnecessary cudaGetDevices (#8089)

* remove unnecessary cudaGetDevices

* make curDevice argument non-optional, add explicit checks to current_device

* Fix cuda.framework error on OSX. (#8136)

When compiling OSX with CUDA, Caffe2's build system uses
find_package(cuda) to get its grubby hands on the CUDA driver
library (for some strange reason, FindCUDA doesn't save this
information as a variable).  Unfortunately, on OSX, sometimes
this picks up the cuda.framework folder, and then our build
system chokes to death because it doesn't try to link against
this as a framework.  (Is the folder even a framework?  I have
no idea).

This commit attempts to fix this in a two pronged fashion:

1. For some users, reducing the precedence of frameworks
using CMAKE_FIND_FRAMEWORK seems to help.  So we set these
variables.  However, this fix is not perfect; on my laptop
it doesn't actually solve the problem.

2. PyTorch doesn't actually need the CUDA driver API.  So we
only add the dep when building Caffe2.

Fixes #8022

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* [C++ API] Improve and use OrderedDict for parameters / modules (#7823)

* Improve OrderedDict for C++ API

* Give OrderedDict a subject and fix review comments

* Fix OrderedDict use in torch/csrc/jit/script/init.cpp

* Fix __rshift__ bug (#8161)

* Fix __rshift__ bug

* Add small tests for __lshift__ and __rshift__ in test_cuda

* Add a more elaborate check for __lshift__ and __rshift__

* refactor the test to address @zou3519 's comments

* Move non-generic Storage code needed by TensorUtils to non-generic C++. (#8164)

For non-generic function call implementations in Storage used by TensorUtils, we do the following:
1) Move the declaration from generic/C to non-generic/C++; we don't need backwards compatibility on these functions and want to use e.g. at::ScalarType.
2) Move the implementation from generic/C++ to non-generic/C++.
3) Change the generic implementation to call the non-generic implementation.

This will allow us to get rid of the corresponding TensorUtils calls (once we move over the Tensor functions in the same manner).

* Pinning opencv to < 3.4 in conda builds (#7923)

* Pinning opencv to 3.1.0 in conda builds

* Also pinning numpy to 1.11

* Trying only specifying <3.4

* Adding -setup- path, and better code structure (#8122)

* Abstract parallelization to faciliate using threadpools (#8163)

* [Caffe2] Update elementwise ops to support numpy style boradcast (#8070)

* Update elementwise ops to support numpy style boradcast

Update elementwise ops to support numpy style boradcast

* Fix sqrt_op

* Fix compare ops

* Fix gradient test

* Fix optimizer legacy broadcast

* Fix legacy broadcast for elementwise ops

* Skip flaky test

* Fix eigen simple binary op

* Fix attention test

* Fix rnn test

* Fix LSTM test

* Fix tan grad

* Fix schema check

* Export getCudnnHandle (#7726)

* [JIT] Support a single TensorList argument anywhere in the argument list + index_put (#8173)

* [JIT] Support a single TensorList argument anywhere in the argument list

* [JIT] index_put

* use the correct datatype format (#8144)

* Add back onnx console scripts dropped during migration from onnx-caffe2 (#8143)

* Get rid of SOVERSION (again). (#8132)

We don't want SOVERSION because pip will lose the symlink and
double your distribution size, and also because our setup.py
accidentally links against both libcaffe2.dylib and libcaffe2.1.dylib
on OS X.  This leads to a very puzzling error where you get
the error "cannot initialize CUDA without ATen_cuda", because
there are actually two copies of your registry in memory (because
there are two copies of the dynamic library).  Dropping SOVERSION
makes it impossible to make this mistake.

In principle, if the shared library load is done with DYLD_GLOBAL,
that should also prevent two copies of the registry from popping up.
Worth checking at some later point, if you need to bring back
SOVERSION (because, e.g., pip finally fixed their software.)

Partially fixes #8022.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Fix a corner case for ReShapeOp (#8178)

In my use case, in the backward propogate pass, the reshape need to
change a [0] tensor into [0,0] shaped tensor. The original implementation would
cause out of index issue. This diff fix this problem.

* Better conv error message basing on weight shape (#8051)

* Add retry logic to sccache download for Windows build (#7697)

* Add retry logic to sccache download for Windows build

* fix script bug

* clean up

* fix caffe2 docker build (#7411)

* [ONNX] Fix type_as symbolic (#8183)

* [ONNX] Nuke type_as symbolic

* make it better

* Fix lookup + test

* Yangqing as an ONNX codeowner (#8185)

* Fix protobuf options (#8184)

* protobuf

* fix protobuf_MSVC_STATIC_RUNTIME

* Add a loop unrolling pass to PyTorch JIT (#7672)

* [auto] Update onnx to 4e65fd8 - fuse consecutive squeezes (onnx/onnx#1078)
https://github.com/onnx/onnx/commit/4e65fd83baaeb94fbaa050ae9df1016378157116

* [Caffe2] Merging setup.py with setup_caffe2.py (#8129)

* Mergine setup.pys, torch works, caffe2 works up to other KP

* Fix to super call for python 2

* Works on python2 on mac

* Consolidating Caffe2 flags

* Fix scalar check for sparse tensors. (#8197)

* Fix scalar check for sparse tensors.

As discovered in #8152

If `t` is a scalar sparse tensor, `t._indices` used to return a sparse
empty tensor because the scalar check was incorrect. This PR modifies
the scalar check to return a dense tensor instead of a sparse tensor.

i.e.
```
tensor = torch.sparse_coo_tensor([], [], torch.Size([]), device=device)
out = tensor._indices()  # was a sparse tensor, now is dense.
```

* Fix typos

* fix lint

* Add more annotations for arguments in ATen schema (#8192)

* use THCThrustAllocator in BCECriterion (#8188)

* Allow parallel_apply to take in list[Tensor] (#8047)

* Docs for gradcheck and gradgradcheck; expose gradgradcheck (#8166)

* Docs for gradcheck and gradgradcheck; expose gradgradcheck

* address comments

* Implement randperm for CUDA (#7606)

* Implement randperm for CUDA

* Use Thrust to implement randperm

* clean up

* Fix test

* Offload small input scenario to CPU

* Fixed test

* Try to fix Windows error

* Fix Windows error and clean up

* Use fork_rng context manager

* Move test_randperm_cuda to test_cuda

* Add half tensor support

* Fix cuda::type error

* Fix CPU offloading

* Fix issues

* No need to check range for n == 0 case

* Update c10d build to link against Caffe2 (#8201)

This follows #7399.

* add wipe_cache option (#8204)

as title

* Replace (non-data) TensorUtils calls with non-generic THCTensor calls. (#8176)

* Replace (non-data) TensorUtils calls with non-generic THCTensor calls.

TensorUtils is templatized on the THTensor type, so to support a single tensor type (like ATen), we need to remove these.

This PR does the following:
1) Allows THCTensorTypeUtils.cuh to include THCTensor.hpp.
   This involves moving includes of it outside of generic/, so we can use the new implementations.
2) Defines a single _THTensor struct and changes THCRealTensor to be a derived type of _THCTensor.
   This allows us to implement a single non-generic function and avoid static_cast or void * tricks to call it from the generic functions.
3) For functions inside of TensorUtils that don't use data pointers:
   a) Implement the functions in (non-generic) THTensor.cpp and declare them in (non-generic) THTensor.hpp.
   b) Have the generic versions call the non-generic versions.
   c) Replace the corresponding TensorUtils<THCTensor>::fn call with (non-generic) THTensor_fn.

* Add comment about THCTensor struct.

* Error if storage is null in setStorageNd or resizeNd.

* Fix c10d compiler warnings (#8206)

Copy compiler flags from the ones used in setup.py and fix warnings.
This makes the root build that includes c10d headers warning free.

* Bump gloo submodule (#8202)

This includes facebookincubator/gloo#125.

* rm -rf aten/contrib (#8165)

* Remove aten/contrib

* Remove from CMake

* Fix tanh_op on ios build (#8207)

* Fix tanh_op on ios build

* Fix tanh

* [auto] Update onnx to f28e2f1 - fix lrn spec (onnx/onnx#1090)
https://github.com/onnx/onnx/commit/f28e2f1a601875593af35a52888f829ba82c0598

* [cmake] deprecate caffe2_* specific cuda function in cmake. (#8200)

* deprecate caffe2_* specific cuda function in cmake.

* ENV{} -> $ENV{}

* CUDA_ARCH_NAME -> TORCH_CUDA_ARCH_LIST

* .

* .

* .

* skip CUDA memory leak check on Windows altogether (#8213)

* Record shape and type in autograd to validate gradients (#8168)

The check that the gradient is defined is currently disabled because
TestJit.test_ge_optimized will trigger the error.

* [auto] Update onnx to 18d70ff - Graph should only have one (input) kParam node (onnx/onnx#1088)
https://github.com/onnx/onnx/commit/18d70ff5294953ccdf791b44ce5ccd9065584945

* Set up a c10 source folder (#7822)

* Set up a c10 source folder

* Change the benchmark log format and also log flops (#8215)

as title

* Move helper functions to unnamed namespace. (#8224)

Currently, the helper functions in this file are in global
namespace. I am guessing the purpose of excluding them from was to
keep them local.

* [auto] Update onnx to e96d823 - Update Google benchmark to 1.4.1 (onnx/onnx#1083)
https://github.com/onnx/onnx/commit/e96d823e5cc69ab02dccaba4d7971897918173c4

* Change new bernoulli implementation to be fully generic. (#8218)

The current implementation depends on THTensor types being unique, which is not guaranteed going forward.

* Structure THTensor like THCTensor is structured. (#8217)

In particular, define a base type, _THTensor, that can be used for all THRealTensor structs.
This is just to have less cognitive load when dealing with generic THTensor/THCTensor types (as in templates).

* move THCP-related utils to cuda/utils.cpp. (#8221)

These files don't follow the usual pattern: In general the files torch/csrc/X torch/csrc/cuda/X
both include the generic file torch/csrc/generic/X, where torch/csrc/X includes the cpu implementations and torch/csrc/cuda/X includes the cuda implementations.
(Aside: this is probably not the best structure, the torch/csrc/X fiels should probably be moved to torch/csrc/cpu/X).

utils.cpp combines these so that torch/csrc/utils.cpp has cuda specific code.  This makes it impossible to declare a single THTensor and THCTensor template type (i.e. THPPointer<_THTensor>, THPointer<_THCTensor>).

* [READY TO MERGE] Use ccache in macOS build (#8009)

* Use ccache in macOS build

* Moving to sccache

* Don't use sccache in test job

* [NEEDS REVIEW] Add nan and inf probability check to multinomial (#7647)

* Add nan and inf probs check to multinomial

* fix bug

* Spawn CUDA test in subprocess

* Make sure invalid input won't pass the test case

* Try to fix error

* Test failure cases in Python 3 only

* Try to fix Windows error

* Move CUDA test to test_cuda.py

* fix issues

* fix module name error

* no need to check for CUDA existence in test_cuda

* Use PY3

* [READY TO MERGE] Enable tests that use DataLoader with multiple workers on Windows (#6745)

* Don't import TEST_CUDA for test_dataloader on Windows

* test_partial_workers is stuck on Windows

* Don't copy unneeded grads when using a function for several derivatives (Fixes #7722) (#7759)

Trying to copy all results fails when one of them is a tensor list which
has not been populated. This blew up for CuDNN RNNs when the weights
did not require grad.

Thanks to Sylvain Gugger for reporting!

* Fix win mkldnn (#7718)

* Sync build_pytorch_libs.bat with build_pytorch_libs.sh

* fix quoting

* add warnings

* fix warnings

* Add /EHa

* [Caffe2] Add ADD operator for IDEEP (#8220)

* Add ADD operator for IDEEP

* Add boradcast check

* Comments

* Allow optional build and installation of native test binaries (#8225)

* test finetuning

* install off by default

* Turn BUILD_TEST=ON for jenkins.

* Turn on install_test in jenkins as well

* Update MKL exporter to IDEEP ops (#8228)

IDEEP exporter support

* [ideep] Add IDEEP Squeeze op (#8227)

Similar to MKLSqueezeOp at caffe2/mkl/operators/squeeze_op.cc

* [auto] Update onnx to 62e63e9 - Fix build errors inside protobuf-bench (onnx/onnx#1084)
https://github.com/onnx/onnx/commit/62e63e9de8f8a3bb8e30c5f7f7f87fb94364ec17

* Use .cc since some downstream libraries are configured for C++ only. (#8234)

* Rename SparseTensor to SparseTensorRef. (#8237)

I want to introduce using SparseTensor = Tensor (as a documentary
type alias for Tensor), but the name is already taken.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* [caffe2] Build Android tests and binaries in CI (#7593)

Update benchmark submodule to version with fixed Android/GNUSTL build

* Remove core and util warnings (#8239)

* Fix some signed/unsigned mismatches

* Skip unused result warning

* Explict fallthrough for murmur hash

* Enable aligned new support to eliminate warning

* Switch to int instead of unsigned in some cases

* Remove .gitmodules.aten since it is in .gitmodules now (#8232)

* Fix: gradcheck forced float32 (#8230)

* Print requires_grad and grad_fn in string repr of tensor (#8211)

For example:

  >>> torch.ones(3).requires_grad_()
  tensor([ 1.,  1.,  1.], requires_grad=True)

  >>> torch.ones(3).requires_grad_() * 5
  tensor([ 5.,  5.,  5.], grad_fn=<MulBackward0>)

The suffix (dtype, requires_grad, grad_fn) wraps to a new line if
it would cause the the line to exceed the linewidth.

  >>> torch.ones(10).double().requires_grad_()
  tensor([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
         dtype=torch.float64, requires_grad=True)

* Fix TEST_CUDA import in test_cuda (#8246)

* Fix lifting cat into its constant version (#8174)

This fixes a bug where schema including varargs lists did not lift
properly blocking correct ONNX export.

* Don't override Tensor, Storage macros defined outside torch/csrc in t… (#8243)

* Don't override Tensor, Storage macros defined outside torch/csrc in torch/csrc.

This PR does the following:
1) Removes THSTensor macros in torch/csrc, which aren't used.
2) For macros defined outside of torch/csrc (THTensor, THTensor_, THStorage, THStorage_):
a) No longer override them, i.e. previously THTensor could actually be THCTensor if a generic file was included from a file including THCP.h.
b) Instead, introduce new macros THW* (e.g. THWTensor) to represent a (potentially empty) wildcard character.

In addition to making this code easier to read and codemod, this allows us to more freely change TH/THC; for example:
currently in the THC random code, the state is casted to THByteTensor*; this happens to work because the macros don't happen to override THByteTensor.
But if THByteTensor just becomes an alias of THTensor (which is the plan for a single tensor type), then this no longer works.
The whole thing is a bit of a mess previously because you really have to understand which macros and redefined and which aren't.

We could also rename the macros that live in torch/csrc (e.g. the THPTensor macros), but since that is more self contained, I punted for now.

* Don't change the plugin.

* [auto] Update onnx to 3a035f4 - Add retry logic to model downloading (onnx/onnx#1077)
https://github.com/onnx/onnx/commit/3a035f439799de3568c364f3f87014841037708e

* Fully genericize THC/THCUNN (except for TensorUtils and DeviceTensorUtils). (#8251)

* [cmake] Use CAFFE2_USE_* for public/cuda.cmake (#8248)

* Fix app size check (#8256)

Fix app size check

* wip on CPU impl

* Stop BCELoss from returning negative results (#8147)

* Stop BCELoss from returning negative results

* check explicitly for 0 before taking log

* add tests

* fix lint

* address comments

* Relax CUDA_HOME detection logic, to build when libraries are found. (#8244)

Log when no cuda runtime is found, but CUDA is found

* Added backward function for kl_div target (#7839)

* added backward fn for target

* added module test for kl_div target, and assuming targets are probabilities

* Change the output format of caffe2 observers (#8261)

as title

* Remove TensorUtils<T>::getData, provide data<T>() in TH(C)Tensor. (#8247)

* Remove TensorUtils<T>::getData, provide data<T>() in TH(C)Tensor.

* Fix template parameter.

* [caffe2] Move submodule onnx-tensorrt forward (#7659)

Commit 82106f833dcb0070446a150e658e60ca9428f89b is essential.

* [ideep] Add IDEEP fallbacks for Faster-RCNN ops (#8260)

TSIA

* un-genericize THCDeviceTensorUtils. (#8258)

* provide data<T>() in TH(C)Tensor.

* un-genericize THCDeviceTensorUtils.

This is used outside of generic context, so we need to un-genericize it to have a single THCTensor type.

* [caffe2] Fix ATen dispatch for ops with TensorList arg (#8226)

* [cmake] Add and export Modules_CUDA_fix (#8271)

* Add and export Modules_CUDA_fix

* actually, need to include before finding cuda

* [auto] Update onnx to 2508156 - Make error message more verbose (onnx/onnx#1097)
https://github.com/onnx/onnx/commit/2508156135c67f2097aaac42153f641e55fd6c68

* [auto] Update onnx to 39e4668 - fix optimizer does not set ir_version bug (onnx/onnx#1098)
https://github.com/onnx/onnx/commit/39e46687eafd34c78dd59a1218171371aa3679f1

* [cmake] Make cudnn optional (#8265)

* Make cudnn optional

* Remove cudnn file from cpu file

* Move signal window functions to ATen; add Blackman window (#8130)

* Move signal window functions to ATen; add Blackman window

* fix cuda test not checking scipy

* [ideep] Fuse Conv-Relu after IDEEP graph rewrite, skip group conv (#8233)

IDEEP supports fusion for non-group conv

* [c10d] NCCL Process Group implementation (#8182)

* [c10d] Process Group NCCL implementation

* Addressed comments

* Added one missing return and clang format again

* Use cmake/Modules for everything and fix gloo build

* Fixed compiler warnings

* Deleted duplicated FindNCCL

* Set up CI build for CUDA 9.2 + macOS (#8274)

* Add macOS CUDA build to CI

* Fix undefined symbols issue

* Use sccache for CUDA build

* Fix sccache issues

* clean up

* c10 build setup (#8264)

* Move c10/ to caffe2/dispatch/

* Set up caffe2/utils directory

* Remove remaining TensorTypeUtils functions. (#8286)

Mostly what's remaining is copy utilities -- these are now provided in THCTensorCopy.hpp and templatized on the ScalarType rather than the TensorType.

* Create initial Python bindings for c10d (#8119)

* Build and install c10d from tools/build_pytorch_libs.sh

* Create initial Python bindings for c10d

* clang-format

* Switch link order to include more symbols

* Add bindings and tests for ProcessGroupGloo

* Add broadcast test

* Separate build flag for c10d

* Explicit PIC property

* Skip c10d tests if not available

* Remove c10d from Windows blacklist

Let it skip by itself because it won't be available anyway.

* Make lint happy

* Comments

* Move c10d module into torch.distributed

* Close tempfile such that it is deleted

* Add option USE_NVRTC which defaults to off (#8289)

* [build] Remove /torch/lib/THD/cmake in favor of /cmake (#7159)

* Remove /torch/lib/THD/cmake in favor of /cmake

* path fix

* Explicitly marking gloo to use cuda

* Fix gloo path in THD

* Have a single THTensor / THCTensor type. (#8288)

* Remove remaining TensorTypeUtils functions.

Mostly what's remaining is copy utilities -- these are now provided in THCTensorCopy.hpp and templatized on the ScalarType rather than the TensorType.

* Have a single THTensor / THCTensor type.

As was previously done with Storages, have only a single (dtype-independent) THTensor / THCTensor.

For documentation and backwards compatibility purposes, the old names, e.g. TH(Cuda)LongTensor alias the new TH(C)Tensor type.

* undef GENERATE_SPARSE.

* [auto] Update onnx to 58efe0a - add float16 support back for math and reduction ops (onnx/onnx#1102)
https://github.com/onnx/onnx/commit/58efe0a9ca6228942d3f7e955babe44459343347

* Some utils for compile-time programming (#7778)

* Add some C++17 features, implemented with C++14

* Add some type traits

* Compile-time type list abstraction

* Some utils for compile-time programming

* Fix compatibility with a larger range of compilers

* Use guts::array instead of std::array because of std::array shortcomings

* code review comments

* Use quotes for includes

* Remove THC's FindMAGMA (#8299)

* Entries for torch.distributed in CODEOWNERS (#8293)

* Add depthwise convolution test for IDEEP (#8301)

* Fix dividing by zero segfault in Reshape (#8302)

when infer a dimension of zero size new shape

* Removes unused THCTensorConv (#8229)

* Replace Variables to Tensors (#8309)

* Clean up old sccache log before build (#8305)

* Remove unused grad ops on mobile to reduce app size (#8297)

Remove unused grad ops on mobile to reduce app size

* Small fixes (#8296)

* [auto] Update onnx to 5ed684e - Remove/replace /MX with /WX for MSVC build. Was typo in a previous ch… (onnx/onnx#1104)
https://github.com/onnx/onnx/commit/5ed684ebe5fd2c1fa1b79aeb7bbacf2844a6cb01

* Fix sample code for cuda stream (#8319)

* [auto] Update onnx to 4b4085c - Add missing warning ignoring flags to onnx_proto CMake target (onnx/onnx#1105)
https://github.com/onnx/onnx/commit/4b4085c2e9d5a944651a2dd0dfdd20ef452bdcdf

* [THD] fix broken THD build with NCCL (#8323)

* Add docstring for `torch.sparse_coo_tensor` (#8152)

* add sparse_coo_tensor docstring

* update empty tensor example

* whitespace

* whitespace again

* add error when backend is not supported by DDP (#8325)

* Fix collect_env.py for Windows (#8326)

* Fix collect_env.py for Windows

* Fix expect file for Win machine

* Fix the script doesn't stop eariler on error for MSVC and Ninja (#8277)

* Simplify the solution

* Remove the usage of set errorlevel

* Skip test_multinomial_invalid_probs_cuda on Windows (#8324)

* Support printing sparse tensors in ATen, fixes #8333. (#8334)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* [C++ API] Cursors (#8190)

* Add cursors to C++ API

* Small self nits

* s/struct/class

* Use more STL like names for cursors

* Implement dim_arange operator (#8266)

* Implement arange_like operator

* add ONNX symbolic

* lint

* change name

* Comment the hack

* 1. fixed flip CPU impl for non-continuous flip dims; 2. added more tests; 3. using TensorInfo and collapseDims to speed up CUDA impl for cases where flip dim is the 1st or last dim

* nits

* 1. removed for loop in pointwise CUDA kernel; 2. using templated (int64_t) IndexType for indices in pointwise CUDA kernel

* added torch.flip.__doc__

* nits

@weiyangfb weiyangfb deleted the weiyangfb:flip_tensor branch Jun 22, 2018

@adam-dziedzic

This comment has been minimized.

Copy link

adam-dziedzic commented Jul 3, 2018

Can you recreate such results:

>>> a
tensor([[1., 1., 1.],
        [1., 0., 2.]], dtype=torch.float64)
>>> torch.flip(a, [1])
tensor([[1., 1.],
        [0., 1.]], dtype=torch.float64)

?

@weiyangfb

This comment has been minimized.

Copy link
Contributor Author

weiyangfb commented Jul 3, 2018

@adam-dziedzic Yes, I can reproduce your results. I think this is a bug. Let me create an issue for this.

@ashwhall

This comment has been minimized.

Copy link

ashwhall commented Jul 24, 2018

@weiyangfb Does this operation copy the memory or give a view into it? I'm flipping HD video, so copying the data is a real memory-bottleneck.

@soumith

This comment has been minimized.

Copy link
Member

soumith commented Jul 24, 2018

@ashwhall it copies the memory over, but does it pretty efficiently.

facebook-github-bot added a commit that referenced this pull request Nov 8, 2018

fix flip() shape bug in CPU (#13344)
Summary:
- a walk around for #13292, a complete fix requires investigation on the root cause when using advanced indexing
- this PR brings in `filp()` CUDA implementation for CPU kernel
- with this change:
```
>>> t = torch.randn(1, 3, 4, 5)
>> t.flip(1, 3).shape
torch.Size([1, 3, 4, 5])
```
- performance:
```
====== with this PR ======
>>> a = torch.randn(1000, 1000)
>>> %timeit -r 100 a.flip(0, 1)
1.98 ms ± 579 µs per loop (mean ± std. dev. of 100 runs, 1000 loops each)

====== Perf at previous PR #7873 ======
100 loops, best of 3: 11 ms per loop
```
Pull Request resolved: #13344

Differential Revision: D12968003

Pulled By: weiyangfb

fbshipit-source-id: 66f434049d143a0575a35b5c983b3e0577a1a28d

@soumith soumith referenced this pull request Dec 12, 2018

Closed

flip a Tensor #229

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment