# Faster Whisper

In [None]:
#@markdown **GPU check** (you typically want a V100, P100 or T4)
!nvidia-smi -L
!nvidia-smi

In [None]:
#@markdown **Setup Whisper**
%pip install -q yt-dlp faster-whisper demucs

In [None]:
#@markdown **load Whisper**
model_size = "large-v2"  # @param ["tiny","base","small","medium", "large-v1", "large-v2"]

import torch, torchaudio, faster_whisper, math
from yt_dlp import YoutubeDL
from tqdm import tqdm
from demucs.pretrained import get_model as demucs_get_model
from demucs.separate import load_track as demucs_load_track
from demucs.apply import apply_model as demucs_apply_model

DEMUCS_MODEL = demucs_get_model("htdemucs").cuda()
WHISPER_MODEL = faster_whisper.WhisperModel(model_size, device="cuda")

def convert_to_hms(seconds: float) -> str:
	hours, remainder = divmod(seconds, 3600)
	minutes, seconds = divmod(remainder, 60)
	milliseconds = math.floor((seconds % 1) * 1000)
	output = f"{int(hours):02}:{int(minutes):02}:{int(seconds):02},{milliseconds:03}"
	return output

def convert_seg(segment: faster_whisper.transcribe.Segment) -> str:
	start = convert_to_hms(segment.start)
	end = convert_to_hms(segment.end)
	text = segment.text.strip()
	return f"{start} --> {end}\n{text}"

TQDM_FORMAT = "{desc}: {percentage:5.1f}% |{bar}| {n:.2f}/{total:.2f} audio frames [{elapsed}<{remaining}, {rate_fmt}]"

TO_LANGUAGE_CODE = { # from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
	"afrikaans": "af",
	"albanian": "sq",
	"amharic": "am",
	"arabic": "ar",
	"armenian": "hy",
	"assamese": "as",
	"azerbaijani": "az",
	"bashkir": "ba",
	"basque": "eu",
	"belarusian": "be",
	"bengali": "bn",
	"bosnian": "bs",
	"breton": "br",
	"bulgarian": "bg",
	"burmese": "my",
	"castilian": "es",
	"catalan": "ca",
	"chinese": "zh",
	"croatian": "hr",
	"czech": "cs",
	"danish": "da",
	"dutch": "nl",
	"english": "en",
	"estonian": "et",
	"faroese": "fo",
	"finnish": "fi",
	"flemish": "nl",
	"french": "fr",
	"galician": "gl",
	"georgian": "ka",
	"german": "de",
	"greek": "el",
	"gujarati": "gu",
	"haitian creole": "ht",
	"haitian": "ht",
	"hausa": "ha",
	"hawaiian": "haw",
	"hebrew": "he",
	"hindi": "hi",
	"hungarian": "hu",
	"icelandic": "is",
	"indonesian": "id",
	"italian": "it",
	"japanese": "ja",
	"javanese": "jw",
	"kannada": "kn",
	"kazakh": "kk",
	"khmer": "km",
	"korean": "ko",
	"lao": "lo",
	"latin": "la",
	"latvian": "lv",
	"letzeburgesch": "lb",
	"lingala": "ln",
	"lithuanian": "lt",
	"luxembourgish": "lb",
	"macedonian": "mk",
	"malagasy": "mg",
	"malay": "ms",
	"malayalam": "ml",
	"maltese": "mt",
	"maori": "mi",
	"marathi": "mr",
	"moldavian": "ro",
	"moldovan": "ro",
	"mongolian": "mn",
	"myanmar": "my",
	"nepali": "ne",
	"norwegian": "no",
	"nynorsk": "nn",
	"occitan": "oc",
	"panjabi": "pa",
	"pashto": "ps",
	"persian": "fa",
	"polish": "pl",
	"portuguese": "pt",
	"punjabi": "pa",
	"pushto": "ps",
	"romanian": "ro",
	"russian": "ru",
	"sanskrit": "sa",
	"serbian": "sr",
	"shona": "sn",
	"sindhi": "sd",
	"sinhala": "si",
	"sinhalese": "si",
	"slovak": "sk",
	"slovenian": "sl",
	"somali": "so",
	"spanish": "es",
	"sundanese": "su",
	"swahili": "sw",
	"swedish": "sv",
	"tagalog": "tl",
	"tajik": "tg",
	"tamil": "ta",
	"tatar": "tt",
	"telugu": "te",
	"thai": "th",
	"tibetan": "bo",
	"turkish": "tr",
	"turkmen": "tk",
	"ukrainian": "uk",
	"urdu": "ur",
	"uzbek": "uz",
	"valencian": "ca",
	"vietnamese": "vi",
	"welsh": "cy",
	"yiddish": "yi",
	"yoruba": "yo",
}

