# Analyze Feature Extractors

## Imports and example audio

In [1]:
import importlib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os.path
import pickle
import soundfile
import string
from scipy import signal

import analysis_helpers as analysis

In [2]:
audio_file = "/u/vieting/tmp/audios/librispeech.test_other.533-1066.wav"
audio, fr = soundfile.read(audio_file)
audio = np.reshape(audio[:16_000 * 5], [1, -1, 1])

In [3]:
EXPORT = True
EXPORTPATH = "/u/vieting/Documents/text/2023_interspeech_features/"
cm = 1 / 2.54  # centimeters in inches
if EXPORT:
    matplotlib.use("pgf")
    matplotlib.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "font.family": "serif",
        "text.usetex": True,
        "pgf.rcfonts": False,
    })

## Helpers

In [4]:
def get_params_layer_n(feat, n, params):
    feat_name_mapped = feat.split("_")[0]
    map = {
        "gammatone": {
            1: "features/gammatone_filterbank/W",
            2: "features/temporal_integration/W",
        },
        "i6": {
            1: "features/conv_h/W",
            2: "features/conv_l/W",
        },
        "i6gt": {
            1: "features/scf/conv_h/W",
            2: "features/scf/conv_l/W",
        },
        "w2v": {
            1: "features/layer0/layer0/W",
            2: "features/layer1/layer0/W",
            3: "features/layer2/layer0/W",
            4: "features/layer3/layer0/W",
            5: "features/layer4/layer0/W",
            6: "features/layer5/layer0/W",
        },
    }
    return params[feat][map[feat_name_mapped][n]]

## Load Data

In [5]:
with open("data.p", "rb") as f:
    data = pickle.load(f)
    features = data["features"]
    params = data["params"]
with open("freq_resp.p", "rb") as f:
    freq_resp = pickle.load(f)

## Analyze Parameters

In [32]:
analysis = importlib.reload(analysis)
fig, axs = plt.subplots(3, 2, figsize=(16 * cm, 9.5 * cm))
titles = {
    "gammatone": "Gammatone",
    "i6": "SC",
    "w2v": "wav2vec 2.0 FE",
    "w2v_pretrained_freeze": "wav2vec 2.0 FE (pre-training)"
}

