Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve make_tensor performance for float and complex types
For floating types, `make_tensor` calls `rand` and then does a linear interpolation (aka lerp) from `low` to `high`. This makes the lerp step faster by: - using inplace operations - using `add`'s `alpha` parameter to avoid an extra kernel - adding shortcuts for special values of `low` and `high` For complex types, `make_tensor` does the `rand` + interpolation step twice and calls `torch.complex(real, imag)` at the end. This reduce overhead by doing a single `rand` + interpolate of double the size, then calling `torch.view_as_complex` at the end. My benchmarks show speedups in all cases for float32 and complex64. | Device | dtype | Size | low | high | Master (us) | This PR (us) | Speedup | |--------|-----------|-------|-----|------|-------------|--------------|---------| | CPU | float32 | 8 | | | 19.4 | 15.1 | 1.3 | | | | | 0 | | 19.7 | 9.21 | 2.1 | | | | | 0 | 1 | 19.7 | 5.94 | 3.3 | | | | 4096 | | | 36.8 | 31.3 | 1.2 | | | | | 0 | | 37.1 | 24.7 | 1.5 | | | | | 0 | 1 | 36.9 | 21.0 | 1.8 | | | | 2**24 | | | 167,000 | 115,000 | 1.5 | | | | | 0 | | 179,000 | 85,200 | 2.1 | | | | | 0 | 1 | 180,000 | 80,800 | 2.2 | | | complex32 | 8 | | | 37.0 | 17.6 | 2.1 | | | | | 0 | | 37.4 | 11.3 | 3.3 | | | | | 0 | 1 | 37.5 | 7.66 | 4.9 | | | | 4096 | | | 73.1 | 49.9 | 1.5 | | | | | 0 | | 73.5 | 41.5 | 1.8 | | | | | 0 | 1 | 73.6 | 37.6 | 2.0 | | | | 2**24 | | | 409,000 | 229,000 | 1.8 | | | | | 0 | | 411,000 | 170,000 | 2.4 | | | | | 0 | 1 | 409,000 | 163,000 | 2.5 | | CUDA | float32 | 8 | | | 40.4 | 30.9 | 1.3 | | | | | 0 | | 39.2 | 17.6 | 2.2 | | | | | 0 | 1 | 39.2 | 11.1 | 3.5 | | | | 4096 | | | 38.7 | 32.2 | 1.2 | | | | | 0 | | 39.2 | 18.0 | 2.2 | | | | | 0 | 1 | 39.3 | 11.1 | 3.5 | | | | 2**24 | | | 2,300 | 1,840 | 1.3 | | | | | 0 | | 2,300 | 704 | 3.3 | | | | | 0 | 1 | 2,300 | 242 | 9.5 | | | complex32 | 8 | | | 78.7 | 34.7 | 2.3 | | | | | 0 | | 80.8 | 20.5 | 3.9 | | | | | 0 | 1 | 83.5 | 13.8 | 6.0 | | | | 4096 | | | 82.7 | 34.8 | 2.4 | | | | | 0 | | 83.9 | 20.5 | 4.1 | | | | | 0 | 1 | 81.5 | 13.9 | 5.9 | | | | 2**24 | | | 5,520 | 3,670 | 1.5 | | | | | 0 | | 5,520 | 1,400 | 3.9 | | | | | 0 | 1 | 5,520 | 484 | 11.4 | ghstack-source-id: 9c6a6e04f29e49b68718bc60ef3f3f6417415365 Pull Request resolved: pytorch#85473
- Loading branch information