In [39]:
# test.py
import museval
from tqdm import tqdm

import utils

import numpy as np
import torch

from utils import compute_loss

def predict(audio, model):
    if isinstance(audio, torch.Tensor):
        is_cuda = audio.is_cuda()
        audio = audio.detach().cpu().numpy()
        return_mode = "pytorch"
    else:
        return_mode = "numpy"

    expected_outputs = audio.shape[1]

    # Pad input if it is not divisible in length by the frame shift number
    output_shift = model.shapes["output_frames"]
    pad_back = audio.shape[1] % output_shift
    pad_back = 0 if pad_back == 0 else output_shift - pad_back
    if pad_back > 0:
        audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0)

    target_outputs = audio.shape[1]
    outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments}

    # Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal
    pad_front_context = model.shapes["output_start_frame"]
    pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
    audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0)

    # Iterate over mixture magnitudes, fetch network prediction
    with torch.no_grad():
        for target_start_pos in range(0, target_outputs, model.shapes["output_frames"]):

            # Prepare mixture excerpt by selecting time interval
            curr_input = audio[:, target_start_pos:target_start_pos + model.shapes["input_frames"]] # Since audio was front-padded input of [targetpos:targetpos+inputframes] actually predicts [targetpos:targetpos+outputframes] target range

            # Convert to Pytorch tensor for model prediction
            curr_input = torch.from_numpy(curr_input).unsqueeze(0)

            # Predict
            for key, curr_targets in utils.compute_output(model, curr_input).items():
                outputs[key][:,target_start_pos:target_start_pos+model.shapes["output_frames"]] = curr_targets.squeeze(0).cpu().numpy()

    # Crop to expected length (since we padded to handle the frame shift)
    outputs = {key : outputs[key][:,:expected_outputs] for key in outputs.keys()}

    if return_mode == "pytorch":
        outputs = torch.from_numpy(outputs)
        if is_cuda:
            outputs = outputs.cuda()
    return outputs

def predict_song(args, audio_path, model):
    model.eval()

    # Load mixture in original sampling rate
    mix_audio, mix_sr = utils.load(audio_path, sr=None, mono=False)
    mix_channels = mix_audio.shape[0]
    mix_len = mix_audio.shape[1]

    # Adapt mixture channels to required input channels
    if args.channels == 1:
        mix_audio = np.mean(mix_audio, axis=0, keepdims=True)
    else:
        if mix_channels == 1: # Duplicate channels if input is mono but model is stereo
            mix_audio = np.tile(mix_audio, [args.channels, 1])
        else:
            assert(mix_channels == args.channels)

    # resample to model sampling rate
    mix_audio = utils.resample(mix_audio, mix_sr, args.sr)

    sources = predict(mix_audio, model)

    # Resample back to mixture sampling rate in case we had model on different sampling rate
    sources = {key : utils.resample(sources[key], args.sr, mix_sr) for key in sources.keys()}

    # In case we had to pad the mixture at the end, or we have a few samples too many due to inconsistent down- and upsamṕling, remove those samples from source prediction now
    for key in sources.keys():
        diff = sources[key].shape[1] - mix_len
        if diff > 0:
            print("WARNING: Cropping " + str(diff) + " samples")
            sources[key] = sources[key][:, :-diff]
        elif diff < 0:
            print("WARNING: Padding output by " + str(diff) + " samples")
            sources[key] = np.pad(sources[key], [(0,0), (0, -diff)], "constant", 0.0)

        # Adapt channels
        if mix_channels > args.channels:
            assert(args.channels == 1)
            # Duplicate mono predictions
            sources[key] = np.tile(sources[key], [mix_channels, 1])
        elif mix_channels < args.channels:
            assert(mix_channels == 1)
            # Reduce model output to mono
            sources[key] = np.mean(sources[key], axis=0, keepdims=True)

        sources[key] = np.asfortranarray(sources[key]) # So librosa does not complain if we want to save it

    return sources

def evaluate(args, dataset, model, instruments):
    perfs = list()
    model.eval()
    with torch.no_grad():
        for example in dataset:
            print("Evaluating " + example["mix"])

            # Load source references in their original sr and channel number
            target_sources = np.stack([utils.load(example[instrument], sr=None, mono=False)[0].T for instrument in instruments])

            # Predict using mixture
            pred_sources = predict_song(args, example["mix"], model)
            pred_sources = np.stack([pred_sources[key].T for key in instruments])

            # Evaluate
            SDR, ISR, SIR, SAR, _ = museval.metrics.bss_eval(target_sources, pred_sources)
            song = {}
            for idx, name in enumerate(instruments):
                song[name] = {"SDR" : SDR[idx], "ISR" : ISR[idx], "SIR" : SIR[idx], "SAR" : SAR[idx]}
            perfs.append(song)

    return perfs


