In [1]:
import sys , os
sys.path.append('../')

from dnn.bitwise_network import *
from dnn.binary_layers import *
from datasets.two_source_mixture import *
from datasets.sinusoidal_data import *
from datasets.quantized_data import *
import torch
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import scipy.signal as signal
from bss_eval import *
import IPython.display as ipd
import soundfile as sf
import mir_eval

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Binary Neural Network

In [2]:
def evaluate(model, batch):
    mix, targ, inter = batch['mixture'], batch['target'], batch['interference']
    estimates = model(mix)
    estimates = torch.argmax(estimates, dim=1)
    
    sources = torch.stack([targ, inter], dim=1)
    metrics = bss_eval_batch(estimates, sources)
    return metrics

def run_evaluation(model, dl):
    bss_metrics = BSSMetricsList()
    for batch in dl:
        metrics = evaluate(model, batch)
        bss_metrics.extend(metrics)
            
    sdr, sir, sar = bss_metrics.mean()
    print('Mean SDR %f' % sdr)
    print('Mean SIR %f' % sir)
    print('Mean SAR %f'% sar)
    return bss_metrics
    
def get_audio_output(model, batch):
    mix, targ, inter = batch['mixture'], batch['target'], batch['interference']
    estimate = model(mix).detach().numpy()
    mix = mix.detach().numpy()
    for i in range(len(mix)):
        print('Mixture')
        ipd.display(ipd.Audio(mix[i], rate=16000))
        print('Estimate')
        ipd.display(ipd.Audio(estimate[i], rate=16000))

def visualize_mask(binary_sample, raw_sample, model=None):
    _, _, mask, ibm = evaluate(binary_sample, raw_sample, model)
    plt.figure(figsize=(15, 5))
    plt.subplot(121)
    plt.pcolormesh(ibm, cmap='binary')
    plt.subplot(122)
    plt.pcolormesh(mask, cmap='binary')
    
def visualize_input(binary_sample, num_bits=4):
    bmag, ibm = binary_sample['bmag'], binary_sample['ibm']
    qmag = np.zeros((bmag.shape[0] // num_bits, bmag.shape[1]))
    for i in range(bmag.shape[0] // num_bits):
        for j in range(num_bits):
            qmag[i] += 2**(num_bits - j - 1) * bmag[num_bits*i + j]
    
    plt.figure(figsize=(15, 5))
    plt.subplot(121)
    plt.pcolormesh(qmag, cmap='binary')
    plt.subplot(122)
    plt.pcolormesh(ibm, cmap='binary')

In [3]:
# Configuration
toy = False
num_bits = 8
quantizer = Quantizer(min=-1, delta=2/(2**num_bits), num_bits=num_bits)

In [4]:
# Load Dataset
train_dl, val_dl = make_data(1, toy=toy)
sample = iter(train_dl).next()
targ = sample['target']
quantized_targ = quantizer(targ)
reconstructed_targ = quantizer.inverse(quantized_targ)
ipd.display(ipd.Audio(targ, rate=16000))
ipd.display(ipd.Audio(reconstructed_targ, rate=16000))

## Bitwise Neural Network Evaluation

In [6]:
# Evaluate real, no adapt network
nn = BitwiseNetwork(512, 128, fc_sizes=[2048, 2048], in_channels=1, out_channels=2**num_bits, autoencode=True)
nn.eval()
nn.load_state_dict(torch.load('../models/real_autoencoder.model'))
target = iter(val_dl).next()['target']
estimate = nn(quantizer(target).unsqueeze(1))
estimate = torch.argmax(estimate, dim=1).to(torch.float)
estimate = quantizer.inverse(estimate)
ipd.display(ipd.Audio(estimate, rate=16000))

In [12]:
for i, batch in enumerate(val_dl):
    if i % 128 == 0:
        get_audio_output(nn, batch)

Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


In [10]:
# Evaluate bitwise network
bnn = BitwiseNetwork(1024, 256, fc_sizes=[2048, 2048])
bnn.load_state_dict(torch.load('models/bitwise_noadapt.model'))
bnn.inference()
bnn.eval()
run_evaluation(bnn, val_dl)

Mean SDR 6.244427
Mean SIR 13.486576
Mean SAR 7.603800


<bss_eval.BSSMetricsList at 0x7f8b6af01908>

In [13]:
for i, batch in enumerate(val_dl):
    if i % 128 == 0:
        get_audio_output(nn, batch)

Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


In [6]:
# Evaluate real, no adapt network
nn = BitwiseNetwork(1024, 256, fc_sizes=[2048, 2048])
nn.eval()
nn.load_state_dict(torch.load('models/real_adapt.model'))
bss_metrics = run_evaluation(nn, val_dl)

Mean SDR 7.073082
Mean SIR 14.808135
Mean SAR 8.245502


In [7]:
for i, batch in enumerate(val_dl):
    if i % 128 == 0:
        get_audio_output(nn, batch)

Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


In [13]:
# Evaluate real, no adapt network
nn = BitwiseNetwork(1024, 256, fc_sizes=[2048, 2048])
nn.eval()
nn.load_state_dict(torch.load('models/bitwise_adapt.model'))
nn.inference()
bss_metrics = run_evaluation(nn, val_dl)

Mean SDR 6.736159
Mean SIR 14.859759
Mean SAR 7.888259


In [12]:
for i, batch in enumerate(val_dl):
    if i % 128 == 0:
        get_audio_output(nn, batch)

Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate
