In [1]:
from datasets import load_from_disk, concatenate_datasets

ds_korean = load_from_disk("../dataset/분류 모델 용/korean")["test"].select(range(1))
ds_cloth = load_from_disk("../dataset/분류 모델 용/cloth")["test"].select(range(1))
ds_race_middle_long = load_from_disk("../dataset/분류 모델 용/race_middle_long")["test"].select(range(1))
ds_race_middle_short = load_from_disk("../dataset/분류 모델 용/race_middle_short")["test"].select(range(1))
ds_race_high_long = load_from_disk("../dataset/분류 모델 용/race_high_long")["test"].select(range(1))
ds_race_high_short = load_from_disk("../dataset/분류 모델 용/race_high_short")["test"].select(range(1))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append('..')

from huggingface_hub import login
import torch
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
import os

In [3]:
from dotenv import load_dotenv

load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")

model_id = "./output/0"
base_model_id = "google/gemma-3-4b-it"

torch_dtype = torch.bfloat16

login(token)

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch_dtype, 
    device_map="auto", 
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

args = SFTConfig(
    output_dir="output/results",         
    packing=True,             
    gradient_accumulation_steps=4,    
    gradient_checkpointing=True,       
    optim="adamw_torch_fused",      
    logging_steps=10,                   
    save_strategy="epoch",
    fp16=True if torch_dtype == torch.float16 else False,  
    bf16=True if torch_dtype == torch.bfloat16 else False, 

    #eval_strategy="epoch",
    #per_device_eval_batch_size=6,

    lr_scheduler_type="constant",         
    push_to_hub=False,                     
    report_to="tensorboard",         
    dataset_kwargs={
        "add_special_tokens": False,
        "append_concat_token": True,
        }
)

tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base = AutoModelForCausalLM.from_pretrained(base_model_id, **model_kwargs)
model = PeftModel.from_pretrained(base, model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.16s/it]


In [4]:
import re
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnAnswerIs(StoppingCriteria):
    def __init__(self, tokenizer, pattern=r"answer\s*is\s*:?\s*$", lookback_tokens=64):
        super().__init__()
        self.tokenizer = tokenizer
        self.regex = re.compile(pattern, flags=re.IGNORECASE)
        self.lookback = lookback_tokens

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        tail = input_ids[0, -self.lookback:].tolist()
        text = self.tokenizer.decode(tail, skip_special_tokens=True)
        return bool(self.regex.search(text))


In [5]:
import build_prompt as bp

device = next(model.parameters()).device

def safe_letter_token_id(tokenizer, letter: str) -> int:
    vocab = tokenizer.get_vocab()
    letter = letter.upper()
    for tok in (f"▁{letter}", f" {letter}", letter):
        if tok in vocab:
            return vocab[tok]
    ids = tokenizer.encode(f" {letter}", add_special_tokens=False)
    if len(ids) == 1:
        return ids[0]

In [6]:
def extract_gold_letter(ex):
    return ex["content"]

In [None]:
import re, json, datetime
from tqdm.auto import tqdm
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

