## Preamble

In [2]:
# load packages + declare constants

%load_ext autoreload
%autoreload

import os

import pandas as pd
import numpy as np
import pingouin as pg
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
import pingouin as pg
import statsmodels.api as sm

from analysis_utils import *
from scipy.optimize import curve_fit

pio.renderers.default = "vscode"

# gunzip -c data/local/lm/3-gram.arpa.gz | head -n 3
LM_VOCAB_SIZE = 200_003

FIGS = '../figs'
os.makedirs(FIGS, exist_ok=True)
FIG_TYPE = 'pdf'

COL_SIZE_MM = 80
MID_MARGIN_SIZE_MM = 10

MM_TO_IN = 0.03937008
IN_TO_PX = 96

COL_SIZE_PX = int(COL_SIZE_MM * MM_TO_IN * IN_TO_PX)
MID_MARGIN_SIZE_PX = int(MID_MARGIN_SIZE_MM * MM_TO_IN * IN_TO_PX)

DOUBLE_COL_SIZE_PX = COL_SIZE_PX * 2 + MID_MARGIN_SIZE_PX

FONT_SIZE = 9
FONT_FAMILY = "Times New Roman"
FONT = dict(size=FONT_SIZE, family=FONT_FAMILY)

E_I = "<i>e<sub>i</sub></i>"
E_C = "<i>e<sub>c</sub></i>"
P_I = "<i>p<sub>i</sub></i>"
P_C = "<i>p<sub>c</sub></i>"
H_I = "<i>H<sub>i</sub></i>"
H_C = "<i>H<sub>c</sub></i>"

MDL_LATLM_RESLM2RENAME = {
    'wav2vec2-large-960h-lv60_null_null': 'W2V2-L',
    'wav2vec2-base-960h_null_null': 'W2V2-B',
    'tdnn_1d_sp_tgsmall_tgsmall': 'TDNN-3',
    'tdnn_1d_sp_tgsmall_fglarge': 'TDNN-4',
    'tri6b_tgsmall_tgsmall': 'GMM-3',
}

PERPLM2RENAME = {
    'tgsmall': 'Pruned 3-gram',
    'fglarge': '4-gram',
    'rnnlm_lstm_1a': 'RNN',
}

def format_fig_path(prefix : str, **kwargs) -> str:
    pth = f"{FIGS}/{prefix}"
    for key, vals in sorted(kwargs.items()):
        if isinstance(vals, (str, int, float, bool)):
            vals = (str(vals).lower(),)
        assert isinstance(vals, (set, list, tuple)) and len(vals) and all(isinstance(x, str) for x in vals)
        pth += f'-{key}'
        for val in sorted(vals):
            pth += f"_{val.replace('-', '_')}"
    pth += f'.{FIG_TYPE}'
    return pth


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# load tables

print("text_df contents")
text_df = read_text_as_df()
display(text_df.head())

print("perp_df contents")
perp_df = read_perps_as_df()
perp_df = perp_df.merge(text_df[['utt', 'len']], on='utt')
display(perp_df.head())

print("wer_df contents")
wer_df = read_best_wers_as_df()
display(wer_df.head())

print("uttwer_df contents")
uttwer_df = read_best_uttwers_as_df()
uttwer_df = uttwer_df.merge(text_df[['utt', 'len']], on='utt')
display(uttwer_df.head())

text_df contents


Unnamed: 0,utt,text,part,len
0,lbi-100-121669-0000,TOM THE PIPER'S SON,train-clean-460,4
1,lbi-100-121669-0001,THE PIG WAS EAT AND TOM WAS BEAT AND TOM RAN C...,train-clean-460,15
2,lbi-100-121669-0002,HE NEVER DID ANY WORK EXCEPT TO PLAY THE PIPES...,train-clean-460,36
3,lbi-100-121669-0003,BUT HE WAS SO SLY AND CAUTIOUS THAT NO ONE HAD...,train-clean-460,42
4,lbi-100-121669-0004,AND THEY LIVED ALL ALONE IN A LITTLE HUT AWAY ...,train-clean-460,51


perp_df contents


Unnamed: 0,utt,perp,perplm,part,ent,len
0,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_002583_002756,188.283,tgmed,PRV,5.237946,11
1,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_002781_002912,2046.096,tgmed,PRV,7.623689,4
2,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_006175_006438,839.002,tgmed,PRV,6.732213,14
3,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_006553_006711,399.898,tgmed,PRV,5.99121,8
4,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_008786_008920,2707.916,tgmed,PRV,7.903935,4


wer_df contents


Unnamed: 0,wer,ins,del,sub,lmwt,wip,mdl,latlm,reslm,part,snr,acc
0,0.0524,265,298,2287,11,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9476
1,0.6015,1029,10167,21528,10,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,2.0,0.3985
2,0.2877,740,3749,11163,12,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,8.0,0.7123
3,0.193,589,1905,8006,12,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,11.0,0.807
4,0.8439,441,23839,21628,8,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,-3.0,0.1561


uttwer_df contents


Unnamed: 0,utt,wer,mdl,latlm,reslm,part,snr,acc,len
0,lbi-1272-128104-0000,0.0588,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9412,17
1,lbi-1272-128104-0001,0.1,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9,10
2,lbi-1272-128104-0002,0.0312,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9688,32
3,lbi-1272-128104-0003,0.0417,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9583,24
4,lbi-1272-128104-0004,0.1618,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.8382,68


## Perplexity

In [9]:
print("entropy/perplexity by partition and LM")
df = agg_mean_by_lens(perp_df, 'len', 'ent', ['part', 'perplm'])
df['perp'] = np.exp(df['ent'])
df = df.pivot(values=['ent', 'perp'], index='part', columns='perplm')
display(df.round(2))


entropy/perplexity by partition and LM


Unnamed: 0_level_0,ent,ent,ent,ent,ent,perp,perp,perp,perp,perp
perplm,fglarge,rnnlm_lstm_1a,tglarge,tgmed,tgsmall,fglarge,rnnlm_lstm_1a,tglarge,tgmed,tgsmall
part,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
PRV,5.13,5.09,5.15,5.35,5.48,169.04,162.94,172.09,211.34,240.49
ROC,5.44,5.35,5.45,5.61,5.72,229.9,210.93,232.26,273.85,306.19
dev-clean,5.02,4.69,5.15,5.52,5.73,151.64,109.37,173.04,249.27,307.44
dev-other,4.95,4.64,5.08,5.43,5.63,140.83,103.87,161.37,227.9,279.69
hp,7.14,7.24,7.13,7.6,7.87,1262.63,1396.79,1250.38,1999.89,2611.67
lp,8.4,8.26,8.43,8.36,8.35,4427.25,3847.25,4604.73,4255.38,4228.28
test-clean,5.06,4.74,5.19,5.55,5.76,158.3,114.17,179.32,257.72,315.82
test-other,4.98,4.68,5.11,5.46,5.66,145.48,108.24,165.57,235.89,288.2
zp,8.95,9.01,8.98,8.7,8.65,7692.62,8145.53,7911.05,6012.86,5732.55


