Skip to content

Commit

Permalink
hopefully the last bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zackees committed May 8, 2023
1 parent e17b0c0 commit b4d5b33
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 53 deletions.
35 changes: 27 additions & 8 deletions tests/test_transcribe_anything_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,47 @@
from transcribe_anything.api import transcribe

HERE = os.path.abspath(os.path.dirname(__file__))
TESTS_DATA_DIR = os.path.join(HERE, "test_data", "en")
LOCALFILE_DIR = os.path.join(HERE, "localfile")
TESTS_DATA_DIR = os.path.join(LOCALFILE_DIR, "text_video", "en")


class TranscribeAnythingApiTester(unittest.TestCase):
"""Tester for transcribe anything."""

def test_local_file(self) -> None:
"""Check that the command works on a local file."""
expected_base_dir = os.path.join(LOCALFILE_DIR, "text_video", "en")
shutil.rmtree(expected_base_dir, ignore_errors=True)
shutil.rmtree(TESTS_DATA_DIR, ignore_errors=True)
vidfile = os.path.join(LOCALFILE_DIR, "video.mp4")
prev_dir = os.getcwd()
os.chdir(LOCALFILE_DIR)
transcribe(url_or_file=vidfile, language="en", model="tiny")
os.chdir(prev_dir)
expected_paths = [
expected_base_dir,
# os.path.join(TESTS_DATA_DIR, "out.mp3"),
os.path.join(expected_base_dir, "out.txt"),
os.path.join(expected_base_dir, "out.srt"),
os.path.join(expected_base_dir, "out.vtt"),
TESTS_DATA_DIR,
os.path.join(TESTS_DATA_DIR, "out.txt"),
os.path.join(TESTS_DATA_DIR, "out.srt"),
os.path.join(TESTS_DATA_DIR, "out.vtt"),
]
for expected_path in expected_paths:
self.assertTrue(
os.path.exists(expected_path),
f"expected path {expected_path} not found",
)

def test_fetch_command_installed(self) -> None:
"""Check that the command works on a live short video."""
shutil.rmtree(TESTS_DATA_DIR, ignore_errors=True)
transcribe(
url_or_file="https://www.youtube.com/watch?v=DWtpNPZ4tb4",
language="en",
model="tiny",
output_dir=TESTS_DATA_DIR,
)
expected_paths = [
TESTS_DATA_DIR,
os.path.join(TESTS_DATA_DIR, "out.txt"),
os.path.join(TESTS_DATA_DIR, "out.srt"),
os.path.join(TESTS_DATA_DIR, "out.vtt"),
]
for expected_path in expected_paths:
self.assertTrue(
Expand Down
90 changes: 45 additions & 45 deletions transcribe_anything/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def make_temp_wav() -> str:

def cleanup() -> None:
if os.path.exists(tmp.name):
print(f"make_temp_wav: Removing {tmp.name}")
os.remove(tmp.name)

atexit.register(cleanup)
Expand Down Expand Up @@ -106,56 +105,57 @@ def transcribe(
raise ValueError(f"Unknown device {device}")
print(f"Using device {device}")
model_str = f" --model {model}" if model else ""
output_dir_str = f' --output_dir "{output_dir}"' if output_dir else ""
task_str = f" --task {task}" if task else ""
language_str = f" --language {language}" if language else ""
cmd_list = []
if sys.platform == "win32":
# Set the text mode to UTF-8 on Windows.
cmd_list.extend(["chcp", "65001", "&&"])
cmd_list.extend(
[
"whisper",
f'"{tmp_wav}"',
"--device",
device,
model_str,
output_dir_str,
task_str,
language_str,
]
)
if other_args:
cmd_list.extend(other_args)
# Remove the empty strings.
cmd_list = [x.strip() for x in cmd_list if x.strip()]
cmd = " ".join(cmd_list)
sys.stderr.write(f"Running:\n {cmd}\n")
proc = subprocess.Popen( # pylint: disable=consider-using-with
cmd, shell=True, universal_newlines=True
)
while True:
rtn = proc.poll()
if rtn is None:
time.sleep(0.25)
continue
if rtn != 0:
msg = f"Failed to execute {cmd}\n "
raise OSError(msg)
break
files = [os.path.join(output_dir, name) for name in os.listdir(output_dir)]
for file in files:
# Change the filename to remove the double extension
file_name = os.path.basename(file)
base_path = os.path.dirname(file)
new_file = os.path.join(base_path, chop_double_extension(file_name))
_, ext = os.path.splitext(new_file)
outfile = os.path.join(base_path, f"out{ext}")
if os.path.exists(outfile):
os.remove(outfile)
assert os.path.isfile(file), f"Path {file} doesn't exist."
assert not os.path.exists(outfile), f"Path {outfile} already exists."
shutil.move(file, outfile)

with tempfile.TemporaryDirectory() as tmpdir:
cmd_list.extend(
[
"whisper",
f'"{tmp_wav}"',
"--device",
device,
model_str,
f'--output_dir "{tmpdir}"',
task_str,
language_str,
]
)
if other_args:
cmd_list.extend(other_args)
# Remove the empty strings.
cmd_list = [x.strip() for x in cmd_list if x.strip()]
cmd = " ".join(cmd_list)
sys.stderr.write(f"Running:\n {cmd}\n")
proc = subprocess.Popen( # pylint: disable=consider-using-with
cmd, shell=True, universal_newlines=True
)
while True:
rtn = proc.poll()
if rtn is None:
time.sleep(0.25)
continue
if rtn != 0:
msg = f"Failed to execute {cmd}\n "
raise OSError(msg)
break
files = [os.path.join(tmpdir, name) for name in os.listdir(tmpdir)]
for file in files:
# Change the filename to remove the double extension
file_name = os.path.basename(file)
base_path = os.path.dirname(file)
new_file = os.path.join(base_path, chop_double_extension(file_name))
_, ext = os.path.splitext(new_file)
outfile = os.path.join(output_dir, f"out{ext}")
if os.path.exists(outfile):
os.remove(outfile)
assert os.path.isfile(file), f"Path {file} doesn't exist."
assert not os.path.exists(outfile), f"Path {outfile} already exists."
shutil.move(file, outfile)
return output_dir


Expand Down

0 comments on commit b4d5b33

Please sign in to comment.