### Compute MI on raw audio
- find relevant WAV files for BF, starling, speech
- spectrogram wavs
- segment files into .01, .1, 1 second chunks
- KMEANS cluster those chunks
- compute MI
- plot MI decay

In [None]:
from glob import glob

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from avgn.spectrogramming import spectrogramming as sg
from parallelspaper.config.paths import DATA_DIR, FIGURE_DIR
from parallelspaper.speech_datasets import LCOL_DICT
from parallelspaper.birdsong_datasets import BCOL_DICT
from parallelspaper.utils import save_fig
from tqdm.autonotebook import tqdm

In [None]:
hparams = {
    # filtering
    'highcut':15000,
    'lowcut':500,
    # spectrograms
    'mel_filter': True, # should a mel filter be used?
    'num_mels':32, # how many channels to use in the mel-spectrogram
    'num_freq':512, # how many channels to use in a spectrogram 
    'preemphasis':0.97, 
    'frame_shift_ms':5, # step size for fft
    'frame_length_ms':10, # frame length for fft
    'min_level_db':-50, # minimum threshold db for computing spe 
    'spec_thresh_min': -40, # (db)
    'ref_level_db':50, # reference db for computing spec
    'fmin': 300, # low frequency cutoff for mel filter
    'fmax': None, # high frequency cutoff for mel filter
}

In [None]:
from scipy.io import wavfile
import numpy as np
import sklearn.cluster
from parallelspaper import information_theory as it 
from sklearn.externals.joblib import Parallel, delayed
import pandas as pd
from datetime import datetime
from parallelspaper.hvc_funcs import load_cbin # for loading cbin files

In [None]:
starling_wavs = glob('../../../animalvocalizationgenerativenet/data/st_wavs/b1077/wavs/*.wav')
bf_wavs = glob('/mnt/cube/Datasets/BengaleseFinch/sober/*/gy6or6*.cbin')
human_wavs = np.array([[i] for i in glob('/mnt/cube/Datasets/buckeye/s01/*.wav')])

In [None]:
# break bf wavs into day
bf_labs = [datetime.strptime(
            "_".join(label_loc.split("/")[-1].split(".")[0].split("_")[-2:]),
            "%d%m%y_%H%M",
        ) for label_loc in bf_wavs]
bf_wavs = np.array(bf_wavs)[np.argsort(bf_labs)]
bf_labs = np.array(bf_labs)[np.argsort(bf_labs)]
bf_days = np.array([lab.strftime('%d%m%y') for lab in bf_labs])
bf_day_wavs = [bf_wavs[bf_days == i] for i in np.unique(bf_days)]

In [None]:
# break starling wavs into day
st_labs = [datetime.strptime(i[:-4].split('/')[-1], "%Y-%m-%d_%H-%M-%S-%f") for i in starling_wavs]
st_wavs = np.array(starling_wavs)[np.argsort(st_labs)]
st_labs = np.array(st_labs)[np.argsort(st_labs)]
st_days = np.array([lab.strftime('%d%m%y') for lab in st_labs])
st_day_wavs = [np.array(starling_wavs)[st_days == i] for i in np.unique(st_days)]

In [None]:
len(bf_day_wavs), len(st_day_wavs)

##### plot an example

In [None]:
wav_loc = human_wavs[0][0]
rate, data= wavfile.read(wav_loc)
hparams['sample_rate'] = rate
_mel_basis = sg._build_mel_basis(hparams) # build a basis function if you are using a mel spectrogram
spec = sg.melspectrogram(data, hparams, _mel_basis)
fig, ax = plt.subplots(figsize=(30,3))
ax.matshow(spec[:,5000:10000], interpolation=None, aspect='auto', origin='lower')

In [None]:
def norm(x):
    return(x - np.min(x))/(np.max(x)-np.min(x))

In [None]:
def process_wav(wav_loc, hparams, time_bin, _mel_basis):
    """ slice a wav into chunks of a given (tim_bin (seconds)) time length
    """
    # load wav
    print(wav_loc.split('.')[-1])
    if wav_loc.split('.')[-1] == 'wav':
        rate, data= wavfile.read(wav_loc)
    elif wav_loc.split('.')[-1] == 'cbin':
        data, rate = load_cbin(wav_loc)
    # set sample rate of wav
    hparams['sample_rate'] = rate
    # create mel basis
    _mel_basis = sg._build_mel_basis(hparams) # build a basis function if you are using a mel spectrogram
    # load wav
    spec = sg.melspectrogram(data, hparams, _mel_basis)
    # get number of frames for each time bin
    frames_per_time_bin = fptb= int(time_bin/(hparams['frame_shift_ms']/1000))
    spec_samples = [spec[:, i*fptb:(i+1)*fptb] for i in range(int(np.shape(spec)[1]/fptb))]
    return np.array(norm(spec_samples)*255).astype(np.uint8)
test = process_wav(wav_loc, hparams, time_bin=1.0, _mel_basis=_mel_basis)

