In [1]:
from bitwise_network import *
from binarized_network import *
from binary_layers import *
from datasets.two_source_mixture import *
from datasets.sinusoidal_data import *
from datasets.binary_data import *
import torch
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import scipy.signal as signal
from bss_eval import *
import IPython.display as ipd
import soundfile as sf
import mir_eval

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Binary Neural Network

In [6]:
def evaluate(model, batch):
    mix, targ, inter = batch['mixture'], batch['target'], batch['interference']
    estimates = model(mix)
    estimates = estimates / torch.max(estimates, dim=1)[0]
    sources = torch.stack([targ, inter], dim=0)
    metrics = bss_eval_batch(estimates, sources)
    
    return metrics

def run_evaluation(model, dl):
    bss_metrics = BSSMetricsList()
    for batch in dl:
        metrics = evaluate(model, batch)
        bss_metrics.extend(metrics)
            
    sdr, sir, sar = bss_metrics.mean()
    print('Mean SDR %f' % sdr)
    print('Mean SIR %f' % sir)
    print('Mean SAR %f'% sar)
    return bss_metrics
    
def get_audio_output(model, batch):
    mix, targ, inter = batch['mixture'], batch['target'], batch['interference']
    estimate = model(mix).detach().numpy()
    mix = mix.detach().numpy()
    for i in range(len(mix)):
        print('Mixture')
        ipd.display(ipd.Audio(mix[i], rate=16000))
        print('Estimate')
        ipd.display(ipd.Audio(estimate[i], rate=16000))

def visualize_mask(binary_sample, raw_sample, model=None):
    _, _, mask, ibm = evaluate(binary_sample, raw_sample, model)
    plt.figure(figsize=(15, 5))
    plt.subplot(121)
    plt.pcolormesh(ibm, cmap='binary')
    plt.subplot(122)
    plt.pcolormesh(mask, cmap='binary')
    
def visualize_input(binary_sample, num_bits=4):
    bmag, ibm = binary_sample['bmag'], binary_sample['ibm']
    qmag = np.zeros((bmag.shape[0] // num_bits, bmag.shape[1]))
    for i in range(bmag.shape[0] // num_bits):
        for j in range(num_bits):
            qmag[i] += 2**(num_bits - j - 1) * bmag[num_bits*i + j]
    
    plt.figure(figsize=(15, 5))
    plt.subplot(121)
    plt.pcolormesh(qmag, cmap='binary')
    plt.subplot(122)
    plt.pcolormesh(ibm, cmap='binary')

In [7]:
# Configuration
toy = False

In [8]:
# Load Dataset
train_dl, val_dl = make_data(1, toy=toy)
sample = iter(train_dl).next()
mix, targ = sample['mixture'], sample['target']
ipd.display(ipd.Audio(mix, rate=16000))
ipd.display(ipd.Audio(targ, rate=16000))

## Bitwise Neural Network Evaluation

In [9]:
# Evaluate real network
nn = BitwiseNetwork(1024, 256, fc_sizes=[2048, 2048])
nn.eval()
nn.load_state_dict(torch.load('models/real_network.model'))
bss_metrics = run_evaluation(nn, val_dl)

Mean SDR 8.011654
Mean SIR 127.438362
Mean SAR 8.011649


In [10]:
for i, batch in enumerate(val_dl):
    if i % 128 == 0:
        get_audio_output(nn, batch)

Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


Mixture


Estimate


In [8]:
# Evaluate bitwise network
bnn = BitwiseNetwork(1024, 256, fc_sizes=[2048, 2048])
bnn.load_state_dict(torch.load('models/bitwise_network.model'))
bnn.inference()
bnn.eval()
run_evaluation(bnn, val_dl)

  z[index] = x
  output = mkl_fft.rfftn_numpy(a, s, axes)


Mean SDR 7.350534
Mean SIR 13.080113
Mean SAR 9.343989


<bss_eval.BSSMetricsList at 0x7f97b95f3fd0>

## Binarized Networks

In [5]:
bin_net = BinarizedNetwork(2052, 513, fc_sizes=[1024, 1024])
bin_net.eval()
bin_net.load_state_dict(torch.load('models/toy_bin_network.model'))
run_evaluation(toy_binary_set, toy_raw_set, bin_net)
visualize_mask(toy_binary_set[0], toy_raw_set[0], bin_net)
visualize_mask(toy_binary_set[100], toy_raw_set[100], bin_net)
get_audio_output(toy_binary_set[0], toy_raw_set[0], bin_net)
get_audio_output(toy_binary_set[100], toy_raw_set[100], bin_net)

  z[index] = x
  output = mkl_fft.rfftn_numpy(a, s, axes)


Mean SDR 13.219864
Mean SIR 25.260276
Mean SAR 13.559282


<bss_eval.BSSMetricsList at 0x7f35240d6c18>

In [None]:
bin_net = BinarizedNetwork(2052, 513, fc_sizes=[2048, 2048])
bin_net.eval()
bin_net.load_state_dict(torch.load('models/bin_network.model'))
run_evaluation(bin_net, binary_set, raw_set)

In [27]:
s, sr = sf.read('/media/data/Nonspeech/n81.wav')
ipd.Audio(s, rate=sr)

In [25]:
print([i for i in nn.state_dict() if i.endswith('weight')])

['linear_list.0.weight', 'linear_list.1.weight', 'linear_list.2.weight']
