|
24 | 24 | import importlib.metadata |
25 | 25 | import io |
26 | 26 | import os |
| 27 | +import tempfile |
27 | 28 | import time |
28 | 29 | import warnings |
29 | 30 | from pathlib import Path |
30 | | -from typing import Callable, Optional, Union, Mapping |
| 31 | +from typing import Callable, Mapping, Optional, Union |
31 | 32 |
|
32 | 33 | import requests |
33 | 34 | from requests import Response |
@@ -251,7 +252,7 @@ def _hash_md5(self, file: Union[str, Path]) -> str: |
251 | 252 |
|
252 | 253 | def upload( |
253 | 254 | self, |
254 | | - audio: str | Path | dict[str, str|Path], |
| 255 | + audio: str | Path | dict[str, str | Path], |
255 | 256 | media_url: Optional[str] = None, |
256 | 257 | callback: Optional[Callable] = None, |
257 | 258 | ) -> str: |
@@ -279,12 +280,39 @@ def upload( |
279 | 280 | or "media://{md5-hash-of-audio-file}" otherwise. |
280 | 281 | """ |
281 | 282 |
|
| 283 | + # whether to delete the audio file after upload. will only be set to True |
| 284 | + # when audio is provided as a waveform and saved in a temporary file. |
| 285 | + delete = False |
| 286 | + |
282 | 287 | if isinstance(audio, Mapping): |
283 | | - if "audio" not in audio: |
| 288 | + if "audio" in audio: |
| 289 | + audio = audio["audio"] |
| 290 | + |
| 291 | + elif "waveform" in audio: |
| 292 | + delete = True |
| 293 | + try: |
| 294 | + import scipy.io |
| 295 | + except ImportError: |
| 296 | + raise ImportError( |
| 297 | + "To process the waveform directly, you need to install `scipy`." |
| 298 | + ) |
| 299 | + |
| 300 | + sample_rate = audio["sample_rate"] |
| 301 | + waveform = audio["waveform"] |
| 302 | + # common pattern is to provide waveform as a torch tensor. |
| 303 | + # turn it into a numpy array before passing to scipy.io.wavfile. |
| 304 | + if hasattr(audio["waveform"], "numpy"): |
| 305 | + waveform = audio["waveform"].numpy(force=True) |
| 306 | + |
| 307 | + # write waveform to a temporary audio file |
| 308 | + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| 309 | + scipy.io.wavfile.write(f.name, sample_rate, waveform.squeeze()) |
| 310 | + f.flush() |
| 311 | + audio = f.name |
| 312 | + else: |
284 | 313 | raise ValueError( |
285 | 314 | "When `audio` is a dict, it must provide the path to the audio file in 'audio' key." |
286 | 315 | ) |
287 | | - audio = audio["audio"] |
288 | 316 |
|
289 | 317 | # get the total size of the file to upload |
290 | 318 | # to provide progress information to the hook |
@@ -318,6 +346,9 @@ def upload( |
318 | 346 | Failed to upload audio to presigned URL {presigned_url}. |
319 | 347 | Please check your internet connection or visit https://pyannote.openstatus.dev/ to check the status of the pyannoteAI API.""" |
320 | 348 | ) |
| 349 | + finally: |
| 350 | + if delete and os.path.exists(audio): |
| 351 | + os.remove(audio) |
321 | 352 |
|
322 | 353 | # TODO: handle HTTPError returned by the API |
323 | 354 | response.raise_for_status() |
|
0 commit comments