In [50]:
import json
import random
import tiktoken
import numpy as np
from typing import List, Dict

In [51]:
MODEL = {
    'gpt-4o-2024-05-13':{
        'input_price': 5/1000000, # 1M tokens
        'output_price': 15/1000000 # 1M tokens
    },
    'gpt-4-turbo-2024-04-09':{
        'input_price': 10/1000000, # 1M tokens
        'output_price': 30/1000000 # 1M tokens
    },
    'gpt-4-0125-preview':{ #'gpt-4-1106-preview'
        'input_price': 10/1000000, # 1M tokens
        'output_price': 30/1000000 # 1M tokens
    },
    'gpt-4-0613':{
        'input_price': 30/1000000, # 1M tokens
        'output_price': 60/1000000 # 1M tokens
    },
    'gpt-3.5-turbo-0125':{
        'input_price': 0.5/1000000, # 1M tokens
        'output_price': 1.5/1000000 # 1M tokens
    },
    'gpt-3.5-turbo-1106':{
        'input_price': 0.001/1000, # 1K tokens
        'output_price': 0.002/1000 # 1K tokens
    },
}

# Prepare Reasoning Request

In [52]:
def read_data(file_path: str) -> List[str]:
    file_suffix = file_path.split('.')[-1]
    file_data = []

    with open(file_path, 'r') as file:
        if file_suffix == 'jsonl':
            for line in file:
                data = json.loads(line)
                choices_text = ', '.join([f"{choice['label']}. {choice['text']}" for choice in data['question']['choices']]) #+ '.'
                item = {
                    'Q': data['question']['stem'],
                    'Options': choices_text.strip(),
                    'GT': data['answerKey'],
                    'Other': [choice['label'] for choice in data['question']['choices'] if choice['label'] != data['answerKey']]
                }
                file_data.append(item)
        elif file_suffix == 'json':
            json_data = json.load(file)
            for data in json_data:
                item = {
                    'Q': data['question'],
                    'Options': 'A. yes, B. no',
                    'GT': 'A' if data['answer'] else 'B',
                    'Other': ['B' if data['answer'] else 'A']
                }
                file_data.append(item)
                
    return file_data

In [53]:
def reasoning_prompt_system(examples: List[Dict[str, any]], request_prompt: Dict[str, any], num_shot: int = 4) -> (List[Dict[str, str]], List[str]):
    examples = random.sample(examples, k=num_shot)
    examples_message = []
    reasoning_message = []
    
    for example in examples:
        reasoning_content = '\n'.join(example['Explanations'])
        examples_message.append(
            f"\n###\n"
            f"{example['Q']}\n"
            f"Options: {example['Options']}\n"
            f"Key Information: {example['Key Information']}\n" 
            f"Explanations: {reasoning_content}\n"
            f"###\n"
        )
        reasoning_content = f"Key Information: {example['Key Information']}\nExplanations: " + reasoning_content
        reasoning_message.append(reasoning_content)
        
    examples_string = "".join(examples_message)
    system_message = [
        {
            "role": "system",
            "content": f"{request_prompt['Key']}{examples_string}".strip()
        },
    ]
    return system_message, reasoning_message

In [54]:
def reasoning_prompt_user(data: List[Dict[str, any]]) -> List[Dict[str, str]]:
    user_message = []
    for item in data:
        option_content = "".join([f"{option} is incorrect. Because\n" for option in item['Other']])
        user_content = "".join([
            f"{item['Q']}\n",
            f"Options: {item['Options']}\n",
            f"Key Information:\n",
            f"Explanations: {item['GT']} is correct. Because\n{option_content}"
        ]).strip()
        user_query = [
            {
                "role": "user",
                "content": user_content
            },
        ]
        user_message.append(user_query)
    return user_message

