## Preamble

In [31]:
# load packages + declare constants

%load_ext autoreload
%autoreload

import os

import pandas as pd
import numpy as np
import pingouin as pg
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
import pingouin as pg

from analysis_utils import *

pio.renderers.default = "vscode"
pio.templates.default = "plotly_white"

# gunzip -c data/local/lm/3-gram.arpa.gz | head -n 3
LM_VOCAB_SIZE = 200_003

FIGS = '../figs'
os.makedirs(FIGS, exist_ok=True)
FIG_TYPE = 'pdf'

COL_SIZE_MM = 80
MID_MARGIN_SIZE_MM = 10

MM_TO_IN = 0.03937008
IN_TO_PX = 96

COL_SIZE_PX = int(COL_SIZE_MM * MM_TO_IN * IN_TO_PX)
MID_MARGIN_SIZE_PX = int(MID_MARGIN_SIZE_MM * MM_TO_IN * IN_TO_PX)

DOUBLE_COL_SIZE_PX = COL_SIZE_PX * 2 + MID_MARGIN_SIZE_PX

FONT_SIZE = 8
FONT_FAMILY = "Times New Roman"
FONT = dict(size=FONT_SIZE, family=FONT_FAMILY)

E_I = "<i>e<sub>i</sub></i>"
E_C = "<i>e<sub>c</sub></i>"
P_I = "<i>p<sub>i</sub></i>"
P_C = "<i>p<sub>c</sub></i>"
H_I = "<i>H<sub>i</sub></i>"
H_C = "<i>H<sub>c</sub></i>"
H_Y = "<i>H<sub>y</sub></i>"

MDL_LATLM_RESLM2RENAME = {
    'wav2vec2-large-960h-lv60_null_null': 'W2V2-L',
    'wav2vec2-base-960h_null_null': 'W2V2-B',
    'tdnn_1d_sp_tgsmall_tgsmall': 'TDNN-3',
    'tdnn_1d_sp_tgsmall_fglarge': 'TDNN-4',
    # 'tdnn_1d_sp_tgsmall_rnnlm_lstm_1a': 'TDNN-R',
    'tri6b_tgsmall_tgsmall': 'GMM-3',
}
MDL_RENAMES = tuple(MDL_LATLM_RESLM2RENAME.values())

PERPLM2RENAME = {
    'tgsmall': '3-gram',
    'fglarge': '4-gram',
    'rnnlm_lstm_1a': 'RNN',
}
PERPLM_RENAMES = tuple(PERPLM2RENAME.values())

PART2RENAME = {
    'PRV': 'CL-P',
    'ROC': 'CL-R',
    'dev-other': 'LS-O',
    'dev-clean': 'LS-C',
}
PART_RENAMES = tuple(PART2RENAME.values())

PERPLM = BINLM = 'rnnlm_lstm_1a'
BIN_QUANT_LOWER = POWER = 0.05
BIN_QUANT_UPPER = 0.95
BIN_PART = 'dev-clean'
BIN_NAMES = ('HP', 'LP', 'ZP')  # order in ascending ent

def format_fig_path(prefix : str, **kwargs) -> str:
    pth = f"{FIGS}/{prefix}"
    for key, vals in sorted(kwargs.items()):
        if isinstance(vals, (str, int, float, bool)):
            vals = (str(vals).lower(),)
        assert isinstance(vals, (set, list, tuple)) and len(vals) and all(isinstance(x, str) for x in vals)
        pth += f'-{key}'
        for val in sorted(vals):
            pth += f"_{val.replace('-', '_')}"
    pth += f'.{FIG_TYPE}'
    return pth


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# load tables

print("text_df contents")
text_df = read_text_as_df()
display(text_df.head())

print("perp_df contents")
perp_df = read_perps_as_df()
perp_df = perp_df.merge(text_df[['utt', 'len']], on='utt')
bin_bounds = bin_series(perp_df.loc[(perp_df['perplm'] == BINLM) & (perp_df['part'] == BIN_PART), 'ent'], len(BIN_NAMES), lower_quant=BIN_QUANT_LOWER, upper_quant=BIN_QUANT_UPPER, by_rank=False)[1]
ent_bin = bin_series(perp_df['ent'], bin_bounds, by_rank=False, fmt="{:.01f}")[0]
bin_cats = dict(zip(ent_bin.dtype.categories, BIN_NAMES))
print(bin_cats)
perp_df = perp_df.assign(ent_bin=ent_bin.map(bin_cats))
display(perp_df.head())