def evaluate_until_answer_is(
    ds,
    ds_name,
    bp,                  
    model,
    tokenizer,
    max_items=None,
    do_sample=False,     
    temperature=0.7,
    top_p=0.9,
    lookback_tokens=64,
    save_results=True,
):
    device = next(model.parameters()).device
    model.eval()

    stopper = StoppingCriteriaList([StopOnAnswerIs(tokenizer, lookback_tokens=lookback_tokens)])
    option_letters = ['A', 'B', 'C', 'D', 'E']
    option_token_ids = [safe_letter_token_id(tokenizer, L) for L in option_letters]

    N = len(ds) if max_items is None else min(max_items, len(ds))
    correct = 0
    evaluated = 0
    stopped_cnt = 0
    skipped_cnt = 0
    results = []

    pbar = tqdm(range(N))
    for i in pbar:

        item = ds[i]
        item = bp.create_conversation_for_generate(item)
        msgs = item['messages'][:2]
        gold = extract_gold_letter(item['messages'][2])

        prompt_ids = tokenizer.apply_chat_template(
            msgs, add_generation_prompt=True, return_tensors="pt"
        ).to(device)

        gen_out = model.generate(
            prompt_ids,
            max_new_tokens=256,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            top_p=top_p if do_sample else None,
            stopping_criteria=stopper,
            return_dict_in_generate=True,
            use_cache=True,
        )
        seq = gen_out.sequences  
        gen_text = tokenizer.decode(seq[0, prompt_ids.shape[1]:], skip_special_tokens=True)
        
        is_skipped = False

        if not re.search(r"answer\s*is\s*:?\s*$", gen_text.strip(), re.IGNORECASE):
            skipped_cnt += 1
            is_skipped = True
        
        if is_skipped == False:
            stopped_cnt += 1

        with torch.inference_mode():
            out2 = model(seq.to(device))
            next_logits = out2.logits[:, -1, :].squeeze(0) 

        option_logits = next_logits[option_token_ids]             
        option_logprobs = torch.log_softmax(option_logits, dim=0) 
        option_probs = torch.softmax(option_logits, dim=0)     

        probs = option_probs.detach().cpu().float().tolist()
        pred_idx = int(torch.argmax(option_probs).item())
        pred = option_letters[pred_idx]

        if is_skipped == False:
            evaluated += 1
            
        is_correct = (pred == gold)

        if is_skipped == False:
            correct += int(is_correct)

        results.append({
            "index": i,
            "generated_until_stop": gen_text,
            "pred": pred,
            "probs": {L: float(p) for L, p in zip(option_letters, probs)},
            "gold": gold,
            "is_correct": is_correct,
            "is_skipped": is_skipped
        })

        pbar.set_postfix(
            acc=f"{(correct/max(1,evaluated))*100:.2f}%",
            stopped=stopped_cnt,
            skipped=skipped_cnt
        )

    final = {
        "model_id": getattr(model.config, "name_or_path", "unknown"),
        "num_items_total": N,
        "num_evaluated": evaluated,
        "num_stopped": stopped_cnt,
        "num_skipped": skipped_cnt,
        "accuracy_no_skipped": f"{(correct/max(1,evaluated))*100:.2f}%",
        "results": results
    }

    if save_results:
        ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        path = f"./eval/results_{ds_name}_{ts}.json"
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump(final, f, ensure_ascii=False, indent=2)
        print(f"\nSaved: {path}")

    print(f"\nEvaluated {evaluated}/{N} (stopped={stopped_cnt}, skipped={skipped_cnt})")
    print(f"Accuracy: {final['accuracy_no_skipped']}")
    return final

In [8]:
ds = [ds_korean, ds_cloth, ds_race_middle_long, ds_race_middle_short, ds_race_high_long, ds_race_high_short]
ds_name = ["ds_korean", "ds_cloth", "ds_race_middle_long", "ds_race_middle_short", "ds_race_high_long", "ds_race_high_short"]

In [10]:
for i in range(6):
    evaluate_until_answer_is(
        ds[i],
        ds_name[i],
        bp,
        model,
        tokenizer,
        max_items=None,  
        do_sample=False,   
        save_results=True,
    )

  0%|          | 0/1 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 1/1 [00:00<00:00,  2.36it/s, acc=0.00%, skipped=1, stopped=0]


FileNotFoundError: [Errno 2] No such file or directory: './eval/results_ds_korean_20250929_153839.json'

In [None]:
import json
from pathlib import Path

IN_DIR = Path("./eval")
OUT1 = Path("./eval_1")
OUT2 = Path("./eval_2")
OUT1.mkdir(parents=True, exist_ok=True)
OUT2.mkdir(parents=True, exist_ok=True)

def recalc_stats(data: dict, new_results: list) -> dict:
    d = dict(data)
    n_total = len(new_results)
    n_skipped = sum(1 for r in new_results if r.get("is_skipped", False))
    n_eval = n_total - n_skipped
    n_stopped = n_eval

    n_correct = sum(
        1 for r in new_results
        if not r.get("is_skipped", False) and r.get("is_correct", False)
    )
    acc = f"{(n_correct / n_eval) * 100:.2f}%" if n_eval > 0 else "0.00%"

    d["results"] = new_results
    d["num_items_total"] = n_total
    d["num_evaluated"] = n_eval
    d["num_stopped"] = n_stopped
    d["num_skipped"] = n_skipped
    d["accuracy_no_skipped"] = acc
    return d

