In [1]:
%load_ext autoreload
%autoreload 2
%cd ..

/home/branch/rnd/research-pt2/grad-tts


In [2]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import sys
import dotsi

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

import torch 

from utils import DictToAttr, Struct, AttrDict
from data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate

import params
from model.tts import GradTTS

In [3]:
# [PREPARE CONFIG]
config = {
    "dataset": {
        "train_metadata_paths": [
            ["resources/filelists/esd-0011/train.txt", "/home/branch/Downloads/ESD/ESD/0011"],
            ["resources/filelists/esd-0012/train.txt", "/home/branch/Downloads/ESD/ESD/0012"],
            ["resources/filelists/esd-0013/train.txt", "/home/branch/Downloads/ESD/ESD/0013"],
            ["resources/filelists/esd-0014/train.txt", "/home/branch/Downloads/ESD/ESD/0014"],
            ["resources/filelists/esd-0015/train.txt", "/home/branch/Downloads/ESD/ESD/0015"],
            ["resources/filelists/esd-0016/train.txt", "/home/branch/Downloads/ESD/ESD/0016"],
            ["resources/filelists/esd-0017/train.txt", "/home/branch/Downloads/ESD/ESD/0017"],
            ["resources/filelists/esd-0018/train.txt", "/home/branch/Downloads/ESD/ESD/0018"],
            ["resources/filelists/esd-0019/train.txt", "/home/branch/Downloads/ESD/ESD/0019"],
            ["resources/filelists/esd-0020/train.txt", "/home/branch/Downloads/ESD/ESD/0020"],
        ],
        "eval_metadata_paths": [
            ["resources/filelists/esd-0011/eval.txt", "/home/branch/Downloads/ESD/ESD/0011"],
            ["resources/filelists/esd-0012/eval.txt", "/home/branch/Downloads/ESD/ESD/0012"],
            ["resources/filelists/esd-0013/eval.txt", "/home/branch/Downloads/ESD/ESD/0013"],
            ["resources/filelists/esd-0014/eval.txt", "/home/branch/Downloads/ESD/ESD/0014"],
            ["resources/filelists/esd-0015/eval.txt", "/home/branch/Downloads/ESD/ESD/0015"],
            ["resources/filelists/esd-0016/eval.txt", "/home/branch/Downloads/ESD/ESD/0016"],
            ["resources/filelists/esd-0017/eval.txt", "/home/branch/Downloads/ESD/ESD/0017"],
            ["resources/filelists/esd-0018/eval.txt", "/home/branch/Downloads/ESD/ESD/0018"],
            ["resources/filelists/esd-0019/eval.txt", "/home/branch/Downloads/ESD/ESD/0019"],
            ["resources/filelists/esd-0020/eval.txt", "/home/branch/Downloads/ESD/ESD/0020"],
        ],
        "label2id_path": "resources/esd_emotion.json",
        "cmudict_path": "resources/cmu_dictionary",
        "eval_texts": [
            "In all these lines the facts are drawn together by a strong thread of unity.",
            "After the construction and action of the machine had been explained, the doctor asked the governor what kind of men he had commanded at Goree.",
            "After a few years of active exertion the Society was rewarded by fresh legislation.",
        ]
    },
    "training": {
        "device": "cpu",
        "batch_size": 3,
        "num_workers": 1,
        "lr_init": 1e-4,
        "n_epochs": 999,
        "global_step": 0,
        "saveroot": "/home/branch/rnd/misc/exp-test/daisy-test-a10",
        "eval_every": 3,
    }
}

config = dotsi.Dict(config)

In [4]:
# [PREPARE DATASET]
label2id = json.load(open(config.dataset.label2id_path, "r"))
id2label = {label2id[l]:l for l in label2id.keys()}

train_dataset = torch.utils.data.ConcatDataset([
    TextMelSpeakerDataset(
        filelist_path=filelist_path,
        cmudict_path=config.dataset.cmudict_path,
        split_char="|",
        label2id_path=config.dataset.label2id_path,
        fileroot=fileroot,
    )
    for filelist_path, fileroot in config.dataset.train_metadata_paths
])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    collate_fn=TextMelSpeakerBatchCollate(),
    shuffle=True,
    num_workers=config.training.num_workers,
)