In [24]:
print('distribution of per-utt entropy by partition and LM')
parts = ('dev-clean', 'dev-other', 'ROC', 'PRV')
perplms = ('Pruned 3-gram', '4-gram', 'RNN')
assert all(v in PERPLM2RENAME.values() for v in perplms)
df = perp_df.loc[perp_df['part'].isin(parts)]
df = df.assign(perplm=df['perplm'].map(PERPLM2RENAME)).dropna()
fig = px.box(
    df, y='ent', color='perplm', x='part',
    # box=True,
    labels=dict(ent='<i>H_s</i>', part='Partition', lm="LM", perplm="LM"),
    category_orders=dict(part=parts, perplm=perplms),
)
fig.update_traces(marker=dict(size=4), line=dict(width=1))
fig.update_layout(
    legend=dict(orientation="h", yanchor="bottom", y=1.0),
    yaxis=dict(tickangle=270, title_standoff=5),
    margin=dict(l=0, r=10, t=10, b=30),
    font=FONT,
    width=COL_SIZE_PX, height=int(COL_SIZE_PX),
)
# fig.show()
fig.write_image(format_fig_path("violin-ent", perplms=perplms, parts=parts))

distribution of per-utt entropy by partition and LM






In [23]:
parts = ('dev-clean', 'dev-other', 'ROC', 'PRV')
perplms = ('Pruned 3-gram', '4-gram', 'RNN')
assert all(v in PERPLM2RENAME.values() for v in perplms)
df = perp_df.loc[perp_df['part'].isin(parts)]
df = df.assign(perplm=df['perplm'].map(PERPLM2RENAME)).dropna()
display(pg.normality(df, dv='ent', group='perplm', method='normaltest').round(3))

print("pairwise spearman correlations of entropy across LMs")
df = df.pivot(values='ent', index='utt', columns='perplm')
display(pg.pairwise_corr(df, columns=df.columns, alternative='greater', method='spearman').round(3))


Unnamed: 0_level_0,W,pval,normal
perplm,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
RNN,2389.619,0.0,False
Pruned 3-gram,694.098,0.0,False
4-gram,1011.655,0.0,False


pairwise spearman correlations of entropy across LMs


Unnamed: 0,X,Y,method,alternative,n,r,CI95%,p-unc,power
0,4-gram,Pruned 3-gram,spearman,greater,17351,0.889,"[0.89, 1.0]",0.0,1.0
1,4-gram,RNN,spearman,greater,17351,0.859,"[0.86, 1.0]",0.0,1.0
2,Pruned 3-gram,RNN,spearman,greater,17351,0.783,"[0.78, 1.0]",0.0,1.0


In [None]:
print("per-utterance perplexity vs. rank by LM")

perplms = ('tgsmall', 'fglarge', 'rnnlm_lstm_1a')
parts = ('dev-clean', 'dev-other', 'ROC', 'PRV')

df = perp_df.copy()
df = df.loc[df['perplm'].isin(perplms) & df['part'].isin(parts)]
df['rank'] = df.groupby(['perplm', 'part'])['perp'].transform(lambda x: x.rank() / len(x))

fig = px.scatter(
    df, x='rank', y='perp', facet_col='perplm', color='part', log_y=True,
    labels=dict(rank='normalized rank', perp='Perplexity', perplm='LM', part='Partition'),
    category_orders=dict(part=parts, perplm=perplms),
)
fig.update_traces(marker=dict(size=3))
fig.update_layout(
    # legend=dict(orientation="h", yanchor="bottom", y=1.03),
    yaxis=dict(tickangle=45, title_standoff=5),
    margin=dict(l=0, r=10, t=20, b=0),
    font=FONT,
    width=DOUBLE_COL_SIZE_PX, height=COL_SIZE_PX,
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
    )
)
# fig.show()
fig.write_image(format_fig_path('perp-by-rank', perplms=perplms, parts=parts))

## WER

In [None]:
part = 'dev-clean'
latlm = reslm = 'tgsmall'
mdl = 'tdnn_1d_sp'
desc = f"({part} partition, {mdl} model, {latlm} lattice LM, and {reslm} rescoring lm)"

df = uttwer_df.loc[
    np.isfinite(uttwer_df['snr']) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['part'] == part) &
    (uttwer_df['mdl'] == mdl)
].copy()
df['snr'] = df['snr'].astype('int')

with pd.option_context('display.max_rows', 10):
    print(f"test of normality of per-utterance WERs given SNR {desc}")
    display(pg.normality(df, dv='wer', group='snr', method='normaltest').round(3).sort_index())


    print(f"spearman correlation of WERs across SNRs {desc}")
    df = df.pivot(values='wer', index='utt', columns='snr')
    display(pg.pairwise_corr(df, columns=df.columns, alternative='greater', method='spearman').round(3).sort_index())

print(f"scatter plot matrix of per-utterance WERs of select SNRs {desc}")
fig = px.scatter_matrix(df, dimensions=[5, 10, 20, 30], opacity=0.1)
fig.update_layout({"xaxis"+str(i+1): dict(range = [-0.1, 1]) for i in range(len(df.columns))})
fig.update_layout({"yaxis"+str(i+1): dict(range = [-0.1, 1]) for i in range(len(df.columns))})
fig.show()

## Zhang et al

In [19]:
# Zhang et al (2023) "Estimate the noise effect on automatic speech recognition
# accuracy for mandarin by an approach associating articulation index"
# FIXME(sdrobert): the fit is very bad if we use eq. 12

latlm = 'tgsmall'
reslm = 'tgsmall'
part = 'dev-other'
desc = f"({part} partition, {latlm} lattice LM, and {reslm} rescoring lm)"
num_points = 100
fit_inverse = False

df = wer_df.replace(dict(latlm=dict(null=latlm), reslm=dict(null=reslm)))
df = df.loc[
    (df['latlm'] == latlm) &
    (df['reslm'] == reslm) &
    (df['part'] == part)
].copy()

idx = np.isinf(df['snr'])
df, Ainvs = df.loc[~idx], df.loc[idx, ['mdl', 'acc']]
snr_min = df['snr'].min() - 1
snr_max = df['snr'].max() + 1
x_interp = np.linspace(snr_min, snr_max, num_points)

mdls = df['mdl'].unique()
assert all(mdls == Ainvs['mdl'].unique())

df['acc'] *= 100