print("wer_df contents")
wer_df = read_best_wers_as_df()
display(wer_df.head())

print("uttwer_df contents")
uttwer_df = read_best_uttwers_as_df()
uttwer_df = uttwer_df.merge(text_df[['utt', 'len']], on='utt')
display(uttwer_df.head())

text_df contents


Unnamed: 0,utt,text,part,len
0,lbi-100-121669-0000,TOM THE PIPER'S SON,train-clean-460,4
1,lbi-100-121669-0001,THE PIG WAS EAT AND TOM WAS BEAT AND TOM RAN C...,train-clean-460,15
2,lbi-100-121669-0002,HE NEVER DID ANY WORK EXCEPT TO PLAY THE PIPES...,train-clean-460,36
3,lbi-100-121669-0003,BUT HE WAS SO SLY AND CAUTIOUS THAT NO ONE HAD...,train-clean-460,42
4,lbi-100-121669-0004,AND THEY LIVED ALL ALONE IN A LITTLE HUT AWAY ...,train-clean-460,51


perp_df contents
{'(3.4,4.5]': 'HP', '(4.5,5.6]': 'LP', '(5.6,6.8]': 'ZP'}


Unnamed: 0,utt,perp,perplm,part,ent,len,ent_bin
0,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_002583_002756,188.283,tgmed,PRV,5.237946,11,LP
1,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_002781_002912,2046.096,tgmed,PRV,7.623689,4,
2,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_006175_006438,839.002,tgmed,PRV,6.732213,14,ZP
3,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_006553_006711,399.898,tgmed,PRV,5.99121,8,ZP
4,PRV_se0_ag1_f_01_PRV_se0_ag1_f_01_1_008786_008920,2707.916,tgmed,PRV,7.903935,4,


wer_df contents


Unnamed: 0,wer,ins,del,sub,lmwt,wip,mdl,latlm,reslm,part,snr,acc
0,0.0524,265,298,2287,11,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9476
1,0.6015,1029,10167,21528,10,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,2.0,0.3985
2,0.2877,740,3749,11163,12,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,8.0,0.7123
3,0.193,589,1905,8006,12,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,11.0,0.807
4,0.8439,441,23839,21628,8,0.0,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,-3.0,0.1561


uttwer_df contents


Unnamed: 0,utt,wer,mdl,latlm,reslm,part,snr,acc,len
0,lbi-1272-128104-0000,0.0588,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9412,17
1,lbi-1272-128104-0001,0.1,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9,10
2,lbi-1272-128104-0002,0.0312,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9688,32
3,lbi-1272-128104-0003,0.0417,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.9583,24
4,lbi-1272-128104-0004,0.1618,tdnn_1d_sp,tgsmall,tgsmall,dev-clean,30.0,0.8382,68


## Perplexity

In [67]:
print("proportion bins")
df = perp_df.loc[perp_df['perplm'] == PERPLM]
df = df.groupby('part')['ent_bin'].value_counts(normalize=True, dropna=False).reset_index()
df = df.assign(part=df['part'].map(PART2RENAME)).dropna().pivot(values='proportion', columns='ent_bin', index='part')
df = df.reindex(columns=BIN_NAMES, index=PART_RENAMES)
df['total'] = df.sum(1)
df = (df * 100).round(1)
display(df)
print(df.to_latex(float_format="{:.01f}".format))

proportion bins


ent_bin,HP,LP,ZP,total
part,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CL-P,23.8,41.9,23.4,89.1
CL-R,18.3,38.7,26.9,84.0
LS-O,39.9,39.5,11.2,90.6
LS-C,37.2,40.1,12.6,89.9


\begin{tabular}{lrrrr}
\toprule
ent_bin & HP & LP & ZP & total \\
part &  &  &  &  \\
\midrule
CL-P & 23.8 & 41.9 & 23.4 & 89.1 \\
CL-R & 18.3 & 38.7 & 26.9 & 84.0 \\
LS-O & 39.9 & 39.5 & 11.2 & 90.6 \\
LS-C & 37.2 & 40.1 & 12.6 & 89.9 \\
\bottomrule
\end{tabular}



