-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Description
🚀 Feature
Implement a function similar to C2 LengthSum in pytorch, currently for the workloads that need it it is inefficiently implemented by calling c2 wrapper. Description:
"""
Aggregates the first dimension of a tensor by individual segments of
specified length.
Category: Length Conversion
Call Args:
data(BlobRefenence|Tensor): A tensor to be aggregated
lengths(BlobReference|Tensor): A 1D tensor of lengths.
Returns:
output(BlobReference|Tensor): A tensor with first dimension len(lengths)
where the ith row in the first dimension is the sum of the ith segment.
Example:
>>> import numpy as np
>>> import dper3.modules as modules
>>> from dper3.core.environment import Caffe2EagerEnv, DperEnv
>>> from dper3.core.train_utils import as_blob, blob_to_numpy
>>> Caffe2EagerEnv().enable()
>>> data = as_blob(np.array([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=np.float32))
>>> lengths = as_blob(np.array([1, 1, 2], dtype=np.int32))
>>> convert = modules.LengthsSum()
>>> output = convert(data, lengths)
>>> blob_to_numpy(output)
array([[1, 1],
[2, 2],
[7, 7]], dtype=int32)
>>> DperEnv.reset()
"""
pointer to c2 cuda implementation (not that SparseFused is false
in template instantiation)
pytorch/caffe2/operators/segment_reduction_op_gpu.cu
Lines 370 to 503 in e429d05
template <typename T, class Context = CUDAContext, bool SparseFused = true> | |
class CUDASparseLengthsSumOp : public Operator<CUDAContext> { | |
public: | |
USE_OPERATOR_CONTEXT_FUNCTIONS; | |
template <class... Args> | |
explicit CUDASparseLengthsSumOp(Args&&... args) | |
: Operator<CUDAContext>(std::forward<Args>(args)...) {} | |
~CUDASparseLengthsSumOp() {} | |
bool RunOnDevice() override { | |
if (SparseFused) { | |
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( | |
this, Input(INDICES)); | |
} else { | |
// type doesn't matter | |
return DoRunWithType<int32_t>(); | |
} | |
} | |
template <typename IndexType> | |
bool DoRunWithType() { | |
if (SparseFused) { | |
return DispatchHelper<TensorTypes2<float, at::Half>, IndexType>::call( | |
this, Input(DATA)); | |
} else { | |
return DoRunWithType2<IndexType, T>(); | |
} | |
} | |
template <typename IndexType, typename InType> | |
bool DoRunWithType2() { | |
auto& dataInput = Input(DATA); | |
auto& lengthsInput = Input(LENGTHS); | |
CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector"); | |
const int64_t dataSize = dataInput.dim(0); | |
// Either first dim the data or how much we pull in indexies from it | |
int64_t dataToReduceSize; | |
const int64_t outputSize = lengthsInput.dim(0); | |
const int len_length = outputSize; | |
auto shape = dataInput.sizes().vec(); | |
shape[0] = outputSize; | |
auto* output = Output(0, shape, at::dtype<T>()); | |
T* out_data = output->template mutable_data<T>(); | |
if (len_length <= 0) { | |
// return early to avoid invalid empty kernel | |
return true; | |
} | |
const IndexType* indices; | |
if (SparseFused) { // static if | |
auto& indicesInput = Input(INDICES); | |
CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector"); | |
indices = indicesInput.template data<IndexType>(); | |
dataToReduceSize = indicesInput.dim(0); | |
} else { | |
dataToReduceSize = dataSize; | |
} | |
// only compute this the first time | |
inclusive_scan_length_buffer_.ResizeLike(lengthsInput); | |
inclusive_scan_wrapper( | |
lengthsInput.template data<int>(), | |
len_length, | |
&inclusive_scan_buffer_, | |
&inclusive_scan_length_buffer_, | |
&context_); | |
auto* prefix_sum_length_data = | |
inclusive_scan_length_buffer_.template data<int>(); | |
int N = dataSize; | |
int post = dataInput.size_from_dim(1); | |
auto maxThreads = | |
GetDeviceProperty(CaffeCudaGetDevice()).maxThreadsPerBlock; | |
if (SparseFused) { | |
const InType* in_data = dataInput.template data<InType>(); | |
if (post <= maxThreads) { | |
int multiple = std::min(maxThreads / post, SEGREDUCE_MINBLOCKS); | |
dim3 block(post, multiple); | |
size_t smem = sizeof(T) * post * multiple; | |
// calling cuda kernel with ExactBlock = true, Average = false | |
sparse_length_sum_kernel<InType, T, IndexType, true, false> | |
<<<len_length, block, smem, context_.cuda_stream()>>>( | |
in_data, | |
out_data, | |
prefix_sum_length_data, | |
indices, | |
N, | |
post, | |
len_length, | |
dataToReduceSize); | |
} else { | |
// calling cuda kernel with ExactBlock = false, Average = false | |
sparse_length_sum_kernel<InType, T, IndexType, false, false> | |
<<<len_length, maxThreads, 0, context_.cuda_stream()>>>( | |
in_data, | |
out_data, | |
prefix_sum_length_data, | |
indices, | |
N, | |
post, | |
len_length, | |
dataToReduceSize); | |
} | |
} else { | |
const T* in_data = dataInput.template data<T>(); | |
if (post <= maxThreads) { | |
length_sum_kernel<T, true, false> | |
<<<len_length, post, 0, context_.cuda_stream()>>>( | |
in_data, out_data, prefix_sum_length_data, N, post, len_length); | |
} else { | |
length_sum_kernel<T, true, false> | |
<<<len_length, maxThreads, 0, context_.cuda_stream()>>>( | |
in_data, out_data, prefix_sum_length_data, N, post, len_length); | |
} | |
} | |
return true; | |
} | |
enum { DATA = 0, INDICES = 1, LENGTHS = 1 + (SparseFused ? 1 : 0) }; | |
private: | |
// menber field to manage memory | |
Tensor inclusive_scan_buffer_{CUDA}; | |
Tensor inclusive_scan_length_buffer_{CUDA}; | |
}; |
Implementing it in the sufficiently generic way may be a lot of work, so just reimplementing c2 functionality is ok for now. We might want to slightly future proof it by allowing either lengths
or indices
argument, and also axis
argument though for now it's ok to allow only axis=0
.
I propose the following API. I suggest make lengths
a keyword-only argument so that if we decide to support either indices
or lengths
it does not have to be breaking.
_segmented_sum(Tensor input, *, Tensor lengths, int axis=0, unsafe=False)
unsafe
means that there are no checks that lengths
argument contain valid values, which may be desired for performance reasons.
This can be a first step towards implementing reduceat
numpy functionality https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.ufunc.reduceat.html