# Benchmark Results Summary

This notebook consolidates all benchmark results into three tables:
1. **Debias Methods** - Selection bias correction methods
2. **PU Methods** - Positive-Unlabeled learning methods
3. **Debias+PU Methods** - Combined methods (placeholder)

In [43]:
import pandas as pd
import yaml
from pathlib import Path
from IPython.display import display

In [None]:
# Model categories
DEBIAS_MODELS = [
    'naive', 'ips', 'dr', 'mtips', 'mtdr', 'sdr2'
]

PU_MODELS = [
    'bpr', 'ubpr', 'cubpr', 'nnpu', 'upu', 'pu_naive',
    'uprl', 'rmf', 'ncrmf'
]

# Placeholder for combined methods
DEBIAS_PU_MODELS = ['counterif']

# Datasets and metrics
DATASETS = ['hs', 'saferlhf', 'ufb']
METRICS = ['AUROC', 'NLL', 'MAE', 'RMSE']

# Results directory
RESULTS_DIR = Path('../results/cache')

In [45]:
def load_all_results(cache_dir: Path) -> dict:
    """
    Load all performance.yaml files from the cache directory.
    
    Returns:
        dict: {(model_name, dataset_name): {metric: value, ...}}
    """
    results = {}
    
    for perf_file in cache_dir.glob('*/*/performance.yaml'):
        model = perf_file.parent.parent.name
        dataset = perf_file.parent.name
        
        # Skip debug directories
        if 'debug' in dataset:
            continue
        
        try:
            with open(perf_file) as f:
                data = yaml.safe_load(f)
            results[(model, dataset)] = data
        except Exception as e:
            print(f"Warning: Failed to load {perf_file}: {e}")
    
    return results


def build_results_table(results: dict, models: list, datasets: list, metrics: list) -> pd.DataFrame:
    """
    Build a results table for a specific set of models with MultiIndex columns.
    
    Args:
        results: Dictionary of results from load_all_results
        models: List of model names to include
        datasets: List of dataset names
        metrics: List of metric names
    
    Returns:
        DataFrame with models as rows and (dataset, metric) MultiIndex columns
    """
    # Create MultiIndex columns
    columns = pd.MultiIndex.from_product([datasets, metrics], names=['Dataset', 'Metric'])
    
    # Build data
    data = []
    for model in models:
        row = []
        for dataset in datasets:
            for metric in metrics:
                key = (model, dataset)
                if key in results:
                    metric_key = f"{metric} on test"
                    value = results[key].get(metric_key, None)
                    row.append(round(value, 4) if value is not None else None)
                else:
                    row.append(None)
        data.append(row)
    
    df = pd.DataFrame(data, index=models, columns=columns)
    df.index.name = 'Model'
    
    return df


def highlight_best(df: pd.DataFrame, lower_is_better: list = ['NLL', 'MAE', 'RMSE']):
    """
    Highlight the best value in each column.
    For metrics in lower_is_better, highlight the minimum; otherwise highlight the maximum.
    """
    def highlight_col(s):
        metric = s.name[1] if isinstance(s.name, tuple) else s.name
        if metric in lower_is_better:
            is_best = s == s.min()
        else:
            is_best = s == s.max()
        return ['font-weight: bold' if v else '' for v in is_best]
    
    # Apply styles with centered Dataset header
    styled = df.style.format(precision=4, na_rep='-').apply(highlight_col, axis=0)
    # Center the top-level (Dataset) header
    styled = styled.set_table_styles([
        {'selector': 'th.col_heading.level0', 'props': [('text-align', 'center')]},
    ])
    return styled


# Load all results
all_results = load_all_results(RESULTS_DIR)
print(f"Loaded {len(all_results)} result files")
print(f"Models found: {set(k[0] for k in all_results.keys())}")
print(f"Datasets found: {set(k[1] for k in all_results.keys())}")

Loaded 0 result files
Models found: set()
Datasets found: set()


## Table 1: Debias Methods

In [40]:
debias_table = build_results_table(all_results, DEBIAS_MODELS, DATASETS, METRICS)
display(highlight_best(debias_table))

