## Preamble

In [None]:
# load packages

%load_ext autoreload
%autoreload

import pandas as pd
import numpy as np
import pingouin as pg
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go

from itertools import product
from scipy.optimize import curve_fit
from analysis_utils import *
from sklearn.metrics import r2_score

pio.renderers.default = "vscode"

In [None]:
# load tables

print("text_df contents")
text_df = read_text_as_df()
display(text_df.head())

print("perp_df contents")
perp_df = read_perps_as_df()
perp_df = perp_df.merge(text_df[['utt', 'len']], on='utt')
display(perp_df.head())


print("wer_df contents")
wer_df = read_best_wers_as_df()
display(wer_df.head())

print("uttwer_df contents")
uttwer_df = read_best_uttwers_as_df()
uttwer_df = uttwer_df.merge(text_df[['utt', 'len']], on='utt')
display(uttwer_df.head())

df = agg_mean_by_lens(uttwer_df, 'len', ['wer', 'acc'], ['mdl', 'latlm', 'reslm', 'part', 'snr'])
df = df.merge(wer_df, on=['mdl', 'latlm', 'reslm', 'part', 'snr'])
print(f"max WER diff (%) btw uttwer and wer: {np.abs(df['acc_x'] - df['acc_y']).max():.01%}")

## Perplexity

In [None]:
print("entropy/perplexity by partition and LM")
df = agg_mean_by_lens(perp_df, 'len', 'ent', ['part', 'perplm'])
df['perp'] = np.exp(df['ent'])
df = df.pivot(values=['ent', 'perp'], index='part', columns='perplm')
display(df.round(2))


In [None]:
print('distribution of per-utt entropy by partition and LM')
fig = px.violin(
    perp_df, x='ent', y='part', color='perplm',
    box=True,
    labels=dict(ent='Entropy (nats)', part='Partition', lm="LM"),
    width=600, height=800,
)
fig.show()

In [None]:
print('Test of normality of entropy given LM')
display(pg.normality(perp_df, dv='ent', group='perplm', method='normaltest').round(3))

print("pairwise spearman correlations of entropy across LMs")
df = perp_df.pivot(values='ent', index='utt', columns='perplm')
display(pg.pairwise_corr(df, columns=df.columns, alternative='greater', method='spearman').round(3))

print("scatter plot matrix of per-utterance entropy of each LM")
fig = px.scatter_matrix(df, dimensions=df.columns, opacity=0.1)
fig.show()

In [None]:
print("per-utterance perplexity vs. rank by LM")

df = perp_df.copy()
df['rank'] = df.groupby(['perplm'])['perp'].rank()

fig = px.scatter(df, x='rank', y='perp', color='perplm', log_y=True)
fig.show()

## WER

In [None]:
part = 'dev-clean'
latlm = reslm = 'tgsmall'
mdl = 'tdnn_1d_sp'
desc = f"({part} partition, {mdl} model, {latlm} lattice LM, and {reslm} rescoring lm)"

df = uttwer_df.loc[
    np.isfinite(uttwer_df['snr']) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['part'] == part) &
    (uttwer_df['mdl'] == mdl)
].copy()
df['snr'] = df['snr'].astype('int')

with pd.option_context('display.max_rows', 10):
    print(f"test of normality of per-utterance WERs given SNR {desc}")
    display(pg.normality(df, dv='wer', group='snr', method='normaltest').round(3).sort_index())


    print(f"spearman correlation of WERs across SNRs {desc}")
    df = df.pivot(values='wer', index='utt', columns='snr')
    display(pg.pairwise_corr(df, columns=df.columns, alternative='greater', method='spearman').round(3).sort_index())

print(f"scatter plot matrix of per-utterance WERs of select SNRs {desc}")
fig = px.scatter_matrix(df, dimensions=[5, 10, 20, 30], opacity=0.1)
fig.update_layout({"xaxis"+str(i+1): dict(range = [-0.1, 1]) for i in range(len(df.columns))})
fig.update_layout({"yaxis"+str(i+1): dict(range = [-0.1, 1]) for i in range(len(df.columns))})
fig.show()

