Skip to content

Commit

Permalink
Move file to /tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Turian committed Jun 16, 2024
1 parent ad578ee commit afa656d
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import os.path
import shutil
import tempfile
import uuid

import torchaudio

from cog import BasePredictor, Input, Path


Expand All @@ -25,23 +25,35 @@ def predict(
checkpoint: str = Input(
description="Model checkpoint to use. EARS-WHAM speech enhancement or EARS-Reverb dereverberation.",
choices=["EARS-WHAM", "EARS-Reverb"],
default="EARS-WHAM"
default="EARS-WHAM",
),
corrector: str = Input(
description="Corrector class for the PC sampler.",
choices=["ald", "langevin", "none"],
default="ald",
),
corrector_steps: int = Input(
description="Number of corrector steps", default=1
),
snr: float = Input(
description="SNR value for (annealed) Langevin dynamics.", default=0.5
),
N: int = Input(description="Number of reverse steps", default=30),
) -> Path:

# Make a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Copy the audio to the temporary directory
audio_path = os.path.join(temp_dir, os.path.basename(audio) + ".wav")
audio_path = os.path.join(temp_dir, "testfile.wav")

# They don't resample for us :\
x, sr = torchaudio.load(audio)
if sr != 48000:
x = torchaudio.transforms.Resampler(sr, 48000)(x)
torchaudio.save(audio_path, x, sr)
#print(f"Copying {audio} to {audio_path}")
#shutil.copy(audio, audio_path)
# print(f"Copying {audio} to {audio_path}")
# shutil.copy(audio, audio_path)

enhanced_dir = os.path.join(temp_dir, "enhanced")
os.mkdir(enhanced_dir)
Expand All @@ -53,17 +65,32 @@ def predict(
else:
raise ValueError(f"Unknown checkpoint: {checkpoint}")

os.system(
f"cd /sgmse ; python3 enhancement.py --test_dir {temp_dir} --enhanced_dir {enhanced_dir} --ckpt {ckpt}"
command = (
f"cd /sgmse ; python3 enhancement.py "
f"--test_dir {temp_dir} "
f"--enhanced_dir {enhanced_dir} "
f"--ckpt {ckpt} "
f"--corrector {corrector} "
f"--corrector_steps {corrector_steps} "
f"--snr {snr} "
f"--N {N}"
)

# TODO: subprocess
os.system(command)
files = [
f
for f in os.listdir(enhanced_dir)
if os.path.isfile(os.path.join(enhanced_dir, f))
and f.endswith(".wav")
]
assert len(files) == 1
return Path(files[0])
outfile = os.path.join(enhanced_dir, files[0])

new_filename = os.path.join("/tmp", str(uuid.uuid4()) + ".wav")
shutil.move(outfile, new_filename)

return Path(new_filename)
except:
raise
finally:
Expand Down

0 comments on commit afa656d

Please sign in to comment.