In [None]:
import pandas as pd
from datasets import load_dataset, Dataset, load_from_disk
from tqdm import tqdm
import os 

PROJ_DIR = "~/Dialect_Bias/"

# define functions for making rule dialects

In [33]:
import json 

with open('data/attestA_rules.json', 'r') as f:
    attest_a_rules = json.load(f)

rule_list = list(attest_a_rules.keys())

In [34]:
def safe_save(save_df, save_dir, file_name): 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_path = os.path.join(save_dir, file_name)
    save_df.to_csv(file_path, index=False)
    

def generate_rule_transformed_dataset(df, dialect, row_to_transform="question", rows_to_save = ["id", "context", "answers"], save_dir="data/oblig_rule_transforms"):
    tranformed_texts = []
    rules_executed = []
    ids_so_far = []
    failed_ids = []

    def save_helper(ids_so_far, tranformed_texts, rules_executed):
        save_df = pd.DataFrame({"id": ids_so_far, "transformed_text": tranformed_texts, "rules_executed": rules_executed})
        save_df = save_df.merge(df[rows_to_save], on="id")
        save_df["rule_transform"] = dialect.dialect_name
        safe_save(save_df, save_dir, f"{dialect.dialect_name}.csv")
        return save_df

    iterator = tqdm(df.iterrows(), total=df.shape[0])
    for i, row in iterator: 
        
        try: 
            # get dialect trasnformations 
            tranformed_texts.append(dialect.transform(row[row_to_transform]))
            
            # get rules 
            rules_executed.append(dialect.executed_rules) 
            ids_so_far.append(row["id"])
        except Exception as e: 
            print(e)
            failed_ids.append(row["id"])

    save_df = save_helper(ids_so_far, tranformed_texts, rules_executed)
    return save_df, failed_ids

def build_dialects_for_each_rule(rule_list, df, row_to_transform="question", rows_to_save = ["id", "context", "answers"], save_dir="data/oblig_rule_transforms"): 
    dialect_df_list = []
    all_failed_ids = [] 

    for r in tqdm(rule_list): 
        dialect =  Dialects.DialectFromFeatureList(feature_list=[r], dialect_name=r)
        dialect_df, failed_ids = generate_rule_transformed_dataset( df, dialect, save_dir=save_dir, rows_to_save=rows_to_save)
        all_failed_ids.extend(failed_ids)
        dialect_df_list.append(dialect_df)

    
    return dialect_df_list, all_failed_ids

def load_in_rules(rule_list, save_dir): 
    df_list = [] 
    for r in rule_list:
        df = pd.read_csv(os.path.join(save_dir, f"{r}.csv"))
        df_list.append(df)
    return pd.concat(df_list)

def get_slice_with_exec(rules_df, exec_col = "rules_executed"): 
    return rules_df[rules_df["rules_executed"] != "{}"]



# process transformed datasets together 

In [4]:
import pandas as pd
import os 

def load_in_rules(rule_list, save_dir, pair=False): 
    df_list = [] 
    for r in rule_list:
        filepath = os.path.join(save_dir, f"{r}.csv")
        if pair: 
            rule_pair_name = "+".join(r)
            filepath = os.path.join(save_dir, f"{rule_pair_name}.csv")
        df = pd.read_csv(filepath)
        df_list.append(df)
    return pd.concat(df_list)

def get_slice_with_exec(rules_df, exec_col = "rules_executed"): 
    return rules_df[rules_df["rules_executed"] != "{}"]

def safe_save(save_df, save_dir, file_name): 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_path = os.path.join(save_dir, file_name)
    save_df.to_csv(file_path, index=False)



In [23]:
dataset_name = "mmlu"
rules = "pair"

# rule_list = [["existential_it", "drop_copula_be_NP"], 
#              ["existential_it", "drop_aux_wh"],
#              ["existential_it", "drop_aux_yn"],
#              ["existential_it", "negative_concord"],
#              ["existential_it", "remove_det_indefinite"],
#              ["existential_it", "plural_interrogative"],
#              ["existential_it", "remove_det_definite"],
#              ]
rule_list = [
    ["null_prepositions", "drop_copula_be_NP"], 
    ["null_prepositions", "one_relativizer"], 
    ["null_prepositions", "one_relativizer", "drop_copula_be_NP"], 
]

save_dir = f"data/{dataset_name}/{rules}_transforms"

loaded_in_rules = load_in_rules(rule_list, save_dir=save_dir, pair=True)
slice_with_exec = get_slice_with_exec(loaded_in_rules, exec_col = "rules_executed") 
combined_dir = os.path.join(save_dir, "combined")

if dataset_name == "boolq": 
    # additional post-processing step convert to yes and no 
    slice_with_exec["answer"] = slice_with_exec["answer"].apply(lambda x: "yes" if x == True else "no")

# safe_save(slice_with_exec, combined_dir, "test.csv")

In [19]:
import ast 

def clean_executed_rules(x): 
    rule_dict = ast.literal_eval(x)
    return [v["type"] for _ , v in rule_dict.items()] if x else []

def get_multi_rule(rules_df, rule_list): 
    rules_df["rules_executed_list"] = rules_df["rules_executed"].apply(clean_executed_rules)
    rules_df["rule_executed_set"] = rules_df["rules_executed_list"].apply(lambda x : list(set(x)))
    rule_slice_list = [] 
    for rules in rule_list: 
        transform_name = "+".join(rules)
        rules_df_slice = rules_df[rules_df["rule_transform"] == transform_name]
        rules_df_slice = rules_df_slice[rules_df_slice["rule_executed_set"].apply(lambda x: len(x) == len(rules))]
        rule_slice_list.append(rules_df_slice)
    return pd.concat(rule_slice_list) 

In [24]:
multi_rule_df = get_multi_rule(slice_with_exec, rule_list)

In [29]:
safe_save(multi_rule_df, combined_dir, "test.csv")

In [28]:
multi_rule_df["subject"].unique()

array(['high_school_government_and_politics', 'security_studies',
       'sociology', 'high_school_european_history', 'college_biology',
       'high_school_psychology', 'astronomy', 'electrical_engineering',
       'logical_fallacies', 'nutrition', 'high_school_biology',
       'high_school_macroeconomics', 'virology', 'machine_learning',
       'jurisprudence', 'professional_psychology', 'abstract_algebra',
       'econometrics', 'high_school_mathematics',
       'high_school_computer_science', 'philosophy', 'college_chemistry',
       'human_sexuality', 'high_school_chemistry', 'human_aging',
       'anatomy', 'management', 'college_medicine', 'computer_security',
       'marketing', 'conceptual_physics', 'medical_genetics',
       'public_relations', 'world_religions', 'high_school_us_history',
       'international_law', 'professional_law', 'high_school_physics',
       'moral_disputes', 'high_school_world_history',
       'professional_medicine', 'miscellaneous',
       'high_sch