Skip to content

Commit

Permalink
fix: download spacy models to custom cache
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Jun 29, 2023
1 parent 9349361 commit 43013d8
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 7 deletions.
71 changes: 65 additions & 6 deletions nlp_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

import itertools
import logging
import tarfile
import typing as t
import urllib.request
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path

import arg_services
import grpc
Expand All @@ -14,7 +17,14 @@
import typer
from arg_services.nlp.v1 import nlp_pb2, nlp_pb2_grpc
from mashumaro.mixins.dict import DataClassDictMixin
from spacy.cli.download import download as spacy_download
from rich.progress import Progress
from spacy import about as spacy_about
from spacy.cli.download import (
get_latest_version as spacy_get_latest_version,
)
from spacy.cli.download import (
get_model_filename as spacy_get_model_filename,
)
from spacy.language import Language as SpacyLanguage
from spacy.tokens import Doc, DocBin
from thinc.types import Floats1d as SpacyVector
Expand Down Expand Up @@ -251,15 +261,64 @@ def vector(self, obj) -> SpacyVector:
model_cache: dict[EmbeddingModel, ModelBase] = {}


class UrlReportHook:
def __init__(self, progress: Progress, name: str):
self.task = None
self.progress = progress
self.name = name

def __call__(self, block_num: int, block_size: int, total_size: int):
if self.task is None:
self.task = self.progress.add_task(
f"Downloading {self.name}...", total=total_size
)

downloaded = block_num * block_size

if downloaded < total_size:
self.progress.update(self.task, completed=downloaded)

if self.progress.finished:
self.task = None


def get_tarfile_members(
tf: tarfile.TarFile, prefix: str
) -> t.Generator[tarfile.TarInfo, None, None]:
prefix_len = len(prefix)

for member in tf.getmembers():
if member.path.startswith(prefix):
member.path = member.path[prefix_len:]
yield member


def _load_spacy_model(name: t.Optional[str]) -> SpacyLanguage:
if not name:
return spacy.blank("en")

try:
return spacy.load(name)
except OSError:
spacy_download(name)
return spacy.load(name)
version = spacy_get_latest_version(name)
filename = spacy_get_model_filename(name, version, sdist=True)
versioned_name = f"{name}-{version}"
path = Path.home() / ".cache" / "nlp-service" / "spacy" / versioned_name
tmpfile = path.with_suffix(".tar.gz")

if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
download_url = f"{spacy_about.__download_url__}/{filename}"
with Progress() as progress:
urllib.request.urlretrieve(
download_url, tmpfile, UrlReportHook(progress, versioned_name)
)

with tarfile.open(tmpfile, mode="r:gz") as tf:
member_prefix = f"{versioned_name}/{name}/{versioned_name}/"
members = get_tarfile_members(tf, member_prefix)
tf.extractall(path=path, members=members)

tmpfile.unlink()

return spacy.load(path)


def _load_spacy(config: nlp_pb2.NlpConfig) -> SpacyCache:
Expand Down
73 changes: 72 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ sentence-transformers = { version = "^2.2.2", optional = true }
torch = { version = ">=2.0.0, !=2.0.1, <3.0", optional = true }
transformers = { version = "^4.30.2", optional = true }
openai = { version = "^0.27.8", optional = true }
rich = "^13.4.2"

[tool.poetry.extras]
wmd = ["gensim"]
Expand Down

0 comments on commit 43013d8

Please sign in to comment.