# Install dependencies

In [4]:
%%capture
!pip install optimum[onnxruntime] transformers git+https://github.com/openai/whisper.git

# Convert Whisper to ONNX

In [5]:
# -*- coding: utf-8 -*-
import warnings
warnings.filterwarnings("ignore")

from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from transformers import (
    set_seed,
    AutoProcessor
)
from pathlib import Path
import os

SEED = 42

# Export vanilla & optimized onnx model
def export_vanilla_optimized_onnx(model_checkpoint):
    set_seed(SEED)
    processor = AutoProcessor.from_pretrained(model_checkpoint)

    # Vanilla
    model = ORTModelForSpeechSeq2Seq.from_pretrained(model_checkpoint, from_transformers=True, use_cache=True)
    onnx_path = Path(os.path.join("exported_onnx_models/", model_checkpoint))
    model.save_pretrained(onnx_path)
    processor.save_pretrained(onnx_path)


export_vanilla_optimized_onnx('openai/whisper-tiny')

# Inference

In [6]:
import numpy as np
import os
import json
import whisper

def transcribe(start_tokens, file, encoder, decoder, tokenizer, skip_special_tokens=True):
    def run(seed:np.ndarray, hidden_states)->int:
        decoder_output = decoder.run(None, {'input_ids' : np.expand_dims(seed, axis=0), 'encoder_hidden_states': hidden_states})[0]
        cleaned = np.argmax(decoder_output, axis=-1)
        last_token = cleaned[0,-1]
        return last_token
    
    audio = whisper.load_audio(file)
    audio = whisper.pad_or_trim(audio)
    mel = whisper.log_mel_spectrogram(audio)
    mel = np.expand_dims(mel,0)

    hidden_states = encoder.run(None, {'input_features': mel})[0]

    tokens = start_tokens
    while(True):
        last_token = run(tokens, hidden_states)
        tokens.append(last_token)
        if tokens[-1] == 50257:
            return tokenizer.batch_decode(np.expand_dims(tokens, axis=0), skip_special_tokens=skip_special_tokens)[0]

In [7]:
!wget https://huggingface.co/datasets/osanseviero/dummy_ja_audio/resolve/main/result.flac

--2023-01-19 21:45:14--  https://huggingface.co/datasets/osanseviero/dummy_ja_audio/resolve/main/result.flac
Resolving huggingface.co (huggingface.co)... 3.231.67.228, 54.235.118.239, 2600:1f18:147f:e800:671:b733:ecf3:a585, ...
Connecting to huggingface.co (huggingface.co)|3.231.67.228|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/89/0e/890e934652e9c7b9c7c24080b0215138c3c32ab2ceb3f6416f8d1445e4a6adfb/afcdb9f2ca46039a983e43273addcc75bf83ca503452a4e7c908afc2e04dce61?response-content-disposition=attachment%3B%20filename%3D%22result.flac%22&Expires=1674423914&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzg5LzBlLzg5MGU5MzQ2NTJlOWM3YjljN2MyNDA4MGIwMjE1MTM4YzNjMzJhYjJjZWIzZjY0MTZmOGQxNDQ1ZTRhNmFkZmIvYWZjZGI5ZjJjYTQ2MDM5YTk4M2U0MzI3M2FkZGNjNzViZjgzY2E1MDM0NTJhNGU3YzkwOGFmYzJlMDRkY2U2MT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPWF0dGFjaG1lbnQlM0IlMjBmaWxlbmFtZSUzRCUyMnJlc3VsdC5mbGFjJ

In [9]:
import onnxruntime as ort
from transformers import (
    AutoTokenizer
)

model_id = "openai/whisper-tiny"
encoder_model = '/content/exported_onnx_models/openai/whisper-tiny/encoder_model.onnx'
decoder_model = '/content/exported_onnx_models/openai/whisper-tiny/decoder_model.onnx'
encoder_ort_sess = ort.InferenceSession(encoder_model)
decoder_ort_sess = ort.InferenceSession(decoder_model)

tokenizer = AutoTokenizer.from_pretrained(model_id)

start_tokens = [50258, 50266, 50358, 50363] #<|startoftranscript|><|ja|><|translate|><|notimestamps|>
text = transcribe(start_tokens, '/content/result.flac', encoder_ort_sess, decoder_ort_sess, tokenizer, skip_special_tokens=True)
text

' I think Kimura-san is a good person.'