# Verifier의 성능 확인

Verifier가 선택한 답이 맞을 확률이 높은 것이 맞는지?

In [1]:
import os; os.chdir("../")
import sys; sys.path.append('scripts')

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import yaml
import matplotlib.pyplot as plt
import torch
from datasets import load_from_disk

from transformers import (
    BitsAndBytesConfig,
    AutoModelForCausalLM, AutoTokenizer,
    
)
from tqdm import tqdm
from utils import HF_NAME_MAP
from utils import set_seed, init_tokenizer, validate_args, _extract_answer

config_path = "configs/basic.yml"
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


def get_likelihood(model, prompt_tokens, label_tokens):
    """
    Compute the likelihood for multiple label tokens given a shared prompt.

    Args:
        model: The causal language model.
        prompt_tokens (torch.Tensor): The prompt tokens of shape (1, q_tokens).
        label_tokens (torch.Tensor): The label tokens of shape (n_samples, a_tokens).

    Returns:
        torch.Tensor: A tensor of shape (n_samples,) containing the log-likelihood for each sample.
    """
    n_samples = label_tokens.size(0)
    q_tokens = prompt_tokens.size(1)

    # Repeat the prompt tokens for each label
    repeated_prompt_tokens = prompt_tokens.repeat(n_samples, 1).to(label_tokens.device)  # Shape: (n_samples

    # Concatenate prompt and label tokens
    input_tokens = torch.cat([repeated_prompt_tokens, label_tokens], dim=1)  # Shape: (n_samples, q_tokens + a_tokens)

    with torch.no_grad():
        outputs = model(input_tokens)
        logits = outputs.logits.detach().cpu()  # Shape: (n_samples, seq_length, vocab_size)

    # Extract logits corresponding to label tokens
    label_start_idx = q_tokens  # Labels start after the prompt
    label_logits = logits[:, label_start_idx - 1:-1, :]  # Shape: (n_samples, a_tokens, vocab_size)

    # Compute log-probabilities for the label tokens
    log_probs = torch.log_softmax(label_logits, dim=-1)  # Shape: (n_samples, a_tokens, vocab_size)
    label_log_probs = log_probs.gather(2, label_tokens.unsqueeze(-1)).squeeze(-1)  # Shape: (n_samples, a_tokens)

    # Sum log-probabilities over all label tokens for each sample
    # total_log_likelihood = label_log_probs.sum(dim=1)  # Shape: (n_samples,)
    total_log_likelihood = label_log_probs.mean(dim=1)  # Shape: (n_samples,)

    return total_log_likelihood.detach().cpu()

In [3]:
model_name = "sft_llama-1b"
task_name = "gsm8k"
model_type, pt_name = model_name.split("_")
hf_name = HF_NAME_MAP[pt_name]

tokenizer = AutoTokenizer.from_pretrained(hf_name)
init_tokenizer(tokenizer)

dset = load_from_disk("data/ver_sft_llama-1b_gsm8k/test")

# verifier_path = f"models/veri_{model_name}_{task_name}"
# verifier_path = f"models/{model_name}_{task_name}"
paths = [
    "/home/chanwoo/chanwoo/repo/verifier/models/verifier/checkpoints/veri_sft_llama-1b_gsm8k/checkpoint-13149/target",
    f"models/{model_name}_{task_name}",
]


In [4]:
import re
from datetime import datetime

# verifier_path = "/home/chanwoo/chanwoo/repo/verifier/models/verifier/checkpoints/veri_sft_llama-8b_gsm8k/checkpoint-3600/target"
# verifier_path = "models/verifier/checkpoints/veri_sft_llama-1b_gsm8k/checkpoint-13149/target"
# verifier_path = "/home/chanwoo/chanwoo/repo/verifier/models/verifier/checkpoints/veri_sft_llama-1b_gsm8k/checkpoint-13149/target"
verifier_path = "/home/chanwoo/chanwoo/repo/verifier/models/veri_sft_llama-8b_gsm8k/target"
verifier = AutoModelForCausalLM.from_pretrained(
    verifier_path,
    quantization_config=BitsAndBytesConfig(**config['qt']),
    **config['model'][pt_name]
)


# gen_model = AutoModelForCausalLM.from_pretrained(
#     f"models/{model_name}_{task_name}",
#     quantization_config=BitsAndBytesConfig(**config['qt']),
#     **config['model'][pt_name]
# )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
def parse_log(file_path, start_time_str):
    start_time = datetime.strptime(start_time_str, "%Y-%m-%d %H:%M:%S")
    results = []

    with open(file_path, "r") as log_file:
        lines = log_file.readlines()

    parsing = False
    current_entry = {}

    for line in lines:
        # Check if we should start parsing after the specific time
        match_time = re.match(r"\[(.*?)\]", line)
        if match_time:
            log_time = datetime.strptime(match_time.group(1), "%Y-%m-%d %H:%M:%S,%f")
            if log_time >= start_time and "Starting Evaluate script" in line:
                parsing = True
                print("Parsing started")


        if not parsing:
            continue

        # Parse Question
        elif "[INFO] - Question:" in line:
            question_match = re.search(r"Question: (.*)", line)
            if question_match:
                current_entry["Question"] = eval(question_match.group(1))  # Safely parse list

        # Parse Prediction
        elif "[INFO] - Prediction:" in line and "INFO" in line:
            prediction_match = re.search(r"Prediction: (.*)", line)
            if prediction_match:
                current_entry["Prediction"] = eval(prediction_match.group(1))  # Safely parse list

        # Parse Answer
        elif "[INFO] - Answer:" in line:
            answer_match = re.search(r"Answer: (.*)", line)
            if answer_match:
                current_entry["Answer"] = eval(answer_match.group(1))  # Safely parse list

        # If all fields are collected, save the entry and reset
        if all(key in current_entry for key in ["Question", "Prediction", "Answer"]):
            results.append(current_entry)
            current_entry = {}

    return results

