In [1]:
from transformers import WhisperProcessor, WhisperFeatureExtractor
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

feature_extractor = processor.feature_extractor
tokenizer = processor.tokenizer

audio_column_name = "audio"
model_input_name = "input_features"
train_text_column_name = "text"

In [3]:
raw_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
raw_dataset_features = list(raw_dataset.features.keys())

raw_dataset = raw_dataset.select(range(50 // 2))
raw_dataset = concatenate_datasets([raw_dataset for _ in range(2 * 50)])

In [8]:
def prepare_dataset(batch):
    # process audio
    sample = batch[audio_column_name]
    inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np")

    # process audio length
    batch[model_input_name] = inputs.get(model_input_name)[0]
    batch["input_length"] = len(sample["array"])

    # process targets
    input_str = batch[train_text_column_name]
    batch["labels"] = tokenizer(input_str).input_ids

    return batch

In [9]:
vectorized_dataset = raw_dataset.map(prepare_dataset, remove_columns=raw_dataset_features, keep_in_memory=True)

Map: 100%|███████████████████████████████████████████████████████████████████| 2500/2500 [01:36<00:00, 25.98 examples/s]


In [13]:
class TorchWhisperFeatureExtractor(WhisperFeatureExtractor):
    def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
        """
        Compute the log-mel spectrogram of the provided audio using torch filters. 
        """
        waveform = torch.from_numpy(waveform).type(torch.float32)

        window = torch.hann_window(self.n_fft)
        stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
        magnitudes = stft[..., :-1].abs() ** 2

        mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
        mel_spec = mel_filters.T @ magnitudes

        log_spec = torch.clamp(mel_spec, min=1e-10).log10()
        log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
        log_spec = (log_spec + 4.0) / 4.0
        return log_spec.numpy()

In [14]:
torch_feature_extractor = TorchWhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

In [19]:
def prepare_torch_dataset(batch):
    # process audio
    sample = batch[audio_column_name]
    inputs = flax_feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np")

    # process audio length
    batch[model_input_name] = inputs.get(model_input_name)[0]
    batch["input_length"] = len(sample["array"])

    # process targets
    input_str = batch[train_text_column_name]
    batch["labels"] = tokenizer(input_str).input_ids

    return batch

In [20]:
vectorized_torch_dataset = raw_dataset.map(prepare_torch_dataset, remove_columns=raw_dataset_features, keep_in_memory=True)

Map: 100%|██████████████████████████████████████████████████████████████████| 2500/2500 [00:23<00:00, 105.25 examples/s]


In [27]:
np.max(np.abs(np.array(vectorized_torch_dataset[0]["input_features"]) - np.array(vectorized_dataset[0]["input_features"])))

8.58306884765625e-06