Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't download both safetensor and bin files. #2480

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions vllm/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Utilities for downloading and initializing model weights."""
import filelock
import glob
import fnmatch
import json
import os
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple

from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch
Expand Down Expand Up @@ -149,6 +150,20 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"]

if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
if pattern == "*.safetensors":
use_safetensors = True
break

logger.info(f"Downloading model weights {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
Expand All @@ -163,8 +178,6 @@ def prepare_hf_model_weights(
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.
Expand Down
Loading