diff --git a/tqdm/rich.py b/tqdm/rich.py index 00e1ddf26..be98b6b79 100644 --- a/tqdm/rich.py +++ b/tqdm/rich.py @@ -9,14 +9,17 @@ from warnings import warn from rich.progress import ( - BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, filesize) + BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, + filesize, TextColumn) from .std import TqdmExperimentalWarning from .std import tqdm as std_tqdm +from typing import TypeAlias __author__ = {"github.com/": ["casperdcl"]} __all__ = ['tqdm_rich', 'trrange', 'tqdm', 'trange'] +PostfixColumn: TypeAlias = TextColumn class FractionColumn(ProgressColumn): """Renders completed/total, e.g. '0.5/2.3 G'.""" @@ -97,8 +100,9 @@ def __init__(self, *args, **kwargs): warn("rich is experimental/alpha", TqdmExperimentalWarning, stacklevel=2) d = self.format_dict + postfix = TextColumn(self.postfix) if self.postfix is not None else None if progress is None: - progress = ( + progress = [ "[progress.description]{task.description}" "[progress.percentage]{task.percentage:>4.0f}%", BarColumn(bar_width=None), @@ -106,8 +110,10 @@ def __init__(self, *args, **kwargs): unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']), "[", TimeElapsedColumn(), "<", TimeRemainingColumn(), ",", RateColumn(unit=d['unit'], unit_scale=d['unit_scale'], - unit_divisor=d['unit_divisor']), "]" - ) + unit_divisor=d['unit_divisor'])] + if self.postfix is not None: + progress.extend([",", PostfixColumn("{task.fields[postfix]}")]) + progress.append("]") options.setdefault('transient', not self.leave) self._prog = Progress(*progress, **options) self._prog.__enter__() @@ -122,6 +128,33 @@ def close(self): def clear(self, *_, **__): pass + def _update_postfix(self): + if self.postfix is None and type(self._prog.columns[-1]) is not PostfixColumn: + columns = list(self._prog.columns) + columns.pop(-1) + columns.pop(-1) + self._prog.columns = tuple(columns) + elif self.postfix is not None and type(self._prog.columns[-2]) is not PostfixColumn: + if type(self._prog.columns[-2]) is not PostfixColumn: + # Add "," and column for postfix if it does not exist + columns = list(self._prog.columns) + columns.insert(-1, ",") + columns.insert(-1, + PostfixColumn("{task.fields[postfix]}")) + self._prog.columns = tuple(columns) + self._prog.update(self._task_id, postfix=self.postfix) + + + def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): + super_result = super().set_postfix(ordered_dict, refresh, **kwargs) + self._update_postfix() + return super_result + + def set_postfix_str(self, s='', refresh=True): + super_result = super().set_postfix_str(s, refresh) + self._update_postfix() + return super_result + def display(self, *_, **__): if not hasattr(self, '_prog'): return