In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os, shutil

from config import args
from models.audio_net import AudioNet
from data.lrs3_dataset import LRS3Main
from data.utils import collate_fn
from utils.general import num_params, train, evaluate
from tqdm import tqdm
from sys import exit


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
matplotlib.use("Agg")
np.random.seed(args["SEED"])
torch.manual_seed(args["SEED"])
gpuAvailable = torch.cuda.is_available()
device = torch.device("cuda" if gpuAvailable else "cpu")
kwargs = {"num_workers": args["NUM_WORKERS"], "pin_memory": True} if gpuAvailable else {}
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
#declaring the train and validation datasets and their corresponding dataloaders
audioParams = {"stftWindow":args["STFT_WINDOW"], "stftWinLen":args["STFT_WIN_LENGTH"], "stftOverlap":args["STFT_OVERLAP"]}
noiseParams = {"noiseFile":args["DATA_DIRECTORY"] + "/noise.wav", "noiseProb":args["NOISE_PROBABILITY"], "noiseSNR":args["NOISE_SNR_DB"]}

In [4]:
dataset = "train"
datadir = args["DATA_DIRECTORY"]
reqInpLen = args["MAIN_REQ_INPUT_LENGTH"]
charToIx = args["CHAR_TO_INDEX"]
stepSize = args["STEP_SIZE"]
trainData = LRS3Main(dataset,datadir,reqInpLen,charToIx,stepSize,audioParams,noiseParams)
trainLoader = DataLoader(trainData, batch_size=args["BATCH_SIZE"], collate_fn=collate_fn, shuffle=True, **kwargs)
noiseParams = {"noiseFile":args["DATA_DIRECTORY"] + "/noise.wav", "noiseProb":0, "noiseSNR":args["NOISE_SNR_DB"]}
valData = LRS3Main("val", args["DATA_DIRECTORY"], args["MAIN_REQ_INPUT_LENGTH"], args["CHAR_TO_INDEX"], args["STEP_SIZE"],
                   audioParams, noiseParams)
valLoader = DataLoader(valData, batch_size=args["BATCH_SIZE"], collate_fn=collate_fn, shuffle=True, **kwargs)

In [5]:
#declaring the model, optimizer, scheduler and the loss function
model = AudioNet(args["TX_NUM_FEATURES"], args["TX_ATTENTION_HEADS"], args["TX_NUM_LAYERS"], args["PE_MAX_LENGTH"],
                 args["AUDIO_FEATURE_SIZE"], args["TX_FEEDFORWARD_DIM"], args["TX_DROPOUT"], args["NUM_CLASSES"])
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args["INIT_LR"], betas=(args["MOMENTUM1"], args["MOMENTUM2"]))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=args["LR_SCHEDULER_FACTOR"],
                                                 patience=args["LR_SCHEDULER_WAIT"], threshold=args["LR_SCHEDULER_THRESH"],
                                                 threshold_mode="abs", min_lr=args["FINAL_LR"], verbose=True)
loss_function = nn.CTCLoss(blank=0, zero_infinity=False)

In [9]:
if os.path.exists(args["CODE_DIRECTORY"] + "/checkpoints"):
    while True:
        ch = input("Continue and remove the 'checkpoints' directory? y/n: ")
        if ch == "y":
            break
        elif ch == "n":
            exit()
        else:
            print("Invalid input")
    shutil.rmtree(args["CODE_DIRECTORY"] + "/checkpoints")

os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints")
os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/models")
os.mkdir(args["CODE_DIRECTORY"] + "/checkpoints/plots")


In [6]:
# #loading the pretrained weights
# if args["PRETRAINED_MODEL_FILE"] is not None:
#     print("\n\nPre-trained Model File: %s" %(args["PRETRAINED_MODEL_FILE"]))
#     print("\nLoading the pre-trained model .... \n")
#     model.load_state_dict(torch.load(args["CODE_DIRECTORY"] + args["PRETRAINED_MODEL_FILE"], map_location=device))
#     model.to(device)
#     print("Loading Done.\n")

trainingLossCurve = list()
validationLossCurve = list()
trainingWERCurve = list()
validationWERCurve = list()

In [7]:
trainData.__getitem__(-1)

(tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [7.1049e-05, 2.5245e-04, 7.7931e-04,  ..., 7.8156e-05, 1.3247e-04,
          8.4065e-05],
         [1.7867e-04, 4.2720e-04, 1.6799e-03,  ..., 1.8789e-04, 2.0205e-04,
          2.2435e-04],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00]], dtype=torch.float64),
 tensor([13,  8,  2, 10,  8,  1, 17,  5,  7,  1, 10,  2,  5, 17,  9,  1,  8, 13,
          6,  3,  5, 20, 11,  2,  1, 20, 11,  4,  4, 12,  1, 12,  4,  7,  4, 10,
          8,  1, 15,  6,  3,  9,  6,  7,  1,  5,  1, 17,  2, 10,  3,  5,  6,  7,
          1, 10,  5, 12,  6, 13,  8,  1, 12,  6, 10,  2, 17,  3, 11, 14,  1,  5,
          7, 12,  1,  8,  2,  7, 12,  1,  3, 

In [8]:
args["DATA_DIRECTORY"]

'../lrs3'

In [9]:
trainLoader.dataset.datalist

['../lrs3/train_mini/00j9bKdiOjk/50001',
 '../lrs3/train_mini/00j9bKdiOjk/50002',
 '../lrs3/train_mini/00j9bKdiOjk/50003',
 '../lrs3/train_mini/0af00UcTOSc/50001',
 '../lrs3/train_mini/0af00UcTOSc/50002',
 '../lrs3/train_mini/0af00UcTOSc/50003',
 '../lrs3/train_mini/0af00UcTOSc/50004',
 '../lrs3/train_mini/0af00UcTOSc/50005',
 '../lrs3/train_mini/0af00UcTOSc/50007',
 '../lrs3/train_mini/0af00UcTOSc/50008',
 '../lrs3/train_mini/0af00UcTOSc/50009',
 '../lrs3/train_mini/0af00UcTOSc/50010',
 '../lrs3/train_mini/0af00UcTOSc/50011',
 '../lrs3/train_mini/0af00UcTOSc/50012',
 '../lrs3/train_mini/0af00UcTOSc/50013',
 '../lrs3/train_mini/0af00UcTOSc/50014',
 '../lrs3/train_mini/0akiEFwtkyA/50001',
 '../lrs3/train_mini/0akiEFwtkyA/50002',
 '../lrs3/train_mini/0Amg53UuRqE/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50002',
 '../lrs3/train_mini/0Bhk65bYSI0/50003',
 '../lrs3/train_mini/0Bhk65bYSI0/50004',
 '../lrs3/train_mini/0Bhk65bYSI0/50005',
 '../lrs3/train_

In [18]:
trainLoader.dataset.datalist

['../lrs3/train_mini/00j9bKdiOjk/50001',
 '../lrs3/train_mini/00j9bKdiOjk/50002',
 '../lrs3/train_mini/00j9bKdiOjk/50003',
 '../lrs3/train_mini/0af00UcTOSc/50001',
 '../lrs3/train_mini/0af00UcTOSc/50002',
 '../lrs3/train_mini/0af00UcTOSc/50003',
 '../lrs3/train_mini/0af00UcTOSc/50004',
 '../lrs3/train_mini/0af00UcTOSc/50005',
 '../lrs3/train_mini/0af00UcTOSc/50007',
 '../lrs3/train_mini/0af00UcTOSc/50008',
 '../lrs3/train_mini/0af00UcTOSc/50009',
 '../lrs3/train_mini/0af00UcTOSc/50010',
 '../lrs3/train_mini/0af00UcTOSc/50011',
 '../lrs3/train_mini/0af00UcTOSc/50012',
 '../lrs3/train_mini/0af00UcTOSc/50013',
 '../lrs3/train_mini/0af00UcTOSc/50014',
 '../lrs3/train_mini/0akiEFwtkyA/50001',
 '../lrs3/train_mini/0akiEFwtkyA/50002',
 '../lrs3/train_mini/0Amg53UuRqE/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50002',
 '../lrs3/train_mini/0Bhk65bYSI0/50003',
 '../lrs3/train_mini/0Bhk65bYSI0/50004',
 '../lrs3/train_mini/0Bhk65bYSI0/50005',
 '../lrs3/train_

In [28]:
trainLoader.dataset.datalist

['../lrs3/train_mini/00j9bKdiOjk/50001',
 '../lrs3/train_mini/00j9bKdiOjk/50002',
 '../lrs3/train_mini/00j9bKdiOjk/50003',
 '../lrs3/train_mini/0af00UcTOSc/50001',
 '../lrs3/train_mini/0af00UcTOSc/50002',
 '../lrs3/train_mini/0af00UcTOSc/50003',
 '../lrs3/train_mini/0af00UcTOSc/50004',
 '../lrs3/train_mini/0af00UcTOSc/50005',
 '../lrs3/train_mini/0af00UcTOSc/50007',
 '../lrs3/train_mini/0af00UcTOSc/50008',
 '../lrs3/train_mini/0af00UcTOSc/50009',
 '../lrs3/train_mini/0af00UcTOSc/50010',
 '../lrs3/train_mini/0af00UcTOSc/50011',
 '../lrs3/train_mini/0af00UcTOSc/50012',
 '../lrs3/train_mini/0af00UcTOSc/50013',
 '../lrs3/train_mini/0af00UcTOSc/50014',
 '../lrs3/train_mini/0akiEFwtkyA/50001',
 '../lrs3/train_mini/0akiEFwtkyA/50002',
 '../lrs3/train_mini/0Amg53UuRqE/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50001',
 '../lrs3/train_mini/0Bhk65bYSI0/50002',
 '../lrs3/train_mini/0Bhk65bYSI0/50003',
 '../lrs3/train_mini/0Bhk65bYSI0/50004',
 '../lrs3/train_mini/0Bhk65bYSI0/50005',
 '../lrs3/train_

In [7]:
#printing the total and trainable parameters in the model
numTotalParams, numTrainableParams = num_params(model)
print("\nNumber of total parameters in the model = %d" %(numTotalParams))
print("Number of trainable parameters in the model = %d\n" %(numTrainableParams))
args["NUM_STEPS"] = 5

print("\nTraining the model .... \n")

trainParams = {"spaceIx":args["CHAR_TO_INDEX"][" "], "eosIx":args["CHAR_TO_INDEX"]["<EOS>"]}
valParams = {"decodeScheme":"greedy", "spaceIx":args["CHAR_TO_INDEX"][" "], "eosIx":args["CHAR_TO_INDEX"]["<EOS>"]}
for step in range(args["NUM_STEPS"]):

    #train the model for one step
    trainingLoss, trainingCER, trainingWER = train(model, trainLoader, optimizer, loss_function, device, trainParams)
    trainingLossCurve.append(trainingLoss)
    trainingWERCurve.append(trainingWER)

    #evaluate the model on validation set
    validationLoss, validationCER, validationWER = evaluate(model, valLoader, loss_function, device, valParams)
    validationLossCurve.append(validationLoss)
    validationWERCurve.append(validationWER)

    #printing the stats after each step
    print("Step: %03d || Tr.Loss: %.6f  Val.Loss: %.6f || Tr.CER: %.3f  Val.CER: %.3f || Tr.WER: %.3f  Val.WER: %.3f"
          %(step, trainingLoss, validationLoss, trainingCER, validationCER, trainingWER, validationWER))

    #make a scheduler step
    scheduler.step(validationWER)


    #saving the model weights and loss/metric curves in the checkpoints directory after every few steps
    if ((step%args["SAVE_FREQUENCY"] == 0) or (step == args["NUM_STEPS"]-1)) and (step != 0):

        savePath = args["CODE_DIRECTORY"] + "/checkpoints/models/train-step_{:04d}-wer_{:.3f}.pt".format(step, validationWER)
        torch.save(model.state_dict(), savePath)

        plt.figure()
        plt.title("Loss Curves")
        plt.xlabel("Step No.")
        plt.ylabel("Loss value")
        plt.plot(list(range(1, len(trainingLossCurve)+1)), trainingLossCurve, "blue", label="Train")
        plt.plot(list(range(1, len(validationLossCurve)+1)), validationLossCurve, "red", label="Validation")
        plt.legend()
        plt.savefig(args["CODE_DIRECTORY"] + "/checkpoints/plots/train-step_{:04d}-loss.png".format(step))
        plt.close()

        plt.figure()
        plt.title("WER Curves")
        plt.xlabel("Step No.")
        plt.ylabel("WER")
        plt.plot(list(range(1, len(trainingWERCurve)+1)), trainingWERCurve, "blue", label="Train")
        plt.plot(list(range(1, len(validationWERCurve)+1)), validationWERCurve, "red", label="Validation")
        plt.legend()
        plt.savefig(args["CODE_DIRECTORY"] + "/checkpoints/plots/train-step_{:04d}-wer.png".format(step))
        plt.close()


print("\nTraining Done.\n")


Number of total parameters in the model = 38507048
Number of trainable parameters in the model = 38507048


Training the model .... 



                                                                           

Step: 000 || Tr.Loss: 3.747638  Val.Loss: 3.324944 || Tr.CER: 1.016  Val.CER: 1.000 || Tr.WER: 1.000  Val.WER: 1.000


                                                                           

Step: 001 || Tr.Loss: 3.378331  Val.Loss: 3.324467 || Tr.CER: 1.000  Val.CER: 1.000 || Tr.WER: 1.000  Val.WER: 1.000


                                                                           

Step: 002 || Tr.Loss: 3.305391  Val.Loss: 3.365579 || Tr.CER: 1.000  Val.CER: 1.000 || Tr.WER: 1.000  Val.WER: 1.000


                                                                           

Step: 003 || Tr.Loss: 3.331936  Val.Loss: 3.326054 || Tr.CER: 1.000  Val.CER: 1.000 || Tr.WER: 1.000  Val.WER: 1.000


                                                                           

Step: 004 || Tr.Loss: 3.304192  Val.Loss: 3.444248 || Tr.CER: 1.000  Val.CER: 1.000 || Tr.WER: 1.000  Val.WER: 1.000

Training Done.

