# Whisper with huggingface `transformers`

In [None]:
#@markdown **GPU check** (free tier T4, paid tier V100 or A100, if error then u must enable gpu session)
!nvidia-smi -L

In [None]:
#@markdown **Setup Whisper**
%pip install -q flash-attn deepl srt

PUNCT_MATCH = ["。", "、", ",", ".", "〜", "！", "!", "？", "?", "-"]
REMOVE_QUOTES = dict.fromkeys(map(ord, '"„“‟”＂「」'), None)

MEDIA_EXT = [
	".3gp",
	".aac",
	".ac3",
	".aif",
	".aifc",
	".aiff",
	".amr",
	".ape",
	".asf",
	".asx",
	".au",
	".avi",
	".bdmv",
	".bwf",
	".caf",
	".dat",
	".dts",
	".dtshd",
	".eac3",
	".eb3",
	".ec3",
	".f4v",
	".flac",
	".fli",
	".flv",
	".l16",
	".m1a",
	".m2a",
	".m2ts",
	".m4a",
	".m4v",
	".mid",
	".mka",
	".mkv",
	".mlp",
	".mod",
	".mov",
	".mp1",
	".mp2",
	".mp3",
	".mp4",
	".mpa",
	".mpc",
	".mpeg",
	".mpg",
	".ofr",
	".oga",
	".ogg",
	".opus",
	".pcm",
	".qt",
	".ra",
	".ram",
	".rm",
	".snd",
	".spx",
	".stm",
	".tak",
	".thd",
	".ts",
	".tta",
	".vob",
	".voc",
	".vqf",
	".wav",
	".wave",
	".webm",
	".wma",
	".wmv",
	".wv",
]

GARBAGE_LIST = [
	"a",
	"aa",
	"ah",
	"ahh",
	"h",
	"ha",
	"haa",
	"hah",
	"haha",
	"hahaha",
	"hm",
	"hmm",
	"huh",
	"m",
	"mh",
	"mm",
	"mmh",
	"mmm",
	"o",
	"oh",
]

NEED_CONTEXT_LINES = [
	"feelsgod",
	"godbye",
	"godnight",
	"thankyou",
]

clean_text = lambda text: (text
	.replace(".", "")
	.replace(",", "")
	.replace(":", "")
	.replace(";", "")
	.replace("!", "")
	.replace("?", "")
	.replace("-", " ")
	.replace("  ", " ")
	.replace("  ", " ")
	.replace("  ", " ")
	.lower()
	.replace("that feels", "feels")
	.replace("it feels", "feels")
	.replace("feels good", "feelsgood")
	.replace("good bye", "goodbye")
	.replace("good night", "goodnight")
	.replace("thank you", "thankyou")
	.replace("aaaaaa", "a")
	.replace("aaaa", "a")
	.replace("aa", "a")
	.replace("aa", "a")
	.replace("mmmmmm", "m")
	.replace("mmmm", "m")
	.replace("mm", "m")
	.replace("mm", "m")
	.replace("hhhhhh", "h")
	.replace("hhhh", "h")
	.replace("hh", "h")
	.replace("hh", "h")
	.replace("oooooo", "o")
	.replace("oooo", "o")
	.replace("oo", "o")
	.replace("oo", "o")
)

