Skip to content

Commit

Permalink
Support the strided tensor on input for torch.cat (#46859)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46859

Current implementation, for non-contiguous, it will go to slow path. This change tries to enable fast path for non-contiguous input(up to 4-dim).

Test Plan:
#benchamark

before
```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : all

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1,1,1)_N2_dim0_cuda
# Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 17.126

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(512,512,2)_N2_dim1_cuda
# Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 20.652

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(128,1024,2)_N2_dim1_cuda
# Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 20.412

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim0_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 48.265

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1025,1023,2)_N2_dim1_cuda
# Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 52.964

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim2_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 71.111

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f8a3cdc2440>,111,65]_N5_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f8a3cdc2440>, 111, 65], N: 5, dim: 0, device: cuda
Forward Execution Time (us) : 39.492

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[96,<function<lambda>at0x7f8a3cdc2b90>,64]_N5_dim1_cuda
# Input: sizes: [96, <function <lambda> at 0x7f8a3cdc2b90>, 64], N: 5, dim: 1, device: cuda
Forward Execution Time (us) : 31.596

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[128,64,<function<lambda>at0x7f880e7db3b0>]_N5_dim2_cuda
# Input: sizes: [128, 64, <function <lambda> at 0x7f880e7db3b0>], N: 5, dim: 2, device: cuda
Forward Execution Time (us) : 66.668

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f880e7db5f0>,32,64]_N50_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f880e7db5f0>, 32, 64], N: 50, dim: 0, device: cuda
Forward Execution Time (us) : 54.562

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[32,<function<lambda>at0x7f880e7db680>,64]_N50_dim1_cuda
# Input: sizes: [32, <function <lambda> at 0x7f880e7db680>, 64], N: 50, dim: 1, device: cuda
Forward Execution Time (us) : 53.255

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[33,65,<function<lambda>at0x7f880e7db710>]_N50_dim2_cuda
# Input: sizes: [33, 65, <function <lambda> at 0x7f880e7db710>], N: 50, dim: 2, device: cuda
Forward Execution Time (us) : 69.771

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda
# Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 98.438

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda
# Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda
Forward Execution Time (us) : 115.045

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda
# Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda
Forward Execution Time (us) : 476.497

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f880e7db7a0>]_N100_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f880e7db7a0>], N: 100, dim: 0, device: cuda
Forward Execution Time (us) : 86.307

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f880e7db830>]_N1000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f880e7db830>], N: 1000, dim: 0, device: cuda
Forward Execution Time (us) : 453.269

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f880e7db8c0>]_N2000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f880e7db8c0>], N: 2000, dim: 0, device: cuda
Forward Execution Time (us) : 935.365

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7f880e7db950>]_N3000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7f880e7db950>], N: 3000, dim: 0, device: cuda
Forward Execution Time (us) : 1355.937
```
after
```
WARNING:2020-11-01 21:14:23 3332963:3336757 EventProfilerController.cpp:143] (x1) Lost sample due to delays (ms): 488, 11, 4121, 0
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : all

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1,1,1)_N2_dim0_cuda
# Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 17.174

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(512,512,2)_N2_dim1_cuda
# Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 20.399

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(128,1024,2)_N2_dim1_cuda
# Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 23.349

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim0_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda
Forward Execution Time (us) : 47.847

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1025,1023,2)_N2_dim1_cuda
# Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda
Forward Execution Time (us) : 53.463

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(1024,1024,2)_N2_dim2_cuda
# Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 72.789

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd5b5567710>,111,65]_N5_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd5b5567710>, 111, 65], N: 5, dim: 0, device: cuda
Forward Execution Time (us) : 39.747

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[96,<function<lambda>at0x7fd5b56b1320>,64]_N5_dim1_cuda
# Input: sizes: [96, <function <lambda> at 0x7fd5b56b1320>, 64], N: 5, dim: 1, device: cuda
Forward Execution Time (us) : 31.814

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[128,64,<function<lambda>at0x7fd3a2289680>]_N5_dim2_cuda
# Input: sizes: [128, 64, <function <lambda> at 0x7fd3a2289680>], N: 5, dim: 2, device: cuda
Forward Execution Time (us) : 67.202

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd3a2289710>,32,64]_N50_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd3a2289710>, 32, 64], N: 50, dim: 0, device: cuda
Forward Execution Time (us) : 65.229

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[32,<function<lambda>at0x7fd3a22897a0>,64]_N50_dim1_cuda
# Input: sizes: [32, <function <lambda> at 0x7fd3a22897a0>, 64], N: 50, dim: 1, device: cuda
Forward Execution Time (us) : 60.843

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[33,65,<function<lambda>at0x7fd3a2289830>]_N50_dim2_cuda
# Input: sizes: [33, 65, <function <lambda> at 0x7fd3a2289830>], N: 50, dim: 2, device: cuda
Forward Execution Time (us) : 69.756

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda
# Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda
Forward Execution Time (us) : 98.222

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda
# Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda
Forward Execution Time (us) : 112.521

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda
# Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda
Forward Execution Time (us) : 477.736

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd3a22898c0>]_N100_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd3a22898c0>], N: 100, dim: 0, device: cuda
Forward Execution Time (us) : 50.617

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd3a2289950>]_N1000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd3a2289950>], N: 1000, dim: 0, device: cuda
Forward Execution Time (us) : 461.631

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd3a22899e0>]_N2000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd3a22899e0>], N: 2000, dim: 0, device: cuda
Forward Execution Time (us) : 840.469

# Benchmarking PyTorch: cat
# Mode: Eager
# Name: cat_sizes[<function<lambda>at0x7fd3a2289a70>]_N3000_dim0_cuda
# Input: sizes: [<function <lambda> at 0x7fd3a2289a70>], N: 3000, dim: 0, device: cuda
Forward Execution Time (us) : 1317.866
```

Reviewed By: ngimel

Differential Revision: D24527676

fbshipit-source-id: 04d8efd89d7856fb45ce6edd8c105a5f5b218135
  • Loading branch information
lly-zero-one authored and facebook-github-bot committed Nov 7, 2020
1 parent 6e69a24 commit 5b894af
Showing 1 changed file with 90 additions and 47 deletions.
137 changes: 90 additions & 47 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType outputSize[Dims],
const IndexType outputStride[Dims],
const IndexType tensorSize[Dims],
const IndexType tensorStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
Expand All @@ -49,22 +49,22 @@ struct CatArrIndexToOffset {

#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * outputStride[i];
IndexType curDimOffset = curDimIndex * tensorStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}

return offset + linearIndex * outputStride[0];
return offset + linearIndex * tensorStride[0];
}
};