In [None]:
print("entropy/perplexity by partition and LM")
df = agg_mean_by_lens(perp_df, 'len', 'ent', ['part', 'perplm'])
df = df.assign(perplm=df['perplm'].map(PERPLM2RENAME), part=df['part'].map(PART2RENAME)).dropna()
df = df.pivot(values='ent', index='part', columns='perplm')
df = df.reindex(columns=PERPLM_RENAMES, index=PART_RENAMES)
df.loc['mean', :] = df.mean(0)
df['mean'] = df.mean(1)
display(df.round(1))
print(df.to_latex(float_format="{:.01f}".format))


In [None]:
print('distribution of per-utt entropy by partition and LM')
df = perp_df.assign(perplm=perp_df['perplm'].map(PERPLM2RENAME), part=perp_df['part'].map(PART2RENAME))
del df['ent_bin']  # don't exclude un-binned terms
df = df.dropna()
fig = px.box(
    df, y='ent', color='perplm', x='part',
    # box=True,
    labels=dict(ent=H_Y, part='Partition', lm="LM", perplm="LM"),
    category_orders=dict(part=PART_RENAMES, perplm=PERPLM_RENAMES),
)
fig.update_traces(marker=dict(size=4), line=dict(width=1))
fig.update_layout(
    legend=dict(orientation="h", yanchor="bottom", y=1.0),
    yaxis=dict(tickangle=270, title_standoff=5),
    margin=dict(l=0, r=10, t=10, b=30),
    font=FONT,
    width=COL_SIZE_PX, height=int(COL_SIZE_PX),
)
# fig.show()
fig.write_image(format_fig_path("violin-ent"))

In [None]:
df = perp_df.assign(perplm=perp_df['perplm'].map(PERPLM2RENAME), part=perp_df['part'].map(PART2RENAME))
del df['ent_bin']
df = df.dropna()
display(pg.normality(df, dv='ent', group='perplm', method='normaltest').round(3))

print("pairwise spearman correlations of entropy across LMs")
df = df.pivot(values='ent', index='utt', columns='perplm')
display(pg.pairwise_corr(df, columns=df.columns, alternative='greater', method='spearman').round(3))


## WER

In [68]:
df = perp_df.loc[perp_df['perplm'] == PERPLM].merge(uttwer_df.loc[np.isinf(uttwer_df['snr'])], on=['utt', 'part', 'len'])

df['mdl'] = (df['mdl'] + '_' + df['latlm'] + '_' + df['reslm']).map(MDL_LATLM_RESLM2RENAME)
df['part'] = df['part'].map(PART2RENAME)

df = agg_mean_by_lens(df, 'len', 'wer',  ['part', 'mdl', 'ent_bin'])
df_all = wer_df.loc[np.isinf(wer_df['snr'])].copy()
df_all['ent_bin'] = 'all'
df_all['mdl'] = (df_all['mdl'] + '_' + df_all['latlm'] + '_' + df_all['reslm']).map(MDL_LATLM_RESLM2RENAME)
df_all['part'] = df_all['part'].map(PART2RENAME)
df = pd.concat([df, df_all])[['part', 'mdl', 'ent_bin', 'wer']]
df = df.dropna()

df = (df.pivot(values='wer', index=['part', 'ent_bin'], columns='mdl') * 100).round(1)
display(df)
print(df.to_latex(float_format="{:.01f}".format))



Unnamed: 0_level_0,mdl,GMM-3,TDNN-3,TDNN-4,W2V2-B,W2V2-L
part,ent_bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
CL-P,HP,73.5,56.2,52.4,50.6,37.8
CL-P,LP,79.2,64.1,62.1,59.2,47.0
CL-P,ZP,83.3,69.0,68.6,66.0,54.3
CL-P,all,78.7,61.9,59.4,58.2,46.2
CL-R,HP,45.8,31.2,26.4,25.2,14.6
CL-R,LP,54.0,38.5,35.2,31.5,20.8
CL-R,ZP,58.1,43.6,41.6,38.2,26.0
CL-R,all,53.9,37.3,33.9,32.8,23.2
LS-C,HP,8.4,3.7,2.4,2.2,1.5
LS-C,LP,11.1,4.9,3.5,3.3,2.2


\begin{tabular}{llrrrrr}
\toprule
 & mdl & GMM-3 & TDNN-3 & TDNN-4 & W2V2-B & W2V2-L \\
