In [None]:
import re
import copy
import json

In [None]:
def load_json(file_path):
    with open(file_path, "r") as file:
        return json.load(file)

In [None]:
def modify_input(data, split_pattern, replacement):
    modified_data = copy.deepcopy(data)
    for item in modified_data:
        input_sentence = item['input'].split(split_pattern)[0]
        item['input'] = f'{input_sentence}{replacement}'
    return modified_data

In [None]:
def prepare_for_summary_stage(train_summarize, valid, test):
    ## Prepare data for only summary stage ablation study ##
    return (
        modify_input(train_summarize, '\nRecall:', '\nSummarize:'),
        modify_input(valid, '\nRecall:', '\nSummarize:'),
        modify_input(test, '\nRecall:', '\nSummarize:')
    )

In [None]:
def prepare_for_without_analysis(train_recall, train_summarize, valid, test):
    ## Prepare data for without Analysis stage ablation study ##
    ablation_recall = copy.deepcopy(train_recall)
    ablation_KS = modify_input(train_summarize, '\nAnalyze:', '\nSummarize:')
    return ablation_recall, ablation_KS, valid, test

In [None]:
def prepare_for_without_recall(train_analyze, train_summarize, valid, test):
    ## Prepare data for without Recall stage ablation study ##
    ablation_analysis = copy.deepcopy(train_analyze)
    ablation_RS = copy.deepcopy(train_summarize)
    pattern = r'Recall:.*?\nAnalyze:'
    for item in ablation_analysis:
        item['input'] = re.sub(pattern, 'Analyze:', item['input'])
    for item in ablation_RS:
        item['input'] = re.sub(pattern, 'Analyze:', item['input'])
   

    options = sorted(re.findall(r'\(([A-Z])\)', valid[0]["input"]))
    
    def generate_ablation_data(data, split_pattern, prefix):
        new_data = []
        for item in data:
            input_sentence = item['input'].split(split_pattern)[0]
            for option in options:
                new_item = copy.deepcopy(item)
                new_item['input'] = f'{input_sentence}{prefix} For option {option},'
                new_data.append(new_item)
        return new_data

    ablation_RS_valid = generate_ablation_data(valid, '\nRecall:', '\nAnalyze:')
    ablation_RS_test = generate_ablation_data(test, '\nRecall:', '\nAnalyze:')

    return ablation_analysis, ablation_RS, ablation_RS_valid, ablation_RS_test

In [None]:
dataset_name = "strategyqa"
base_path = f"{dataset_name}/final/sft"

train_recall = load_json(f"{base_path}/recall_analyze_summarize/recall.json")
train_analyze = load_json(f"{base_path}/recall_analyze_summarize/analyze.json")
train_summarize = load_json(f"{base_path}/recall_analyze_summarize/summarize.json")
valid = load_json(f"{base_path}/recall_analyze_summarize/valid.json")
test = load_json(f"{base_path}/recall_analyze_summarize/test.json")

In [None]:
ablation_summarize, ablation_summarize_valid, ablation_summarize_test = prepare_for_summary_stage(train_summarize, valid, test)
ablation_recall, ablation_RS, ablation_RS_valid, ablation_RS_test = prepare_for_without_analysis(train_recall, train_summarize, valid, test)
ablation_analyze, ablation_AS, ablation_AS_valid, ablation_AS_test = prepare_for_without_recall(train_analyze, train_summarize, valid, test)

In [None]:
def remove_duplicates(data_list, unique_keys):
    seen = set()
    unique_data_list = []
    for data in data_list:
        identifier = tuple(data[key] for key in unique_keys)
        if identifier not in seen:
            seen.add(identifier)
            unique_data_list.append(data)
    return unique_data_list

In [None]:
ablation_summarize = remove_duplicates(ablation_summarize, ['input'])
with open(f'{base_path}/summarize/summarize.json', 'w') as f:
    json.dump(ablation_summarize, f)
with open(f'{base_path}/summarize/valid.json', 'w') as f:
    json.dump(ablation_summarize_valid, f)
with open(f'{base_path}/summarize/test.json', 'w') as f:
    json.dump(ablation_summarize_test, f)
    
ablation_KS = remove_duplicates(ablation_KS, ['input'])
with open(f'{base_path}/recall_summarize/recall.json', 'w') as f:
    json.dump(ablation_recall, f)
with open(f'{base_path}/recall_summarize/summarize.json', 'w') as f:
    json.dump(ablation_RS, f)
with open(f'{base_path}/recall_summarize/valid.json', 'w') as f:
    json.dump(ablation_RS_valid, f)
with open(f'{base_path}/recall_summarize/test.json', 'w') as f:
    json.dump(ablation_RS_test, f)
    
ablation_analyze = remove_duplicates(ablation_analyze, ['input', 'analyze'])
with open(f'{base_path}/analyze_summarize/analyze.json', 'w') as f:
    json.dump(ablation_analyze, f)
with open(f'{base_path}/analyze_summarize/summarize.json', 'w') as f:
    json.dump(ablation_AS, f)
with open(f'{base_path}/analyze_summarize/valid.json', 'w') as f:
    json.dump(ablation_AS_valid, f)
with open(f'{base_path}/analyze_summarize/test.json', 'w') as f:
    json.dump(ablation_AS_test, f)