In [None]:
import os, re, json, datetime, sys
from typing import List, Dict, Tuple, Optional

import numpy as np
import torch
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

RE_OPT = re.compile(r"^[ \t]*([A-Ea-e])[.)]\s*(.+?)\s*$", re.MULTILINE)

def parse_options(text: str) -> Tuple[List[str], Dict[str, str]]:
    letters, mapping, seen = [], {}, set()
    for m in RE_OPT.finditer(text or ""):
        L = m.group(1).upper()
        opt = m.group(2).strip()
        if L not in seen:
            seen.add(L)
            letters.append(L)
            mapping[L] = opt
    return letters, mapping

def bnb_config(dtype: torch.dtype):
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=dtype,
        bnb_4bit_quant_storage=torch.uint8,
    )


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from datasets import load_from_disk, concatenate_datasets
from dotenv import load_dotenv
from huggingface_hub import login
load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
login(token)
BASE_MODEL_ID = "google/gemma-3-4b-it"
  
SEGMENT = "train"                              
USE_GENERATE = False                        
MAX_ITEMS = None                           

TORCH_DTYPE = torch.bfloat16

USE_4BIT = True

ATTN_IMPL = "eager"

sys.path.append("..")
from train.build_dataset_v2 import concat_datasets 

USE_TEST_DATASETS_LIST = [
    ["../dataset/my_korean", 900], 
    ["../dataset/my_race_middle", 3500], 
    ["../dataset/my_race_high", 10500], 
    ["../dataset/my_cloth", 5100],
]

raw0 = load_from_disk(USE_TEST_DATASETS_LIST[0][0])["train"]
raw1 = load_from_disk(USE_TEST_DATASETS_LIST[1][0])["train"]
raw2 = load_from_disk(USE_TEST_DATASETS_LIST[2][0])["train"]
raw3 = load_from_disk(USE_TEST_DATASETS_LIST[3][0])["train"]

dataset0 = raw0.shuffle(seed=42).select(range(USE_TEST_DATASETS_LIST[0][1]))
dataset1 = raw1.shuffle(seed=42).select(range(USE_TEST_DATASETS_LIST[1][1]))
dataset2 = raw2.shuffle(seed=42).select(range(USE_TEST_DATASETS_LIST[2][1]))
dataset3 = raw3.shuffle(seed=42).select(range(USE_TEST_DATASETS_LIST[3][1]))

dataset = concatenate_datasets([dataset0, dataset1, dataset2, dataset3])
dataset = dataset.shuffle(seed=42)

In [3]:
tok_id = BASE_MODEL_ID
tokenizer = AutoTokenizer.from_pretrained(tok_id, use_fast=True)

qconf = bnb_config(TORCH_DTYPE) if USE_4BIT else None
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=TORCH_DTYPE,
    device_map="auto",
    attn_implementation=ATTN_IMPL,
    quantization_config=qconf,
)

model = base
device = next(model.parameters()).device
model.eval()
print("Loaded on:", device)


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Loaded on: cuda:0


In [None]:
print(len(dataset), "samples")
print(dataset) 


20000 samples
Dataset({
    features: ['example_id', 'article', 'question', 'options', 'answer'],
    num_rows: 20000
})


In [None]:
import train.build_dataset_v2 as bd
import torch

with torch.no_grad():
    item = dataset[119]
    item = bd.create_conversation(item)
    msgs = item["messages"][:2]  
    prompt = tokenizer.apply_chat_template(
        msgs,
        add_generation_prompt=True,
        tokenize=False
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False 
    )

    gen_text = tokenizer.decode(
        out[0, inputs.input_ids.shape[1]:],
        skip_special_tokens=True
    )

print("=== GENERATED ===")
print(gen_text)
# print(item["messages"][1])
print(item["messages"][2])

=== GENERATED ===
1. Main: discuss food safety incidents
2. Text: KFC's food safety issues
3. A: introduce food safety incidents
4. D: appeal to people for food safety
answer is: D
{'role': 'assistant', 'content': 'D'}