part & ent_bin &  &  &  &  &  \\
\midrule
\multirow[t]{4}{*}{CL-P} & HP & 73.5 & 56.2 & 52.4 & 50.6 & 37.8 \\
 & LP & 79.2 & 64.1 & 62.1 & 59.2 & 47.0 \\
 & ZP & 83.3 & 69.0 & 68.6 & 66.0 & 54.3 \\
 & all & 78.7 & 61.9 & 59.4 & 58.2 & 46.2 \\
\cline{1-7}
\multirow[t]{4}{*}{CL-R} & HP & 45.8 & 31.2 & 26.4 & 25.2 & 14.6 \\
 & LP & 54.0 & 38.5 & 35.2 & 31.5 & 20.8 \\
 & ZP & 58.1 & 43.6 & 41.6 & 38.2 & 26.0 \\
 & all & 53.9 & 37.3 & 33.9 & 32.8 & 23.2 \\
\cline{1-7}
\multirow[t]{4}{*}{LS-C} & HP & 8.4 & 3.7 & 2.4 & 2.2 & 1.5 \\
 & LP & 11.1 & 4.9 & 3.5 & 3.3 & 2.2 \\
 & ZP & 16.2 & 7.8 & 5.9 & 6.6 & 4.4 \\
 & all & 10.5 & 4.7 & 3.3 & 3.3 & 2.2 \\
\cline{1-7}
\multirow[t]{4}{*}{LS-O} & HP & 22.1 & 10.0 & 6.5 & 6.3 & 3.2 \\
 & LP & 28.4 & 13.1 & 9.7 & 10.0 & 5.2 \\
 & ZP & 37.0 & 18.7 & 15.3 & 16.2 & 8.5 \\
 & all & 26.1 & 12.2 & 8.7 & 8.8 & 4.6 \\
\cline{1-7}
\bottomrule
\end{tabular}



## Zhang et al

In [None]:
# Zhang et al (2023) "Estimate the noise effect on automatic speech recognition
# accuracy for mandarin by an approach associating articulation index"
# FIXME(sdrobert): the fit is very bad if we use eq. 12

latlm = 'tgsmall'
reslm = 'tgsmall'
part = 'dev-other'
desc = f"({part} partition, {latlm} lattice LM, and {reslm} rescoring lm)"
num_points = 100
fit_inverse = False

df = wer_df.replace(dict(latlm=dict(null=latlm), reslm=dict(null=reslm)))
df = df.loc[
    (df['latlm'] == latlm) &
    (df['reslm'] == reslm) &
    (df['part'] == part)
].copy()

idx = np.isinf(df['snr'])
df, Ainvs = df.loc[~idx], df.loc[idx, ['mdl', 'acc']]
snr_min = df['snr'].min() - 1
snr_max = df['snr'].max() + 1
x_interp = np.linspace(snr_min, snr_max, num_points)

mdls = df['mdl'].unique()
assert all(mdls == Ainvs['mdl'].unique())

df['acc'] *= 100

# fit = []
# fig = go.Figure()
# for mdl_idx, mdl in enumerate(mdls):
#     colour = px.colors.qualitative.Plotly[mdl_idx]
#     df_ = df.loc[df['mdl'] == mdl]
#     Ainv = Ainvs.loc[Ainvs['mdl'] == mdl, 'acc'].iloc[0]
#     A_init = 1 / Ainv
#     N = len(df_)
#     x = df_['snr'].array
#     y = df_['acc'].array
#     A, B, C = zhang_fit(x, y, fit_inverse)
#     y_pred = zhang_func(x, A, B, C)
#     r2 = r2_score(y, y_pred)
#     fit.append(dict(mdl=mdl, A=A, B=B, C=C, r2=r2))
#     y_interp = zhang_func(x_interp, A, B, C)
#     fig.add_scatter(
#         x=x_interp, y=y_interp * 100,
#         name=f"{mdl} fit",
#         mode='lines',
#         opacity=0.5,
#         showlegend=False,
#         line=dict(color=colour),
#     )
#     fig.add_scatter(
#         x=x, y=df_['acc'] * 100,
#         name=mdl, mode='markers',
#         marker=dict(color=colour),
#     )
#     # fig.add_annotation(
#     #     x=x_interp[ratio * (mdl_idx + 1)], y=y_interp[ratio * (mdl_idx + 1)] * 100,
#     #     text=f"A={A:.02f},B={B:.02f},C={C:.02f}",
#     #     showarrow=True,
#     #     font=dict(color=colour, size=FONT_SIZE),
#     # )
# print(f"Zhang et al fits by model {desc}")
# display(pd.DataFrame.from_records(fit).round(3))

