In [None]:
from datasets import load_from_disk, concatenate_datasets

ds_korean = load_from_disk("../dataset/분류 모델 용/korean")["test"].select(range(124))
ds_cloth = load_from_disk("../dataset/분류 모델 용/cloth")["test"].select(range(476))
ds_race_middle_long = load_from_disk("../dataset/분류 모델 용/race_middle_long")["test"].select(range(175))
ds_race_middle_short = load_from_disk("../dataset/분류 모델 용/race_middle_short")["test"].select(range(175))
ds_race_high_long = load_from_disk("../dataset/분류 모델 용/race_high_long")["test"].select(range(525))
ds_race_high_short = load_from_disk("../dataset/분류 모델 용/race_high_short")["test"].select(range(525))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append('..')

from huggingface_hub import login
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

In [None]:
from dotenv import load_dotenv

load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")

model_id = "./output/0"
# model_id = "google/gemma-3-4b-it"

torch_dtype = torch.bfloat16

login(token)

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch_dtype, 
    device_map="auto", 
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

args = SFTConfig(
    output_dir="output/results",         
    packing=True,             
    gradient_accumulation_steps=4,    
    gradient_checkpointing=True,       
    optim="adamw_torch_fused",      
    logging_steps=10,                   
    save_strategy="epoch",
    fp16=True if torch_dtype == torch.float16 else False,  
    bf16=True if torch_dtype == torch.bfloat16 else False, 

    #eval_strategy="epoch",
    #per_device_eval_batch_size=6,

    lr_scheduler_type="constant",         
    push_to_hub=False,                     
    report_to="tensorboard",         
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
        }
)

model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.98s/it]


In [None]:
import re
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnAnswerIs(StoppingCriteria):
    def __init__(self, tokenizer, pattern=r"answer\s*is\s*:?\s*$", lookback_tokens=64):
        super().__init__()
        self.tokenizer = tokenizer
        self.regex = re.compile(pattern, flags=re.IGNORECASE)
        self.lookback = lookback_tokens

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        tail = input_ids[0, -self.lookback:].tolist()
        text = self.tokenizer.decode(tail, skip_special_tokens=True)
        return bool(self.regex.search(text))


In [None]:
import build_prompt as bp

device = next(model.parameters()).device 

