In [None]:
from tqdm import tqdm
import numpy as np
import torch
import copy
import math
from torch.utils.data import DataLoader
from patter import ModelFactory
from patter.data import AudioDataset, BucketingSampler, audio_seq_collate_fn
from patter.decoder import GreedyCTCDecoder
from patter.data.features import PerturbedSpectrogramFeaturizer
from patter.evaluator import validate
from patter.models import SpeechModel

In [None]:
# load the model you're starting from
seed_model_path = "/data/users/ryan/models/deepspeech/librispeech_pretrained_patter.pt"
seed_model = ModelFactory.load(seed_model_path)

In [3]:
# specify paths of where the new model should be saved, and the train/val manifests to use in training
new_model_path = "/data/users/ryan/models/deepspeech/an4_transferred.pt"
train_manifest_path = "/home/ryan/data/patter_data/an4-jl/an4_train_manifest.jl"
val_manifest_path = "/home/ryan/data/patter_data/an4-jl/an4_val_manifest.jl"

In [4]:
# make a copy of the model
model = copy.copy(seed_model)

In [5]:
# decide on set of labels you want for the new model
# for this experiment (going from english to english, we can just reuse the set of labels from the original model)
# NB: the first label MUST represent the CTC blank label (canonically is '\xa0')
labels = seed_model.labels

# when new labels are set, tack them into the model
model.labels = labels
print("New Labels:", labels)

New Labels: ['\xa0', "'", 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', ' ']


In [6]:
# create a new output layer entirely
model.output[1] = torch.nn.Linear(in_features=800, out_features=len(labels), bias=False)

# and initialize it
for p in model.output[1].parameters():
    torch.nn.init.xavier_uniform_(p)

In [7]:
# set requires_grad = False for all but the output layers
for x, y in model.conv.named_parameters():
    y.requires_grad_(False)
for x, y in model.rnn.named_parameters():
    y.requires_grad_(False)

In [8]:
# create a featurizer based on the model's expected features, optionally specify perturbations on training data
featurizer = PerturbedSpectrogramFeaturizer.from_config(seed_model.input_cfg, perturbation_configs=None)

# load datasets for train and dev sets that you want to transfer model to
train_corpus = AudioDataset(train_manifest_path, labels, featurizer, max_duration=17.0, min_duration=1.0)
val_corpus = AudioDataset(val_manifest_path, labels, featurizer, max_duration=17.0, min_duration=1.0)

Dataset loaded with 0.70 hours. Filtered 0.00 hours.
Dataset loaded with 0.10 hours. Filtered 0.00 hours.


In [9]:
# set up data loaders
batch_size = 32
num_workers = 4
cuda = True
train_sampler = BucketingSampler(train_corpus, batch_size=batch_size)
train_loader = DataLoader(train_corpus, num_workers=num_workers, collate_fn=audio_seq_collate_fn, pin_memory=cuda, batch_sampler=train_sampler)
eval_loader = DataLoader(val_corpus, num_workers=4, collate_fn=audio_seq_collate_fn, pin_memory=cuda, batch_size=batch_size)

In [10]:
if cuda:
    model = model.cuda()

In [11]:
# set up optimizer
lr = 3e-4
momentum = 0.9
annealing = 1.01
trainable_params = set([x for x in model.parameters() if x.requires_grad])

# we typically use SGD w/ nesterov momentum and a custom learning rate schedule
#optimizer = torch.optim.SGD(trainable_params, lr=lr, momentum=momentum, nesterov=True)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=1/annealing)

# Don't need a learning rate schedule if using Adam
class NoOpScheduler(object):
    def __init__(self):
        pass

    def step(self):
        pass

# but for example purposes, we'll just use Adam
optimizer = torch.optim.Adam(trainable_params, lr=lr, amsgrad=True)
scheduler = NoOpScheduler()

In [12]:
# Report the initial performance of the model on the new validation set:
err = validate(eval_loader, model, tqdm=False)
print("WER: {0:.3f}, CER: {1:.3f}".format(err.wer, err.cer))

best_wer = err.wer

WER: 102.872, CER: 277.516


In [13]:
num_epochs = 40
for epoch in range(num_epochs):
    train_sampler.shuffle()
    scheduler.step()
    
    model.train()
    train_loader = tqdm(train_loader, desc="Epoch {}".format(epoch+1), leave=False)
    for i, data in enumerate(train_loader):
        feat, target, feat_len, target_len = data
        if cuda:
            feat = feat.cuda(async=True)
        
        optimizer.zero_grad()
        
        output, output_len = model(feat, feat_len)
        loss = model.loss(output, target, output_len.squeeze(0), target_len)
        
        scalar_loss = loss.item()/feat.size(0)
        if abs(scalar_loss) == math.inf:
            scalar_loss = 0
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, 400)
        optimizer.step()
    model.eval()
    err = validate(eval_loader, model, tqdm=False)
    print("Epoch {0} :: WER: {1:.3f}, CER: {2:.3f}".format(epoch+1, err.wer, err.cer))
    
    if err.wer < best_wer:
        best_wer = err.wer
        torch.save(SpeechModel.serialize(model), new_model_path)