fig = px.scatter(
    df, x='snr', y='acc', color='mdl',
)

fig.update_traces(marker=dict(size=5), line=dict(width=2))
fig.update_layout(
    xaxis=dict(title='SNR (dB)', range=[snr_min, snr_max], tickformat='d'),
    yaxis=dict(title='Accuracy (%)', range=[0, 100], tickformat='d'),
    legend=dict(title='Model', yanchor="top", y=0.99, xanchor='left', x=0.01),
    margin=dict(l=0, r=10, t=10, b=0),
    font=FONT,
    width=COL_SIZE_PX, height=COL_SIZE_PX,
)
fig.show()
# fig.write_image(format_fig_path("zhang", latlm=latlm, reslm=reslm, part=part))


In [None]:
# wer by perp

perplm = 'tgsmall'

mdl = 'tdnn_1d_sp'
latlm = reslm = 'tgsmall'
num_points = 100
part = 'dev-clean'
print(
    f"mdl {mdl}, partition {part}, lattice LM {latlm}, rescore LM {reslm}, "
    f"perlexity LM {perplm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)]
df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl)
], on=['utt', 'part'])
df = df.loc[df['snr'].isnull()]  # without noise
ymin, ymax = df['wer'].quantile(0.05), df['wer'].quantile(0.95)
xmin, xmax = df['perp'].quantile(0.05), df['perp'].quantile(0.95)
perp_interp = np.linspace(xmin, xmax, num_points)

print("per-utterance WER by perplexity")
fig = px.scatter(df, x='perp', y='wer')
fig.update_xaxes(type='log', range=[np.log10(xmin), np.log10(xmax)])
fig.update_yaxes(range=[ymin, ymax])
fig.show()

## Boothroyd and Nittrouer

In [45]:
# idealized curve, following B&N's ZP -> LP, ZP -> HP
k_low = 1.38
k_high = 2.72
k_max = 500
range_ = [0, 100]
num_pts = 100
tickvals = list(range(range_[0], range_[1] + 1, 25))

x = np.linspace(0.0, 1.0, num_pts)

fig = go.Figure()
for k, x0, ax, ay in ((1, .66, 20, 20), (k_low, .44, 30, 30), (k_high, .22, 40, 40), (k_max, 0.02, 25, 25)):
    fig.add_scatter(
        x=100 * x,
        y=100 * (1 - boothroyd_func(1 - x, k)),
        showlegend=False,
        line=dict(color='black')
    )
    fig.add_annotation(
        x=100 * x0, y=100 * (1 - boothroyd_func(1 - x0, k)),
        text=f"<i>k</i> = {k}",
        showarrow=True,
        arrowhead=1,
        ax=ax, ay=ay,
        arrowcolor="black",
    )
fig.update_layout(
    width=COL_SIZE_PX // 1.6, height=COL_SIZE_PX // 1.6,
    margin=dict(l=0, r=0, t=0, b=0),
    font=FONT,
    xaxis=dict(title=f"Accuracy {P_I} (%)", range=range_, tickvals=tickvals, mirror=True, showline=True),
    yaxis=dict(title=f"Accuracy {P_C} (%)", range=range_, tickvals=tickvals, mirror=True, showline=True),
)
fig.update_xaxes(title_standoff=5)
fig.update_yaxes(title_standoff=5)
fig.write_image(format_fig_path('bn'))

In [42]:
# stats

# print('merging')
# df = perp_df.loc[perp_df['perplm'] == PERPLM].merge(uttwer_df, on=['utt', 'part', 'len'])
# df = df.assign(mdl=(df['mdl'] + '_' + df['latlm'] + '_' + df['reslm']).map(MDL_LATLM_RESLM2RENAME), part=df['part'].map(PART2RENAME))
# df = agg_mean_by_lens(df, 'len', 'wer',  ['snr', 'part', 'mdl', 'ent_bin'])
# df['lwer'] = np.log(df['wer'])
# df['Wer'] = 100 * df['wer']
# df['acc'] = 1 - df['wer']
# df['Acc'] = 100 * df['acc']
# df = df.loc[np.isfinite(df['snr'])].dropna()

