In [2]:
path = 'metrics/filtered_pile_metrics_all'
aggregate_stats_path = 'output_stats_pile_all'

import json
import cattrs
import pandas as pd

from data_overlap_spec import DataOverlapStats 
from collections import defaultdict

In [3]:
def scenario_spec_to_class(scenario_spec) -> str:
    return f"{'.'.join(scenario_spec.class_name.split('.')[-1:])}"

In [4]:
output_stats_jsons = open(aggregate_stats_path, "r").readlines()

scenario_spec_to_instance_count = defaultdict(int)
data_overlap_stats_list = []
for output_stats_json in output_stats_jsons:
    output_stats_dict = json.loads(output_stats_json)
    data_overlap_stats = cattrs.structure(output_stats_dict, DataOverlapStats)
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    light_scenario_key = data_overlap_stats_key.light_scenario_key
    scenario_spec = light_scenario_key.scenario_spec
    class_name = scenario_spec_to_class(scenario_spec)
    n = data_overlap_stats_key.overlap_protocol_spec.n
    if n != 13:
        continue
    data_overlap_stats_list.append(data_overlap_stats)
    scenario_spec_to_instance_count[scenario_spec] += data_overlap_stats.num_instances


In [6]:
import ast
import json
import cattrs
import pandas as pd
from nltk import ngrams
from collections import defaultdict
from typing import List, Tuple
from dataclasses import dataclass

from data_overlap_spec import DataOverlapStats, DataOverlapStatsKey, EntryOverlapNgrams
from compute_data_overlap_metrics import load_light_scenarios_from_jsonl
from common.util import get_tokenizer
from common.general import asdict_without_nones

from enum import Enum
%matplotlib inline



In [9]:
@dataclass(frozen=True)
class EntryDataOverlapKey:
    """Unique key representing either the input or references of a single instance in a scenario."""

    stats_key: DataOverlapStatsKey
    part: str
    """Either PART_INPUT or PART_REF"""
    instance_id: str


# Input: List[EntryOverlapNgrams]
@dataclass(frozen=True)
class EntryOverlapNgrams:
    """Dataclass that represents output data overlap stats"""

    entry_data_overlap_key: EntryDataOverlapKey

    overlapping_ngram_counts: List[Tuple[str, int]]


class PartialOverlapSpec(int, Enum):
    binary = 0
    jaccard = 1
    token = 2
    def __str__(self):
        return self.name

@dataclass(frozen=True)
class FrequencySpec:
    # Filter ngrams with frequency >= filter_value; 0 means no filter
    filter_value: int
    # Whether to apply weight; we'll do inverse frequency
    weighting: bool
        
@dataclass(frozen=True)
class MetricProtocolSpec:
    """Specification for how we compute the metric"""
    
    partial_overlap_spec: PartialOverlapSpec
    frequency_spec: FrequencySpec
        
@dataclass(frozen=True)
class OverlapMetric:
    metric_score: float # use 0/1 for binary, can revise as neded
    metric_protocol_spec: MetricProtocolSpec

# Output: List[EntryOverlapMetric]
@dataclass(frozen=True)
class EntryOverlapMetric:
    """Dataclass that represents output data overlap stats"""

    entry_data_overlap_key: EntryDataOverlapKey

    overlap_metric: OverlapMetric
        
        

In [10]:

overlap_metrics_jsons = open(path, "r").readlines()

entry_overlap_metric_list = []
for entry_overlap_metric_json in overlap_metrics_jsons:
    entry_overlap_metric_dict = json.loads(entry_overlap_metric_json)
    entry_overlap_metric_list.append(cattrs.structure(entry_overlap_metric_dict, EntryOverlapMetric))
    

In [11]:
entry_overlap_metric_list[0]