template<typename IndexType, unsigned int MaxDims>
struct OutputTensorSizeStride {
IndexType outputSize[MaxDims];
IndexType outputStride[MaxDims];
struct TensorSizeStride {
IndexType tensorSize[MaxDims];
IndexType tensorStride[MaxDims];
};

/**
Expand Down Expand Up @@ -97,7 +97,7 @@ C10_LAUNCH_BOUNDS_1(512)
__global__ void HIP_CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

Expand All @@ -115,32 +115,38 @@ __global__ void HIP_CatArrayBatchedCopy(

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

tid += stride;
}
}

// pass meta data directly through kernel argument instead of pin memory
template <typename T, typename IndexType, int n>
// In contiguous case, we will not need stride_size, setting it as 1 as placeholder
// to pass compile.
template <typename T, typename IndexType, int n, int stride_size>
struct CatArrInputTensorMetadata {
T* input[n];
IndexType offset[n];
IndexType dimSize[n];
IndexType nElements[n];
bool isContiguous[n];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> tensorStride[stride_size];
};

template <typename T, typename IndexType, int Dims>
template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensorMetadata<T, IndexType, CAT_ARRAY_BATCH_SIZE> inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs.nElements[blockIdx.y];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0];
bool isContig = inputs.isContiguous[blockIdx.y];

if(tid >= nElements) return;

Expand All @@ -152,10 +158,15 @@ __global__ void CatArrayBatchedCopy(
IndexType stride = gridDim.x * blockDim.x;

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
if (isContig) {
output[dataOffset + elementOffset] = data[tid];
} else {
IndexType inElementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[inElementOffset];
}
tid += stride;
}
}
Expand Down Expand Up @@ -197,25 +208,25 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());

OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
outputParam.tensorSize[i] = at::native::size(out, i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
outputParam.tensorSize[0] = at::native::size(out, 0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
outputParam.tensorSize[i] = at::native::size(out, i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
outputParam.tensorSize[nDims - 1] = at::native::size(out, 1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
Expand Down Expand Up @@ -283,7 +294,7 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
#define HANDLE_CASE(DIMS) \
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, param, dimension, param.outputStride[dimension]);
data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
Expand All @@ -303,32 +314,32 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
}
}

template <typename scalar_t>
template <typename scalar_t, int batch_size, int stride_size>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();
CatArrInputTensorMetadata<scalar_t, unsigned int, CAT_ARRAY_BATCH_SIZE> catMetaData;
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
outputParam.tensorSize[i] = at::native::size(out, i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
outputParam.tensorSize[0] = at::native::size(out, 0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
outputParam.tensorSize[i] = at::native::size(out, i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
outputParam.tensorSize[nDims - 1] = at::native::size(out, 1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
Expand All @@ -338,9 +349,9 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
for (int i = 0; i < inputs.size() ; i += batch_size) {
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
batchCounter < batch_size &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = 0;
Expand All @@ -353,7 +364,17 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].numel();

if (stride_size > 1) {
auto strides = inputs[i+batchCounter].strides();
auto sizes = inputs[i+batchCounter].sizes();
for(int j = 0; j < nDims; j++){
catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j];
catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
}
catMetaData.isContiguous[batchCounter] = false;
} else {
catMetaData.isContiguous[batchCounter] = true;
}
// update offset
offset += dimSize;
}
Expand Down Expand Up @@ -381,9 +402,9 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, param, dimension, param.outputStride[dimension]);
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
Expand Down Expand Up @@ -527,24 +548,46 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
return t.scalar_type() == firstType;
});
allSameType = allSameType && (out.scalar_type() == firstType);

#ifdef __HIP_PLATFORM_HCC__
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable &&
allSameType) {

#ifdef __HIP_PLATFORM_HCC__
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
#else
// We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways
// For contiguous input, we don't need to pass stride meta data to cuda kernel through constant
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(out, inputs, dimension, nDims, memory_format);
});
} else if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(out, inputs, dimension, nDims, memory_format);
});
#endif
} else {
Expand Down

0 comments on commit 5b894af

Please sign in to comment.