In [29]:
from datasets import load_from_disk
import pandas as pd
import scipy.stats as stats
from tqdm import tqdm

In [3]:
rewards=load_from_disk("REBEL/rewards")

In [4]:
df=rewards.to_pandas()
df.columns

Index(['prompt_uid', 'prompt_category', 'prompt',
       'claude-3-haiku-20240307_response', 'claude-3-opus-20240229_response',
       'claude-3-sonnet-20240229_response', 'command-r_response',
       'command-r-plus_response', 'dbrx-instruct_response',
       'gpt-3.5-turbo-0125_response', 'gpt-4-turbo-2024-04-09_response',
       'llama-3-70b-instruct_response', 'mistral-large_response',
       'mistral-medium_response', 'mistral-small_response',
       'mixtral-8x7b-instruct_response',
       'response_mixtral-8x7b-instruct_reward',
       'response_mixtral-8x7b-instruct_by_objective',
       'response_mistral-small_reward', 'response_mistral-small_by_objective',
       'response_mistral-medium_reward',
       'response_mistral-medium_by_objective',
       'response_gpt-3.5-turbo-0125_reward',
       'response_gpt-3.5-turbo-0125_by_objective',
       'response_mistral-large_reward', 'response_mistral-large_by_objective',
       'response_gpt-4-turbo-2024-04-09_reward',
       'respo

In [26]:

def kruskal_for_category_metric(df_long, category, metric):
    sub = df_long[
        (df_long['prompt_category'] == category) &
        (df_long['metric'] == metric)
    ]
    # Сгруппируем по модели, получим списки значений
    groups = [grp['value'].values for _, grp in sub.groupby('model')]
    # Нужно минимум 2 группы с ≥1 значением
    valid = [g for g in groups if len(g) > 0]
    if len(valid) < 2:
        return float('nan')
    stat, pval = stats.kruskal(*valid)
    return pval

In [None]:

# 1. Выделяем список колонок с оценками по моделям
response_cols = [col for col in df.columns if col.startswith("response_") and col.endswith("_by_objective")]

# 2. Собираем «длинный» DataFrame: каждая строка — это одна оценка метрики от одной модели для одного промпта
records = []
for idx, row in df.iterrows():
    uid = row['prompt_uid']
    category = row['prompt_category']
    for col in response_cols:
        # Получаем имя модели, убирая префикс "response_" и суффикс "_by_objective"
        model = col[len("response_"):-len("_by_objective")]
        metric_dict = row[col]
        # Если в cell пусто или не словарь, пропустим
        if not isinstance(metric_dict, dict):
            continue
        for metric, value in metric_dict.items():
            records.append({
                'prompt_uid': uid,
                'prompt_category': category,
                'model': model,
                'metric': metric,
                'value': value
            })

long_df = pd.DataFrame(records)

# 3. Список уникальных категорий и метрик
categories = long_df['prompt_category'].unique()
metrics = long_df['metric'].unique()


pvals_kruskal = pd.DataFrame(index=categories, columns=metrics, dtype=float)

for cat in tqdm(categories):
    for met in tqdm(metrics):
        pvals_kruskal.loc[cat, met] = kruskal_for_category_metric(long_df, cat, met)

# Преобразуем к float
p_values = pvals_kruskal.astype(float)

100%|██████████| 19/19 [00:01<00:00, 18.16it/s]
100%|██████████| 19/19 [00:01<00:00, 18.78it/s]
100%|██████████| 19/19 [00:01<00:00, 18.72it/s]
100%|██████████| 19/19 [00:01<00:00, 18.77it/s]
100%|██████████| 19/19 [00:01<00:00, 18.62it/s]
100%|██████████| 19/19 [00:01<00:00, 18.68it/s]
100%|██████████| 19/19 [00:01<00:00, 18.58it/s]
100%|██████████| 19/19 [00:01<00:00, 18.69it/s]
100%|██████████| 19/19 [00:01<00:00, 18.53it/s]
100%|██████████| 19/19 [00:01<00:00, 18.61it/s]
100%|██████████| 19/19 [00:01<00:00, 18.73it/s]
100%|██████████| 19/19 [00:01<00:00, 18.77it/s]
100%|██████████| 19/19 [00:01<00:00, 18.74it/s]
100%|██████████| 19/19 [00:01<00:00, 18.73it/s]
100%|██████████| 19/19 [00:01<00:00, 18.69it/s]
100%|██████████| 19/19 [00:01<00:00, 18.78it/s]
100%|██████████| 19/19 [00:01<00:00, 18.63it/s]
100%|██████████| 19/19 [00:01<00:00, 18.79it/s]
100%|██████████| 19/19 [00:01<00:00, 18.70it/s]
100%|██████████| 19/19 [00:01<00:00, 18.70it/s]
100%|██████████| 19/19 [00:01<00:00, 18.

In [34]:

mask = p_values > 0.99

mask.all().all()

np.True_