In [55]:
def num_tokens_from_messages(messages: List[Dict[str, str]], model: str) -> int:
    encoding = tiktoken.encoding_for_model(model)
    
    if model in {
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k-0613",
        "gpt-3.5-turbo-1106",
        'gpt-3.5-turbo-0125',
        "gpt-4-0613",
        "gpt-4-0125-preview",
        "gpt-4-1106-preview",
        "gpt-4-turbo-2024-04-09",
        "gpt-4o-2024-05-13"
        }:
        tokens_per_message = 3 # every message follows <|im_start|><im_sep>{content}<|im_end|>
        tokens_per_name = 1
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = 4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        tokens_per_name = -1  # if there's a name, the role is omitted
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
        )
        
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        if isinstance(message, dict):
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":
                    num_tokens += tokens_per_name
        elif isinstance(message, str):
            num_tokens += len(encoding.encode(message))
    num_tokens += 2  # every reply is primed with <im_start>assistant<im_sep>
    return num_tokens

In [56]:
def wrap_reasoning_prompt(
    model_name, 
    system_message, 
    user_message,
    few_shot_message,
    num_prompts = None,
    max_tokens = None,
    num_sample = None,
    temperature = 0.8,
    save_path = None
):
    if num_prompts is None:
        num_prompts = len(user_message)
    user_message_subset = user_message[:num_prompts]

    prompts = [{
        'model': model_name,
        'messages': system_message + message if system_message is not None else message,
        'max_tokens': max_tokens,
        'n': num_sample,
        'temperature': temperature,
    } for message in user_message_subset]
  
    estimate_price(system_message, user_message_subset, few_shot_message, num_sample, max_tokens)
    
    if save_path:
        with open(save_path, 'w') as file:
            for prompt in prompts:
                file.write(json.dumps(prompt) + '\n')
                
    return prompts

In [57]:
def estimate_price(
    system_message, 
    user_message,
    few_shot_message,
    num_sample=1, 
    max_tokens=None
):
    for model_name in MODEL.keys():
        print(f'{model_name}/{dataset_name}')
        
        input_price = MODEL[model_name]['input_price']
        output_price = MODEL[model_name]['output_price']

        # calculate input token
        num_tokens_system_message = num_tokens_from_messages(system_message, model_name) if system_message else 0
        len_input_token = [num_tokens_system_message + num_tokens_from_messages(item, model_name) for item in user_message]

        input_mean_price = np.mean(len_input_token) * input_price
        input_estimated_price = input_mean_price * len(user_message)
        print(f'The mean length of GPT input is {np.mean(len_input_token)}')
        
        
        #calculate output token
        if max_tokens is None:
            len_output_token = [num_tokens_from_messages([message], model_name) for message in few_shot_message]
            output_mean_price = np.mean(len_output_token) * output_price
            print(f'The mean length of output(*num_sample) is {np.mean(len_output_token) * num_sample}')
            print(f'estimated tokens per request is {np.mean(len_input_token) + np.mean(len_output_token)}')
        else:
            output_mean_price = max_tokens * output_price
            print(f'The max length of output is {max_tokens}')
            print(f'estimated tokens per request is {np.mean(len_input_token) + max_tokens}')
        
        output_estimated_price = output_mean_price * num_sample * len(user_message)
        
        print(f'estimated input price is ${input_estimated_price}/${input_mean_price}')
        print(f'estimated output price is ${output_estimated_price}/${output_mean_price}')
        print(f'estimated final price is ${input_estimated_price+output_estimated_price}')
        print()

## CSQA

In [58]:
dataset_name = "csqa"

In [59]:
training_data = read_data(f'{dataset_name}/original/train_rand_split.jsonl')
validation_data = read_data(f'{dataset_name}/original/dev_rand_split.jsonl')

In [11]:
new_validation_data = random.sample(training_data, k=len(validation_data))
for item in new_validation_data:
    training_data.remove(item)
new_test_data = validation_data

In [12]:
with open(f'{dataset_name}/prompt/{dataset_name}_cot.json', 'r') as file:
    csqa_examples = json.load(file)

with open('request.json', 'r') as file:
    request_prompt = json.load(file)

In [13]:
csqa_system_message, reasoning_message = reasoning_prompt_system(csqa_examples, request_prompt, num_shot=8)
csqa_user_message = reasoning_prompt_user(training_data)