In [None]:
def find_gen_start_for_row(inputs, out_ids, row_idx):
    attn = inputs.attention_mask[row_idx].bool()
    prompt_ids = inputs.input_ids[row_idx][attn]
    seq = out_ids[row_idx]

    gen_start = None
    plen = prompt_ids.size(0)
    limit = seq.size(0) - plen
    for s in range(max(0, limit - 4), limit + 1):
        if torch.equal(seq[s:s+plen], prompt_ids):
            gen_start = s + plen
            break
    if gen_start is None:
        for s in range(0, limit + 1):
            if torch.equal(seq[s:s+plen], prompt_ids):
                gen_start = s + plen
        if gen_start is None:
            gen_start = plen
    return gen_start


In [7]:
print(dataset[0]["answer"])

A


In [None]:
import re, time, copy, math, logging, torch
from tqdm.auto import tqdm

logging.basicConfig(level=logging.INFO, format="%(message)s")

model.eval()
device = next(model.parameters()).device

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
if getattr(tokenizer, "pad_token_id", None) is None:
    tokenizer.pad_token = tokenizer.eos_token

ANS_PATTERNS = [
    re.compile(r"answer\s*is[:\s]*([A-Z])\b", re.IGNORECASE),
    re.compile(r"\b([A-E])\b") 
]

def extract_letter_from_text(text: str) -> str | None:
    m = ANS_PATTERNS[0].search(text or "")
    if m: return m.group(1).upper()
    cands = ANS_PATTERNS[1].findall(text or "")
    return cands[-1].upper() if cands else None

BATCH_SIZE = 4
MAX_NEW_TOKENS = 256  

def generate_preds(dataset, batch_size, max_new_tokens):
    n = len(dataset)
    pred_list    = [None]  * n
    correct_list = [False] * n
    attempted = correct = unparsable = gold_missing = 0
    gen_time_sum = 0.0
    gen_tokens_sum = 0
    reduced_budget_batches = 0
    zero_budget_samples = 0

    num_batches = math.ceil(n / batch_size)
    bar = tqdm(range(num_batches), desc="Generating", unit="batch")

    for b in bar:
        start = b * batch_size
        end   = min(start + batch_size, n)

        batch_indices, batch_prompts = [], []
        for idx in range(start, end):
            item = dataset[idx]
            item = bd.create_conversation(item)
            if "messages" not in item or len(item["messages"]) < 2:
                continue
            msgs = item["messages"][:2]
            prompt = tokenizer.apply_chat_template(
                msgs, add_generation_prompt=True, tokenize=False
            )
            batch_indices.append(idx)
            batch_prompts.append(prompt)

        if not batch_prompts:
            bar.set_postfix({"acc":"n/a","att":attempted,"ok":correct,
                            "unpars":unparsable,"gold_miss":gold_missing})
            continue

        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=False 
        ).to(device)

        t0 = time.perf_counter()
        with torch.no_grad():
            out_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
        dt = time.perf_counter() - t0
        gen_time_sum += dt

        for bi, idx in enumerate(batch_indices):
            gen_start = find_gen_start_for_row(inputs, out_ids, bi)
            old_tokens = out_ids[bi, :gen_start]
            new_tokens = out_ids[bi, gen_start:]
            input_text = tokenizer.decode(old_tokens, skip_special_tokens=True)
            gen_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            gen_tokens_sum += int(new_tokens.numel())

            gold = dataset[idx]["answer"]

            # print("=====in=====")
            # print(input_text)
            # print("=====gen=====")
            # print(gen_text)
            # print(gold)
            
            if gold is None:
                gold_missing += 1

            pred = extract_letter_from_text(gen_text)
            if pred is None:
                unparsable += 1
                attempted += 1
                pred_list[idx] = None
                correct_list[idx] = False
            else:
                attempted += 1
                ok = (gold is not None and pred == gold)
                correct += int(ok)
                pred_list[idx] = gen_text
                correct_list[idx] = ok

        acc = (correct / attempted * 100) if attempted else 0.0
        toks_per_s = (gen_tokens_sum / gen_time_sum) if gen_time_sum > 0 else 0.0
        bar.set_postfix({
            "acc": f"{acc:.2f}%",
            "att": attempted,
            "ok": correct,
            "unpars": unparsable,
            "gold_miss": gold_missing,
            "toks/s": f"{toks_per_s:.1f}",
            "reduced_batches": reduced_budget_batches,
            "zero_budget": zero_budget_samples
        })

    dataset_copy = dataset.add_column("pred", pred_list)
    dataset_copy = dataset_copy.add_column("correct", correct_list)
    return dataset_copy

