<a href="https://colab.research.google.com/github/vasudevgupta7/gsoc-wav2vec2/blob/patch-4/notebooks/librispeech-evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wav2Vec2 inference on LibriSpeech dataset

In this notebook, we will be evaluating TensorFlow Wav2Vec2 using the checkpoint fine-tuned on 960h of LibriSpeech dataset.

In [1]:
!nvidia-smi

Wed Jun 23 15:27:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Let's start with basic setup and install `wav2vec2` package from this [repositary](https://github.com/vasudevgupta7/gsoc-wav2vec2).

In [2]:
!git clone https://github.com/vasudevgupta7/gsoc-wav2vec2 --branch=main && cd gsoc-wav2vec2 && pip3 install .

import os
os.chdir("./gsoc-wav2vec2/src")

Cloning into 'gsoc-wav2vec2'...
remote: Enumerating objects: 430, done.[K
remote: Counting objects: 100% (430/430), done.[K
remote: Compressing objects: 100% (248/248), done.[K
remote: Total 430 (delta 249), reused 343 (delta 173), pack-reused 0[K
Receiving objects: 100% (430/430), 380.84 KiB | 2.00 MiB/s, done.
Resolving deltas: 100% (249/249), done.
Processing /content/gsoc-wav2vec2
Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/e0/b4/9d92953d8cddc8450c859be12e3dbdd4c7754fb8def94c28b3b351c6ee4e/wandb-0.10.32-py2.py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 8.6MB/s 
[?25hCollecting huggingface-hub
  Downloading https://files.pythonhosted.org/packages/45/94/27f4f66d8d763f60204f447287cbe78d8bdf9c86d87dbc1fe26e792e727a/huggingface_hub-0.0.11-py3-none-any.whl
Collecting tensorflow_io
[?25l  Downloading https://files.pythonhosted.org/packages/e6/d2/6fd39a3519e325037462721092248b468ccbeeeb5dc870cea072655ee4b0/tensorflow_io-0.1

Now that we have installed required packages, lets download validation dataset from official LibriSpeech [website](https://www.openslr.org/12). It may take couple of seconds depending on your internet connection.

In [3]:
!wget https://www.openslr.org/resources/12/test-clean.tar.gz -P /content/gsoc-wav2vec2/data && tar -xf /content/gsoc-wav2vec2/data/test-clean.tar.gz -C /content/gsoc-wav2vec2/data/
!wget https://www.openslr.org/resources/12/test-other.tar.gz -P /content/gsoc-wav2vec2/data && tar -xf /content/gsoc-wav2vec2/data/test-other.tar.gz -C /content/gsoc-wav2vec2/data/

--2021-06-23 15:27:19--  https://www.openslr.org/resources/12/test-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 346663984 (331M) [application/x-gzip]
Saving to: ‘/content/gsoc-wav2vec2/data/test-clean.tar.gz’


2021-06-23 15:27:38 (17.9 MB/s) - ‘/content/gsoc-wav2vec2/data/test-clean.tar.gz’ saved [346663984/346663984]

--2021-06-23 15:27:41--  https://www.openslr.org/resources/12/test-other.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 328757843 (314M) [application/x-gzip]
Saving to: ‘/content/gsoc-wav2vec2/data/test-other.tar.gz’


2021-06-23 15:27:58 (18.8 MB/s) - ‘/content/gsoc-wav2vec2/data/test-other.tar.gz’ saved [328757843/328757843]



Let's import `Wav2Vec2Processor` and `Wav2Vec2ForCTC` from our installed `wav2vec2` package.

In [4]:
import tensorflow as tf
from wav2vec2 import Wav2Vec2Processor, Wav2Vec2ForCTC

Now, we will instantiate all the classes from their default configurations. Convenient `.from_pretrained(...)` method will enable us to download pre-trained/fine-tuned weights automatically from HuggingFace Hub.

In [5]:
model_id = "vasudevgupta/tf-wav2vec2-base-960h"

processor = Wav2Vec2Processor(is_tokenizer=False)
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

Downloading model weights from `https://huggingface.co/vasudevgupta/tf-wav2vec2-base-960h` ... Done
Total number of loaded variables: 213


`processor` will help us to convert raw speech into required format which can be accepted into our `Wav2Vec2ForCTC` model. Eg: Normalizing the speech w.r.to frames axis.

`tokenizer` will convert our model outputs into string and will take care of removal of special tokens (depending on your tokenizer configuration).

For getting out of box performance with TensorFlow-2, we will be decorating our forward pass with `tf.function(...)`. Argument `jit_compile=True` will result in compilation of python code using **XLA** and will fuse operations to be able to generate very efficient code for accelerators.

In [11]:
@tf.function(jit_compile=True)
def tf_forward(speech, training=False):
  tf_out = model(speech, training=training)
  return tf.squeeze(tf.argmax(tf_out, axis=-1))

It's time to write function for itertation over complete validation dataset. We will be collecting and storing predictions for each step in `list`.

In [23]:
from data_utils import LibriSpeechDataLoader, LibriSpeechDataLoaderArgs
from tqdm.auto import tqdm

def infer_librispeech(dataset: tf.data.Dataset):
  predictions = []
  labels = []
  for batch in tqdm(dataset, total=len(dataset), desc="LibriSpeech Inference ... "):
    speech, label = batch
    tf_out = tf_forward(speech, training=False)
    predictions.extend([tokenizer.decode(pred, group_tokens=True) for pred in tf_out.numpy().tolist()])
    labels.extend([tokenizer.decode(tgt, group_tokens=False) for tgt in label.numpy().tolist()])
  return predictions, labels

Now, we will define the arguments for our `DataLoader` used in `infer_librispeech(...)` and will perform the inference on complete validation dataset.

In [25]:
args = LibriSpeechDataLoaderArgs(data_dir="../data/LibriSpeech/test-clean", batch_size=32, audio_maxlen=500000, labels_maxlen=256)

dataset = LibriSpeechDataLoader(args)(seed=None)
dataset = dataset.take(2) # this will take 2 batches

DISCARDING 2 samples
LOADED 2618 FILES FROM ../data/LibriSpeech/test-clean


Following cell will take ~ 7 mins

In [None]:
predictions, labels = infer_librispeech(dataset)
list(zip(predictions, labels))

It's time to calculate **Word Error Rate (WER)** to be able to judge if our model performed well. We will be using `load_metric(...)` function from HuggingFace datasets to setup metric for us. First, let's install `datasets` library using `pip`.

In [14]:
!pip3 install datasets

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/08/a2/d4e1024c891506e1cee8f9d719d20831bac31cb5b7416983c4d2f65a6287/datasets-1.8.0-py3-none-any.whl (237kB)
[K     |████████████████████████████████| 245kB 7.5MB/s 
Collecting fsspec
[?25l  Downloading https://files.pythonhosted.org/packages/0e/3a/666e63625a19883ae8e1674099e631f9737bd5478c4790e5ad49c5ac5261/fsspec-2021.6.1-py3-none-any.whl (115kB)
[K     |████████████████████████████████| 122kB 53.6MB/s 
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/7d/4f/0a862cad26aa2ed7a7cd87178cbbfa824fc1383e472d63596a0d018374e7/xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243kB)
[K     |████████████████████████████████| 245kB 47.9MB/s 
Installing collected packages: fsspec, xxhash, datasets
Successfully installed datasets-1.8.0 fsspec-2021.6.1 xxhash-2.0.2


Let's install WER script using `load_metric(wer)` and compute metric value over our predictions.

In [None]:
from datasets import load_metric

wer = load_metric("wer")
wer.compute(references=labels, predictions=predictions)