In [None]:
from matplotlib.cm import ScalarMappable
import matplotlib.pyplot as plt
import numpy as np
import os
import soundfile as sf
import torchaudio

from solver.demucs import DemucsSolver
from utils import audio_player_list, read_config_yaml, bandwidth_to_max_bin, Config

MODEL_VERSION = '20230403_ISMIR22'
MODEL_TYPE = 'hdemucs'
DEVICE = 'cpu'
CONFIG_PATH = f'config/cfg_{MODEL_TYPE}.yaml'
MODEL_DIR = f'checkpoints/{MODEL_TYPE}/{MODEL_VERSION}/'
TEST_AUDIO_DIR = '/mnt/Projects/PianoConcerto_MMO/MMO_rendered/PCD/excerpts'

OUT_DIR = f'separated/2023_PCPipeline/{MODEL_TYPE}/{MODEL_VERSION}'
if not os.path.isdir(OUT_DIR):
    os.makedirs(OUT_DIR)

In [None]:
model_cfg = read_config_yaml(CONFIG_PATH)


demucs_solver = DemucsSolver(model_type=MODEL_TYPE,
                             device=DEVICE,
                             model_cfg=model_cfg,
                             train=False)
try:
    demucs_solver.load_checkpoint(os.path.join(MODEL_DIR, 'demucs_best.pth'))
except:
    demucs_solver.load_best_model(MODEL_DIR)

In [None]:
H = 1024
N = 4096
Fs = 44100
K = N // 2
gamma = 10

F_coef = np.arange(K + 1) * Fs / N


for concerto_id in sorted(os.listdir(TEST_AUDIO_DIR))[0::10]:
    print(concerto_id)
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))
    num_frames = int(np.ceil(Fs/H * audio.shape[1] / Fs))
    T_coef = np.arange(num_frames) * H / Fs
    extent = [T_coef[0], T_coef[-1], F_coef[0], F_coef[-1]]

    estimates_dict = demucs_solver.separate(audio.unsqueeze(0))

    piano = estimates_dict['piano'].detach().numpy().squeeze(0)
    orch = estimates_dict['orch'].detach().numpy().squeeze(0)

    audio_player_list([audio.to('cpu').detach().numpy(),
                       piano,
                       orch], [Fs, Fs, Fs], width=180, height=30, 
                  columns=['mix', 'piano', 'orch'])

In [None]:
from tqdm import tqdm
for concerto_id in tqdm(sorted(os.listdir(TEST_AUDIO_DIR))):
    print(concerto_id)
    audio, sr = torchaudio.load(os.path.join(TEST_AUDIO_DIR, 
                                             f'{concerto_id}',
                                             f'{concerto_id}_OP.wav'))
                                                      
    
    estimates_dict = demucs_solver.separate(audio.unsqueeze(0))
    
    piano = estimates_dict['piano'].detach().numpy().squeeze(0)
    orch = estimates_dict['orch'].detach().numpy().squeeze(0)
                                                         
    sf.write(f'{OUT_DIR}/{concerto_id}_P.wav', piano.T, Fs)
    sf.write(f'{OUT_DIR}/{concerto_id}_O.wav', orch.T, Fs)