result = generate_preds(dataset, BATCH_SIZE, MAX_NEW_TOKENS)



Generating:   0%|          | 0/5000 [00:00<?, ?batch/s]

Generating: 100%|██████████| 5000/5000 [6:09:05<00:00,  4.43s/batch, acc=63.90%, att=2e+4, ok=12780, unpars=53, gold_miss=0, toks/s=55.7, reduced_batches=0, zero_budget=0]   


In [9]:
import os, datetime

OUT_DIR = "./generate_output"
os.makedirs(OUT_DIR, exist_ok=True)
base = os.path.join(OUT_DIR, "init_generate")

result.save_to_disk(base) 
print(f"Saved HF Arrow shards to: {base}")

Saving the dataset (1/1 shards): 100%|██████████| 20000/20000 [00:00<00:00, 598758.60 examples/s]

Saved HF Arrow shards to: ./generate_output/init_generate





Rationalization

In [None]:
import train.build_dataset_rationalization as bdr
from datasets import load_from_disk

ds = load_from_disk("./generate_output/init_generate")
print("Loaded:", ds.num_rows, "rows; columns:", ds.column_names)

def is_incorrect(ex):
    v = ex["correct"] if "correct" in ex else False
    if isinstance(v, bool):
        return (v is False)
    if isinstance(v, str):
        return v.strip().lower() in {"false", "0", "no"}
    if isinstance(v, int):
        return v == 0
    return True

incorrect = ds.filter(is_incorrect)
print("Incorrect:", incorrect.num_rows)

for i in range(min(3, incorrect.num_rows)):
    print("---")
    print("pred:", incorrect[i].get("pred"))
    print("correct flag:", incorrect[i].get("correct"))

Loaded: 20000 rows; columns: ['example_id', 'article', 'question', 'options', 'answer', 'pred', 'correct']


Filter: 100%|██████████| 20000/20000 [00:00<00:00, 99553.87 examples/s]

Incorrect: 7220
---
pred: 1. Context: power outage
2. Text: something blocked power
3. A: poles (electricity poles)
answer is: A
correct flag: False
---
pred: 1. Context: growth and becoming more real
2. Best fit: A. cleverer (implies growth in understanding and maturity)
correct flag: False
---
pred: 1. Identify the main point about Descartes' experience.
2. Recognize the role of dreams in Descartes' discovery.
3. Understand that the discovery was significant and influenced his life path.
4. Realize that the dream contributed to an important discovery, not always the case.
answer is: B. Dreams sometimes contribute to important discoveries.
correct flag: False