In [None]:
# Zhang et al (2023) "Estimate the noise effect on automatic speech recognition
# accuracy for mandarin by an approach associating articulation index"
# FIXME(sdrobert): the fit is very bad if we use eq. 12

latlm = 'tgsmall'
reslm = 'tgsmall'
part = 'dev-clean'
desc = f"({part} partition, {latlm} lattice LM, and {reslm} rescoring lm)"
num_points = 100

df = wer_df.loc[
    (wer_df['latlm'] == latlm) &
    (wer_df['reslm'] == reslm) &
    (wer_df['part'] == part)
].copy()

idx = np.isinf(df['snr'])
df, Ainvs = df.loc[~idx], df.loc[idx, ['mdl', 'acc']]
snr_min = df['snr'].min() - 1
snr_max = df['snr'].max() + 1
x_interp = np.linspace(snr_min, snr_max, num_points)

mdls = df['mdl'].unique()
assert all(mdls == Ainvs['mdl'].unique())
ratio = num_points // (len(mdls) + 2)

def zhang_func(x : np.ndarray, A : float, B : float, C : float) -> np.ndarray:
    return 1 / (np.exp(-(x + B) / C) + A)


fit = []
fig = go.Figure()
for mdl_idx, mdl in enumerate(mdls):
    colour = px.colors.qualitative.Plotly[mdl_idx]
    df_ = df.loc[df['mdl'] == mdl]
    Ainv = Ainvs.loc[Ainvs['mdl'] == mdl, 'acc'].iloc[0]
    A_init = 1 / Ainv
    N = len(df_)
    x = df_['snr'].array
    y = df_['acc'].array
    (A, B, C), _ = curve_fit(
        zhang_func, x, y,
        p0=(A_init, 0, 1),
        bounds=([1, -np.inf, 0.01], [np.inf, np.inf, np.inf]),
    )
    y_pred = zhang_func(x, A, B, C)
    r2 = r2_score(y, y_pred)
    fit.append(dict(mdl=mdl, A=A, B=B, C=C, r2=r2))
    y_interp = 1 / (A + np.exp(-(x_interp + B) / C))
    fig.add_scatter(
        x=x, y=df_['acc'] * 100,
        name=mdl, mode='markers',
        marker=dict(color=colour),
    )
    fig.add_scatter(
        x=x_interp, y=y_interp * 100,
        mode='lines',
        opacity=0.5,
        showlegend=False,
        line=dict(color=colour),
    )
    fig.add_annotation(
        x=x_interp[ratio * (mdl_idx + 1)], y=y_interp[ratio * (mdl_idx + 1)] * 100,
        text=f"A={A:.02f},B={B:.02f},C={C:.02f}",
        showarrow=True,
        font=dict(color=colour),
    )
print(f"Zhang et al fits by model {desc}")
display(pd.DataFrame.from_records(fit).round(3))

print(f"accuracy (inv. WER) by SNR across models w/ Zhang et al fits {desc}")
fig.update_layout(
    xaxis_title="SNR (dB)",
    yaxis_title="accuracy (%)",
    legend_title="model",
    xaxis_tickformat='d',
    yaxis_tickformat='d',
    xaxis_range=[snr_min, snr_max],
    yaxis_range=[0, 100],
)
fig.show()


## Perplexity vs. WER

In [None]:
# wer by perp

mdl = 'tdnn_1d_sp'
latlm = reslm = perplm = 'tgsmall'
num_points = 100
part = 'dev-clean'
print(
    f"mdl {mdl}, partition {part}, lattice LM {latlm}, rescore LM {reslm}, "
    f"perlexity LM {perplm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)]
df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl)
], on=['utt', 'part'])
df = df.loc[df['snr'].isnull()]  # without noise
ymin, ymax = df['wer'].quantile(0.05), df['wer'].quantile(0.95)
xmin, xmax = df['perp'].quantile(0.05), df['perp'].quantile(0.95)
perp_interp = np.linspace(xmin, xmax, num_points)

