<a href="https://colab.research.google.com/github/vasudevgupta7/gsoc-wav2vec2/blob/export/notebooks/wav2vec2-base-saved-model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to train TensorFlow saved-model with extra head

In this notebook, we will load pre-trained wav2vec2 model from TFHub and will train it on librispeech dataset by appending one extra head over the top of our pre-trained model.

## Setting Up

Before diving into it, let's see what GPU we got using `nvidia-smi`

In [1]:
!nvidia-smi

Sat Jul 17 12:41:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P0    28W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Following cell will clone the code repositary and will install all the dependencies.

In [2]:
!git clone https://github.com/vasudevgupta7/gsoc-wav2vec2 --branch=export

import sys
import os

os.chdir("gsoc-wav2vec2")
sys.path.append("src")

!pip3 install -qe .

fatal: destination path 'gsoc-wav2vec2' already exists and is not an empty directory.


In [4]:
# This cell will be removed after model get exported to TFHub
!wget https://huggingface.co/vasudevgupta/tf-wav2vec2-base/resolve/main/wav2vec2-base.tar.gz
!tar -xf wav2vec2-base.tar.gz

--2021-07-17 12:41:35--  https://huggingface.co/vasudevgupta/tf-wav2vec2-base/resolve/main/wav2vec2-base.tar.gz
Resolving huggingface.co (huggingface.co)... 15.197.130.34
Connecting to huggingface.co (huggingface.co)|15.197.130.34|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/vasudevgupta/tf-wav2vec2-base/ba29ac5ff1f78271a6c9e6466cedd221e811b5ed58020337d238bc14512de9f3 [following]
--2021-07-17 12:41:35--  https://cdn-lfs.huggingface.co/vasudevgupta/tf-wav2vec2-base/ba29ac5ff1f78271a6c9e6466cedd221e811b5ed58020337d238bc14512de9f3
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.85.144.69, 52.85.144.70, 52.85.144.56, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.85.144.69|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 387426816 (369M) [application/octet-stream]
Saving to: ‘wav2vec2-base.tar.gz.1’


2021-07-17 12:41:44 (39.9 MB/s) - ‘wav2vec2-base.tar.gz.1’ save

In [5]:
import tensorflow as tf
import tensorflow_hub as hub

from wav2vec2 import Wav2Vec2Config, CTCLoss

config = Wav2Vec2Config()



In [6]:
# TODO: update it to load from TFHub later
loaded = hub.load("saved-model")
print("Available signatures are:", list(loaded.signatures.keys()))

Available signatures are: ['infer', 'train']


In [7]:
pretrained_model = loaded.signatures["train"]
pretrained_model = hub.KerasLayer(pretrained_model, trainable=False)

lm_head = tf.keras.layers.Dense(config.vocab_size)

In [8]:
@tf.function(jit_compile=True)
def forward(batch):
    return lm_head(pretrained_model(batch)["output_0"])

In [9]:
BATCH_SIZE = 2
LEARNING_RATE = 5e-5
AUDIO_MAXLEN = 246000

In [10]:
forward(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
print("Number of trainable variables:", len(list(pretrained_model.trainable_variables) + lm_head.trainable_variables))

Number of trainable variables: 2


In [11]:
loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

In [13]:
!wget https://huggingface.co/datasets/vasudevgupta/gsoc-librispeech/resolve/main/train-clean-100/train-clean-100-0.tfrecord -P /content/gsoc-wav2vec2/data/train/

--2021-07-17 12:42:20--  https://huggingface.co/datasets/vasudevgupta/gsoc-librispeech/resolve/main/train-clean-100/train-clean-100-0.tfrecord
Resolving huggingface.co (huggingface.co)... 15.197.130.34
Connecting to huggingface.co (huggingface.co)|15.197.130.34|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/vasudevgupta/gsoc-librispeech/df6dfb983f6514a98fc05d2ee219f57b1286589d61c1271bb10a0ed3effd6ae8 [following]
--2021-07-17 12:42:20--  https://cdn-lfs.huggingface.co/datasets/vasudevgupta/gsoc-librispeech/df6dfb983f6514a98fc05d2ee219f57b1286589d61c1271bb10a0ed3effd6ae8
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.85.132.4, 52.85.132.34, 52.85.132.50, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.85.132.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 483730930 (461M) [application/octet-stream]
Saving to: ‘/content/gsoc-wav2vec2/data/train/train-cl

In [14]:
ls /content/gsoc-wav2vec2/data/train/

train-clean-100-0.tfrecord  train-clean-100-0.tfrecord.1


In [15]:
ls

[0m[01;34massets[0m/      [01;34mnotebooks[0m/        [01;34msaved-model[0m/  [01;34msrc[0m/        wav2vec2-base.tar.gz
[01;34mdata[0m/        readme.md         setup.cfg     [01;34mtests[0m/      wav2vec2-base.tar.gz.1
LICENSE.txt  requirements.txt  setup.py      vocab.json


In [16]:
from data_utils import LibriSpeechDataLoaderArgs, LibriSpeechDataLoader

data_args = LibriSpeechDataLoaderArgs(
    from_tfrecords=True,
    tfrecords=["data/train/train-clean-100-0.tfrecord"],
    audio_maxlen=AUDIO_MAXLEN,
    batch_size=BATCH_SIZE,
)
dataloader = LibriSpeechDataLoader(data_args)
dataset = dataloader(seed=None)

num_batches = 2
dataset = dataset.take(num_batches)

Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE
Reading tfrecords from ['data/train/train-clean-100-0.tfrecord'] ... Done!


In [12]:
@tf.function
def train_step(speech, labels):
    with tf.GradientTape() as gtape:
        speech = forward(speech)
        loss = loss_fn(labels, speech)
    trainable_variables = list(pretrained_model.trainable_variables) + lm_head.trainable_variables
    grads = gtape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))
    return loss

In [None]:
from tqdm import tqdm

pbar = tqdm(dataset, total=num_batches)
for speech, label in pbar:
    loss = train_step(speech, label)
    pbar.set_postfix(tr_loss=loss)

In [20]:
@tf.function
def eval_step(speech, labels):
    speech = forward(speech)
    loss = loss_fn(labels, speech)
    return loss

In [None]:
pbar = tqdm(dataset, total=num_batches)
for speech, label in pbar:
    loss = eval_step(speech, label)
    pbar.set_postfix(val_loss=loss)