# fit = []
# fig = go.Figure()
# for mdl_idx, mdl in enumerate(mdls):
#     colour = px.colors.qualitative.Plotly[mdl_idx]
#     df_ = df.loc[df['mdl'] == mdl]
#     Ainv = Ainvs.loc[Ainvs['mdl'] == mdl, 'acc'].iloc[0]
#     A_init = 1 / Ainv
#     N = len(df_)
#     x = df_['snr'].array
#     y = df_['acc'].array
#     A, B, C = zhang_fit(x, y, fit_inverse)
#     y_pred = zhang_func(x, A, B, C)
#     r2 = r2_score(y, y_pred)
#     fit.append(dict(mdl=mdl, A=A, B=B, C=C, r2=r2))
#     y_interp = zhang_func(x_interp, A, B, C)
#     fig.add_scatter(
#         x=x_interp, y=y_interp * 100,
#         name=f"{mdl} fit",
#         mode='lines',
#         opacity=0.5,
#         showlegend=False,
#         line=dict(color=colour),
#     )
#     fig.add_scatter(
#         x=x, y=df_['acc'] * 100,
#         name=mdl, mode='markers',
#         marker=dict(color=colour),
#     )
#     # fig.add_annotation(
#     #     x=x_interp[ratio * (mdl_idx + 1)], y=y_interp[ratio * (mdl_idx + 1)] * 100,
#     #     text=f"A={A:.02f},B={B:.02f},C={C:.02f}",
#     #     showarrow=True,
#     #     font=dict(color=colour, size=FONT_SIZE),
#     # )
# print(f"Zhang et al fits by model {desc}")
# display(pd.DataFrame.from_records(fit).round(3))

fig = px.scatter(
    df, x='snr', y='acc', color='mdl',
)

fig.update_traces(marker=dict(size=5), line=dict(width=2))
fig.update_layout(
    xaxis=dict(title='SNR (dB)', range=[snr_min, snr_max], tickformat='d'),
    yaxis=dict(title='Accuracy (%)', range=[0, 100], tickformat='d'),
    legend=dict(title='Model', yanchor="top", y=0.99, xanchor='left', x=0.01),
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    width=COL_SIZE_PX, height=COL_SIZE_PX,
)
fig.show()
# fig.write_image(format_fig_path("zhang", latlm=latlm, reslm=reslm, part=part))






## Perplexity vs. WER

In [None]:
# wer by perp

perplm = 'tgsmall'

mdl = 'tdnn_1d_sp'
latlm = reslm = 'tgsmall'
num_points = 100
part = 'dev-clean'
print(
    f"mdl {mdl}, partition {part}, lattice LM {latlm}, rescore LM {reslm}, "
    f"perlexity LM {perplm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)]
df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl)
], on=['utt', 'part'])
df = df.loc[df['snr'].isnull()]  # without noise
ymin, ymax = df['wer'].quantile(0.05), df['wer'].quantile(0.95)
xmin, xmax = df['perp'].quantile(0.05), df['perp'].quantile(0.95)
perp_interp = np.linspace(xmin, xmax, num_points)

print("per-utterance WER by perplexity")
fig = px.scatter(df, x='perp', y='wer')
fig.update_xaxes(type='log', range=[np.log10(xmin), np.log10(xmax)])
fig.update_yaxes(range=[ymin, ymax])
fig.show()

## Boothroyd and Nittrouer

In [172]:
# stats

perplm = binlm = 'rnnlm_lstm_1a'

parts = ('dev-clean', 'dev-other', 'ROC', 'PRV')
binpart = 'dev-clean'
perplm = binlm = 'rnnlm_lstm_1a'

num_bins = 3
power = 0.05

bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins, lower_quant=0.05, upper_quant=0.95, by_rank=True)[1]
df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'].isin(parts))].copy()
df['ent_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['ent_bin'].dtype.categories

df = df.merge(uttwer_df, on=['utt', 'part', 'len'])
df['mdl'] = (df['mdl'] + '_' + df['latlm'] + '_' + df['reslm']).map(MDL_LATLM_RESLM2RENAME).dropna()
df = agg_mean_by_lens(df, 'len', 'wer',  ['snr', 'part', 'mdl', 'ent_bin'])
df['lwer'] = np.log(df['wer'])
df = df.loc[np.isfinite(df['snr']) & (df['lwer'] < 0)].dropna()

mask = df['ent_bin'] == bin_cats[num_bins - 1]
df, df_out = df.loc[~mask], df.loc[mask]
df = df.merge(df_out, on=['snr', 'part', 'mdl'], suffixes=('_in', '_out'))
df = df.assign(ent_bin_in=df.ent_bin_in.cat.remove_unused_categories())

df = df.loc[df['mdl'] == 'W2V2-L']

fits = []
for b in (0.0, 0.1, 0.25, 0.5, 0.75, 0.9):
    df_ = df.loc[(df['wer_out'] > b)]
    for negative_log_fit in (True, False):
        fit = boothroyd_fit(df_, alpha=power, negative_log_fit=negative_log_fit)
        fit['b'] = b
        fit['nlf'] = negative_log_fit
        fits.append(fit.reset_index())
fits = pd.concat(fits).pivot(index=['nlf', 'b'], columns=['name'], values=['coef', 'se', 'ci_low', 'ci_high'])
display(fits.round(2))

fit = boothroyd_fit(df, alpha=power, negative_log_fit=True)

fig = px.scatter(df, 'wer_out', 'wer_in', color='ent_bin_in', log_x=True, log_y=True)
x = np.linspace(df['wer_in'].min(), 0.99, 100)
for bin in range(num_bins - 1):
    fit_ = fit.loc[f"ent_bin_in[{bin_cats[bin]}]"]
    y = boothroyd_func(x, fit_["coef"])
    error_up = boothroyd_func(x, fit_["ci_low"]) - y
    error_down = y - boothroyd_func(x, fit_["ci_high"])
    fig.add_scatter(
        x=x,
        y=y,
        # error_y=dict(
        #     type='data',
        #     symmetric=False,
        #     array=error_up,
        #     arrayminus=error_down,
        # ),
        line=dict(color="black"),
    )
fig.show()

# # display(pg.normality(fit['resid'].iloc[0]))

# # norm = pg.normality(df, 'k', 'ent_bin_in')
# # display(norm.round(3))
# # assert not norm.iloc[0]['normal']  # use ANOVA or Welch-ANOVA otherwise
# # assert num_bins == 3  # use kruskal otherwise
# # k_low = df.loc[df['ent_bin_in'] == bin_cats[0], 'k']
# # k_high = df.loc[df['ent_bin_in'] == bin_cats[1], 'k']
# # mwu = pg.mwu(k_low, k_high)
# # display(mwu.round(3))
# # assert mwu.iloc[0]['p-val'] < power

# # norms, krs, pts, frs, lrs = [], [], [], [], []
# # for bin in range(0, num_bins - 1):
# #     bin_ = bin_cats[bin]
# #     df_bin = df.loc[df['ent_bin_in'] == bin_]

# #     norm = pg.normality(df_bin, 'k', 'mdl', alpha=power)
# #     assert not norm.iloc[0]['normal']
# #     norms.append(norm)
# #     kr = pg.kruskal(df_bin, 'k', 'mdl')
# #     assert kr.iloc[0]['p-unc'] < power
# #     krs.append(kr)
# #     pt = pg.pairwise_tests(df_bin, 'k', 'mdl', parametric=False, alpha=power)
# #     pts.append(pt)

# #     for df_ in (norm, kr, pt):
# #         df_['ent_bin_in'] = bin_
# #         df_['dv'] = 'k'
# #         df_['iv'] = 'mdl'

# #     fr = pg.friedman(df_bin, 'k', 'snr', 'mdl')
# #     fr['ent_bin_in'], fr['dv'], fr['iv'] = bin_, 'k', 'snr'
# #     frs.append(kr)

# #     for mdl in df['mdl'].unique():
# #         df_mdl = df_bin.loc[df_bin['mdl'] == mdl]

# #         lr = pg.linear_regression(df_mdl[['lwer_out']], df_mdl['lwer_in'], add_intercept=True)
# #         lr['ent_bin_in'], lr['mdl'], lr['part'] = bin_, mdl, 'all'
# #         lrs.append(lr)
# #         px.scatter(x=np.arange(len(lr.residuals_)), y=lr.residuals_).show()
# #         display(norm)
# #         for part in parts:
# #             df_part = df_mdl.loc[df_mdl['part'] == part]
# #             lr = pg.linear_regression(df_part[['lwer_out']], df_part['lwer_in'], add_intercept=True)
# #             lr['ent_bin_in'], lr['mdl'], lr['part'] = bin_, mdl, part
# #             lrs.append(lr)


# # idx = ['ent_bin_in', 'dv', 'iv']

# # norms = pd.concat(norms).set_index(idx)
# # display(norms.round(3))

# # krs = pd.concat(krs).set_index(idx)
# # display(krs.round(3))

# # frs = pd.concat(frs).set_index(idx)
# # display(frs.round(3))

# # pts = pd.concat(pts).set_index(idx)
# # display(pts.round(3))

# # lrs = pd.concat(lrs).pivot(index=['ent_bin_in', 'mdl', 'part'], values=['coef', 'se', 'pval'], columns=['names'])
# # display(lrs.round(3))

# # x_lims = [1, 101]
# # y_lims = [1, 101]

# # fig = px.scatter(
# #     df.loc[df['mdl'].isin({'TDNN-3', 'W2V2-L'})], x='wer_out', y='wer_in', color='part', symbol='ent_bin_in',
# #     facet_col='mdl',
# #     symbol_sequence=list(range(num_bins - 1)),
# #     category_orders=dict(part=parts),
# #     labels=dict(
# #         wer_out=f"Error rate {E_I} (%) for {H_I} in {bin_cats[num_bins - 1]}",
# #         wer_in=f"Error rate {E_C} (%)",
# #     ),
# #     log_x=True, log_y=True, range_x=x_lims, range_y=y_lims,
# # )
# # for i, trace in enumerate(fig.data):
# #     if trace.mode == 'markers':
# #         name = trace.name.split(', ')
# #         if name[1] in bin_cats[1:]:
# #             trace['name'] = ''
# #             trace['showlegend']=False
# #         else:
# #             trace['name'] = name[0]
# # for bin in range(num_bins - 1):
# #     fig.add_scatter(
# #         y=[None], mode='markers',
# #         marker=dict(color='black', symbol=bin),
# #         legend="legend2",
# #         name=bin_cats[bin],
# #     )
# # for col in range(1, 3):
# #     fig.add_scatter(
# #         x=x_lims,
# #         y=x_lims,
# #         mode='lines',
# #         line=dict(color="grey", width=1, dash='dash'),
# #         showlegend=False,
# #         col=col, row=1,
# #     )
# # fig.update_traces(marker=dict(line_width=1, size=4))
# # fig.update_layout(
# #     width=DOUBLE_COL_SIZE_PX, height=COL_SIZE_PX,
# #     margin=dict(l=0, r=10, t=20, b=0),
# #     font=FONT,
# #     legend=dict(
# #         title_text='Partition',
# #         yanchor="top",
# #         y=0.99,
# #         xanchor="left",
# #         x=0.01,
# #     ),
# #     legend2=dict(
# #         title_text=f"{H_C} range",
# #         yanchor="bottom",
# #         y=0.01,
# #         xanchor="right",
# #         x=0.99,
# #     )
# # )
# # fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
# # fig.write_image(format_fig_path('foo', binlm=binlm))

Unnamed: 0_level_0,Unnamed: 1_level_0,coef,coef,se,se,ci_low,ci_low,ci_high,ci_high
Unnamed: 0_level_1,name,"ent_bin_in[(3.4,4.3]]","ent_bin_in[(4.3,5.0]]","ent_bin_in[(3.4,4.3]]","ent_bin_in[(4.3,5.0]]","ent_bin_in[(3.4,4.3]]","ent_bin_in[(4.3,5.0]]","ent_bin_in[(3.4,4.3]]","ent_bin_in[(4.3,5.0]]"
nlf,b,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
False,0.0,1.31,1.16,0.05,0.05,1.29,1.16,1.42,1.25
False,0.1,1.4,1.23,0.07,0.01,1.36,1.23,1.49,1.25
False,0.25,1.44,1.26,0.04,0.01,1.4,1.25,1.54,1.28
False,0.5,1.58,1.28,0.02,0.02,1.56,1.25,1.61,1.31
False,0.75,1.72,1.32,0.12,0.08,1.7,1.29,1.79,1.32
False,0.9,1.77,1.29,0.09,0.13,1.59,1.18,1.79,1.34
True,0.0,1.5,1.24,0.06,0.03,1.4,1.2,1.62,1.29
True,0.1,1.56,1.27,0.05,0.02,1.47,1.23,1.65,1.29
True,0.25,1.58,1.27,0.04,0.02,1.5,1.24,1.66,1.29
True,0.5,1.64,1.28,0.04,0.02,1.59,1.25,1.74,1.3








In [None]:
mdl = 'wav2vec2-base-960h'
mdl = 'wav2vec2-large-960h-lv60'
latlm = reslm = 'null'

# mdl = 'tdnn_1d_sp'
# latlm = 'tgsmall'
# reslm = 'fglarge'

perplm = binlm = 'rnnlm_lstm_1a'

num_bins = 3
num_points = 100
binpart = 'dev-clean'
parts = ('dev-clean', 'dev-other', 'ROC', 'PRV')
add_intercept = False
log_lims = 1, 101
cfg = dict(
    mdl=mdl, latlm=latlm, reslm=reslm, perplm=perplm, binlm=binlm, num_bins=num_bins,
    binpart=binpart, parts=parts, intercept=add_intercept
)

bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins, lower_quant=0.05, upper_quant=0.95, by_rank=False)[1]
df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'].isin(parts))].copy()
df['ent_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['ent_bin'].dtype.categories

min_bin, max_bin = 0, num_bins - 1

df = df.merge(
    uttwer_df.loc[
        (uttwer_df['reslm'] == reslm) &
        (uttwer_df['latlm'] == latlm) &
        (uttwer_df['mdl'] == mdl)
    ], on=['utt', 'part', 'len'])

df = agg_mean_by_lens(df, 'len', 'wer', ['snr', 'ent_bin', 'part'])
df['lwer'] = np.log(df['wer'])
df['wer'] *= 100
df['acc'] = 100 - df['wer']
mask = df['ent_bin'] == bin_cats[max_bin]
df, df_out = df.loc[~mask], df.loc[mask]
df = df.merge(df_out, on=['snr', 'part'], suffixes=('_in', '_out'))
df['k'] = df['lwer_in'] / df['lwer_out']
df['r'] = df['wer_out'] / df['wer_in']

fig = px.scatter(
    df, x='wer_out', y='k', color='part', symbol="ent_bin_in",
    symbol_sequence=list(range(min_bin, max_bin)),
    category_orders=dict(part=parts),
    labels=dict(
        k=f"Ratio of log error rates (log {E_C} over log {E_I})",
        wer_out=f"Error rate {E_I} (%) for {H_I} in {bin_cats[max_bin]}",
    ),
)
fig.update_traces(marker=dict(line_width=1, size=4))
for i, trace in enumerate(fig.data):
    name = trace.name.split(', ')
    if name[1] in bin_cats[min_bin + 1:]:
        trace['name'] = ''
        trace['showlegend']=False
    else:
        trace['name'] = name[0]
for bin in range(min_bin, max_bin):
    fig.add_scatter(
        y=[None], mode='markers',
        marker=dict(color='black', symbol=bin),
        legend="legend2",
        name=bin_cats[bin]
    )
fig.update_layout(
    width=COL_SIZE_PX, height=COL_SIZE_PX,
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    xaxis=dict(range=[0, 101]),
    yaxis=dict(range=[1, 3]),
    legend=dict(
        title_text='Partition',
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
    ),
    legend2=dict(
        title_text=f"{H_C} range",
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.35,
    )
)
# fig.show()
fig.write_image(format_fig_path('k', **cfg))

fig = px.scatter(
    df, x='wer_out', y='r', color='part', symbol="ent_bin_in",
    symbol_sequence=list(range(min_bin, max_bin)),
    category_orders=dict(part=parts),
    labels=dict(
        r=f"Ratio of error rates ({E_I} over {E_C})",
        wer_out=f"Error rate {E_I} (%) for {H_I} in {bin_cats[max_bin]}",
    ),
)
fig.update_traces(marker=dict(line_width=1, size=4))
for i, trace in enumerate(fig.data):
    name = trace.name.split(', ')
    if name[1] in bin_cats[min_bin + 1:]:
        trace['name'] = ''
        trace['showlegend']=False
    else:
        trace['name'] = name[0]
for bin in range(min_bin, max_bin):
    fig.add_scatter(
        y=[None], mode='markers',
        marker=dict(color='black', symbol=bin),
        legend="legend2",
        name=bin_cats[bin]
    )
fig.update_layout(
    width=COL_SIZE_PX, height=COL_SIZE_PX,
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    xaxis=dict(range=[0, 101]),
    yaxis=dict(range=[0.9, 3.5]),
    legend=dict(
        title_text='Partition',
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99,
    ),
    legend2=dict(
        title_text=f"{H_C} range",
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.65,
    )
)
# fig.show()
fig.write_image(format_fig_path('r', **cfg))

fig = px.scatter(
    df, x='acc_out', y='acc_in', color='part', symbol="ent_bin_in",
    symbol_sequence=list(range(min_bin, max_bin)),
    category_orders=dict(part=parts),
    labels=dict(
        acc_out=f"Accuracy {P_I} (%) for {H_I} in {bin_cats[max_bin]}",
        acc_in=f"Accuracy {P_C} (%)",
    ),
)
fig.update_traces(marker=dict(line_width=1, size=4))
for i, trace in enumerate(fig.data):
    name = trace.name.split(', ')
    if name[1] in bin_cats[min_bin + 1:]:
        trace['name'] = ''
        trace['showlegend']=False
    else:
        trace['name'] = name[0]
for bin in range(min_bin, max_bin):
    fig.add_scatter(
        y=[None], mode='markers',
        marker=dict(color='black', symbol=bin),
        legend="legend2",
        name=bin_cats[bin]
    )
fig.add_scatter(
    x=[0, 100],
    y=[0, 100],
    mode='lines',
    line=dict(color="grey", width=1, dash='dash'),
    showlegend=False,
)
fig.update_layout(
    width=DOUBLE_COL_SIZE_PX, height=COL_SIZE_PX,
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    xaxis=dict(range=[0, 100]),
    yaxis=dict(range=[0, 100]),
    legend=dict(
        title_text='Partition',
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
    ),
    legend2=dict(
        title_text=f"{H_C} range",
        yanchor="bottom",
        y=0.01,
        xanchor="right",
        x=0.99,
    ),
)
# fig.show()
fig.write_image(format_fig_path('acc-ratio', **cfg))

fig = px.scatter(
    df, x='wer_out', y='wer_in', color='part', symbol="ent_bin_in",
    symbol_sequence=list(range(min_bin, max_bin)),
    category_orders=dict(part=parts),
    labels=dict(
        wer_out=f"Error rate {E_I} (%) for {H_I} in {bin_cats[max_bin]}",
        wer_in=f"Error rate {E_C} (%)",
    ),
    log_x=True, log_y=True, range_x=log_lims, range_y=log_lims,
)


fig.update_traces(marker=dict(line_width=1, size=4))
for i, trace in enumerate(fig.data):
    name = trace.name.split(', ')
    if name[1] in bin_cats[min_bin + 1:]:
        trace['name'] = ''
        trace['showlegend']=False
    else:
        trace['name'] = name[0]
for bin in range(min_bin, max_bin):
    fig.add_scatter(
        y=[None], mode='markers',
        marker=dict(color='black', symbol=bin),
        legend="legend2",
        name=bin_cats[bin]
    )
fig.add_scatter(
    x=log_lims,
    y=log_lims,
    mode='lines',
    line=dict(color="grey", width=1, dash='dash'),
    showlegend=False,
)
fig.update_layout(
    width=DOUBLE_COL_SIZE_PX, height=COL_SIZE_PX,
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    legend=dict(
        title_text="Partition",
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01,
    ),
    legend2=dict(
        title_text=f"{H_C} range",
        yanchor="bottom",
        y=0.01,
        xanchor="right",
        x=0.99,
    )
)
fig.show()
fig.write_image(format_fig_path('lwer-ratio', **cfg))



In [None]:
# boothroyd prediction
num_bins = 7
train_mdl = 'tdnn_1d_sp'
train_part = 'dev-clean'
train_latlm = train_reslm = 'tgsmall'
train_perplm = 'fglarge'
test_mdls = ('tri6b', 'wav2vec2-large-960h-lv60', 'wav2vec2-base-960h')
test_parts = ('dev-other',)
test_reslms = ('fglarge',)
test_perplms = ('tgmed', 'fglarge')
add_intercept = False

# determine SNRs which don't have extremal values. It is more important to set the
# max, as high values tend to inflate correlations (i.e. 0.99^k ~= 0.99)
min_wer, max_wer = 0.0, 0.20
df = wer_df.replace(dict(latlm=dict(null=train_latlm), reslm=dict(null=train_reslm)))
df = df.loc[
    (df['mdl'] == train_mdl) &
    (df['latlm'] == train_latlm) &
    (df['reslm'] == train_reslm) &
    (df['part'] == train_part) &
    np.isfinite(df['snr'])
].groupby('snr')['wer'].agg(['min', 'max'])
good_snrs = df.index[(df['min'] >= min_wer) & (df['max'] <= max_wer)]
good_snr_min, good_snr_max = good_snrs.min(), good_snrs.max()
good_snr_mid = (good_snr_min + good_snr_max) / 2
print(f"good SNRs: [{good_snr_min}, {good_snr_max}]")

# all records we'll consider
df = perp_df.copy()
bounds = bin_series(
    perp_df.loc[
        (perp_df['perplm'] == train_perplm) &
        (perp_df['part'] == train_part)
    , 'ent'], num_bins)[1]
df['perp_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['perp_bin'].dtype.categories

df = df.merge(
    uttwer_df.loc[
        np.isinf(uttwer_df['snr']) |
        ((uttwer_df['snr'] >= good_snr_min) & (uttwer_df['snr'] <= good_snr_max))
    ], on=['utt', 'part', 'len'])
df = agg_mean_by_lens(
    df,
    'len',
    ['wer', 'ent', 'len'],
    ['snr', 'perp_bin', 'perplm', 'reslm', 'latlm', 'mdl', 'part'],
)
df['lwer'] = np.log(df['wer'])
df = df.replace(dict(latlm=dict(null=train_latlm), reslm=dict(null=train_reslm)))
df = df.dropna()

train_df = df.loc[
    (df['latlm'] == train_latlm) &
    (df['reslm'] == train_reslm) &
    (df['perplm'] == train_perplm) &
    (df['mdl'] == train_mdl) &
    (df['part'] == train_part)
]

print('train entropy by bin')
display(train_df.groupby('perp_bin', observed=False)[['ent']].mean().round(3))
ent_fit = dict()
for in_bin in range(num_bins):
    ent_in = train_df.loc[train_df['perp_bin'] == bin_cats[in_bin], 'ent'].iloc[0]
    for out_bin in range(num_bins):
         ent_out = train_df.loc[train_df['perp_bin'] == bin_cats[out_bin], 'ent'].iloc[0]
         ent_fit[(in_bin, out_bin)] = (12 - ent_in) / (12 - ent_out), 0

def train(df : pd.DataFrame) -> dict[tuple[int, int], tuple[float,float]]:
    fits = dict()
    df = df.loc[np.isfinite(df['snr'])]
    for in_bin in range(num_bins):
        df_in = df.loc[df['perp_bin'] == bin_cats[in_bin], ['snr', 'lwer']]
        for out_bin in range(num_bins):
            df_out = df.loc[df['perp_bin'] == bin_cats[out_bin], ['snr', 'lwer']]
            df_in_out = df_in.merge(df_out, on='snr', suffixes=('_in', '_out'))
            print(df_in_out.head())
            k, c = boothroyd_fit(
                df_in_out['lwer_out'].to_numpy(),
                df_in_out['lwer_in'].to_numpy(),
                add_intercept=add_intercept,
            )
            print(k, c)
            fits[(in_bin, out_bin)] = k, c
    return fits

def test(df: pd.DataFrame, fits : dict[(int, int), tuple[float, float]], plot : bool = False) -> pd.DataFrame:
    res = dict()
    is_inf = np.isinf(df['snr'])
    df_nonoise, df = df.loc[is_inf], df.loc[~is_inf]
    for in_bin in range(num_bins):
        df_in = df.loc[df['perp_bin'] == bin_cats[in_bin]]
        wer_true = df_nonoise.loc[df_nonoise['perp_bin'] == bin_cats[in_bin], 'wer'].iloc[0]
        df_in = df_in[['snr', 'lwer']]
        for out_bin in range(num_bins):
            df_out = df.loc[df['perp_bin'] == bin_cats[out_bin]]
            k, c = fits[(in_bin, out_bin)]
            wer_pred = df_nonoise.loc[df_nonoise['perp_bin'] == bin_cats[out_bin], 'wer'].iloc[0] ** k
            df_out = df_out[['snr', 'lwer']]
            df_in_out = df_in.merge(df_out, on='snr', suffixes=('_in', '_out'))
            y_true = df_in_out['lwer_in'].to_numpy()
            y_pred = k * df_in_out['lwer_out'].to_numpy() + c
            r2 = r2_score(y_true, y_pred)
            res[(in_bin, out_bin)] = r2, 100 * wer_true, 100 * wer_pred
    df = pd.DataFrame.from_dict(res, orient='index', columns=['r2', 'wer_true', 'wer_pred'])
    df.sort_index()
    df.index = pd.MultiIndex.from_product([bin_cats] * 2, names=['in_bin', 'out_bin'])
    if plot:
        im = df.reset_index().pivot(values='r2', columns='out_bin', index='in_bin')
        fig = px.imshow(
            im,
            labels=dict(x="out-of-context bin", y="in-context bin", z="R^2"),
            x=bin_cats,
            y=bin_cats,
            zmin=-1,
            text_auto=".3f",
            color_continuous_scale='BrBG',
        )
        fig.show()
    return df

def display_test(df: pd.DataFrame, groupby=None):
    df = df.reset_index()
    df = df.reset_index().loc[df['in_bin'] != df['out_bin']].copy()
    df['wer_diff'] = np.abs(df['wer_pred'] - df['wer_true'])
    df['wer_prop'] = df['wer_diff'] / df['wer_true'] * 100
    if groupby:
        df_with = df.groupby(groupby)
    else:
        df_with = df
    df_with = df_with[['r2', 'wer_diff', 'wer_true', 'wer_prop']].describe()
    df = df.loc[
        (df['in_bin'] != bin_cats[0]) &
        (df['in_bin'] != bin_cats[-1]) &
        (df['out_bin'] != bin_cats[0]) &
        (df['out_bin'] != bin_cats[-1])
    ]
    if groupby:
        df_wo = df.groupby(groupby)
    else:
        df_wo = df
    df_wo = df_wo[['r2', 'wer_diff', 'wer_true', 'wer_prop']].describe()
    df = pd.concat([df_with, df_wo], keys=['w/ extreme bins', 'w/o extreme bins'])
    display(df.transpose().round(3))


# print('all equal fit on train')
# display_test(test(
#     train_df,
#     dict((key, (1, 0)) for key in product(range(num_bins), repeat=2)),
#     True
# ))

print('entropy fit on train')
display_test(test(train_df, ent_fit, True))

fit = train(train_df)

print('train and test on self')
display_test(test(train_df, fit, True))

# for test_mdl in test_mdls:
#     test_df = df.loc[
#         (df['latlm'] == train_latlm) &
#         (df['reslm'] == train_reslm) &
#         (df['perplm'] == train_perplm) &
#         (df['mdl'] == test_mdl) &
#         (df['part'] == train_part)
#     ]

#     print(f"train on {train_mdl}, test on {test_mdl}")
#     display_test(test(test_df, fit, True))

#     print(f"entropy fit on {test_mdl}")
#     display_test(test(test_df, ent_fit, True))


# for test_part in test_parts:
#     test_df = df.loc[
#         (df['latlm'] == train_latlm) &
#         (df['reslm'] == train_reslm) &
#         (df['perplm'] == train_perplm) &
#         (df['mdl'] == train_mdl) &
#         (df['part'] == test_part)
#     ]

#     print(f"train on {train_part}, test on {test_part}")
#     display_test(test(test_df, fit, True))

#     print(f"entropy fit on {test_part}")
#     display_test(test(test_df, ent_fit, True))

# for test_reslm in test_reslms:
#     test_df = df.loc[
#         (df['latlm'] == train_latlm) &
#         (df['reslm'] == test_reslm) &
#         (df['perplm'] == train_perplm) &
#         (df['mdl'] == train_mdl) &
#         (df['part'] == train_part)
#     ]

#     print(f"train on {train_reslm}-rescored, test on {test_reslm}-rescored")
#     display_test(test(test_df, fit, True))

#     print(f"entropy fit on {test_reslm} rescore")
#     display_test(test(test_df, ent_fit, True))

# for test_perplm in test_perplms:
#     test_df = df.loc[
#         (df['latlm'] == train_latlm) &
#         (df['reslm'] == train_reslm) &
#         (df['perplm'] == test_perplm) &
#         (df['mdl'] == train_mdl) &
#         (df['part'] == train_part)
#     ]

#     print(f"partitioned with {train_perplm}, test on {test_perplm} bins")
#     display_test(test(test_df, fit))

#     print(f"entropy fit on {test_perplm} bins")
#     display_test(test(test_df, ent_fit))


## Klakow and Peters

In [None]:
# Klakow and Peters (2002). "Testing the correlation of word error rate and perplexity"
# "... slope a is smaller for tasks that are acoustically more challenging. Hence on
# those tasks larger reductions in PP are needed to obtain a given reduction in WER." 

num_bins = 5
num_points = 100
perplm = binlm = 'rnnlm_lstm_1a'
binpart = 'dev-clean'

mdl = 'wav2vec2-large-960h-lv60'
latlm = reslm = 'null'

mdl = 'tdnn_1d_sp'
latlm  = 'tgsmall'
reslm = 'fglarge'

part = 'dev-clean'
max_bin = num_bins -1 

cfg = dict(
    num_bins=num_bins, perplm=perplm, binlm=binlm, binpart=binpart, mdl=mdl,
    latlm=latlm, reslm=reslm, part=part,
)

print(
    f"mdl {mdl}, part {part} lattice lm {latlm}, rescore lm {reslm} perplexity LM "
    f"{perplm}, bin part {binpart}, bin LM {binlm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)].copy()
bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins, by_rank=False, lower_quant=0.05, upper_quant=0.95)[1]
bins = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
df['ent_bin'] = bins
bin_cats = df['ent_bin'].dtype.categories

df_ent = agg_mean_by_lens(df, 'len', 'ent', 'ent_bin')
print('entropy by bin')
display(df_ent.round(3))
x = df_ent['ent']

df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl) &
    np.isfinite(uttwer_df['snr'])
], on=['utt', 'part', 'len'])
display(df.head())
snr_50 = agg_mean_by_lens(df, 'len', 'wer', ['snr'])
snr_50 = snr_50.loc[snr_50['wer'] < .5, 'snr'].min()
df = agg_mean_by_lens(df, 'len', 'wer', ['snr', 'ent_bin'])