def validate(args, model, criterion, test_data):
    # PREPARE DATA
    dataloader = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers)

    # VALIDATE
    model.eval()
    total_loss = 0.
    with tqdm(total=len(test_data) // args.batch_size) as pbar, torch.no_grad():
        for example_num, (x, targets) in enumerate(dataloader):
            if args.cuda:
                x = x.cuda()
                for k in list(targets.keys()):
                    targets[k] = targets[k].cuda()

            _, avg_loss = compute_loss(model, x, targets, criterion)

            total_loss += (1. / float(example_num + 1)) * (avg_loss - total_loss)

            pbar.set_description("Current loss: " + str(total_loss))
            pbar.update(1)

    return total_loss

In [35]:
class Parameters:
    def __init__(self):
        self.cuda = True
        self.features = 32
        self.load_model = "waveunet/model"
        self.batch_size = 4
        self.levels = 6
        self.depth = 1
        self.sr = 44100
        self.channels = 2
        self.kernel_size = 5
        self.output_size = 2.0
        self.strides = 4
        self.conv_type = "gn"
        self.res = "fixed"
        self.separate = 1
        self.feature_growth = "double"
        self.input = "audio_examples/Cristina Vane - So Easy/mix.mp3"
        self.output = "./out"

In [36]:
import os 
args = Parameters()
os.makedirs(args.output, exist_ok=True)


In [41]:
# fun play
audio_path = args.input
mix_audio, mix_sr = utils.load(audio_path, sr=None, mono=False)
mix_channels = mix_audio.shape[0]
mix_len = mix_audio.shape[1]
print(mix_audio.shape)  
    



(2, 1323000)


In [None]:
# predict.py

In [38]:
import argparse
import os
import utils

# from test import predict_song
# import test
from waveunet import Waveunet

# parser = argparse.ArgumentParser()
# parser.add_argument('--cuda', action='store_true',
#                     help='use CUDA (default: False)')
# parser.add_argument('--features', type=int, default=32,
#                     help='# of feature channels per layer')
# parser.add_argument('--load_model', type=str,
#                     help='Reload a previously trained model')
# parser.add_argument('--batch_size', type=int, default=4,
#                     help="Batch size")
# parser.add_argument('--levels', type=int, default=6,
#                     help="Number DS/US blocks")
# parser.add_argument('--depth', type=int, default=1,
#                     help="Number of convs per block")
# parser.add_argument('--sr', type=int, default=44100,
#                     help="Sampling rate")
# parser.add_argument('--channels', type=int, default=2,
#                     help="Number of input audio channels")
# parser.add_argument('--kernel_size', type=int, default=5,
#                     help="Filter width of kernels. Has to be an odd number")
# parser.add_argument('--output_size', type=float, default=2.0,
#                     help="Output duration")
# parser.add_argument('--strides', type=int, default=4,
#                     help="Strides in Waveunet")
# parser.add_argument('--conv_type', type=str, default="gn",
#                     help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
# parser.add_argument('--res', type=str, default="fixed",
#                     help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
# parser.add_argument('--separate', type=int, default=1,
#                     help="Train separate model for each source (1) or only one (0)")
# parser.add_argument('--feature_growth', type=str, default="double",
#                     help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")

# parser.add_argument('--input', type=str, default=os.path.join("audio_examples", "Cristina Vane - So Easy", "mix.mp3"),
#                     help="Path to input mixture to be separated")
# parser.add_argument('--output', type=str, default=None, help="Output path (same folder as input path if not set)")

# args = parser.parse_args()

INSTRUMENTS = ["bass", "drums", "other", "vocals"]
NUM_INSTRUMENTS = len(INSTRUMENTS)

# MODEL
num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
               [args.features*2**i for i in range(0, args.levels)]
target_outputs = int(args.output_size * args.sr)
model = Waveunet(args.channels, num_features, args.channels, INSTRUMENTS, kernel_size=args.kernel_size,
                 target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                 conv_type=args.conv_type, res=args.res, separate=args.separate)

if args.cuda:
    model = utils.DataParallel(model)
    print("move model to gpu")
    model.cuda()

print("Loading model from checkpoint " + str(args.load_model))
state = utils.load_model(model, None, args.load_model)

preds = predict_song(args, args.input, model)

output_folder = os.path.dirname(args.input) if args.output is None else args.output
for inst in preds.keys():
    utils.write_wav(os.path.join(output_folder, os.path.basename(args.input) + "_" + inst + ".wav"), preds[inst], args.sr)

Using valid convolutions with 97961 inputs and 88409 outputs
move model to gpu
Loading model from checkpoint waveunet/model




In [42]:
model.shapes["output_frames"]

88409

In [50]:
pad_front_context = model.shapes["output_start_frame"]
pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
print(pad_front_context, pad_back_context)
print(model.shapes["input_frames"] , model.shapes["output_end_frame"])
print(model.shapes["input_frames"] - model.shapes["output_frames"])

4776 4776
97961 93185
9552


In [48]:
audio = mix_audio
print(audio.shape)
expected_outputs = audio.shape[1]

# Pad input if it is not divisible in length by the frame shift number
output_shift = model.shapes["output_frames"]
pad_back = audio.shape[1] % output_shift
pad_back = 0 if pad_back == 0 else output_shift - pad_back
if pad_back > 0:
    audio = np.pad(audio, [(0,0), (0, pad_back)], mode="constant", constant_values=0.0)
print(audio.shape)
target_outputs = audio.shape[1]
outputs = {key: np.zeros(audio.shape, np.float32) for key in model.instruments}

# Pad mixture across time at beginning and end so that neural network can make prediction at the beginning and end of signal
pad_front_context = model.shapes["output_start_frame"]
pad_back_context = model.shapes["input_frames"] - model.shapes["output_end_frame"]
audio = np.pad(audio, [(0,0), (pad_front_context, pad_back_context)], mode="constant", constant_values=0.0)
print(audio.shape)
print(audio.shape[1]% output_shift)

(2, 1323000)
(2, 1326135)
(2, 1335687)
9552
