# Baseline Source Separation Statistics

This notebook calculates the baseline source separation statistics for a given dataset. It is currently setup for the folder structure of the new CHiME dataset, which includes mixed, isolated voice and isolated background noise recordings.

Metrics are calculated by comparison of the mixed audio to the isolated sources, giving a measure of the audio quality prior to separation and providing a reference against which to compare model outputs.

In [25]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)

import csv
import numpy as np
import tensorflow as tf
import mir_eval
import dataset
import datetime

### Build the Data Pipeline

In [29]:
#  Set variables
sample_rate=16384
n_fft=1024
fft_hop=256
patch_window=256
patch_hop=128
n_parallel_readers=4
normalise=True
batch_size = 50
shuffle=False
n_shuffle = 1

#root = 'C:/Users/Toby/Speech_Data/BG_test/'
root = '/home/enterprise.internal.city.ac.uk/acvn728/NewCHiME/'





#  Create the pipeline
tf.reset_default_graph()
data = np.empty((0, 3))
for env in ['bus', 'caf', 'ped', 'str']:
    directory_a = root + 'et05_' + env + '_simu'
    directory_b = root + 'et05_bth'
    directory_c = root + 'et05_' + env + '_bg'

    file_list = dataset.zip_files(directory_a, directory_b, directory_c)
    data = np.concatenate((data, file_list))
    
data = dataset.get_paired_dataset(data,
                                  sample_rate,
                                  n_fft,
                                  fft_hop,
                                  patch_window,
                                  patch_hop,
                                  n_parallel_readers,
                                  batch_size,
                                  n_shuffle,
                                  normalise)

#  Create the iterator
pipe = data.make_initializable_iterator()
_, _, _, mixed_audio, voice_audio, background_audio = pipe.get_next()


data

<PrefetchDataset shapes: ((?, 256, 513, 4), (?, 256, 513, 4), (?, 256, 513, 4), (?, 65280, 1), (?, 65280, 1), (?, 65280, 1)), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32)>

### Run the Data and Collect Results

In [30]:
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_config.gpu_options.visible_device_list = str(0)
sess = tf.Session(config=tf_config)
sess.run(tf.global_variables_initializer())

In [None]:
sess.run(pipe.initializer)
batch_count = 0
metrics = []
sdrs = np.empty((0, 2))
sirs = np.empty((0, 2))
sars = np.empty((0, 2))
nsdrs = np.empty((0, 2))

while True:
    try:
        mixed, voice, background = sess.run([mixed_audio, voice_audio, background_audio])

        # Reshape for mir_eval
        mixed = np.transpose(mixed, (0, 2, 1))
        voice = np.transpose(voice, (0, 2, 1))
        background = np.transpose(background, (0, 2, 1))

        for i in range(voice.shape[0]):
            ref_sources = np.concatenate((voice[i, :, :], background[i, :, :]), axis=0)
            est_sources = np.concatenate((mixed[i, :, :], mixed[i, :, :]), axis=0)

            # Calculate audio quality statistics
            sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(ref_sources, est_sources, compute_permutation=False)
            sdrs = np.concatenate((sdrs, np.expand_dims(sdr, 1).T), axis=0)
            sirs = np.concatenate((sirs, np.expand_dims(sir, 1).T), axis=0)
            sars = np.concatenate((sars, np.expand_dims(sar, 1).T), axis=0)
        print('{ts}:\t{bc} processed.'.format(ts=datetime.datetime.now(), bc=batch_count))
        batch_count += 1
    except tf.errors.OutOfRangeError:
        mean_sdr = np.mean(sdrs, axis=0)
        mean_sir = np.mean(sirs, axis=0)
        mean_sar = np.mean(sars, axis=0)
        

In [32]:
for (k, v) in (('voice', 0), ('background', 1)):
            metrics.append({'test': k, 'mean_sdr': mean_sdr[v],
                            'mean_sir': mean_sir[v], 'mean_sar': mean_sar[v]})

print('{ts}:\nProcessing complete\n{m}'.format(ts=datetime.datetime.now(), m=metrics))

2018-11-05 13:47:36.871146:
Processing complete
[{'mean_sir': 3.460299907111517, 'mean_sdr': 1.9689712679452347, 'test': 'voice', 'mean_sar': 14.605391982548587}, {'mean_sir': -2.7038786284632264, 'mean_sdr': -3.3982215847168984, 'test': 'background', 'mean_sar': 14.605391982548587}]


### Save the Results

In [33]:
if not os.path.isdir('test_metrics'):
    os.mkdir('test_metrics')
file_name = 'test_metrics/NewCHiMEDatasetBaselineMetrics.csv'
with open(file_name, 'w') as csvfile:
    fieldnames = ['test', 'mean_cost', 'mean_sdr', 'mean_sir', 'mean_sar', 'mean_nsdr']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator='\n')
    writer.writeheader()
    for test in metrics:
        writer.writerow(test)