In [9]:
import json
from light_scenario import ScenarioSpecInstanceIds
import cattrs
mapping = {
    "BabiQA": ["bAbI"],
    "BBQ": ["BBQ"],
    "BLiMP": ["BLiMP"],
    "BOLD": ["BOLD"],
    "BoolQ": ["BoolQ"],
    "CivilComments": ["CivilComments"],
    "Code": ["HumanEval", "APPS"],
    'Copyright':['Copyright'],
    "CommonSense": ["CommonSenseQA", "HellaSwag", "PIQA", "SIQA"], 
    "Disinformation": ["Disinformation - HELM"],
    "DyckLanguage": ["DyckLanguage"],
    "EntityDataImputation": ["EntityDataImputation"],
    "EntityMatching": ["EntityMatching"],
    "GSM8K": ["GSM8K"],
    "ICE": ["ICE"],
    "IMDB": ["IMDB"],
    "LegalSupport": ["LegalSupport"],
    "LSAT": ["LSAT"],
    "MATH": ["MATH"],
    "MMLU": ["MMLU"],
    "MSMARCO": ["MS MARCO"],
    "NarrativeQA": ["NarrativeQA"],
    "NaturalQA": ["Natural Questions"],
    "QuAC": ["QuAC"],
    "RAFT": ["RAFT"],
    "RealToxicityPrompts": ["RealToxicityPrompts"],
    "Summarization": ["XSum", "CNN/Daily Mail"],
    "SyntheticEfficiency": ["SyntheticEfficiency"],
    "SyntheticReasoning": ["SyntheticReasoning"],
    "SRN": ["SynetheticReasoningNatural"],
    "ThePile": ["The Pile"],
    "TruthfulQA": ["TruthfulQA"],
    "TwitterAAE": ["TwitterAAE"],
    "WIKIFact": ["WikiFact"]
}
dataset_mapping = {
    "humaneval": "HumanEval",
    "apps": "APPS",
    "xsum-sampled": "XSum",
    "cnn-dm": "CNN/Daily Mail",
    'hellaswag': 'HellaSwag',
    'openbookqa': 'OpenBookQA',
}

def scenario_spec_to_dataset_name(scenario_spec):
    """ Get the dataset name from scenario_spec """
    class_name = scenario_spec.class_name.split('.')[-1][:-8]  # Get only the class name, not the full module path
    key = mapping[class_name]
    
    args = scenario_spec.args  # ScenarioSpec args
    
    
    if len(key) == 1:
        return key[0]
    
    dataset = args['dataset_name'] if 'dataset_name' in args else args['dataset']
    key = dataset_mapping[dataset]
    return key

class_name_to_counts = dict()
scenario_spec_instance_id_dict = dict()
scenario_spec_instance_ids_json = 'filtered_scenario_spec_instance_ids.json'
scenario_spec_instance_ids_jsons = open(scenario_spec_instance_ids_json, "r").readlines()
for scenario_spec_instance_ids_json in scenario_spec_instance_ids_jsons:
    scenario_spec_instance_ids_dict = json.loads(scenario_spec_instance_ids_json)
    scenario_spec_instance_ids = cattrs.structure(scenario_spec_instance_ids_dict, ScenarioSpecInstanceIds)
    scenario_spec_instance_id_dict[
        scenario_spec_instance_ids.scenario_spec
    ] = scenario_spec_instance_ids.instance_ids
    class_name = scenario_spec_to_dataset_name(scenario_spec_instance_ids.scenario_spec)
    if class_name not in class_name_to_counts:
        print(class_name)
        if class_name == 'Copyright':
            continue
#         print(scenario_spec_instance_ids_json)
        class_name_to_counts[class_name] = 0
    class_name_to_counts[class_name] += len(scenario_spec_instance_ids.instance_ids)
    

bAbI
BBQ
BLiMP
BOLD
BoolQ
CivilComments
APPS
HumanEval
HellaSwag
OpenBookQA
Copyright
Copyright
Copyright
Copyright
Copyright
Disinformation - HELM
DyckLanguage
EntityDataImputation
EntityMatching
GSM8K
ICE
IMDB
LegalSupport
LSAT
MATH
MMLU
MS MARCO
NarrativeQA
Natural Questions
QuAC
RAFT
RealToxicityPrompts
CNN/Daily Mail
XSum
SyntheticEfficiency
SyntheticReasoning
SynetheticReasoningNatural
The Pile
TruthfulQA
TwitterAAE
WikiFact


In [10]:
class_name_to_counts

{'bAbI': 4000,
 'BBQ': 1000,
 'BLiMP': 4000,
 'BOLD': 1000,
 'BoolQ': 5128,
 'CivilComments': 45000,
 'APPS': 1000,
 'HumanEval': 164,
 'HellaSwag': 1000,
 'OpenBookQA': 500,
 'Disinformation - HELM': 79,
 'DyckLanguage': 500,
 'EntityDataImputation': 424,
 'EntityMatching': 1400,
 'GSM8K': 1000,
 'ICE': 2939,
 'IMDB': 4421,
 'LegalSupport': 1000,
 'LSAT': 461,
 'MATH': 874,
 'MMLU': 1641,
 'MS MARCO': 2504,
 'NarrativeQA': 2350,
 'Natural Questions': 7876,
 'QuAC': 4321,
 'RAFT': 1348,
 'RealToxicityPrompts': 1000,
 'CNN/Daily Mail': 1000,
 'XSum': 1000,
 'SyntheticEfficiency': 600,
 'SyntheticReasoning': 3000,
 'SynetheticReasoningNatural': 2000,
 'The Pile': 2952,
 'TruthfulQA': 1895,
 'TwitterAAE': 2000,
 'WikiFact': 7927}

In [None]:
with open(f'outputs/{dataset}_latex_table.txt', 'w') as output_file:
    for index, row in sorted_df.iterrows():
        formatted_output = f"{row['Scenario']} & {row['Part']} & {row['Count']} & {format(row['Binary'], '.4f')} & {format(row['Jaccard'], '.4f')} & {format(row['Token'], '.4f')} & {format(row['Filtered Binary'], '.4f')} & {format(row['Filtered Jaccard'], '.4f')} & {format(row['Filtered Token'], '.4f')} & {format(row['Weighted Jaccard'], '.4f')} & {format(row['Weighted Token'], '.4f')} & {format(row['Weighted Filtered Jaccard'], '.4f')} & {format(row['Weighted Filtered Token'], '.4f')} \\\\\n"
#         print(formatted_output)
        output_file.write(formatted_output)