print("per-utterance WER by perplexity")
fig = px.scatter(df, x='perp', y='wer')
fig.update_xaxes(type='log', range=[np.log10(xmin), np.log10(xmax)])
fig.update_yaxes(range=[ymin, ymax])
fig.show()

In [None]:
# boothroyd's k

latlm = reslm = perplm = binlm = 'tgsmall'
mdl = 'tdnn_1d_sp'
num_bins = 5
num_points = 100
binpart = 'dev-clean'
part = 'dev-clean'
x_interp = np.linspace(0.01, 100, num_points)
ratio = num_points // (num_bins + 2)
add_intercept = True
print(
    f"mdl {mdl}, part {part} lattice lm {latlm}, rescore lm {reslm} perplexity LM "
    f"{perplm}, bin part {binpart}, bin LM {binlm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)].copy()
bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins)[1]
df['ent_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['ent_bin'].dtype.categories

print("mean entropy by bin and ratio (highest/bin)")
df_ent = agg_mean_by_lens(df, 'len', 'ent', 'ent_bin')
df_ent['ratio'] = df_ent.loc[df_ent['ent_bin'] == bin_cats[num_bins - 1], 'ent'].iloc[0] / df_ent['ent']
display(df_ent.round(3))

df = df.merge(
    uttwer_df.loc[
        (uttwer_df['reslm'] == reslm) &
        (uttwer_df['latlm'] == latlm) &
        (uttwer_df['mdl'] == mdl)
    ], on=['utt', 'part', 'len'])

df = agg_mean_by_lens(df, 'len', 'wer', ['snr', 'ent_bin'])
df = df.pivot(values='wer', index='snr', columns='ent_bin')

fig_acc, fig_loge = go.Figure(), go.Figure()
x = df[bin_cats[num_bins - 1]]
log_x_lims = np.log10(100 * df[bin_cats[num_bins - 1]].min() - 1), np.log10(30)
log_y_lims = np.log10(100 * df[bin_cats[0]].min() - 1), np.log10(30)
log_x_interp = np.linspace(log_x_lims[0], np.log10(100), num_points)

fits = []
for bin in range(num_bins):
    y = df[bin_cats[bin]]
    fit : pd.DataFrame = pg.linear_regression(np.log(x), np.log(y), add_intercept=add_intercept)
    iv_name, int_name = f"k {bin_cats[bin]}", f"c {bin_cats[bin]}"
    fit['names'] = fit['names'].map({bin_cats[num_bins - 1]: iv_name, "Intercept": int_name})
    if add_intercept:
        c = fit.loc[fit['names'] == int_name, 'coef'].iloc[0]
    else:
        c = 0
    fits.append(fit)
    k = fit.loc[fit['names'] == iv_name, 'coef'].iloc[0]
    colour = px.colors.qualitative.Plotly[bin]
    y_interp = 100 * (1 - np.exp(c) * (1 - x_interp / 100) ** k)
    interp_name = f"k={k:.02f}" + (f", c={c:.02f}" if add_intercept else "")
    fig_acc.add_scatter(
        x=100 - x * 100, y=100 - y * 100,
        name=bin_cats[bin],
        mode='markers',
        legendgroup="points",
        marker=dict(color=colour),
    )
    fig_acc.add_scatter(
        x=x_interp, y=y_interp,
        name=interp_name,
        legendgroup="fits",
        mode='lines', opacity=0.5,
        line=dict(color=colour),
    )
    fig_loge.add_scatter(
        x=100 * x, y=100 * y,
        name=bin_cats[bin],
        mode='markers',
        legendgroup="points",
        marker=dict(color=colour),
    )
    y_interp = 10 ** (k * (log_x_interp - np.log10(100)) + c + np.log10(100))
    fig_loge.add_scatter(
        x=10 ** (log_x_interp), y=y_interp,
        mode='lines',
        opacity=0.5,
        name=interp_name,
        legendgroup="fits",
        line=dict(color=colour),
    )
print("Boothroyd & Nittrouer model fits")
display(pd.concat(fits).round(3))

print("in-context vs. out-of-context accuracy and B & N fits")
fig_acc.update_layout(
    xaxis_title="out-of-context accuracy (%)",
    yaxis_title="in-context accuracy (%)",
    xaxis_tickformat='d',
    yaxis_tickformat='d',
    xaxis_range=[0, 100],
    yaxis_range=[0, 100],
    width=800, height=400,
)
fig_acc.show()
print("in-context vs. out-of-context error rates and B & N fits")
fig_loge.update_layout(
    xaxis_title="out-of-context error rate (%)",
    yaxis_title="in-context error rate (%)",
    width=800, height=400,
)
fig_loge.update_xaxes(type='log', range=log_x_lims)
fig_loge.update_yaxes(type='log', range=log_y_lims)
fig_loge.show()


In [None]:
# Klakow and Peters (2002). "Testing the correlation of word error rate and perplexity"
# "... slope a is smaller for tasks that are acoustically more challenging. Hence on
# those tasks larger reductions in PP are needed to obtain a given reduction in WER." 

latlm = reslm = perplm = binlm = 'tgsmall'
mdl = 'tdnn_1d_sp'
num_bins = 5
num_points = 100
part = binpart = 'dev-clean'
print(
    f"mdl {mdl}, part {part} lattice lm {latlm}, rescore lm {reslm} perplexity LM "
    f"{perplm}, bin part {binpart}, bin LM {binlm}"
)

def klakow_func(perp : np.ndarray, a : float, b: float) -> np.ndarray:
    return b * (perp ** a)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)].copy()
bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins)[1]
bins = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
df['ent_bin'] = bins
bin_cats = df['ent_bin'].dtype.categories
x = agg_mean_by_lens(df, 'len', 'ent', 'ent_bin')['ent']
print("entropy by bin")
display(x.round(3))

df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl) &
    np.isfinite(uttwer_df['snr'])
], on=['utt', 'part', 'len'])
display(df.head())
snr_30 = agg_mean_by_lens(df, 'len', 'wer', ['snr'])
snr_30 = snr_30.loc[snr_30['wer'] < .3, 'snr'].min()
df = agg_mean_by_lens(df, 'len', 'wer', ['snr', 'ent_bin'])

snrs = df['snr'].unique()
snrs.sort()
fits = []
curve_params_list = []
for snr in snrs:
    snr_mask = df['snr'] == snr
    y = np.log(df.loc[df['snr'] == snr, "wer"])
    fit : pd.DataFrame = pg.linear_regression(x, y)
    curve_params_list.append({
        "snr": snr,
        "a": fit.loc[fit['names'] == 'ent', 'coef'].iloc[0],
        "b": np.exp(fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]),
    })
    iv_name, int_name = f"iv {int(snr)}", f"int {int(snr)}"
    fit['names'] = fit['names'].map({'ent': iv_name, "Intercept": int_name})
    fits.append(fit)
print("regression fits for Klakow and Peters models")
display(pd.concat(fits).round(3))

snr_mini, snr_midi, snr_maxi = 10, 16, len(snrs) - 1
df = df.loc[(df['snr'] >= snrs[snr_mini]) & (df['snr'] <= snrs[snr_maxi])]
df['wer'] *= 100