In [14]:
csqa_example = wrap_reasoning_prompt(
    model_name='gpt-4o-2024-05-13', 
    system_message=csqa_system_message, 
    user_message=csqa_user_message,
    few_shot_message=reasoning_message,
    num_prompts=None,
    max_tokens=256,
    num_sample=8,
    temperature=0.8,
    save_path=None
)

gpt-4o-2024-05-13/csqa
The mean length of GPT input is 1777.4390845070423
The max length of output is 256
estimated tokens per request is 2033.4390845070423
estimated input price is $75.718905/$0.008887195422535212
estimated output price is $261.7344/$0.00384
estimated final price is $337.453305

gpt-4-turbo-2024-04-09/csqa
The mean length of GPT input is 1790.743896713615
The max length of output is 256
estimated tokens per request is 2046.743896713615
estimated input price is $152.57138/$0.01790743896713615
estimated output price is $523.4688/$0.00768
estimated final price is $676.04018

gpt-4-0125-preview/csqa
The mean length of GPT input is 1790.743896713615
The max length of output is 256
estimated tokens per request is 2046.743896713615
estimated input price is $152.57138/$0.01790743896713615
estimated output price is $523.4688/$0.00768
estimated final price is $676.04018

gpt-4-0613/csqa
The mean length of GPT input is 1790.743896713615
The max length of output is 256
estimated 

In [16]:
with open(f'{dataset_name}/prompt/valid_raw.json', 'w') as f:
    json.dump(new_validation_data, f)

with open(f'{dataset_name}/prompt/test_raw.json', 'w') as f:
    json.dump(new_test_data, f)

## StrategyQA

In [15]:
dataset_name = "strategyqa"

In [16]:
training_data = read_data(f'{dataset_name}/original/{dataset_name}_train.json')

In [17]:
random.shuffle(training_data)

In [18]:
train_size = len(training_data) * 8 // 10
other_size = len(training_data) // 10
new_training_data = training_data[:train_size]
validation_data = training_data[train_size:train_size + other_size]
test_data = training_data[train_size + other_size:]

In [19]:
with open(f'{dataset_name}/prompt/{dataset_name}_cot.json', 'r') as file:
    strategyqa_examples = json.load(file)
    
with open('request.json', 'r') as file:
    request_prompt = json.load(file)

In [20]:
strategyqa_system_message, reasoning_message = reasoning_prompt_system(strategyqa_examples, request_prompt, num_shot=8)
strategyqa_user_message = reasoning_prompt_user(new_training_data)

In [21]:
strategyqa_example = wrap_reasoning_prompt(
    model_name='gpt-4o-2024-05-13', 
    system_message=strategyqa_system_message,
    user_message=strategyqa_user_message,
    few_shot_message=reasoning_message,
    num_prompts=None,
    max_tokens=256,
    num_sample=4,
    temperature=0.8,
    save_path=None
)

gpt-4o-2024-05-13/strategyqa
The mean length of GPT input is 1039.1512008733625
The max length of output is 256
estimated tokens per request is 1295.1512008733625
estimated input price is $9.518625000000002/$0.0051957560043668135
estimated output price is $28.13952/$0.00384
estimated final price is $37.658145000000005

gpt-4-turbo-2024-04-09/strategyqa
The mean length of GPT input is 1047.3160480349345
The max length of output is 256
estimated tokens per request is 1303.3160480349345
estimated input price is $19.18683/$0.010473160480349346
estimated output price is $56.27904/$0.00768
estimated final price is $75.46587

gpt-4-0125-preview/strategyqa
The mean length of GPT input is 1047.3160480349345
The max length of output is 256
estimated tokens per request is 1303.3160480349345
estimated input price is $19.18683/$0.010473160480349346
estimated output price is $56.27904/$0.00768
estimated final price is $75.46587

gpt-4-0613/strategyqa
The mean length of GPT input is 1047.316048034934

In [27]:
with open(f'{dataset_name}/prompt/valid_raw.json', 'w') as f:
    json.dump(validation_data, f)

