In [None]:
from src.analyzer import LogprobAnalyzer, AnalysisConfig
from src.metrics import pairwise_mse_of_group, single_evidence_estimate
import altair as alt
import pandas as pd

logprob_data_path = "data/logprobs.csv"
analyzer = LogprobAnalyzer(logprob_data_path)


# Rename mapping for nicer labels
analyzer.add_rename_mapping('model_name', {
    'openai-community/gpt2': 'GPT-2',
    'openai-community/gpt2-medium': 'GPT-2-M',
    'openai-community/gpt2-large': 'GPT-2-L',
    'openai-community/gpt2-xl': 'GPT-2-XL',
    'meta-llama/Llama-3.2-1B': 'Llama3.2-1B',
    'meta-llama/Llama-3.2-3B': 'Llama3.2-3B',
    'meta-llama/Llama-3.1-8B': 'Llama3.1-8B',
})

# Set sort order using the *new* names
analyzer.set_sort_order('model_name', ['GPT-2-M', 'Llama3.2-1B'])


analyzer.add_categorizer(
    output_columns='_bce_sum', # Explicit output name
    source_columns=('prior_logprob', 'likelihood_logprob', 'posterior_logprob'),
    categorizer=lambda p, l, post: p + l - post if pd.notna(p) and pd.notna(l) and pd.notna(post) else pd.NA
)

# --- Plot 1: Variance BCE ---
config_var = AnalysisConfig(
    plot_fn=alt.Chart.mark_line,
    fig_title="BCE (Variance method) by Model and Temperature",
    x_category='temperature:Q',
    y_category='variance(single_evidence_estimate):Q',
    facet_category='model_name:N',
    # facet_columns=2, # Let the code calculate default columns
    tooltip_fields=[
        alt.Tooltip('model_name:N', title='Model'),
        alt.Tooltip('temperature:Q', title='Temp'),
        alt.Tooltip('variance(single_evidence_estimate):Q', title='BCE', format=".3f"),
        alt.Tooltip("count():Q", title="Count", format="d"),
    ],
    titles={
        'variance(single_evidence_estimate)': 'BCE (Variance method)',
        'model_name': 'Language Model',
        'temperature': 'Temperature'
    },
    interactive_chart=False,
    legend_config={"orient": "top"}
)

chart_var = analyzer.visualize(
    config=config_var,
    metric=single_evidence_estimate,
    metric_name="single_evidence_estimate",
    aggregate=False,
    metric_kwargs={'log_prior_col': 'prior_logprob', 'log_likelihood_col': 'likelihood_logprob', 'log_posterior_col': 'posterior_logprob'}
)
chart_var.show() # Display the chart


# --- Plot 2: Pairwise MSE BCE ---
config_mse = AnalysisConfig(
    plot_fn=alt.Chart.mark_boxplot,
    fig_title="BCE (Pairwise MSE method) by Model and Temperature",
    x_category='temperature:Q',
    y_category='pairwise_bce_mse:Q',
    facet_category='model_name:N',
    # facet_columns=2, # Let the code calculate default columns
    tooltip_fields=[ # Using explicit tooltips
        alt.Tooltip('model_name:N', title='Model'),
        alt.Tooltip('temperature:Q', title='Temp'),
        alt.Tooltip('mean(pairwise_bce_mse):Q', title='Mean BCE', format=".3f"),
        alt.Tooltip("count():Q", title="Count", format="d"),
    ],
    titles={
        'pairwise_bce_mse': 'BCE (Pairwise MSE method)',
        'model_name': 'Language Model',
        'temperature': 'Temperature'
    },
    interactive_chart=False,
    legend_config={"orient": "top"}
)

chart_mse = analyzer.visualize(
    config=config_mse,
    metric=pairwise_mse_of_group,
    metric_name="pairwise_bce_mse",
    aggregate=True,
    metric_kwargs={'value_col': '_bce_sum'} # Use metric_kwargs
)
chart_mse.show() # Display the chart