Skip to content

Commit

Permalink
Update on "Extend SampleInput str representation with tensor data."
Browse files Browse the repository at this point in the history
As in the title. The aim of this addition is to make debugging certain CI failures (that cannot be reproduced locally) easier. For instance, currently we see messages like
```
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(20,), device="cuda:0", dtype=torch.float64], args=(), kwargs={}, broadcasts_input=False, name='')
```
that is not really useful (as all those sample parameters can often be detected by other means) without showing actual sample data. The sample data can then be related to the `index` part in the error messages like:
```
Mismatched elements: 2 / 20 (10.0%)
Greatest absolute difference: nan at index (10,) (up to 1e-05 allowed)
Greatest relative difference: nan at index (10,) (up to 1e-07 allowed)
```

As an example of usefulness of this PR, consider the following failure message:
```
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [1.5510s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 ('RERUN', {'yellow': True}) [0.0473s] [ 70%]
inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_polygamma_polygamma_n_0_cpu_int32 FAILED [0.0493s] [ 70%]

==================================== RERUNS ====================================
__ TestInductorOpInfoCPU.test_comprehensive_polygamma_polygamma_n_0_cpu_int32 __
Traceback (most recent call last):
<snip>
AssertionError: Tensor-likes are not close!

Mismatched elements: 9 / 25 (36.0%)
Greatest absolute difference: inf at index (0, 0) (up to 1e-05 allowed), inf vs 20177651499008.0
Greatest relative difference: inf at index (0, 0) (up to 1.3e-06 allowed)

The above exception was the direct cause of the following exception:

<snip>
Exception: Caused by sample input at index 0: SampleInput(input=Tensor[size=(5, 5), device="cpu", dtype=torch.int32, data=[-8, 6, 9, 0, 0, 5, 5, 7, 6, 5, 1, -5, 2, -1, 8, -4, 0, -6, 3, -5]], args=(1), kwargs={}, broadcasts_input=False, name='')
```
from which we learn that `torch.polygamma` result is actually correct because `polygamma(0, -8) -> inf` while the used reference value (20177651499008.0) is wrong (see #106692 for more details).





[ghstack-poisoned]
  • Loading branch information
pearu committed Feb 11, 2024
2 parents 4d7f5d0 + a275012 commit 10025db
Show file tree
Hide file tree
Showing 10 changed files with 2,000 additions and 2,022 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/Bucketization.mm
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ static void searchsorted_mps_contiguous(Tensor& result,

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();

Expand Down
12 changes: 6 additions & 6 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,24 +1015,24 @@ def set_float32_matmul_precision(precision: str) -> None:
Supports three settings:
* "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
bits) for internal computations.
bits with 23 bits explicitly stored) for internal computations.
* "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
mantissa bits) or treat each float32 number as the sum of two bfloat16 numbers
(approximately 16 mantissa bits), if the appropriate fast matrix multiplication
mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
(approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
algorithms are available. Otherwise float32 matrix multiplications are computed
as if the precision is "highest". See below for more information on the bfloat16
approach.
* "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
bits) for internal computations, if a fast matrix multiplication algorithm
bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
using that datatype internally is available. Otherwise float32
matrix multiplications are computed as if the precision is "high".
When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
that is more complicated than simply truncating to some smaller number mantissa bits
(e.g. 10 for TensorFloat32, 8 for bfloat16). Refer to [Henry2019]_ for a complete
(e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
description of this algorithm. To briefly explain here, the first step is to realize
that we can perfectly encode a single float32 number as the sum of three bfloat16
numbers (because float32 has 24 mantissa bits while bfloat16 has 8, and both have the
numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
same number of exponent bits). This means that the product of two float32 numbers can
be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
accuracy for speed by dropping some of these products. The "high" precision algorithm
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,7 @@ def group_fn(self, sizes):
return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)

def get_kernel_group(self):
from .cpp_wrapper_cpu import CppWrapperCodeGen
from .wrapper import CppWrapperCodeGen

self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup]
if isinstance(V.graph.wrapper_code, CppWrapperCodeGen):
Expand Down
Loading

0 comments on commit 10025db

Please sign in to comment.