In [1]:
import jsonlines
import json
from tsllm.envs import get_env_answer_checker
import numpy as np
import random

## Deduplication

In [26]:
input_file_path = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12.jsonl"
output_file_path = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl"

check_fn = get_env_answer_checker("gsm8k")

total_tokens = 0
cnt = 0
dedup_objs = []
cnt_before_dedup, cnt_after_dedup = 0, 0
correct_cnt = 0
with jsonlines.open(input_file_path, "r") as reader:
    for obj in reader:
        total_tokens += obj["result"]["#token"]
        cnt += 1
        texts = set()
        new_output_list = []
        for o in obj["output"]:
            cnt_before_dedup += 1
            txt = o["text"]
            if txt not in texts:
                cnt_after_dedup += 1
                o["correct"] = check_fn(obj["question"], obj["groundtruth"], txt)
                if o["correct"]:
                    correct_cnt += 1
                new_output_list.append(o)
                texts.add(txt)
        obj.pop("output")
        obj["answer"] = new_output_list
        dedup_objs.append(obj)

print(total_tokens / cnt)
print("{} -> {} after deduplicate.".format(cnt_before_dedup, cnt_after_dedup))
print("correct: {}".format(correct_cnt / cnt_after_dedup))


with jsonlines.open(output_file_path, "w") as writer:
    for obj in dedup_objs:
        writer.write(obj)

del input_file_path, output_file_path

5199.825371336812
89676 -> 78732 after deduplicate.
correct: 0.732091144642585


## Merge previous rollout data and current rollout data for policy training

set `input_file_path_0` to be the path of previous rollout data, `input_file_path_1` to be the path of current rollout_data

set `output_file_path` to be where you store the merged data

In [3]:
input_file_path_1 = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl"
input_file_path_0 = "tslmm/envs/gsm8k/train_data/sft_init.jsonl"
output_file_path = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-merge_train-dedup.jsonl"

obj_dict = {}
cnt = 0
total_cnt = 0
with jsonlines.open(input_file_path_1, "r") as reader:
    for obj in reader:
        obj_dict[obj["i"]] = obj

with jsonlines.open(input_file_path_0, "r") as reader:
    for i, obj in enumerate(reader):
        obj_to_merge = obj_dict[i]
        assert obj_to_merge["question"] == obj["question"]
        current_texts = set([o["text"] for o in obj_to_merge["answer"]])

        come_in_output = obj["answer"][0]
        if come_in_output["text"] not in current_texts:
            obj_to_merge["answer"].append(come_in_output)
            cnt += 1
        total_cnt += len([x for x in obj_to_merge["answer"] if x["correct"]])

print("ADD {} new instances".format(cnt))
print("TOTAL DATA: {}".format(total_cnt))

# SL sft training data
with jsonlines.open(output_file_path, "w") as writer:
    for obj in obj_dict.values():
        writer.write(obj)

del input_file_path, output_file_path

ADD 6504 new instances
TOTAL DATA: 64134


## Merge previous rollout data and current rollout data for critic training

set `input_file_path_0` to be the path of previous rollout data, `input_file_path_1` to be the path of current rollout_data

set `output_file_path` to be where you store the merged data

In [24]:
input_file_path_0 = "tsllm/offline_rl/gsm8k_data/processed/gsm8k_train_cot_sample_sft_k100_merged_dedup_sample17x3.jsonl"
input_file_path_1 = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl"
output_file_path = "train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-value-sl_train-dedup.jsonl"

seed = 1
random.seed(seed)
np.random.seed(seed)

obj_dict = {}
with jsonlines.open(input_file_path_1, "r") as reader:
    for obj in reader:
        obj_dict[obj["i"]] = obj
        

cnt = 0
total_cnt = 0
merged_dedup_cnt = 0
K = 51 - 12
subsample_obj_dict = {}
with jsonlines.open(input_file_path_0, "r") as reader:
    for i, obj in enumerate(reader):
        assert obj["question"] == obj_dict[i]["question"]
        current_texts = set([o["text"] for o in obj_dict[i]["answer"]])
        total_cnt += len(obj["answer"])        
        print(len(obj["answer"]), end="\r")
        if len(obj["answer"]) > K:
            subsample_list = np.random.choice(obj["answer"], K, replace=False)
        else:
            subsample_list = obj["answer"]

        for o in subsample_list:
            if o["text"] not in current_texts:
                current_texts.add(o["text"])
                obj_dict[i]["answer"].append(o)
                cnt += 1
        merged_dedup_cnt += len(obj_dict[i]["answer"])
print("ADD {}/{} NEW DATA, NOW {}".format(cnt, total_cnt, merged_dedup_cnt))

# SL sft training data
with jsonlines.open(output_file_path, "w") as writer:
    for obj in obj_dict.values():
        writer.write(obj)

ADD 280037/345945 NEW DATA, NOW 358769
