In [None]:
# default_exp trainer.gradtts

# Trainer

In [None]:
# export
import json
import os
from pathlib import Path
from pprint import pprint
import numpy as np

import torch
from torch.cuda.amp import autocast, GradScaler
import torch.distributed as dist
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
import time

from uberduck_ml_dev.models.common import MelSTFT
from uberduck_ml_dev.utils.plot import (
    plot_attention,
    plot_gate_outputs,
    plot_spectrogram,
    plot_tensor,
)
from uberduck_ml_dev.text.util import text_to_sequence, random_utterance
from uberduck_ml_dev.text.symbols import symbols_with_ipa
from uberduck_ml_dev.trainer.base import TTSTrainer

from uberduck_ml_dev.data_loader import (
    TextAudioSpeakerLoader,
    TextMelCollate,
    DistributedBucketSampler,
    TextMelDataset,
)
from uberduck_ml_dev.vendor.tfcompat.hparam import HParams
from uberduck_ml_dev.utils.plot import save_figure_to_numpy, plot_spectrogram
from uberduck_ml_dev.utils.utils import slice_segments, clip_grad_value_
from uberduck_ml_dev.text.symbols import SYMBOL_SETS

TypeError: Couldn't build proto file into descriptor pool!
Invalid proto descriptor for file "tensorboard/compat/proto/tensor_shape.proto":
  tensorboard.TensorShapeProto.dim: "tensorboard.TensorShapeProto.dim" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto.unknown_rank: "tensorboard.TensorShapeProto.unknown_rank" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto.Dim.size: "tensorboard.TensorShapeProto.Dim.size" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto.Dim.name: "tensorboard.TensorShapeProto.Dim.name" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto.Dim: "tensorboard.TensorShapeProto.Dim" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto: "tensorboard.TensorShapeProto" is already defined in file "tensorboardX/src/tensor_shape.proto".
  tensorboard.TensorShapeProto.dim: "tensorboard.TensorShapeProto.Dim" seems to be defined in "tensorboardX/src/tensor_shape.proto", which is not imported by "tensorboard/compat/proto/tensor_shape.proto".  To use it here, please add the necessary import.


# Grad TTS Trainer

In [None]:
# export
from tqdm import tqdm
from uberduck_ml_dev.text.util import text_to_sequence, random_utterance
from uberduck_ml_dev.models.gradtts import GradTTS
from uberduck_ml_dev.utils.utils import intersperse