if config.dataset.eval_metadata_paths is not None:
    eval_dataset = torch.utils.data.ConcatDataset([
        TextMelSpeakerDataset(
            filelist_path=filelist_path,
            cmudict_path=config.dataset.cmudict_path,
            split_char="|",
            label2id_path=config.dataset.label2id_path,
            fileroot=fileroot,
        )
        for filelist_path, fileroot in config.dataset.eval_metadata_paths
    ])

    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.training.batch_size,
        collate_fn=TextMelSpeakerBatchCollate(),
        shuffle=False,
        num_workers=config.training.num_workers,
    )

In [5]:
# [PREPARE MODEL]
model = GradTTS(
    n_vocab=params.n_vocab,
    n_spks=params.n_spks,
    spk_emb_dim=params.spk_emb_dim,
    n_enc_channels=params.n_enc_channels,
    filter_channels=params.filter_channels,
    filter_channels_dp=params.filter_channels_dp, 
    n_heads=params.n_heads, n_enc_layers=params.n_enc_layers,
    enc_kernel=params.enc_kernel, enc_dropout=params.enc_dropout, window_size=params.window_size, 
    n_feats=params.n_feats, dec_dim=params.dec_dim,
    beta_min=params.beta_min, beta_max=params.beta_max, pe_scale=params.pe_scale,
    with_reference_encoder=params.with_reference_encoder,
    re_n_feats=params.re_n_feats,
    re_n_channels=params.re_n_channels,
    re_filter_channels=params.re_filter_channels,
    re_n_heads=params.re_n_heads,
    re_n_layers=params.re_n_layers,
    re_kernel_size=params.re_kernel_size,
    re_p_dropout=params.re_p_dropout,
    with_film=params.with_film,
)

_ = model.to(config.training.device)

In [6]:
# [PREPARE OPTIMIZER AND CRITERION]
optimizer = torch.optim.AdamW(model.parameters(), lr=config.training.lr_init)

In [7]:
# [PREPARE CHECKPOINT DIRECTORY]
os.makedirs(os.path.join(config.training.saveroot), exist_ok=True)
os.makedirs(os.path.join(config.training.saveroot, "logs"), exist_ok=True)
os.makedirs(os.path.join(config.training.saveroot, "checkpoints"), exist_ok=True)

In [12]:
# [TRAINING LOOP]
loss_hist = list()
global_step = config.training.global_step

_ = model.train()
for epoch in range(config.training.n_epochs):
    for step, batch in enumerate(train_loader):

        # Forward-pass
        optimizer.zero_grad()

        dur_loss, prior_loss, diff_loss, aux_clf_loss = model.compute_loss(
            batch["x"].to(config.training.device), batch["x_lengths"].to(config.training.device),
            batch["y"].to(config.training.device), batch["y_lengths"].to(config.training.device),
            labels=batch["spk"].to(config.training.device),
            out_size=params.out_size,
        )

        # Compute loss
        loss = dur_loss + prior_loss + diff_loss + aux_clf_loss

        # Backprop
        loss.backward()
        optimizer.step()
        
        # Logging
        loss_hist.append(loss.item())
        global_step += 1
        sys.stdout.write(f"\r[epoch] {epoch+1}/{global_step} [step] {step+1}/{len(train_loader)} [loss] {loss.item()}")
        
        # Evaluation
        if (global_step % config.training.eval_every) == 0:

            # Create steps checkpoint folder
            os.makedirs(
                os.path.join(config.training.saveroot, "checkpoints", str(global_step)),
                exist_ok=True,
            )

            # Do evaluation
            print("\n===> Evaluating...")
            if eval_dataset is not None:
                eval_embeds = list()
                eval_ids = list()
                _ = model.eval()
                for step, batch in enumerate(eval_loader):
                    with torch.no_grad():
                        # Extract embeddings
                        ref_embeddings, gammas, betas = model.reference_encoder(
                            batch["y"].to(config.training.device),
                            batch["y_lengths"].to(config.training.device),
                        )
                        
                    eval_embeds.append(ref_embeddings.cpu())
                    eval_ids.append(batch["spk"].cpu())

                    sys.stdout.write(f"\r{step+1}/{len(eval_loader)}")
                    if step > 3:
                        break
                
                # Examine projected embeddings
                eval_embeds = torch.cat(eval_embeds).cpu()
                eval_ids = torch.cat(eval_ids).cpu()
                
                eval_embeds_pcs = PCA().fit_transform(eval_embeds)

                plt.title(f"Global Step: {global_step}")
                for emo_id in torch.unique(eval_ids).numpy():
                    idxs = np.argwhere(eval_ids == emo_id).flatten()
                    plt.scatter(eval_embeds_pcs[idxs,0], eval_embeds_pcs[idxs,1], label=id2label[emo_id], alpha=0.33)
                plt.legend(loc="best")
                plt.savefig(os.path.join(config.training.saveroot, "checkpoints", str(global_step), "eval_emo_pcs.png"))
                plt.close()
                
                _ = model.train()

                # Save checkpoint
                ### Save loss history
                pd.DataFrame(loss_hist, columns=["loss"]).to_csv(
                    os.path.join(config.training.saveroot, "checkpoints", str(global_step), "losses.tsv"),
                    sep="\t",
                    index=None,
                )
                ### Save config
                with open(os.path.join(config.training.saveroot, "checkpoints", str(global_step), "config.json"), 'w', encoding='utf-8') as f:
                    json.dump(config, f, ensure_ascii=False, indent=4)
                ### Save optimizer
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(config.training.saveroot, "checkpoints", str(global_step), "optimizer_state.pt")
                )
                ### Save weights
                torch.save(
                    model.state_dict(),
                    os.path.join(config.training.saveroot, "checkpoints", str(global_step), "model_weights.pt")
                )

    #     break
    # break



