In [None]:
from matplotlib.cm import ScalarMappable
import matplotlib.pyplot as plt
import numpy as np
import os
from pathlib import Path
import scipy.signal
import soundfile as sf

import torch
import torchaudio

from dsp.transforms import ComplexNorm, TorchSTFT
from model.umx import UMXSeparator
from utils import audio_player_list, read_config_yaml, bandwidth_to_max_bin, Config

GAMMA = 10

H = 1024
N = 4096
Fs = 44100
K = N // 2

MODEL_VERSION = '20230328_ISMIR22'
DEVICE = 'cpu'
MODEL_DIR = f'checkpoints/umx/{MODEL_VERSION}/'

TEST_AUDIO_DIR = '/mnt/Projects/PianoConcerto_MMO/MMO_rendered/PCD/excerpts'

OUT_DIR = f'separated/2023_PCPipeline/umx/{MODEL_VERSION}'

if not os.path.isdir(OUT_DIR):
    os.makedirs(OUT_DIR)

In [None]:
config_umx = read_config_yaml('config/cfg_umx_piano.yaml')
umx_sep = UMXSeparator(device=torch.device(DEVICE),
                       model_cfg=config_umx,
                       softmask=True,
                       residual=False)
umx_sep.load_model(model_dir=MODEL_DIR,
                   targets=['piano', 'orch'])

In [None]:
for concerto_id in sorted(os.listdir(TEST_AUDIO_DIR))[::10]:
        
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))
    
    num_frames = int(np.ceil(Fs/H * audio.shape[1] / Fs))
    T_coef = np.arange(num_frames) * H / Fs
    F_coef = np.arange(K + 1) * Fs / N
    extent = [T_coef[0], T_coef[-1], F_coef[0], F_coef[-1]]
                                                       
                                                       
    estimates, targets_stft = umx_sep.forward(audio.unsqueeze(0).to(DEVICE))
    estimates_dict = umx_sep.to_dict(estimates.squeeze(0))
    target_stft_dict = umx_sep.to_dict(targets_stft.squeeze())
    fig, ax = plt.subplots(2, 2, figsize=(6,6), layout='constrained')            

    for target_idx, target in enumerate(target_stft_dict):
        target_stft = target_stft_dict[target]

        for ch in range(2):
            X_mag = ComplexNorm()(target_stft)
            X_compressed = np.log(1 + GAMMA * X_mag.to('cpu').abs().detach().numpy())
            
            ax[target_idx, ch].imshow(X_compressed[ch, ...],
                                      origin='lower', 
                                      cmap='gray_r',
                                      aspect='auto',
                                      extent=extent)
            if ch == 0:
                ax[target_idx, ch].set_ylabel('Frequency (Hz)')
            if target_idx == 1:
                ax[target_idx, ch].set_xlabel('Time (seconds)')
                
            ax[target_idx, ch].set_ylim([0, 4000])
            ax[target_idx, ch].set_title(f'{target.upper()}, channel: {ch}', fontsize=10)
            
            
    fig.suptitle(concerto_id, fontsize=12)
    fig.colorbar(ScalarMappable(norm=None, cmap='gray_r'), ax=ax[:, -1], shrink=0.5, location='right')
    plt.show()

    audio_player_list([audio.to('cpu').detach().numpy(),
                       estimates_dict['piano'].to('cpu').detach().numpy(), 
                       estimates_dict['orch'].to('cpu').detach().numpy()], [Fs, Fs, Fs], width=180, height=30, 
                  columns=['mix', 'piano', 'orch'])


## Write separated files (All 81)

In [None]:
from tqdm import tqdm
for concerto_id in tqdm(sorted(os.listdir(TEST_AUDIO_DIR))):
        
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))                                                       
                                                       
    estimates, targets_stft = umx_sep.forward(audio.unsqueeze(0))
    estimates_dict = umx_sep.to_dict(estimates.squeeze(0))
    
    sf.write(f'{OUT_DIR}/{concerto_id}_P.wav', estimates_dict['piano'].cpu().detach().numpy().T, Fs)
    sf.write(f'{OUT_DIR}/{concerto_id}_O.wav', estimates_dict['orch'].cpu().detach().numpy().T, Fs)