# Import libraries & parse arguments

In [1]:
import os
import torch
import random
import librosa
import numpy as np
import IPython.display
import json

from utils.audio import Audio
from utils.hparams import HParam

In [2]:
%matplotlib inline

from mir_eval.separation import bss_eval_sources
import matplotlib
import matplotlib.pyplot as plt

In [3]:
from utils.hparams import HParam
from datasets.dataloader import create_dataloader

In [4]:
import torch
import torch.nn as nn

from utils.adabound import AdaBound
from utils.power_law_loss import PowerLawCompLoss
from model.model import VoiceFilter
from model.embedder import SpeechEmbedder

In [5]:
hp = HParam("config.yaml")

  for doc in docs:


# Prepare

Create testloader and get first sample

In [6]:
testloader_vn = create_dataloader(hp, "generate", dataset_detail=["vin", "zalo-train", "zalo-test"], scheme="test_cuda", size=5000)
testloader_lb = create_dataloader(hp, "generate", dataset_detail=["librispeech-test"], scheme="test_cuda", size=5000)
testloader_gg = create_dataloader(hp, "gg", dataset_detail="test", scheme="test_cuda")

Audio is an abstract class that help simplify many operation on a single audio file like convert to mel, waveform to mel or mel to waveform,...

In [7]:
audio = Audio(hp)

Load pretrained

In [8]:
device = "cpu"

embedder_pt = torch.load("embedder.pt", device)
embedder = SpeechEmbedder(hp)
embedder.load_state_dict(embedder_pt)
embedder = embedder.cuda()
embedder.eval()

# Power-law compressed loss
model = VoiceFilter(hp)
checkpoint = torch.load("chkpt/powlaw_loss/chkpt_168000.pt", device)
model.load_state_dict(checkpoint['model'])
model = model.cuda()
model.eval()


# First try (MSE loss)
model_0 = VoiceFilter(hp)
checkpoint = torch.load("chkpt/new_dataloader/chkpt_108000.pt", device)
model_0.load_state_dict(checkpoint['model'])
model_0 = model_0.cuda()
model_0.eval()

# MSE ver 48k (ms.Tam)
model_t = VoiceFilter(hp)
checkpoint = torch.load("chkpt/mstam_mse/chkpt_48000.pt", device)
model_t.load_state_dict(checkpoint['model'])
model_t = model_t.cuda()
model_t.eval()

pass

# Main

Inference function for power-law compressed model

In [9]:
def powerlaw_forward(model, batch):
    criterion = PowerLawCompLoss()
    dvec_mels, target_mag, _, mixed_mag, mixed_phase, target_stft, mixed_stft, target_wavs, mixed_wavs = batch
    
    with torch.no_grad():
        dvec_list = list()
        for mel in dvec_mels:
            mel = mel.cuda()
            dvec = embedder(mel)
            dvec_list.append(dvec)
        dvec = torch.stack(dvec_list, dim=0)
        target_stft = target_stft.cuda()
        mixed_stft = mixed_stft.cuda()
        

        est_mask = model(torch.pow(mixed_stft.abs(), 0.3), dvec)
        loss = criterion(est_mask, mixed_stft, target_stft).item()
        
        est_mask = torch.pow(est_mask, 10/3)
        est_stft = mixed_stft * est_mask
        est_stft = est_stft.cpu().numpy()

    sdrs = []
    sdrs_before = []
    for est_stft_, target_wav, mixed_wav in zip(est_stft, target_wavs, mixed_wavs):
        est_wav = audio._istft(est_stft_.T, length=len(target_wav))
        sdrs_before.append(bss_eval_sources(target_wav, mixed_wav, False)[0][0])
        sdrs.append(bss_eval_sources(target_wav, est_wav, False)[0][0])

    return loss, sdrs_before, sdrs

