In [190]:
import torch
import numpy as np
import yaml
from utils import *

from tqdm import tqdm
from pesq import pesq
from pystoi import stoi
import random
from torch.utils.data import Subset
import soundfile as sf
from datetime import datetime

from IPython.display import Audio, display
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from dataset import CustomDataset

from models.SEANet_TFiLM import SEANet_TFiLM
from models.SEANet_TFiLM_nok import SEANet_TFiLM as SEANet_TFiLM_nok
from models.SEANet import SEANet

DEVICE = 'cuda'
def load_model(model, checkpoint_path):
    model = model.to(DEVICE)
    ckpt = torch.load(checkpoint_path)
    model.load_state_dict(ckpt['generator_state_dict'])
    print(f"Model loaded from {checkpoint_path}")
    return model

In [200]:
## Load Dataset
config_path = "configs/K64_main.yaml"
config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
test_dataset = CustomDataset(path_dir_nb=config['dataset']['nb_test'], 
                                 path_dir_wb=config['dataset']['wb_test'], seg_len=config['dataset']['seg_len'], mode="train")

## For USAC44 Dataset
path_wb = ["/home/woongjib/Projects/USAC44_mono_48k"]
path_nb = ["/home/woongjib/Projects/USAC44_mono_48k_HEAAC16_LPF_Crop"]
test_dataset = CustomDataset(path_dir_nb=path_nb, 
                                 path_dir_wb=path_wb, seg_len=3, mode="val")

test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

idx=13
hr, lr, cond, name, _ = test_dataset[idx]

for hr, lr, cond, name, _ in test_dataloader:
    break

print(hr.shape)
print(lr.shape)
print(name)
print(cond.shape)

# sf.write("HR.wav", hr.squeeze().cpu().numpy(), samplerate=48000)
# sf.write("LR.wav", lr.squeeze().cpu().numpy(), samplerate=48000)

# t1 = draw_spec(hr.squeeze().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True, 
#               save_fig=True, save_path="hr.png")
# t2 = draw_spec(lr.squeeze().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True,
#               save_fig=True, save_path="lr.png")


Index:0 with 105051 samples
Index:1 with 62475 samples
LR 167526 and HR 167526 file numbers loaded!
train: 167526 files loaded
Index:0 with 43 samples
LR 43 and HR 43 file numbers loaded!
val: 43 files loaded
torch.Size([1, 1, 674408])
torch.Size([1, 1, 674408])
('Alice',)
torch.Size([1, 1, 320, 329])


In [213]:
## Load Model
# model = SEANet_TFiLM(kmeans_model_path=config['model']['kmeans_path'])
# model = load_model(model, "/home/woongjib/Projects/BESSL__/ckpt_K64/epoch_41_lsdH_0.441.pth")

model = SEANet_TFiLM_nok(kmeans_model_path=config['model']['kmeans_path'])
model = load_model(model, "/home/woongjib/Projects/BESSL__/ckpt_nok/epoch_13_lsdH_0.430.pth")

model_base = SEANet()
model_base = load_model(model_base, "/home/woongjib/Projects/BESSL__/ckpt_baseline/epoch_43_lsdH_0.548.pth")

## Loop
lsd_list = []
lsd_list_b = []
lsd_list_nb = []

for idx in range(10):
    wb, nb, cond, name, _ = test_dataset[idx]
    print(name)
    with torch.no_grad():
        recon = model(nb.to(DEVICE), cond.to(DEVICE))
        recon_b = model_base(nb.to(DEVICE), cond.to(DEVICE))
        
        a = lsd_batch(wb.cpu(), recon.cpu(), fs=48000)
        b = lsd_batch(wb.cpu(), recon_b.cpu(), fs=48000)
        c = lsd_batch(wb.cpu(), nb.cpu(), fs=48000)
        
        lsd_list.append(a)
        lsd_list_b.append(b)
        lsd_list_nb.append(c)
        
        sf.write(f"outputs/{name}_bwe.wav", recon.squeeze().cpu().numpy(), samplerate=48000)
        sf.write(f"outputs/{name}_bwe_base.wav", recon_b.squeeze().cpu().numpy(), samplerate=48000)
        sf.write(f"outputs/{name}_gt.wav", wb.squeeze().cpu().numpy(), samplerate=48000)
        
average_lsd = sum(lsd_list) / len(lsd_list)
average_lsd_b = sum(lsd_list_b) / len(lsd_list)
average_lsd_nb = sum(lsd_list_nb) / len(lsd_list)

print(f"\nAverage LSD for the 20 samples: {average_lsd:.3f}")
print(f"Average LSD base for the 20 samples: {average_lsd_b:.3f}")
print(f"Average LSD nb for the 20 samples: {average_lsd_nb:.3f}")

# display(Audio(wb.cpu().numpy(), rate=48000))
# display(Audio(recon.squeeze().cpu().numpy(), rate=48000))
# display(Audio(recon_b.squeeze().cpu().numpy(), rate=48000))
# display(Audio(nb.cpu().numpy(), rate=48000))

# t1 = draw_spec(wb.squeeze().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True, 
#               save_fig=True, save_path="hr")
# t2 = draw_spec(recon.squeeze().cpu().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True,
#               save_fig=True, save_path="bwe")
# t2 = draw_spec(recon_b.squeeze().cpu().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True,
#               save_fig=True, save_path="bwe_b")
# t3 = draw_spec(nb.squeeze().cpu().numpy(), figsize=(15,4), sr=48000, win_len=2048, hop_len=2048, use_colorbar=True,
#               save_fig=True, save_path="lr")    

# sf.write("bwe.wav", recon.squeeze().cpu().numpy(), samplerate=48000)
# sf.write("bwe_b.wav", recon_b.squeeze().cpu().numpy(), samplerate=48000)

# 15425536831381643598
## 8495



**** CHECKPOINT LOADED! **** 
Model loaded from /home/woongjib/Projects/BESSL__/ckpt_nok/epoch_13_lsdH_0.430.pth
Model loaded from /home/woongjib/Projects/BESSL__/ckpt_baseline/epoch_43_lsdH_0.548.pth
Alice
Arirang_speech
Green_speech
HarryPotter
KoreanM1
Music_1
Music_2
Music_3
Music_4
Music_5

Average LSD for the 20 samples: 0.510
Average LSD base for the 20 samples: 0.664
Average LSD nb for the 20 samples: 0.956