In [None]:
def generate_preds_rationalization(dataset, batch_size, max_new_tokens):
    n = len(dataset)
    pred_list    = [None]  * n
    correct_list = [False] * n
    attempted = correct = unparsable = gold_missing = 0
    gen_time_sum = 0.0
    gen_tokens_sum = 0
    reduced_budget_batches = 0
    zero_budget_samples = 0

    num_batches = math.ceil(n / batch_size)
    bar = tqdm(range(num_batches), desc="Generating", unit="batch")

    for b in bar:
        start = b * batch_size
        end   = min(start + batch_size, n)

        batch_indices, batch_prompts = [], []
        for idx in range(start, end):
            item = dataset[idx]
            item = bd.create_conversation_for_rational(item)
            if "messages" not in item or len(item["messages"]) < 2:
                continue
            msgs = item["messages"][:2]
            prompt = tokenizer.apply_chat_template(
                msgs, add_generation_prompt=True, tokenize=False
            )
            batch_indices.append(idx)
            batch_prompts.append(prompt)

        if not batch_prompts:
            bar.set_postfix({"acc":"n/a","att":attempted,"ok":correct,
                            "unpars":unparsable,"gold_miss":gold_missing})
            continue

        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=False 
        ).to(device)

        t0 = time.perf_counter()
        with torch.no_grad():
            out_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False
            )
        dt = time.perf_counter() - t0
        gen_time_sum += dt

        for bi, idx in enumerate(batch_indices):
            gen_start = find_gen_start_for_row(inputs, out_ids, bi)
            old_tokens = out_ids[bi, :gen_start]
            new_tokens = out_ids[bi, gen_start:]
            input_text = tokenizer.decode(old_tokens, skip_special_tokens=True)
            gen_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            gen_tokens_sum += int(new_tokens.numel())

            gold = dataset[idx]["answer"]

            # print("=====in=====")
            # print(input_text)
            # print("=====gen=====")
            # print(gen_text)
            # print(gold)
            
            if gold is None:
                gold_missing += 1

            pred = extract_letter_from_text(gen_text)
            if pred is None:
                unparsable += 1
                attempted += 1
                pred_list[idx] = None
                correct_list[idx] = False
            else:
                attempted += 1
                ok = (gold is not None and pred == gold)
                correct += int(ok)
                pred_list[idx] = gen_text
                correct_list[idx] = ok

        acc = (correct / attempted * 100) if attempted else 0.0
        toks_per_s = (gen_tokens_sum / gen_time_sum) if gen_time_sum > 0 else 0.0
        bar.set_postfix({
            "acc": f"{acc:.2f}%",
            "att": attempted,
            "ok": correct,
            "unpars": unparsable,
            "gold_miss": gold_missing,
            "toks/s": f"{toks_per_s:.1f}",
            "reduced_batches": reduced_budget_batches,
            "zero_budget": zero_budget_samples
        })
        
    dataset_copy = dataset.remove_columns("pred")
    dataset_copy = dataset_copy.remove_columns("correct")
    dataset_copy = dataset_copy.add_column("pred", pred_list)
    dataset_copy = dataset_copy.add_column("correct", correct_list)
    return dataset_copy

In [12]:
rationalization = generate_preds_rationalization(incorrect, BATCH_SIZE, MAX_NEW_TOKENS)

base2 = os.path.join(OUT_DIR, "init_rationalization")
rationalization.save_to_disk(base2) 
print(f"Saved HF Arrow shards to: {base2}")

Generating: 100%|██████████| 1805/1805 [3:27:10<00:00,  6.89s/batch, acc=95.76%, att=7220, ok=6914, unpars=90, gold_miss=0, toks/s=61.1, reduced_batches=0, zero_budget=0]  
Flattening the indices: 100%|██████████| 7220/7220 [00:00<00:00, 35265.30 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 7220/7220 [00:00<00:00, 578889.64 examples/s]

Saved HF Arrow shards to: ./generate_output/init_rationalization





In [None]:
from datasets import load_from_disk, concatenate_datasets, Dataset
ds1 = load_from_disk("./generate_output/init_generate")
ds2 = load_from_disk("./generate_output/init_rationalization")

def drop_incorrect(ds):
    def _filter(dset):
        return dset.filter(
            lambda batch: [bool(x) for x in batch["correct"]],
            batched=True,
        )
    if isinstance(ds, Dataset):
        return _filter(ds)
    else:
        raise TypeError("Unknown dataset type")

ds1_ok = drop_incorrect(ds1)
ds2_ok = drop_incorrect(ds2)

merge = concatenate_datasets([ds1_ok, ds2_ok])


Filter: 100%|██████████| 20000/20000 [00:00<00:00, 1164405.21 examples/s]
Filter:   0%|          | 0/7220 [00:00<?, ? examples/s]

Filter: 100%|██████████| 7220/7220 [00:00<00:00, 950438.61 examples/s]


In [14]:
a = 2
print(merge[a]['pred'])
print(merge[a]['answer'])
base3 = os.path.join(OUT_DIR, "init_merge")
merge.save_to_disk(base3) 
print(f"Saved HF Arrow shards to: {base3}")

1. Sonora Smart Dodd's father, Henry Jackson Smart, inspired the idea for Father's Day.
2. Henry Jackson Smart was kind to his daughter, Sonora Smart Dodd.
3. Options A, C, D are not directly related to the passage.
answer is: B. did a lot for his daughter
B


Saving the dataset (1/1 shards): 100%|██████████| 19694/19694 [00:00<00:00, 141959.40 examples/s]

Saved HF Arrow shards to: ./generate_output/init_merge





In [None]:
from datasets import load_from_disk, DatasetDict
from datasets import ClassLabel

DATA_DIR_IN  = "./generate_output/init_merge"
DATA_DIR_OUT = "./generate_output/loop0_dataset"
SEED = 42

ds = load_from_disk(DATA_DIR_IN)

unique_answers = sorted(set(ds["answer"]))
class_label = ClassLabel(num_classes=len(unique_answers), names=unique_answers)

ds_encoded = ds.class_encode_column("answer")

split1 = ds_encoded.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="answer")

