Skip to content

Commit

Permalink
added max_line_length parameter for .srt files
Browse files Browse the repository at this point in the history
  • Loading branch information
rBrenick committed Dec 29, 2022
1 parent 0b5dcfd commit 6f2e2aa
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
10 changes: 7 additions & 3 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def cli():
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")


parser.add_argument("--max_line_length", type=optional_int, default=42, help="max amount of characters for a line in the subtitle files")

parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")

Expand Down Expand Up @@ -299,7 +301,9 @@ def cli():
threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)


output_max_line_length = args.pop("max_line_length")

from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)

Expand All @@ -318,7 +322,7 @@ def cli():

# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
write_srt(result["segments"], file=srt, max_line_length=output_max_line_length)


if __name__ == '__main__':
Expand Down
43 changes: 41 additions & 2 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def write_vtt(transcript: Iterator[dict], file: TextIO):
)


def write_srt(transcript: Iterator[dict], file: TextIO):
def write_srt(transcript: Iterator[dict], file: TextIO, max_line_length: int = 42):
"""
Write a transcript to a file in SRT format.
Expand All @@ -76,13 +76,52 @@ def write_srt(transcript: Iterator[dict], file: TextIO):
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""

comma_split_threshold = int(float(max_line_length) * 0.75)

for i, segment in enumerate(transcript, start=1):

# left the .replace() here to not change unnecessarily
# but I don't think it's needed?
segment_text = segment['text'].strip().replace('-->', '->')

if len(segment_text) > max_line_length:
segment_text = split_text_into_multiline(segment_text, max_line_length, comma_split_threshold)

# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
f"{segment_text}\n",
file=file,
flush=True,
)


def split_text_into_multiline(segment_text: str, max_line_length: int, comma_split_threshold: int):

words = segment_text.split(' ')

lines = [
words[0]
]

for word in words[1:]:
current_line = lines[-1]

# start a new line if the last word ended with a comma,
# and we're mostly through this line
if current_line.endswith(',') and len(current_line) > comma_split_threshold:
lines.append(word)
continue

line_with_word = f'{current_line} {word}'

if len(line_with_word) > max_line_length:
lines.append(word)
else:
lines[-1] = line_with_word

return '\n'.join(lines)

0 comments on commit 6f2e2aa

Please sign in to comment.