class GradTTSTrainer(TTSTrainer):
    REQUIRED_HPARAMS = [
        "training_audiopaths_and_text",
        "test_audiopaths_and_text",
    ]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for param in self.REQUIRED_HPARAMS:
            if not hasattr(self, param):
                raise Exception(f"GradTTSTrainer missing a required param: {param}")
        self.sampling_rate = self.hparams.sampling_rate
        self.checkpoint_path = self.hparams.log_dir

    def sample_inference(self, model, timesteps=10, spk=None):
        with torch.no_grad():
            sequence = text_to_sequence(
                random_utterance(),
                self.text_cleaners,
                1.0,
                symbol_set=self.hparams.symbol_set,
            )
            if self.hparams.intersperse_text:
                sequence = intersperse(
                    sequence, (len(SYMBOL_SETS[self.hparams.symbol_set]))
                )
            x = torch.LongTensor(sequence).cuda()[None]
            x_lengths = torch.LongTensor([x.shape[-1]]).cuda()
            y_enc, y_dec, attn = model(
                x,
                x_lengths,
                n_timesteps=50,
                temperature=1.5,
                stoc=False,
                spk=spk,
                length_scale=0.91,
            )
            if self.hparams.vocoder_algorithm == "hifigan":
                audio = self.sample(
                    y_dec,
                    algorithm=self.hparams.vocoder_algorithm,
                    hifigan_config=self.hparams.hifigan_config,
                    hifigan_checkpoint=self.hparams.hifigan_checkpoint,
                    cudnn_enabled=self.hparams.cudnn_enabled,
                )
            else:
                audio = self.sample(y_dec.cpu()[0])
            return audio

    def train(self, checkpoint=None):
        if self.distributed_run:
            self.init_distributed()

        train_dataset = TextMelDataset(
            self.hparams.training_audiopaths_and_text,
            self.hparams.text_cleaners,
            1.0,
            self.hparams.n_feats,
            self.hparams.sampling_rate,
            self.hparams.mel_fmin,
            self.hparams.mel_fmax,
            self.hparams.filter_length,
            self.hparams.hop_length,
            (self.hparams.filter_length - self.hparams.hop_length) // 2,
            self.hparams.win_length,
            intersperse_text=self.hparams.intersperse_text,
            intersperse_token=(len(SYMBOL_SETS[self.hparams.symbol_set])),
            symbol_set=self.hparams.symbol_set,
        )
        collate_fn = TextMelCollate()

        loader = DataLoader(
            dataset=train_dataset,
            batch_size=self.hparams.batch_size,
            collate_fn=collate_fn,
            drop_last=True,
            num_workers=0,
            shuffle=False,
        )

        test_dataset = TextMelDataset(
            self.hparams.test_audiopaths_and_text,
            self.hparams.text_cleaners,
            1.0,
            self.hparams.n_feats,
            self.hparams.sampling_rate,
            self.hparams.mel_fmin,
            self.hparams.mel_fmax,
            self.hparams.filter_length,
            self.hparams.hop_length,
            (self.hparams.filter_length - self.hparams.hop_length) // 2,
            self.hparams.win_length,
            intersperse_text=self.hparams.intersperse_text,
            intersperse_token=(len(SYMBOL_SETS[self.hparams.symbol_set])),
            symbol_set=self.hparams.symbol_set,
        )

        model = GradTTS(self.hparams)

        if self.hparams.checkpoint:
            model.load_state_dict(torch.load(self.hparams.checkpoint))
        model = model.cuda()

        print(
            "Number of encoder + duration predictor parameters: %.2fm"
            % (model.encoder.nparams / 1e6)
        )
        print("Number of decoder parameters: %.2fm" % (model.decoder.nparams / 1e6))
        print("Total parameters: %.2fm" % (model.nparams / 1e6))

        print("Initializing optimizer...")
        optimizer = torch.optim.Adam(
            params=model.parameters(), lr=self.hparams.learning_rate
        )
        test_batch = test_dataset.sample_test_batch(size=self.hparams.test_size)
        for i, item in enumerate(test_batch):
            text, mel, spk = item
            self.log(
                f"image_{i}/ground_truth",
                0,
                image=plot_tensor(mel.squeeze()),
            )
        iteration = 0
        last_time = time.time()
        for epoch in range(0, self.hparams.n_epochs):
            model.train()
            dur_losses = []
            prior_losses = []
            diff_losses = []
            for batch_idx, batch in enumerate(loader):
                model.zero_grad()
                x, x_lengths, y, _, y_lengths, speaker_ids = batch

                dur_loss, prior_loss, diff_loss = model.compute_loss(
                    x, x_lengths, y, y_lengths, out_size=self.hparams.out_size
                )
                loss = sum([dur_loss, prior_loss, diff_loss])
                loss.backward()

                enc_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.encoder.parameters(), max_norm=1
                )
                dec_grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.decoder.parameters(), max_norm=1
                )
                optimizer.step()

                self.log("training/duration_loss", iteration, dur_loss.item())
                self.log("training/prior_loss", iteration, prior_loss.item())
                self.log("training/diffusion_loss", iteration, diff_loss.item())
                self.log("training/encoder_grad_norm", iteration, enc_grad_norm)
                self.log("training/decoder_grad_norm", iteration, dec_grad_norm)

                dur_losses.append(dur_loss.item())
                prior_losses.append(prior_loss.item())
                diff_losses.append(diff_loss.item())

                iteration += 1

            log_msg = f"Epoch {epoch}, iter: {iteration}: dur_loss: {np.mean(dur_losses):.4f} | prior_loss: {np.mean(prior_losses):.4f} | diff_loss: {np.mean(diff_losses):.4f} | time: {time.time()-last_time:.2f}s"
            last_time = time.time()
            with open(f"{self.hparams.log_dir}/train.log", "a") as f:
                f.write(log_msg + "\n")
                print(log_msg)

            if epoch % self.log_interval == 0:
                model.eval()
                with torch.no_grad():
                    for i, item in enumerate(test_batch):
                        x, _y, _speaker_id = item
                        x = x.to(torch.long).unsqueeze(0)
                        x_lengths = torch.LongTensor([x.shape[-1]])
                        y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=50)
                        self.log(
                            f"image_{i}/generated_enc",
                            iteration,
                            image=plot_tensor(y_enc.squeeze().cpu()),
                        )
                        self.log(
                            f"image_{i}/generated_dec",
                            iteration,
                            image=plot_tensor(y_dec.squeeze().cpu()),
                        )
                        self.log(
                            f"image_{i}/alignment",
                            iteration,
                            image=plot_tensor(attn.squeeze().cpu()),
                        )
                        self.log(
                            f"audio/inference_{i}",
                            iteration,
                            audio=self.sample_inference(model),
                        )

            if epoch % self.save_every == 0:
                torch.save(
                    model.state_dict(),
                    f=f"{self.hparams.log_dir}/{self.checkpoint_name}_{epoch}.pt",
                )

