In [None]:
from src.analyzer import LogprobAnalyzer, AnalysisConfig
from src.metrics import pairwise_bce_of_group
import pandas as pd
import altair as alt
alt.data_transformers.enable("vegafusion")

logprob_data_path = "data/logprobs.csv"
analyzer = LogprobAnalyzer(logprob_data_path)

analyzer.add_category(
    output_columns='model_family', # Explicit output name
    source_columns=('model_name',),
    categorizer=lambda model_name: model_name.split('/')[1].split('-')[0]
)

analyzer.add_rename_mapping('model_family', {
    'Llama': 'Llama 3',
    'gpt2': 'GPT 2',
})

analyzer.add_rename_mapping('model_name', {
    'openai-community/gpt2': 'GPT-2-S',
    'openai-community/gpt2-medium': 'GPT-2-M',
    'openai-community/gpt2-large': 'GPT-2-L',
    'openai-community/gpt2-xl': 'GPT-2-XL',
    'meta-llama/Llama-3.2-1B': 'Llama3.2-1B',
    'meta-llama/Llama-3.2-3B': 'Llama3.2-3B',
    'meta-llama/Llama-3.1-8B': 'Llama3.1-8B',
})

# Set sort order using the *new* names
analyzer.set_sort_order({'model_name': ['GPT-2-S', 'GPT-2-M', 'GPT-2-L', 'GPT-2-XL', 'Llama3.2-1B', 'Llama3.2-3B', 'Llama3.1-8B']})

analyzer.calculate_metric(
    metric=pairwise_bce_of_group,
    metric_name=" BCE (Pairwise MSE)",
    group_by_cols=['class_type', 'model_name', 'model_family'],
    metric_kwargs={'log_prior_col': 'prior_logprob', 'log_likelihood_col': 'likelihood_logprob', 'log_posterior_col': 'posterior_logprob', 'square': True}
)

config_mse = AnalysisConfig(
    plot_fn=alt.Chart.mark_boxplot,
    fig_title="BCE (Pairwise MSE method by class_type) by Model",
    x_category='model_name:N',
    y_category='BCE (Pairwise MSE):Q',
    facet_category='model_family:N',
    tooltip_fields=[
        alt.Tooltip('model_name:N', title='Model'),
        alt.Tooltip('model_family:N', title='Model Family'),
        alt.Tooltip('BCE (Pairwise MSE):Q', title='BCE', format=".3f"),
        alt.Tooltip('median():Q', title='Median', format=".3f"),
        alt.Tooltip('mean():Q', title='Mean', format=".3f"),
        alt.Tooltip("count():Q", title="Count", format="d"),
    ],
    titles={
        'BCE (Pairwise MSE)': 'BCE (Pairwise MSE method)',
        'model_name': 'Language Model',
    },
    interactive_chart=False,
    legend_config={"orient": "top"}, 
    chart_properties={"resolve": {"scale": {"x": "independent"}}}
)

chart_mse = analyzer.visualize(
    config=config_mse
)
chart_mse.show()

In [7]:
analyzer.metric_df

Unnamed: 0,class_type,model_name,model_family,BCE (Pairwise MSE)
0,culinary_techniques,GPT-2-S,GPT 2,3.014349
1,culinary_techniques,GPT-2-S,GPT 2,0.192832
2,culinary_techniques,GPT-2-S,GPT 2,244.851510
3,culinary_techniques,GPT-2-S,GPT 2,70.873421
4,culinary_techniques,GPT-2-S,GPT 2,3.344786
...,...,...,...,...
32405,tech_innovators,Llama3.1-8B,Llama 3,17.961401
32406,tech_innovators,Llama3.1-8B,Llama 3,243.535149
32407,tech_innovators,Llama3.1-8B,Llama 3,10.215959
32408,tech_innovators,Llama3.1-8B,Llama 3,66.769850


In [2]:
config_mse = AnalysisConfig(
    plot_fn=alt.Chart.mark_boxplot,
    fig_title="BCE (pairwise MAE method) by Model",
    x_category='model_name',
    y_category='pairwise_bce_mae:Q',
    facet_category='model_family:N',
    tooltip_fields=[ # Using explicit tooltips
        alt.Tooltip('model_name:N', title='Model'),
        alt.Tooltip('model_family:N', title='Model Family'),
        alt.Tooltip('pairwise_bce_mae:Q', title='BCE', format=".3f"),
        alt.Tooltip("count():Q", title="Count", format="d"),
    ],
    titles={
        'pairwise_bce_mae': 'BCE (pairwise MAE method)',
        'model_name': 'Language Model',
    },
    interactive_chart=False,
    legend_config={"orient": "top"}, 
    chart_properties={"resolve": {"scale": {"x": "independent"}}}
)

chart_mse = analyzer.visualize(
    config=config_mse,
    metric=pairwise_error_of_group,
    metric_name="pairwise_bce_mae",
    aggregate=True,
    group_by_cols=['model_name', 'model_family',],
    metric_kwargs={'value_col': 'evidence_estimate', 'square': False}
)
chart_mse.show()

NameError: name 'pairwise_error_of_group' is not defined