snrs = df['snr'].unique()
snrs.sort()
curve_params_list = []
fits = []
for snr in snrs:
    snr_mask = df['snr'] == snr
    y = np.log(df.loc[df['snr'] == snr, "wer"])
    # fit = klakow_fit(x, y, add_intercept=True)
    fit = pg.linear_regression(x, y, True)
    fit['snr'] = snr
    fits.append(fit)
    curve_params_list.append({
        "snr": snr,
        "a": fit.loc[fit['names'] == 'ent', 'coef'].iloc[0],
        "b": np.exp(fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]),
        "se": fit.loc[fit['names'] == 'ent', 'se'].iloc[0]
    })
display(pd.concat(fits).round(3))

snr_mini, snr_midi, snr_maxi = 0, 16, len(snrs) - 1
df = df.loc[(df['snr'] >= snrs[snr_mini]) & (df['snr'] <= snrs[snr_maxi])]
df['wer'] *= 100

print("WER by (PP, SNR) with select K & P fits")
fig = px.bar(
    df,
    x='ent_bin',
    y='wer',
    color='snr',
    barmode='overlay',
    color_continuous_scale="viridis",
    opacity=1.0,
    labels={
        'wer': 'WER (%)',
        'ent_bin': 'Entropy range (nats)',
        'snr': 'SNR (dB)',
    }
)
for dict_ in (curve_params_list[snr_mini], curve_params_list[snr_midi], curve_params_list[snr_maxi]):
    y = klakow_func(np.exp(x), dict_['a'], dict_['b']) * 100
    interp_name = f"a={dict_['a']:.03f}, b={dict_['b']:.03f}"
    fig.add_scatter(
        x=bins.dtype.categories,
        y=y,
        showlegend=False,
        name=interp_name,
        mode='markers+lines',
        marker=dict(color='red'), line=dict(color='red'))
    # fig.add_annotation(
    #     x=bins.dtype.categories[2], y=y.iloc[2],
    #     text=interp_name,
    #     showarrow=True,
    #     opacity=1,
    #     font=dict(color="black", size=FONT_SIZE),
    #     bgcolor='white',
    # )
