In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
from os.path import join
# import config
import numpy as np
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
from neuro import analyze_helper, viz
from neuro.features.feat_select import get_alphas
import matplotlib.patches as mpatches
import dvu
import matplotlib
dvu.set_style()


In [None]:
# saved out in main_curves nb
data = joblib.load('results_best_ensemble.pkl')
r, cols_varied, mets = data['r'], data['cols_varied'], data['mets']

In [None]:
rr = r

In [None]:
r[r['qa_questions_version'] ==
    'v3_boostexamples_merged']['weight_enet_mask_num_nonzero'].value_counts()

In [None]:
# r = rr
r = r[r.qa_questions_version.isin(['', 'v3_boostexamples_merged'])]
# r = r[r.feature_selection_alpha == -1]
r = r[~r.feature_space.isin(
    ['meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B'])]
cols_varied = [c for c in cols_varied if not c in [
    'feature_selection_stability_seeds']]

r = r[r.qa_embedding_model.isin(['', 'ensemble1', 'ensemble2'])]

# only keep feature selection with stability
r = r[(r.feature_selection_alpha < 0) | (
    r.feature_selection_stability_seeds > 0)]

r.shape

### Check runs (full grid)

In [None]:
d = r
# d = d[d.subject.isin(['S01', 'S02', 'S03'])]
d = d[~(d.num_stories == 15)]
# d = d[d.feature_selection_alpha < 0]
d = d[
    (d.feature_selection_alpha < 0) |
    ((d.feature_space_simplified == 'qa_embedder')
     & (d.feature_selection_alpha == get_alphas('qa_embedder')[3]))
]
d = d.groupby(cols_varied)[['corrs_test_mean']].mean()
cols_top = ['feature_space', 'embedding_layer',
            'qa_embedding_model', 'qa_questions_version', 'feature_selection_alpha']
d = (
    d.pivot_table(index=[c for c in cols_varied if not c in cols_top],
                  columns=cols_top, values='corrs_test_mean', aggfunc='mean')
    .sort_index(axis=1)
)
display(
    d.style
    .background_gradient(cmap='viridis', axis=1)
    .format(precision=3)

)

## Check simplified table

In [None]:
rr.feature_space.unique()

In [None]:
r.feature_space_simplified.unique()

In [None]:
d = r
d.num_stories = d.num_stories.replace(-1, 100)
num_stories_list = [5, 10, 20, 100]
d = d[d.num_stories.isin(num_stories_list)]
# d = d[d.num_stories == 10]
# d = d[d.subject.isin(['S01', 'S02', 'S03'])]
# d = d[~d.subject.isin(['S01', 'S02', 'S03'])]
d = d[
    (d.feature_selection_alpha < 0) |
    ((d.feature_space_simplified == 'qa_embedder')
     & (d.feature_selection_alpha == get_alphas('qa_embedder')[3]))
]
group_cols = ['subject', 'num_stories',
              'feature_space_simplified', 'feature_selection_alpha']
metric_sort = 'corrs_tune_pc_weighted_mean'
d = d.sort_values(
    by=metric_sort, ascending=False)

d = d.groupby(group_cols)[mets]
d = d.first().reset_index()
cols_top = ['feature_space_simplified', 'feature_selection_alpha']
d_tab = (
    d.pivot_table(index=[c for c in group_cols if not c in cols_top],
                  columns=cols_top, values='corrs_test_mean', aggfunc='mean')
    .sort_index(axis=1)
)
display(
    d_tab.style
    .background_gradient(cmap='magma', axis=1)
    .format(precision=3)
)

In [None]:
val_all_qa_100 = d_tab[('qa_embedder', 0.28)][np.array([x[1]
                                                        for x in d_tab.index]) == 100].mean()
val_all_llama_100 = d_tab[('llama', -1.0)][np.array([x[1]
                                                     for x in d_tab.index]) == 100].mean()
print(f'val_all_qa_100: {val_all_qa_100:.3f} val_all_llama_100: {val_all_llama_100:.3f} improvement: {(val_all_qa_100 - val_all_llama_100) / val_all_llama_100:.2%}')

val_all_qa_10 = d_tab[('qa_embedder', 0.28)][np.array([x[1]
                                                       for x in d_tab.index]) == 5].mean()
val_all_llama_10 = d_tab[('llama', -1.0)][np.array([x[1]
                                                    for x in d_tab.index]) == 5].mean()
print(f'val_all_qa_10: {val_all_qa_10:.3f} val_all_llama_10: {val_all_llama_10:.3f} improvement: {(val_all_qa_10 - val_all_llama_10) / val_all_llama_10:.2%}')

In [None]:
# plot
d['legend'] = list(
    zip(d.feature_space_simplified.map(viz.feature_space_rename), d.feature_selection_alpha))
