Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fetch local_strategy parameter and disable symlinks by default #2476

Merged
merged 37 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
90643c2
symlink-free fetch API mockup
asumagic Mar 27, 2024
462cb12
Formatting
asumagic Mar 27, 2024
5f0f34e
Lint
asumagic Mar 27, 2024
7c66922
load_audio inference uses NO_LINK strategy
asumagic Mar 27, 2024
cdcb54c
Formatting fix..?
asumagic Mar 27, 2024
e06b726
Fix now broken fetchfrom logic, whoops
asumagic Mar 27, 2024
c5c0a64
Typos
asumagic Mar 28, 2024
a4ac3bb
Docstring for link_with_strategy
asumagic Mar 28, 2024
b1004e2
Fix potential Python 3.8 compat issue with paths
asumagic Mar 28, 2024
6215f89
Fix incorrectly removed import
asumagic Apr 24, 2024
46191a1
Fix include formatting
asumagic Apr 24, 2024
4c6489c
Warn when using the SYMLINK strategy on Windows
asumagic May 6, 2024
5412fef
localstrategy configurability for utils/inference
asumagic May 6, 2024
9a333c7
More explicit ESC50 fetch localstrategy
asumagic May 6, 2024
f09c206
Default to NO_LINK strategy
asumagic May 6, 2024
f9de878
File fetching in inference more explicitly uses NO_LINK
asumagic May 6, 2024
507b472
Precommit fix
asumagic May 6, 2024
a0e4c7e
Fix attempt for parameter transfer?
asumagic May 7, 2024
2484e99
Pre-commit fix
asumagic May 7, 2024
0a733a3
Pre-commit fix the return
asumagic May 7, 2024
c4b3872
Fix log formatting in fetching
asumagic Jun 7, 2024
b21e0db
Explicitly detect symlink self-recursion in fetch() and error out
asumagic Jun 7, 2024
78e5b94
Fix missing savedir for parameter_transfer
asumagic Jun 26, 2024
3add4d8
Cleanup fetch kwargs logic
asumagic Jun 26, 2024
c762f5f
Merge branch 'develop' into symlinkinator
asumagic Jul 12, 2024
a020eec
Merge branch 'develop' into symlinkinator
asumagic Jul 17, 2024
0614651
Log formatting
asumagic Jul 17, 2024
3cc9c93
Refactor out some source guessing complexity, add params to fetch
asumagic Jul 18, 2024
db3ba48
Merge branch 'develop' into symlinkinator
asumagic Jul 19, 2024
9153015
Type annotations and formatting
asumagic Jul 19, 2024
cee008f
Fix bad assert
asumagic Jul 19, 2024
6ad247b
Move some fetch messages to debug when minor
asumagic Jul 19, 2024
ba99992
Move some parameter transfer logging from info to debug level
asumagic Jul 19, 2024
7e0533b
Be more explicit in logging for HF fetch
asumagic Jul 19, 2024
d17aa10
Merge branch 'develop' into symlinkinator
asumagic Sep 11, 2024
ee05542
Lint fixes
asumagic Sep 11, 2024
b9a7a7f
Fix invalid default for `local_strategy` in doc
asumagic Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions recipes/ESC50/esc50_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import speechbrain as sb
from speechbrain.dataio.dataio import load_data_csv, read_audio
from speechbrain.utils.fetching import fetch
from speechbrain.utils.fetching import LocalStrategy, fetch

logger = logging.getLogger(__name__)

Expand All @@ -47,14 +47,16 @@ def download_esc50(data_path):
temp_path = os.path.join(data_path, "temp_download")

