In [1]:
import os
import time
import logging
import argparse

from utils.hparams import HParam
from utils.writer import MyWriter
from datasets.dataloader import create_dataloader

parser = argparse.ArgumentParser()
parser.add_argument('-b', '--base_dir', type=str, default='.',
                    help="Root directory of run.")
parser.add_argument('-c', '--config', type=str, required=True,
                    help="yaml file for configuration")
parser.add_argument('-e', '--embedder_path', type=str, required=True,
                    help="path of embedder model pt file")
parser.add_argument('--checkpoint_path', type=str, default=None,
                    help="path of checkpoint pt file")
parser.add_argument('-m', '--model', type=str, required=True,
                    help="Name of the model. Used for both logging and saving checkpoints.")
args = parser.parse_args(["-c", "config.yaml", "-e", "embedder.pt", "-m", "eval"])

hp = HParam(args.config)
with open(args.config, 'r') as f:
    # store hparams as string
    hp_str = ''.join(f.readlines())

pt_dir = os.path.join(args.base_dir, hp.log.chkpt_dir, args.model)
os.makedirs(pt_dir, exist_ok=True)

log_dir = os.path.join(args.base_dir, hp.log.log_dir, args.model)
os.makedirs(log_dir, exist_ok=True)

chkpt_path = args.checkpoint_path if args.checkpoint_path is not None else None

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir,
            '%s-%d.log' % (args.model, time.time()))),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

if hp.data.train_dir == '' or hp.data.test_dir == '':
    logger.error("train_dir, test_dir cannot be empty.")
    raise Exception("Please specify directories of data in %s" % args.config)

writer = MyWriter(hp, log_dir)

trainloader = create_dataloader(hp, args, train=True)
testloader = create_dataloader(hp, args, train=False)

import os
import math
import torch
import torch.nn as nn
import traceback

from utils.adabound import AdaBound
from utils.audio import Audio
from utils.evaluation import validate
from model.model import VoiceFilter
from model.embedder import SpeechEmbedder
from utils.power_law_loss import PowerLawCompLoss
from utils.gdrive import GDrive

# load embedder
embedder_pt = torch.load(args.embedder_path)
embedder = SpeechEmbedder(hp).cuda()
embedder.load_state_dict(embedder_pt)
embedder.eval()

audio = Audio(hp)
model = VoiceFilter(hp).cuda()
if hp.train.optimizer == 'adabound':
    optimizer = AdaBound(model.parameters(),
                         lr=hp.train.adabound.initial,
                         final_lr=hp.train.adabound.final)
elif hp.train.optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=hp.train.adam)
else:
    raise Exception("%s optimizer not supported" % hp.train.optimizer)

step = 0

In [2]:
chkpt_path = "chkpt/final_try/chkpt_66000.pt"

In [3]:
logger.info("Resuming from checkpoint: %s" % chkpt_path)
checkpoint = torch.load(chkpt_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
step = checkpoint['step']

# will use new given hparams.
if hp_str != checkpoint['hp_str']:
    logger.warning("New hparams is different from checkpoint.")

2021-11-25 08:31:21,665 - INFO - Resuming from checkpoint: chkpt/final_try/chkpt_66000.pt


In [4]:
criterion = nn.MSELoss()
_criterion = PowerLawCompLoss()

model.eval()

accum = 0
accum_loss = 0

In [8]:
from mir_eval.separation import bss_eval_sources

In [9]:
step = 0
with torch.no_grad():
    for batch in testloader:
        dvec_mel, target_wav, mixed_wav, target_mag, mixed_mag, mixed_phase = batch[0]

        dvec_mel = dvec_mel.cuda()
        target_mag = target_mag.unsqueeze(0).cuda()
        mixed_mag = mixed_mag.unsqueeze(0).cuda()

        dvec = embedder(dvec_mel)
        dvec = dvec.unsqueeze(0)
        est_mask = model(mixed_mag, dvec)
        est_mag = est_mask * mixed_mag
        test_loss = criterion(target_mag, est_mag).item()

        mixed_mag = mixed_mag[0].cpu().detach().numpy()
        target_mag = target_mag[0].cpu().detach().numpy()
        est_mag = est_mag[0].cpu().detach().numpy()
        est_wav = audio.spec2wav(est_mag, mixed_phase)
        est_mask = est_mask[0].cpu().detach().numpy()

        sdr = bss_eval_sources(target_wav, est_wav, False)[0][0]
        writer.log_evaluation(test_loss, sdr,
                              mixed_wav, target_wav, est_wav,
                              mixed_mag.T, target_mag.T, est_mag.T, est_mask.T,
                              step)
        step += 1

In [None]:
for dvec_mels, target_mag, mixed_mag in trainloader:
    target_mag = target_mag.cuda()
    mixed_mag = mixed_mag.cuda()
    
    dvec_list = list()
    for mel in dvec_mels:
        mel = mel.cuda()
        dvec = embedder(mel)
        dvec_list.append(dvec)
    dvec = torch.stack(dvec_list, dim=0)
    dvec = dvec.detach()
    
    mask = model(mixed_mag, dvec)
    output = mixed_mag * mask

    # output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
    # target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
    loss = criterion(output, target_mag)

    loss.backward()
    accum_loss += loss.item()
    accum += 1
    
    if accum % hp["train"]["grad_accumulate"] == 0:
        optimizer.step()
        optimizer.zero_grad()
        accum = 0
        step += 1
        accum_loss /= hp["train"]["grad_accumulate"]
        
        if accum_loss > 1e8 or math.isnan(accum_loss):
            logger.error("Loss exploded to %.02f at step %d!" % (accum_loss, step))
            raise Exception("Loss exploded")

        if step == 2100:
            break

        # write loss to tensorboard
        if step % hp.train.summary_interval == 0:
            writer.log_training(accum_loss, step)
            logger.info("Wrote summary at step %d" % step)

        accum_loss = 0
            
        # 1. save checkpoint file to resume training
        # 2. evaluate and save sample to tensorboard
        # backup brrrrrrrrrrrrrrrrrrrrrrrrrrrrrr
        if step % hp.train.checkpoint_interval == 0:
            save_path = os.path.join(pt_dir, 'chkpt_%d.pt' % step)
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'step': step,
                'hp_str': hp_str,
            }, save_path)
            logger.info("Saved checkpoint to: %s" % save_path)
            validate(audio, model, embedder, testloader, writer, step)

            # drive.Upload(save_path, "1sWAUt5vfyD97Cq85J8_zuwMeX4tmfEiZ")
            asyncio.run(UploadToDrive(drive, save_path))

            # Nén file
            os.system(f'zip -j ./tensorboard.zip ./{log_dir}/*')
            drive.Upload('tensorboard.zip', "1sWAUt5vfyD97Cq85J8_zuwMeX4tmfEiZ")