fig.update_layout(
    yaxis=dict(range=[0, 100]),
    font=dict(size=FONT_SIZE),
    margin=dict(l=0, r=0, t=10, b=0, pad=5),
    width=COL_SIZE_PX, height=COL_SIZE_PX,
    coloraxis=dict(colorbar=dict(thickness=20, title=dict(side="right"))),
)
fig.show()
fig.write_image(format_fig_path('kp-over-snr', **cfg))

# df = pd.DataFrame.from_records(curve_params_list)
# df['logb'] = np.log(df['b'])
# df['logb/a'] = df['logb'] / df['a']
# df['b^(1/a)'] = np.exp(df['logb/a'])
# print('K & P model parameter ratio by snr')
# fig = px.scatter(df, x='snr', y='logb/a')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.show()
# print("K & P model parameters by SNR")
# df = pd.melt(df, ['snr'], ['a', 'b', 'b^(1/a)'], var_name='param', value_name='val')
# fig = px.scatter(df, x='snr', y='val', color='param')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.update_layout(yaxis_range=[0, 1])
# fig.show()

# print("Predicted k by bin and snr")
# records = []
# ent_out = x[num_bins - 1]
# for snri, dict_ in enumerate(curve_params_list):
#     a, b = dict_['a'], dict_['b']
#     snr = int(snrs[snri])
#     log_b = np.log(b)
#     lwer_out = a * ent_out + log_b
#     for bin_in in range(num_bins):
#         ent_in = x[bin_in]
#         ratio_name = f'{bin_cats[bin_in]} over {bin_cats[num_bins - 1]}'
#         lwer_in = a * ent_in + log_b
#         k = lwer_in / lwer_out
#         records.append(dict(snr=snr, k=k, ratio_name=ratio_name))
# df = pd.DataFrame.from_records(records)
# fig = px.scatter(df, x='snr', y='k', color='ratio_name')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.show()


