Skip to content

Commit

Permalink
Exception formatting: handle case where module is None (facebookres…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasha10 committed Oct 10, 2022
1 parent 06d8a66 commit 9ce6720
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 10 deletions.
9 changes: 4 additions & 5 deletions hydra/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging.config
import os
import sys
import traceback
import warnings
from dataclasses import dataclass
from os.path import dirname, join, normpath, realpath
from traceback import print_exc, print_exception
from types import FrameType, TracebackType
from typing import Any, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -248,7 +248,7 @@ def run_and_report(func: Any) -> Any:
if search_max == 0 or tb is None:
# could not detect run_job, probably a runtime exception before we got there.
# do not sanitize the stack trace.
print_exc()
traceback.print_exc()
sys.exit(1)

# strip OmegaConf frames from bottom of stack
Expand All @@ -257,8 +257,7 @@ def run_and_report(func: Any) -> Any:
while end is not None:
frame = end.tb_frame
mdl = inspect.getmodule(frame)
assert mdl is not None
name = mdl.__name__
name = mdl.__name__ if mdl is not None else ""
if name.startswith("omegaconf."):
break
end = end.tb_next
Expand Down Expand Up @@ -288,7 +287,7 @@ class FakeTracebackType:
assert iter_tb.tb_next is not None
iter_tb = iter_tb.tb_next

print_exception(None, value=ex, tb=final_tb) # type: ignore
traceback.print_exception(None, value=ex, tb=final_tb) # type: ignore
sys.stderr.write(
"\nSet the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.\n"
)
Expand Down
1 change: 1 addition & 0 deletions news/2342.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix an issue where Hydra's exception-handling logic could raise an `AssertionError`
43 changes: 38 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from pathlib import Path
from textwrap import dedent
from typing import Any, Optional
from typing import Any, NoReturn, Optional
from unittest.mock import patch

from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -216,25 +216,58 @@ def test_failure(self, demo_func: Any, expected_traceback_regex: str) -> None:
stderr_output = mock_stderr.read()
assert_regex_match(expected_traceback_regex, stderr_output)

def test_simplified_traceback_with_no_module(self) -> None:
"""
Test that simplified traceback logic can succeed even if
`inspect.getmodule(frame)` returns `None` for one of
the frames in the stacktrace.
"""
demo_func = self.DemoFunctions.run_job_wrapper
expected_traceback_regex = dedent(
r"""
Traceback \(most recent call last\):$
File "[^"]+", line \d+, in nested_error$
assert False, "nested_err"$
AssertionError: nested_err$
assert False$
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace\.$
"""
)
mock_stderr = io.StringIO()
with raises(SystemExit, match="1"), patch("sys.stderr", new=mock_stderr):
# Patch `inspect.getmodule` so that it will return None. This simulates a
# situation where a python module cannot be identified from a traceback
# stack frame. This can occur when python extension modules or
# multithreading are involved.
with patch("inspect.getmodule", new=lambda *args: None):
run_and_report(demo_func)
mock_stderr.seek(0)
stderr_output = mock_stderr.read()
assert_regex_match(expected_traceback_regex, stderr_output)

def test_simplified_traceback_failure(self) -> None:
"""
Test that a warning is printed and the original exception is re-raised
when an exception occurs during the simplified traceback logic.
"""
demo_func = self.DemoFunctions.run_job_wrapper

def throws(*args: Any, **kwargs: Any) -> NoReturn:
assert False, "Error thrown"

expected_traceback_regex = dedent(
r"""
An error occurred during Hydra's exception formatting:$
AssertionError\(.*\)$
AssertionError\(.*Error thrown.*\)$
"""
)
mock_stderr = io.StringIO()
with raises(AssertionError, match="nested_err"), patch(
"sys.stderr", new=mock_stderr
):
# patch `inspect.getmodule` so that an exception will occur in the
# simplified traceback logic:
with patch("inspect.getmodule", new=lambda *args: None):
# patch `traceback.print_exception` so that an exception will occur
# in the simplified traceback logic:
with patch("traceback.print_exception", new=throws):
run_and_report(demo_func)
mock_stderr.seek(0)
stderr_output = mock_stderr.read()
Expand Down

0 comments on commit 9ce6720

Please sign in to comment.