<a href="https://colab.research.google.com/github/vasudevgupta7/gsoc-wav2vec2/blob/deploy/notebooks/wav2vec2_onnx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wav2Vec2 ONNX

In this notebook, we will be exporting TF Wav2Vec2 model into ONNX and will compare ONNX exported and TF model latency on CPU.

In [13]:
!pip3 install -qU tf2onnx onnxruntime

# TODO: update training-v2 by main after merge of training-v2
!pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@training-v2

In [14]:
import tensorflow as tf

In [15]:
from wav2vec2 import Wav2Vec2ForCTC

model_id = "vasudevgupta/gsoc-wav2vec2-base-960h"
model = Wav2Vec2ForCTC.from_pretrained(model_id)

@tf.function(jit_compile=True)
def jitted_forward(speech):
    return model(speech)

Downloading model weights from https://huggingface.co/vasudevgupta/gsoc-wav2vec2-base-960h ... Done
Total number of loaded variables: 213


In [16]:
from contextlib import contextmanager
import time

@contextmanager
def timeit(prefix="Time taken:"):
  start = time.time()
  yield
  time_taken = time.time() - start
  print(prefix, time_taken, "seconds")

In [17]:
!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/sample.wav

--2021-08-08 01:53:26--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/sample.wav
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/sample.wav [following]
--2021-08-08 01:53:26--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/sample.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 93638 (91K) [audio/wav]
Saving to: ‘sample.wav’


2021-08-08 01:53:26 (6.41 MB/s) - ‘sample.wav’ saved [93638/93638]



In [18]:
import soundfile as sf
AUDIO_MAXLEN = 246000

speech, _ = sf.read("sample.wav")
speech = tf.constant(speech, dtype=tf.float32)[None, :AUDIO_MAXLEN]

In [19]:
import tf2onnx
ONNX_PATH = "onnx-wav2vec2.onnx"

input_signature = (tf.TensorSpec((None, speech.shape[1]), tf.float32, name="speech"),)
_ = tf2onnx.convert.from_keras(model, input_signature=input_signature, output_path=ONNX_PATH)

Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`


In [20]:
import onnxruntime as rt
session = rt.InferenceSession(ONNX_PATH)

In [21]:
import numpy as np

onnx_outputs = session.run(None, {"speech": speech.numpy()})
tf_outputs = jitted_forward(speech)

np.allclose(onnx_outputs, tf_outputs.numpy(), atol=1e-2)

True

In [24]:
with timeit(prefix="JIT Compiled Wav2vec2 time taken:"):
  jitted_forward(speech)

with timeit(prefix="Eager mode time taken:"):
  model(speech)

with timeit(prefix="ONNX-Wav2Vec2 time taken:"):
  session.run(None, {"speech": speech.numpy()})

JIT Compiled Wav2vec2 time taken: 1.9526395797729492 seconds
Eager mode time taken: 1.234731674194336 seconds
ONNX-Wav2Vec2 time taken: 0.795966386795044 seconds
