<a href="https://colab.research.google.com/github/vasudevgupta7/gsoc-wav2vec2/blob/export-v2/wav2vec2_base_saved_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to train TensorFlow saved-model with extra head

In this notebook, we will load the pre-trained wav2vec2 model from [TFHub](https://tfhub.dev) and will train it on [librispeech dataset](https://huggingface.co/datasets/librispeech_asr) by appending one extra head over the top of our pre-trained model.

## Setting Up

Before diving into it, let's see what GPU we got using `nvidia-smi`

In [None]:
!nvidia-smi

Sat Jul 17 12:41:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P0    28W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

The following cell will clone my code repository ([`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2)) and will install all the dependencies.

In [None]:
!git clone https://github.com/vasudevgupta7/gsoc-wav2vec2 --branch=export

import sys
import os

os.chdir("gsoc-wav2vec2")
sys.path.append("src")

!pip3 install -qe .

fatal: destination path 'gsoc-wav2vec2' already exists and is not an empty directory.


In [None]:
# This cell will be removed after model get exported to TFHub
!wget https://huggingface.co/vasudevgupta/tf-wav2vec2-base/resolve/main/wav2vec2-base.tar.gz
!tar -xf wav2vec2-base.tar.gz

--2021-07-17 12:41:35--  https://huggingface.co/vasudevgupta/tf-wav2vec2-base/resolve/main/wav2vec2-base.tar.gz
Resolving huggingface.co (huggingface.co)... 15.197.130.34
Connecting to huggingface.co (huggingface.co)|15.197.130.34|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/vasudevgupta/tf-wav2vec2-base/ba29ac5ff1f78271a6c9e6466cedd221e811b5ed58020337d238bc14512de9f3 [following]
--2021-07-17 12:41:35--  https://cdn-lfs.huggingface.co/vasudevgupta/tf-wav2vec2-base/ba29ac5ff1f78271a6c9e6466cedd221e811b5ed58020337d238bc14512de9f3
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.85.144.69, 52.85.144.70, 52.85.144.56, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.85.144.69|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 387426816 (369M) [application/octet-stream]
Saving to: ‘wav2vec2-base.tar.gz.1’


2021-07-17 12:41:44 (39.9 MB/s) - ‘wav2vec2-base.tar.gz.1’ save

## Model setup using `TFHub`

We will start by importing all the important libraries & modules.

In [None]:
import tensorflow as tf
import tensorflow_hub as hub

from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()



We will be loading pre-trained saved-model directly from TFHub. [`hub.load(...)`](https://www.tensorflow.org/hub/api_docs/python/hub/load) will download the pre-trained model first and will call [`tf.saved_model.load(...)`](https://www.tensorflow.org/api_docs/python/tf/saved_model/load).

In [None]:
# TODO: update it to load from TFHub later
loaded = hub.load("saved-model")
print("Available signatures are:", list(loaded.signatures.keys()))

Available signatures are: ['infer', 'train']


We can see 2 signatures above, this is because while saving model with [`tf.saved_model.save(...)`](https://www.tensorflow.org/api_docs/python/tf/saved_model/save), `infer` & `train` signatures were provided (you can refer this [script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/export2hub.py)). We will be using the `train` signature for training this model on our downstream task. We will be wrapping our model signature with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer) to be able to freeze the pre-trained variables (Note: For this notebook, we will only train the extra head & not the pre-trained model).

`lm_head` is a very simple dense layer having an output dimension of `vocab_size` as we want to predict probabilities of each token in the vocabulary at each time step.

In [None]:
pretrained_model = loaded.signatures["train"]
pretrained_model = hub.KerasLayer(pretrained_model, trainable=False)

lm_head = tf.keras.layers.Dense(config.vocab_size)

Alright, let's define our forward pass by combining our pre-trained model & LM head. It's important to wrap it with `tf.function(...)` to be able to get performance benefits during training.

Additionally, we will be passing `jit_compile=True` to compile (using XLA) our model graph on the accelerator (i.e GPU) & fuse many operations to get out-of-box performance.

In [None]:
@tf.function(jit_compile=True)
def forward(batch):
    return lm_head(pretrained_model(batch)["output_0"])

## Setting training state

In the following cell, we will be defining some of the hyper-parameters to be used in this notebook. `AUDIO_MAXLEN` is intentionally set to `246000` as the model signature only accepts static sequence length of `246000`.

In [None]:
BATCH_SIZE = 2
LEARNING_RATE = 5e-5
AUDIO_MAXLEN = 246000

In TensorFlow, model weights are build only when `model.__call__` is called for the first time, so the following cell will build the model weights for us.

Further, we will be checking the total number of trainable variables to assert that the pre-trained model is frozen.

In [None]:
forward(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
print("Number of trainable variables:", len(list(pretrained_model.trainable_variables) + lm_head.trainable_variables))

Number of trainable variables: 2


Now, we need to define `loss_fn` and optimizer to be able to train the model. The following cell will do that for us. We will be using the `Adam` optimizer for simplicity. `CTCLoss` is a very common loss type that is used for tasks (like `ASR`) where input sub-parts can't be easily aligned with output sub-parts. You can read more about CTC-loss from this amazing [blog post](https://distill.pub/2017/ctc/).


`CTCLoss` (from [`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2) package) accepts 3 arguments: `config`, `model_input_shape` & `division_factor`. If `division_factor=1`, then loss will simply get summed, so pass `division_factor` accordingly to get mean over batch.

In [None]:
from wav2vec2 import CTCLoss

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

## Setting up `tf.data.Dataset`


We need to fetch some data to perform training. I have saved some part of `LibriSpeech tfrecords` in the  `HuggingFace Hub` ([see this](https://huggingface.co/datasets/vasudevgupta/gsoc-librispeech/tree/main)), so we will download those `tfrecords` using `wget`.

Note: While converting [original LibriSpeech dataset](http://www.openslr.org/12) (i.e. `.flac` files) into `tfrecords`, some standard pre-processing is done. You can refer [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/make_tfrecords.py) to know more on that.

In [None]:
!wget https://huggingface.co/datasets/vasudevgupta/gsoc-librispeech/resolve/main/train-clean-100/train-clean-100-0.tfrecord -P /content/gsoc-wav2vec2/data/train/

--2021-07-17 12:42:20--  https://huggingface.co/datasets/vasudevgupta/gsoc-librispeech/resolve/main/train-clean-100/train-clean-100-0.tfrecord
Resolving huggingface.co (huggingface.co)... 15.197.130.34
Connecting to huggingface.co (huggingface.co)|15.197.130.34|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/vasudevgupta/gsoc-librispeech/df6dfb983f6514a98fc05d2ee219f57b1286589d61c1271bb10a0ed3effd6ae8 [following]
--2021-07-17 12:42:20--  https://cdn-lfs.huggingface.co/datasets/vasudevgupta/gsoc-librispeech/df6dfb983f6514a98fc05d2ee219f57b1286589d61c1271bb10a0ed3effd6ae8
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.85.132.4, 52.85.132.34, 52.85.132.50, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.85.132.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 483730930 (461M) [application/octet-stream]
Saving to: ‘/content/gsoc-wav2vec2/data/train/train-cl

In [None]:
ls /content/gsoc-wav2vec2/data/train/

train-clean-100-0.tfrecord  train-clean-100-0.tfrecord.1


Following cell will setup `tf.data.Dataset` object for you using my `gsoc-wav2vec2` package.

In [None]:
from data_utils import LibriSpeechDataLoaderArgs, LibriSpeechDataLoader

data_args = LibriSpeechDataLoaderArgs(
    from_tfrecords=True,
    tfrecords=["data/train/train-clean-100-0.tfrecord"],
    audio_maxlen=AUDIO_MAXLEN,
    batch_size=BATCH_SIZE,
)
dataloader = LibriSpeechDataLoader(data_args)
dataset = dataloader(seed=None)

Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE
Reading tfrecords from ['data/train/train-clean-100-0.tfrecord'] ... Done!


Since this notebook is made for demonstration purposes, we will be taking first `num_batches` and will perform training over only that. You are encouraged to train on the whole dataset though.

In [None]:
num_batches = 2
dataset = dataset.take(num_batches)

## Training & Evaluation

Let's define our `train_step` function now. There are 3 main steps in `train_step`: 
1. forward pass with variables tracking
2. backward pass for calculating gradients
3. variables update to minimize training loss

All the trainable variables in the scope of `tf.GradientTape(...)` will get tracked during the forward pass. Further, `.gradient(...)` will help us find gradient of loss w.r.to those tracked variables & `.apply_gradients(...)` will update the trainable variables based on our `optimizer` defined above.

In [None]:
@tf.function
def train_step(speech, labels):
    with tf.GradientTape() as gtape:
        speech = forward(speech)
        loss = loss_fn(labels, speech)
    trainable_variables = list(pretrained_model.trainable_variables) + lm_head.trainable_variables
    grads = gtape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))
    return loss

Let's kick start training finally !!!

We will iterate over our dataset (instance of `tf.data.Dataset`) and each batch will be fed to `train_step(...)` for calculating loss, gradients & updating parameters.

In [None]:
from tqdm import tqdm

pbar = tqdm(dataset, total=num_batches)
for speech, label in pbar:
    loss = train_step(speech, label)
    pbar.set_postfix(tr_loss=loss)

Let's compute loss over validation dataset using `eval_step(...)` defined in the following cell.

In [None]:
@tf.function
def eval_step(speech, labels):
    speech = forward(speech)
    loss = loss_fn(labels, speech)
    return loss

We are using the same dataset just for demonstration purposes. In general, we should use separate data (generally called `validation/dev` data) sampled before initiating training.

In [None]:
pbar = tqdm(dataset, total=num_batches)
for speech, label in pbar:
    loss = eval_step(speech, label)
    pbar.set_postfix(val_loss=loss)

Finally, we have reached an end to this notebook. But it's not an end of learning TensorFlow for speech-related tasks, this [repositary](https://github.com/vasudevgupta7/gsoc-wav2vec2) contains some more amazing tutorials. Feel free to go through them. In case you encounter any bug, please create an issue [here](https://github.com/vasudevgupta7/gsoc-wav2vec2/issues).