In [None]:
import os
from tqdm import notebook
import glob
import soundfile as sf
from scipy import signal
import numpy as np
import sys
import matplotlib.pyplot as plt

In [None]:
def get_top_k_peaks(f_corr, Cxy, top_n_peaks=5, distance_bw_peaks=2):
    fft_peaks_indices, fft_peaks_props = signal.find_peaks(Cxy, distance=distance_bw_peaks)
    freqs_at_peaks = f_corr[fft_peaks_indices]
    amplitudes_at_peaks = Cxy[fft_peaks_indices]
    if top_n_peaks < len(amplitudes_at_peaks):
        ind = np.argpartition(amplitudes_at_peaks, -top_n_peaks)[
              -top_n_peaks:]  # from https://stackoverflow.com/a/23734295
        ind_sorted_by_coef = ind[np.argsort(-amplitudes_at_peaks[ind])]  # reverse sort indices
    else:
        ind_sorted_by_coef = np.argsort(-amplitudes_at_peaks)
    # return_list = list(zip(freqs_at_peaks[ind_sorted_by_coef], amplitudes_at_peaks[ind_sorted_by_coef]))

    return freqs_at_peaks[ind_sorted_by_coef], amplitudes_at_peaks[ind_sorted_by_coef], ind_sorted_by_coef


In [None]:
def get_spectral_coherence(x, y, sr=8000, nperseg_ms=0.02, noverlap_ms=0.01, nfft=512):
    f_corr, Cxy = signal.coherence(x, y, fs=sr, nperseg=int(sr*nperseg_ms),
                                   noverlap=int(sr*noverlap_ms), nfft=nfft)
    return f_corr, Cxy


In [None]:
def top_k_peaks_plain(f, Cxy, top_n_peaks=5):
    idxs = np.argsort(Cxy)[::-1][:top_n_peaks]
    return f[idxs], Cxy[idxs], idxs

In [None]:
def process_layer(exp_dir, layer_index, top_n_instances=9, top_n_peaks=5):
    # print("in process layer")
    layer_dir = os.path.join(exp_dir, "{:02d}".format(layer_index))
    # print(layer_dir)
    num_filters = len(glob.glob(os.path.join(layer_dir, "*")))
    # print(num_filters)
    res = {}
    # fs = []
    vals = []
    for filter_idx in notebook.tqdm(range(num_filters), position=1):
        filter_subfld = os.path.join(os.path.join(layer_dir, "{:04d}".format(filter_idx)))
        
        peak_frequencies = []
        peak_values = []
        
        for ix in range(top_n_instances):
            deconv_ix = os.path.join(filter_subfld, "deconv_audio", "deconv_{:02}.wav".format(ix))
            input_ix = os.path.join(filter_subfld, "input_audio", "input_{:02}.wav".format(ix))
        
            x, _ = sf.read(input_ix)
            y, _ = sf.read(deconv_ix)
        
            f, Cxy = get_spectral_coherence(x, y)
        
            freq_peaks, val_peaks, _ = get_top_k_peaks(f, Cxy, top_n_peaks=top_n_peaks)
            # print("freq_peaks.shape:", freq_peaks.shape)
            if len(freq_peaks) == top_n_peaks:
                peak_frequencies.append(freq_peaks)
                peak_values.append(val_peaks)
        
        if len(peak_frequencies) != 0:
            peak_frequencies = np.asarray(peak_frequencies)
            peak_values = np.asarray(peak_values)
            # print("peak_values.shape", peak_values.shape)
            mean_top_n_peaks = np.mean(peak_values, 0)
            res[filter_idx] = mean_top_n_peaks.tolist()
            vals.append(mean_top_n_peaks.tolist())
    vals = np.asarray(vals)
    # print(vals.shape)
    mean_top_n_vals = np.mean(vals, 0)
    return res, mean_top_n_vals

In [None]:
def plot_n_peaks(f, Cxy, top_n_peaks, top_n=None):
    fig, ax = plt.subplots(figsize=(20,10))
    if top_n is None:
        top_n = len(top_n_peaks)
    ax.plot(f, Cxy)
    # ax.set_title("")
    for ix in range(top_n):
        freq_ix, amp_ix = top_n_peaks[ix]
        plt.plot(freq_ix, amp_ix, marker='x', color='black', alpha=0.8)
        plt.text(freq_ix+3, amp_ix, "{:d}".format(ix), color='black')
    plt.show()

In [None]:
exp_dirs = [
    "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8881/inspection_all_maps_f/",
    "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8882_noagc/inspection_all_maps_f/",
    "/media/user/nvme/contrastive_experiments/experiments_audioset_full_latest/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._baseline_rs8883_noagc/inspection_all_maps_f/"
]

In [None]:
# exp_dir = "/media/user/nvme/contrastive_experiments/experiments_audioset_v5_full/cnn12_1x_full_tr_8x128_Adam_1e-3_warmupcosine_wd0._fixed_lr_scaling_randomgain_gaussiannoise_timemask_bgnoise_nolineareval_full_ft_fullmodel_r2/inspection_all_maps_f"

In [None]:
outputs = {}
cnt = 0
for exp_dir in notebook.tqdm(exp_dirs, position=0):
    output = {}
    for layer_idx in notebook.tqdm(range(1, 12), position=1):
        res, mean_top_n_vals = process_layer(exp_dir, layer_idx, top_n_instances=5)
        output[layer_idx] = mean_top_n_vals
    outputs[cnt] = output
    cnt += 1