print("WER by (PP, SNR) with select K & P fits")
fig = px.bar(df, x='ent_bin', y='wer', color='snr', barmode='overlay', color_continuous_scale="viridis", opacity=1.0)
for dict_ in (curve_params_list[snr_mini], curve_params_list[snr_midi], curve_params_list[snr_maxi]):
    y = klakow_func(np.exp(x), dict_['a'], dict_['b']) * 100
    interp_name = f"a={dict_['a']:.03f}, b={dict_['b']:.03f} WER ∈ [{y.min():.02f},{y.max():.02f}]"
    fig.add_scatter(
        x=bins.dtype.categories,
        y=y,
        showlegend=False,
        name=interp_name,
        mode='markers+lines',
        marker=dict(color='red'), line=dict(color='red'))
    fig.add_annotation(
        x=bins.dtype.categories[0], y=y.iloc[0],
        text=interp_name,
        showarrow=True,
        opacity=1,
        font=dict(color="black"),
        bgcolor='white',
    )
fig.update_layout(
    yaxis_range=[0, 100]
)
fig.show()

df = pd.DataFrame.from_records(curve_params_list)
df['logb/a'] = np.log(df['b']) / df['a']
print('K & P model parameter ratio by snr')
fig = px.scatter(df, x='snr', y='logb/a')
fig.show()
print("K & P model parameters by SNR")
df = pd.melt(df, ['snr'], ['a', 'b'], var_name='param', value_name='val')
fig = px.scatter(df, x='snr', y='val', color='param')
fig.update_layout(yaxis_range=[0, 1])
fig.show()