# mask = df['ent_bin'] == BIN_NAMES[-1]
# df, df_out = df.loc[~mask], df.loc[mask]
# df = df.merge(df_out, on=['snr', 'part', 'mdl'], suffixes=('_in', '_out'))
# df = df.assign(ent_bin_in=df.ent_bin_in.cat.remove_unused_categories())
# df['k'] = df['lwer_in'] / df['lwer_out']

# fits = []
# print('fitting all')
# fit = boothroyd_fit(df)
# fit['mdl'], fit['part'] = 'all', 'all'
# go.Figure().add_histogram(x=fit.iloc[0]['bootstrap']).show()

# fits.append(fit)

# for part in PART_RENAMES:
#     print(f'fitting {part}')
#     fit = boothroyd_fit(df.loc[df['part'] == part])
#     fit['mdl'], fit['part'] = 'all', part
#     fits.append(fit)

# for mdl in MDL_RENAMES:
#     print(f'fitting {mdl}')
#     fit = boothroyd_fit(df.loc[df['mdl'] == mdl])
#     fit['mdl'], fit['part'] = mdl, 'all'
#     fits.append(fit)
#     for part in PART_RENAMES:
#         print(f'fitting {mdl} x {part}')
#         fit = boothroyd_fit(df.loc[(df['mdl'] == mdl) & (df['part'] == part)])
#         fit['mdl'], fit['part'] = mdl, part
#         fits.append(fit)

# fits = pd.concat(fits)
# fits['coef+ci'] = fits.apply(lambda row: f"{row['coef']:.2f} [{row['ci_low']:.2f}, {row['ci_high']:.2f}]", axis=1)
# fits = fits.pivot(values='coef+ci', columns='name', index=['mdl', 'part'])

# display(fits)
# print(fits.swaplevel(0, 1).to_latex())

range_ = [0, 100]
x_interp = np.linspace(*range_, 100)
tickvals = list(range(range_[0], range_[1] + 1, 25))

for mdl in MDL_RENAMES:
    df_ = df.loc[df['mdl'] == mdl]
    fig = px.scatter(
        df_, x='Acc_out', y='Acc_in', color='part', symbol='ent_bin_in',
        symbol_sequence=list(range(len(BIN_NAMES) - 1)),
        category_orders=dict(part=PART_RENAMES),
    )
    for i, trace in enumerate(fig.data):
        if trace.mode == 'markers':
            name = trace.name.split(', ')
            if name[1] in BIN_NAMES[1:]:
                trace['name'] = ''
                trace['showlegend'] = False
            else:
                trace['name'] = name[0]
    for bin in range(len(BIN_NAMES) - 1):
        fig.add_scatter(
            y=[None], mode='markers',
            marker=dict(color='black', symbol=bin),
            legend="legend2",
            name=BIN_NAMES[bin],
        )
        k = float(fits.loc[mdl, 'LS-C'][BIN_NAMES[bin]].split(" ")[0])
        fig.add_scatter(
            x=x_interp,
            y=(1 - boothroyd_func(1 - x_interp / 100, k)) * 100,
            line=dict(color="black", width=1),
            showlegend=False,
        )
    fig.add_scatter(
        x=range_,
        y=range_,
        mode='lines',
        line=dict(color="grey", width=1, dash='dash'),
        showlegend=False,
    )
    fig.update_traces(marker=dict(line_width=0.5, size=4))
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0, pad=0),
        font=FONT,
        legend=dict(
            title_text='Partition',
            yanchor="bottom",
            y=0.01,
            xanchor="right",
            x=0.99,
            bgcolor='rgba(0,0,0,0)',
        ),
        legend2=dict(
            title_text="In-context bin",
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            bgcolor='rgba(0,0,0,0)',
        ),
        xaxis=dict(title=f"Accuracy {P_I} (%)", range=range_, tickvals=tickvals, mirror=True, showline=True),
        yaxis=dict(title=f"Accuracy {P_C} (%)", range=range_, tickvals=tickvals, mirror=True, showline=True),
    )
    fig.update_xaxes(title_standoff=5)
    fig.update_yaxes(title_standoff=5)
    fig.write_image(format_fig_path('acc-ratio', mdl=mdl), width=int(COL_SIZE_PX * 0.75), height=int(COL_SIZE_PX * 0.75), scale=1)
    # fig.show()

    fig = px.scatter(
        df_, x='Wer_out', y='k', color='part', symbol='ent_bin_in',
        symbol_sequence=list(range(len(BIN_NAMES) - 1)),
        category_orders=dict(part=PART_RENAMES),
    )
    for i, trace in enumerate(fig.data):
        if trace.mode == 'markers':
            name = trace.name.split(', ')
            if name[1] in BIN_NAMES[1:]:
                trace['name'] = ''
                trace['showlegend']=False
            else:
                trace['name'] = name[0]
    for bin in range(len(BIN_NAMES) - 1):
        fig.add_scatter(
            y=[None], mode='markers',
            marker=dict(color='black', symbol=bin),
            legend="legend2",
            name=BIN_NAMES[bin],
        )
        k = float(fits.loc[mdl, 'LS-C'][BIN_NAMES[bin]].split(" ")[0])
        fig.add_scatter(
            x=[0, 101],
            y=[k, k],
            mode="lines",
            line=dict(color="black", width=1),
            showlegend=False,
        )
    fig.update_traces(marker=dict(line_width=0.5, size=4))
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        font=FONT,
        legend=dict(
            title_text='Partition',
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            bgcolor='rgba(0,0,0,0)',
        ),
        legend2=dict(
            title_text=f"In-context bin",
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.37,
            bgcolor='rgba(0,0,0,0)',
        ),
        xaxis=dict(title=f"Error rate {E_I} (%)", range=[0, 101], tickvals=tickvals, mirror=True, showline=True),
        yaxis=dict(title=f"Pointwise <i>k</i>", range=[1, 2.5], tickvals=[1, 1.5, 2, 2.5], mirror=True, showline=True),
    )
    fig.update_xaxes(title_standoff=5)
    fig.update_yaxes(title_standoff=5)
    fig.write_image(format_fig_path('point-k', mdl=mdl), width=int(COL_SIZE_PX * 0.75), height=int(COL_SIZE_PX * 0.75), scale=1)
    # fig.show()























