In [30]:
from datasets import load_dataset

ds = load_dataset("RUC-AIBOX/STILL-3-Preview-RL-Data")
# Convert ds into json list of dicotinaries
still_ds = ds['train'].map(lambda x: {"problem": x["question"], "answer": x["answer"],})


In [31]:
# Load all jsons in ../test and merge them into one large json
import os
import json

prior_data = []
for file in ['aime.json', 'amc.json', 'math.json', 'omni_math.json']:
    # get absolute path by merging file wtih directory
    file_path = os.path.join("../train", file)
    file_path = os.path.abspath(file_path)
    with open(file_path, "r") as f:
        data = json.load(f)
    prior_data.extend(data)
    
len(prior_data)


15965

In [27]:
from rllm.utils import RAG

rag_searcher = RAG(docs=[d["problem"] for d in prior_data])

In [32]:
# Filter for olympiad problems that are not in the omni dataset
from tqdm import tqdm

filter_problems = []
num_problems = 0

counter = 0
# Wrap olympiad_data with tqdm, optionally adding a description and total
for d in tqdm(still_ds, desc="Filtering stil data", total=len(still_ds)):
    search_result = rag_searcher.top_k(d["problem"], k=1)[0]
    score = search_result["score"]
    if score > 0.90:
        num_problems += 1
    else:
        filter_problems.append(d)
    counter += 1
    if counter %1000 == 0:
        print(counter)
# Save final list as json
with open("still.json", "w") as f:
    json.dump(filter_problems, f, indent=2)

Filtering stil data:   3%|▎         | 1021/29925 [00:07<03:33, 135.23it/s]

1000


Filtering stil data:   7%|▋         | 2015/29925 [00:14<03:27, 134.79it/s]

2000


Filtering stil data:  10%|█         | 3023/29925 [00:22<03:18, 135.79it/s]

3000


Filtering stil data:  13%|█▎        | 4017/29925 [00:29<03:11, 135.13it/s]

4000


Filtering stil data:  17%|█▋        | 5025/29925 [00:37<03:03, 135.98it/s]

5000


Filtering stil data:  20%|██        | 6019/29925 [00:44<02:55, 135.99it/s]

6000


Filtering stil data:  23%|██▎       | 7027/29925 [00:51<02:48, 135.50it/s]

7000


Filtering stil data:  27%|██▋       | 8021/29925 [00:59<02:41, 135.71it/s]

8000


Filtering stil data:  30%|███       | 9015/29925 [01:06<02:33, 136.31it/s]

9000


Filtering stil data:  33%|███▎      | 10023/29925 [01:14<02:25, 136.84it/s]

10000


Filtering stil data:  37%|███▋      | 11017/29925 [01:21<02:18, 136.77it/s]

11000


Filtering stil data:  40%|████      | 12025/29925 [01:28<02:10, 137.53it/s]

12000


Filtering stil data:  44%|████▎     | 13019/29925 [01:35<02:03, 136.99it/s]

13000


Filtering stil data:  47%|████▋     | 14027/29925 [01:43<01:55, 137.21it/s]

14000


Filtering stil data:  50%|█████     | 15021/29925 [01:50<01:48, 137.65it/s]

15000


Filtering stil data:  54%|█████▎    | 16015/29925 [01:57<01:41, 136.82it/s]

16000


Filtering stil data:  57%|█████▋    | 17023/29925 [02:05<01:33, 137.93it/s]

17000


Filtering stil data:  60%|██████    | 18017/29925 [02:12<01:26, 138.16it/s]

18000


Filtering stil data:  64%|██████▎   | 19028/29925 [02:19<01:18, 139.18it/s]

19000


Filtering stil data:  67%|██████▋   | 20016/29925 [02:26<01:11, 139.14it/s]

20000


Filtering stil data:  70%|███████   | 21027/29925 [02:34<01:03, 139.33it/s]

21000


Filtering stil data:  74%|███████▎  | 22024/29925 [02:41<00:56, 138.69it/s]

22000


Filtering stil data:  77%|███████▋  | 23016/29925 [02:48<00:49, 139.74it/s]

23000


Filtering stil data:  80%|████████  | 24015/29925 [02:55<00:41, 140.79it/s]

24000


Filtering stil data:  84%|████████▎ | 25020/29925 [03:02<00:34, 140.81it/s]

25000


Filtering stil data:  87%|████████▋ | 26024/29925 [03:09<00:27, 141.36it/s]

26000


Filtering stil data:  90%|█████████ | 27014/29925 [03:16<00:20, 141.15it/s]

27000


Filtering stil data:  94%|█████████▎| 28016/29925 [03:24<00:13, 140.23it/s]

28000


Filtering stil data:  97%|█████████▋| 29021/29925 [03:31<00:06, 140.83it/s]

29000


Filtering stil data: 100%|██████████| 29925/29925 [03:37<00:00, 137.53it/s]


In [33]:
for d in filter_problems:
    del d["question"]
    del d['messages']

# Save final list as json
with open("still.json", "w") as f:
    json.dump(filter_problems, f, indent=2)
len(filter_problems)

26026