with open(f'{dataset_name}/prompt/test_raw.json', 'w') as f:
    json.dump(test_data, f)

## OBQA

In [22]:
dataset_name = "obqa"

In [23]:
training_data = read_data(f'{dataset_name}/original/train.jsonl')
validation_data = read_data(f'{dataset_name}/original/dev.jsonl')
test_data = read_data(f'{dataset_name}/original/test.jsonl')

In [24]:
with open(f'{dataset_name}/prompt/{dataset_name}_cot.json', 'r') as file:
    obqa_examples = json.load(file)
    
with open('request.json', 'r') as file:
    request_prompt = json.load(file)

In [25]:
obqa_system_message, reasoning_message = reasoning_prompt_system(obqa_examples, request_prompt, num_shot=7)
obqa_user_message = reasoning_prompt_user(training_data)

In [26]:
obqa_example = wrap_reasoning_prompt(
    model_name='gpt-4o-2024-05-13', 
    system_message=obqa_system_message,
    user_message=obqa_user_message,
    few_shot_message=reasoning_message,
    num_prompts=None,
    max_tokens=256,
    num_sample=4,
    temperature=0.8,
    save_path=None
)

gpt-4o-2024-05-13/obqa
The mean length of GPT input is 1554.7401654226346
The max length of output is 256
estimated tokens per request is 1810.7401654226346
estimated input price is $38.534235/$0.007773700827113174
estimated output price is $76.13952/$0.00384
estimated final price is $114.673755

gpt-4-turbo-2024-04-09/obqa
The mean length of GPT input is 1577.029453298366
The max length of output is 256
estimated tokens per request is 1833.029453298366
estimated input price is $78.17335000000001/$0.01577029453298366
estimated output price is $152.27904/$0.00768
estimated final price is $230.45239000000004

gpt-4-0125-preview/obqa
The mean length of GPT input is 1577.029453298366
The max length of output is 256
estimated tokens per request is 1833.029453298366
estimated input price is $78.17335000000001/$0.01577029453298366
estimated output price is $152.27904/$0.00768
estimated final price is $230.45239000000004

gpt-4-0613/obqa
The mean length of GPT input is 1577.029453298366
The ma

In [17]:
with open(f'{dataset_name}/prompt/valid_raw.json', 'w') as f:
    json.dump(validation_data, f)

with open(f'{dataset_name}/prompt/test_raw.json', 'w') as f:
    json.dump(test_data, f)

# Collect Reasoning Response

In [27]:
import re
import json
import random
import numpy as np
from collections import Counter
from itertools import combinations
from Levenshtein import distance as levenshtein_distance

## Function for Processing Reasoning Dataset

In [28]:
def read_reasoning_data(dataset_name, num_iter, num_option, num_augment):
    train_list = []
    retry_dataset = None
    
    for index in range(1, num_iter):
        read_file = f'{dataset_name}/prompt/reasoning/train_openai_{index}_results.jsonl'
        write_file = f'{dataset_name}/prompt/reasoning/train_openai_{index + 1}.jsonl'
        train_dataset, retry_dataset = extract_answer_reasoning(read_file, num_option + 1, num_augment, write_file)
        train_list += train_dataset
        print(f"Length of {index} data: {len(train_dataset)}")
        print(f"Length of {index + 1} retry data: {len(retry_dataset)}\n")
        
    return train_list, retry_dataset


def collate_content(content):
    match = re.compile(r'(Key Information.*?)(Explanations:)', re.DOTALL).search(content)
    if match:
        general_knowledge = match.group(1).strip()
        cleaned_content = general_knowledge + '\n' + content[match.end(1):]
    else:
        cleaned_content = content

    cleaned_content = re.sub(r'Explanations:\s*', 'Explanations: ', cleaned_content)
    return [line.strip() for line in cleaned_content.strip().split("\n") if line.strip()]