def safe_letter_token_id(tokenizer, letter: str) -> int:
    vocab = tokenizer.get_vocab()
    letter = letter.upper()
    for tok in (f"▁{letter}", f" {letter}", letter):
        if tok in vocab:
            return vocab[tok]
    ids = tokenizer.encode(f" {letter}", add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


[{'role': 'system', 'content': "You are a skilled English test-solving tutor.\nIn the passage, any text referred to as 'underlined' is shown in bold. Example: **Was Underline**\nRead the problem and reason step by step.\nUse no more than 5 steps, shorter is better.\nmax 50 characters per step, and write the final answer at the end."}, {'role': 'user', 'content': '\n[QUESTION]\n다음 글의 요지로 가장 적절한 것은?\n\n[PASSAGE]\nPeople sometimes make downward social comparisons ―\ncomparing themselves to inferior or worse-off others ― to feel\nbetter about themselves. This is self-enhancement at work.\nBut what happens when the only available comparison target\nwe have is superior or better off than we are? Can\nself-enhancement motives still be served in such situations?\nYes, they can, as captured by the self-evaluation maintenance\nmodel. According to this theory, we shift between two\nprocesses ― reflection and comparison ― in a way that lets us\nmaintain favorable self-views. In areas that are not 

In [6]:
def extract_gold_letter(ex):
    return ex["content"]

In [None]:
import re, json, datetime
from tqdm.auto import tqdm
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

def evaluate_until_answer_is(
    ds,
    ds_name,
    bp,             
    model,
    tokenizer,
    max_items=None,
    do_sample=False,    
    temperature=0.7,
    top_p=0.9,
    lookback_tokens=64,
    save_results=True,
):
    device = next(model.parameters()).device
    model.eval()

    stopper = StoppingCriteriaList([StopOnAnswerIs(tokenizer, lookback_tokens=lookback_tokens)])
    option_letters = ['A', 'B', 'C', 'D', 'E']
    option_token_ids = [safe_letter_token_id(tokenizer, L) for L in option_letters]

    N = len(ds) if max_items is None else min(max_items, len(ds))
    correct = 0
    evaluated = 0
    stopped_cnt = 0
    skipped_cnt = 0
    results = []

    pbar = tqdm(range(N))
    for i in pbar:

        item = ds[i]
        item = bp.create_conversation_for_generate(item)
        msgs = item['messages'][:2]
        gold = extract_gold_letter(item['messages'][2])

        prompt_ids = tokenizer.apply_chat_template(
            msgs, add_generation_prompt=True, return_tensors="pt"
        ).to(device)

        gen_out = model.generate(
            prompt_ids,
            max_new_tokens=256,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            top_p=top_p if do_sample else None,
            stopping_criteria=stopper,
            return_dict_in_generate=True,
            use_cache=True,
        )
        seq = gen_out.sequences 
        gen_text = tokenizer.decode(seq[0, prompt_ids.shape[1]:], skip_special_tokens=True)
        
        is_skipped = False

        if not re.search(r"answer\s*is\s*:?\s*$", gen_text.strip(), re.IGNORECASE):
            skipped_cnt += 1
            is_skipped = True
        
        if is_skipped == False:
            stopped_cnt += 1

        with torch.inference_mode():
            out2 = model(seq.to(device))
            next_logits = out2.logits[:, -1, :].squeeze(0) 

        option_logits = next_logits[option_token_ids]           
        option_logprobs = torch.log_softmax(option_logits, dim=0) 
        option_probs = torch.softmax(option_logits, dim=0)     

        probs = option_probs.detach().cpu().float().tolist()
        pred_idx = int(torch.argmax(option_probs).item())
        pred = option_letters[pred_idx]

        if is_skipped == False:
            evaluated += 1
            
        is_correct = (pred == gold)

        if is_skipped == False:
            correct += int(is_correct)

        results.append({
            "index": i,
            "generated_until_stop": gen_text,
            "pred": pred,
            "probs": {L: float(p) for L, p in zip(option_letters, probs)},
            "gold": gold,
            "is_correct": is_correct,
            "is_skipped": is_skipped
        })

        pbar.set_postfix(
            acc=f"{(correct/max(1,evaluated))*100:.2f}%",
            stopped=stopped_cnt,
            skipped=skipped_cnt
        )

    final = {
        "model_id": getattr(model.config, "name_or_path", "unknown"),
        "num_items_total": N,
        "num_evaluated": evaluated,
        "num_stopped": stopped_cnt,
        "num_skipped": skipped_cnt,
        "accuracy_no_skipped": f"{(correct/max(1,evaluated))*100:.2f}%",
        "results": results
    }

    if save_results:
        ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./eval/results_{ds_name}_{ts}.json"
        with open(path, "w", encoding="utf-8") as f:
            json.dump(final, f, ensure_ascii=False, indent=2)
        print(f"\nSaved: {path}")

    print(f"\nEvaluated {evaluated}/{N} (stopped={stopped_cnt}, skipped={skipped_cnt})")
    print(f"Accuracy: {final['accuracy_no_skipped']}")
    return final

In [8]:
ds = [ds_korean, ds_cloth, ds_race_middle_long, ds_race_middle_short, ds_race_high_long, ds_race_high_short]
ds_name = ["ds_korean", "ds_cloth", "ds_race_middle_long", "ds_race_middle_short", "ds_race_high_long", "ds_race_high_short"]

In [9]:
for i in range(6):
    evaluate_until_answer_is(
        ds[i],
        ds_name[i],
        bp,
        model,
        tokenizer,
        max_items=None,  
        do_sample=False,   
        save_results=True,
    )

100%|██████████| 124/124 [07:25<00:00,  3.60s/it, acc=74.34%, skipped=11, stopped=113]



Saved: ./eval/results_ds_korean_20250908_014206.json

Evaluated 113/124 (stopped=113, skipped=11)
Accuracy: 74.34%


100%|██████████| 2000/2000 [1:11:54<00:00,  2.16s/it, acc=56.36%, skipped=41, stopped=1959]



Saved: ./eval/results_ds_cloth_20250908_025401.json

Evaluated 1959/2000 (stopped=1959, skipped=41)
Accuracy: 56.36%


100%|██████████| 772/772 [29:55<00:00,  2.33s/it, acc=73.95%, skipped=35, stopped=737]



Saved: ./eval/results_ds_race_middle_long_20250908_032357.json

Evaluated 737/772 (stopped=737, skipped=35)
Accuracy: 73.95%


100%|██████████| 664/664 [23:53<00:00,  2.16s/it, acc=77.40%, skipped=18, stopped=646]



Saved: ./eval/results_ds_race_middle_short_20250908_034750.json

Evaluated 646/664 (stopped=646, skipped=18)
Accuracy: 77.40%


100%|██████████| 1686/1686 [1:23:43<00:00,  2.98s/it, acc=66.20%, skipped=112, stopped=1574]



Saved: ./eval/results_ds_race_high_long_20250908_051134.json

Evaluated 1574/1686 (stopped=1574, skipped=112)
Accuracy: 66.20%


100%|██████████| 1812/1812 [1:13:33<00:00,  2.44s/it, acc=68.92%, skipped=65, stopped=1747]


Saved: ./eval/results_ds_race_high_short_20250908_062508.json

Evaluated 1747/1812 (stopped=1747, skipped=65)
Accuracy: 68.92%



