## Import libraries

In [None]:
from functools import partial

import numpy as np
import soundfile as sf
import librosa as lb
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import (
    WhisperProcessor,
    WhisperFeatureExtractor,
    WhisperTokenizerFast,
    WhisperForConditionalGeneration,
)
import evaluate
from transformers import pipeline
from datasets import load_dataset, DatasetDict

import torch
import torchaudio
from torch.cuda import empty_cache
from torch.utils.data import Dataset, DataLoader

In [None]:
%%time
model_id = 'openai/whisper-large-v2'
processor = WhisperProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id)

tokenizer = WhisperTokenizerFast.from_pretrained(model_id)
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

In [None]:
forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language='Hindi')
prompt_ids = processor.get_prompt_ids('Glossary')

In [None]:
wer_metric = evaluate.load("wer")

In [None]:
config = {
    "forced_decoder_ids": forced_decoder_ids, # What decoder IDs to use
    "prompt_ids": None, # What prompt IDs to use
    "num_beams": 1, # Number of beams to use for beam search
    "return_timestamps": True,
    "chunk_length_s": 30
}

## Load dataset

Load a toy dataset to evalute the performance

In [None]:
common_voice = DatasetDict()

common_voice["train"] = load_dataset(
    "mozilla-foundation/common_voice_11_0", "hi", split="train+validation"
)
common_voice["test"] = load_dataset(
    "mozilla-foundation/common_voice_11_0", "hi", split="test"
)

print(common_voice)

In [None]:
%%time
duration = []
for e in common_voice["train"]:
    duration.append(e["audio"]["array"].shape[0]/e["audio"]["sampling_rate"])

duration = np.array(duration)

In [None]:
print(
    f'Median length of the audio: {np.percentile(duration, 0.5)}\n95th percentile lenght of audio: {np.percentile(duration, 0.95)}'
)

Since single audio clips are quite small, we combine them to form longer audio files. We will concatenate 300 audio files one after the other. Each audio files will be also be transformed by removing the silences

### Transform dataset

In [None]:
def cut_silences(audio: np.ndarray, silence_model, silence_threshold:float = 0.5, **kwargs):
    """
    Removes silences from the audio file
    """
    get_speech_timestamps = kwargs.get('get_speech_timestamps')
    collect_chunks = kwargs.get('collect_chunks')

    speech_audio = get_speech_timestamps(
        audio, silence_model, threshold=silence_threshold, sampling_rate=16000
    )

    speech_audio = collect_chunks(speech_audio, torch.tensor(audio))

    return speech_audio.numpy()

In [None]:
vad_model, vad_utils = torch.hub.load(
    repo_or_dir='snakers4/silero-vad',
    model='silero_vad',
    force_reload=True,
)

(get_speech_timestamps, _, _, _, collect_chunks) = vad_utils

In [None]:
transform_cut_silences = partial(
    cut_silences, silence_model=vad_model, get_speech_timestamps=get_speech_timestamps, collect_chunks=collect_chunks
)

In [None]:
# Print one example of the dataset
common_voice["train"][0].keys()
common_voice["train"][0]

In [None]:
class CustomAudioDataset(Dataset):
    def __init__(self, raw_ds: Dataset, concat_n:int = 10, transform: callable = None):
        self.raw_ds = raw_ds
        self.transform = transform
        self.concat_n = concat_n

    def __len__(self):
        return len(self.raw_ds)//self.concat_n

    def _resample_audio(self, x: np.ndarray, sr: int):
        """
        Resample audio to 16Khz since that is being used by Whisper
        """
        return lb.resample(x, orig_sr=sr, target_sr=16000)

    def __getitem__(self, idx):
        concat_audio = []
        concat_label = []
        
        for i in range(idx*self.concat_n, (idx+1)*self.concat_n):
            a = self.raw_ds[i]["audio"]["array"]
            a = self._resample_audio(
                a,
                self.raw_ds[i]["audio"]["sampling_rate"]
            )

            concat_audio.append(a)
            concat_label.append(self.raw_ds[i]["sentence"])

        concat_audio = np.concatenate(concat_audio, axis=0)
        if self.transform:
            concat_audio = self.transform(concat_audio)
            
        return concat_audio, ' '.join(concat_label)

In [None]:
training_data = CustomAudioDataset(
    raw_ds=common_voice["test"],
    concat_n=50,
    transform=transform_cut_silences
)

In [None]:
train_dataloader = DataLoader(
    training_data, batch_size=8
)

## Loading the pipeline

In [None]:
%%time
transcriber = pipeline(
    task="automatic-speech-recognition",
    model=model,
    feature_extractor=feature_extractor,
    tokenizer=tokenizer,
    device=0
)

In [None]:
wer_scores = []
i = 0

for audio, transc in tqdm(training_data, total=len(training_data)):
    empty_cache()
    generated_transc = transcriber(
        audio,
        chunk_length_s=config.get('chunk_length_s'),
        return_timestamps=config.get('return_timestamps'),
        generate_kwargs={
            'forced_decoder_ids': config.get('forced_decoder_ids'),
            'num_beams': config.get('num_beams')
        }
    )['text']

    generated_transc_ids = transcriber.tokenizer.encode(generated_transc)
    transc_ids = transcriber.tokenizer.encode(transc)
    # Make sure both predictions and references are of the same length
    if len(generated_transc_ids) > len(transc_ids):
        pred_ids = generated_transc_ids[:len(transc_ids)]
        ref_ids = transc_ids
    else:
        pred_ids = generated_transc_ids
        ref_ids = transc_ids[:len(pred_ids)]

    wer_score = wer_metric.compute(
        predictions=transcriber.tokenizer.batch_decode(pred_ids, skip_special_tokens=True),
        references=transcriber.tokenizer.batch_decode(ref_ids, skip_special_tokens=True)
    )
    wer_scores.append(wer_score)

    i += 1

    if i >= 10:
        break

wer_scores = np.array(wer_scores)