Epoch 2:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 1 :: WER: 110.183, CER: 112.736


                                                        

Epoch 2 :: WER: 97.128, CER: 58.766


Epoch 4:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 3 :: WER: 99.217, CER: 68.868


Epoch 5:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 4 :: WER: 99.739, CER: 70.715


Epoch 6:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 5 :: WER: 99.217, CER: 66.863


Epoch 7:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 6 :: WER: 99.739, CER: 57.783


Epoch 8:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 7 :: WER: 97.781, CER: 44.300


                                                        

Epoch 8 :: WER: 95.561, CER: 35.888


                                                        

Epoch 9 :: WER: 92.298, CER: 28.184


                                                         

Epoch 10 :: WER: 85.640, CER: 23.349


                                                         

Epoch 11 :: WER: 77.546, CER: 19.182


                                                         

Epoch 12 :: WER: 69.713, CER: 16.903


                                                         

Epoch 13 :: WER: 58.486, CER: 13.679


                                                         

Epoch 14 :: WER: 55.483, CER: 12.736


                                                         

Epoch 15 :: WER: 48.172, CER: 11.321


Epoch 17:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 16 :: WER: 48.956, CER: 10.063


                                                         

Epoch 17 :: WER: 45.953, CER: 9.355


                                                         

Epoch 18 :: WER: 44.517, CER: 9.041


                                                         

Epoch 19 :: WER: 40.209, CER: 7.665


                                                         

Epoch 20 :: WER: 32.115, CER: 6.918


                                                         

Epoch 21 :: WER: 31.070, CER: 6.643


                                                         

Epoch 22 :: WER: 27.676, CER: 6.643


                                                         

Epoch 23 :: WER: 23.890, CER: 5.975


                                                         

Epoch 24 :: WER: 21.802, CER: 5.818


                                                         

Epoch 25 :: WER: 19.974, CER: 5.307


Epoch 27:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 26 :: WER: 20.627, CER: 5.739


                                                         

Epoch 27 :: WER: 18.146, CER: 5.031


                                                         

Epoch 28 :: WER: 17.755, CER: 5.071


                                                         

Epoch 29 :: WER: 17.102, CER: 4.717


                                                         

Epoch 30 :: WER: 16.710, CER: 5.071


                                                         

Epoch 31 :: WER: 14.752, CER: 4.442


Epoch 33:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 32 :: WER: 14.752, CER: 4.442


Epoch 34:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 33 :: WER: 14.752, CER: 4.324


                                                         

Epoch 34 :: WER: 14.360, CER: 4.285


                                                         

Epoch 35 :: WER: 14.099, CER: 3.931


                                                         

Epoch 36 :: WER: 13.577, CER: 4.088


Epoch 38:   0%|          | 0/29 [00:00<?, ?it/s]         

Epoch 37 :: WER: 13.708, CER: 3.970


                                                         

Epoch 38 :: WER: 13.055, CER: 3.813


                                                         

Epoch 39 :: WER: 12.924, CER: 3.695


                                                         

Epoch 40 :: WER: 12.402, CER: 3.616


In [14]:
# reload the previously best found model
model = ModelFactory.load(new_model_path)
if cuda:
    model = model.cuda()

err = validate(eval_loader, model, tqdm=False)
print("WER: {0:.3f}, CER: {1:.3f}".format(err.wer, err.cer))

# add the rnns for additional fine tuning
for x, y in model.conv.named_parameters():
    y.requires_grad_(False)
trainable_params = set([x for x in model.parameters() if x.requires_grad])
optimizer = torch.optim.Adam(trainable_params, lr=3e-4, amsgrad=True)