In [10]:
def mse_forward(model, batch):
    criterion = nn.MSELoss()
    dvec_mels, target_mag, _, mixed_mag, mixed_phase, _, _, target_wavs, mixed_wavs = batch
    
    with torch.no_grad():
        dvec_list = list()
        for mel in dvec_mels:
            mel = mel.cuda()
            dvec = embedder(mel)
            dvec_list.append(dvec)
        dvec = torch.stack(dvec_list, dim=0)
        mixed_mag = mixed_mag.cuda()
        target_mag = target_mag.cuda()
        
        est_mask = model(mixed_mag, dvec)
        est_mag = mixed_mag * est_mask
        
        loss = criterion(target_mag, est_mag).item()

        est_mag = est_mag.cpu().numpy()
        mixed_phase = mixed_phase.numpy()
    
    sdrs = []
    for est_mag_, mixed_phase_, target_wav in zip(est_mag, mixed_phase, target_wavs):
        est_wav = audio.spec2wav(est_mag_, mixed_phase_, length=len(target_wav))
        sdrs.append(bss_eval_sources(target_wav, est_wav, False)[0][0])

    return loss, sdrs

# Final evaluation

## GGSpeakerID

In [None]:
%%time
losses_p = []
losses_0 = []
losses_t = []
sdrs_before = []
sdrs_p = []
sdrs_0 = []
sdrs_t = []
step = 0
for batch in testloader_gg:
    loss, sdrs_b, sdrs = powerlaw_forward(model, batch)
    losses_p.append(loss)
    sdrs_before += sdrs_b
    sdrs_p += sdrs

    loss, sdrs = mse_forward(model_0, batch)
    losses_0.append(loss)
    sdrs_0 += sdrs

    loss, sdrs = mse_forward(model_t, batch)
    losses_t.append(loss)
    sdrs_t += sdrs
    
    print(f"Step {step} done")
    step+=1

Step 0 done
Step 1 done
Step 2 done
Step 3 done
Step 4 done
Step 5 done
Step 6 done
Step 7 done
Step 8 done
Step 9 done


In [None]:
with open("GGSpeaker_test.json", "w") as f:
    result_json = {
        "losses_p": losses_p,
        "losses_0": losses_0,
        "losses_t": losses_t,
        "sdrs_before": sdrs_before,
        "sdrs_p": sdrs_p,
        "sdrs_0": sdrs_0,
        "sdrs_t": sdrs_t
    }
    json_object = json.dumps(result_json, indent = 4)
    f.write(json_object)

## VN dataset

In [None]:
%%time
losses_p = []
losses_0 = []
losses_t = []
sdrs_before = []
sdrs_p = []
sdrs_0 = []
sdrs_t = []
step = 0
for batch in testloader_vn:
    loss, sdrs_b, sdrs = powerlaw_forward(model, batch)
    losses_p.append(loss)
    sdrs_before += sdrs_b
    sdrs_p += sdrs

    loss, sdrs = mse_forward(model_0, batch)
    losses_0.append(loss)
    sdrs_0 += sdrs

    loss, sdrs = mse_forward(model_t, batch)
    losses_t.append(loss)
    sdrs_t += sdrs
    
    print(f"Step {step} done")
    step+=1

In [None]:
with open("VNdata_test.json", "w") as f:
    result_json = {
        "losses_p": losses_p,
        "losses_0": losses_0,
        "losses_t": losses_t,
        "sdrs_before": sdrs_before,
        "sdrs_p": sdrs_p,
        "sdrs_0": sdrs_0,
        "sdrs_t": sdrs_t
    }
    json_object = json.dumps(result_json, indent = 4)
    f.write(json_object)

## Generate test

In [None]:
%%time
losses_p = []
losses_0 = []
losses_t = []
sdrs_before = []
sdrs_p = []
sdrs_0 = []
sdrs_t = []
step = 0
for batch in testloader_lb:
    loss, sdrs_b, sdrs = powerlaw_forward(model, batch)
    losses_p.append(loss)
    sdrs_before += sdrs_b
    sdrs_p += sdrs

    loss, sdrs = mse_forward(model_0, batch)
    losses_0.append(loss)
    sdrs_0 += sdrs

    loss, sdrs = mse_forward(model_t, batch)
    losses_t.append(loss)
    sdrs_t += sdrs
    
    print(f"Step {step} done")
    step+=1

In [None]:
with open("Generate_test.json", "w") as f:
    result_json = {
        "losses_p": losses_p,
        "losses_0": losses_0,
        "losses_t": losses_t,
        "sdrs_before": sdrs_before,
        "sdrs_p": sdrs_p,
        "sdrs_0": sdrs_0,
        "sdrs_t": sdrs_t
    }
    json_object = json.dumps(result_json, indent = 4)
    f.write(json_object)