print("Predicted k by bin and snr")
records = []
ent_out = x[num_bins - 1]
for snri, dict_ in enumerate(curve_params_list):
    a, b = dict_['a'], dict_['b']
    snr = int(snrs[snri])
    log_b = np.log(b)
    lwer_out = a * ent_out + log_b
    for bin_in in (0, num_bins // 2, num_bins - 1):
        ent_in = x[bin_in]
        ratio_name = f'{bin_cats[bin_in]} over {bin_cats[num_bins - 1]}'
        lwer_in = a * ent_in + log_b
        k = lwer_in / lwer_out
        records.append(dict(snr=snr, k=k, ratio_name=ratio_name))
df = pd.DataFrame.from_records(records)
fig = px.scatter(df, x='snr', y='k', color='ratio_name')
fig.add_vline(x=snr_30, line_dash='dash', line_color='black', annotation_text='70% acc')
fig.show()


In [None]:
# boothroyd prediction
num_bins = 7
train_mdl = 'tdnn_1d_sp'
train_part = 'dev-clean'
train_latlm = train_reslm = train_perplm = 'tgsmall'
test_mdls = ('tri6b',)
test_parts = ('dev-other',)
test_perplms = ('tgmed', 'fglarge')
add_intercept = False

# determine SNRs which don't have extremal values. It is more important to set the
# max, as high values tend to inflate correlations (i.e. 0.99^k ~= 0.99)
min_wer, max_wer = 0.0, 0.20
df = wer_df.loc[
    (wer_df['mdl'] == train_mdl) &
    (wer_df['latlm'] == train_latlm) &
    (wer_df['reslm'] == train_reslm) &
    (wer_df['part'] == train_part) &
    np.isfinite(wer_df['snr'])
].groupby('snr')['wer'].agg(['min', 'max'])
good_snrs = df.index[(df['min'] >= min_wer) & (df['max'] <= max_wer)]
good_snr_min, good_snr_max = good_snrs.min(), good_snrs.max()
good_snr_mid = (good_snr_min + good_snr_max) / 2
print(f"good SNRs: [{good_snr_min}, {good_snr_max}]")

# all records we'll consider
df = perp_df.copy()
bounds = bin_series(
    perp_df.loc[
        (perp_df['perplm'] == train_perplm) &
        (perp_df['part'] == train_part)
    , 'ent'], num_bins)[1]
df['perp_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['perp_bin'].dtype.categories

df = df.merge(
    uttwer_df.loc[
        np.isinf(uttwer_df['snr']) |
        ((uttwer_df['snr'] >= good_snr_min) & (uttwer_df['snr'] <= good_snr_max))
    ], on=['utt', 'part', 'len'])
df = agg_mean_by_lens(
    df,
    'len',
    ['wer', 'ent', 'len'],
    ['snr', 'perp_bin', 'perplm', 'reslm', 'latlm', 'mdl', 'part'],
)
df['lwer'] = np.log(df['wer'])

train_df = df.loc[
    (df['latlm'] == train_latlm) &
    (df['reslm'] == train_reslm) &
    (df['perplm'] == train_perplm) &
    (df['mdl'] == train_mdl) &
    (df['part'] == train_part)
]

print('train entropy by bin')
display(train_df.groupby('perp_bin', observed=False)[['ent']].mean().round(3))
ent_fit = dict()
for in_bin in range(num_bins):
    ent_in = train_df.loc[train_df['perp_bin'] == bin_cats[in_bin], 'ent'].iloc[0]
    for out_bin in range(num_bins):
         ent_out = train_df.loc[train_df['perp_bin'] == bin_cats[out_bin], 'ent'].iloc[0]
         ent_fit[(in_bin, out_bin)] = ent_out / ent_in, 0

def train(df : pd.DataFrame) -> dict[tuple[int, int], tuple[float,float]]:
    fits = dict()
    df = df.loc[np.isfinite(df['snr'])]
    for in_bin in range(num_bins):
        df_in = df.loc[df['perp_bin'] == bin_cats[in_bin], ['snr', 'lwer']]
        for out_bin in range(num_bins):
            df_out = df.loc[df['perp_bin'] == bin_cats[out_bin], ['snr', 'lwer']]
            df_in_out = df_in.merge(df_out, on='snr', suffixes=('_in', '_out'))
            fit = pg.linear_regression(
                df_in_out['lwer_out'],
                df_in_out['lwer_in'],
                add_intercept=add_intercept,
            )
            k = fit.loc[fit['names'] == 'lwer_out', 'coef'].iloc[0]
            if add_intercept:
                c = fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]
            else:
                c = 0
            fits[(in_bin, out_bin)] = k, c
    return fits

def test(df: pd.DataFrame, fits : dict[(int, int), tuple[float, float]], plot : bool = False) -> pd.DataFrame:
    res = dict()
    is_inf = np.isinf(df['snr'])
    df_nonoise, df = df.loc[is_inf], df.loc[~is_inf]
    for in_bin in range(num_bins):
        df_in = df.loc[df['perp_bin'] == bin_cats[in_bin]]
        wer_true = df_nonoise.loc[df_nonoise['perp_bin'] == bin_cats[in_bin], 'wer'].iloc[0]
        df_in = df_in[['snr', 'lwer']]
        for out_bin in range(num_bins):
            df_out = df.loc[df['perp_bin'] == bin_cats[out_bin]]
            k, c = fits[(in_bin, out_bin)]
            wer_pred = df_nonoise.loc[df_nonoise['perp_bin'] == bin_cats[out_bin], 'wer'].iloc[0] ** k
            df_out = df_out[['snr', 'lwer']]
            df_in_out = df_in.merge(df_out, on='snr', suffixes=('_in', '_out'))
            y_true = df_in_out['lwer_in'].to_numpy()
            y_pred = k * df_in_out['lwer_out'].to_numpy() + c
            r2 = r2_score(y_true, y_pred)
            res[(in_bin, out_bin)] = r2, 100 * wer_true, 100 * wer_pred
    df = pd.DataFrame.from_dict(res, orient='index', columns=['r2', 'wer_true', 'wer_pred'])
    df.sort_index()
    df.index = pd.MultiIndex.from_product([bin_cats] * 2, names=['in_bin', 'out_bin'])
    if plot:
        im = df.reset_index().pivot(values='r2', columns='out_bin', index='in_bin')
        fig = px.imshow(
            im,
            labels=dict(x="out-of-context bin", y="in-context bin", z="R^2"),
            x=bin_cats,
            y=bin_cats,
            zmin=-1,
            text_auto=".3f",
            color_continuous_scale='BrBG',
        )
        fig.show()
    return df

def display_test(df: pd.DataFrame, groupby=None):
    df = df.reset_index()
    df = df.reset_index().loc[df['in_bin'] != df['out_bin']].copy()
    df['wer_diff'] = np.abs(df['wer_pred'] - df['wer_true'])
    df['wer_prop'] = df['wer_diff'] / df['wer_true'] * 100
    if groupby:
        df_with = df.groupby(groupby)
    else:
        df_with = df
    df_with = df_with[['r2', 'wer_diff', 'wer_true', 'wer_prop']].describe()
    df = df.loc[
        (df['in_bin'] != bin_cats[0]) &
        (df['in_bin'] != bin_cats[-1]) &
        (df['out_bin'] != bin_cats[0]) &
        (df['out_bin'] != bin_cats[-1])
    ]
    if groupby:
        df_wo = df.groupby(groupby)
    else:
        df_wo = df
    df_wo = df_wo[['r2', 'wer_diff', 'wer_true', 'wer_prop']].describe()
    df = pd.concat([df_with, df_wo], keys=['w/ extreme bins', 'w/o extreme bins'])
    display(df.transpose().round(3))


print('all equal fit on train')
display_test(test(
    train_df,
    dict((key, (1, 0)) for key in product(range(num_bins), repeat=2))
))

print('entropy fit on train')
display_test(test(train_df, ent_fit, True))

# print(f'split by SNR {good_snr_mid} and train/test on quadrants')
# res = dict()
# for train_split, test_split in product(("low", "high"), repeat=2):
#     if train_split == "low":
#         fit = train(train_df.loc[train_df['snr'] <= good_snr_mid])
#     else:
#         fit = train(train_df.loc[train_df['snr'] > good_snr_mid])
#     if test_split == "low":
#         scores = test(train_df.loc[train_df['snr'] <= good_snr_mid], fit)
#     else:
#         scores = test(train_df.loc[train_df['snr'] > good_snr_mid], fit)
#     res[(train_split, test_split)] = scores
# res = pd.concat(res.values(), keys=res.keys(), names=['train SNR', 'test SNR'])
# display_test(res, ["train SNR", "test SNR"])

fit = train(train_df)

print('train and test on self')
display_test(test(train_df, fit))

for test_mdl in test_mdls:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == test_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_mdl}, test on {test_mdl}")
    display_test(test(test_df, fit))

    print(f"entropy fit on {test_mdl}")
    display_test(test(test_df, ent_fit))