# first layer
for idx, feat in enumerate(["gammatone", "i6", "w2v", "w2v_pretrained_freeze"]):
    filters = get_params_layer_n(feat, 1, params).squeeze().T
    f_resp = np.zeros((filters.shape[0], 512))
    for fltr in range(filters.shape[0]):
        omega, f_r = signal.freqz(filters[fltr, :])
        f_resp[fltr, :] = 20 * np.log10(np.abs(f_r[::-1]))
    f_resp = np.rot90(f_resp)
    f = np.linspace(0, 8, filters.shape[0])
    f_resp = analysis.get_sorted_filters(f_resp, version="v4")#, threshold=0.25)

    ax = axs[idx // 2, idx % 2]
    im = ax.imshow(
        f_resp, origin="lower", aspect="auto", extent=[1, f_resp.shape[1], f[0], f[-1]], vmin=-100, vmax=30,
    )
    if idx >= len(fig.get_axes()) - 2:
        ax.set_xlabel("Filter index")
    if idx % 2 == 0:
        ax.set_ylabel("Frequency [kHz]")
    ax.set_yticks([0, 4, 8])
    ax.set_title(f"({string.ascii_lowercase[idx]}) {titles[feat]}")
    
# full FE
for idx, feat in enumerate(["i6", "w2v"]):
    f, f_resp = freq_resp[feat]
    f = [f_ / 1000 for f_ in f]
    f_resp = analysis.get_sorted_filters(f_resp, version="v4")
    
    ax = axs[idx // 2 + 2, idx % 2]
    im = ax.imshow(
        f_resp, origin="lower", aspect="auto", extent=[1, f_resp.shape[1], f[0], f[-1]],
    )
    ax.set_xlabel("Filter index")
    if idx % 2 == 0:
        ax.set_ylabel("Frequency [kHz]")
    ax.set_yticks([0, 4, 8])
    ax.set_title(f"({string.ascii_lowercase[idx + 4]}) {titles[feat]}")

# fig.colorbar(im, ax=axs.ravel().tolist())
# fig.subplots_adjust(wspace=0.25, hspace=0.25) # set the spacing between axes. 
fig.tight_layout(pad=0.1)

if EXPORT:
    plt.savefig(f"{EXPORTPATH}first_layer.pgf")
else:
    plt.show()

In [7]:
# # get sorted indices for masking in recognition
# feat = "i6"
# filters = get_params_layer_n(feat, 1, params).squeeze().T
# f_resp = np.zeros((filters.shape[0], 512))
# for fltr in range(filters.shape[0]):
#     omega, f_r = signal.freqz(filters[fltr, :])
#     f_resp[fltr, :] = 20 * np.log10(np.abs(f_r[::-1]))
# f_resp = np.rot90(f_resp)
# f = np.linspace(0, 8, filters.shape[0])

# f_resp_nolog = 10 ** (f_resp / 20)  # from log to regular domain
# mean_to_max = f_resp_nolog.mean(axis=0) / f_resp_nolog.max(axis=0)
# sorted_idcs = np.argsort(mean_to_max)
# f_resp = f_resp[:, sorted_idcs]
# print(mean_to_max.max(), mean_to_max.min())

# plt.imshow(f_resp, origin="lower", aspect="auto")
# plt.show()

In [8]:
# feat = "w2v"
# filters = get_params_layer_n(feat, 1, params).squeeze().T
# f_resp = np.zeros((filters.shape[0], 512))
# for fltr in range(filters.shape[0]):
#     omega, f_r = signal.freqz(filters[fltr, :])
#     f_resp[fltr, :] = 20 * np.log10(np.abs(f_r[::-1]))
# f_resp = np.rot90(f_resp)
# f = np.linspace(0, 8, filters.shape[0])

# # f_resp = analysis.get_sorted_filters(f_resp, version="v0")
# f_cent = analysis.get_center_frequencies(f_resp)
# f_cent_2 = np.argsort(f_resp, axis=0)[-2]
# f_cut_u = analysis.get_cutoff_frequencies(f_resp)
# sorting = "v2"
# if sorting == "v0":  # v0: f_cent only
#     sorted_idcs = np.argsort(f_cent)
#     f_resp = f_resp[:, sorted_idcs]
# elif sorting == "v1":  # v1: f_cent, then f_cutoff upper
#     sorted_idcs = np.lexsort((f_cut_u, f_cent))
#     f_resp = f_resp[:, sorted_idcs]
# elif sorting == "v2":  # v1: f_cent, then f_cent_2
#     sorted_idcs = np.lexsort((f_cent_2, f_cent))
#     f_resp = f_resp[:, sorted_idcs]
# else:
#     assert False

# plt.imshow(f_resp, origin="lower", aspect="auto")
# plt.show()
# plt.plot(f_cent[sorted_idcs])
# plt.plot(f_cut_u[sorted_idcs])
# plt.show()
# plt.plot(f_resp[:, 0:3])
# plt.legend([str(idx) for idx in range(5)])
# plt.show()

## Results after masking

In [31]:
num_masked = [0, 5, 10, 15, 20, 25, 30, 35, 40, 50, 60, 70, 80]
wer = [7.1, 7.1, 7.1, 7.1, 7.1, 7.2, 7.2, 7.3, 7.4, 7.7, 9.0, 12.5, 19.9]

fig, ax = plt.subplots(1, 1, figsize=(7.5 * cm, 2.5 * cm))
ax.plot(num_masked, wer, "-o", ms='3')
ax.set_xlabel("Masked Filters")
ax.set_ylabel("WER")
# ax.set_xlim([0, 80])
fig.tight_layout(pad=0.1)

if EXPORT:
    plt.savefig(f"{EXPORTPATH}masked_filters.pgf")
else:
    plt.show()