In [1]:
import sys
sys.path.append("../src")

In [2]:
from wav2vec2 import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Processer
from transformers import Wav2Vec2ForCTC as HFWav2Vec2ForCTC

import tensorflow as tf
import torch
import numpy as np

def get_difference(tf_out, hf_out):
    return np.max(tf_out.numpy() - hf_out.numpy())

In [3]:
batch, _ = tf.audio.decode_wav(tf.io.read_file("../data/sample.wav"))
processor = Wav2Vec2Processer(is_tokenizer=False)
batch = tf.transpose(batch, perm=(1, 0))

batch = tf.concat([batch, batch], axis=0)

batch = processor(batch)
hf_batch = torch.from_numpy(batch.numpy()).float()

batch, hf_batch

(<tf.Tensor: shape=(2, 46797), dtype=float32, numpy=
 array([[ 0.00455413, -0.00263517,  0.00814878, ..., -0.00263517,
         -0.01701376, -0.02779771],
        [ 0.00455413, -0.00263517,  0.00814878, ..., -0.00263517,
         -0.01701376, -0.02779771]], dtype=float32)>,
 tensor([[ 0.0046, -0.0026,  0.0081,  ..., -0.0026, -0.0170, -0.0278],
         [ 0.0046, -0.0026,  0.0081,  ..., -0.0026, -0.0170, -0.0278]]))

In [4]:
tf_model = Wav2Vec2ForCTC.from_pretrained("/Users/vasudevgupta/Local/wav2vec2/wav2vec2-base-960h", input_shape=batch.shape)

Loading weights locally from `/Users/vasudevgupta/Local/wav2vec2/wav2vec2-base-960h`
Total number of loaded variables: 212


In [5]:
hf_model = HFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
tf_out = tf_model(batch, training=False)["logits"]

In [7]:
with torch.no_grad():
    hf_out = hf_model(hf_batch)["logits"]

In [8]:
tf_out.shape, hf_out.shape

(TensorShape([2, 145, 32]), torch.Size([2, 145, 32]))

In [9]:
print("difference in logits:", get_difference(tf_out, hf_out))

difference in logits: 0.003186226


In [10]:
def tf_forward(*args, **kwargs):
    return tf_model(*args, **kwargs)
tf_forward = tf.function(tf_forward, autograph=True, jit_compile=True)

In [16]:
tf_out = tf_forward(batch, training=False)["logits"]

In [17]:
print("difference in graph based model logits:", get_difference(tf_out, hf_out))

difference in graph based model logits: 0.0023155212