TO_LANGUAGE_CODE = { # from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
	"afrikaans": "af",
	"albanian": "sq",
	"amharic": "am",
	"arabic": "ar",
	"armenian": "hy",
	"assamese": "as",
	"azerbaijani": "az",
	"bashkir": "ba",
	"basque": "eu",
	"belarusian": "be",
	"bengali": "bn",
	"bosnian": "bs",
	"breton": "br",
	"bulgarian": "bg",
	"burmese": "my",
	"castilian": "es",
	"catalan": "ca",
	"chinese": "zh",
	"croatian": "hr",
	"czech": "cs",
	"danish": "da",
	"dutch": "nl",
	"english": "en",
	"estonian": "et",
	"faroese": "fo",
	"finnish": "fi",
	"flemish": "nl",
	"french": "fr",
	"galician": "gl",
	"georgian": "ka",
	"german": "de",
	"greek": "el",
	"gujarati": "gu",
	"haitian creole": "ht",
	"haitian": "ht",
	"hausa": "ha",
	"hawaiian": "haw",
	"hebrew": "he",
	"hindi": "hi",
	"hungarian": "hu",
	"icelandic": "is",
	"indonesian": "id",
	"italian": "it",
	"japanese": "ja",
	"javanese": "jw",
	"kannada": "kn",
	"kazakh": "kk",
	"khmer": "km",
	"korean": "ko",
	"lao": "lo",
	"latin": "la",
	"latvian": "lv",
	"letzeburgesch": "lb",
	"lingala": "ln",
	"lithuanian": "lt",
	"luxembourgish": "lb",
	"macedonian": "mk",
	"malagasy": "mg",
	"malay": "ms",
	"malayalam": "ml",
	"maltese": "mt",
	"maori": "mi",
	"marathi": "mr",
	"moldavian": "ro",
	"moldovan": "ro",
	"mongolian": "mn",
	"myanmar": "my",
	"nepali": "ne",
	"norwegian": "no",
	"nynorsk": "nn",
	"occitan": "oc",
	"panjabi": "pa",
	"pashto": "ps",
	"persian": "fa",
	"polish": "pl",
	"portuguese": "pt",
	"punjabi": "pa",
	"pushto": "ps",
	"romanian": "ro",
	"russian": "ru",
	"sanskrit": "sa",
	"serbian": "sr",
	"shona": "sn",
	"sindhi": "sd",
	"sinhala": "si",
	"sinhalese": "si",
	"slovak": "sk",
	"slovenian": "sl",
	"somali": "so",
	"spanish": "es",
	"sundanese": "su",
	"swahili": "sw",
	"swedish": "sv",
	"tagalog": "tl",
	"tajik": "tg",
	"tamil": "ta",
	"tatar": "tt",
	"telugu": "te",
	"thai": "th",
	"tibetan": "bo",
	"turkish": "tr",
	"turkmen": "tk",
	"ukrainian": "uk",
	"urdu": "ur",
	"uzbek": "uz",
	"valencian": "ca",
	"vietnamese": "vi",
	"welsh": "cy",
	"yiddish": "yi",
	"yoruba": "yo",
}

In [None]:
# weird error with huggingface
from huggingface_hub.utils import _runtime
_runtime._is_google_colab = False

import os, srt, json, torch
from datetime import timedelta
from transformers import pipeline
from google.colab import files as g_files, drive as g_drive

#@markdown **load Whisper**
model_size = "large-v2"  # @param ["tiny", "base", "small", "medium", "large-v1", "large-v2", "large-v3"]

SILERO_SAMPLING_RATE = 16000  # Silero VAD operating value
SILERO_MODEL, (silero_get_speech_timestamps, silero_save_audio, silero_read_audio, _, silero_collect_chunks) = torch.hub.load(
	repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=False
)

WHISPER_MODEL = pipeline(
	task="automatic-speech-recognition",
	model="openai/whisper-" + model_size,
	device="cuda:0",
	torch_dtype=torch.float16,
	chunk_length_s=30, # if not precised then only generate as much as `max_new_tokens`
	batch_size=24,
	generate_kwargs={"num_beams": 5, "attn_implementation": "flash_attention_2"}
)

