diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 225593437..8b0ae36a3 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -29,7 +29,7 @@ import torch.multiprocessing as mp from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_map -from tqdm.auto import tqdm +from tqdm.rich import tqdm from triton.testing import do_bench from .. import exc diff --git a/helion/autotuner/benchmarking.py b/helion/autotuner/benchmarking.py index 8fecfe568..e8d6313b0 100644 --- a/helion/autotuner/benchmarking.py +++ b/helion/autotuner/benchmarking.py @@ -4,7 +4,7 @@ import statistics from typing import Callable -from tqdm.auto import tqdm +from tqdm.rich import tqdm from triton import runtime diff --git a/pyproject.toml b/pyproject.toml index 6f79e8528..8eb890980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ dependencies = [ "typing_extensions>=4.0.0", "filecheck", "psutil", - "tqdm" + "tqdm", + "rich" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index cd95d4d82..894173feb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ filecheck expecttest numpy tqdm +rich