In [None]:
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration
from datasets import load_dataset, DatasetDict
from datasets import Audio
import torch
import timeit

In [None]:
common_voice = DatasetDict()

common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation")
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")

print(common_voice)

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small", torchscript=True)

In [None]:
model.config.torchscript  # assert True

In [None]:
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
sample = common_voice['train'][0]

In [None]:
# An example input you would normally provide to your model's forward() method.
input_features = processor(
    sample["audio"]["array"],
    sampling_rate=feature_extractor.sampling_rate,
    return_attention_mask=True,
    return_tensors="pt"
    ).input_features

attention_mask = processor(
    sample["audio"]["array"],
    sampling_rate=feature_extractor.sampling_rate,
    return_attention_mask=True,
    return_tensors="pt"
    ).attention_mask

decoder_input_ids = tokenizer(sample['sentence'], return_tensors="pt").input_ids

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
model.eval()
traced_model = torch.jit.trace(model, (input_features, attention_mask, decoder_input_ids))

In [None]:
start_time = timeit.default_timer()
traced_model(input_features, attention_mask, decoder_input_ids)
elapsed = timeit.default_timer() - start_time

elapsed

In [None]:
start_time = timeit.default_timer()
model(input_features, attention_mask, decoder_input_ids)
elapsed = timeit.default_timer() - start_time

elapsed