In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import sys
sys.path.append('/content/drive/MyDrive/TTS_2023_V3')

In [3]:
import os

import torch
import torch.nn as nn
import constants
from dataset import Dataset
from evaluate import evaluate
from model import FastSpeech2Loss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#Instantly make your loops show a smart progress meter - just wrap any iterable with
from tqdm import tqdm
from model_utils import get_model, get_param_num
from tools import log, synth_one_sample, to_device, get_vocoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train():
    print("Getting ready for training ...")

    # preprocess_config, model_config, train_config = configs

    # Get dataset
    dataset = Dataset("train.txt", sort=True, drop_last=True)
    batch_size = constants.BATCH_SIZE
    group_size = 4  # Set this larger than 1 to enable sorting in Dataset
    assert batch_size * group_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
    )

    # Prepare model
    model, optimizer = get_model(constants.RESTORE_STEP, device, train=True)
    model = nn.DataParallel(model)
    num_param = get_param_num(model)
    Loss = FastSpeech2Loss().to(device)
    print("Number of FastSpeech2 Parameters:", num_param)

    # Load vocoder
    vocoder = get_vocoder(device, constants.VOCODER_CONFIG_PATH, constants.VOCODER_PRETRAINED_MODEL_PATH)

    os.makedirs(constants.CKPT_PATH, exist_ok=True)
    os.makedirs(constants.LOG_PATH, exist_ok=True)
    os.makedirs(constants.RESULT_PATH, exist_ok=True)

    train_log_path = constants.LOG_PATH + "/train"
    val_log_path = constants.LOG_PATH + "/val"
    os.makedirs(train_log_path, exist_ok=True)
    os.makedirs(val_log_path, exist_ok=True)
    train_logger = SummaryWriter(train_log_path)
    val_logger = SummaryWriter(val_log_path)

    # Training
    step =  constants.RESTORE_STEP+1
    epoch = 1
    grad_acc_step = constants.GRAD_ACC_STEP
    grad_clip_thresh = constants.GRAD_CLIP_THRESH
    total_step = constants.TOTAL_STEP
    log_step = constants.LOG_STEP
    save_step = constants.SAVE_STEP
    synth_step = constants.SYNTH_STEP
    val_step = constants.VAL_STEP
    print('total_step:', total_step)
    print('restore_step:',constants.RESTORE_STEP)
    print('grad_acc_step:',grad_acc_step)
    print('grad_clip_thresh:',grad_clip_thresh)
    print('log_step:',log_step)
    print('save_step:',save_step)
    print('synth_step:',synth_step)
    print('val_step:',val_step)
    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = constants.RESTORE_STEP
    outer_bar.update()
    while True:
        # inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
        for batchs in loader:
            for batch in batchs:

                batch = to_device(batch, device)

                # Forward
                output = model(*(batch[2:]))

                # Cal Loss
                losses = Loss(batch, output)
                total_loss = losses[0]

                # Backward
                total_loss = total_loss / grad_acc_step
                total_loss.backward()
                if step % grad_acc_step == 0:
                    # Clipping gradients to avoid gradient explosion
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

                    # Update weights
                    optimizer.step_and_update_lr()
                    optimizer.zero_grad()

                if step % log_step == 0:
                    losses = [l.item() for l in losses]
                    message1 = "Step {}/{}, ".format(step, total_step)
                    message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
                        *losses
                    )

                    with open(os.path.join(train_log_path, "log.txt"), "a") as f:
                        f.write(message1 + message2 + "\n")

                    outer_bar.write(message1 + message2)

                    log(train_logger, step, losses=losses)

                if step % synth_step == 0:
                    fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
                        batch,
                        output,
                        vocoder,
                        constants
                    )
                    log(
                        train_logger,
                        fig=fig,
                        tag="Training/step_{}_{}".format(step, tag),
                    )
                    sampling_rate = constants.SAMPLING_RATE
                    log(
                        train_logger,
                        audio=wav_reconstruction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_reconstructed".format(step, tag),
                    )
                    log(
                        train_logger,
                        audio=wav_prediction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_synthesized".format(step, tag),
                    )

                if step % val_step == 0:
                    model.eval()
                    message = evaluate(model, step, val_logger, vocoder)
                    with open(os.path.join(val_log_path, "log.txt"), "a") as f:
                        f.write(message + "\n")
                    outer_bar.write(message)

                    model.train()

                if step % save_step == 0:
                    torch.save(
                        {
                            "model": model.module.state_dict(),
                            "optimizer": optimizer._optimizer.state_dict(),
                        },
                        os.path.join(
                            constants.CKPT_PATH,
                            "{}.pth.tar".format(step),
                        ),
                    )
                if step>=total_step:
                    break

                step += 1
                outer_bar.update(1)

            if step>=total_step:
                    break

        if step>=total_step:
                    break
        epoch += 1



In [4]:
train()

Getting ready for training ...
Number of FastSpeech2 Parameters: 35186497
Removing weight norm...
total_step: 1045000
restore_step: 1040000
grad_acc_step: 1
grad_clip_thresh: 1.0
log_step: 1000
save_step: 5000
synth_step: 1000
val_step: 10000


Training: 100%|█████████▉| 1041000/1045000 [17:21<54:10,  1.23it/s]

Step 1041000/1045000, Total Loss: 0.5602, Mel Loss: 0.2561, Mel PostNet Loss: 0.2555, Pitch Loss: 0.0187, Energy Loss: 0.0162, Duration Loss: 0.0137


Training: 100%|█████████▉| 1042000/1045000 [31:00<03:17, 15.20it/s]

Step 1042000/1045000, Total Loss: 0.7109, Mel Loss: 0.3259, Mel PostNet Loss: 0.3253, Pitch Loss: 0.0261, Energy Loss: 0.0181, Duration Loss: 0.0156


Training: 100%|█████████▉| 1043000/1045000 [32:13<02:18, 14.41it/s]

Step 1043000/1045000, Total Loss: 0.7298, Mel Loss: 0.3271, Mel PostNet Loss: 0.3263, Pitch Loss: 0.0287, Energy Loss: 0.0223, Duration Loss: 0.0254


Training: 100%|█████████▉| 1044000/1045000 [33:26<01:17, 12.84it/s]

Step 1044000/1045000, Total Loss: 0.6110, Mel Loss: 0.2685, Mel PostNet Loss: 0.2676, Pitch Loss: 0.0326, Energy Loss: 0.0185, Duration Loss: 0.0237


Training: 100%|██████████| 1045000/1045000 [34:39<00:00, 13.91it/s]

Step 1045000/1045000, Total Loss: 0.7228, Mel Loss: 0.3209, Mel PostNet Loss: 0.3201, Pitch Loss: 0.0241, Energy Loss: 0.0304, Duration Loss: 0.0273


Training: 100%|██████████| 1045000/1045000 [34:41<00:00, 502.00it/s]