def count_metrics(results: list):
    n_total = len(results)
    n_skipped = sum(1 for r in results if r.get("is_skipped", False))
    n_eval = n_total - n_skipped
    n_correct_excl_skip = sum(
        1 for r in results
        if not r.get("is_skipped", False) and r.get("is_correct", False)
    )
    acc_excl = (n_correct_excl_skip / n_eval * 100) if n_eval > 0 else 0.0
    acc_incl = (n_correct_excl_skip / n_total * 100) if n_total > 0 else 0.0
    return {
        "correct_excl": n_correct_excl_skip,
        "eval_cnt": n_eval,
        "total_cnt": n_total,
        "skipped": n_skipped,
        "acc_excl": acc_excl,
        "acc_incl": acc_incl,
    }

processed = []

agg_1 = {"correct_excl": 0, "eval_cnt": 0, "total_cnt": 0, "skipped": 0}
agg_2 = {"correct_excl": 0, "eval_cnt": 0, "total_cnt": 0, "skipped": 0}

for path in sorted(IN_DIR.glob("*.json")):
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)

    results = data.get("results", [])
    n = len(results)

    n_front = n // 2
    front = results[:n_front]
    back = results[n_front:] 

    data_front = recalc_stats(data, front)
    data_back = recalc_stats(data, back)

    out1 = OUT1 / path.name
    out2 = OUT2 / path.name

    with out1.open("w", encoding="utf-8") as f:
        json.dump(data_front, f, ensure_ascii=False, indent=2)
    with out2.open("w", encoding="utf-8") as f:
        json.dump(data_back, f, ensure_ascii=False, indent=2)

    m1 = count_metrics(front)
    m2 = count_metrics(back)

    processed.append((
        path.name,
        len(front), len(back),
        f"{m1['acc_excl']:.2f}%", f"{m2['acc_excl']:.2f}%",  
        f"{m1['acc_incl']:.2f}%", f"{m2['acc_incl']:.2f}%", 
    ))

    for k in agg_1: agg_1[k] += m1[k]
    for k in agg_2: agg_2[k] += m2[k]

for name, nf, nb, accf_ex, accb_ex, accf_in, accb_in in processed:
    print(
        f"{name}: "
        f"eval_1(front)={nf}개 acc_excl_skip={accf_ex} | acc_incl_skip={accf_in}  ||  "
        f"eval_2(back)={nb}개 acc_excl_skip={accb_ex} | acc_incl_skip={accb_in}"
    )

def pretty_overall(agg):
    acc_excl = (agg["correct_excl"] / agg["eval_cnt"] * 100) if agg["eval_cnt"] > 0 else 0.0
    acc_incl = (agg["correct_excl"] / agg["total_cnt"] * 100) if agg["total_cnt"] > 0 else 0.0
    return acc_excl, acc_incl

acc1_ex, acc1_in = pretty_overall(agg_1)
acc2_ex, acc2_in = pretty_overall(agg_2)

print(f"eval_1(front): "
      f"correct={agg_1['correct_excl']}, "
      f"eval={agg_1['eval_cnt']}, total={agg_1['total_cnt']}, skipped={agg_1['skipped']} -> "
      f"acc_excl_skip={acc1_ex:.2f}% | acc_incl_skip={acc1_in:.2f}%")
print(f"eval_2(back) : "
      f"correct={agg_2['correct_excl']}, "
      f"eval={agg_2['eval_cnt']}, total={agg_2['total_cnt']}, skipped={agg_2['skipped']} -> "
      f"acc_excl_skip={acc2_ex:.2f}% | acc_incl_skip={acc2_in:.2f}%")

def peek(dir_name: str, file_name: str):
    p = Path(dir_name) / file_name
    with p.open("r", encoding="utf-8") as f:
        d = json.load(f)
    head = {k: d[k] for k in ["model_id", "num_items_total", "num_evaluated", "num_stopped", "num_skipped", "accuracy_no_skipped"] if k in d}
    print(p, "->", head)
