Skip to content

Commit

Permalink
improved coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
de-code committed May 29, 2021
1 parent ba3ce3b commit e880d2b
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions tests/tests_contrib_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import logging
import logging.handlers
import sys
from contextlib import contextmanager
from io import StringIO

try:
from typing import Optional # pylint: disable=unused-import
from typing import Iterator, List, Optional # pylint: disable=unused-import
except ImportError:
pass

Expand All @@ -29,7 +30,7 @@


class CustomTqdm(tqdm):
messages = []
messages = [] # type: List[str]

@classmethod
def write(cls, s, **__): # pylint: disable=arguments-differ
Expand Down Expand Up @@ -188,6 +189,21 @@ def test_use_root_logger_by_default_and_write_to_custom_tqdm(self):
assert CustomTqdm.messages == ['test']


@contextmanager
def add_capturing_logging_handler(
logger # type: logging.Logger
):
# type: (...) -> Iterator[StringIO]
try:
previous_handlers = logger.handlers
out = StringIO()
stream_handler = logging.StreamHandler(out)
logger.addHandler(stream_handler)
yield out
finally:
logger.handlers = previous_handlers


class TestLoggingTqdm:
@pytest.mark.parametrize(
"logger_param,expected_logger",
Expand All @@ -201,15 +217,21 @@ def test_should_log_tqdm_output(
logger_param, # type: Optional[logging.Logger]
expected_logger # type: logging.Logger
):
try:
previous_handlers = expected_logger.handlers
out = StringIO()
stream_handler = logging.StreamHandler(out)
expected_logger.addHandler(stream_handler)
with add_capturing_logging_handler(expected_logger) as out:
with logging_tqdm(total=2, logger=logger_param, mininterval=0) as pbar:
pbar.update(1)
finally:
expected_logger.handlers = previous_handlers
last_log_line = out.getvalue().splitlines()[-1]
last_log_line = out.getvalue().splitlines()[-1]
assert '50%' in last_log_line
assert '1/2' in last_log_line

def test_should_not_output_before_any_progress(self):
with add_capturing_logging_handler(DEFAULT_LOGGER) as out:
with logging_tqdm(total=2, mininterval=0) as _:
pass
assert out.getvalue() == ''

def test_should_not_output_with_none_msg(self):
with add_capturing_logging_handler(DEFAULT_LOGGER) as out:
with logging_tqdm(total=2, mininterval=0) as pbar:
pbar.display()
assert out.getvalue() == ''

0 comments on commit e880d2b

Please sign in to comment.