Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
from requests.exceptions import HTTPError
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.utils.constants import QEFF_MODELS_DIR
from QEfficient.utils.constants import QEFF_MODELS_DIR, Constants
from QEfficient.utils.logging_utils import logger


class DownloadRetryLimitExceeded(Exception):
"""
Used for raising error when hf_download fails to download the model after given max_retries.
"""


def login_and_download_hf_lm(model_name, *args, **kwargs):
logger.info(f"loading HuggingFace model for {model_name}")
hf_token = kwargs.pop("hf_token", None)
Expand All @@ -37,12 +43,12 @@ def hf_download(
hf_token: Optional[str] = None,
allow_patterns: Optional[List[str]] = None,
ignore_patterns: Optional[List[str]] = None,
max_retries: Optional[int] = Constants.MAX_RETRIES,
):
# Setup cache_dir
if cache_dir is not None:
os.makedirs(cache_dir, exist_ok=True)

max_retries = 5
retry_count = 0
while retry_count < max_retries:
try:
Expand All @@ -59,14 +65,23 @@ def hf_download(
except requests.ReadTimeout as e:
logger.info(f"Read timeout: {e}")
retry_count += 1

except HTTPError as e:
retry_count = max_retries
if e.response.status_code == 401:
logger.info("You need to pass a valid `--hf_token=...` to download private checkpoints.")
raise e
except OSError as e:
if "Consistency check failed" in str(e):
logger.info(
"Consistency check failed during model download. The file appears to be incomplete. Resuming the download..."
)
retry_count += 1
else:
raise e

if retry_count >= max_retries:
raise DownloadRetryLimitExceeded(
f"Unable to download full model after {max_retries} tries. If the model fileS are huge in size, please try again."
)
return model_path


Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ class Constants:
INPUT_STR = ["My name is"]
GB = 2**30
MAX_QPC_LIMIT = 30
MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
Loading