In [21]:
import os
import csv
import ast
import random

### Load CSV Rephrased Data

In [4]:
def load_csv_data(file_path):
    # Initialize an empty list to store the data
    data_list = []

    # Open the CSV file for reading
    with open(file_path, newline='') as csvfile:
        # Create a CSV reader object
        csv_reader = csv.DictReader(csvfile)
        
        # Iterate through each row in the CSV file
        for row in csv_reader:
            # Append the row (as a dictionary) to the data_list
            row["choices"] = ast.literal_eval(row["choices"])

            rephrase_params = ["concept", "name", "option"]
            for param in rephrase_params:
                if row[param] == "True":
                    row[param] = True
                elif row[param] == "False":
                    row[param] = False
                else:
                    raise TypeError(f"{param} data cannot be recognized")

            data_list.append(row)
    
    return data_list

def load_all_rephrase_data(split, dir_path, file_name):
    data = {}
    
    for s in split:
        file_path = f"{dir_path}/{s}{file_name}"
        data[s] = load_csv_data(file_path)
    
    return data


split = ["validation", "test", "train"]

v1_data = load_all_rephrase_data(split, "91123", "_rephrased_clean_name_2_91123.csv")
v2_data = load_all_rephrase_data(split, "91223", "_rephrased_name_91223.csv")

### Sample Rephrased Data

In [34]:
def save_data(samples, file_path):
    # Get the keys from the first dictionary
    header = samples[0].keys()

    # Write the data to the CSV file
    with open(file_path, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=header)
        
        # Write the header
        writer.writeheader()
        
        # Write the data
        for row in samples:
            writer.writerow(row)

    print(f'CSV file "{file_path}" has been created with the data.')

def sample_data(split_type, v1_data, v2_data, idxs, file_name, fraction=0.5):
    num_sampled = int(len(idxs) * fraction)
    idx_test_sampled = random.sample(idxs, num_sampled)
    v1_test_sampled = [v1_data[split_type][i] for i in idx_test_sampled]
    v2_test_sampled = [v2_data[split_type][i] for i in idx_test_sampled]

    save_data(v1_test_sampled, f"./eval/v1_{split_type}_{file_name}")
    save_data(v2_test_sampled, f"./eval/v2_{split_type}_{file_name}")

def sample_test_val_data(split_type, v1_data, v2_data):
    v1_test_name = [idx for idx, data in enumerate(v1_data[split_type]) if data["name"] and (not data["concept"] and not data["option"])]
    v2_test_name = [idx for idx, data in enumerate(v2_data[split_type]) if data["name"] and (not data["concept"] and not data["option"])]

    assert v1_test_name == v2_test_name

    v1_test_concept = [idx for idx, data in enumerate(v1_data[split_type]) if data["concept"] and not data["option"]]
    v2_test_concept = [idx for idx, data in enumerate(v2_data[split_type]) if data["concept"] and not data["option"]]

    assert v1_test_concept == v2_test_concept

    v1_test_option = [idx for idx, data in enumerate(v1_data[split_type]) if data["option"] and not data["concept"]]
    v2_test_option = [idx for idx, data in enumerate(v2_data[split_type]) if data["option"] and not data["concept"]]

    assert v1_test_option == v2_test_option

    v1_test_both = [idx for idx, data in enumerate(v1_data[split_type]) if data["option"] and data["concept"]]
    v2_test_both = [idx for idx, data in enumerate(v2_data[split_type]) if data["option"] and data["concept"]]

    assert v1_test_both == v2_test_both
    
    print(f"Statistics for split: {split_type}")
    print(f"Name only cases: {len(v1_test_name)}")
    print(f"Concept only cases: {len(v1_test_concept)}")
    print(f"Option only cases: {len(v1_test_option)}")
    print(f"Both concept and option cases: {len(v1_test_both)}")

    permutations_idxs = {
        "name": v1_test_name,
        "concept": v1_test_concept,
        "option": v1_test_option,
        "both": v1_test_both
    }

    for s_type, idxs in permutations_idxs.items():
        if split_type == "train":
            sample_data(split_type, v1_data, v2_data, idxs, f"{s_type}.csv", fraction=0.1)
        else:
            sample_data(split_type, v1_data, v2_data, idxs, f"{s_type}.csv")

In [35]:
sample_test_val_data("train", v1_data, v2_data)

Statistics for split: train
Name only cases: 901
Concept only cases: 277
Option only cases: 816
Both concept and option cases: 168
CSV file "./eval/v1_train_name.csv" has been created with the data.
CSV file "./eval/v2_train_name.csv" has been created with the data.
CSV file "./eval/v1_train_concept.csv" has been created with the data.
CSV file "./eval/v2_train_concept.csv" has been created with the data.
CSV file "./eval/v1_train_option.csv" has been created with the data.
CSV file "./eval/v2_train_option.csv" has been created with the data.
CSV file "./eval/v1_train_both.csv" has been created with the data.
CSV file "./eval/v2_train_both.csv" has been created with the data.
