In [None]:
from bisect import bisect_left
from itertools import accumulate
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import gc
import json
import random
import torch
import wandb

In [None]:
class RewardModelDataset(Dataset):
    
    def __init__(self, json_file):
        self.data = [sample for sample in self.load_data_from_file(json_file) if len(sample["label"]["steps"]) > 0]
        
        self.total_length = sum(len(sample["label"]["steps"]) for sample in self.data)
        
        self.slots = list(accumulate([len(sample["label"]["steps"]) for sample in self.data]))
        self.ctx_target_pairs = []
        
        def find_pos(numbers, x):
            if x < numbers[0]:
                return 0  # Insert at the beginning
            elif x > numbers[-1]:
                return len(numbers)  # Insert at the end
            else:
                return bisect_left(numbers, x)
            
        def numerical_to_one_hot(target):
            # Convert numerical target to one-hot encoding
            if target == -1:
                return torch.tensor([1, 0, 0], dtype=torch.float32)
            elif target == 0:
                return torch.tensor([0, 1, 0], dtype=torch.float32)
            elif target == 1:
                return torch.tensor([0, 0, 1], dtype=torch.float32)
            else:
                return torch.tensor([0, 1, 0], dtype=torch.float32)
     
        
        for idx in tqdm(range(self.total_length)):
            slot_idx = find_pos(self.slots, idx)
            sample_idx = self.slots[slot_idx] - idx 
            sample = self.data[sample_idx]

            question = sample["question"]["problem"]
            steps = sample["label"]["steps"][:sample_idx + 1]
            context = question+"[ANS]"

            targets = []

            for step in steps[:-1]:
                completion = random.choice(step["completions"])
                context += f"  [SEP]{completion['text']} <[RATING]> {completion['rating']}"

            final_ctx = random.choice(steps[-1]["completions"])

            context += f"  [SEP]{final_ctx['text']} <[RATING]>"
            target = numerical_to_one_hot(final_ctx['rating'])
            self.ctx_target_pairs.append({"context": context, "target": target})


    def load_data_from_file(self, json_file):
        with open(json_file, 'r') as file:
            data = [json.loads(line) for line in file]
        return data

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        return self.ctx_target_pairs[idx]

In [None]:
train_path = "prmdata/train.jsonl"
test_path = "prmdata/test.jsonl"


reward_model_dataset = RewardModelDataset(train_path)
validation_size = int(0.1 * len(reward_model_dataset))
train_size = len(reward_model_dataset) - validation_size

train_dataset, validation_dataset = random_split(reward_model_dataset, [train_size, validation_size])
test_set = RewardModelDataset(test_path)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_classes = 3

In [None]:
tokenizer = AutoTokenizer.from_pretrained("ChaiML/reward_models_100_170000000_cp_498032")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("ChaiML/reward_models_100_170000000_cp_498032")
model.score = torch.nn.Linear(768, num_classes)
model.to(device)

In [None]:
# Load model directly

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")

In [None]:
wandb.init(project='prm', name='continuous-loss-plotting')

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

train_losses = []

In [None]:
model.train()

epochs = 10

for epoch in range(epochs):
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} - Training"):
        context = batch["context"]
        target = batch["target"].to(device)

        inputs = tokenizer(context, return_tensors="pt", padding=True, truncation=True)

        inputs = {key: value.to(device) for key, value in inputs.items()}
        optimizer.zero_grad()

        outputs = model(**inputs)
        loss = criterion(outputs.logits, target)

        loss.backward()
        
        optimizer.step()

        total_train_loss += loss.item()

        wandb.log({'train_batch_loss': loss.item()})
        
        del loss, context, target, outputs, inputs
       
        torch.cuda.empty_cache()

    average_train_loss = total_train_loss / len(train_loader)
    train_losses.append(average_train_loss)
    wandb.log({'train_loss': average_train_loss})
    print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {average_train_loss}")

model.save_pretrained("prm")
tokenizer.save_pretrained("prm")

wandb.finish()