In [None]:
from datasets import load_from_disk, concatenate_datasets

ds_korean = load_from_disk("../dataset/분류 모델 용/korean")["test"].select(range(124))
ds_cloth = load_from_disk("../dataset/분류 모델 용/cloth")["test"].select(range(476))
ds_race_middle_long = load_from_disk("../dataset/분류 모델 용/race_middle_long")["test"].select(range(175))
ds_race_middle_short = load_from_disk("../dataset/분류 모델 용/race_middle_short")["test"].select(range(175))
ds_race_high_long = load_from_disk("../dataset/분류 모델 용/race_high_long")["test"].select(range(525))
ds_race_high_short = load_from_disk("../dataset/분류 모델 용/race_high_short")["test"].select(range(525))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append('..')

from huggingface_hub import login
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig

In [None]:
from dotenv import load_dotenv

load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")

model_id = "./output/3"
# model_id = "google/gemma-3-4b-it"

torch_dtype = torch.bfloat16

login(token)

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch_dtype, 
    device_map="auto", 
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

args = SFTConfig(
    output_dir="output/results",         
    packing=True,             
    gradient_accumulation_steps=4,    
    gradient_checkpointing=True,       
    optim="adamw_torch_fused",      
    logging_steps=10,                   
    save_strategy="epoch",
    fp16=True if torch_dtype == torch.float16 else False,  
    bf16=True if torch_dtype == torch.bfloat16 else False, 

    #eval_strategy="epoch",
    #per_device_eval_batch_size=6,

    lr_scheduler_type="constant",         
    push_to_hub=False,                     
    report_to="tensorboard",         
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
        }
)

model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.25s/it]


In [None]:
import re
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnAnswerIs(StoppingCriteria):
    def __init__(self, tokenizer, pattern=r"answer\s*is\s*:?\s*$", lookback_tokens=64):
        super().__init__()
        self.tokenizer = tokenizer
        self.regex = re.compile(pattern, flags=re.IGNORECASE)
        self.lookback = lookback_tokens

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        tail = input_ids[0, -self.lookback:].tolist()
        text = self.tokenizer.decode(tail, skip_special_tokens=True)
        return bool(self.regex.search(text))


In [None]:
import build_prompt as bp

device = next(model.parameters()).device