# download the data
fetch(
archive_path = fetch(
"master.zip",
"https://github.com/karoldvl/ESC-50/archive/",
savedir=temp_path,
# URL, so will be fetched directly in the savedir anyway
local_strategy=LocalStrategy.COPY_SKIP_CACHE,
)

# unpack the .zip file
shutil.unpack_archive(os.path.join(temp_path, "master.zip"), data_path)
shutil.unpack_archive(archive_path, data_path)

# move the files up to the datapath
files = os.listdir(os.path.join(data_path, "ESC-50-master"))
Expand Down
9 changes: 7 additions & 2 deletions speechbrain/inference/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch
from speechbrain.utils.fetching import LocalStrategy, fetch


class EncoderClassifier(Pretrained):
Expand Down Expand Up @@ -293,7 +293,12 @@ def classify_file(self, path, savedir="audio_cache"):
(label encoder should be provided).
"""
source, fl = split_path(path)
path = fetch(fl, source=source, savedir=savedir)
path = fetch(
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
)

batch, fs_file = torchaudio.load(path)
batch = batch.to(self.device)
Expand Down
22 changes: 20 additions & 2 deletions speechbrain/inference/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from speechbrain.utils.data_pipeline import DataPipeline
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.fetching import fetch
from speechbrain.utils.fetching import LocalStrategy, fetch
from speechbrain.utils.superpowers import import_from_path

logger = logging.getLogger(__name__)
Expand All @@ -49,6 +49,7 @@ def foreign_class(
use_auth_token=False,
download_only=False,
huggingface_cache_dir=None,
local_strategy: LocalStrategy = LocalStrategy.NO_LINK,
**kwargs,
):
"""Fetch and load an interface from an outside source
Expand Down Expand Up @@ -95,6 +96,10 @@ def foreign_class(
If true, class and instance creation is skipped.
huggingface_cache_dir : str
Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None)
local_strategy : speechbrain.utils.fetching.LocalStrategy
The fetching strategy to use, which controls the behavior of remote file
fetching with regards to symlinking and copying.
See :func:`speechbrain.utils.fetching.fetch` for further details.
**kwargs : dict
Arguments to forward to class constructor.

Expand All @@ -114,6 +119,7 @@ def foreign_class(
use_auth_token=use_auth_token,
revision=None,
huggingface_cache_dir=huggingface_cache_dir,
local_strategy=local_strategy,
)
pymodule_local_path = fetch(
filename=pymodule_file,
Expand All @@ -124,6 +130,7 @@ def foreign_class(
use_auth_token=use_auth_token,
revision=None,
huggingface_cache_dir=huggingface_cache_dir,
local_strategy=local_strategy,
)
sys.path.append(str(pymodule_local_path.parent))

Expand Down Expand Up @@ -286,7 +293,12 @@ def load_audio(self, path, savedir="."):
The path can be a local path, a web url, or a link to a huggingface repo.
"""
source, fl = split_path(path)
path = fetch(fl, source=source, savedir=savedir)
path = fetch(
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
)
signal, sr = torchaudio.load(str(path), channels_first=False)
return self.audio_normalizer(signal, sr)

Expand Down Expand Up @@ -397,6 +409,7 @@ def from_hparams(
download_only=False,
huggingface_cache_dir=None,
overrides_must_match=True,
local_strategy: LocalStrategy = LocalStrategy.NO_LINK,
**kwargs,
):
"""Fetch and load based from outside source based on HyperPyYAML file
Expand Down Expand Up @@ -450,6 +463,9 @@ def from_hparams(
Path to HuggingFace cache; if None -> "~/.cache/huggingface" (default: None)
overrides_must_match : bool
Whether the overrides must match the parameters already in the file.
local_strategy : LocalStrategy, optional
Which strategy to use to deal with files locally. (default:
`LocalStrategy.SYMLINK`)
**kwargs : dict
Arguments to forward to class constructor.

Expand All @@ -469,6 +485,7 @@ def from_hparams(
use_auth_token=use_auth_token,
revision=revision,
huggingface_cache_dir=huggingface_cache_dir,
local_strategy=local_strategy,
)
try:
pymodule_local_path = fetch(
Expand All @@ -480,6 +497,7 @@ def from_hparams(
use_auth_token=use_auth_token,
revision=revision,
huggingface_cache_dir=huggingface_cache_dir,
local_strategy=local_strategy,
)
sys.path.append(str(pymodule_local_path.parent))
except ValueError:
Expand Down
9 changes: 7 additions & 2 deletions speechbrain/inference/interpretability.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from speechbrain.inference.interfaces import Pretrained
from speechbrain.processing.NMF import spectral_phase
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch
from speechbrain.utils.fetching import LocalStrategy, fetch


class PIQAudioInterpreter(Pretrained):
Expand Down Expand Up @@ -153,7 +153,12 @@ def interpret_file(self, path, savedir="audio_cache"):
The sampling frequency of the model. Useful to save the audio.
"""
source, fl = split_path(path)
path = fetch(fl, source=source, savedir=savedir)
path = fetch(
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
)

batch, fs_file = torchaudio.load(path)
batch = batch.to(self.device)
Expand Down
9 changes: 7 additions & 2 deletions speechbrain/inference/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.data_utils import split_path
from speechbrain.utils.fetching import fetch
from speechbrain.utils.fetching import LocalStrategy, fetch


class SepformerSeparation(Pretrained):
Expand Down Expand Up @@ -97,7 +97,12 @@ def separate_file(self, path, savedir="audio_cache"):
Separated sources
"""
source, fl = split_path(path)
path = fetch(fl, source=source, savedir=savedir)
path = fetch(
fl,
source=source,
savedir=savedir,
local_strategy=LocalStrategy.NO_LINK,
)

batch, fs_file = torchaudio.load(path)
batch = batch.to(self.device)
Expand Down
Loading
Loading