In [None]:
# klakow prediction
num_bins = 7
train_mdl = 'tdnn_1d_sp'
train_part = 'dev-clean'
train_latlm = train_reslm = train_perplm = 'tgsmall'
test_mdls = ('tri6b',)
test_parts = ('dev-other',)
test_perplms = ('tgmed', 'fglarge')

df = perp_df.copy()
bounds = bin_series(
    perp_df.loc[
        (perp_df['perplm'] == train_perplm) &
        (perp_df['part'] == train_part)
    , 'ent'], num_bins)[1]
df['perp_bin'] = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = df['perp_bin'].dtype.categories

# Klakow's model doesn't have anything to do with SNR
df = df.merge(
    uttwer_df.loc[
        np.isinf(uttwer_df['snr'])
    ], on=['utt', 'part', 'len'])
df = agg_mean_by_lens(
    df,
    'len',
    ['wer', 'ent', 'len'],
    ['perp_bin', 'perplm', 'reslm', 'latlm', 'mdl', 'part'],
)
df['lwer'] = np.log(df['wer'])

train_df = df.loc[
    (df['latlm'] == train_latlm) &
    (df['reslm'] == train_reslm) &
    (df['perplm'] == train_perplm) &
    (df['mdl'] == train_mdl) &
    (df['part'] == train_part)
]

