Skip to content

Commit

Permalink
adds conversion function for webvtt files
Browse files Browse the repository at this point in the history
  • Loading branch information
zackees committed May 10, 2023
1 parent 6984df2 commit 3a08ea0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/video_subtitles/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Don't say anything.",
)
parser.add_argument(
"--webvtt",
action="store_true",
help="Output WebVTT format.",
)
parser.add_argument("--api-key", default=None, help="Transcribe Anything API key.")
args = parser.parse_args()
if not args.languages:
Expand Down Expand Up @@ -120,6 +125,7 @@ def main() -> int:
deepl_api_key=api_key,
out_languages=args.languages,
model=args.model,
convert_to_webvtt=args.webvtt,
)
if not args.quite:
say(f"Finished generating srt files for {file}")
Expand Down
1 change: 1 addition & 0 deletions src/video_subtitles/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _generate_subtitles(
deepl_api_key=deeply_api_key,
out_languages=languages,
model=model,
convert_to_webvtt=True,
)
except Exception as e: # pylint: disable=broad-except
print(e)
Expand Down
10 changes: 10 additions & 0 deletions src/video_subtitles/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil

from video_subtitles.convert_to_webvtt import convert_to_webvtt as convert_webvtt
from video_subtitles.translate import srt_wrap, translate

IS_GITHUB = os.environ.get("GITHUB_ACTIONS", False)
Expand All @@ -25,6 +26,7 @@ def run( # pylint: disable=too-many-locals,too-many-branches,too-many-statement
deepl_api_key: str | None,
out_languages: list[str],
model: str,
convert_to_webvtt: bool,
) -> str:
"""Run the program."""
from transcribe_anything.api import ( # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -87,5 +89,13 @@ def run( # pylint: disable=too-many-locals,too-many-branches,too-many-statement
os.remove(out_file)
shutil.move(srt_file, out_file)
shutil.rmtree(os.path.dirname(srt_file))
if convert_to_webvtt:
srt_files = find_srt_files(outdir)
for srt_file in srt_files:
webvtt_file = os.path.splitext(srt_file)[0] + ".vtt"
if os.path.exists(webvtt_file):
os.remove(webvtt_file)
convert_webvtt(srt_file, webvtt_file)
os.remove(srt_file)
print("Done translating")
return outdir
1 change: 1 addition & 0 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_imports(self) -> None:
deepl_api_key=None,
out_languages=["es", "fr", "zh"],
model="small",
convert_to_webvtt=False,
)
self.assertTrue(os.path.exists("text_video"))
self.assertTrue(os.path.exists(os.path.join("text_video", "en.srt")))
Expand Down
5 changes: 2 additions & 3 deletions tests/test_srt_to_webvtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import tempfile
import unittest

import webvtt
from video_subtitles.convert_to_webvtt import convert_to_webvtt

HERE = os.path.dirname(os.path.abspath(__file__))


TEST_SRT = os.path.join(HERE, "test.srt")


Expand All @@ -21,7 +20,7 @@ def test_translate(self) -> None:
# translator = DeeplTranslator() # free version
with tempfile.TemporaryDirectory() as tmpdirname:
out_file = os.path.join(tmpdirname, "out.vtt")
webvtt.from_srt(TEST_SRT).save(out_file)
convert_to_webvtt(TEST_SRT, out_file)


if __name__ == "__main__":
Expand Down

0 comments on commit 3a08ea0

Please sign in to comment.