Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the stride tensor on input for torch.cat #46859

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
137 changes: 90 additions & 47 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType outputSize[Dims],
const IndexType outputStride[Dims],
const IndexType tensorSize[Dims],
const IndexType tensorStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
Expand All @@ -49,22 +49,22 @@ struct CatArrIndexToOffset {

#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * outputStride[i];
IndexType curDimOffset = curDimIndex * tensorStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}

return offset + linearIndex * outputStride[0];
return offset + linearIndex * tensorStride[0];
}
};

template<typename IndexType, unsigned int MaxDims>
struct OutputTensorSizeStride {
IndexType outputSize[MaxDims];
IndexType outputStride[MaxDims];
struct TensorSizeStride {
IndexType tensorSize[MaxDims];
IndexType tensorStride[MaxDims];
};

/**
Expand Down Expand Up @@ -97,7 +97,7 @@ C10_LAUNCH_BOUNDS_1(512)
__global__ void HIP_CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

Expand All @@ -115,32 +115,38 @@ __global__ void HIP_CatArrayBatchedCopy(

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

tid += stride;
}
}

// pass meta data directly through kernel argument instead of pin memory
template <typename T, typename IndexType, int n>
// In contiguous case, we will not need stride_size, setting it as 1 as placeholder
// to pass compile.
template <typename T, typename IndexType, int n, int stride_size>
struct CatArrInputTensorMetadata {
T* input[n];
IndexType offset[n];
IndexType dimSize[n];
IndexType nElements[n];
bool isContiguous[n];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> tensorStride[stride_size];
};

template <typename T, typename IndexType, int Dims>
template <typename T, typename IndexType, int Dims, int batch_size, int stride_size>
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensorMetadata<T, IndexType, CAT_ARRAY_BATCH_SIZE> inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs.nElements[blockIdx.y];
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0];
bool isContig = inputs.isContiguous[blockIdx.y];

if(tid >= nElements) return;

Expand All @@ -152,10 +158,15 @@ __global__ void CatArrayBatchedCopy(
IndexType stride = gridDim.x * blockDim.x;

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
if (isContig) {
output[dataOffset + elementOffset] = data[tid];
} else {
IndexType inElementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[inElementOffset];
}
tid += stride;
}
}
Expand Down Expand Up @@ -197,25 +208,25 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());

OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
outputParam.tensorSize[i] = at::native::size(out, i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
outputParam.tensorSize[0] = at::native::size(out, 0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
outputParam.tensorSize[i] = at::native::size(out, i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
outputParam.tensorSize[nDims - 1] = at::native::size(out, 1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
Expand Down Expand Up @@ -283,7 +294,7 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
#define HANDLE_CASE(DIMS) \
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, param, dimension, param.outputStride[dimension]);
data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
Expand All @@ -303,32 +314,32 @@ void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
}
}

template <typename scalar_t>
template <typename scalar_t, int batch_size, int stride_size>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();
CatArrInputTensorMetadata<scalar_t, unsigned int, CAT_ARRAY_BATCH_SIZE> catMetaData;
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
outputParam.tensorSize[i] = at::native::size(out, i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
outputParam.tensorSize[0] = at::native::size(out, 0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
outputParam.tensorSize[i] = at::native::size(out, i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
outputParam.tensorSize[nDims - 1] = at::native::size(out, 1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
Expand All @@ -338,9 +349,9 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
for (int i = 0; i < inputs.size() ; i += batch_size) {
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
batchCounter < batch_size &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = 0;
Expand All @@ -353,7 +364,17 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].numel();

if (stride_size > 1) {
auto strides = inputs[i+batchCounter].strides();
auto sizes = inputs[i+batchCounter].sizes();
for(int j = 0; j < nDims; j++){
catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j];
catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j];
}
catMetaData.isContiguous[batchCounter] = false;
} else {
catMetaData.isContiguous[batchCounter] = true;
}
// update offset
offset += dimSize;
}
Expand Down Expand Up @@ -381,9 +402,9 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, catMetaData, param, dimension, param.outputStride[dimension]);
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
Expand Down Expand Up @@ -527,24 +548,46 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
return t.scalar_type() == firstType;
});
allSameType = allSameType && (out.scalar_type() == firstType);

#ifdef __HIP_PLATFORM_HCC__
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable &&
allSameType) {

#ifdef __HIP_PLATFORM_HCC__
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
#else
// We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways
// For contiguous input, we don't need to pass stride meta data to cuda kernel through constant
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(out, inputs, dimension, nDims, memory_format);
});
} else if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(out, inputs, dimension, nDims, memory_format);
});
#endif
} else {
Expand Down