EntryOverlapMetric(entry_data_overlap_key=EntryDataOverlapKey(stats_key=DataOverlapStatsKey(light_scenario_key=LightScenarioKey(scenario_spec=ScenarioSpec(class_name='helm.benchmark.scenarios.math_scenario.MATHScenario', args={'subject': 'intermediate_algebra', 'level': 1, 'use_official_examples': False, 'use_chain_of_thought': True}), split='test'), overlap_protocol_spec=OverlapProtocolSpec(n=13)), part='references', instance_id='id103'), overlap_metric=OverlapMetric(metric_score=1.0, metric_protocol_spec=MetricProtocolSpec(partial_overlap_spec=<PartialOverlapSpec.binary: 0>, frequency_spec=FrequencySpec(filter_value=0, weighting=False))))

In [12]:
# Initialize an empty dictionary to store the mapping
score_dict = {}

# Iterate through the entry_overlap_metric_list
for entry_overlap_metric in entry_overlap_metric_list:
    scenario_spec = entry_overlap_metric.entry_data_overlap_key.stats_key.light_scenario_key.scenario_spec
    class_name = scenario_spec_to_class(scenario_spec)

    metric_protocol_spec = entry_overlap_metric.overlap_metric.metric_protocol_spec
    metric_score = entry_overlap_metric.overlap_metric.metric_score
    part = entry_overlap_metric.entry_data_overlap_key.part
    
#     key = (class_name, part)
    key = scenario_spec
    # Check if the scenario_spec already exists in the dictionary
    if key not in score_dict:
        score_dict[key] = {}
    
    # Check if the metric_protocol_spec is already a key under scenario_spec
    if metric_protocol_spec not in score_dict[key]:
        score_dict[key][metric_protocol_spec] = []
    
    # Append the metric_score to the list associated with metric_protocol_spec under scenario_spec
    score_dict[key][metric_protocol_spec].append(metric_score)


In [13]:
def metric_to_label(metric: MetricProtocolSpec) -> str:
    partial_overlap_str = str(metric.partial_overlap_spec)
    frequency_str = f"{metric.frequency_spec.filter_value} {metric.frequency_spec.weighting}"
    
    return f"{partial_overlap_str}, {frequency_str}"

def scenario_spec_to_str(scenario_spec) -> str:
    class_name_str = f"{'.'.join(scenario_spec.class_name.split('.')[-2:])}"
    args_str = f"{{{', '.join([f'{k}: {v}' for k, v in scenario_spec.args.items()])}}}"
    
    return f"{class_name_str}, {args_str}"



In [16]:
import numpy

In [10]:

PART_INPUT: str = "input"
PART_REF: str = "reference"
metric_protocol_specs_list  = [
    MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(0, False)),
    MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, False)),
    MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(0, True)),
    MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, False)),
    MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(0, True)),
    MetricProtocolSpec(PartialOverlapSpec.binary, FrequencySpec(10, False)),
    MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, False)),
    MetricProtocolSpec(PartialOverlapSpec.jaccard, FrequencySpec(10, True)),
    MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, False)),
    MetricProtocolSpec(PartialOverlapSpec.token, FrequencySpec(10, True))
]

cols = ['class_name', 'part', 'num_instances']
res = []
scenario_spec_to_results = defaultdict(dict) 
def stats_to_row(key, metric_scores):
    class_name, part = key
    num_instances = scenario_spec_to_instance_count[class_name]
    for metric_spec, values in metric_scores.items():
        scenario_spec_to_results[key][metric_spec] = sum(values) / num_instances
    ret = [class_name, part, num_instances]
    for metric_spec in metric_protocol_specs_list:
        ret.append(scenario_spec_to_results[key][metric_spec])
    return ret



for metric_spec in metric_protocol_specs_list:
    cols.append(metric_to_label(metric_spec))

for class_name, metric_scores in score_dict.items():
    res.append(stats_to_row(class_name, metric_scores))

    

In [11]:
agg_metrics_df = pd.DataFrame(res, columns=cols)

In [12]:
agg_metrics_df.to_csv('the_pile_metrics_class.csv', index=False)

In [13]:
agg_metrics_df