In [None]:
DEFAULTS = HParams(
    training_audiopaths_and_text="train.txt",
    test_audiopaths_and_text="val.txt",
    cudnn_enabled=True,
    log_dir="output",
    symbol_set="gradtts",
    intersperse_text=True,
    n_spks=1,
    spk_emb_dim=64,
    sampling_rate=22050,
    hop_length=256,
    win_length=1024,
    n_enc_channels=192,
    filter_channels=768,
    filter_channels_dp=256,
    n_enc_layers=6,
    enc_kernel=3,
    enc_dropout=0.1,
    n_heads=2,
    window_size=4,
    dec_dim=64,
    beta_min=0.05,
    beta_max=20.0,
    pe_scale=1000,
    test_size=2,
    n_epochs=10000,
    batch_size=1,
    learning_rate=1e-4,
    seed=37,
    out_size=2 * 22050 // 256,
    filter_length=1024,
    rank=0,
    distributed_run=False,
    oversample_weights=None,
    text_cleaners=["english_cleaners"],
    max_wav_value=32768.0,
    n_feats=80,
    mel_fmax=8000,
    mel_fmin=0.0,
    checkpoint=None,
    log_interval=100,
    save_every=1000,
)
trainer = GradTTSTrainer(DEFAULTS, rank=0)

TTSTrainer start 1311.942060225
Initializing trainer with hparams:
{'batch_size': 1,
 'beta_max': 20.0,
 'beta_min': 0.05,
 'checkpoint': None,
 'cudnn_enabled': True,
 'dec_dim': 64,
 'distributed_run': False,
 'enc_dropout': 0.1,
 'enc_kernel': 3,
 'filter_channels': 768,
 'filter_channels_dp': 256,
 'filter_length': 1024,
 'hop_length': 256,
 'intersperse_text': True,
 'learning_rate': 0.0001,
 'log_dir': 'output',
 'log_interval': 100,
 'max_wav_value': 32768.0,
 'mel_fmax': 8000,
 'mel_fmin': 0.0,
 'n_enc_channels': 192,
 'n_enc_layers': 6,
 'n_epochs': 10000,
 'n_feats': 80,
 'n_heads': 2,
 'n_spks': 1,
 'out_size': 172,
 'oversample_weights': None,
 'pe_scale': 1000,
 'rank': 0,
 'sampling_rate': 22050,
 'save_every': 1000,
 'seed': 37,
 'spk_emb_dim': 64,
 'symbol_set': 'gradtts',
 'test_audiopaths_and_text': 'val.txt',
 'test_size': 2,
 'text_cleaners': ['english_cleaners'],
 'training_audiopaths_and_text': 'train.txt',
 'win_length': 1024,
 'window_size': 4}