## Klakow and Peters

In [None]:
# Klakow and Peters (2002). "Testing the correlation of word error rate and perplexity"
# "... slope a is smaller for tasks that are acoustically more challenging. Hence on
# those tasks larger reductions in PP are needed to obtain a given reduction in WER." 

num_bins = 5
num_points = 100
perplm = binlm = 'rnnlm_lstm_1a'
binpart = 'dev-clean'

mdl = 'wav2vec2-large-960h-lv60'
latlm = reslm = 'null'

mdl = 'tdnn_1d_sp'
latlm  = 'tgsmall'
reslm = 'fglarge'

part = 'dev-clean'
max_bin = num_bins -1 

cfg = dict(
    num_bins=num_bins, perplm=perplm, binlm=binlm, binpart=binpart, mdl=mdl,
    latlm=latlm, reslm=reslm, part=part,
)

print(
    f"mdl {mdl}, part {part} lattice lm {latlm}, rescore lm {reslm} perplexity LM "
    f"{perplm}, bin part {binpart}, bin LM {binlm}"
)

df = perp_df.loc[(perp_df['perplm'] == perplm) & (perp_df['part'] == part)].copy()
bounds = bin_series(perp_df.loc[(perp_df['perplm'] == binlm) & (perp_df['part'] == binpart), 'ent'], num_bins, by_rank=False, lower_quant=0.05, upper_quant=0.95)[1]
bins = bin_series(df['ent'], bounds, by_rank=False, fmt="{:.01f}")[0]
df['ent_bin'] = bins
bin_cats = df['ent_bin'].dtype.categories

df_ent = agg_mean_by_lens(df, 'len', 'ent', 'ent_bin')
print('entropy by bin')
display(df_ent.round(3))
x = df_ent['ent']

df = df.merge(uttwer_df.loc[
    (uttwer_df['reslm'] == reslm) &
    (uttwer_df['latlm'] == latlm) &
    (uttwer_df['mdl'] == mdl) &
    np.isfinite(uttwer_df['snr'])
], on=['utt', 'part', 'len'])
display(df.head())
snr_50 = agg_mean_by_lens(df, 'len', 'wer', ['snr'])
snr_50 = snr_50.loc[snr_50['wer'] < .5, 'snr'].min()
df = agg_mean_by_lens(df, 'len', 'wer', ['snr', 'ent_bin'])

