In [3]:
import logging
import math
import os
import tempfile
import zipfile
import time
import shutil
from multiprocessing import Pool

import gradio as gr
import jax.numpy as jnp
import numpy as np
import yt_dlp as youtube_dl
from jax.experimental.compilation_cache import compilation_cache as cc
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read

from whisper_jax import FlaxWhisperPipline

In [4]:
cc.initialize_cache("./jax_cache")
checkpoint = "openai/whisper-tiny"

DEBUG = True
BATCH_SIZE = 32
CHUNK_LENGTH_S = 30
NUM_PROC = 32
FILE_LIMIT_MB = 100000
YT_LENGTH_LIMIT_S = 720000  # limit to 2 hour YouTube files

title = description = article = " Whisper JAX ⚡️ "

language_names = sorted(TO_LANGUAGE_CODE.keys())

logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)

temp_path_zip_file = os.path.join("/home/ubuntu/whisper-gradio-ytb-demo/src", 'temp')

Initialized persistent compilation cache at ./jax_cache


In [5]:
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
stride_length_s = CHUNK_LENGTH_S / 6
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
step = chunk_len - stride_left - stride_right
pool = Pool(NUM_PROC)

#do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
logger.info("compiling forward call...")
start = time.time()
random_inputs = {"input_features": np.ones((BATCH_SIZE, 80, 3000))}
random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
compile_time = time.time() - start
logger.info(f"compiled in {compile_time}s")

2023-05-05 13:10:14;INFO;compiling forward call...
2023-05-05 13:10:17.476211: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:190] failed to create cublas handle: cublas error
2023-05-05 13:10:17.476255: E external/xla/xla/stream_executor/cuda/cuda_blas.cc:193] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2023-05-05 13:10:17.594513: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:850] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2023-05-05 13:10:17.594546: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:853] Conv: (f32[32,384,3000]{2,1,0}, u8[0]{0}) custom-call(f32[32,80,3000]{2,1,0}, f32[384,80,3]{2,1,0}), window={size=3 pad=1_1}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForwa

XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.2 = (f32[32,384,3000]{2,1,0}, u8[0]{0}) custom-call(f32[32,80,3000]{2,1,0} %Arg_167.168, f32[384,80,3]{2,1,0} %transpose.405), window={size=3 pad=1_1}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="pmap(generate)/jit(main)/encoder/conv1/conv_general_dilated[window_strides=(1,) padding=((1, 1),) lhs_dilation=(1,) rhs_dilation=(1,) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_44303/3957122359.py" source_line=12}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: INTERNAL: All algorithms tried for %cudnn-conv.2 = (f32[32,384,3000]{2,1,0}, u8[0]{0}) custom-call(f32[32,80,3000]{2,1,0} %Arg_167.168, f32[384,80,3]{2,1,0} %transpose.405), window={size=3 pad=1_1}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", metadata={op_name="pmap(generate)/jit(main)/encoder/conv1/conv_general_dilated[window_strides=(1,) padding=((1, 1),) lhs_dilation=(1,) rhs_dilation=(1,) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_44303/3957122359.py" source_line=12}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng34{k2=0,k4=2,k5=1,k6=0,k7=0,k19=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng34{k2=1,k4=2,k5=1,k6=0,k7=0,k19=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng4{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=1,k5=0,k6=0,k7=0,k19=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng42{k2=1,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng31{k2=2,k4=2,k5=3,k6=2,k7=1}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k4=2,k5=0,k6=0,k7=0,k19=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng30{k2=2,k4=2,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng42{k2=2,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng2{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=0,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng2{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng1{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'
  Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(4686): 'status'

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

In [None]:
def identity(batch):
    return batch

# Extra

In [11]:
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
    if seconds is None:
        # we have a malformed timestamp so just return it as is
        return seconds
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"

def create_transcript_zip(videos,tmpdir):
        """
      Clear the temporary directory contents
      
      Create a zip file for each video transcript and return the path to the zip of all transcripts.

      Args:
      videos (list of dict): Each dictionary must have "title" and "transcript" keys, containing the video title
      and its transcript respectively.

      Returns:
      str: Path to the zip file containing all transcript zip files.
      """
        for filename in os.listdir(tmpdir):
            file_path = os.path.join(tmpdir, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
                print(f'Deleted {file_path}')
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')

        # Create a temporary directory to store all transcript zip files
        zip_paths = []
        # Loop through all videos and create a transcript zip file for each
        for video in videos:
            # Create a zip file with the video title as the filename
            zip_path = os.path.join(tmpdir, f"{video['title']}.zip")
            if not os.path.exists(temp_path_zip_file):
              os.makedirs(temp_path_zip_file)
            with zipfile.ZipFile(zip_path, "w") as zip_file:
                # Write the transcript to an SRT file with the same name as the video
                srt_path = os.path.join(tmpdir, f"{video['title']}.srt")
                with open(srt_path, "w") as srt_file:
                    srt_file.write(video["transcript"])
                # Add the SRT file to the zip
                zip_file.write(srt_path, f"{video['title']}.srt")
            zip_paths.append(zip_path)
        # Create a zip file containing all transcript zip files
        all_zip_path = os.path.join(tmpdir, "all_transcripts.zip")
        with zipfile.ZipFile(all_zip_path, "w") as all_zip_file:
            for zip_path in zip_paths:
                all_zip_file.write(zip_path, os.path.basename(zip_path))
        return all_zip_path

In [12]:
def _return_yt_html_embed(yt_url):
    video_id = yt_url[-1].split("?v=")[-1]
    return f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe> </center>'

def download_yt_audio(yt_url, filename):
    title_ytb = youtube_dl.YoutubeDL().extract_info(yt_url, download=False).get("title", None)
    info_loader = youtube_dl.YoutubeDL()
    try:
        info = info_loader.extract_info(yt_url, download=False)
    except youtube_dl.utils.DownloadError as err:
        raise gr.Error(str(err)) from err

    file_length = info["duration_string"]
    file_h_m_s = file_length.split(":")
    file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
    if len(file_h_m_s) == 1:
        file_h_m_s.insert(0, 0)
    if len(file_h_m_s) == 2:
        file_h_m_s.insert(0, 0)

    file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
    if file_length_s > YT_LENGTH_LIMIT_S:
        yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
        file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
        raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")

    ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        try:
            ydl.download([yt_url])
            return title_ytb
        except youtube_dl.utils.ExtractorError as err:
            raise gr.Error(str(err)) 

# Main

In [6]:
def tqdm_generate(inputs: dict, task: str, return_timestamps: bool):
    inputs_len = inputs["array"].shape[0]
    all_chunk_start_idx = np.arange(0, inputs_len, step)
    num_samples = len(all_chunk_start_idx)
    num_batches = math.ceil(num_samples / BATCH_SIZE)
    dummy_batches = list(
        range(num_batches)
    )  # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841

    dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
    #progress(0, desc="Pre-processing audio file...")
    logger.info("pre-processing audio file...")
    dataloader = pool.map(identity, dataloader)
    logger.info("done post-processing")

    start_time = time.time()
    logger.info("transcribing...")
    model_outputs = [
        pipeline.forward(
            batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True
        )
        for batch, _ in zip(
            dataloader
        )
    ]
    runtime = time.time() - start_time
    logger.info("done transcription")

    logger.info("post-processing...")
    post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
    text = post_processed["text"]
    if return_timestamps:
        timestamps = post_processed.get("chunks")
        timestamps = [
            f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
            for chunk in timestamps
        ]
        text = "\n".join(str(feature) for feature in timestamps)
    logger.info("done post-processing")
    return text, runtime

In [7]:
def transcribe_chunked_audio(inputs, task, return_timestamps):
    #progress(0, desc="Loading audio file...")
    logger.info("loading audio file...")
    if inputs is None:
        logger.warning("No audio file")
        raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
    file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
    if file_size_mb > FILE_LIMIT_MB:
        logger.warning("Max file size exceeded")
        raise gr.Error(
            f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
        )

    with open(inputs, "rb") as f:
        inputs = f.read()

    inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
    sampling_rate = float(pipeline.feature_extractor.sampling_rate)
    inputs = {"array": inputs, "sampling_rate": sampling_rate}
    logger.info("done loading")
    text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
    return text, runtime

In [23]:
import uuid

def transcribe_youtube(yt_urls, task, return_timestamps):
        final_files_data = []
        yt_urls = yt_urls.split()
        html_embed_str = _return_yt_html_embed(yt_urls)
        with tempfile.TemporaryDirectory() as tmpdirname:
            print(yt_urls)
            for yt_url in yt_urls:
                ran_id = str(uuid.uuid4())
                filepath = os.path.join(tmpdirname, f"{ran_id}_video.mp4")
                print(f"///////////----{filepath}")
                print(f"\n--Doing for {yt_urls.index(yt_url)}--{filepath}----\n")
                title_ytb = download_yt_audio(yt_url, filepath)

                with open(filepath, "rb") as f:
                    inputs = f.read()

        #         inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
        #         inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
        #         logger.info("done loading...")
        #         text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps)
        #         final_files_data.append({"title": title_ytb, "transcript": text})
        # path_of_zip_file = create_transcript_zip(final_files_data, temp_path_zip_file)
        # return html_embed_str, path_of_zip_file, runtime

In [24]:
yt_urls = "https://www.youtube.com/watch?v=4AHz39IIkmc https://www.youtube.com/watch?v=vhr-i1WtfXY"

In [25]:
html_embed_str, path_of_zip_file, runtime = transcribe_youtube(yt_urls,"transcribe",False)

['https://www.youtube.com/watch?v=4AHz39IIkmc', 'https://www.youtube.com/watch?v=vhr-i1WtfXY']
///////////----/tmp/tmp5zqrooq3/ed7423e2-37c9-4039-9b2d-24c5aaa430cd_video.mp4

--Doing for 0--/tmp/tmp5zqrooq3/ed7423e2-37c9-4039-9b2d-24c5aaa430cd_video.mp4----

[youtube] Extracting URL: https://www.youtube.com/watch?v=4AHz39IIkmc
[youtube] 4AHz39IIkmc: Downloading webpage
[youtube] 4AHz39IIkmc: Downloading android player API JSON
[youtube] Extracting URL: https://www.youtube.com/watch?v=4AHz39IIkmc
[youtube] 4AHz39IIkmc: Downloading webpage
[youtube] 4AHz39IIkmc: Downloading android player API JSON
[youtube] Extracting URL: https://www.youtube.com/watch?v=4AHz39IIkmc
[youtube] 4AHz39IIkmc: Downloading webpage
[youtube] 4AHz39IIkmc: Downloading android player API JSON
[info] 4AHz39IIkmc: Downloading 1 format(s): 597+140
[dashsegments] Total fragments: 1
[download] Destination: /tmp/tmp5zqrooq3/ed7423e2-37c9-4039-9b2d-24c5aaa430cd_video.f597.mp4
[download] 100% of    2.08MiB in 00:00:00 at 

TypeError: cannot unpack non-iterable NoneType object