def safe_letter_token_id(tokenizer, letter: str) -> int:
    vocab = tokenizer.get_vocab()
    letter = letter.upper()
    for tok in (f"▁{letter}", f" {letter}", letter):
        if tok in vocab:
            return vocab[tok]
    ids = tokenizer.encode(f" {letter}", add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[{'role': 'system', 'content': "You are a skilled English test-solving tutor.\nIn the passage, any text referred to as 'underlined' is shown in bold. Example: **Was Underline**\nRead the problem and reason step by step.\nUse no more than 5 steps, shorter is better.\nmax 50 characters per step, and write the final answer at the end."}, {'role': 'user', 'content': '\n[QUESTION]\n다음 글의 요지로 가장 적절한 것은?\n\n[PASSAGE]\nPeople sometimes make downward social comparisons ―\ncomparing themselves to inferior or worse-off others ― to feel\nbetter about themselves. This is self-enhancement at work.\nBut what happens when the only available comparison target\nwe have is superior or better off than we are? Can\nself-enhancement motives still be served in such situations?\nYes, they can, as captured by the self-evaluation maintenance\nmodel. According to this theory, we shift between two\nprocesses ― reflection and comparison ― in a way that lets us\nmaintain favorable self-views. In areas that are not 

In [6]:
def extract_gold_letter(ex):
    return ex["content"]

In [None]:
import re, json, datetime
from tqdm.auto import tqdm
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

def evaluate_until_answer_is(
    ds,
    ds_name,
    bp,                  
    model,
    tokenizer,
    max_items=None,
    do_sample=False,     
    temperature=0.7,
    top_p=0.9,
    lookback_tokens=64,
    save_results=True,
):
    device = next(model.parameters()).device
    model.eval()

    stopper = StoppingCriteriaList([StopOnAnswerIs(tokenizer, lookback_tokens=lookback_tokens)])
    option_letters = ['A', 'B', 'C', 'D', 'E']
    option_token_ids = [safe_letter_token_id(tokenizer, L) for L in option_letters]

    N = len(ds) if max_items is None else min(max_items, len(ds))
    correct = 0
    evaluated = 0
    stopped_cnt = 0
    skipped_cnt = 0
    results = []

    pbar = tqdm(range(N))
    for i in pbar:

        item = ds[i]
        item = bp.create_conversation_for_generate(item)
        msgs = item['messages'][:2]
        gold = extract_gold_letter(item['messages'][2])

        prompt_ids = tokenizer.apply_chat_template(
            msgs, add_generation_prompt=True, return_tensors="pt"
        ).to(device)

        gen_out = model.generate(
            prompt_ids,
            max_new_tokens=256,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            top_p=top_p if do_sample else None,
            stopping_criteria=stopper,
            return_dict_in_generate=True,
            use_cache=True,
        )
        seq = gen_out.sequences  
        gen_text = tokenizer.decode(seq[0, prompt_ids.shape[1]:], skip_special_tokens=True)
        
        is_skipped = False

        if not re.search(r"answer\s*is\s*:?\s*$", gen_text.strip(), re.IGNORECASE):
            skipped_cnt += 1
            is_skipped = True
        
        if is_skipped == False:
            stopped_cnt += 1

        with torch.inference_mode():
            out2 = model(seq.to(device))
            next_logits = out2.logits[:, -1, :].squeeze(0) 

        option_logits = next_logits[option_token_ids]             
        option_logprobs = torch.log_softmax(option_logits, dim=0) 
        option_probs = torch.softmax(option_logits, dim=0)     

        probs = option_probs.detach().cpu().float().tolist()
        pred_idx = int(torch.argmax(option_probs).item())
        pred = option_letters[pred_idx]

        if is_skipped == False:
            evaluated += 1
            
        is_correct = (pred == gold)

        if is_skipped == False:
            correct += int(is_correct)

        results.append({
            "index": i,
            "generated_until_stop": gen_text,
            "pred": pred,
            "probs": {L: float(p) for L, p in zip(option_letters, probs)},
            "gold": gold,
            "is_correct": is_correct,
            "is_skipped": is_skipped
        })

        pbar.set_postfix(
            acc=f"{(correct/max(1,evaluated))*100:.2f}%",
            stopped=stopped_cnt,
            skipped=skipped_cnt
        )

    final = {
        "model_id": getattr(model.config, "name_or_path", "unknown"),
        "num_items_total": N,
        "num_evaluated": evaluated,
        "num_stopped": stopped_cnt,
        "num_skipped": skipped_cnt,
        "accuracy_no_skipped": f"{(correct/max(1,evaluated))*100:.2f}%",
        "results": results
    }

    if save_results:
        ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./eval/results_{ds_name}_{ts}.json"
        with open(path, "w", encoding="utf-8") as f:
            json.dump(final, f, ensure_ascii=False, indent=2)
        print(f"\nSaved: {path}")

    print(f"\nEvaluated {evaluated}/{N} (stopped={stopped_cnt}, skipped={skipped_cnt})")
    print(f"Accuracy: {final['accuracy_no_skipped']}")
    return final

In [8]:
ds = [ds_korean, ds_cloth, ds_race_middle_long, ds_race_middle_short, ds_race_high_long, ds_race_high_short]
ds_name = ["ds_korean", "ds_cloth", "ds_race_middle_long", "ds_race_middle_short", "ds_race_high_long", "ds_race_high_short"]

In [9]:
for i in range(6):
    evaluate_until_answer_is(
        ds[i],
        ds_name[i],
        bp,
        model,
        tokenizer,
        max_items=None,  
        do_sample=False,   
        save_results=True,
    )

  0%|          | 0/124 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 1/124 [00:07<16:14,  7.92s/it, acc=100.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▏         | 2/124 [00:12<11:41,  5.75s/it, acc=100.00%, skipped=0, stopped=2]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▏         | 3/124 [00:17<11:27,  5.68s/it, acc=100.00%, skipped=0, stopped=3]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|▎         | 4/124 [00:23<11:44,  5.87s/it, acc=100.00%, skipped=0, stopped=4]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more de


Saved: ./eval/results_ds_korean_20250908_083715.json

Evaluated 123/124 (stopped=123, skipped=1)
Accuracy: 71.54%


  0%|          | 0/2000 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/2000 [00:05<2:49:46,  5.10s/it, acc=100.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/2000 [00:11<3:12:41,  5.79s/it, acc=100.00%, skipped=0, stopped=2]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 3/2000 [00:17<3:14:32,  5.85s/it, acc=100.00%, skipped=0, stopped=3]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 4/2000 [00:21<2:49:06,  5.08s/it, acc=75.00%, skipped=0, stopped=4] The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info


Saved: ./eval/results_ds_cloth_20250908_115307.json

Evaluated 1998/2000 (stopped=1998, skipped=2)
Accuracy: 54.65%


  0%|          | 0/772 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/772 [00:04<1:02:28,  4.86s/it, acc=100.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/772 [00:11<1:18:37,  6.13s/it, acc=100.00%, skipped=0, stopped=2]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 3/772 [00:16<1:07:44,  5.29s/it, acc=66.67%, skipped=0, stopped=3] The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 4/772 [00:21<1:10:22,  5.50s/it, acc=75.00%, skipped=0, stopped=4]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for 


Saved: ./eval/results_ds_race_middle_long_20250908_130238.json

Evaluated 772/772 (stopped=772, skipped=0)
Accuracy: 73.96%


  0%|          | 0/664 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/664 [00:06<1:12:09,  6.53s/it, acc=0.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/664 [00:10<54:37,  4.95s/it, acc=50.00%, skipped=0, stopped=2]  The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 3/664 [00:15<55:23,  5.03s/it, acc=66.67%, skipped=0, stopped=3]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 4/664 [00:19<51:44,  4.70s/it, acc=75.00%, skipped=0, stopped=4]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more det


Saved: ./eval/results_ds_race_middle_short_20250908_140034.json

Evaluated 662/664 (stopped=662, skipped=2)
Accuracy: 73.56%


  0%|          | 0/1686 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/1686 [00:06<3:03:04,  6.52s/it, acc=100.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/1686 [00:13<3:17:00,  7.02s/it, acc=50.00%, skipped=0, stopped=2] The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 3/1686 [00:22<3:39:24,  7.82s/it, acc=33.33%, skipped=0, stopped=3]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 4/1686 [00:28<3:18:21,  7.08s/it, acc=25.00%, skipped=0, stopped=4]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` 


Saved: ./eval/results_ds_race_high_long_20250908_165011.json

Evaluated 1686/1686 (stopped=1686, skipped=0)
Accuracy: 63.17%


  0%|          | 0/1812 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/1812 [00:05<2:59:16,  5.94s/it, acc=100.00%, skipped=0, stopped=1]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/1812 [00:10<2:42:49,  5.40s/it, acc=100.00%, skipped=0, stopped=2]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 3/1812 [00:16<2:42:14,  5.38s/it, acc=100.00%, skipped=0, stopped=3]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 4/1812 [00:22<2:47:28,  5.56s/it, acc=100.00%, skipped=0, stopped=4]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info


Saved: ./eval/results_ds_race_high_short_20250908_193820.json

Evaluated 1812/1812 (stopped=1812, skipped=0)
Accuracy: 63.74%



