In [12]:
from bisect import bisect_left
from itertools import accumulate
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gc
import json
import random
import torch
import wandb

In [2]:
class RewardModelDataset(Dataset):
    
    def __init__(self, json_file):
        self.data = [sample for sample in self.load_data_from_file(json_file) if len(sample["label"]["steps"]) > 0]
        
        self.total_length = sum(len(sample["label"]["steps"]) for sample in self.data)
        
        self.slots = list(accumulate([len(sample["label"]["steps"]) for sample in self.data]))
        self.ctx_target_pairs = []
        
        def find_pos(numbers, x):
            if x < numbers[0]:
                return 0  # Insert at the beginning
            elif x > numbers[-1]:
                return len(numbers)  # Insert at the end
            else:
                return bisect_left(numbers, x)
            
        def numerical_to_one_hot(target):
            # Convert numerical target to one-hot encoding
            if target == -1:
                return torch.tensor([1, 0, 0], dtype=torch.float32)
            elif target == 0:
                return torch.tensor([0, 1, 0], dtype=torch.float32)
            elif target == 1:
                return torch.tensor([0, 0, 1], dtype=torch.float32)
            else:
                return torch.tensor([0, 1, 0], dtype=torch.float32)
     
        
        for idx in tqdm(range(self.total_length)):
            slot_idx = find_pos(self.slots, idx)
            sample_idx = self.slots[slot_idx] - idx 
            sample = self.data[sample_idx]

            question = sample["question"]["problem"]
            steps = sample["label"]["steps"][:sample_idx + 1]
            context = question+"[ANS]"

            targets = []

            for step in steps[:-1]:
                completion = random.choice(step["completions"])
                context += f"  [SEP]{completion['text']} <[RATING]> {completion['rating']}"

            final_ctx = random.choice(steps[-1]["completions"])

            context += f"  [SEP]{final_ctx['text']} <[RATING]>"
            target = numerical_to_one_hot(final_ctx['rating'])
            self.ctx_target_pairs.append({"context": context, "target": target})


    def load_data_from_file(self, json_file):
        with open(json_file, 'r') as file:
            data = [json.loads(line) for line in file]
        return data

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        return self.ctx_target_pairs[idx]

In [14]:
def load_data_from_file(json_file):
    with open(json_file, 'r') as file:
        data = [json.loads(line) for line in file]
    return data

In [15]:
dataset = load_data_from_file("prmdata/train.jsonl")

In [16]:
from tqdm import tqdm

In [17]:
completion_counts = []

for sample in tqdm(dataset):
    completion_counts.append([len(completions["completions"]) for completions in sample["label"]["steps"]])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97782/97782 [00:00<00:00, 620066.42it/s]


In [22]:
dataset[0]["question"]["problem"]

'The first four terms in an arithmetic sequence are $x+y$, $x-y$, $xy$, and $x/y$, in that order. What is the fifth term? Express your answer as a common fraction.'

97782

In [25]:
probs = []
for sample in tqdm(dataset):
    probs.append(sample["question"]["problem"])

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97782/97782 [00:00<00:00, 1821242.56it/s]


In [27]:
len(dataset)/len(set(probs))

9.030476542297746

In [28]:
problem_sets = {problem: [] for problem in probs}

In [30]:
for sample in tqdm(dataset):
    prob = sample["question"]["problem"]
    _ = problem_sets[prob]
    _.append(sample["label"]["steps"])
    problem_sets[prob] = _ 

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 97782/97782 [00:00<00:00, 965635.09it/s]


In [34]:
problem_sets[list(problem_sets.keys())[0]][0]

[{'completions': [{'text': 'To find the fifth term, I need to identify the common difference of the arithmetic sequence and add it to the fourth term.',
    'rating': 1,
    'flagged': None}],
  'human_completion': None,
  'chosen_completion': 0},
 {'completions': [{'text': 'The common difference is the same for any consecutive pair of terms, so I can use any of them to find it.',
    'rating': 1,
    'flagged': None}],
  'human_completion': None,
  'chosen_completion': 0},
 {'completions': [{'text': 'For example, using the first and second terms, I can write $x-y = x+y + d$, where $d$ is the common difference.',
    'rating': 1,
    'flagged': None}],
  'human_completion': None,
  'chosen_completion': 0},
 {'completions': [{'text': 'Solving for $d$, I get $d = -2y$.',
    'rating': 1,
    'flagged': None}],
  'human_completion': None,
  'chosen_completion': 0},
 {'completions': [{'text': 'Using another pair of terms, such as the second and third, I can check if this value of $d$ is co

In [38]:
sum(len(value) for value in problem_sets.values())

97782