In [None]:
plt.matshow(test[0])

### Compute MI across timescales and datasets

In [None]:
def MI_raw_audio(dset, _mel_basis, wav_files, n_clusters = 100, time_bin=1.0, verbosity=0, n_jobs=20, seconds_dist = 100):
    """
    """
    song_pieces = []
    
    for wf_day in wav_files:
        # split into pieces
        with Parallel(n_jobs=n_jobs, verbose=verbosity) as parallel:
            song_pieces_day = [parallel(
                delayed(process_wav)(wav_loc, hparams, time_bin, _mel_basis) 
                     for wav_loc in tqdm(wf_day, leave=False, desc='wav segmentation'))]
            song_pieces.append(np.vstack([np.vstack(i) for i in song_pieces_day]))

    # stack pieces into one long list
    song_pieces_filt = np.vstack(song_pieces)
    # flatten
    song_pieces_filt_flat = song_pieces_filt.reshape((np.shape(song_pieces_filt)[
                                                     0], np.shape(song_pieces_filt)[1]*np.shape(song_pieces_filt)[2]))
    # prep kmeans clustering
    mbk = sklearn.cluster.MiniBatchKMeans(init='k-means++', n_clusters=n_clusters, batch_size=100,
                                          n_init=10, max_no_improvement=10, verbose=0,
                                          random_state=0)
    
    # fit kmeans
    clusters = mbk.fit(song_pieces_filt_flat)
    
    # MI should be computed up until the median list len
    list_lens = [len(i) for i in song_pieces]
    
    d2c = int(seconds_dist/time_bin)
    distances = np.arange(1, np.median(d2c).astype(int))
    
    # split labels into original sequences
    seqs = [mbk.labels_[int(np.sum(list_lens[:i])):int(
        np.sum(list_lens[:i+1]))] for i, llen in enumerate(list_lens)]
    
    # calculate Mutual information
    (MI, var_MI), (MI_shuff, MI_shuff_var) = it.sequential_mutual_information(seqs,
                                                                  distances,
                                                                  n_jobs=n_jobs,
                                                                  verbosity=verbosity,
                                                                  n_shuff_repeats=1, estimate=True)

    return [dset, time_bin, MI, var_MI, MI_shuff, MI_shuff_var]

In [None]:
n_clusters = 100
MI_raw = pd.DataFrame(columns = ['dset', 'time_bin', 'MI', 'var_MI', 'MI_shuff', 'MI_shuff_var'])
for (dset, wav_files) in tqdm([['starling', st_day_wavs], ['bengalese finch', bf_day_wavs], ['english', human_wavs]], desc="dataset"):
    print(dset)
    for time_bin in tqdm([1.0, 0.1, 0.01], leave=False, desc="time_bin"):
        results = MI_raw_audio(dset, _mel_basis, wav_files, n_clusters = n_clusters, time_bin=time_bin, verbosity=0, n_jobs=20, seconds_dist = 100)
        MI_raw.loc[len(MI_raw)] = results
        

In [None]:
MI_raw.to_pickle(DATA_DIR/'MI_DF/MI_raw.pickle')

In [None]:
MI_raw

In [None]:
fig, axs = plt.subplots(ncols = 3, nrows = 3, figsize = (20,10))
for dsi, dset in enumerate(['english', 'bengalese finch', 'starling']):
    for tbi, time_bin in enumerate([0.01, 0.1, 1.0]):
        ax = axs[tbi, dsi]
        if dset == 'english':
            color = LCOL_DICT[dset]
        elif dset == 'bengalese finch':
            color = BCOL_DICT['BF']
        elif dset == 'starling':
            color = BCOL_DICT['Starling']

        subset_MI_DF = MI_raw[(MI_raw.dset == dset) & (MI_raw.time_bin == time_bin)]
        sig = subset_MI_DF.MI.values[0] - subset_MI_DF.MI_shuff.values[0]
        distances = np.arange(1,len(sig)+1)*time_bin

        ax.scatter(distances, sig, alpha = 1, s=40, color=color)
        ax.plot(distances, sig, alpha = 0, color=color)

        ax.tick_params(which='both', direction='in', labelsize=14, pad=10)
        ax.tick_params(which='major', length=10, width =3)
        ax.tick_params(which='minor', length=5, width =2)
        ax.set_xscale( "log" , basex=10)
        for axis in ['top','bottom','left','right']:
            ax.spines[axis].set_linewidth(3)
            ax.spines[axis].set_color('k')

        ax.set_xlim([10e-4, 100])

        ax.set_xscale( "log" , basex=10)
        ax.set_yscale( "log" , basey=10)
    axs[0,dsi].set_title(dset.capitalize(), fontsize=18)
    axs[2,dsi].set_xlabel('Distance between elements (seconds)', fontsize=18)

axs[1,0].set_ylabel('Mutual Information (bits)', fontsize=18)

save_fig(FIGURE_DIR/'spectrogram_MI')