Skip to content

Commit

Permalink
Update on "Improve make_tensor performance for float and complex types"
Browse files Browse the repository at this point in the history
For floating types, `make_tensor` calls `rand` and then does a linear
interpolation from `low` to `high`. This instead calls `uniform_(low,
high)` to cut out the interpolation step.

For complex types, `make_tensor` does the `rand` + interpolation step
twice and calls `torch.complex(real, imag)` at the end. This instead
uses `view_as_real` and `uniform_(low, high)` to fuse it all into one
operation.

My benchmarks show significant speedups in all cases for float32 and
complex64.

| Device | dtype     | Size  | Master (us) | This PR (us) | Speedup |
|--------|-----------|-------|-------------|--------------|---------|
| CPU    | float32   | 8     | 19.4        | 6.34         | 3.1     |
|        |           | 4096  | 36.8        | 21.3         | 1.7     |
|        |           | 2**24 | 167,000     | 80,500       | 2.1     |
|        | complex32 | 8     | 37.0        | 7.57         | 4.9     |
|        |           | 4096  | 73.1        | 37.6         | 1.9     |
|        |           | 2**24 | 409,000     | 161,000      | 2.5     |
| CUDA   | float32   | 8     | 40.4        | 11.7         | 3.5     |
|        |           | 4096  | 38.7        | 11.7         | 3.3     |
|        |           | 2**24 | 2,300       | 238          | 9.7     |
|        | complex32 | 8     | 78.7        | 14           | 5.6     |
|        |           | 4096  | 82.7        | 13.8         | 6.0     |
|        |           | 2**24 | 5,520       | 489          | 11.3    |

[ghstack-poisoned]
  • Loading branch information
peterbell10 committed Sep 27, 2022
2 parents 4565183 + 6f1b49e commit 83aac7b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11845,7 +11845,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_out=False,
sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.35p-3'),
'value': -9},
{'threshold': float.fromhex('0x1.2ap-5'),
{'threshold': float.fromhex('0x1.35p-3'),
'value': -9}),
# TODO(whc) should not need sample_inputs_func, but without it
# kwargs aren't being hooked up properly
Expand Down

0 comments on commit 83aac7b

Please sign in to comment.