In [None]:
from matplotlib.cm import ScalarMappable
import matplotlib.pyplot as plt
import numpy as np
import os
import soundfile as sf
import torch
import torchaudio

from solver.spl import SPLSolver
from utils import audio_player_list, read_config_yaml, bandwidth_to_max_bin, Config

MODEL_VERSION = '20230402_ISMIR22'
DEVICE = 'cpu'
CONFIG_PATH = 'config/cfg_spl.yaml'
MODEL_DIR = f'checkpoints/spl/{MODEL_VERSION}/'
TEST_AUDIO_DIR = '/mnt/Projects/PianoConcerto_MMO/MMO_rendered/PCD/excerpts'

OUT_DIR = f'separated/2023_PCPipeline/spl/{MODEL_VERSION}'

if not os.path.isdir(OUT_DIR):
    os.makedirs(OUT_DIR)
    
model_cfg = read_config_yaml(CONFIG_PATH)


spl_solver = SPLSolver(device=DEVICE,
                       model_cfg=model_cfg,
                       train=False)

# TODO: Write a better function here.
spl_solver._model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'spl_best.pth')))

In [None]:
H = 1024
N = 4096
Fs = 44100
K = N // 2
gamma = 10

F_coef = np.arange(K + 1) * Fs / N


for concerto_id in sorted(os.listdir(TEST_AUDIO_DIR))[1::10]:
        
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))

    num_frames = int(np.ceil(Fs/H * audio.shape[1] / Fs))
    T_coef = np.arange(num_frames) * H / Fs
    extent = [T_coef[0], T_coef[-1], F_coef[0], F_coef[-1]]
                                                       
    
    X_c = spl_solver._stft(audio.unsqueeze(0))
    model_outputs, mask_dict, masked_stfts, estimates_dict = spl_solver.separate(audio, mwf=False)
    fig, ax = plt.subplots(2, 2, figsize=(6,6), layout='constrained')            

    for target_idx, target in enumerate(model_outputs):
        mask = masked_stfts[target]

        for ch in range(2):
            X_compressed = mask.to('cpu').abs().detach().numpy() # np.log(1 + gamma * mask.to('cpu').abs().detach().numpy())
            X_compressed = np.log(1 + gamma * mask.to('cpu').abs().detach().numpy())
            
            ax[target_idx, ch].imshow(X_compressed[0, ch, ...],
                                      origin='lower', 
                                      cmap='gray_r',
                                      aspect='auto',
                                      extent=extent)
            if ch == 0:
                ax[target_idx, ch].set_ylabel('Frequency (Hz)')
            if target_idx == 1:
                ax[target_idx, ch].set_xlabel('Time (seconds)')
                
            ax[target_idx, ch].set_ylim([0, 4000])
            ax[target_idx, ch].set_title(f'{target.upper()}, channel: {ch}', fontsize=10)
            
            
    fig.suptitle(concerto_id, fontsize=12)
    fig.colorbar(ScalarMappable(norm=None, cmap='gray_r'), ax=ax[:, -1], shrink=0.5, location='right')
    plt.show()

    audio_player_list([audio.to('cpu').detach().numpy(),
                       estimates_dict['piano'], 
                       estimates_dict['orch']], [Fs, Fs, Fs], width=180, height=30, 
                  columns=['mix', 'piano', 'orch'])
                    

In [None]:
from tqdm import tqdm
for concerto_id in tqdm(sorted(os.listdir(TEST_AUDIO_DIR))):
        
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))
                                                      
    
    X_c = spl_solver._stft(audio.unsqueeze(0))
    X_mag = spl_solver._build_input_mix_spec(X_c).permute(0, 1, 3, 2)
    _, _, _, estimates_dict = spl_solver.separate(mix=audio, 
                                                  mwf=True, 
                                                  num_iter=1)
                                                          
    sf.write(f'{OUT_DIR}/{concerto_id}_P.wav', estimates_dict['piano'].cpu().detach().numpy().T, Fs)
    sf.write(f'{OUT_DIR}/{concerto_id}_O.wav', estimates_dict['orch'].cpu().detach().numpy().T, Fs)