Unnamed: 0,class_name,part,num_instances,"binary, 0 False","jaccard, 0 False","jaccard, 0 True","token, 0 False","token, 0 True","binary, 10 False","jaccard, 10 False","jaccard, 10 True","token, 10 False","token, 10 True"
0,MATHScenario,references,17778,0.000225,2.6e-05,2.3e-05,6.7e-05,5.2e-05,0.000225,2.6e-05,2.3e-05,6.7e-05,5.2e-05
1,MATHScenario,input,17778,0.000562,0.000134,0.0001,0.000317,0.000199,0.000562,0.000134,0.0001,0.000317,0.000199
2,RAFTScenario,input,550,0.112727,0.046415,0.031255,0.061694,0.046487,0.105455,0.039375,0.030887,0.054882,0.046049
3,SummarizationScenario,input,49683,0.01417,0.001822,0.001558,0.002907,0.002626,0.014069,0.00179,0.001556,0.002901,0.002625
4,ICEScenario,input,13089,0.030102,0.000525,0.000383,0.000965,0.000771,0.028421,0.000476,0.000381,0.00092,0.000766
5,BoolQScenario,input,22188,0.025329,0.009753,0.007059,0.01382,0.011654,0.025014,0.009406,0.007043,0.013614,0.01163
6,TwitterAAEScenario,input,100000,7e-05,4.3e-05,1.2e-05,5.9e-05,1.8e-05,6e-05,3.4e-05,1.1e-05,5e-05,1.7e-05
7,SummarizationScenario,references,49683,0.002174,0.001167,0.001017,0.00149,0.001354,0.002154,0.001163,0.001017,0.001484,0.001354
8,IMDBScenario,input,75000,0.000333,0.000129,0.000128,0.000152,0.00015,0.000333,0.000128,0.000128,0.000152,0.00015
9,CivilCommentsScenario,input,692436,0.000703,0.000184,8.6e-05,0.00027,0.000175,0.000646,0.00013,8.4e-05,0.000235,0.000172


In [14]:
# Assuming you have a DataFrame named agg_metrics_df
sorted_df = agg_metrics_df.sort_values(by='binary, 0 False', ascending=False)

# Print the sorted DataFrame
sorted_df

Unnamed: 0,class_name,part,num_instances,"binary, 0 False","jaccard, 0 False","jaccard, 0 True","token, 0 False","token, 0 True","binary, 10 False","jaccard, 10 False","jaccard, 10 True","token, 10 False","token, 10 True"
34,NarrativeQAScenario,input,1572,0.149491,0.044497,0.038621,0.059716,0.056266,0.148219,0.044351,0.038612,0.059671,0.056263
27,MSMARCOScenario,references,9022,0.1135,0.01057,0.007853,0.01805,0.014817,0.113279,0.00989,0.007824,0.018033,0.014808
2,RAFTScenario,input,550,0.112727,0.046415,0.031255,0.061694,0.046487,0.105455,0.039375,0.030887,0.054882,0.046049
18,QuACScenario,input,12567,0.065012,0.013977,0.01145,0.021143,0.01944,0.064932,0.013827,0.01144,0.02106,0.019372
24,NaturalQAScenario,input,8578,0.057939,0.022707,0.015525,0.031078,0.02535,0.057356,0.021549,0.015465,0.030788,0.025338
17,ThePileScenario,input,21455,0.048893,0.004632,0.001468,0.006373,0.003448,0.042694,0.001928,0.001449,0.005133,0.003382
4,ICEScenario,input,13089,0.030102,0.000525,0.000383,0.000965,0.000771,0.028421,0.000476,0.000381,0.00092,0.000766
20,CopyrightScenario,references,6977,0.029812,0.004327,0.002994,0.005929,0.004864,0.029096,0.00371,0.002972,0.005877,0.004854
23,CodeScenario,input,10164,0.026368,0.002346,0.002151,0.00373,0.00338,0.026072,0.002304,0.00215,0.003702,0.003374
5,BoolQScenario,input,22188,0.025329,0.009753,0.007059,0.01382,0.011654,0.025014,0.009406,0.007043,0.013614,0.01163
