In [1]:
from src.analyzer import LogprobAnalyzer, AnalysisConfig
import pandas as pd
import altair as alt
alt.data_transformers.enable("vegafusion")

logprob_data_path = "data/logprobs.csv"
analyzer = LogprobAnalyzer(logprob_data_path)

analyzer.add_categorizer(
    output_columns='model_family', # Explicit output name
    source_columns=('model_name',),
    categorizer=lambda model_name: model_name.split('/')[1].split('-')[0]
)

analyzer.add_categorizer(
    output_columns='_bce_sum', # Explicit output name
    source_columns=('prior_logprob', 'likelihood_logprob', 'posterior_logprob'),
    categorizer=lambda p, l, post: p + l - post if pd.notna(p) and pd.notna(l) and pd.
    notna(post) else pd.NA
)

analyzer.add_rename_mapping('model_family', {
    'llama': 'Llama3',
    'gpt2': 'GPT2',
})

analyzer.add_rename_mapping('model_name', {
    'openai-community/gpt2': 'GPT-2-S',
    'openai-community/gpt2-medium': 'GPT-2-M',
    'openai-community/gpt2-large': 'GPT-2-L',
    'openai-community/gpt2-xl': 'GPT-2-XL',
    'meta-llama/Llama-3.2-1B': 'Llama3.2-1B',
    'meta-llama/Llama-3.2-3B': 'Llama3.2-3B',
    'meta-llama/Llama-3.1-8B': 'Llama3.1-8B',
})

# Set sort order using the *new* names
analyzer.set_sort_order('model_name', ['GPT-2-S', 'GPT-2-M', 'GPT-2-L', 'GPT-2-XL', 'Llama3.2-1B', 'Llama3.2-3B', 'Llama3.1-8B'])

# --- Plot 2: Pairwise MSE BCE ---
config_mse = AnalysisConfig(
    plot_fn=alt.Chart.mark_boxplot,
    fig_title="BCE (Pairwise MSE method) by Model and Temperature",
    x_category='model_name:Q',
    y_category='pairwise_bce_mse:Q',
    facet_category='model_family:N',
    tooltip_fields=[ # Using explicit tooltips
        alt.Tooltip('model_name:N', title='Model'),
        alt.Tooltip('model_family:N', title='Model Family'),
        alt.Tooltip('mean(pairwise_bce_mse):Q', title='Mean BCE', format=".3f"),
        alt.Tooltip("count():Q", title="Count", format="d"),
    ],
    titles={
        'pairwise_bce_mse': 'BCE (Pairwise MSE method)',
        'model_name': 'Language Model',
        'model_family': 'Model Family'
    },
    interactive_chart=False,
    legend_config={"orient": "top"}
)

chart_mse = analyzer.visualize(
    config=config_mse,
    metric=pairwise_mse_of_group,
    metric_name="pairwise_bce_mse",
    aggregate=True,
    metric_kwargs={'value_col': '_bce_sum'} # Use metric_kwargs
)
chart_mse.show() # Display the chart

MaxRowsError: The number of rows in your dataset is greater than the maximum allowed (5000).

Try enabling the VegaFusion data transformer which raises this limit by pre-evaluating data
transformations in Python.
    >> import altair as alt
    >> alt.data_transformers.enable("vegafusion")

Or, see https://altair-viz.github.io/user_guide/large_datasets.html for additional information
on how to plot large datasets.

alt.FacetChart(...)