def filter_answer(answer_list, check_num):
    num_retry = 0
    candidate_answer = []
    max_tokens_double = False

    answer_pattern = re.compile(r"^[A-E] is (correct|incorrect)\.$")
    
    for item in answer_list:
        # Skip if the response was cut off due to length
        if item["finish_reason"] == 'length':
            num_retry += 1
            max_tokens_double = True
            continue
        
        result_content = collate_content(item['message']['content'])
        if "Key Information:" not in result_content[0]:
            # if after collation the knowledge part is still strange then skip
            num_retry += 1
            continue
        
        general_knowledge = result_content[0].split("Key Information:")[1].strip()
        if not general_knowledge.endswith(('.', '"', "'")):
            general_knowledge = general_knowledge + '.'
        
        answer_dict = {
            'general_knowledge': general_knowledge,
            'answer_prefix': [],
            'specific_knowledge': [],
            'LLM_answer': None
        }
       
        for content in result_content[1:check_num]:
            if "Explanations:" in content and "correct" in content:
                content = content.split("Explanations: ")[-1].strip()
            elif "is correct" in content or "is incorrect" in content:
                content = content.strip()
            else:
                # if after collation the specific_knowledge part is still strange then skip
                continue
            
            content_split = None
            for delimiter in ["Because ", "Because, ", "Although ", "While "]:
                if delimiter in content:
                    content_split = content.split(delimiter)
                    break
                    
            if not content_split:
                continue
                
            answer_prefix, reason_part = content_split
            answer_prefix = answer_prefix.strip()
            reason_part = reason_part.strip()

            if not answer_pattern.match(answer_prefix):
                answer_dict['LLM_answer'] = None
                break
            
            if " correct" in answer_prefix:
                answer_dict['LLM_answer'] = answer_prefix.split(" is correct")[0].strip()

            answer_dict['answer_prefix'].append(answer_prefix)
            answer_dict['specific_knowledge'].append(reason_part)

        if answer_dict['LLM_answer']:
            candidate_answer.append(answer_dict)
        else:
            # case-1 Sometimes all prefix will be 'is incorrect'
            # case-2 the answer format is strange
            num_retry += 1
        
    return num_retry, max_tokens_double, candidate_answer
            

def select_candidate(candidates, augment_num):
    candidate_general_knowledge = [item['general_knowledge'] for item in candidates]
    num_general_knowledges = len(candidate_general_knowledge)
    dist_matrix = np.zeros((num_general_knowledges, num_general_knowledges))
    
    for i, j in combinations(range(num_general_knowledges), 2):
        dist = levenshtein_distance(candidate_general_knowledge[i], candidate_general_knowledge[j])
        dist_matrix[i, j] = dist
        dist_matrix[j, i] = dist
    
    threshold = 10
    
    selected_general_knowledges = []
    for i in range(num_general_knowledges):
        if all(dist_matrix[i, j] > threshold for j in range(num_general_knowledges) if i != j):
            selected_general_knowledges.append(candidate_general_knowledge[i])
    
    selected_candidates = [item for item in candidates if item['general_knowledge'] in selected_general_knowledges[:augment_num]]

    if len(selected_candidates) < augment_num:
        remaining_candidate = [item for item in candidates if item not in selected_candidates]
        additional_content = random.sample(remaining_candidate, min(augment_num - len(selected_candidates), len(remaining_candidate)))
        selected_candidates.extend(additional_content)
    
    return selected_candidates