train = split1["train"]
temp  = split1["test"]

split2 = temp.train_test_split(test_size=0.5, seed=SEED, stratify_by_column="answer")

validation = split2["train"]
test       = split2["test"]

dd = DatasetDict({
    "train": train,
    "validation": validation,
    "test": test,
})

print(dd)
dd.save_to_disk(DATA_DIR_OUT)
print(f"Saved to: {DATA_DIR_OUT}")
print("sizes:", {k: v.num_rows for k, v in dd.items()})


Casting to class labels: 100%|██████████| 19694/19694 [00:00<00:00, 340120.24 examples/s]


DatasetDict({
    train: Dataset({
        features: ['example_id', 'article', 'question', 'options', 'answer', 'pred', 'correct'],
        num_rows: 15755
    })
    validation: Dataset({
        features: ['example_id', 'article', 'question', 'options', 'answer', 'pred', 'correct'],
        num_rows: 1969
    })
    test: Dataset({
        features: ['example_id', 'article', 'question', 'options', 'answer', 'pred', 'correct'],
        num_rows: 1970
    })
})


Saving the dataset (1/1 shards): 100%|██████████| 15755/15755 [00:00<00:00, 146711.62 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1969/1969 [00:00<00:00, 126682.89 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1970/1970 [00:00<00:00, 131822.70 examples/s]

Saved to: ./generate_output/loop0_dataset
sizes: {'train': 15755, 'validation': 1969, 'test': 1970}





In [16]:
from datasets import load_from_disk

ds = load_from_disk("./generate_output/loop0_dataset")

In [None]:
mapping = {0: "A", 1: "B", 2: "C", 3: "D", 4: "E"}

def to_letter(batch):
    out = []
    for v in batch["answer"]:
        if isinstance(v, int):
            out.append(mapping.get(v, str(v)))
        elif isinstance(v, str):
            if v.isdigit():
                out.append(mapping.get(int(v), v))
            else:
                out.append(v)
        else:
            out.append(str(v))
    return {"answer_letter": out}

train2 = ds["train"].map(to_letter, batched=True, load_from_cache_file=False)
train2 = train2.remove_columns("answer").rename_column("answer_letter", "answer")

vali2 = ds["validation"].map(to_letter, batched=True, load_from_cache_file=False)
vali2 = vali2.remove_columns("answer").rename_column("answer_letter", "answer")

test2 = ds["test"].map(to_letter, batched=True, load_from_cache_file=False)
test2 = test2.remove_columns("answer").rename_column("answer_letter", "answer")

Map: 100%|██████████| 15755/15755 [00:00<00:00, 368157.13 examples/s]
Map: 100%|██████████| 1969/1969 [00:00<00:00, 274947.05 examples/s]
Map: 100%|██████████| 1970/1970 [00:00<00:00, 277079.20 examples/s]


In [18]:
from datasets import DatasetDict

complete = DatasetDict({
    "train": train2,
    "validation": vali2,
    "test": test2,
})

In [20]:
complete.save_to_disk("./dataset/loop0_dataset")

Saving the dataset (1/1 shards): 100%|██████████| 15755/15755 [00:00<00:00, 595160.45 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1969/1969 [00:00<00:00, 340433.84 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1970/1970 [00:00<00:00, 359924.16 examples/s]
