In [None]:
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

In [None]:
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
import torch
import jiwer
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from datasets import Audio
from transformers import WhisperForConditionalGeneration
from dataclasses import dataclass
from typing import Any, Dict, List, Union

In [None]:
# Load the Model
model_name = "openai/whisper-small"
# Load the speech dataset
dataset = load_dataset("DTU54DL/common-accent", split="test", use_auth_token=False)

# Load the ASR model and tokenizer
model = Wav2Vec2ForCTC.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name, language = "english",task = "transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name, language="english", task="transcribe")

dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

In [None]:
dataset = dataset.map(prepare_dataset)

In [None]:
dataset = dataset.remove_columns(["accent"])

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
model = WhisperForConditionalGeneration.from_pretrained(model_name)

In [None]:
model.generation_config.language = "english"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
# load model and processor
model.config.forced_decoder_ids = None

total_wer = 0
total_samples = 0

for data in dataset:
    sample = data["audio"]
    input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features

    # generate token ids
    predicted_ids = model.generate(input_features)

    # decode token ids to text
    transcription_with_special_tokens = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription_without_special_tokens = processor.batch_decode(predicted_ids, skip_special_tokens=True)

    # Calculate Word Error Rate (WER)
    wer = jiwer.wer(data["sentence"], transcription_without_special_tokens[0])
    total_wer += wer
    total_samples += 1

    print("Labels:", data["sentence"])
    print("Predicted Transcription:", transcription_without_special_tokens)
    print("Word Error Rate (WER):", wer)
    print()

average_wer = total_wer / total_samples
print("Average Word Error Rate (WER):", average_wer)
