<a href="https://colab.research.google.com/github/vasudevgupta7/gsoc-wav2vec2/blob/export-v3/notebooks/wav2vec2_saved_model_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning with an extra head

In this notebook, we will load the pre-trained wav2vec2 model from [TFHub](https://tfhub.dev) and will fine-tune it on [LibriSpeech dataset](https://huggingface.co/datasets/librispeech_asr) by appending Language Modeling head (LM) over the top of our pre-trained model. The underlying task is to build a model for **Automatic Speech Recognition** i.e. given some speech, model should be able to transcribe it into text.

## Setting Up

Before diving into it, let's see what GPU we got using `nvidia-smi`

In [1]:
!nvidia-smi

Tue Jul 27 11:25:53 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

The following cell will clone my code repository ([`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2)) and will install all the dependencies.

In [2]:
!pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main

[K     |████████████████████████████████| 1.8 MB 10.1 MB/s 
[K     |████████████████████████████████| 43 kB 2.2 MB/s 
[K     |████████████████████████████████| 50 kB 9.0 MB/s 
[K     |████████████████████████████████| 170 kB 62.0 MB/s 
[K     |████████████████████████████████| 133 kB 55.4 MB/s 
[K     |████████████████████████████████| 97 kB 7.5 MB/s 
[K     |████████████████████████████████| 1.8 MB 59.1 MB/s 
[K     |████████████████████████████████| 63 kB 2.4 MB/s 
[?25h  Building wheel for wav2vec2 (setup.py) ... [?25l[?25hdone
  Building wheel for python-Levenshtein (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


## Model setup using `TFHub`

We will start by importing all the important libraries & modules.

In [4]:
import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)

TF version: 2.5.0


First, we will download our model from TFHub & will wrap our model signature with [`hub.KerasLayer`](https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer) to be able to use this model like any other keras layer. Fortunately, `hub.KerasLayer` can do both in just 1 line.

**Note:** When loading model with `hub.KerasLayer`, model becomes a bit opaque but sometimes we need finer controls over the model, then we can load the model with `tf.keras.models.load_model(...)`.

In [5]:
# TODO: change after export to TFHub

# pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=False)
!wget https://huggingface.co/vasudevgupta/gsoc-wav2vec2/resolve/main/saved-model.tar.gz
!tar -xf saved-model.tar.gz

pretrained_layer = hub.KerasLayer("saved-model", trainable=True)

--2021-07-27 11:26:27--  https://huggingface.co/vasudevgupta/gsoc-wav2vec2/resolve/main/saved-model.tar.gz
Resolving huggingface.co (huggingface.co)... 15.197.130.34
Connecting to huggingface.co (huggingface.co)|15.197.130.34|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/vasudevgupta/gsoc-wav2vec2/a51f3029eaf96da1a4ed2de852a9168057b0605a312fbb6ae86324d892117c5e [following]
--2021-07-27 11:26:28--  https://cdn-lfs.huggingface.co/vasudevgupta/gsoc-wav2vec2/a51f3029eaf96da1a4ed2de852a9168057b0605a312fbb6ae86324d892117c5e
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 52.84.169.23, 52.84.169.2, 52.84.169.48, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|52.84.169.23|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 224928775 (215M) [application/x-gzip]
Saving to: ‘saved-model.tar.gz’


2021-07-27 11:26:36 (28.7 MB/s) - ‘saved-model.tar.gz’ saved [224928775/224928775]



You can refer this [script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/export2hub.py) in case you are interested in model exporting script. Object `pretrained_layer` is the freezed version of [`Wav2Vec2Model`](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/wav2vec2/modeling.py). Pre-trained weights are converted from HuggingFace PyTorch [pre-trained weights](https://huggingface.co/facebook/wav2vec2-base) using [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/convert_torch_to_tf.py).

Originally, wav2vec2 was pre-trained with a masked language modelling approach with the objective to identify the true quantized latent speech representation for a masked time step. You can read more about the training objective in the paper- [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477).

Now, we will be defining a few constants and hyper-parameters which will be useful in the next few cells. `AUDIO_MAXLEN` is intentionally set to `246000` as the model signature only accepts static sequence length of `246000`.

In [6]:
AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

In the following cell, we will wrap `pretrained_layer` & a dense layer (LM head) with the [TensorFlow's Functional API](https://www.tensorflow.org/guide/keras/functional).

In [7]:
inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

The dense layer (defined above) is having an output dimension of `vocab_size` as we want to predict probabilities of each token in the vocabulary at each time step.

## Setting up training state

Alright, let's define our training forward pass by calling the model with `training=True` and wrapping it with `tf.function(...)`. It's important to wrap it with `tf.function(...)` to be able to get performance benefits during training.

Additionally, we will be passing `jit_compile=True` to compile (using XLA) our model graph on the accelerators (i.e GPUs/TPUs) & fuse many operations to get out-of-box performance.

In [8]:
@tf.function(jit_compile=True)
def forward(batch):
    return model(batch, training=True)

In TensorFlow, model weights are build only when `model.__call__` is called for the first time, so the following cell will build the model weights for us. Further, we will be running `model.summary()` for checking the total number of trainable parameters.

In [9]:
forward(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 246000)]          0         
_________________________________________________________________
keras_layer (KerasLayer)     (None, 768, 768)          94371712  
_________________________________________________________________
dense (Dense)                (None, 768, 32)           24608     
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________


Now, we need to define `loss_fn` and optimizer to be able to train the model. The following cell will do that for us. We will be using the `Adam` optimizer for simplicity. `CTCLoss` is a very common loss type that is used for tasks (like `ASR`) where input sub-parts can't be easily aligned with output sub-parts. You can read more about CTC-loss from this amazing [blog post](https://distill.pub/2017/ctc/).


`CTCLoss` (from [`gsoc-wav2vec2`](https://github.com/vasudevgupta7/gsoc-wav2vec2) package) accepts 3 arguments: `config`, `model_input_shape` & `division_factor`. If `division_factor=1`, then loss will simply get summed, so pass `division_factor` accordingly to get mean over batch.

In [10]:
from wav2vec2 import CTCLoss

LEARNING_RATE = 1e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

## Loading & Pre-processing data

Let's now download the LibriSpeech dataset from the [official website](http://www.openslr.org/12) and set it up.

In [11]:
!wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
!tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/

--2021-07-27 11:27:52--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’


2021-07-27 11:28:10 (18.8 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]



**Note:** We are using `dev-clean` configuration as this notebook is just for demonstration purposes, so we just need small data.

In [12]:
ls ./data/train/

dev-clean.tar.gz  [0m[01;34mLibriSpeech[0m/


Our dataset lies in `LibriSpeech` directory. Let's further narrow down & choose a sub-directory to see few files.

In [13]:
data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)

Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0016.flac', '2428-83705-0022.flac', '2428-83705-0014.flac', '2428-83705-0015.flac', '2428-83705-0021.flac', '2428-83705-0032.flac', '2428-83705-0013.flac', '2428-83705-0010.flac', '2428-83705-0027.flac', '2428-83705-0033.flac', '2428-83705-0036.flac', '2428-83705-0003.flac', '2428-83705-0034.flac', '2428-83705-0009.flac', '2428-83705-0035.flac', '2428-83705-0023.flac', '2428-83705-0012.flac', '2428-83705-0041.flac', '2428-83705-0042.flac', '2428-83705-0007.flac', '2428-83705-0006.flac', '2428-83705-0024.flac', '2428-83705-0020.flac', '2428-83705-0030.flac', '2428-83705-0029.flac', '2428-83705-0025.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0038.flac', '2428-83705-0028.flac', '2428-83705-0001.flac', '2428-83705-0037.flac', '2428-83705-0017.flac', '2428-83705-0000.flac', '2428-83705-0004.flac', '2428-83705-0026.flac', '2428-83705-0002.flac', '2428-83705-0031.flac', '2428-83705-0018.flac', '24

Alright, so each sub-directory is having many `.flac` files and single `.txt` file. `.txt` file will have text transcriptions for all the speech samples (i.e. `.flac` files) present in that sub-directory.

In following cell, we will define function for loading & formatting the text data into memory.

In [14]:
def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Similary, we will define a function for loading speech sample from `.flac` file.

`REQUIRED_SAMPLE_RATE` is set to `16000` as wav2vec2 was pre-trained with `16K` frequency and it's recommended to train it further without any major change in data distribution due to frequency.

In [15]:
import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Now, let's have a look at some sample.

In [16]:
from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)

Text Transcription: HER FATHER IS A MOST REMARKABLE PERSON TO SAY THE LEAST 
Audio:


Now, we will combine all the speech & text samples and will define the function (in next cell) for that purpose.

In [17]:
def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(txt_samples[file_id], speech_samples[file_id]) for file_id in speech_samples.keys()]
  return samples

It's time to have a look at a few samples ...

In [18]:
samples = fetch_sound_text_mapping(data_dir)
samples[:5]

[('THERE WERE NO SIGNS OF FALTERING ABOUT HER FLOW OF LANGUAGE',
  array([-0.00036621, -0.00015259, -0.00012207, ..., -0.0005188 ,
         -0.00048828, -0.00048828])),
 ("A BIRD IN THE HAND IS WORTH TWO IN A BUSH' AND IT WILL BE SOMETHING TO HAVE BY US",
  array([0.00085449, 0.00073242, 0.0005188 , ..., 0.00048828, 0.00054932,
         0.0005188 ])),
 ('THERE SHE OWNS A COTTAGE OR IT MAY BE A PIGSTYE FOR ALL I KNOW',
  array([-0.00027466, -0.00033569, -0.00036621, ...,  0.00021362,
          0.        ,  0.        ])),
 ('WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT',
  array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04])),
 ('I SHALL MAKE PAPA GIVE ME FIVE HUNDRED POUNDS AT LEAST',
  array([0.00036621, 0.00027466, 0.00015259, ..., 0.00039673, 0.00048828,
         0.00079346]))]

Let's pre-process the data now !!!

We will first define the tokenizer & processor using `gsoc-wav2vec2` package. Then, we will do very simple pre-processing. Speech will be normalized over time axis and text will be tokenized using `processor` and `tokenizer` respectively.

In [19]:
from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))

Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE


Now, we will define the python generator to call the preprocessing functions we defined in above cells.

In [20]:
def inputs_generator(samples):
  for text, speech in samples:
    yield preprocess_text(text), preprocess_speech(speech)

In [21]:
from functools import partial
generator = partial(inputs_generator, samples=samples)
next(iter(generator()))

(<tf.Tensor: shape=(59,), dtype=int32, numpy=
 array([ 6, 11,  5, 13,  5,  4, 18,  5, 13,  5,  4,  9,  8,  4, 12, 10, 21,
         9, 12,  4,  8, 20,  4, 20,  7, 15,  6,  5, 13, 10,  9, 21,  4,  7,
        24,  8, 16,  6,  4, 11,  5, 13,  4, 20, 15,  8, 18,  4,  8, 20,  4,
        15,  7,  9, 21, 16,  7, 21,  5], dtype=int32)>,
 <tf.Tensor: shape=(58240,), dtype=float32, numpy=
 array([-0.00590616, -0.00194895, -0.00138363, ..., -0.00873275,
        -0.00816743, -0.00816743], dtype=float32)>)

## Setting up `tf.data.Dataset`

Following cell will setup `tf.data.Dataset` object using its `.from_generator(...)` method. We will be using the `generator` object, we defined in the above cell.

**Note:** For distributed training (especially on TPUs), `.from_generator(...)` doesn't work currently and it is recommended to train on data stored in `.tfrecord` format. The TFRecords should ideally be stored inside a GCS Bucket in order for the TPUs to work to the fullest extent.

You can refer to [this script](https://github.com/vasudevgupta7/gsoc-wav2vec2/blob/main/src/make_tfrecords.py) for more details on how to convert LibriSpeech data into tfrecords.

In [22]:
output_signature = (
    tf.TensorSpec(shape=(None), dtype=tf.int32),
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
)
dataset = tf.data.Dataset.from_generator(generator, output_signature=output_signature)

Let's shuffle the dataset using `.shuffle(...)` method. Argument buffer size leads to approximate shuffling as many times the complete dataset can't be fitted into memory for actual shuffling (Eg. complete LibriSpeech tfrecords takes around 250 GB on disk).

In [23]:
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

We will pass the dataset into multiple batches, so let's prepare batches in the following cell. Now, all the sequences in a batch should be padded to a constant length. We will use the`.padded_batch(...)` method for that purpose. We also need to restrict sequence length to some particular value as some of the sequences are very long.

In [24]:
dataset = dataset.map(lambda labels, speech: (labels[: LABEL_MAXLEN], speech[: AUDIO_MAXLEN]))
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(LABEL_MAXLEN, AUDIO_MAXLEN), padding_values=(0, 0.))

Accelerators (like GPUs/TPUs) are very fast and often data-loading (& pre-processing) becomes the bottleneck during training as the data-loading part happens on CPUs. This can increase the training time significantly especially when there is a lot of online pre-processing involved or data is streamed online from GCS buckets. To handle those issues, `tf.data.Dataset` offers the `.prefetch(...)` method. This method helps in preparing the next few batches in parallel (on CPUs) while the model is making predictions (on GPUs/TPUs) on the current batch.

In [25]:
dataset = dataset.prefetch(tf.data.AUTOTUNE)

Since this notebook is made for demonstration purposes, we will be taking first `num_batches` and will perform training over only that. You are encouraged to train on the whole dataset though.

In [38]:
num_train_batches = 16
num_val_batches = 6

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

## Training

Let's define our `train_step` function now. There are 3 main steps in `train_step`: 
1. forward pass with variables tracking
2. backward pass for calculating gradients
3. variables update to minimize training loss

All the trainable variables in the scope of `tf.GradientTape(...)` will get tracked during the forward pass. Further, `.gradient(...)` will help us find gradient of loss w.r.to those tracked variables & `.apply_gradients(...)` will update the trainable variables based on our `optimizer` defined above.

In [27]:
@tf.function
def train_step(speech, labels):
    with tf.GradientTape() as gtape:
        speech = forward(speech)
        loss = loss_fn(labels, speech)
    grads = gtape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

Let's kick start training finally !!!

We will iterate over our dataset (instance of `tf.data.Dataset`) and each batch will be fed to `train_step(...)` for calculating loss, gradients & updating parameters.

In [28]:
from tqdm.auto import tqdm
EPOCHS = 3

pbar = tqdm(range(EPOCHS), total=EPOCHS)
for e in pbar:
  running_loss, steps = tf.constant(0.), 0
  for labels, speech in train_dataset:
      loss = train_step(speech, labels)
      running_loss += loss
      steps += 1
  pbar.set_postfix(tr_loss=running_loss.numpy().item()/steps, epoch=e)

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.



## Evaluation

Let's compute loss over validation dataset using `eval_step(...)` defined in the following cell.

In [29]:
@tf.function(jit_compile=True)
def eval_fwd(batch):
  return model(batch, training=False)

@tf.function
def eval_step(speech, labels):
    speech = eval_fwd(speech)
    loss = loss_fn(labels, speech)
    return loss, tf.argmax(speech, axis=-1)

We need to compute `WER` (word error rate) over our validation data. We will use `load_metric(...)` function from [HuggingFace datasets](https://huggingface.co/docs/datasets/) library. Let's first install the `datasets` library using `pip` and then define the `metric` object.

In [30]:
!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")

[K     |████████████████████████████████| 542 kB 7.5 MB/s 
[K     |████████████████████████████████| 76 kB 5.6 MB/s 
[K     |████████████████████████████████| 243 kB 71.5 MB/s 
[K     |████████████████████████████████| 118 kB 71.2 MB/s 
[?25h

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1947.0, style=ProgressStyle(description…




It's time to run the evaluation on validation data now.

In [33]:
pbar = tqdm(val_dataset, total=num_val_batches)
for labels, speech in pbar:
    loss, predictions = eval_step(speech, labels)
    pbar.set_postfix(val_loss=loss.numpy().item())
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




We are using the `tokenizer.decode(...)` method for decoding our predictions and labels back into the text and will add them to the metric for `WER` computation later.

Now, let's calculate the metric value in following cell:

In [34]:
metric.compute()

1.0

**Note:** Here metric value doesn't make any sense as the model is trained on very small data and ASR-like tasks often require a very large amount of data to learn a mapping from speech to text. You should probably train on large data to get some good results. This notebook gives you a template to fine-tune a pre-trained speech model.

## Wrapping training in `tf.keras.Model` (Recommended)

Now that we have defined the training & evaluation step and we can train the model with our custom loop, we will wrap our training pipeline with `tf.keras.Model`. This will allow us to get rid of a lot of boilerplate code & we will be able to get all benefits of the `.fit(...)` method which includes easy logging with `callbacks` argument, Early stopping, model checkpointing, auto data-sharding when running distributed training and many more.

To be able to do this, we will implement `compile(...)` method for configuring `optimizer`, `metric` & `loss function` and the `train_step(...)` & `eval_step(...)` methods. `train_step` & `eval_step` are very similar to what we defined above & we are making little changes only to make them `class methods` from `functions`.

In [35]:
class Trainer(tf.keras.Model):
  def __init__(self, model):
    super().__init__()
    self.model = model

  def compile(self, optimizer, loss_fn):
      super().compile(optimizer=optimizer)
      self.loss_fn = loss_fn
      self.loss_tracker = tf.keras.metrics.Mean(name="loss")

  @property
  def metrics(self):
      """TFKeras will call `metric.reset_states()` because of this method."""
      return [self.loss_tracker]

  @tf.function(jit_compile=True)
  def call(self, speech, training=False):
      return self.model(speech, training=training)

  def train_step(self, data):
      labels, speech = data
      with tf.GradientTape() as gtape:
          speech = self(speech, training=True)
          loss = self.loss_fn(labels, speech)
      grads = gtape.gradient(loss, self.model.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

      self.loss_tracker.update_state(loss)
      return {m.name: m.result() for m in self.metrics}

  def test_step(self, data):
      labels, speech = data
      speech = self(speech, training=False)
      loss = self.loss_fn(labels, speech)

      self.loss_tracker.update_state(loss)
      return {m.name: m.result() for m in self.metrics}

In [36]:
trainer = Trainer(model)
trainer.compile(optimizer, loss_fn)

The above cell will set up our training state. Now we can initiate the model with the `.fit(...)` method.

In [37]:
trainer.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f0596d85990>

In [39]:
save_dir = "finetuned-wav2vec2"
trainer.save(save_dir)



INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets


INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets


## Inference

In [40]:
finetuned_model = tf.keras.models.load_model(save_dir)

In [41]:
!wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav

--2021-07-27 11:35:10--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-07-27 11:35:10--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’


2021-07-27 11:35:10 (9.11 MB/s) - ‘SA2.wav’ saved [94252/94252]



In [43]:
import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(tf.constant(speech), 0)

outputs = finetuned_model(speech)
outputs

<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 1.7183228 , -0.85034037, -1.0313914 , ..., -1.2755425 ,
         -1.0585048 , -1.2861346 ],
        [ 1.4849069 , -0.73765254, -0.9256569 , ..., -1.061569  ,
         -0.81141865, -1.3527607 ],
        [ 1.6891884 , -1.2361852 , -0.9793967 , ..., -1.4525627 ,
         -1.1310242 , -1.7152848 ],
        ...,
        [ 1.8763846 , -0.06146083,  0.15416975, ..., -1.529478  ,
         -0.7728075 , -0.4524945 ],
        [ 1.7784934 , -0.11706376,  0.19549124, ..., -1.4768647 ,
         -0.8392049 , -0.43076077],
        [ 1.6242267 , -0.2248601 ,  0.10749584, ..., -1.4436699 ,
         -0.9423501 , -0.47027552]]], dtype=float32)>

In [44]:
predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions

['H']

Finally, we have reached an end to this notebook. But it's not an end of learning TensorFlow for speech-related tasks, this [repository](https://github.com/tulasiram58827/TTS_TFLite) contains some more amazing tutorials. In case you encountered any bug in this notebook, please create an issue [here](https://github.com/vasudevgupta7/gsoc-wav2vec2/issues).