def train(df: pd.DataFrame) -> tuple[float, float]:
    fit = pg.linear_regression(df['ent'], df['lwer'])
    a = fit.loc[fit['names'] == 'ent', 'coef'].iloc[0]
    log_b = fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]
    return a, log_b

def test(df: pd.DataFrame, fit: tuple[float, float]) -> dict[str, float]:
    wer_true = (df['wer'] * df['len']).sum() / df['len'].sum() * 100
    ent = (df['ent'] * df['len']).sum() / df['len'].sum()
    wer_pred = np.exp(fit[0] * ent + fit[1]) * 100
    y_true = df['lwer'].to_numpy()
    y_pred = fit[0] * df['ent'].to_numpy() + fit[1]
    if len(y_pred) > 1:
        r2 = r2_score(y_true, y_pred)
    else:
        r2 = None
    return dict(r2=r2, wer_true=wer_true, wer_pred=wer_pred)

def display_test(records : list[dict], groupby=None):
    df = pd.DataFrame.from_records(records)
    df['wer_diff'] = np.abs(df['wer_pred'] - df['wer_true'])
    df['wer_prop'] = df['wer_diff'] / df['wer_true'] * 100
    if groupby:
        df = df.groupby(groupby)
    df = df[['r2', 'wer_diff', 'wer_true', 'wer_prop']]
    if len(records) > 1:
        df = df.describe()
        df = df.transpose()
    display(df.round(3))

print(f"{num_bins}-fold cross-validation")
records = []
for test_bin in range(num_bins):
    test_mask = train_df['perp_bin'] == bin_cats[test_bin]
    records.append(test(train_df[test_mask], train(train_df.loc[~test_mask])))
display_test(records)

print("train and test on self")
fit = train(train_df)
display_test([test(train_df, fit)])

for test_mdl in test_mdls:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == test_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_mdl}, test on {test_mdl}")
    display_test([test(test_df, fit)])


for test_part in test_parts:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == train_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == test_part)
    ]

    print(f"train on {train_part}, test on {test_part}")
    display_test([test(test_df, fit)])

for test_perplm in test_perplms:
    test_df = df.loc[
        (df['latlm'] == train_latlm) &
        (df['reslm'] == train_reslm) &
        (df['perplm'] == test_perplm) &
        (df['mdl'] == train_mdl) &
        (df['part'] == train_part)
    ]

    print(f"train on {train_part}, test on {test_perplm}")
    display_test([test(test_df, fit)])


## Thoughts

- $k$ is relatively stable to changes in partition, SNR; moreso than $a,b$
    - $R^2$ is inflated by low SNRs by virtue of being near to the intercept
- $k$ can probably be inferred from $a,b$
    - $\log b / a$ stabilizes as SNR increases. Why?
        - Check if $\log b / a$ converges to something else on `dev-other`. Perhaps it's close enough to the `dev-clean` ratio that drastic changes in entropy dominate?
        - Maybe this is compensatory? 
    - Based on its current trajectory, the ratio of $\log b / a \approx 12$ will never be exceeded by the entropy of the partition. The corresponding perplexity is in the vicinity of $162,000$.
- $k$ can be estimated by a ratio of entropies
    - As speech becomes cleaner, errors are more likely to occur one at a time. Guesswork more closely resembles the perplexity computations, which are conditioned on single words.
- Serious problem with ratio estimates (Curran-Everett). $k$ may be compromised.
    - Easy solution is to include intercepts. Regardless, $k$ can be used to predict with or without explaining.
- Klakow's model predicts accuracy $b$ with $0$ entropy. However, $0$ entropy ought to be $0$ errors.