# Example usage
log_path = "logs/test_verifier-8b.log"
start_time = "2024-12-11 14:25:34"
parsed_data = parse_log(log_path, start_time)

res = []
for entry in parsed_data:
    # print("Best Answer Index:", entry["Best answer index"])
    # print("Question:", entry["Question"])
    # print("Prediction:", entry["Prediction"])
    # print("Answer:", entry["Answer"])
    # print("-" * 80)
    # break
    try:
        answer_pat = r'####\s*\d+'
        res.append({
            'question': entry["Question"],
            'prediction': entry["Prediction"],
            'answer': entry["Answer"],
        })
    except Exception as e:
        print(e)

Parsing started


In [6]:
train_dataset = load_from_disk(f"data/ver_sft_llama-8b_gsm8k/train")
train_dataset[1]

{'prompt': 'Ken created a care package to send to his brother, who was away at boarding school.  Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds.  Then, he added enough brownies to cause the weight to triple.  Next, he added another 2 pounds of jelly beans.  And finally, he added enough gummy worms to double the weight once again.  What was the final weight of the box of goodies, in pounds?',
 'chosen': 'In pounds, the box started at <<0=0>>0.\nKen added enough jelly beans to cause the weight to rise to 2 pounds, so now the weight was 2 pounds.\nA tripled weight, which is equal to 2 pounds, is 2*3=<<2*3=6>>6 pounds.\nKen then added another 2 pounds so now the weight was 6+2=<<2+6=8>>8 pounds.\nThen, he added enough gummy worms to double the weight again, 8 pounds, so the final weight becomes 8*2=<<8*2=16>>16 pounds.\n#### 16 pounds',
 'rejected': 'All told, 2 pounds of jelly beans were added to the care package, and the br

In [7]:
prompt = train_dataset['prompt'][100]
chosen = train_dataset['chosen'][100]
rejected = train_dataset['rejected'][100]

In [8]:
prompt_answer = tokenizer(prompt + chosen, return_tensors="pt").input_ids
chosen_token = tokenizer(chosen, return_tensors="pt").input_ids
logits = verifier(prompt_answer).logits.detach().cpu()
label_len = chosen_token.shape[1]
torch.gather(logits[:,-label_len:-1,:].log_softmax(-1), dim=2, index=chosen_token[:,1:].unsqueeze(2)).squeeze(2).sum()

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


tensor(-127.2425)

In [9]:
prompt_answer = tokenizer(prompt + rejected, return_tensors="pt").input_ids
rejected_token = tokenizer(rejected, return_tensors="pt").input_ids
logits = verifier(prompt_answer).logits.detach().cpu()
label_len = rejected_token.shape[1]
torch.gather(logits[:,-label_len:-1,:].log_softmax(-1), dim=2, index=rejected_token[:,1:].unsqueeze(2)).squeeze(2).sum()

tensor(-187.8689)

In [10]:
idx = 24
instruction = "Judge whether the reasoning from given question is reasonable?\nQuestion: "

def is_answer(answer, pred):
    pn, gn = _extract_answer(pred, answer)
    return pn == gn


In [11]:
print(entry["Question"][0])
print()
print(entry["Answer"][0])

Peter has 4 boxes with the same number of chocolate bars in each, while Martha has 7 boxes with the same number of chocolate bars in each. If Peter and Martha have totals of 64 and 56 chocolate bars respectively, how many more chocolate bars does Peter have in each box than Martha?

Peter has 64 chocolate bars in 4 equal boxes so there are 64/4 = <<64/4=16>>16 bars in each box
Martha has 56 chocolate bars in 7 equal boxes so there are 56/7 = <<56/7=8>>8 bars in each box
Peter has 16-8 = <<16-8=8>>8 bars more than Martha in each box
#### 8


In [12]:
print(entry["Prediction"][8])

 
Answer: Since for each box Martha has 56/7=<<56/7=8>>8 chocolate bars
That means in each box Martha has 8-4=<<8-4=4>>4 fewer chocolate bars than in each box does Peter
#### 4 fewer


In [13]:
from tqdm import tqdm
exp_res = []
for idx, entry in tqdm(enumerate(parsed_data[:100])):
    is_answers = [is_answer(entry["Answer"][0], pred) for pred in entry["Prediction"]]
    total_num_answer = sum(is_answers)
    questions = [instruction + entry["Question"][0] + pred for pred in entry["Prediction"]]

    # Step 2: Tokenize 입력 배치
    batch_inputs = tokenizer(questions, return_tensors='pt', padding=True, padding_side='right')
    attention_mask = batch_inputs["attention_mask"]


    # Step 3: Tokenize 정답 레이블 (패딩 포함)
    label_tokens = tokenizer(entry["Prediction"], return_tensors="pt", padding=True, padding_side='right')

    # Step 4: Verifier 모델 호출 (배치 처리)
    logits = verifier(batch_inputs.input_ids, attention_mask=attention_mask).logits.detach().cpu()

    # Step 5: Mask 생성 및 로그 확률 계산
    label_mask = (label_tokens.input_ids != tokenizer.pad_token_id)  # 패딩이 아닌 부분은 True
    shifted_labels = label_tokens.input_ids[:, 1:]  # 첫 번째 토큰 제외 (Decoder 방식)

    log_probs = torch.gather(
        logits[:, -shifted_labels.shape[1]-1:, :].log_softmax(-1),  # Label 길이에 맞춘 logits
        dim=2,
        index=shifted_labels.unsqueeze(2)  # 레이블 차원 확장
    ).squeeze(2)  # [Batch, Sequence]

    # Mask 적용 후, 각 문장의 총 로그 확률 계산
    masked_log_probs = log_probs * label_mask[:, 1:].float()  # 첫 번째 패딩 제외
    total_log_probs = masked_log_probs.sum(dim=1)  # 배치별 로그 확률 합계

    exp_res.append({
        "is_answer": is_answers,
        "total_num_answer": total_num_answer,
        "total_log_probs": total_log_probs,
        "questions": questions,
        "label_tokens": label_tokens,
        "logits": logits,
        "log_probs": log_probs,
        "masked_log_probs": masked_log_probs,
    })

100it [41:59, 25.19s/it]


In [22]:
import pickle

with open("exp_res.pkl", "wb") as f:
    pickle.dump(exp_res, f)

KeyboardInterrupt: 

: 

In [21]:
acc = 0
for i in range(100):
    ans_ind = exp_res[i]['total_log_probs'].argmax() 
    acc += exp_res[i]['is_answer'][ans_ind]
acc / 100

0.52

In [18]:
exp_res[0]['total_log_probs'].argmax()

tensor(6)

In [14]:
# Step 1: Instruction 준비
instruction = ""
questions = [instruction + entry["Question"][0] + pred for pred in entry["Prediction"]]

# Step 2: Tokenize 입력 배치
batch_inputs = tokenizer(questions, return_tensors='pt', padding=True, padding_side='right')
attention_mask = batch_inputs["attention_mask"]


# Step 3: Tokenize 정답 레이블 (패딩 포함)
label_tokens = tokenizer(entry["Prediction"], return_tensors="pt", padding=True, padding_side='right')

# Step 4: Verifier 모델 호출 (배치 처리)
logits = gen_model(batch_inputs.input_ids, attention_mask=attention_mask).logits.detach().cpu()

# Step 5: Mask 생성 및 로그 확률 계산
label_mask = (label_tokens.input_ids != tokenizer.pad_token_id)  # 패딩이 아닌 부분은 True
shifted_labels = label_tokens.input_ids[:, 1:]  # 첫 번째 토큰 제외 (Decoder 방식)

log_probs = torch.gather(
    logits[:, -shifted_labels.shape[1]-1:, :].log_softmax(-1),  # Label 길이에 맞춘 logits
    dim=2,
    index=shifted_labels.unsqueeze(2)  # 레이블 차원 확장
).squeeze(2)  # [Batch, Sequence]

# Mask 적용 후, 각 문장의 총 로그 확률 계산
masked_log_probs = log_probs * label_mask[:, 1:].float()  # 첫 번째 패딩 제외
total_log_probs = masked_log_probs.sum(dim=1)  # 배치별 로그 확률 합계
print(total_log_probs)
print(total_log_probs.argmax())

NameError: name 'gen_model' is not defined

In [None]:
# instruction = ""
# # instruction = "Judge whether the reasoning from given question is reasonable?\nQuestion: "
# for idx in range(len(entry["Prediction"])):
#     prompt_answer = tokenizer(instruction + entry["Question"][0] + entry["Prediction"][idx], return_tensors='pt').input_ids
#     label_token = tokenizer(entry["Prediction"][idx], return_tensors="pt").input_ids
#     logits = verifier(prompt_answer).logits.detach().cpu()
#     label_len = label_token.shape[1]
#     logp = torch.gather(logits[:,-label_len:-1,:].log_softmax(-1), dim=2, index=label_token[:,1:].unsqueeze(2)).squeeze(2).sum()
    
#     # logp = 0
#     # for i in range(1, label_token.shape[1]):
#     #     logp += logits[:, -label_len-1+i, :].log_softmax(-1)[:,label_token[0,i]]
#     # logp
#     print(idx, logp)
#     # break
