In [4]:
import warnings
warnings.filterwarnings("ignore")

import argparse
import logging
import os
import sys
import time

from collections import defaultdict, OrderedDict

import matplotlib
import numpy as np
import soundfile as sf
import torch
import yaml

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from tqdm import tqdm

import sys

sys.path.append("../../cuhksz-phd/sho_util/pyfiles/")
from basic import plot_spectrogram

sys.path.append("../")
from pyfiles.dataset import PretrainingMelDataset
from pyfiles.utils import Dict2Obj

import seq2seq_vc
import seq2seq_vc.models
import seq2seq_vc.losses
import seq2seq_vc.trainers
import seq2seq_vc.collaters

# from seq2seq_vc.datasets import ParallelVCMelDataset
from torch.utils.data import Dataset

from seq2seq_vc.utils import read_hdf5
from seq2seq_vc.utils.types import str_or_none
# from seq2seq_vc.vocoder import Vocoder
# from seq2seq_vc.vocoder.s3prl_feat2wav import S3PRL_Feat2Wav
# from seq2seq_vc.vocoder.griffin_lim import Spectrogram2Waveform
# from seq2seq_vc.vocoder.encodec import EnCodec_decoder

# set to avoid matplotlib error in CLI environment
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

from seq2seq_vc.schedulers.warmup_lr import WarmupLR

scheduler_classes = dict(warmuplr=WarmupLR)

class Dict2Obj(object):
    def __init__(self, dictionary):
        """Constructor"""
        for key in dictionary:
            setattr(self, key, dictionary[key])

import joblib
import glob
datasplit = list(np.load("./data_split_ARCTIC.npy", allow_pickle=True))
information = [
    ["ABA", "Arabic", "M"],
    ["SKA", "Arabic", "F"],
    ["YBAA", "Arabic", "M"],
    ["ZHAA", "Arabic", "F"],
    ["BWC", "Mandarin", "M"],
    ["LXC", "Mandarin", "F"],
    ["NCC", "Mandarin", "F"],
    ["TXHC", "Mandarin", "M"],
    ["ASI", "Hindi", "M"],
    ["RRBI", "Hindi", "M"],
    ["SVBI", "Hindi", "F"],
    ["TNI", "Hindi", "F"],
    ["HJK", "Korean", "F"],
    ["HKK", "Korean", "M"],
    ["YDCK", "Korean", "F"],
    ["YKWK", "Korean", "M"],
    ["EBVS", "Spanish", "M"],
    ["ERMS", "Spanish", "M"],
    ["MBMPS", "Spanish", "F"],
    ["NJS", "Spanish", "F"],
    ["HQTV", "Vietnamese", "M"],
    ["PNV", "Vietnamese", "F"],
    ["THV", "Vietnamese", "F"],
    ["TLV", "Vietnamese", "M"],
]
spk2acc = {info[0]: info[1] for info in information}
spk2sex = {info[0]: info[2] for info in information}
acc2spk = {key: [] for key in set(list(spk2acc.values()))}
sex2spk = {key: [] for key in set(list(spk2sex.values()))}
for spk in spk2acc:
    acc2spk[spk2acc[spk]] += [spk]
    sex2spk[spk2sex[spk]] += [spk]
accents = list(acc2spk.keys())
accents.sort()
speakers = list(spk2acc.keys())
speakers.sort()
genders = list(sex2spk.keys())
genders.sort()

In [5]:
class PretrainingL2Arctic(Dataset):
    def __init__(self, dataset_dir, speakers, accents, genders, datasplit, scaler, mode="train", input_output_type=["wavlm", "mel"]):
        modefiles = datasplit[["train", "valid", "test"].index(mode)]
        data = {}
        for spk in speakers:
            data[spk] = []
            for a in glob.glob(dataset_dir + f"{spk}/{input_output_type[0]}/*.npy"):
                basename = os.path.basename(a)[:-4]
                if basename in modefiles:
                    data[spk] += [basename]
            data[spk].sort()
                
        files = []
        for s, spk in enumerate(speakers):
            accentid = accents.index(spk2acc[spk])
            genderid = genders.index(spk2sex[spk])
            for basename in data[spk]:
                files += [[s, accentid, genderid, basename]]
            
        self.dataset_dir = dataset_dir
        self.data = data
        self.scaler = scaler
        self.files = files
        self.speakers = speakers
        self.accents = accents
        self.genders = genders
        self.input_output_type = input_output_type
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        speaker, accent, gender, basename = self.files[idx]
        items = {}
        src_mel = self.dataset_dir + f"{self.speakers[speaker]}/{self.input_output_type[0]}/{basename}.npy"
        trg_mel = self.dataset_dir + f"{self.speakers[speaker]}/{self.input_output_type[1]}/{basename}.npy"
        items["src_feat"] = self.scaler[self.input_output_type[0]].transform(np.load(src_mel).T)
        items["trg_feat"] = self.scaler[self.input_output_type[1]].transform(np.load(trg_mel).T)
        items["src_condition"] = np.load(src_mel.replace(self.input_output_type[0], "accent_embedding"))
        items["trg_condition"] = np.load(trg_mel.replace(self.input_output_type[1], "accent_embedding"))
        items["utt_id"] = basename
        items["speaker_id"] = speaker
        items["accent_id"] = accent
        items["gender_id"] = gender
        
        return items

