In [17]:
import torch
import numpy as np
import json
from musdb import DB
import museval
import argparse
import soundfile as sf
from tqdm import trange
import os
from ignite.handlers import Checkpoint

from model import X_UMX

In [18]:
config_path = 'config.json'
checkpoint_path = 'saved/x-umx_checkpoint_362400.pt'
estimate_path = 'estimates/'
chunk_dur = 30
sr = 44100

In [19]:
config = json.load(open(config_path))

db = DB(config['root'], is_wav=True, subsets='test')

max_bins = int(config["bandwidth"] / sr * config['nfft']) + 1
model = X_UMX(config['nfft'], config['nhop'], config['hidden_size'], max_bins, 2, 3)

checkpoint = torch.load(checkpoint_path, map_location='cpu')
to_load = {
    'model': model
}
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
model = model.cuda()
model.eval()

X_UMX(
  (affine1): Sequential(
    (0): Conv1d(11896, 2048, kernel_size=(1,), stride=(1,), groups=4, bias=False)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Tanh()
  )
  (bass_lstm): LSTM(512, 256, num_layers=3, dropout=0.4, bidirectional=True)
  (drums_lstm): LSTM(512, 256, num_layers=3, dropout=0.4, bidirectional=True)
  (vocals_lstm): LSTM(512, 256, num_layers=3, dropout=0.4, bidirectional=True)
  (others_lstm): LSTM(512, 256, num_layers=3, dropout=0.4, bidirectional=True)
  (affine2): Sequential(
    (0): Conv1d(1024, 2048, kernel_size=(1,), stride=(1,), bias=False)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv1d(2048, 16392, kernel_size=(1,), stride=(1,), groups=4, bias=False)
    (4): BatchNorm1d(16392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [21]:
#results = museval.EvalStore(frames_agg='median', tracks_agg='median')
chunk_size = int(chunk_dur * sr)
for track in db[:3]:
    print(track.name)
    os.makedirs(os.path.join(estimate_path, track.name), exist_ok=True)
    audio = track.audio
    if (audio.shape[0] % chunk_size) == 0:
        nchunks = (audio.shape[0] // chunk_size)
    else:
        nchunks = (audio.shape[0] // chunk_size) + 1
    outputs = []
    for chunk_idx in trange(nchunks):
        cur_chunk = audio[chunk_idx * chunk_size: min((chunk_idx + 1) * chunk_size, audio.shape[0]),:]
        x = torch.from_numpy(cur_chunk).float().t().unsqueeze(0).cuda()

        with torch.no_grad():
            X = model.t2f(x)
            Xmag = X.abs()
            pred_mask = model(Xmag)
            xpred = model.f2t(pred_mask * X.unsqueeze(1), length=x.shape[-1]).squeeze().cpu().numpy()
        
        outputs.append(xpred)
    xpred = np.concatenate(outputs, 2)
    for i, s in enumerate(['drums', 'bass', 'other', 'vocals']):
        sf.write(
            os.path.join(estimate_path, track.name, s + '.wav'),
            xpred[i].T,
            sr
        )

    #estimates = {
    #    'drums': xpred[0].T,
    #    'bass': xpred[1].T,
    #    'other': xpred[2].T, 
    #    'vocals': xpred[3].T
    #}
    #score = museval.eval_mus_track(track, estimates)
    #print(score)
    #results.add_track(score)

AM Contra - Heart Peripheral
100%|██████████| 8/8 [00:03<00:00,  2.49it/s]
Al James - Schoolboy Facination
100%|██████████| 7/7 [00:02<00:00,  2.35it/s]
Angels In Amplifiers - I'm Alright
100%|██████████| 6/6 [00:02<00:00,  2.23it/s]
