In [None]:
!git clone https://github.com/vasudevgupta7/gsoc-wav2vec2 --branch=export

import sys
import os

os.chdir("gsoc-wav2vec2")
sys.path.append("src")

In [29]:
import tensorflow as tf
import tensorflow_hub as hub

from wav2vec2 import Wav2Vec2Config, CTCLoss

config = Wav2Vec2Config()

In [30]:
# TODO: update it to load from TFHub later
loaded = hub.load("saved-model")
print("Available signatures are:", list(loaded.signatures.keys()))
pretrained_model = loaded.signatures["train"]

Available signatures are: ['infer', 'train']


In [31]:
pretrained_model = hub.KerasLayer(pretrained_model, trainable=False)
lm_head = tf.keras.layers.Dense(config.vocab_size)

In [37]:
print("Number of trainable variables:", len(pretrained_model.trainable_variables) + len(lm_head.trainable_variables))

Number of trainable variables: 2


In [None]:
BATCH_SIZE_PER_DEVICE = 8
BATCH_SIZE = BATCH_SIZE_PER_DEVICE
LEARNING_RATE = 5e-5

In [None]:
loss_fn = CTCLoss(config, (BATCH_SIZE_PER_DEVICE, 246000), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

In [35]:
@tf.function(jit_compile=True)
def forward(batch):
    return lm_head(pretrained_model(batch)["output_0"])

In [None]:
@tf.function
def train_step(speech, labels):
    with tf.GradientTape() as gtape:
        speech = forward(speech)
        loss = loss_fn(labels, speech)
    trainable_variables = list(pretrained_model.trainable_variables) + lm_head.trainable_variables
    grads = gtape.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(grads, trainable_variables))
    return loss

In [None]:
from data_utils import LibriSpeechDataLoaderArgs, LibriSpeechDataLoader

data_args = LibriSpeechDataLoaderArgs(from_tfrecords=False, data_dir="data/LibriSpeech/test-clean")
dataloader = LibriSpeechDataLoader(data_args)
dataset = dataloader(seed=None)

dataset = dataset.take(2)

In [None]:
from tqdm import tqdm

for speech, label in tqdm(dataset):
    train_step(speech, label)