In [None]:
#@markdown **Run Whisper**

video_links = "QblE6s_bYLY"  # @param {type:"string"}
#@markdown >*accept other streaming sites supported by `yt-dlp`, also you can directly input youtube video ID*
language = "english"  # @param {type:"string"}
translation_mode = "transcription only"  # @param ["transcription only", "transcription + translation"]
#@markdown SileroVAD settings:
vad_threshold = 0.4  # @param {type:"number"}
chunk_duration = 15.0  # @param {type:"number"}
#@markdown enable this setting below for audio with duration >1h requires lots of RAM and crashes Colab
vocals_extraction = False  # @param {type:"boolean"}
#@markdown new vocabulary (for e.g. person name)
new_vocabulary = "Genshin Impact, Jeht, Liloupar"  # @param {type:"string"}
#@markdown this setting below should not be changed except when having very serious hallucination
condition_on_previous_text = True  # @param {type:"boolean"}

# some sanity checks
assert vad_threshold >= 0.01
assert chunk_duration >= 0.1
assert video_links != ""
assert language != ""
language = language.lower()
assert language in TO_LANGUAGE_CODE, "invalid language"

if translation_mode == "transcription + translation":
	task = "translate"
elif translation_mode == "transcription only":
	task = "transcribe"
else:
	raise ValueError("Invalid translation mode")

if new_vocabulary.strip() == "":
	new_vocabulary = None

print("Downloading …")
with YoutubeDL({"format": "bestaudio", "outtmpl": "%(id)s_audio", "overwrites": True}) as ydl:
	vid_info = ydl.extract_info(video_links, download=True)
	vid_id = ydl.sanitize_info(vid_info)["id"]
audio_path = vid_id + "_audio" # as same as output template above
print()

if vocals_extraction:
	print("Separating vocals …")
	raw_audio = demucs_load_track(audio_path, DEMUCS_MODEL.audio_channels, DEMUCS_MODEL.samplerate)
	# should not be on GPU because sometimes not enough VRAM
	if raw_audio.dim() == 1:
		raw_audio = raw_audio[None, None].repeat_interleave(2, -2)
	elif raw_audio.shape[-2] == 1:
		raw_audio = raw_audio.repeat_interleave(2, -2)
	elif raw_audio.dim() < 3:
		raw_audio = raw_audio[None]
	demucs_extract = demucs_apply_model(DEMUCS_MODEL, raw_audio, device="cuda", split=True, overlap=.25)
	torch.cuda.empty_cache()
	demucs_res = demucs_extract[0, DEMUCS_MODEL.sources.index("vocals")].mean(0)[None]
	audio_path = vid_id + "_vocals.wav"
	torchaudio.save(audio_path, demucs_res, DEMUCS_MODEL.samplerate)

print("Running Whisper … PLEASE WAIT")
segments, info = WHISPER_MODEL.transcribe(
	audio_path, task=task,
	language=TO_LANGUAGE_CODE[language],
	condition_on_previous_text=condition_on_previous_text,
	initial_prompt=new_vocabulary,
	vad_filter=True,
	vad_parameters=dict(threshold=vad_threshold, max_speech_duration_s=chunk_duration),
)

full_txt = []
timestamps = 0.0 # for progress bar
with tqdm(total=info.duration, bar_format=TQDM_FORMAT) as pbar:
	for i, segment in enumerate(segments, start=1):
		full_txt.append(f"{i}\n{convert_seg(segment)}\n\n")
		pbar.update(segment.end - timestamps)
		timestamps = segment.end
	if timestamps < info.duration:
		pbar.update(info.duration - timestamps)

with open(vid_id + ".srt", mode="w", encoding="UTF-8") as f:
	f.writelines(full_txt)