def my_transcribe_func(audio_path: str, task: str, language: str) -> None:
	audiofilebasename = os.path.splitext(audio_path)[0]
	out_path = audiofilebasename + ".srt"

	silero_wav = silero_read_audio(audio_path, sampling_rate=SILERO_SAMPLING_RATE)  # SileroVAD operate on mono channel at 16 kHz
	silero_speech_timestamps = silero_get_speech_timestamps(silero_wav, SILERO_MODEL, sampling_rate=SILERO_SAMPLING_RATE)
	audio_wav = silero_collect_chunks(silero_speech_timestamps, silero_wav).numpy()

	print("Running Whisper … PLEASE WAIT")
	result = WHISPER_MODEL(
		audio_wav,
		return_timestamps=True,
		generate_kwargs={
			"language": TO_LANGUAGE_CODE[language],
			"task": task,
		}
	)

	subs = []
	segment_info = []

	for i, chunk in enumerate(result["chunks"], start=1):
		# Keep segment info for debugging
		segment_info.append(chunk)
		# Add to SRT list
		subs.append(srt.Subtitle(
			index=i,
			start=timedelta(seconds=chunk["timestamp"][0]),
			end=timedelta(seconds=chunk["timestamp"][1]),
			content=chunk["text"].strip(),
		))

	# for debugging only
	with open(audiofilebasename + ".debug.json", mode="w", encoding="utf8") as f:
		json.dump(segment_info, f, indent="\t")

	# Removal of garbage lines
	clean_subs = []
	last_line_garbage = False
	for i in range(len(subs)):
		c = clean_text(subs[i].content)
		is_garbage = True
		for w in c.split(" "):
			w_tmp = w.strip()
			if w_tmp == "":
				continue
			if w_tmp in GARBAGE_LIST:
				continue
			elif w_tmp in NEED_CONTEXT_LINES and last_line_garbage:
				continue
			else:
				is_garbage = False
				break
		if not is_garbage:
			clean_subs.append(subs[i])
		last_line_garbage = is_garbage
	with open(out_path, mode="w", encoding="utf8") as f:
		f.write(srt.compose(clean_subs))
	print("\nDone! Subs written to", out_path)
	print("Downloading SRT file:")
	g_files.download(out_path)

In [None]:
#@markdown **Mount Google Drive** (skip this if your audio file isn't stored there)
g_drive.mount("drive")

In [None]:
#@markdown **Upload audio file to Colab** (optional) <br>
#@markdown If this step fails, or is very slow, try one of these options:
#@markdown  - Drag your file into the Files sidebar, and set audio_path to the filename
#@markdown  - OR upload it to Google Drive, mount it, and set audio_path to the absolute path

gfiles = g_files.upload()
if len(gfiles) > 0:
	uploaded_file = list(gfiles)[0]
	print("Upload complete")

In [None]:
#@markdown **Run Whisper**

audio_path = "wu16rnkF8II.opus"  # @param {type:"string"}
#@markdown >*path can be a file or a folder contain audio files*
language = "english"  # @param {type:"string"}
translation_mode = "transcription only"  # @param ["transcription only", "transcription + translation"]

# some sanity checks
assert audio_path != ""
assert language != ""
language = language.lower()
assert language in TO_LANGUAGE_CODE, "invalid language"

if translation_mode == "transcription + translation":
	task = "translate"
elif translation_mode == "transcription only":
	task = "transcribe"
else:
	raise ValueError("Invalid translation mode")

if not os.path.exists(audio_path):
	try:
		audio_path = uploaded_file
		if not os.path.exists(audio_path):
			raise ValueError("Input audio not found. Is your audio_path correct?")
	except:
		raise ValueError("Input audio not found. Did you upload a file?")

if os.path.isfile(audio_path):
	print(audio_path)
	my_transcribe_func(audio_path, task, language)
elif os.path.isdir(audio_path):
	for root, dirs, files in os.walk(audio_path):
		audio_files = filter(lambda filename: os.path.splitext(filename)[1].lower() in MEDIA_EXT, files)
		for filename in audio_files:
			filepath = os.path.join(root, filename)
			print(filepath)
			my_transcribe_func(filepath, task, language)
else:
	raise ValueError("cannot open audio path")