for test_part in test_parts:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == test_part)
    ]

    print(f"train on {train_part}, test on {test_part}")
    display_test(test(test_df, fit))

    print(f"entropy fit on {test_part}")
    display_test(test(test_df, ent_fit))

for test_perplm in test_perplms:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == test_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_part}, test on {test_perplm}")
    display_test(test(test_df, fit, True))

    print(f"entropy fit on {test_perplm}")
    display_test(test(test_df, ent_fit, True))


In [None]:
# klakow prediction
num_bins = 7
train_mdl = 'tdnn_1d_sp'
train_part = 'dev-clean'
train_latlm = train_reslm = train_perplm = 'tgsmall'
test_mdls = ('tri6b',)
test_parts = ('dev-other',)
test_perplms = ('tgmed', 'fglarge')

df = perp_df.copy()
bounds = bin_series(
    perp_df.loc[
        (perp_df['perplm'] == train_perplm) &
        (perp_df['part'] == train_part)
    , 'ent'], num_bins)[1]
df['perp_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['perp_bin'].dtype.categories

# Klakow's model doesn't have anything to do with SNR
df = df.merge(
    uttwer_df.loc[
        np.isinf(uttwer_df['snr'])
    ], on=['utt', 'part', 'len'])
df = agg_mean_by_lens(
    df,
    'len',
    ['wer', 'ent', 'len'],
    ['perp_bin', 'perplm', 'reslm', 'latlm', 'mdl', 'part'],
)
df['lwer'] = np.log(df['wer'])

train_df = df.loc[
    (df['latlm'] == train_latlm) &
    (df['reslm'] == train_reslm) &
    (df['perplm'] == train_perplm) &
    (df['mdl'] == train_mdl) &
    (df['part'] == train_part)
]

def train(df: pd.DataFrame) -> tuple[float, float]:
    fit = pg.linear_regression(df['ent'], df['lwer'])
    a = fit.loc[fit['names'] == 'ent', 'coef'].iloc[0]
    log_b = fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]
    return a, log_b

