Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a global configuration to throw error on NaN #1274

Closed
soumith opened this issue Apr 17, 2017 · 10 comments
Closed

a global configuration to throw error on NaN #1274

soumith opened this issue Apr 17, 2017 · 10 comments
Labels
todo Not as important as medium or high priority tasks, but we will work on these.

Comments

@soumith
Copy link
Member

soumith commented Apr 17, 2017

NumPy has a global setting where instead of returning NaN, one can throw a RuntimeError.
https://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html

We should have something like this in pytorch (maybe).
This task is to scope out that work, or reject the feature request.

@soumith soumith added enhancement todo Not as important as medium or high priority tasks, but we will work on these. labels Apr 17, 2017
@kdexd
Copy link

kdexd commented Apr 20, 2017

This might be helpful, following the example from their docs:

>>> np.int16(32000) * np.int16(3)
   30464
>>> old_settings = np.seterr(all='warn', over='raise')
>>> np.int16(32000) * np.int16(3)
 Traceback (most recent call last):
   File "<stdin>", line 1, in <module>
 FloatingPointError: overflow encountered in short_scalars

Such error raising might be helpful for pytorch when a user is playing with quantization / deep compression. This is one thing I can think of.

@samuela
Copy link
Contributor

samuela commented Mar 30, 2018

please please please

@samuela
Copy link
Contributor

samuela commented Mar 30, 2018

FWIW I've implemented a wrapped version of torch that does its best to emulate this sort of behavior: https://github.com/samuela/kindling/blob/master/kindling/nan_police.py.

In [1]: from kindling.nan_police import torch

In [2]: x = torch.ones(2) / 0 * 0

In [3]: x
Out[3]:

nan
nan
[torch.FloatTensor of size 2]

In [4]: torch.sum(x)
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-4-e7f45fec8fb4> in <module>()
----> 1 torch.sum(x)

~/Development/kindling/kindling/nan_police.py in __call__(self, *args, **kwargs)
    147         if argnan_path == []:
    148           raise Exception(
--> 149             f'Found a NaN at positional argument {i + 1} (of {len(args)}) when '
    150             f'calling `{path}`!'
    151           )

Exception: Found a NaN at positional argument 1 (of 1) when calling `torch.sum`!

I've added it to my mini-toolkit of pytorch utilities, kindling.

@ElleryL
Copy link

ElleryL commented Jun 12, 2018

please add this feature

@yaysummeriscoming
Copy link

please, this would be great!

@suruoxi
Copy link

suruoxi commented Mar 15, 2019

Great feature for debugging. any updates?

@soumith
Copy link
Member Author

soumith commented Mar 27, 2019

anomaly detection provides this feature: https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly

However, it's not cheap.

@samuela the simulated version you have is not cheap either right?

Shall we close this feature request with a pointer to detect_anomaly?

@samuela
Copy link
Contributor

samuela commented Mar 27, 2019

@soumith No, I don't imagine that it's particularly fast, but I also haven't benchmarked it. I use it when debugging these sorts of errors and then resume using usual torch after figuring things out.

@soumith
Copy link
Member Author

soumith commented Mar 28, 2019

in that case, I'm closing the issue.

The global configurable variable to detect anomalies including nan is:

torch.autograd.set_detect_anomaly(True)

It is going to be really slow, but there isn't really something else we are aiming to do that'll be any faster.

This can also be used as a context manager for a limited set of statements as:

with autograd.detect_anomaly():
     inp = torch.rand(10, 10, requires_grad=True)
     out = run_fn(inp)
     out.backward()

Full documentation at: https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly

@soumith soumith closed this as completed Mar 28, 2019
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this issue Nov 1, 2022
* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
hubertlu-tw pushed a commit to hubertlu-tw/pytorch that referenced this issue Nov 1, 2022
* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
hubertlu-tw added a commit to hubertlu-tw/pytorch that referenced this issue Nov 1, 2022
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (pytorch#1274)

* FusedRMSNorm based on FusedLayerNorm

* refactor duplicated kernels

* delete comments

* delete comments

* cleanup

* cleanup

* cleanup, fixed clobbering forward_affine_mixed_dtypes

* fix pybind naming and add MixedFused test

* undo skipping

* check elementwise_affine

* Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py

Oof, nice catch, thanks

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>

* fix and generate docs for FusedRMSNorm (pytorch#1285)

* [FusedRMSNorm doc] document where epsilon is added (pytorch#1295)

* [FusedRMSNorm doc] add epsilon to formula

* correct

* better wording

* Fix some bugs

* Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs

* Fix NaN issues in FusedRMSNorm

* Update test_fused_layer_norm.py

* Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm

* Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize

Co-authored-by: eqy <eddiey@nvidia.com>
Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
@R-N
Copy link

R-N commented Nov 4, 2023

in that case, I'm closing the issue.

The global configurable variable to detect anomalies including nan is:

torch.autograd.set_detect_anomaly(True)

It doesn't seem to raise error on nan in forward pass, or does it?

"any backward computation that generate “nan” value will raise an error"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo Not as important as medium or high priority tasks, but we will work on these.
Projects
None yet
Development

No branches or pull requests

7 participants