In [1]:
import os
import torch
import argparse
from train.rm import *
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import (
    SAFE_WEIGHTS_NAME,
)

In [4]:
def fix_valuehead(
    model, rm_ckpt_dir: str, rm_save_dir: str, V_HEAD_WEIGHTS_NAME: str = "value_head.bin"
) -> None:

    path_to_checkpoint = os.path.join(rm_ckpt_dir, SAFE_WEIGHTS_NAME)
    # state_dict = load_file(path_to_checkpoint)
    rm_model = AutoModelForCausalLM.from_pretrained(
        rm_ckpt_dir
    )
    state_dict = model.state_dict()

    
    tokenizer = AutoTokenizer.from_pretrained(
        rm_ckpt_dir, 
        trust_remote_code=True, 
        use_fast=True,
        padding_side="right",
        split_special_tokens=False,
    )
    model.pretrained_model.resize_token_embeddings(len(tokenizer))
    decoder_state_dict = {}
    v_head_state_dict = {}
    for name, param in state_dict.items():
        print(name)
        print(param)
        if name.startswith("v_head."):
            print('yes')
            v_head_state_dict[name] = param
        else:
            decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param

    model.pretrained_model.save_pretrained(
        rm_save_dir, state_dict=decoder_state_dict or None
    )
    tokenizer.save_pretrained(rm_save_dir)
    torch.save(v_head_state_dict, os.path.join(rm_save_dir, V_HEAD_WEIGHTS_NAME))
    #os.remove(path_to_checkpoint)

In [6]:
base_dir = '/groups/kjun/tnn/datasets/'
    
# dataset path
data_dir = base_dir + "/prm800k/math_splits"

llm_model = "Llama-3.2-1B-Instruct"
prm_model = "Llama3.1-8B-PRM-Deepseek-Data"
prm_model = "Qwen2.5-Math-1.5B-Instruct-PRM-0.2"

llm_model_dir = f"{base_dir}/{llm_model}"
prm_model_dir = f"{base_dir}/{prm_model}"
prm_save_dir = f"{base_dir}/{prm_model}-Modified"

args = argparse.Namespace()
args.custom_cfg = f"config/sft_eval_mcts.yaml"
args.qaf = f"eval_data/math500_test.json"
args.sft_model_path = f"{llm_model_dir}"
args.rm_ckpt_path = f"{prm_model_dir}"
args.rm_save_path = f"{prm_save_dir}"

model = AutoModelForCausalLM.from_pretrained(
    args.sft_model_path, 
    trust_remote_code=True,
    #torch_dtype=torch.bfloat16,
    use_cache=False,
)

model = RewardModelWithValueHead(pretrained_model=model)
fix_valuehead(model, args.rm_ckpt_path, args.rm_save_path)

pretrained_model.model.embed_tokens.weight
tensor([[ 3.1281e-03,  1.7822e-02,  2.0996e-02,  ..., -5.2185e-03,
         -4.1992e-02, -3.3447e-02],
        [ 2.3682e-02, -2.2949e-02,  1.9897e-02,  ..., -9.4604e-03,
         -2.2125e-03, -3.9551e-02],
        [ 1.4465e-02,  1.0559e-02,  9.8267e-03,  ...,  6.8359e-03,
         -1.1597e-02,  5.7983e-03],
        ...,
        [-1.0580e-06,  1.0620e-02, -1.9043e-02,  ...,  1.3885e-03,
         -1.7700e-03,  9.7046e-03],
        [-1.3635e-06,  1.0620e-02, -1.9043e-02,  ...,  1.3885e-03,
         -1.7700e-03,  9.7046e-03],
        [-1.1921e-06,  1.0620e-02, -1.9043e-02,  ...,  1.3885e-03,
         -1.7700e-03,  9.7046e-03]])
pretrained_model.model.layers.0.self_attn.q_proj.weight
tensor([[-0.0179,  0.0066,  0.0247,  ..., -0.0087, -0.0117,  0.0201],
        [ 0.0122,  0.0593,  0.0552,  ..., -0.0332, -0.0154,  0.0108],
        [ 0.0178,  0.0155,  0.0344,  ..., -0.0386, -0.0386, -0.0276],
        ...,
        [ 0.0298,  0.0352,  0.0713,  ..., -0.0