From 0a72e813b3a6e59d7117eea0e70abf8a11acab10 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 15 Jul 2025 14:39:42 +0100 Subject: [PATCH 1/3] Add `load_with_torchcodec`, modify load()'s warnings (#3974) --- .github/scripts/unittest-linux/install.sh | 8 +- docs/source/torchaudio.rst | 7 +- src/torchaudio/__init__.py | 5 +- src/torchaudio/_backend/utils.py | 17 ++ src/torchaudio/_torchcodec.py | 161 +++++++++++++++ .../test_load_torchcodec.py | 193 ++++++++++++++++++ 6 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 src/torchaudio/_torchcodec.py create mode 100644 test/torchaudio_unittest/test_load_torchcodec.py diff --git a/.github/scripts/unittest-linux/install.sh b/.github/scripts/unittest-linux/install.sh index 8859b827f0..24dd7e3476 100755 --- a/.github/scripts/unittest-linux/install.sh +++ b/.github/scripts/unittest-linux/install.sh @@ -74,7 +74,7 @@ case $GPU_ARCH_TYPE in ;; esac PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${GPU_ARCH_ID}" -pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" +pip install --progress-bar=off --pre torch torchcodec --index-url="${PYTORCH_WHEEL_INDEX}" # 2. Install torchaudio @@ -86,6 +86,10 @@ python setup.py install # 3. Install Test tools printf "* Installing test tools\n" +# On this CI, for whatever reason, we're only able to install ffmpeg 4. +conda install -y "ffmpeg<5" +python -c "import torch; import torchaudio; import torchcodec; print(torch.__version__, torchaudio.__version__, torchcodec.__version__)" + NUMBA_DEV_CHANNEL="" if [[ "$(python --version)" = *3.9* || "$(python --version)" = *3.10* ]]; then # Numba isn't available for Python 3.9 and 3.10 except on the numba dev channel and building from source fails @@ -94,7 +98,7 @@ if [[ "$(python --version)" = *3.9* || "$(python --version)" = *3.10* ]]; then fi ( set -x - conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} sox libvorbis parameterized 'requests>=2.20' 'ffmpeg>=6,<7' + conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} sox libvorbis parameterized 'requests>=2.20' pip install kaldi-io SoundFile librosa coverage pytest pytest-cov scipy expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag pyroomacoustics flashlight-text git+https://github.com/kpu/kenlm # TODO: might be better to fix the single call to `pip install` above diff --git a/docs/source/torchaudio.rst b/docs/source/torchaudio.rst index 3212d554bb..13e14ceb6c 100644 --- a/docs/source/torchaudio.rst +++ b/docs/source/torchaudio.rst @@ -7,9 +7,11 @@ torchaudio Starting with version 2.8, we are refactoring TorchAudio to transition it into a maintenance phase. As a result: - - The APIs listed below are deprecated in 2.8 and will be removed in 2.9. + - Most APIs listed below are deprecated in 2.8 and will be removed in 2.9. - The decoding and encoding capabilities of PyTorch for both audio and video - are being consolidated into TorchCodec. + are being consolidated into TorchCodec. We provide + ``torchaudio.load_with_torchcodec()`` as a replacement for + ``torchaudio.load()``. Please see https://github.com/pytorch/audio/issues/3902 for more information. @@ -26,6 +28,7 @@ it easy to handle audio data. info load + load_with_torchcodec save list_audio_backends diff --git a/src/torchaudio/__init__.py b/src/torchaudio/__init__.py index 6c9c39d031..2a6f924ecf 100644 --- a/src/torchaudio/__init__.py +++ b/src/torchaudio/__init__.py @@ -7,16 +7,16 @@ get_audio_backend as _get_audio_backend, info as _info, list_audio_backends as _list_audio_backends, - load as _load, + load, save as _save, set_audio_backend as _set_audio_backend, ) +from ._torchcodec import load_with_torchcodec AudioMetaData = dropping_class_io_support(_AudioMetaData) get_audio_backend = dropping_io_support(_get_audio_backend) info = dropping_io_support(_info) list_audio_backends = dropping_io_support(_list_audio_backends) -load = dropping_io_support(_load) save = dropping_io_support(_save) set_audio_backend = dropping_io_support(_set_audio_backend) @@ -45,6 +45,7 @@ __all__ = [ "AudioMetaData", "load", + "load_with_torchcodec", "info", "save", "io", diff --git a/src/torchaudio/_backend/utils.py b/src/torchaudio/_backend/utils.py index 0cde6b1927..c39bc936d2 100644 --- a/src/torchaudio/_backend/utils.py +++ b/src/torchaudio/_backend/utils.py @@ -1,6 +1,7 @@ import os from functools import lru_cache from typing import BinaryIO, Dict, Optional, Tuple, Type, Union +import warnings import torch @@ -127,6 +128,14 @@ def load( ) -> Tuple[torch.Tensor, int]: """Load audio data from source. + .. warning:: + In 2.9, this function's implementation will be changed to use + :func:`~torchaudio.load_with_torchcodec` under the hood. Some + parameters like ``normalize``, ``format``, ``buffer_size``, and + ``backend`` will be ignored. We recommend that you port your code to + rely directly on TorchCodec's decoder instead: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder. + By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with ``float32`` dtype, and the shape of `[channel, time]`. @@ -201,6 +210,14 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has `[channel, time]` else `[time, channel]`. """ + warnings.warn( + "In 2.9, this function's implementation will be changed to use " + "torchaudio.load_with_torchcodec` under the hood. Some " + "parameters like ``normalize``, ``format``, ``buffer_size``, and " + "``backend`` will be ignored. We recommend that you port your code to " + "rely directly on TorchCodec's decoder instead: " + "https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder." + ) backend = dispatcher(uri, format, backend) return backend.load(uri, frame_offset, num_frames, normalize, channels_first, format, buffer_size) diff --git a/src/torchaudio/_torchcodec.py b/src/torchaudio/_torchcodec.py new file mode 100644 index 0000000000..db62ec413a --- /dev/null +++ b/src/torchaudio/_torchcodec.py @@ -0,0 +1,161 @@ +"""TorchCodec integration for TorchAudio.""" + +import os +from typing import BinaryIO, Optional, Tuple, Union + +import torch + + +def load_with_torchcodec( + uri: Union[BinaryIO, str, os.PathLike], + frame_offset: int = 0, + num_frames: int = -1, + normalize: bool = True, + channels_first: bool = True, + format: Optional[str] = None, + buffer_size: int = 4096, + backend: Optional[str] = None, +) -> Tuple[torch.Tensor, int]: + """Load audio data from source using TorchCodec's AudioDecoder. + + .. note:: + + This function supports the same API as ``torchaudio.load()``, and relies + on TorchCodec's decoding capabilities under the hood. It is provided for + convenience, but we do recommend that you port your code to natively use + ``torchcodec``'s ``AudioDecoder`` class for better performance: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder. + In TorchAudio 2.9, ``torchaudio.load()`` will be relying on + ``load_with_torchcodec``. Note that some parameters of + ``torchaudio.load()``, like ``normalize``, ``buffer_size``, and + ``backend``, are ignored by ``load_with_torchcodec``. + + + Args: + uri (path-like object or file-like object): + Source of audio data. The following types are accepted: + + * ``path-like``: File path or URL. + * ``file-like``: Object with ``read(size: int) -> bytes`` method. + + frame_offset (int, optional): + Number of samples to skip before start reading data. + num_frames (int, optional): + Maximum number of samples to read. ``-1`` reads all the remaining samples, + starting from ``frame_offset``. + normalize (bool, optional): + TorchCodec always returns normalized float32 samples. This parameter + is ignored and a warning is issued if set to False. + Default: ``True``. + channels_first (bool, optional): + When True, the returned Tensor has dimension `[channel, time]`. + Otherwise, the returned Tensor's dimension is `[time, channel]`. + format (str or None, optional): + Format hint for the decoder. May not be supported by all TorchCodec + decoders. (Default: ``None``) + buffer_size (int, optional): + Not used by TorchCodec AudioDecoder. Provided for API compatibility. + backend (str or None, optional): + Not used by TorchCodec AudioDecoder. Provided for API compatibility. + + Returns: + (torch.Tensor, int): Resulting Tensor and sample rate. + Always returns float32 tensors. If ``channels_first=True``, shape is + `[channel, time]`, otherwise `[time, channel]`. + + Raises: + ImportError: If torchcodec is not available. + ValueError: If unsupported parameters are used. + RuntimeError: If TorchCodec fails to decode the audio. + + Note: + - TorchCodec always returns normalized float32 samples, so the ``normalize`` + parameter has no effect. + - The ``buffer_size`` and ``backend`` parameters are ignored. + - Not all audio formats supported by torchaudio backends may be supported + by TorchCodec. + """ + # Import torchcodec here to provide clear error if not available + try: + from torchcodec.decoders import AudioDecoder + except ImportError as e: + raise ImportError( + "TorchCodec is required for load_with_torchcodec. " + "Please install torchcodec to use this function." + ) from e + + # Parameter validation and warnings + if not normalize: + import warnings + warnings.warn( + "TorchCodec AudioDecoder always returns normalized float32 samples. " + "The 'normalize=False' parameter is ignored.", + UserWarning, + stacklevel=2 + ) + + if buffer_size != 4096: + import warnings + warnings.warn( + "The 'buffer_size' parameter is not used by TorchCodec AudioDecoder.", + UserWarning, + stacklevel=2 + ) + + if backend is not None: + import warnings + warnings.warn( + "The 'backend' parameter is not used by TorchCodec AudioDecoder.", + UserWarning, + stacklevel=2 + ) + + if format is not None: + import warnings + warnings.warn( + "The 'format' parameter is not supported by TorchCodec AudioDecoder.", + UserWarning, + stacklevel=2 + ) + + # Create AudioDecoder + try: + decoder = AudioDecoder(uri) + except Exception as e: + raise RuntimeError(f"Failed to create AudioDecoder for {uri}: {e}") from e + + # Get sample rate from metadata + sample_rate = decoder.metadata.sample_rate + if sample_rate is None: + raise RuntimeError("Unable to determine sample rate from audio metadata") + + # Decode the entire file first, then subsample manually + # This is the simplest approach since torchcodec uses time-based indexing + try: + audio_samples = decoder.get_all_samples() + except Exception as e: + raise RuntimeError(f"Failed to decode audio samples: {e}") from e + + data = audio_samples.data + + # Apply frame_offset and num_frames (which are actually sample offsets) + if frame_offset > 0: + if frame_offset >= data.shape[1]: + # Return empty tensor if offset is beyond available data + empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0]) + return torch.zeros(empty_shape, dtype=torch.float32), sample_rate + data = data[:, frame_offset:] + + if num_frames == 0: + # Return empty tensor if num_frames is 0 + empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0]) + return torch.zeros(empty_shape, dtype=torch.float32), sample_rate + elif num_frames > 0: + data = data[:, :num_frames] + + # TorchCodec returns data in [channel, time] format by default + # Handle channels_first parameter + if not channels_first: + data = data.transpose(0, 1) # [channel, time] -> [time, channel] + + return data, sample_rate \ No newline at end of file diff --git a/test/torchaudio_unittest/test_load_torchcodec.py b/test/torchaudio_unittest/test_load_torchcodec.py new file mode 100644 index 0000000000..62af890aca --- /dev/null +++ b/test/torchaudio_unittest/test_load_torchcodec.py @@ -0,0 +1,193 @@ +from unittest.mock import patch +import subprocess +import re + +import pytest +import torch + +import torchaudio +from torchaudio import load_with_torchcodec +from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule + + +def get_ffmpeg_version(): + """Get FFmpeg version to check for compatibility issues.""" + try: + result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + # Extract version number from output like "ffmpeg version 4.4.2-0ubuntu0.22.04.1" + match = re.search(r'ffmpeg version (\d+)\.', result.stdout) + if match: + return int(match.group(1)) + except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): + pass + return None + + +def is_ffmpeg4(): + """Check if FFmpeg version is 4, which has known compatibility issues.""" + version = get_ffmpeg_version() + return version == 4 + + +# Test with wav files that should work with both torchaudio and torchcodec +TEST_FILES = [ + "sinewave.wav", + "steam-train-whistle-daniel_simon.wav", + "vad-go-mono-32000.wav", + "vad-go-stereo-44100.wav", + "VCTK-Corpus/wav48/p224/p224_002.wav", +] + + +@skipIfNoModule("torchcodec") +@pytest.mark.parametrize("filename", TEST_FILES) +def test_basic_load(filename): + """Test basic loading functionality against torchaudio.load.""" + # Skip problematic files on FFmpeg4 due to known compatibility issues + if is_ffmpeg4() and filename != "sinewave.wav": + pytest.skip("FFmpeg4 has known compatibility issues with some audio files") + + file_path = get_asset_path(*filename.split("/")) + + # Load with torchaudio + waveform_ta, sample_rate_ta = torchaudio.load(file_path) + + # Load with torchcodec + waveform_tc, sample_rate_tc = load_with_torchcodec(file_path) + + # Check sample rates match + assert sample_rate_ta == sample_rate_tc + + # Check shapes match + assert waveform_ta.shape == waveform_tc.shape + + # Check data types (should both be float32) + assert waveform_ta.dtype == torch.float32 + assert waveform_tc.dtype == torch.float32 + + # Check values are close (allowing for small differences in decoders) + torch.testing.assert_close(waveform_ta, waveform_tc) + +@skipIfNoModule("torchcodec") +@pytest.mark.parametrize("frame_offset,num_frames", [ + (0, 1000), # First 1000 samples + (1000, 2000), # 2000 samples starting from 1000 + (5000, -1), # From 5000 to end + (0, -1), # Full file +]) +def test_frame_offset_and_num_frames(frame_offset, num_frames): + """Test frame_offset and num_frames parameters.""" + file_path = get_asset_path("sinewave.wav") + + # Load with torchaudio + waveform_ta, sample_rate_ta = torchaudio.load( + file_path, frame_offset=frame_offset, num_frames=num_frames + ) + + # Load with torchcodec + waveform_tc, sample_rate_tc = load_with_torchcodec( + file_path, frame_offset=frame_offset, num_frames=num_frames + ) + + # Check results match + assert sample_rate_ta == sample_rate_tc + assert waveform_ta.shape == waveform_tc.shape + torch.testing.assert_close(waveform_ta, waveform_tc) + +@skipIfNoModule("torchcodec") +def test_channels_first(): + """Test channels_first parameter.""" + file_path = get_asset_path("sinewave.wav") # Use sinewave.wav for compatibility + + # Test channels_first=True (default) + waveform_cf_true, sample_rate = load_with_torchcodec(file_path, channels_first=True) + + # Test channels_first=False + waveform_cf_false, _ = load_with_torchcodec(file_path, channels_first=False) + + # Check that transpose relationship holds + assert waveform_cf_true.shape == waveform_cf_false.transpose(0, 1).shape + torch.testing.assert_close(waveform_cf_true, waveform_cf_false.transpose(0, 1)) + + # Compare with torchaudio + waveform_ta_true, _ = torchaudio.load(file_path, channels_first=True) + waveform_ta_false, _ = torchaudio.load(file_path, channels_first=False) + + assert waveform_cf_true.shape == waveform_ta_true.shape + assert waveform_cf_false.shape == waveform_ta_false.shape + torch.testing.assert_close(waveform_cf_true, waveform_ta_true) + torch.testing.assert_close(waveform_cf_false, waveform_ta_false) + +@skipIfNoModule("torchcodec") +def test_normalize_parameter_warning(): + """Test that normalize=False produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="normalize=False.*ignored"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, normalize=False) + + # Result should still be float32 (normalized) + assert waveform.dtype == torch.float32 + +@skipIfNoModule("torchcodec") +def test_buffer_size_parameter_warning(): + """Test that non-default buffer_size produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="buffer_size.*not used"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, buffer_size=8192) + + +@skipIfNoModule("torchcodec") +def test_backend_parameter_warning(): + """Test that specifying backend produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="backend.*not used"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, backend="ffmpeg") + + +@skipIfNoModule("torchcodec") +def test_invalid_file(): + """Test that invalid files raise appropriate errors.""" + with pytest.raises(RuntimeError, match="Failed to create AudioDecoder"): + load_with_torchcodec("/nonexistent/file.wav") + + +@skipIfNoModule("torchcodec") +def test_format_parameter(): + """Test that format parameter produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="format.*not supported"): + waveform, sample_rate = load_with_torchcodec(file_path, format="wav") + + # Check basic properties + assert waveform.dtype == torch.float32 + assert sample_rate > 0 + + +@skipIfNoModule("torchcodec") +def test_multiple_warnings(): + """Test that multiple unsupported parameters produce multiple warnings.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns() as warning_list: + # This should produce multiple warnings + waveform, sample_rate = load_with_torchcodec( + file_path, + normalize=False, + buffer_size=8192, + backend="ffmpeg" + ) + + + # Check that expected warnings are present + messages = [str(w.message) for w in warning_list] + assert any("normalize=False" in msg for msg in messages) + assert any("buffer_size" in msg for msg in messages) + assert any("backend" in msg for msg in messages) \ No newline at end of file From 1ee65166a266b42e091d4e6ede99a6cb86a9630e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 15 Jul 2025 17:09:37 +0100 Subject: [PATCH 2/3] Add save_with_torchcodec, modify save()'s warnings (#3975) --- docs/source/torchaudio.rst | 9 +- src/torchaudio/__init__.py | 6 +- src/torchaudio/_backend/utils.py | 16 + src/torchaudio/_torchcodec.py | 209 +++++++- .../test_load_save_torchcodec.py | 449 ++++++++++++++++++ .../test_load_torchcodec.py | 193 -------- 6 files changed, 674 insertions(+), 208 deletions(-) create mode 100644 test/torchaudio_unittest/test_load_save_torchcodec.py delete mode 100644 test/torchaudio_unittest/test_load_torchcodec.py diff --git a/docs/source/torchaudio.rst b/docs/source/torchaudio.rst index 13e14ceb6c..aa933e84ad 100644 --- a/docs/source/torchaudio.rst +++ b/docs/source/torchaudio.rst @@ -9,9 +9,11 @@ torchaudio - Most APIs listed below are deprecated in 2.8 and will be removed in 2.9. - The decoding and encoding capabilities of PyTorch for both audio and video - are being consolidated into TorchCodec. We provide - ``torchaudio.load_with_torchcodec()`` as a replacement for - ``torchaudio.load()``. + are being consolidated into TorchCodec. For convenience, we provide + :func:`~torchaudio.load_with_torchcodec` as a replacement for + :func:`~torchaudio.load` and :func:`~torchaudio.save_with_torchcodec` as a + replacement for :func:`~torchaudio.save`, but we recommend that you port + your code to native torchcodec APIs. Please see https://github.com/pytorch/audio/issues/3902 for more information. @@ -30,6 +32,7 @@ it easy to handle audio data. load load_with_torchcodec save + save_with_torchcodec list_audio_backends .. _backend: diff --git a/src/torchaudio/__init__.py b/src/torchaudio/__init__.py index 2a6f924ecf..e533cafe9d 100644 --- a/src/torchaudio/__init__.py +++ b/src/torchaudio/__init__.py @@ -8,16 +8,15 @@ info as _info, list_audio_backends as _list_audio_backends, load, - save as _save, + save, set_audio_backend as _set_audio_backend, ) -from ._torchcodec import load_with_torchcodec +from ._torchcodec import load_with_torchcodec, save_with_torchcodec AudioMetaData = dropping_class_io_support(_AudioMetaData) get_audio_backend = dropping_io_support(_get_audio_backend) info = dropping_io_support(_info) list_audio_backends = dropping_io_support(_list_audio_backends) -save = dropping_io_support(_save) set_audio_backend = dropping_io_support(_set_audio_backend) from . import ( # noqa: F401 @@ -46,6 +45,7 @@ "AudioMetaData", "load", "load_with_torchcodec", + "save_with_torchcodec", "info", "save", "io", diff --git a/src/torchaudio/_backend/utils.py b/src/torchaudio/_backend/utils.py index c39bc936d2..eb7c51f0cb 100644 --- a/src/torchaudio/_backend/utils.py +++ b/src/torchaudio/_backend/utils.py @@ -252,6 +252,14 @@ def save( ): """Save audio data to file. + .. warning:: + In 2.9, this function's implementation will be changed to use + :func:`~torchaudio.save_with_torchcodec` under the hood. Some + parameters like format, encoding, bits_per_sample, buffer_size, and + ``backend`` will be ignored. We recommend that you port your code to + rely directly on TorchCodec's decoder instead: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder + Note: The formats this function can handle depend on the availability of backends. Please use the following functions to fetch the supported formats. @@ -326,6 +334,14 @@ def save( Refer to http://sox.sourceforge.net/soxformat.html for more details. """ + warnings.warn( + "In 2.9, this function's implementation will be changed to use " + "torchaudio.save_with_torchcodec` under the hood. Some " + "parameters like format, encoding, bits_per_sample, buffer_size, and " + "``backend`` will be ignored. We recommend that you port your code to " + "rely directly on TorchCodec's encoder instead: " + "https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder" + ) backend = dispatcher(uri, format, backend) return backend.save( uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression diff --git a/src/torchaudio/_torchcodec.py b/src/torchaudio/_torchcodec.py index db62ec413a..ab82e1fb77 100644 --- a/src/torchaudio/_torchcodec.py +++ b/src/torchaudio/_torchcodec.py @@ -20,15 +20,16 @@ def load_with_torchcodec( .. note:: - This function supports the same API as ``torchaudio.load()``, and relies - on TorchCodec's decoding capabilities under the hood. It is provided for - convenience, but we do recommend that you port your code to natively use - ``torchcodec``'s ``AudioDecoder`` class for better performance: + This function supports the same API as :func:`~torchaudio.load`, and + relies on TorchCodec's decoding capabilities under the hood. It is + provided for convenience, but we do recommend that you port your code to + natively use ``torchcodec``'s ``AudioDecoder`` class for better + performance: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder. - In TorchAudio 2.9, ``torchaudio.load()`` will be relying on - ``load_with_torchcodec``. Note that some parameters of - ``torchaudio.load()``, like ``normalize``, ``buffer_size``, and - ``backend``, are ignored by ``load_with_torchcodec``. + In TorchAudio 2.9, :func:`~torchaudio.load` will be relying on + :func:`~torchaudio.load_with_torchcodec`. Note that some parameters of + :func:`~torchaudio.load`, like ``normalize``, ``buffer_size``, and + ``backend``, are ignored by :func:`~torchaudio.load_with_torchcodec`. Args: @@ -158,4 +159,194 @@ def load_with_torchcodec( if not channels_first: data = data.transpose(0, 1) # [channel, time] -> [time, channel] - return data, sample_rate \ No newline at end of file + return data, sample_rate + + +def save_with_torchcodec( + uri: Union[str, os.PathLike], + src: torch.Tensor, + sample_rate: int, + channels_first: bool = True, + format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, + buffer_size: int = 4096, + backend: Optional[str] = None, + compression: Optional[Union[float, int]] = None, +) -> None: + """Save audio data to file using TorchCodec's AudioEncoder. + + .. note:: + + This function supports the same API as :func:`~torchaudio.save`, and + relies on TorchCodec's encoding capabilities under the hood. It is + provided for convenience, but we do recommend that you port your code to + natively use ``torchcodec``'s ``AudioEncoder`` class for better + performance: + https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder. + In TorchAudio 2.9, :func:`~torchaudio.save` will be relying on + :func:`~torchaudio.save_with_torchcodec`. Note that some parameters of + :func:`~torchaudio.save`, like ``format``, ``encoding``, + ``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored by + are ignored by :func:`~torchaudio.save_with_torchcodec`. + + This function provides a TorchCodec-based alternative to torchaudio.save + with the same API. TorchCodec's AudioEncoder provides efficient encoding + with FFmpeg under the hood. + + Args: + uri (path-like object): + Path to save the audio file. The file extension determines the format. + + src (torch.Tensor): + Audio data to save. Must be a 1D or 2D tensor with float32 values + in the range [-1, 1]. If 2D, shape should be [channel, time] when + channels_first=True, or [time, channel] when channels_first=False. + + sample_rate (int): + Sample rate of the audio data. + + channels_first (bool, optional): + Indicates whether the input tensor has channels as the first dimension. + If True, expects [channel, time]. If False, expects [time, channel]. + Default: True. + + format (str or None, optional): + Audio format hint. Not used by TorchCodec (format is determined by + file extension). A warning is issued if provided. + Default: None. + + encoding (str or None, optional): + Audio encoding. Not fully supported by TorchCodec AudioEncoder. + A warning is issued if provided. Default: None. + + bits_per_sample (int or None, optional): + Bits per sample. Not directly supported by TorchCodec AudioEncoder. + A warning is issued if provided. Default: None. + + buffer_size (int, optional): + Not used by TorchCodec AudioEncoder. Provided for API compatibility. + A warning is issued if not default value. Default: 4096. + + backend (str or None, optional): + Not used by TorchCodec AudioEncoder. Provided for API compatibility. + A warning is issued if provided. Default: None. + + compression (float, int or None, optional): + Compression level or bit rate. Maps to bit_rate parameter in + TorchCodec AudioEncoder. Default: None. + + Raises: + ImportError: If torchcodec is not available. + ValueError: If input parameters are invalid. + RuntimeError: If TorchCodec fails to encode the audio. + + Note: + - TorchCodec AudioEncoder expects float32 samples in [-1, 1] range. + - Some parameters (format, encoding, bits_per_sample, buffer_size, backend) + are not used by TorchCodec but are provided for API compatibility. + - The output format is determined by the file extension in the uri. + - TorchCodec uses FFmpeg under the hood for encoding. + """ + # Import torchcodec here to provide clear error if not available + try: + from torchcodec.encoders import AudioEncoder + except ImportError as e: + raise ImportError( + "TorchCodec is required for save_with_torchcodec. " + "Please install torchcodec to use this function." + ) from e + + # Parameter validation and warnings + if format is not None: + import warnings + warnings.warn( + "The 'format' parameter is not used by TorchCodec AudioEncoder. " + "Format is determined by the file extension.", + UserWarning, + stacklevel=2 + ) + + if encoding is not None: + import warnings + warnings.warn( + "The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder.", + UserWarning, + stacklevel=2 + ) + + if bits_per_sample is not None: + import warnings + warnings.warn( + "The 'bits_per_sample' parameter is not directly supported by TorchCodec AudioEncoder.", + UserWarning, + stacklevel=2 + ) + + if buffer_size != 4096: + import warnings + warnings.warn( + "The 'buffer_size' parameter is not used by TorchCodec AudioEncoder.", + UserWarning, + stacklevel=2 + ) + + if backend is not None: + import warnings + warnings.warn( + "The 'backend' parameter is not used by TorchCodec AudioEncoder.", + UserWarning, + stacklevel=2 + ) + + # Input validation + if not isinstance(src, torch.Tensor): + raise ValueError(f"Expected src to be a torch.Tensor, got {type(src)}") + + if src.dtype != torch.float32: + src = src.float() + + if sample_rate <= 0: + raise ValueError(f"sample_rate must be positive, got {sample_rate}") + + # Handle tensor shape and channels_first + if src.ndim == 1: + # Convert to 2D: [1, time] for channels_first=True + if channels_first: + data = src.unsqueeze(0) # [1, time] + else: + # For channels_first=False, input is [time] -> reshape to [time, 1] -> transpose to [1, time] + data = src.unsqueeze(1).transpose(0, 1) # [time, 1] -> [1, time] + elif src.ndim == 2: + if channels_first: + data = src # Already [channel, time] + else: + data = src.transpose(0, 1) # [time, channel] -> [channel, time] + else: + raise ValueError(f"Expected 1D or 2D tensor, got {src.ndim}D tensor") + + # Create AudioEncoder + try: + encoder = AudioEncoder(data, sample_rate=sample_rate) + except Exception as e: + raise RuntimeError(f"Failed to create AudioEncoder: {e}") from e + + # Determine bit_rate from compression parameter + bit_rate = None + if compression is not None: + if isinstance(compression, (int, float)): + bit_rate = int(compression) + else: + import warnings + warnings.warn( + f"Unsupported compression type {type(compression)}. " + "TorchCodec AudioEncoder expects int or float for bit_rate.", + UserWarning, + stacklevel=2 + ) + + # Save to file + try: + encoder.to_file(uri, bit_rate=bit_rate) + except Exception as e: + raise RuntimeError(f"Failed to save audio to {uri}: {e}") from e diff --git a/test/torchaudio_unittest/test_load_save_torchcodec.py b/test/torchaudio_unittest/test_load_save_torchcodec.py new file mode 100644 index 0000000000..3edb4c423b --- /dev/null +++ b/test/torchaudio_unittest/test_load_save_torchcodec.py @@ -0,0 +1,449 @@ +from unittest.mock import patch +import re +import subprocess + + +import os +import tempfile +import pytest +import torch + +import torchaudio +from torchaudio import load_with_torchcodec, save_with_torchcodec +from torchaudio_unittest.common_utils import get_asset_path + +def get_ffmpeg_version(): + """Get FFmpeg version to check for compatibility issues.""" + try: + result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + # Extract version number from output like "ffmpeg version 4.4.2-0ubuntu0.22.04.1" + match = re.search(r'ffmpeg version (\d+)\.', result.stdout) + if match: + return int(match.group(1)) + except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): + pass + return None + + +def is_ffmpeg4(): + """Check if FFmpeg version is 4, which has known compatibility issues.""" + version = get_ffmpeg_version() + return version == 4 + + +# Test with wav files that should work with both torchaudio and torchcodec +TEST_FILES = [ + "sinewave.wav", + "steam-train-whistle-daniel_simon.wav", + "vad-go-mono-32000.wav", + "vad-go-stereo-44100.wav", + "VCTK-Corpus/wav48/p224/p224_002.wav", +] + + +@pytest.mark.parametrize("filename", TEST_FILES) +def test_basic_load(filename): + """Test basic loading functionality against torchaudio.load.""" + # Skip problematic files on FFmpeg4 due to known compatibility issues + if is_ffmpeg4() and filename != "sinewave.wav": + pytest.skip("FFmpeg4 has known compatibility issues with some audio files") + + file_path = get_asset_path(*filename.split("/")) + + # Load with torchaudio + waveform_ta, sample_rate_ta = torchaudio.load(file_path) + + # Load with torchcodec + waveform_tc, sample_rate_tc = load_with_torchcodec(file_path) + + # Check sample rates match + assert sample_rate_ta == sample_rate_tc + + # Check shapes match + assert waveform_ta.shape == waveform_tc.shape + + # Check data types (should both be float32) + assert waveform_ta.dtype == torch.float32 + assert waveform_tc.dtype == torch.float32 + + # Check values are close (allowing for small differences in decoders) + torch.testing.assert_close(waveform_ta, waveform_tc) + +@pytest.mark.parametrize("frame_offset,num_frames", [ + (0, 1000), # First 1000 samples + (1000, 2000), # 2000 samples starting from 1000 + (5000, -1), # From 5000 to end + (0, -1), # Full file +]) +def test_frame_offset_and_num_frames(frame_offset, num_frames): + """Test frame_offset and num_frames parameters.""" + file_path = get_asset_path("sinewave.wav") + + # Load with torchaudio + waveform_ta, sample_rate_ta = torchaudio.load( + file_path, frame_offset=frame_offset, num_frames=num_frames + ) + + # Load with torchcodec + waveform_tc, sample_rate_tc = load_with_torchcodec( + file_path, frame_offset=frame_offset, num_frames=num_frames + ) + + # Check results match + assert sample_rate_ta == sample_rate_tc + assert waveform_ta.shape == waveform_tc.shape + torch.testing.assert_close(waveform_ta, waveform_tc) + +def test_channels_first(): + """Test channels_first parameter.""" + file_path = get_asset_path("sinewave.wav") # Use sinewave.wav for compatibility + + # Test channels_first=True (default) + waveform_cf_true, sample_rate = load_with_torchcodec(file_path, channels_first=True) + + # Test channels_first=False + waveform_cf_false, _ = load_with_torchcodec(file_path, channels_first=False) + + # Check that transpose relationship holds + assert waveform_cf_true.shape == waveform_cf_false.transpose(0, 1).shape + torch.testing.assert_close(waveform_cf_true, waveform_cf_false.transpose(0, 1)) + + # Compare with torchaudio + waveform_ta_true, _ = torchaudio.load(file_path, channels_first=True) + waveform_ta_false, _ = torchaudio.load(file_path, channels_first=False) + + assert waveform_cf_true.shape == waveform_ta_true.shape + assert waveform_cf_false.shape == waveform_ta_false.shape + torch.testing.assert_close(waveform_cf_true, waveform_ta_true) + torch.testing.assert_close(waveform_cf_false, waveform_ta_false) + +def test_normalize_parameter_warning(): + """Test that normalize=False produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="normalize=False.*ignored"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, normalize=False) + + # Result should still be float32 (normalized) + assert waveform.dtype == torch.float32 + +def test_buffer_size_parameter_warning(): + """Test that non-default buffer_size produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="buffer_size.*not used"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, buffer_size=8192) + + +def test_backend_parameter_warning(): + """Test that specifying backend produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="backend.*not used"): + # This should produce a warning + waveform, sample_rate = load_with_torchcodec(file_path, backend="ffmpeg") + + +def test_invalid_file(): + """Test that invalid files raise appropriate errors.""" + with pytest.raises(RuntimeError, match="Failed to create AudioDecoder"): + load_with_torchcodec("/nonexistent/file.wav") + + +def test_format_parameter(): + """Test that format parameter produces a warning.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns(UserWarning, match="format.*not supported"): + waveform, sample_rate = load_with_torchcodec(file_path, format="wav") + + # Check basic properties + assert waveform.dtype == torch.float32 + assert sample_rate > 0 + + +def test_multiple_warnings(): + """Test that multiple unsupported parameters produce multiple warnings.""" + file_path = get_asset_path("sinewave.wav") + + with pytest.warns() as warning_list: + # This should produce multiple warnings + waveform, sample_rate = load_with_torchcodec( + file_path, + normalize=False, + buffer_size=8192, + backend="ffmpeg" + ) + + + # Check that expected warnings are present + messages = [str(w.message) for w in warning_list] + assert any("normalize=False" in msg for msg in messages) + assert any("buffer_size" in msg for msg in messages) + assert any("backend" in msg for msg in messages) + + +# ===== SAVE WITH TORCHCODEC TESTS ===== + +@pytest.mark.parametrize("filename", TEST_FILES) +def test_save_basic_save(filename): + """Test basic saving functionality against torchaudio.save.""" + # Load a test file first + file_path = get_asset_path(*filename.split("/")) + waveform, sample_rate = torchaudio.load(file_path) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save with torchaudio + ta_path = os.path.join(temp_dir, "ta_output.wav") + torchaudio.save(ta_path, waveform, sample_rate) + + # Save with torchcodec + tc_path = os.path.join(temp_dir, "tc_output.wav") + save_with_torchcodec(tc_path, waveform, sample_rate) + + # Load both back and compare + waveform_ta, sample_rate_ta = torchaudio.load(ta_path) + waveform_tc, sample_rate_tc = torchaudio.load(tc_path) + + # Check sample rates match + assert sample_rate_ta == sample_rate_tc + + # Check shapes match + assert waveform_ta.shape == waveform_tc.shape + + # Check data types (should both be float32) + assert waveform_ta.dtype == torch.float32 + assert waveform_tc.dtype == torch.float32 + + # Check values are close (allowing for small differences in encoders) + torch.testing.assert_close(waveform_ta, waveform_tc, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("channels_first", [True, False]) +def test_save_channels_first(channels_first): + """Test channels_first parameter.""" + # Create test data + if channels_first: + waveform = torch.randn(2, 16000) # [channel, time] + else: + waveform = torch.randn(16000, 2) # [time, channel] + + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + # Save with torchaudio + ta_path = os.path.join(temp_dir, "ta_output.wav") + torchaudio.save(ta_path, waveform, sample_rate, channels_first=channels_first) + + # Save with torchcodec + tc_path = os.path.join(temp_dir, "tc_output.wav") + save_with_torchcodec(tc_path, waveform, sample_rate, channels_first=channels_first) + + # Load both back and compare + waveform_ta, sample_rate_ta = torchaudio.load(ta_path) + waveform_tc, sample_rate_tc = torchaudio.load(tc_path) + + # Check results match + assert sample_rate_ta == sample_rate_tc + assert waveform_ta.shape == waveform_tc.shape + torch.testing.assert_close(waveform_ta, waveform_tc, atol=1e-3, rtol=1e-3) + + +def test_save_compression_parameter(): + """Test compression parameter (maps to bit_rate).""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + # Test with compression (bit_rate) + output_path = os.path.join(temp_dir, "output.wav") + save_with_torchcodec(output_path, waveform, sample_rate, compression=128000) + + # Should not raise an error and file should exist + assert os.path.exists(output_path) + + # Load back and check basic properties + waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) + assert sample_rate_loaded == sample_rate + assert waveform_loaded.shape[0] == 1 # Should be mono + + +def test_save_format_parameter_warning(): + """Test that format parameter produces a warning.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns(UserWarning, match="format.*not used"): + save_with_torchcodec(output_path, waveform, sample_rate, format="wav") + + # Should still work despite warning + assert os.path.exists(output_path) + + +def test_save_encoding_parameter_warning(): + """Test that encoding parameter produces a warning.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns(UserWarning, match="encoding.*not fully supported"): + save_with_torchcodec(output_path, waveform, sample_rate, encoding="PCM_16") + + # Should still work despite warning + assert os.path.exists(output_path) + + +def test_save_bits_per_sample_parameter_warning(): + """Test that bits_per_sample parameter produces a warning.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns(UserWarning, match="bits_per_sample.*not directly supported"): + save_with_torchcodec(output_path, waveform, sample_rate, bits_per_sample=16) + + # Should still work despite warning + assert os.path.exists(output_path) + + +def test_save_buffer_size_parameter_warning(): + """Test that non-default buffer_size produces a warning.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns(UserWarning, match="buffer_size.*not used"): + save_with_torchcodec(output_path, waveform, sample_rate, buffer_size=8192) + + # Should still work despite warning + assert os.path.exists(output_path) + + +def test_save_backend_parameter_warning(): + """Test that specifying backend produces a warning.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns(UserWarning, match="backend.*not used"): + save_with_torchcodec(output_path, waveform, sample_rate, backend="ffmpeg") + + # Should still work despite warning + assert os.path.exists(output_path) + + +def test_save_edge_cases(): + """Test edge cases and error conditions.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + # Test with very small waveform + small_waveform = torch.randn(1, 10) + save_with_torchcodec(output_path, small_waveform, sample_rate) + waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) + assert sample_rate_loaded == sample_rate + + # Test with different sample rates + for sr in [8000, 22050, 44100]: + sr_path = os.path.join(temp_dir, f"output_{sr}.wav") + save_with_torchcodec(sr_path, waveform, sr) + waveform_loaded, sample_rate_loaded = torchaudio.load(sr_path) + assert sample_rate_loaded == sr + + +def test_save_invalid_inputs(): + """Test that invalid inputs raise appropriate errors.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + # Test with invalid sample rate + with pytest.raises(ValueError, match="sample_rate must be positive"): + save_with_torchcodec(output_path, waveform, -1) + + # Test with invalid tensor dimensions + with pytest.raises(ValueError, match="Expected 1D or 2D tensor"): + invalid_waveform = torch.randn(1, 2, 16000) # 3D tensor + save_with_torchcodec(output_path, invalid_waveform, sample_rate) + + # Test with non-tensor input + with pytest.raises(ValueError, match="Expected src to be a torch.Tensor"): + save_with_torchcodec(output_path, [1, 2, 3], sample_rate) + + +def test_save_multiple_warnings(): + """Test that multiple unsupported parameters produce multiple warnings.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + output_path = os.path.join(temp_dir, "output.wav") + + with pytest.warns() as warning_list: + save_with_torchcodec( + output_path, + waveform, + sample_rate, + format="wav", + encoding="PCM_16", + bits_per_sample=16, + buffer_size=8192, + backend="ffmpeg" + ) + + # Check that expected warnings are present + messages = [str(w.message) for w in warning_list] + assert any("format" in msg for msg in messages) + assert any("encoding" in msg for msg in messages) + assert any("bits_per_sample" in msg for msg in messages) + assert any("buffer_size" in msg for msg in messages) + assert any("backend" in msg for msg in messages) + + # Should still work despite warnings + assert os.path.exists(output_path) + + +def test_save_different_formats(): + """Test saving to different audio formats.""" + waveform = torch.randn(1, 16000) + sample_rate = 16000 + + with tempfile.TemporaryDirectory() as temp_dir: + # Test common formats + formats = ["wav", "mp3", "flac"] + + for fmt in formats: + output_path = os.path.join(temp_dir, f"output.{fmt}") + try: + save_with_torchcodec(output_path, waveform, sample_rate) + assert os.path.exists(output_path) + + # Try to load back (may not work for all formats with all backends) + try: + waveform_loaded, sample_rate_loaded = torchaudio.load(output_path) + assert sample_rate_loaded == sample_rate + except Exception: + # Some formats might not be supported by the loading backend + pass + except Exception as e: + # Some formats might not be supported by torchcodec + pytest.skip(f"Format {fmt} not supported: {e}") \ No newline at end of file diff --git a/test/torchaudio_unittest/test_load_torchcodec.py b/test/torchaudio_unittest/test_load_torchcodec.py deleted file mode 100644 index 62af890aca..0000000000 --- a/test/torchaudio_unittest/test_load_torchcodec.py +++ /dev/null @@ -1,193 +0,0 @@ -from unittest.mock import patch -import subprocess -import re - -import pytest -import torch - -import torchaudio -from torchaudio import load_with_torchcodec -from torchaudio_unittest.common_utils import get_asset_path, skipIfNoModule - - -def get_ffmpeg_version(): - """Get FFmpeg version to check for compatibility issues.""" - try: - result = subprocess.run(['ffmpeg', '-version'], capture_output=True, text=True, timeout=5) - if result.returncode == 0: - # Extract version number from output like "ffmpeg version 4.4.2-0ubuntu0.22.04.1" - match = re.search(r'ffmpeg version (\d+)\.', result.stdout) - if match: - return int(match.group(1)) - except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): - pass - return None - - -def is_ffmpeg4(): - """Check if FFmpeg version is 4, which has known compatibility issues.""" - version = get_ffmpeg_version() - return version == 4 - - -# Test with wav files that should work with both torchaudio and torchcodec -TEST_FILES = [ - "sinewave.wav", - "steam-train-whistle-daniel_simon.wav", - "vad-go-mono-32000.wav", - "vad-go-stereo-44100.wav", - "VCTK-Corpus/wav48/p224/p224_002.wav", -] - - -@skipIfNoModule("torchcodec") -@pytest.mark.parametrize("filename", TEST_FILES) -def test_basic_load(filename): - """Test basic loading functionality against torchaudio.load.""" - # Skip problematic files on FFmpeg4 due to known compatibility issues - if is_ffmpeg4() and filename != "sinewave.wav": - pytest.skip("FFmpeg4 has known compatibility issues with some audio files") - - file_path = get_asset_path(*filename.split("/")) - - # Load with torchaudio - waveform_ta, sample_rate_ta = torchaudio.load(file_path) - - # Load with torchcodec - waveform_tc, sample_rate_tc = load_with_torchcodec(file_path) - - # Check sample rates match - assert sample_rate_ta == sample_rate_tc - - # Check shapes match - assert waveform_ta.shape == waveform_tc.shape - - # Check data types (should both be float32) - assert waveform_ta.dtype == torch.float32 - assert waveform_tc.dtype == torch.float32 - - # Check values are close (allowing for small differences in decoders) - torch.testing.assert_close(waveform_ta, waveform_tc) - -@skipIfNoModule("torchcodec") -@pytest.mark.parametrize("frame_offset,num_frames", [ - (0, 1000), # First 1000 samples - (1000, 2000), # 2000 samples starting from 1000 - (5000, -1), # From 5000 to end - (0, -1), # Full file -]) -def test_frame_offset_and_num_frames(frame_offset, num_frames): - """Test frame_offset and num_frames parameters.""" - file_path = get_asset_path("sinewave.wav") - - # Load with torchaudio - waveform_ta, sample_rate_ta = torchaudio.load( - file_path, frame_offset=frame_offset, num_frames=num_frames - ) - - # Load with torchcodec - waveform_tc, sample_rate_tc = load_with_torchcodec( - file_path, frame_offset=frame_offset, num_frames=num_frames - ) - - # Check results match - assert sample_rate_ta == sample_rate_tc - assert waveform_ta.shape == waveform_tc.shape - torch.testing.assert_close(waveform_ta, waveform_tc) - -@skipIfNoModule("torchcodec") -def test_channels_first(): - """Test channels_first parameter.""" - file_path = get_asset_path("sinewave.wav") # Use sinewave.wav for compatibility - - # Test channels_first=True (default) - waveform_cf_true, sample_rate = load_with_torchcodec(file_path, channels_first=True) - - # Test channels_first=False - waveform_cf_false, _ = load_with_torchcodec(file_path, channels_first=False) - - # Check that transpose relationship holds - assert waveform_cf_true.shape == waveform_cf_false.transpose(0, 1).shape - torch.testing.assert_close(waveform_cf_true, waveform_cf_false.transpose(0, 1)) - - # Compare with torchaudio - waveform_ta_true, _ = torchaudio.load(file_path, channels_first=True) - waveform_ta_false, _ = torchaudio.load(file_path, channels_first=False) - - assert waveform_cf_true.shape == waveform_ta_true.shape - assert waveform_cf_false.shape == waveform_ta_false.shape - torch.testing.assert_close(waveform_cf_true, waveform_ta_true) - torch.testing.assert_close(waveform_cf_false, waveform_ta_false) - -@skipIfNoModule("torchcodec") -def test_normalize_parameter_warning(): - """Test that normalize=False produces a warning.""" - file_path = get_asset_path("sinewave.wav") - - with pytest.warns(UserWarning, match="normalize=False.*ignored"): - # This should produce a warning - waveform, sample_rate = load_with_torchcodec(file_path, normalize=False) - - # Result should still be float32 (normalized) - assert waveform.dtype == torch.float32 - -@skipIfNoModule("torchcodec") -def test_buffer_size_parameter_warning(): - """Test that non-default buffer_size produces a warning.""" - file_path = get_asset_path("sinewave.wav") - - with pytest.warns(UserWarning, match="buffer_size.*not used"): - # This should produce a warning - waveform, sample_rate = load_with_torchcodec(file_path, buffer_size=8192) - - -@skipIfNoModule("torchcodec") -def test_backend_parameter_warning(): - """Test that specifying backend produces a warning.""" - file_path = get_asset_path("sinewave.wav") - - with pytest.warns(UserWarning, match="backend.*not used"): - # This should produce a warning - waveform, sample_rate = load_with_torchcodec(file_path, backend="ffmpeg") - - -@skipIfNoModule("torchcodec") -def test_invalid_file(): - """Test that invalid files raise appropriate errors.""" - with pytest.raises(RuntimeError, match="Failed to create AudioDecoder"): - load_with_torchcodec("/nonexistent/file.wav") - - -@skipIfNoModule("torchcodec") -def test_format_parameter(): - """Test that format parameter produces a warning.""" - file_path = get_asset_path("sinewave.wav") - - with pytest.warns(UserWarning, match="format.*not supported"): - waveform, sample_rate = load_with_torchcodec(file_path, format="wav") - - # Check basic properties - assert waveform.dtype == torch.float32 - assert sample_rate > 0 - - -@skipIfNoModule("torchcodec") -def test_multiple_warnings(): - """Test that multiple unsupported parameters produce multiple warnings.""" - file_path = get_asset_path("sinewave.wav") - - with pytest.warns() as warning_list: - # This should produce multiple warnings - waveform, sample_rate = load_with_torchcodec( - file_path, - normalize=False, - buffer_size=8192, - backend="ffmpeg" - ) - - - # Check that expected warnings are present - messages = [str(w.message) for w in warning_list] - assert any("normalize=False" in msg for msg in messages) - assert any("buffer_size" in msg for msg in messages) - assert any("backend" in msg for msg in messages) \ No newline at end of file From 54e5addbe211fec0c8582fad72367eafb3b845e5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 15 Jul 2025 18:11:09 +0100 Subject: [PATCH 3/3] Install torchcodec on nightly? --- .github/scripts/unittest-linux/install.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/scripts/unittest-linux/install.sh b/.github/scripts/unittest-linux/install.sh index 24dd7e3476..b2c3538eda 100755 --- a/.github/scripts/unittest-linux/install.sh +++ b/.github/scripts/unittest-linux/install.sh @@ -74,7 +74,8 @@ case $GPU_ARCH_TYPE in ;; esac PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${GPU_ARCH_ID}" -pip install --progress-bar=off --pre torch torchcodec --index-url="${PYTORCH_WHEEL_INDEX}" +pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" +pip install --progress-bar=off --pre torchcodec --index-url="https://download.pytorch.org/whl/nightly/cpu" # 2. Install torchaudio