Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Compatible with huggingface-hub v0.23.0 #1514

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ install_requires =
pydantic
fastapi
uvicorn
huggingface-hub>=0.19.4,<0.23.0
huggingface-hub>=0.19.4
typing_extensions
fsspec>=2023.1.0,<=2023.10.0
s3fs
Expand Down
2 changes: 1 addition & 1 deletion xinference/deploy/docker/cpu.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ RUN python -m pip install --upgrade -i "$PIP_INDEX" pip && \
pydantic \
fastapi \
uvicorn \
"huggingface-hub>=0.19.4,<0.23.0" \
"huggingface-hub>=0.19.4" \
typing_extensions \
"fsspec>=2023.1.0,<=2023.10.0" \
s3fs \
Expand Down
1 change: 1 addition & 0 deletions xinference/model/embedding/tests/test_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_model_from_modelscope():
assert len(r["data"]) == 1
for d in r["data"]:
assert len(d["embedding"]) == 512
shutil.rmtree(model_path, ignore_errors=True)


def test_meta_file():
Expand Down
25 changes: 15 additions & 10 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)
from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
from ..utils import (
IS_NEW_HUGGINGFACE_HUB,
create_symlink,
download_from_modelscope,
is_valid_model_uri,
parse_uri,
Expand Down Expand Up @@ -625,10 +627,7 @@ def cache_from_modelscope(
llm_spec.model_id,
revision=llm_spec.model_revision,
)
for subdir, dirs, files in os.walk(download_dir):
for file in files:
relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
create_symlink(download_dir, cache_dir)

elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
file_names, final_file_name, need_merge = _generate_model_file_names(
Expand Down Expand Up @@ -682,9 +681,13 @@ def cache_from_huggingface(
):
return cache_dir

use_symlinks = {}
if not IS_NEW_HUGGINGFACE_HUB:
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}

if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
assert isinstance(llm_spec, PytorchLLMSpecV1)
retry_download(
download_dir = retry_download(
huggingface_hub.snapshot_download,
llm_family.model_name,
{
Expand All @@ -693,9 +696,10 @@ def cache_from_huggingface(
},
llm_spec.model_id,
revision=llm_spec.model_revision,
local_dir=cache_dir,
local_dir_use_symlinks=True,
**use_symlinks,
)
if IS_NEW_HUGGINGFACE_HUB:
create_symlink(download_dir, cache_dir)

elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
assert isinstance(llm_spec, GgmlLLMSpecV1)
Expand All @@ -704,7 +708,7 @@ def cache_from_huggingface(
)

for file_name in file_names:
retry_download(
download_file_path = retry_download(
huggingface_hub.hf_hub_download,
llm_family.model_name,
{
Expand All @@ -714,9 +718,10 @@ def cache_from_huggingface(
llm_spec.model_id,
revision=llm_spec.model_revision,
filename=file_name,
local_dir=cache_dir,
local_dir_use_symlinks=True,
**use_symlinks,
)
if IS_NEW_HUGGINGFACE_HUB:
symlink_local_file(download_file_path, cache_dir, file_name)

if need_merge:
_merge_cached_files(cache_dir, file_names, final_file_name)
Expand Down
14 changes: 7 additions & 7 deletions xinference/model/llm/pytorch/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Union

import pytest
Expand Down Expand Up @@ -49,6 +48,8 @@ class MockPytorchModel(MockNonPytorchModel, PytorchModel):
@pytest.mark.asyncio
@pytest.mark.parametrize("quantization", ["none"])
async def test_opt_pytorch_model(setup, quantization):
from .....constants import XINFERENCE_CACHE_DIR

endpoint, _ = setup
client = Client(endpoint)
assert len(client.list_models()) == 0
Expand Down Expand Up @@ -97,12 +98,11 @@ def _check():
assert len(client.list_models()) == 0

# check for cached revision
home_address = str(Path.home())
snapshot_address = (
home_address
+ "/.cache/huggingface/hub/models--facebook--opt-125m/snapshots"
valid_file = os.path.join(
XINFERENCE_CACHE_DIR, "opt-pytorch-1b", "__valid_download"
)
actual_revision = os.listdir(snapshot_address)
with open(valid_file, "r") as f:
actual_revision = json.load(f)["revision"]
model_name = "opt"
expected_revision: Union[str, None] = "" # type: ignore

Expand All @@ -112,7 +112,7 @@ def _check():
for spec in family.model_specs:
expected_revision = spec.model_revision

assert [expected_revision] == actual_revision
assert expected_revision == actual_revision


@pytest.mark.asyncio
Expand Down
24 changes: 17 additions & 7 deletions xinference/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union

import huggingface_hub
from fsspec import AbstractFileSystem

from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
Expand All @@ -27,6 +28,7 @@

logger = logging.getLogger(__name__)
MAX_ATTEMPTS = 3
IS_NEW_HUGGINGFACE_HUB: bool = huggingface_hub.__version__ >= "0.23.0"


def is_locale_chinese_simplified() -> bool:
Expand Down Expand Up @@ -76,6 +78,13 @@ def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
return local_dir_filepath


def create_symlink(download_dir: str, cache_dir: str):
for subdir, dirs, files in os.walk(download_dir):
for file in files:
relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)


def retry_download(
download_func: Callable,
model_name: str,
Expand Down Expand Up @@ -306,22 +315,23 @@ def cache(model_spec: CacheableModelSpec, model_description_type: type):
model_spec.model_id,
revision=model_spec.model_revision,
)
for subdir, dirs, files in os.walk(download_dir):
for file in files:
relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
create_symlink(download_dir, cache_dir)
else:
from huggingface_hub import snapshot_download as hf_download

retry_download(
use_symlinks = {}
if not IS_NEW_HUGGINGFACE_HUB:
use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
download_dir = retry_download(
hf_download,
model_spec.model_name,
None,
model_spec.model_id,
revision=model_spec.model_revision,
local_dir=cache_dir,
local_dir_use_symlinks=True,
**use_symlinks,
)
if IS_NEW_HUGGINGFACE_HUB:
create_symlink(download_dir, cache_dir)
with open(meta_path, "w") as f:
import json

Expand Down
Loading