snrs = df['snr'].unique()
snrs.sort()
curve_params_list = []
fits = []
for snr in snrs:
    snr_mask = df['snr'] == snr
    y = np.log(df.loc[df['snr'] == snr, "wer"])
    # fit = klakow_fit(x, y, add_intercept=True)
    fit = pg.linear_regression(x, y, True)
    fit['snr'] = snr
    fits.append(fit)
    curve_params_list.append({
        "snr": snr,
        "a": fit.loc[fit['names'] == 'ent', 'coef'].iloc[0],
        "b": np.exp(fit.loc[fit['names'] == 'Intercept', 'coef'].iloc[0]),
        "se": fit.loc[fit['names'] == 'ent', 'se'].iloc[0]
    })
display(pd.concat(fits).round(3))

snr_mini, snr_midi, snr_maxi = 0, 16, len(snrs) - 1
df = df.loc[(df['snr'] >= snrs[snr_mini]) & (df['snr'] <= snrs[snr_maxi])]
df['wer'] *= 100

print("WER by (PP, SNR) with select K & P fits")
fig = px.bar(
    df,
    x='ent_bin',
    y='wer',
    color='snr',
    barmode='overlay',
    color_continuous_scale="viridis",
    opacity=1.0,
    labels={
        'wer': 'WER (%)',
        'ent_bin': 'Entropy range (nats)',
        'snr': 'SNR (dB)',
    }
)
for dict_ in (curve_params_list[snr_mini], curve_params_list[snr_midi], curve_params_list[snr_maxi]):
    y = klakow_func(np.exp(x), dict_['a'], dict_['b']) * 100
    interp_name = f"a={dict_['a']:.03f}, b={dict_['b']:.03f}"
    fig.add_scatter(
        x=bins.dtype.categories,
        y=y,
        showlegend=False,
        name=interp_name,
        mode='markers+lines',
        marker=dict(color='red'), line=dict(color='red'))
    # fig.add_annotation(
    #     x=bins.dtype.categories[2], y=y.iloc[2],
    #     text=interp_name,
    #     showarrow=True,
    #     opacity=1,
    #     font=dict(color="black", size=FONT_SIZE),
    #     bgcolor='white',
    # )
fig.update_layout(
    yaxis=dict(range=[0, 100]),
    font=dict(size=FONT_SIZE),
    margin=dict(l=0, r=0, t=10, b=0, pad=5),
    width=COL_SIZE_PX, height=COL_SIZE_PX,
    coloraxis=dict(colorbar=dict(thickness=20, title=dict(side="right"))),
)
fig.show()
fig.write_image(format_fig_path('kp-over-snr', **cfg))

# df = pd.DataFrame.from_records(curve_params_list)
# df['logb'] = np.log(df['b'])
# df['logb/a'] = df['logb'] / df['a']
# df['b^(1/a)'] = np.exp(df['logb/a'])
# print('K & P model parameter ratio by snr')
# fig = px.scatter(df, x='snr', y='logb/a')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.show()
# print("K & P model parameters by SNR")
# df = pd.melt(df, ['snr'], ['a', 'b', 'b^(1/a)'], var_name='param', value_name='val')
# fig = px.scatter(df, x='snr', y='val', color='param')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.update_layout(yaxis_range=[0, 1])
# fig.show()

# print("Predicted k by bin and snr")
# records = []
# ent_out = x[num_bins - 1]
# for snri, dict_ in enumerate(curve_params_list):
#     a, b = dict_['a'], dict_['b']
#     snr = int(snrs[snri])
#     log_b = np.log(b)
#     lwer_out = a * ent_out + log_b
#     for bin_in in range(num_bins):
#         ent_in = x[bin_in]
#         ratio_name = f'{bin_cats[bin_in]} over {bin_cats[num_bins - 1]}'
#         lwer_in = a * ent_in + log_b
#         k = lwer_in / lwer_out
#         records.append(dict(snr=snr, k=k, ratio_name=ratio_name))
# df = pd.DataFrame.from_records(records)
# fig = px.scatter(df, x='snr', y='k', color='ratio_name')
# fig.add_vline(x=snr_50, line_dash='dash', line_color='black', annotation_text='50% acc')
# fig.show()