def test(df: pd.DataFrame, fit: tuple[float, float]) -> dict[str, float]:
    wer_true = (df['wer'] * df['len']).sum() / df['len'].sum() * 100
    ent = (df['ent'] * df['len']).sum() / df['len'].sum()
    wer_pred = np.exp(fit[0] * ent + fit[1]) * 100
    y_true = df['lwer'].to_numpy()
    y_pred = fit[0] * df['ent'].to_numpy() + fit[1]
    if len(y_pred) > 1:
        r2 = r2_score(y_true, y_pred)
    else:
        r2 = None
    return dict(r2=r2, wer_true=wer_true, wer_pred=wer_pred)

def display_test(records : list[dict], groupby=None):
    df = pd.DataFrame.from_records(records)
    df['wer_diff'] = np.abs(df['wer_pred'] - df['wer_true'])
    df['wer_prop'] = df['wer_diff'] / df['wer_true'] * 100
    if groupby:
        df = df.groupby(groupby)
    df = df[['r2', 'wer_diff', 'wer_true', 'wer_prop']]
    if len(records) > 1:
        df = df.describe()
        df = df.transpose()
    display(df.round(3))

print(f"{num_bins}-fold cross-validation")
records = []
for test_bin in range(num_bins):
    test_mask = train_df['perp_bin'] == bin_cats[test_bin]
    records.append(test(train_df[test_mask], train(train_df.loc[~test_mask])))
display_test(records)

print("train and test on self")
fit = train(train_df)
display_test([test(train_df, fit)])

for test_mdl in test_mdls:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == test_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_mdl}, test on {test_mdl}")
    display_test([test(test_df, fit)])


for test_part in test_parts:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == test_part)
    ]

    print(f"train on {train_part}, test on {test_part}")
    display_test([test(test_df, fit)])

for test_perplm in test_perplms:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == test_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_part}, test on {test_perplm}")
    display_test([test(test_df, fit)])


## Thoughts

- $k$ is relatively stable to changes in partition, SNR; moreso than $a,b$
    - $R^2$ is inflated by low SNRs by virtue of being near to the intercept
- $k$ can probably be inferred from $a,b$
    - $\log b / a$ stabilizes as SNR increases. Lesser variance, maybe? How do we reconcile that $k$ is reasonably robust to changes in partition?
        - Check if $\log b / a$ converges to something else on `dev-other`. Perhaps it's close enough to the `dev-clean` ratio that drastic changes in entropy dominate?
    - Based on its current trajectory, the ratio of $\log b / a \approx 12$ will never be exceeded by the entropy of the partition. The corresponding perplexity is in the vicinity of $162,000$.
- $k$ can be estimated by a ratio of entropies
    - As speech becomes cleaner, errors are more likely to occur one at a time. Guesswork more closely resembles the perplexity computations, which are conditioned on single words.
- Serious problem with ratio estimates (Curran-Everett). $k$ may be compromised.
    - Easy solution is to include intercepts. Regardless, $k$ can be used to predict with or without explaining.
- Klakow's model predicts accuracy $b$ with $0$ entropy. However, $0$ entropy ought to be $0$ errors.