In [None]:
from common import utils, data, models, argbind, viz
import nussl
import torch
import os
from contextlib import contextmanager
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr

argbind.ARGS['args.debug'] = True
utils.logger(level='info')

LABELS = ['bass', 'drums', 'other', 'vocals']

@contextmanager
def scope(output_folder, override_args={}):
    with utils.chdir(output_folder):
        args = argbind.load_args('./args.yml')
        for key, val in override_args.items():
            args[key] = val
        with argbind.scope(args):
            yield args

In [None]:
overrides = {
    'device.use': 'cpu',
    'deep_mask_estimation.model_path': 'checkpoints/latest.model.pth',
}

with scope('../../../output/musdb18/', overrides) as args:    
    device = utils.device()
    separator = models.deep_mask_estimation(device)
    stft_params, sample_rate = data.signal()
    
    with argbind.scope(args, 'test'):
        tfm, new_labels = data.transform(stft_params, sample_rate, 'vocals', True)
        dataset = data.mixer(stft_params, tfm)
        
    item = dataset[np.random.randint(len(dataset))]

    separator.audio_signal = item['mix']
    estimates = separator()
    estimates.append(item['mix'] - estimates[0])
    
    viz.embed(estimates)

In [None]:
with scope('../../../output/musdb18/') as args:    
    device = utils.device()
    separator = models.deep_mask_estimation(device)    
    
    def separate(audio):
        sr, data = audio
        mix = nussl.AudioSignal(audio_data_array=data, sample_rate=sr)
        
        separator.audio_signal = mix
        estimates = separator()
        estimates.append(mix - estimates[0])
        
        estimates = {f'Estimate {i}': s for i, s in enumerate(estimates)}
        html = nussl.play_utils.multitrack(estimates, ext='.mp3', display=False)
        return html

    gr.Interface(
        fn=separate, 
        inputs="audio", 
        outputs="html",
    ).launch(share=True)