# ('BERT', -1.0), ('Eng1000', -1.0), ('LLaMA', -1.0),
#    ('QA-Emb', -1.0), ('QA-Emb', 0.28)
d['legend'] = d['legend'].map(lambda x: {
    ('BERT', -1.0): 'BERT',
    ('Eng1000', -1.0): 'Eng1000',
    ('LLaMA', -1.0): 'LLaMA (best)',
    ('QA-Emb', -1.0): 'QA',
    ('QA-Emb', 0.28): 'QA (35 questions)'
}.get(x, x))
kwargs = dict(
    x='num_stories',
    y='corrs_test_mean',
    hue='legend',
    hue_order=['Eng1000', 'BERT', 'LLaMA (best)', 'QA', 'QA (35 questions)'],
    palette=['tomato', '#aaa', '#000', 'C0', 'cadetblue'],
    dodge=True,
)
# sns.boxplot(**kwargs, fill=False)
# fig = plt.figure(figsize=(12, 6))
matplotlib.rcParams.update({'font.size': 14})
plt.figure(dpi=300, figsize=(8, 4))
ax = sns.barplot(**kwargs, data=d, alpha=0.2, errorbar='se',
                 err_kws={'alpha': 0.4}, legend=False)
sns.stripplot(
    **kwargs, data=d[d.subject.isin(['S01', 'S02', 'S03'])], jitter=True, size=4, legend=False)
# sns.stripplot(**kwargs, jitter=True, size=4)
sns.stripplot(
    **kwargs, data=d[~d.subject.isin(['S01', 'S02', 'S03'])], jitter=True, size=4)

# ylim bottom to 0
plt.ylim(bottom=0)

# replace "100" with "All" on xticklabels
xtick_labels = ax.get_xticklabels()
ax.set_xticklabels(['All' if label.get_text() ==
                   '100' else label.get_text() for label in xtick_labels])

plt.xlabel("Number of stories used for training")
plt.ylabel("Test correlation")

# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.legend(loc='upper left')  # , bbox_to_anchor=(1, 0.5))
viz.savefig('subsample/subsample_barplot.pdf', bbox_inches='tight')
plt.savefig('subsample/subsample_barplot.png',
            bbox_inches='tight', dpi=500)
plt.show()

### Plot for talk

In [None]:
d.num_stories.unique()

In [None]:
# plot
dp = d[d.num_stories.isin([5, 100])]
dp['legend'] = list(
    zip(dp.feature_space_simplified.map(viz.feature_space_rename), dp.feature_selection_alpha))
# ('BERT', -1.0), ('Eng1000', -1.0), ('LLaMA', -1.0),
#    ('QA-Emb', -1.0), ('QA-Emb', 0.28)
dp['legend'] = dp['legend'].map(lambda x: {
    ('BERT', -1.0): 'BERT',
    ('Eng1000', -1.0): 'Eng1000',
    ('LLaMA', -1.0): 'LLaMA (best)',
    ('QA-Emb', -1.0): 'QA',
    ('QA-Emb', 0.28): 'QA (35 questions)'
}.get(x, x))
kwargs = dict(
    x='num_stories',
    y='corrs_test_mean',
    hue='legend',
    hue_order=['Eng1000', 'BERT', 'LLaMA (best)', 'QA', 'QA (35 questions)'],
    palette=['tomato', '#aaa', '#000', 'C0', 'navy'],
    dodge=True,
)
# sns.boxplot(**kwargs, fill=False)
# fig = plt.figure(figsize=(12, 6))
matplotlib.rcParams.update({'font.size': 14})
plt.figure(dpi=300, figsize=(8, 4))
ax = sns.barplot(**kwargs, data=dp, legend=True, errorbar=None)
plt.ylim(bottom=0)

xtick_labels = ax.get_xticklabels()
ax.set_xticklabels(['All' if label.get_text() ==
                   '100' else label.get_text() for label in xtick_labels])

plt.xlabel("Number of stories used for training")
plt.ylabel("Predictive performance")

# plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.legend(loc='upper left')  # , bbox_to_anchor=(1, 0.5))
plt.savefig('subsample/subsample_barplot.pdf', bbox_inches='tight')
plt.savefig('subsample/subsample_barplot.png',
            bbox_inches='tight', dpi=500)
plt.show()

In [None]:
num_stories_list = [5, 10, 20, 100]
dt = d_tab.reset_index()
for i, num_stories in enumerate(num_stories_list):
    dn = dt[dt.num_stories == num_stories]
    # display(dn)
    print(f'num_stories = {num_stories}')
    display(dn.drop(columns=['num_stories']).style.background_gradient(
        cmap='magma', axis=1).format(precision=3))