def extract_answer_reasoning(file_path, check_num, augment_num, retry_path):
    item_list = []
    retry_list = []
    
    with open(file_path, 'r') as file:
        for line in file:
            prompt, answer = json.loads(line)
            
            if isinstance(answer, list):
                # which means the request is totally failed
                retry_list.append(prompt)
                continue
            
            user_prompt = prompt['messages'][-1]['content'].strip().split("\nKey Information:\nExplanations: ")
            question_options = user_prompt[0].strip()
            ground_truth = user_prompt[1].split(" is correct")[0].strip()
                
            num_retry, max_token_double, candidate_answer = filter_answer(answer['choices'], check_num)
            if max_token_double:
                prompt["max_tokens"] *= 2

            # Use vote to check if LLM agree with the ground truth
            # if the result > 50%, then we should follow the LLM and change the ground truth
            LLM_answer = [item['LLM_answer'] for item in candidate_answer]
            if len(LLM_answer) == 0:
                prompt['n'] = num_retry
                retry_list.append(prompt)
                continue
            
            most_vote_answer, _ = Counter(LLM_answer).most_common(1)[0]
            
            if most_vote_answer != ground_truth:
                ground_truth = most_vote_answer
                
            consistent_candidates = [candidate for candidate in candidate_answer if f"{ground_truth} is correct." in candidate['answer_prefix']]
            sampled_candidates = select_candidate(consistent_candidates, augment_num)

            for candidate in sampled_candidates:
                object_dict = {
                    'input': question_options,
                    'GT': ground_truth,
                    'general_knowledge': candidate['general_knowledge'],
                    'specific_knowledge': []
                }
                for prefix, specific_knowledge in zip(candidate['answer_prefix'], candidate['specific_knowledge']):
                    if not specific_knowledge.endswith(('.', '"', "'")):
                        specific_knowledge = specific_knowledge.strip() + '.'
                    if re.match(r'^[A-Z]$', prefix[0]):
                        object_dict['specific_knowledge'].append(f"For option {prefix[0]}, {specific_knowledge.strip()}")

                if len(object_dict['specific_knowledge']) == check_num - 1:
                    item_list.append(object_dict)
                else:
                    num_retry += 1
                    
            if num_retry != 0:
                if num_retry > 8:
                    print("wrong")
                else:
                    prompt['n'] = num_retry
                    retry_list.append(prompt)
                
    if retry_path:
        with open(retry_path, 'w') as file:
            for prompt in retry_list:
                file.write(json.dumps(prompt) + '\n')
    
    return item_list, retry_list

## Final Process and Save

In [29]:
def process_and_save(dataset_name, train_data, save_flag=False):
    ### preprocess train data ###
    train_recall_data = []
    train_analyze_data = []
    train_summarize_data = []
    
    for data in train_data:
        data_input = data['input'].replace("Options: ", "")
        lines = data_input.split("\n")
        lines[1] = " ".join([f"({item[0]}) {item[3:]}" for item in lines[1].split(", ")])
        data_input = "\n".join(lines)
        
        train_recall_data.append(
            {
                'input': f"{data_input}\nRecall:",
                'GT': data['GT'],
                'recall': data['general_knowledge']
            }
        )
        
        for knowledge in data['specific_knowledge']:
            pattern = r"(For option [A-Z],)(.*)"
            matches = re.findall(pattern, knowledge)
            prefix = matches[0][0].strip()
            knowledge_text = matches[0][1].strip()
            train_analyze_data.append(
                {
                    'input': f"{data_input}\nRecall: {data['general_knowledge']}\nAnalyze: {prefix}",
                    'GT': data['GT'],
                    'analyze': knowledge_text
                }
            )

        train_summarize_data.append(
            {
                'input': f"{data_input}\nRecall: {data['general_knowledge']}\nAnalyze: {' '.join(data['specific_knowledge'])}\nSummarize:",
                'GT': data['GT'],
                'summarize': f"So the answer is option {data['GT']}."
            }
        )
        
    
    def _remove_duplicates(data_list, unique_keys):
        seen = set()
        unique_data_list = []
        for data in data_list:
            identifier = tuple(data[key] for key in unique_keys)
            if identifier not in seen:
                seen.add(identifier)
                unique_data_list.append(data)
        return unique_data_list

    print(f"Length of original recall data: {len(train_recall_data)}")
    train_recall_data = _remove_duplicates(train_recall_data, ['input', 'recall'])
    print(f"After remove duplicates data, the length of recall data: {len(train_recall_data)}")
    print(f"Length of original analyze data: {len(train_analyze_data)}")
    train_analyze_data = _remove_duplicates(train_analyze_data, ['input', 'analyze'])
    print(f"After remove duplicates data, the length of analyze data: {len(train_analyze_data)}")
    print(f"Length of original summarize data: {len(train_summarize_data)}")
    train_summarize_data = _remove_duplicates(train_summarize_data, ['input', 'summarize'])
    print(f"After remove duplicates data, the length of summarize data: {len(train_summarize_data)}")
    
    def _preprocess(data_list):
        item_list = []
        for item in data_list:
            item_choices = " ".join([f"({choice[0]}) {choice[3:]}" for choice in item['Choices'].split(", ")])
            item_dict = {
                "input": f"{item['Q']}\n{item_choices}\nRecall:",
                "GT": item["GT"]
            }
            item_list.append(item_dict)
        return item_list


    ### preprocess validation data ###
    with open(f'{dataset_name}/prompt/valid_raw.json', 'r') as file:
        valid_data = json.load(file)
    valid_dataset = _preprocess(valid_data)


    ### preprocess test data ###
    with open(f'{dataset_name}/prompt/test_raw.json', 'r') as file:
        test_data = json.load(file)
    test_dataset = _preprocess(test_data)

    
    if save_flag:
        with open(f'{dataset_name}/final/sft/recall_analyze_summarize/recall.json', 'w') as f:
            json.dump(train_recall_data, f)
        with open(f'{dataset_name}/final/sft/recall_analyze_summarize/analyze.json', 'w') as f:
            json.dump(train_analyze_data, f)
        with open(f'{dataset_name}/final/sft/recall_analyze_summarize/summarize.json', 'w') as f:
            json.dump(train_summarize_data, f)
        with open(f'{dataset_name}/final/sft/recall_analyze_summarize/valid.json', 'w') as f:
            json.dump(valid_dataset, f)
        with open(f'{dataset_name}/final/sft/recall_analyze_summarize/test.json', 'w') as f:
            json.dump(test_dataset, f)