In [6]:
# Dataset Variables
dataset_dir = "/mntcephfs/lab_data/shoinoue/Dataset/L2-ARCTIC/"

scaler = {}
scaler_filename = f"ckpts/scalers/LibriTTS-R.save"
scaler["mel"] = joblib.load(scaler_filename)
scaler_filename = f"ckpts/scalers/LibriTTS-R_wavlm.save"
scaler["wavlm"] = joblib.load(scaler_filename)

In [9]:
conditiontype = "nocondition"
size = "small"

args = {}
args["rank"] = 0
# args["outdir"] = f"/mntcephfs/lab_data/shoinoue/Models/trained_models/AC_01/ckpts/pretraining_concatenation_LibriTTS-R/"
# args["outdir"] = f"/mntcephfs/lab_data/shoinoue/Models/trained_models/AC_01/ckpts/pretraining_addition_LibriTTS-R/"
args["outdir"] = f"/mntcephfs/lab_data/shoinoue/Models/trained_models/AC_01/ckpts/pretraining2_nocondition_LibriTTS-R_wavlmmel{'_'*int(bool(size))}{size}/"

input_type = "wavlm" if "wavlm" in args["outdir"] else "mel"
args["config_path"] = f"./../egs/l2-arctic/cascade/conf/{size}m2mvtn.wavlmmel_pt2.yaml"
# args["init_checkpoint"] = f"/mntcephfs/lab_data/shoinoue/Models/trained_models/AC_01/ckpts/pretraining_nocondition_LibriTTS-R_wavlmmel_small/checkpoint-300001steps.pkl"
args["init_checkpoint"] = f""
args["resume"] = "/mntcephfs/lab_data/shoinoue/Models/trained_models/AC_01/ckpts/pretraining2_nocondition_LibriTTS-R_wavlmmel_small/checkpoint-180000steps.pkl"
args["distributed"] = False
args = Dict2Obj(args)

# load main config
with open(args.config_path) as f:
    config = yaml.load(f, Loader=yaml.Loader)
config.update(vars(args))

# Customization
config["model_params"]["conditiontype"] = conditiontype
config["optimizer_params"]["lr"] = 0.00008

In [10]:
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(args.rank)
if not os.path.exists(args.outdir):
    os.makedirs(args.outdir)
    
### Dataset Preparation ###
dataset = {
    "train": PretrainingL2Arctic(dataset_dir, speakers, accents, genders, datasplit, scaler, "train", ["wavlm", "mel"]),
    "dev": PretrainingL2Arctic(dataset_dir, speakers, accents, genders, datasplit, scaler, "valid", ["wavlm", "mel"]),
}

collater_class = getattr(
    seq2seq_vc.collaters,
    config.get("collater_type", "ARM2MVCCollater"),
)
collater = collater_class()

sampler = {"train": None, "dev": None}
data_loader = {
    "train": DataLoader(
        dataset=dataset["train"],
        shuffle=True,
        collate_fn=collater,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        sampler=sampler["train"],
        pin_memory=config["pin_memory"],
    ),
    "dev": DataLoader(
        dataset=dataset["dev"],
        shuffle=True,
        collate_fn=collater,
        batch_size=config["batch_size"],
        num_workers=config["num_workers"],
        sampler=sampler["dev"],
        pin_memory=config["pin_memory"],
    ),
}

### Model Preparation ###
model_class = getattr(
    seq2seq_vc.models,
    config.get("model_type", "M2MVTNPro"),
)
model = model_class(**config["model_params"]).to(device)

if config.get("criterions", None):
    criterion = {
        criterion_class: getattr(seq2seq_vc.losses, criterion_class)(
            **criterion_paramaters
        )
        for criterion_class, criterion_paramaters in config["criterions"].items()
    }
else:
    raise ValueError("Please specify criterions in the config file.")

### optimizers and schedulers ###
optimizer_class = getattr(
    torch.optim,
    # keep compatibility
    config.get("optimizer_type", "Adam"),
)
optimizer = optimizer_class(
    model.parameters(),
    **config["optimizer_params"],
)
scheduler_class = scheduler_classes.get(config.get("scheduler_type", "warmuplr"))
scheduler = scheduler_class(
    optimizer=optimizer,
    **config["scheduler_params"],
)

### define trainer ###
trainer_class = getattr(
    seq2seq_vc.trainers,
    config.get("trainer_type", "ARM2MVCADVTrainer"),
)
trainer = trainer_class(
    steps=0,
    epochs=0,
    data_loader=data_loader,
    sampler=sampler,
    model=model,
    vocoder=None,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=config,
    device=device,
)

# load pretrained parameters from checkpoint
if len(args.init_checkpoint) != 0:
    trainer.load_trained_modules(
        args.init_checkpoint, init_mods=config["init-mods"]
    )

# resume from checkpoint
if len(args.resume) != 0:
    trainer.load_checkpoint(args.resume)

# freeze modules if necessary
if config.get("freeze-mods", None) is not None:
    assert type(config["freeze-mods"]) is list
    trainer.freeze_modules(config["freeze-mods"])



In [11]:
try:
    trainer.run()
finally:
    trainer.save_checkpoint(
        os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")
    )
    logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")

[train]:  60%|██████    | 180003/300000 [00:07<76:41:52,  2.30s/it] Traceback (most recent call last):
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuhk/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuhk/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuhk/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuhk/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuhk/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/cuh

KeyboardInterrupt: 