In [None]:
import sys, os
sys.path.append('../')

import numpy as np
import data_loader.data_loader as data_loader
from argparse import Namespace
import matplotlib.pyplot as plt
import maxflow
import IPython.display as ipd
from librosa.core import stft, istft

In [None]:
default_parameters = Namespace(batch_size=1, n_jobs=4, get_top=None)
default_parameters.return_items = ['mixture_wav',
                                   'clean_sources_wavs',
                                   'rpd']
default_parameters.input_dataset_p = '/mnt/data/CS544_data/timit_5400_1800_512_2_fm_random_taus_delays/val'
data_gen = data_loader.get_numpy_data_generator(default_parameters)

In [None]:
def get_normalized_spectrogram(wav):
    stft_representation = (stft(wav,
                                 n_fft=512,
                                 win_length=512,
                                 hop_length=128))[::-1, :]
    
    # normalize the spectrogram values 
    stft_representation = (stft_representation - np.mean(stft_representation)) / (
                           np.std(stft_representation) + 10e-9)

    spec = np.abs(stft_representation)
    return stft_representation, spec**0.2

In [None]:
data_gen = iter(data_gen)
batch_data_list = next(data_gen)
numpy_data_list = data_loader.convert_to_numpy(batch_data_list)
mix_wav, clean_wavs, phase_diff = numpy_data_list
s1_clean, s2_clean = clean_wavs[0], clean_wavs[1]
phase_diff = (np.clip(phase_diff, -2, 2) / 4 + 0.5)[::-1, :]

mix_stft, mix_spec = get_normalized_spectrogram(mix_wav)
s1_stft, s1_spec = get_normalized_spectrogram(s1_clean)
s2_stft, s2_spec = get_normalized_spectrogram(s2_clean)

spec = np.abs(mix_stft)**0.2
plt.title("Phase Difference")
plt.imshow(phase_diff)
plt.colorbar()

plt.figure()
plt.title("Mixture Spectrogram")
plt.imshow(mix_spec, cmap='coolwarm')
plt.colorbar()

In [None]:
# plt.figure()
# plt.title("Mixture Superposition Spectrogram")
# im1 = plt.imshow(s1_spec, cmap='Reds', interpolation='nearest', alpha=0.99)
# im2 = plt.imshow(s2_spec, cmap='Blues', interpolation='nearest', alpha = 0.8)
# # plt.colorbar()
# plt.show()


fig,ax = plt.subplots()
# plt.title("Superposition of Sources Spectrograms")
pa = ax.imshow(s1_spec,interpolation='nearest',cmap='Reds', alpha=0.9)
cba = plt.colorbar(pa)
pb = ax.imshow(s2_spec,interpolation='nearest',cmap='Blues', alpha=0.6)
cbb = plt.colorbar(pb)
# plt.xlabel('Time')
# plt.ylabel('Depth')
cba.set_label('Source 1 Activation')
cbb.set_label('Source 2 Activation')
plt.show()

In [None]:
def build_graph_from_img(spec, phase_diff, horiz_weight=1, vert_weight=1):
    height, width = phase_diff.shape
    g = maxflow.Graph[float](height*width, height*width*2)
    g = maxflow.Graph[float]()

    
    nodeids = g.add_grid_nodes(phase_diff.shape)
    horiz_diff = np.abs(phase_diff[:, :-1] - phase_diff[:, 1:])
    vert_diff = np.abs(phase_diff[:-1] - phase_diff[1:])
    horiz_diff = np.concatenate([horiz_diff, np.zeros((height, 1))], axis=1)
    vert_diff = np.concatenate([vert_diff, np.zeros((1, width))], axis=0)
    horiz_diff = horiz_diff.reshape(-1)
    vert_diff = vert_diff.reshape(-1)
    
    for row in nodeids:
        for nodeid in row:
            if nodeid % width != width - 1:
                weight = np.exp(-horiz_diff[nodeid]**2)
                weight *= horiz_weight
                g.add_edge(nodeid, nodeid + 1, weight, weight)
                
            if nodeid < width * (height - 1):
                weight = np.exp(-vert_diff[nodeid]**2)
                weight *= vert_weight
                g.add_edge(nodeid, nodeid + width, weight, weight)
                
    # make terminal edges
    g.add_grid_tedges(nodeids, spec, np.max(spec) - spec)
    return (g, nodeids) 

In [None]:
g, nodeids = build_graph_from_img(phase_diff, mix_spec, horiz_weight=0.000, vert_weight=0.00)
g.maxflow()
mask = g.get_grid_segments(nodeids)
plt.imshow(mask, cmap='gray_r')

In [None]:
def sisnr(s_pred, s, eps=10e-9):
    s_pred -= s_pred.mean()
    s -= s.mean()
#     s_pred /= (s_pred.std() + eps)
#     s /= (s.std() + eps)
    if not len(s_pred) == len(s):
        min_len = int(min(len(s), len(s_pred)))
        s = s[0:min_len]
        s_pred = s_pred[:min_len]
    coef = np.dot(s, s_pred) / (np.dot(s, s) + 10e-9) 
    s_target = coef * s 
    e_noise = s_pred - s_target 
    sisnr = 10*np.log10(np.dot(s_target, s_target) / 
                        (np.dot(e_noise, e_noise)+10e-9))
    return sisnr

def compute_sisnr_and_return_pair(s_recon, clean_sources, eps=10e-9):
    all_sisdrs = [(sisnr(s_recon, clean_sources[i], eps=eps), clean_sources[i]) 
                  for i in range(clean_sources.shape[0])]
    return sorted(all_sisdrs, key = lambda x: x[0])[-1]
    

In [None]:
mix_stft = (stft(mix_wav,
                 n_fft=512,
                 win_length=512,
                 hop_length=128))

s1_estimate = istft(mix_stft * mask[::-1, :],
                    win_length=512,
                    hop_length=128)
s2_estimate = istft(mix_stft * (1. - mask[::-1, :]),
                    win_length=512,
                    hop_length=128)

s1_sisnr, s1_clean = compute_sisnr_and_return_pair(s1_estimate, clean_wavs)
s2_sisnr, s2_clean = compute_sisnr_and_return_pair(s2_estimate, clean_wavs)


print("Initial Mixture")
ipd.display(ipd.Audio(mix_wav, rate=16000))

print("Source 1 Reconstruction: SISDR {}".format(s1_sisnr))
ipd.display(ipd.Audio(s1_estimate, rate=16000))

print("Source 1 Clean")
ipd.display(ipd.Audio(s1_clean, rate=16000))

print("Source 2 Reconstruction: SISDR {}".format(s2_sisnr))
ipd.display(ipd.Audio(s2_estimate, rate=16000))

print("Source 2 Clean")
ipd.display(ipd.Audio(s2_clean, rate=16000))