Skip to content

Commit

Permalink
1.2.1: 4x faster subm indice generation
Browse files Browse the repository at this point in the history
  • Loading branch information
traveller59 committed Jun 4, 2020
1 parent 492865a commit 11bcbbf
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 29 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
@@ -1,5 +1,10 @@
# Changelog

## [1.2.1] - 2020-06-04
### Changed
- The subm indice pair generation speed is greatly increased by two tricks: 1. most subm conv use only kernelsize=3, so we can unroll loops to get 100% performance increase. 2. subm indice pairs have a property: indicePairs[0, i] = indicePairs[1, kernelVolume - i - 1], so we can get another 100% performance increase.


## [1.2.0] - 2020-05-28
### Added
- add batch gemm support. small performance increasement but more gpu memory usage. you can use algo=spconv.ConvAlgo.Batch to use it.
Expand Down
81 changes: 56 additions & 25 deletions include/spconv/indice.cu.h
Expand Up @@ -274,10 +274,13 @@ __global__ void getSubMIndicePairsKernel3(
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, 3> outSpatialShape, Index spatialVolume) {
auto numActIn = indicesIn.dim(0);

Index point[3];
Index index = 0;
Index offset;

constexpr unsigned KV = K0 * K1 * K2;
constexpr unsigned center = KV / 2;
*(indiceNum.data() + center) = numActIn;
for (int ix : tv::KernelLoopX<int>(numActIn)) {
const Index *indice_data = indicesIn.data() + ix * (3 + 1);
#pragma unroll
Expand All @@ -287,19 +290,32 @@ __global__ void getSubMIndicePairsKernel3(
#pragma unroll
for (int k = 0; k < K2; ++k) {
offset = i * K1 * K2 + j * K2 + k;
point[2] = indice_data[3] - k + K2 / 2;
point[1] = indice_data[2] - j + K1 / 2;
point[0] = indice_data[1] - i + K0 / 2;
if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[2] >= 0 &&
point[2] < outSpatialShape[2] && point[0] >= 0 &&
point[0] < outSpatialShape[0]) {
index = tv::ArrayIndexRowMajor<3, 3>::runPtrs(
point, outSpatialShape.data(), 0) +
spatialVolume * indice_data[0];
if (gridsOut[index] != -1) {
Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(1, offset, oldNum) = gridsOut[index];
indicePairs(0, offset, oldNum) = ix;
if (offset > center){
continue;
}
if (center == offset){
// center of subm indice pairs dont need atomicadd
indicePairs(1, offset, ix) = ix;
indicePairs(0, offset, ix) = ix;
}else{
point[2] = indice_data[3] - k + K2 / 2;
point[1] = indice_data[2] - j + K1 / 2;
point[0] = indice_data[1] - i + K0 / 2;
if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[2] >= 0 &&
point[2] < outSpatialShape[2] && point[0] >= 0 &&
point[0] < outSpatialShape[0]) {
index = tv::ArrayIndexRowMajor<3, 3>::runPtrs(
point, outSpatialShape.data(), 0) +
spatialVolume * indice_data[0];
if (gridsOut[index] != -1) {
// for subm: indicePairs[0, i] = indicePairs[1, kernelVolume - i - 1]
Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
atomicAdd(indiceNum.data() + KV - offset - 1, Index(1));
indicePairs(1, offset, oldNum) = gridsOut[index];
indicePairs(0, offset, oldNum) = ix;
indicePairs(1, KV - offset - 1, oldNum) = ix;
indicePairs(0, KV - offset - 1, oldNum) = gridsOut[index];
}
}
}
}
Expand All @@ -317,6 +333,9 @@ __global__ void getSubMIndicePairsKernel2(
Index point[2];
Index index = 0;
Index offset;
constexpr unsigned KV = K0 * K1;
constexpr unsigned center = KV / 2;
*(indiceNum.data() + center) = numActIn;

for (int ix : tv::KernelLoopX<int>(numActIn)) {
const Index *indice_data = indicesIn.data() + ix * (2 + 1);
Expand All @@ -325,17 +344,29 @@ __global__ void getSubMIndicePairsKernel2(
#pragma unroll
for (int j = 0; j < K1; ++j) {
offset = i * K1 + j;
point[1] = indice_data[2] - j + K1 / 2;
point[0] = indice_data[1] - i + K0 / 2;
if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[0] >= 0 &&
point[0] < outSpatialShape[0]) {
index = tv::ArrayIndexRowMajor<2, 2>::runPtrs(
point, outSpatialShape.data(), 0) +
spatialVolume * indice_data[0];
if (gridsOut[index] > -1) {
Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(1, offset, oldNum) = gridsOut[index];
indicePairs(0, offset, oldNum) = ix;
if (offset > center){
continue;
}
if (center == offset){
// center of subm indice pairs dont need atomicadd
indicePairs(1, offset, ix) = ix;
indicePairs(0, offset, ix) = ix;
}else{
point[1] = indice_data[2] - j + K1 / 2;
point[0] = indice_data[1] - i + K0 / 2;
if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[0] >= 0 &&
point[0] < outSpatialShape[0]) {
index = tv::ArrayIndexRowMajor<2, 2>::runPtrs(
point, outSpatialShape.data(), 0) +
spatialVolume * indice_data[0];
if (gridsOut[index] > -1) {
Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
atomicAdd(indiceNum.data() + KV - offset - 1, Index(1));
indicePairs(1, offset, oldNum) = gridsOut[index];
indicePairs(0, offset, oldNum) = ix;
indicePairs(1, KV - offset - 1, oldNum) = ix;
indicePairs(0, KV - offset - 1, oldNum) = gridsOut[index];
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -96,7 +96,7 @@ def build_extension(self, ext):
packages = find_packages(exclude=('tools', 'tools.*'))
setup(
name='spconv',
version='1.2',
version='1.2.1',
author='Yan Yan',
author_email='scrin@foxmail.com',
description='spatial sparse convolution for pytorch',
Expand Down
6 changes: 3 additions & 3 deletions test/test_conv.py
Expand Up @@ -752,8 +752,8 @@ def main_subm(algo, dtype=torch.float32):


if __name__ == '__main__':
# main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.float32)
# main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.half)
main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.float32)
main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.half)
# TestCase().assertAllClose(out_my, out_ref)
# unittest.main()
TestSpConv().testSpConv3d()
# TestSpConv().testSpConv3d()

0 comments on commit 11bcbbf

Please sign in to comment.