diff --git a/CHANGELOG.md b/CHANGELOG.md index 485a39b..ff4fa57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## [1.2.1] - 2020-06-04 +### Changed +- The subm indice pair generation speed is greatly increased by two tricks: 1. most subm conv use only kernelsize=3, so we can unroll loops to get 100% performance increase. 2. subm indice pairs have a property: indicePairs[0, i] = indicePairs[1, kernelVolume - i - 1], so we can get another 100% performance increase. + + ## [1.2.0] - 2020-05-28 ### Added - add batch gemm support. small performance increasement but more gpu memory usage. you can use algo=spconv.ConvAlgo.Batch to use it. diff --git a/include/spconv/indice.cu.h b/include/spconv/indice.cu.h index cb2cd44..828ed7c 100644 --- a/include/spconv/indice.cu.h +++ b/include/spconv/indice.cu.h @@ -274,10 +274,13 @@ __global__ void getSubMIndicePairsKernel3( tv::TensorView indicePairs, tv::TensorView indiceNum, const tv::SimpleVector outSpatialShape, Index spatialVolume) { auto numActIn = indicesIn.dim(0); + Index point[3]; Index index = 0; Index offset; - + constexpr unsigned KV = K0 * K1 * K2; + constexpr unsigned center = KV / 2; + *(indiceNum.data() + center) = numActIn; for (int ix : tv::KernelLoopX(numActIn)) { const Index *indice_data = indicesIn.data() + ix * (3 + 1); #pragma unroll @@ -287,19 +290,32 @@ __global__ void getSubMIndicePairsKernel3( #pragma unroll for (int k = 0; k < K2; ++k) { offset = i * K1 * K2 + j * K2 + k; - point[2] = indice_data[3] - k + K2 / 2; - point[1] = indice_data[2] - j + K1 / 2; - point[0] = indice_data[1] - i + K0 / 2; - if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[2] >= 0 && - point[2] < outSpatialShape[2] && point[0] >= 0 && - point[0] < outSpatialShape[0]) { - index = tv::ArrayIndexRowMajor<3, 3>::runPtrs( - point, outSpatialShape.data(), 0) + - spatialVolume * indice_data[0]; - if (gridsOut[index] != -1) { - Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); - indicePairs(1, offset, oldNum) = gridsOut[index]; - indicePairs(0, offset, oldNum) = ix; + if (offset > center){ + continue; + } + if (center == offset){ + // center of subm indice pairs dont need atomicadd + indicePairs(1, offset, ix) = ix; + indicePairs(0, offset, ix) = ix; + }else{ + point[2] = indice_data[3] - k + K2 / 2; + point[1] = indice_data[2] - j + K1 / 2; + point[0] = indice_data[1] - i + K0 / 2; + if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[2] >= 0 && + point[2] < outSpatialShape[2] && point[0] >= 0 && + point[0] < outSpatialShape[0]) { + index = tv::ArrayIndexRowMajor<3, 3>::runPtrs( + point, outSpatialShape.data(), 0) + + spatialVolume * indice_data[0]; + if (gridsOut[index] != -1) { + // for subm: indicePairs[0, i] = indicePairs[1, kernelVolume - i - 1] + Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); + atomicAdd(indiceNum.data() + KV - offset - 1, Index(1)); + indicePairs(1, offset, oldNum) = gridsOut[index]; + indicePairs(0, offset, oldNum) = ix; + indicePairs(1, KV - offset - 1, oldNum) = ix; + indicePairs(0, KV - offset - 1, oldNum) = gridsOut[index]; + } } } } @@ -317,6 +333,9 @@ __global__ void getSubMIndicePairsKernel2( Index point[2]; Index index = 0; Index offset; + constexpr unsigned KV = K0 * K1; + constexpr unsigned center = KV / 2; + *(indiceNum.data() + center) = numActIn; for (int ix : tv::KernelLoopX(numActIn)) { const Index *indice_data = indicesIn.data() + ix * (2 + 1); @@ -325,17 +344,29 @@ __global__ void getSubMIndicePairsKernel2( #pragma unroll for (int j = 0; j < K1; ++j) { offset = i * K1 + j; - point[1] = indice_data[2] - j + K1 / 2; - point[0] = indice_data[1] - i + K0 / 2; - if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[0] >= 0 && - point[0] < outSpatialShape[0]) { - index = tv::ArrayIndexRowMajor<2, 2>::runPtrs( - point, outSpatialShape.data(), 0) + - spatialVolume * indice_data[0]; - if (gridsOut[index] > -1) { - Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); - indicePairs(1, offset, oldNum) = gridsOut[index]; - indicePairs(0, offset, oldNum) = ix; + if (offset > center){ + continue; + } + if (center == offset){ + // center of subm indice pairs dont need atomicadd + indicePairs(1, offset, ix) = ix; + indicePairs(0, offset, ix) = ix; + }else{ + point[1] = indice_data[2] - j + K1 / 2; + point[0] = indice_data[1] - i + K0 / 2; + if (point[1] >= 0 && point[1] < outSpatialShape[1] && point[0] >= 0 && + point[0] < outSpatialShape[0]) { + index = tv::ArrayIndexRowMajor<2, 2>::runPtrs( + point, outSpatialShape.data(), 0) + + spatialVolume * indice_data[0]; + if (gridsOut[index] > -1) { + Index oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); + atomicAdd(indiceNum.data() + KV - offset - 1, Index(1)); + indicePairs(1, offset, oldNum) = gridsOut[index]; + indicePairs(0, offset, oldNum) = ix; + indicePairs(1, KV - offset - 1, oldNum) = ix; + indicePairs(0, KV - offset - 1, oldNum) = gridsOut[index]; + } } } } diff --git a/setup.py b/setup.py index eb328dd..cf5e554 100644 --- a/setup.py +++ b/setup.py @@ -96,7 +96,7 @@ def build_extension(self, ext): packages = find_packages(exclude=('tools', 'tools.*')) setup( name='spconv', - version='1.2', + version='1.2.1', author='Yan Yan', author_email='scrin@foxmail.com', description='spatial sparse convolution for pytorch', diff --git a/test/test_conv.py b/test/test_conv.py index b9a0810..ababc87 100644 --- a/test/test_conv.py +++ b/test/test_conv.py @@ -752,8 +752,8 @@ def main_subm(algo, dtype=torch.float32): if __name__ == '__main__': - # main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.float32) - # main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.half) + main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.float32) + main_subm(algo=spconv.ConvAlgo.Native, dtype=torch.half) # TestCase().assertAllClose(out_my, out_ref) # unittest.main() - TestSpConv().testSpConv3d() + # TestSpConv().testSpConv3d()