In [10]:
from datasets import load_from_disk, concatenate_datasets

ds_korean = load_from_disk("../dataset/분류 모델 용/korean")["test"].select(range(124))
ds_cloth = load_from_disk("../dataset/분류 모델 용/cloth")["test"].select(range(476))
ds_race_middle_long = load_from_disk("../dataset/분류 모델 용/race_middle_long")["test"].select(range(175))
ds_race_middle_short = load_from_disk("../dataset/분류 모델 용/race_middle_short")["test"].select(range(175))
ds_race_high_long = load_from_disk("../dataset/분류 모델 용/race_high_long")["test"].select(range(525))
ds_race_high_short = load_from_disk("../dataset/분류 모델 용/race_high_short")["test"].select(range(525))


In [11]:
print(ds_korean[3])

{'example_id': 3, 'article': 'People sometimes make downward social comparisons ―\ncomparing themselves to inferior or worse-off others ― to feel\nbetter about themselves. This is self-enhancement at work.\nBut what happens when the only available comparison target\nwe have is superior or better off than we are? Can\nself-enhancement motives still be served in such situations?\nYes, they can, as captured by the self-evaluation maintenance\nmodel. According to this theory, we shift between two\nprocesses ― reflection and comparison ― in a way that lets us\nmaintain favorable self-views. In areas that are not especially\nrelevant to our self-definition, we engage in reflection,\nwhereby we flatter ourselves by association with others’\naccomplishments. Suppose you care very little about your\nown athletic skills, but when your friend scores the winning\ngoal during a critical soccer match, you beam with pride,\nexperience a boost to your self-esteem, and take delight in\nher victory cele

In [12]:
import sys
sys.path.append('..')

from huggingface_hub import login
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

In [None]:
from dotenv import load_dotenv

load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")

model_id = "./output/4/checkpoint-2841"
# model_id = "google/gemma-3-4b-it"

torch_dtype = torch.bfloat16

login(token)

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch_dtype, 
    device_map="auto", 
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

args = SFTConfig(
    output_dir="output/results",         
    packing=True,             
    gradient_accumulation_steps=4,    
    gradient_checkpointing=True,       
    optim="adamw_torch_fused",      
    logging_steps=10,                   
    save_strategy="epoch",
    fp16=True if torch_dtype == torch.float16 else False,  
    bf16=True if torch_dtype == torch.bfloat16 else False, 

    #eval_strategy="epoch",
    #per_device_eval_batch_size=6,

    lr_scheduler_type="constant",         
    push_to_hub=False,                     
    report_to="tensorboard",         
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
        }
)

model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)

KeyboardInterrupt: 

: 

In [None]:
import build_dataset as bd

device = next(model.parameters()).device

def safe_letter_token_id(tokenizer, letter: str) -> int:
    vocab = tokenizer.get_vocab()
    letter = letter.upper()
    for tok in (f"▁{letter}", f" {letter}", letter):
        if tok in vocab:
            return vocab[tok]
    ids = tokenizer.encode(f" {letter}", add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]

In [None]:
def extract_gold_letter(ex):
    return ex["content"]

In [None]:
import re, json, datetime
from tqdm.auto import tqdm
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
import os

def evaluate_until_answer_is(
    ds,
    ds_name,
    bd,                    
    model,
    tokenizer,
    max_items=None,
    do_sample=False,       
    temperature=0.7,     
    top_p=0.9,              
    lookback_tokens=64,     
    save_results=True,
):

    device = next(model.parameters()).device
    model.eval()

    option_letters = ['A', 'B', 'C', 'D', 'E']
    option_token_ids = [safe_letter_token_id(tokenizer, L) for L in option_letters]

    N = len(ds) if max_items is None else min(max_items, len(ds))
    correct = 0
    evaluated = 0
    stopped_cnt = 0 
    skipped_cnt = 0 
    results = []

    pbar = tqdm(range(N))
    for i in pbar:
        item = ds[i]
        item = bd.create_conversation(item)
        msgs = item['messages'][:2]  
        gold = extract_gold_letter(item['messages'][2])  

        prompt_ids = tokenizer.apply_chat_template(
            msgs, add_generation_prompt=True, return_tensors="pt"
        ).to(device)

        with torch.inference_mode():
            out = model(prompt_ids)
            next_logits = out.logits[:, -1, :].squeeze(0) 

        option_logits = next_logits[option_token_ids]             
        option_logprobs = torch.log_softmax(option_logits, dim=0) 
        option_probs = torch.softmax(option_logits, dim=0)       

        probs = option_probs.detach().cpu().float().tolist()
        pred_idx = int(torch.argmax(option_probs).item())
        pred = option_letters[pred_idx]
        is_correct = (pred == gold)

        evaluated += 1
        correct += int(is_correct)

        results.append({
            "index": i,
            "generated_until_stop": "", 
            "pred": pred,
            "probs": {L: float(p) for L, p in zip(option_letters, probs)},
            "gold": gold,
            "is_correct": is_correct,
            "is_skipped": False
        })

        pbar.set_postfix(
            acc=f"{(correct/max(1,evaluated))*100:.2f}%",
            stopped=stopped_cnt,
            skipped=skipped_cnt
        )

    final = {
        "model_id": getattr(model.config, "name_or_path", "unknown"),
        "num_items_total": N,
        "num_evaluated": evaluated,
        "num_stopped": stopped_cnt, 
        "num_skipped": skipped_cnt, 
        "accuracy_no_skipped": f"{(correct/max(1,evaluated))*100:.2f}%",
        "results": results
    }

    if save_results:
        ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./eval/results_{ds_name}_{ts}.json"
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(final, f, ensure_ascii=False, indent=2)
        print(f"\nSaved: {path}")

    print(f"\nEvaluated {evaluated}/{N} (stopped={stopped_cnt}, skipped={skipped_cnt})")
    print(f"Accuracy: {final['accuracy_no_skipped']}")
    return final

In [None]:
ds = [ds_korean, ds_cloth, ds_race_middle_long, ds_race_middle_short, ds_race_high_long, ds_race_high_short]
ds_name = ["ds_korean", "ds_cloth", "ds_race_middle_long", "ds_race_middle_short", "ds_race_high_long", "ds_race_high_short"]

In [None]:
for i in range(6):
    evaluate_until_answer_is(
        ds[i],
        ds_name[i],
        bd,
        model,
        tokenizer,
        max_items=None,  
        do_sample=False,   
        save_results=True,
    )

 56%|█████▋    | 70/124 [00:08<00:06,  8.01it/s, acc=82.86%, skipped=0, stopped=0]


KeyboardInterrupt: 