[epoch] 1/3 [step] 3/4995 [loss] 19.675281524658203
===> Evaluating...
[epoch] 1/6 [step] 6/4995 [loss] 19.381908416748047
===> Evaluating...
[epoch] 1/9 [step] 9/4995 [loss] 17.178924560546875
===> Evaluating...
[epoch] 1/12 [step] 12/4995 [loss] 14.079953193664555
===> Evaluating...
5/334

KeyboardInterrupt: 

In [9]:
config

{'dataset': {'train_metadata_paths': [['resources/filelists/esd-0011/train.txt',
    '/home/branch/Downloads/ESD/ESD/0011'],
   ['resources/filelists/esd-0012/train.txt',
    '/home/branch/Downloads/ESD/ESD/0012'],
   ['resources/filelists/esd-0013/train.txt',
    '/home/branch/Downloads/ESD/ESD/0013'],
   ['resources/filelists/esd-0014/train.txt',
    '/home/branch/Downloads/ESD/ESD/0014'],
   ['resources/filelists/esd-0015/train.txt',
    '/home/branch/Downloads/ESD/ESD/0015'],
   ['resources/filelists/esd-0016/train.txt',
    '/home/branch/Downloads/ESD/ESD/0016'],
   ['resources/filelists/esd-0017/train.txt',
    '/home/branch/Downloads/ESD/ESD/0017'],
   ['resources/filelists/esd-0018/train.txt',
    '/home/branch/Downloads/ESD/ESD/0018'],
   ['resources/filelists/esd-0019/train.txt',
    '/home/branch/Downloads/ESD/ESD/0019'],
   ['resources/filelists/esd-0020/train.txt',
    '/home/branch/Downloads/ESD/ESD/0020']],
  'eval_metadata_paths': [['resources/filelists/esd-0011/eval.tx

In [19]:
with torch.no_grad():
    dur_loss, prior_loss, diff_loss, aux_clf_loss = model.compute_loss(
        batch["x"].to(config.training.device), batch["x_lengths"].to(config.training.device),
        batch["y"].to(config.training.device), batch["y_lengths"].to(config.training.device),
        labels=batch["spk"].to(config.training.device),
        out_size=params.out_size,
    )

In [15]:
batch["spk"]

tensor([2, 3, 2])

In [10]:
with torch.no_grad():
    u, b, g = model.reference_encoder(batch["y"], batch["y_lengths"])

In [11]:
b.shape, g.shape, u.shape

(torch.Size([3, 6, 192, 1]),
 torch.Size([3, 6, 192, 1]),
 torch.Size([3, 80, 336]))

In [61]:
from glob import glob
from shutil import copyfile
import pandas as pd

In [65]:
# for df_name in glob("resources/filelists/esd*/*.txt"):
#     df = pd.read_csv(df_name, sep="|", names=["x","y","z"])
#     df["z"] = [str(i).capitalize() for i in df["z"].to_list()]
#     df.to_csv(df_name, header=False, index=None, sep="|")

# for df_name in glob("/home/branch/rnd/datasets/esd-0*/*.txt"):
#     idd = df_name.split("/")[-2]
#     tdd = df_name.split("/")[-1]
#     tg_name = f"resources/filelists/{idd}/{tdd}"
#     copyfile(df_name, tg_name)