Dataset,hs,hs,hs,hs,saferlhf,saferlhf,saferlhf,saferlhf,ufb,ufb,ufb,ufb
Metric,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
naive,0.8159,0.8909,0.5255,0.5832,0.974,0.5278,0.3011,0.4421,-,-,-,-
ips,0.8143,2.1394,0.6581,0.7863,0.965,1.3238,0.4262,0.6254,-,-,-,-
dr,0.8143,0.5773,0.4146,0.4442,0.9727,0.328,0.2256,0.3287,-,-,-,-
mtips,0.8149,0.7044,0.5025,0.5058,0.973,0.5378,0.2961,0.4411,-,-,-,-
mtdr,0.8139,0.7928,0.5286,0.5471,0.981,0.2465,0.085,0.2433,-,-,-,-
sdr2,0.8142,0.9554,0.5389,0.6043,0.9752,0.5922,0.3156,0.4686,-,-,-,-


## Table 2: PU Methods

In [41]:
pu_table = build_results_table(all_results, PU_MODELS, DATASETS, METRICS)
display(highlight_best(pu_table))

Dataset,hs,hs,hs,hs,saferlhf,saferlhf,saferlhf,saferlhf,ufb,ufb,ufb,ufb
Metric,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
bpr,0.8046,2.2775,0.6504,0.7808,0.933,0.3381,0.1944,0.3224,-,-,-,-
ubpr,0.7366,4.1236,0.6811,0.8227,0.9658,4.0292,0.4587,0.6772,-,-,-,-
cubpr,0.7366,4.1236,0.6811,0.8227,0.9658,4.0292,0.4587,0.6772,-,-,-,-
nnpu,0.8135,0.6469,0.4622,0.4805,0.982,0.3963,0.2791,0.3727,-,-,-,-
upu,0.8122,0.646,0.2381,0.4128,0.8976,1.6887,0.1062,0.3258,-,-,-,-
pu_naive,0.8213,0.5181,0.383,0.4115,0.9787,0.2433,0.1748,0.2666,-,-,-,-
uprl,0.7595,0.5991,0.3387,0.4517,0.9667,0.8162,0.348,0.5058,-,-,-,-
rmf,0.6825,7.9722,0.6842,0.8271,0.952,6.8655,0.4588,0.6773,-,-,-,-
ncrmf,0.6825,7.9722,0.6842,0.8271,0.952,6.8655,0.4588,0.6773,-,-,-,-


No Debias+PU models defined yet. Add model names to DEBIAS_PU_MODELS list when available.


Unnamed: 0_level_0,hs_AUROC,hs_NLL,hs_NDCG,hs_Recall,saferlhf_AUROC,saferlhf_NLL,saferlhf_NDCG,saferlhf_Recall,ufb_AUROC,ufb_NLL,ufb_NDCG,ufb_Recall
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1


No Debias+PU models defined yet. Add model names to DEBIAS_PU_MODELS list when available.


Unnamed: 0_level_0,hs_AUROC,hs_NLL,hs_NDCG,hs_Recall,saferlhf_AUROC,saferlhf_NLL,saferlhf_NDCG,saferlhf_Recall,ufb_AUROC,ufb_NLL,ufb_NDCG,ufb_Recall
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1


No Debias+PU models defined yet. Add model names to DEBIAS_PU_MODELS list when available.


Unnamed: 0_level_0,hs_AUROC,hs_NLL,hs_NDCG,hs_Recall,saferlhf_AUROC,saferlhf_NLL,saferlhf_NDCG,saferlhf_Recall,ufb_AUROC,ufb_NLL,ufb_NDCG,ufb_Recall
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1


## Table 3: Debias+PU Methods (Placeholder)

In [42]:
if DEBIAS_PU_MODELS:
    debias_pu_table = build_results_table(all_results, DEBIAS_PU_MODELS, DATASETS, METRICS)
    display(highlight_best(debias_pu_table))
else:
    print("No Debias+PU models defined yet. Add model names to DEBIAS_PU_MODELS list when available.")
    # Create empty placeholder table with MultiIndex columns
    columns = pd.MultiIndex.from_product([DATASETS, METRICS], names=['Dataset', 'Metric'])
    debias_pu_table = pd.DataFrame(columns=columns)
    debias_pu_table.index.name = 'Model'
    display(debias_pu_table)

Dataset,hs,hs,hs,hs,saferlhf,saferlhf,saferlhf,saferlhf,ufb,ufb,ufb,ufb
Metric,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE,AUROC,NLL,MAE,RMSE
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
counterif,0.813,0.8932,0.5254,0.5831,0.9614,0.5612,0.3186,0.4445,-,-,-,-