## CSQA

In [30]:
dataset_name = "csqa"
num_iter = 4
num_option = 5
num_augment = 8
csqa_train, csqa_retry = read_reasoning_data(dataset_name, num_iter, num_option, num_augment)

Length of 1 data: 61730
Length of 2 retry data: 3092

Length of 2 data: 5006
Length of 3 retry data: 123

Length of 3 data: 135
Length of 4 retry data: 32



In [32]:
process_and_save(dataset_name, csqa_train, save_flag=True)

Length of original recall data: 66871
After reduce duplicates data, the length of recall data: 66740
Length of original analyze data: 334355
After reduce duplicates data, the length of analyze data: 334316
Length of original summarize data: 66871
After reduce duplicates data, the length of summarize data: 66870


## StrategyQA

In [33]:
dataset_name = "strategyqa"
num_iter = 3
num_option = 2
num_augment = 4
strategyqa_train, strategyqa_retry = read_reasoning_data(dataset_name, num_iter, num_option, num_augment)
print(len(strategyqa_train))

Length of 1 data: 7123
Length of 2 retry data: 22

Length of 2 data: 22
Length of 3 retry data: 4

7145


In [34]:
process_and_save(dataset_name, strategyqa_train, save_flag=True)

Length of original recall data: 7145
After reduce duplicates data, the length of recall data: 7089
Length of original analyze data: 14290
After reduce duplicates data, the length of analyze data: 14290
Length of original summarize data: 7145
After reduce duplicates data, the length of summarize data: 7145


## OBQA

In [35]:
dataset_name = "obqa"
num_iter = 4
num_option = 4
num_augment = 4
obqa_train, obqa_retry = read_reasoning_data(dataset_name, num_iter, num_option, num_augment)
print(len(obqa_train))

Length of 1 data: 18757
Length of 2 retry data: 593

Length of 2 data: 666
Length of 3 retry data: 108

Length of 3 data: 71
Length of 4 retry data: 77

19494


In [36]:
process_and_save(dataset_name, obqa_train, save_flag=True)

Length of original recall data: 19494
After reduce duplicates data, the length of recall data: 19439
Length of original analyze data: 77976
After reduce duplicates data, the length of analyze data: 77965
Length of original summarize data: 19494
After reduce duplicates data, the length of summarize data: 19492