WER: 12.402, CER: 3.616


In [16]:
# run another 20 epochs of training
num_epochs = 20
for epoch in range(num_epochs):
    train_sampler.shuffle()
    
    model.train()
    train_loader = tqdm(train_loader, desc="Epoch {}".format(epoch+1))
    for i, data in enumerate(train_loader):
        feat, target, feat_len, target_len = data
        if cuda:
            feat = feat.cuda(async=True)
        
        optimizer.zero_grad()
        
        output, output_len = model(feat, feat_len)
        loss = model.loss(output, target, output_len.squeeze(0), target_len)
        
        scalar_loss = loss.item()/feat.size(0)
        if abs(scalar_loss) == math.inf:
            scalar_loss = 0
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, 400)
        optimizer.step()
    model.eval()
    err = validate(eval_loader, model, tqdm=False)
    print("Epoch {0} :: WER: {1:.3f}, CER: {2:.3f}".format(epoch+1, err.wer, err.cer))
    
    if err.wer < best_wer:
        best_wer = err.wer
        torch.save(SpeechModel.serialize(model), new_model_path)

Epoch 1: 100%|██████████| 29/29 [00:07<00:00,  3.77it/s]


Epoch 1 :: WER: 3.264, CER: 1.140


Epoch 2: 100%|██████████| 29/29 [00:07<00:00,  3.93it/s]


Epoch 2 :: WER: 1.958, CER: 0.865


Epoch 3: 100%|██████████| 29/29 [00:07<00:00,  3.91it/s]
Epoch 4:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 3 :: WER: 1.958, CER: 0.865


Epoch 4: 100%|██████████| 29/29 [00:07<00:00,  3.85it/s]


Epoch 4 :: WER: 1.697, CER: 0.747


Epoch 5: 100%|██████████| 29/29 [00:07<00:00,  3.89it/s]
Epoch 6:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 5 :: WER: 1.697, CER: 0.668


Epoch 6: 100%|██████████| 29/29 [00:07<00:00,  3.74it/s]
Epoch 7:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 6 :: WER: 1.697, CER: 0.629


Epoch 7: 100%|██████████| 29/29 [00:07<00:00,  3.84it/s]


Epoch 7 :: WER: 1.436, CER: 0.550


Epoch 8: 100%|██████████| 29/29 [00:07<00:00,  3.83it/s]
Epoch 9:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 8 :: WER: 1.567, CER: 0.629


Epoch 9: 100%|██████████| 29/29 [00:07<00:00,  3.73it/s]
Epoch 10:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 9 :: WER: 1.567, CER: 0.590


Epoch 10: 100%|██████████| 29/29 [00:07<00:00,  3.80it/s]
Epoch 11:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 10 :: WER: 1.567, CER: 0.590


Epoch 11: 100%|██████████| 29/29 [00:07<00:00,  3.85it/s]
Epoch 12:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 11 :: WER: 1.436, CER: 0.550


Epoch 12: 100%|██████████| 29/29 [00:07<00:00,  3.80it/s]


Epoch 12 :: WER: 1.305, CER: 0.511


Epoch 13: 100%|██████████| 29/29 [00:07<00:00,  3.82it/s]
Epoch 14:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 13 :: WER: 1.436, CER: 0.550


Epoch 14: 100%|██████████| 29/29 [00:07<00:00,  3.90it/s]
Epoch 15:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 14 :: WER: 1.567, CER: 0.590


Epoch 15: 100%|██████████| 29/29 [00:07<00:00,  3.75it/s]
Epoch 16:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 15 :: WER: 1.436, CER: 0.550


Epoch 16: 100%|██████████| 29/29 [00:07<00:00,  3.77it/s]
Epoch 17:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 16 :: WER: 1.436, CER: 0.550


Epoch 17: 100%|██████████| 29/29 [00:07<00:00,  3.85it/s]
Epoch 18:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 17 :: WER: 1.436, CER: 0.590


Epoch 18: 100%|██████████| 29/29 [00:07<00:00,  3.90it/s]
Epoch 19:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 18 :: WER: 1.697, CER: 0.629


Epoch 19: 100%|██████████| 29/29 [00:07<00:00,  3.92it/s]
Epoch 20:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 19 :: WER: 1.305, CER: 0.511


Epoch 20: 100%|██████████| 29/29 [00:07<00:00,  3.74it/s]


Epoch 20 :: WER: 1.567, CER: 0.629
