In [None]:
import os
import json
import time
import string
import random
import shutil
from typing import List, Dict, Tuple, Iterable, Set
from datasets import load_from_disk, concatenate_datasets, DatasetDict
import math
from dotenv import load_dotenv

load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

system_message = (
    "You are a helpful assistant.\n"
    "Provide reasoning only (no mention or hints about options).\n"
    "At the very end, output exactly: 'answer is X' where X is one of A,B,C,D,E.\n"
    "The phrase 'answer is' must appear once, in lowercase, and only on the final line.\n"
    "Do not reveal or imply the answer prior to that final line."
)

user_prompt = """
[QUESTION]
다음 글에 드러난 Natalie의 심경 변화로 가장 적절한 것은?

[PASSAGE]
My heart _ when I was asked to the back room by the immigration officer.

[OPTIONS]
A. beat
B. ached
C. rose
D. sank

[ANSWER]
The setting implies a sudden negative emotional shift—being called to a secluded area by an authority typically triggers dread or alarm. In English, this feeling is commonly expressed with an idiomatic verb that metaphorically indicates a downward drop in mood. The construction requires a simple past intransitive verb following “My heart …,” matching that idiom and conveying immediate fear or disappointment rather than physical action or neutral elevation.

answer is D

[QUESTION]
Which word best fits in the blank?

[PASSAGE]
My _ , with his very American last name, had no trouble at all.

[OPTIONS]
A. son
B. husband
C. friend
D. daughter

[ANSWER]
The appositive “with his very American last name” marks a male referent and highlights a contrast that explains differential treatment at immigration. The construction “My \_\_\_, with his …, had …” most naturally denotes a close, primary companion in travel and paperwork contexts rather than a child or casual acquaintance, and the tone suggests an adult relationship central to the speaker.

answer is B

[QUESTION]
다음 중 빈칸에 들어가기 가장 적절한 말은?

[PASSAGE]
Suddenly a pair of arms came around me and a small  _  said," Thank you, lady."

[OPTIONS]
A. boy
B. girl
C. voice
D. sound

[ANSWER]
The clause requires a singular noun that can grammatically and semantically serve as the subject of “said.” When the narrator perceives only sound rather than the speaker’s identity, English uses the idiomatic collocation “a small voice said…,” especially to indicate a child speaking softly. The preceding “arms came around me” provides a tactile cue while the subject of the reporting verb is supplied by the audible cue.

answer is C

[QUESTION]
{question}

[PASSAGE]
{passage}

[OPTIONS]
{option_block}

[ANSWER]
""".strip()


def format_options(opts: List[str]) -> str:
    letters = string.ascii_uppercase
    return "\n".join(f"{letters[i]}. {opt}" for i, opt in enumerate(opts))


def build_request_text(question: str, passage: str, options: List[str]) -> str:
    return user_prompt.format(
        question=question,
        passage=passage,
        option_block=format_options(options)
    )

def concat_raw(datasets_info: List[Tuple[str, int]], segment: str):
    first = load_from_disk(datasets_info[0][0])[segment]
    result = first.select([])
    for data_path, n in datasets_info:
        raw = load_from_disk(data_path)[segment]
        part = raw.shuffle(seed=42).select(range(min(n, len(raw))))
        result = concatenate_datasets([result, part])
    return result.shuffle(seed=42)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
BATCH_DIR = "./batch"
BATCH_COMPLETE_DIR = "./batch_complete"

def ensure_dirs():
    os.makedirs(BATCH_DIR, exist_ok=True)
    os.makedirs(BATCH_COMPLETE_DIR, exist_ok=True)

def choose_keys_for_batch(raw_dd: DatasetDict, segments=("train","validation","test"),
                          send_ratio: float = 0.7, seed: int = 42) -> Set[str]:
    rng = random.Random(seed)
    include: Set[str] = set()
    for seg in segments:
        if seg not in raw_dd:
            continue
        ds = raw_dd[seg]
        keys = [f"{seg}:{ex['example_id']}" for ex in ds]
        rng.shuffle(keys)
        n_send = round(len(keys) * send_ratio)
        include.update(keys[:n_send])
    return include

def export_batches_jsonl(
    raw_dd: DatasetDict,
    segments=("train","validation","test"),
    include_keys: Set[str] | None = None,
    model_temperature: float = 0.2,
    max_output_tokens: int = 5000,
    chunk_size: int = 100
) -> List[str]:
    ensure_dirs()

    all_reqs = []
    for segment in segments:
        if segment not in raw_dd:
            continue
        for row in raw_dd[segment]:
            key = f"{segment}:{row['example_id']}"
            if include_keys is not None and key not in include_keys:
                continue

            req_obj = {
                "key": key,
                "request": {
                    "contents": [{
                        "role": "user",
                        "parts": [{
                            "text": build_request_text(
                                question=row["question"],
                                passage=row["article"],
                                options=row["options"],
                            )
                        }]
                    }],
                    "generation_config": {
                        "temperature": model_temperature,
                        "max_output_tokens": max_output_tokens,
                    },
                    "system_instruction": {
                        "role": "system",
                        "parts": [{"text": system_message}]
                    }
                }
            }
            all_reqs.append(req_obj)

    jsonl_paths = []
    if not all_reqs:
        return jsonl_paths

    n_chunks = math.ceil(len(all_reqs) / chunk_size)
    for i in range(n_chunks):
        part = all_reqs[i*chunk_size:(i+1)*chunk_size]
        jsonl_path = os.path.join(BATCH_DIR, f"batch_{i+1:05d}.jsonl")
        with open(jsonl_path, "w", encoding="utf-8") as f:
            for req in part:
                f.write(json.dumps(req, ensure_ascii=False) + "\n")
        jsonl_paths.append(jsonl_path)

    return jsonl_paths

from google import genai
from google.genai import types

def submit_batch_job(jsonl_path: str, display_name: str = "gemini-explain-batch",
                     model_name: str = "gemini-2.5-flash"):
    client = genai.Client(api_key=GOOGLE_API_KEY)

    uploaded = client.files.upload(
        file=jsonl_path,
        config=types.UploadFileConfig(
            display_name=os.path.basename(jsonl_path),
            mime_type="jsonl"
        ),
    )
    job = client.batches.create(
        model=model_name,
        src=uploaded.name,
        config={"display_name": display_name},
    )
    return client, job.name

def wait_and_download_results(client: "genai.Client", job_name: str, poll_sec: int = 30) -> bytes:
    terminal = {
        "JOB_STATE_SUCCEEDED",
        "JOB_STATE_FAILED",
        "JOB_STATE_CANCELLED",
        "JOB_STATE_EXPIRED",
    }
    while True:
        job = client.batches.get(name=job_name)
        state = getattr(job.state, "name", str(job.state))
        print(f"[Batch] {job_name} state: {state}")
        if state in terminal:
            break
        time.sleep(poll_sec)

    if state != "JOB_STATE_SUCCEEDED":
        raise RuntimeError(f"Batch job finished with state={state}. error={getattr(job, 'error', None)}")

    if getattr(job, "dest", None) and getattr(job.dest, "file_name", None):
        return client.files.download(file=job.dest.file_name)

    if getattr(job, "dest", None) and getattr(job.dest, "inlined_responses", None):
        lines = []
        for idx, r in enumerate(job.dest.inlined_responses):
            if getattr(r, "response", None):
                try:
                    txt = r.response.text
                except Exception:
                    cand = r.response.candidates[0]
                    parts = cand.content.parts
                    txt = "".join(getattr(p, "text", "") for p in parts)
                lines.append(json.dumps({"key": f"idx:{idx}", "text": txt}, ensure_ascii=False))
            elif getattr(r, "error", None):
                lines.append(json.dumps({"key": f"idx:{idx}", "error": str(r.error)}, ensure_ascii=False))
        return ("\n".join(lines)).encode("utf-8")

def extract_text_from_response_obj(resp_obj: dict) -> str:
    cands = resp_obj.get("candidates") or []
    if cands:
        parts = (cands[0].get("content") or {}).get("parts") or []
        texts = [p.get("text", "") for p in parts if isinstance(p, dict) and "text" in p]
        return "".join(texts).strip()
    if "text" in resp_obj and isinstance(resp_obj["text"], str):
        return resp_obj["text"].strip()
    return ""

def load_key_to_text_from_results_bytes(result_bytes: bytes) -> Dict[str, str]:
    key2text: Dict[str, str] = {}
    for line in result_bytes.decode("utf-8").splitlines():
        if not line.strip():
            continue
        obj = json.loads(line)
        if "response" in obj:
            text = extract_text_from_response_obj(obj["response"])
            k = obj.get("key") or obj.get("metadata", {}).get("key")
            if not k:
                k = f"idx:{len(key2text)}"
            key2text[k] = text
        elif "error" in obj:
            k = obj.get("key", f"idx:{len(key2text)}")
            key2text[k] = f"[ERROR] {obj['error']}"
        else:
            text = extract_text_from_response_obj(obj)
            k = obj.get("key", f"idx:{len(key2text)}")
            key2text[k] = text
    return key2text

def process_batches_sequentially(
    model_name: str = "gemini-2.5-flash",
    display_name_prefix: str = "explain_mcq_batch",
    poll_seconds: int = 60
) -> Dict[str, str]:
    ensure_dirs()
    files = sorted(
        [os.path.join(BATCH_DIR, f) for f in os.listdir(BATCH_DIR) if f.endswith(".jsonl")]
    )
    all_key2text: Dict[str, str] = {}

    for idx, jsonl_path in enumerate(files, start=1):
        disp = f"{display_name_prefix}_{idx:05d}"
        print(f"\n==> Submitting {os.path.basename(jsonl_path)} as {disp}")
        client, job_name = submit_batch_job(jsonl_path, display_name=disp, model_name=model_name)
        result_bytes = wait_and_download_results(client, job_name, poll_sec=poll_seconds)
        key2text = load_key_to_text_from_results_bytes(result_bytes)
        print(f" -> parsed {len(key2text)} items from {os.path.basename(jsonl_path)}")
        all_key2text.update(key2text)

        dest_path = os.path.join(BATCH_COMPLETE_DIR, os.path.basename(jsonl_path))
        shutil.move(jsonl_path, dest_path)
        print(f" -> moved to {dest_path}")

    return all_key2text

def attach_predictions(raw_dd: DatasetDict, key2text: Dict[str, str],
                       segments=("train","validation","test")) -> DatasetDict:
    out = {}
    for seg in segments:
        if seg not in raw_dd:
            continue
        ds = raw_dd[seg]
        def add_pred(example):
            key = f"{seg}:{example['example_id']}"
            return {"pred": key2text.get(key, "")}
        out[seg] = ds.map(add_pred)
    return DatasetDict(out)


In [None]:
datasets_info = [
    ("../dataset/cloth", 12000),
]
segments = ("train",)    
send_ratio = 1.0     
random_seed = 42

out_dir_raw = "./combined_raw"   
out_dir_with_pred = "./combined_with_pred" 

model_name = "gemini-2.5-flash-lite"
temperature = 0.2
max_tokens = 5000
poll_seconds = 60
display_name_prefix = "explain_mcq_batch_subset"

print("[0] Concatenating RAW datasets...")
raw_dd = DatasetDict({seg: concat_raw(datasets_info, seg) for seg in segments})
raw_dd.save_to_disk(out_dir_raw)
print(f" -> saved: {out_dir_raw}")

print("[1] Selecting keys to send...")
include_keys = choose_keys_for_batch(raw_dd, segments=segments, send_ratio=send_ratio, seed=random_seed)
print(f" -> selected {len(include_keys)} items")

print("[2] Exporting batches to ./batch (100 per file)...")
_ = export_batches_jsonl(
    raw_dd,
    segments=segments,
    include_keys=include_keys,
    model_temperature=temperature,
    max_output_tokens=max_tokens,
    chunk_size=100
)
print(f" -> {len(_)} files written to {BATCH_DIR}")

print("[3] Processing batches sequentially...")
key2text = process_batches_sequentially(
    model_name=model_name,
    display_name_prefix=display_name_prefix,
    poll_seconds=poll_seconds
)
print(f" -> total parsed items: {len(key2text)}")

print("[4] Attaching predictions to RAW dataset copy...")
dd_with_pred = attach_predictions(raw_dd, key2text, segments=segments)
dd_with_pred.save_to_disk(out_dir_with_pred)
print(f" -> saved: {out_dir_with_pred}")

print(f"RAW dataset dir: {out_dir_raw}")
print(f"With predictions dir: {out_dir_with_pred}")
print(f"Batches created in: {BATCH_DIR}")
print(f"Completed batches moved to: {BATCH_COMPLETE_DIR}")

[0] Concatenating RAW datasets...


Saving the dataset (0/1 shards):   0%|          | 0/12000 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 12000/12000 [00:00<00:00, 179525.71 examples/s]

 -> saved: ./combined_raw
[1] Selecting keys to send...





 -> selected 12000 items
[2] Exporting batches to ./batch (100 per file)...
 -> 120 files written to ./batch
[3] Processing batches sequentially...

==> Submitting batch_00001.jsonl as explain_mcq_batch_subset_00001
[Batch] batches/0jtyjg45xk3e2woqzxbo7r3zkq1rps1uxik6 state: JOB_STATE_PENDING
[Batch] batches/0jtyjg45xk3e2woqzxbo7r3zkq1rps1uxik6 state: JOB_STATE_SUCCEEDED
 -> parsed 100 items from batch_00001.jsonl
 -> moved to ./batch_complete/batch_00001.jsonl

==> Submitting batch_00002.jsonl as explain_mcq_batch_subset_00002
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDING
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDING
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDING
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDING
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDING
[Batch] batches/uc7m2ezdi43uajhmd49t9myhicwximaomwul state: JOB_STATE_PENDIN



[Batch] batches/fu0ldawonztfxj2mh2doz8pfwc4xobo7mkwx state: BATCH_STATE_RUNNING
[Batch] batches/fu0ldawonztfxj2mh2doz8pfwc4xobo7mkwx state: JOB_STATE_SUCCEEDED
 -> parsed 100 items from batch_00007.jsonl
 -> moved to ./batch_complete/batch_00007.jsonl

==> Submitting batch_00008.jsonl as explain_mcq_batch_subset_00008
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl state: JOB_STATE_PENDING
[Batch] batches/4a6c8lmi9klt60ycjz61hjt7lifgue007tpl sta

Map: 100%|██████████| 12000/12000 [00:00<00:00, 12899.14 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 12000/12000 [00:00<00:00, 608450.67 examples/s]

 -> saved: ./combined_with_pred

Done ✅
RAW dataset dir: ./combined_raw
With predictions dir: ./combined_with_pred
Batches created in: ./batch
Completed batches moved to: ./batch_complete





In [None]:
print(dd_with_pred)

DatasetDict({
    train: Dataset({
        features: ['example_id', 'article', 'question', 'options', 'answer', 'example_type', 'pred'],
        num_rows: 12000
    })
})


In [None]:
import re
from typing import Optional, Any
from datasets import Dataset, DatasetDict
import pandas as pd  

PATTERNS = [
    re.compile(r"answer\s*is\s*:?\s*\(?\s*([A-E])\s*[\).]?\b", re.IGNORECASE),
    re.compile(r"(?:correct\s+)?answer\s*is\s*:?\s*\(?\s*([A-E])\s*[\).]?\b", re.IGNORECASE),
    re.compile(r"answer\s*[:\-]\s*\(?\s*([A-E])\s*[\).]?\b", re.IGNORECASE),
]

def extract_pred_letter(text: Any) -> Optional[str]:
    if not isinstance(text, str):
        return None
    for pat in PATTERNS:
        m = pat.search(text)
        if m:
            return m.group(1).upper()
    return None

def _map_fn(example: dict) -> dict:
    pred_letter = extract_pred_letter(example.get("pred", ""))
    ans = example.get("answer", None)
    ans_letter = ans.strip().upper() if isinstance(ans, str) else None
    is_correct = (pred_letter is not None and ans_letter is not None and pred_letter == ans_letter)
    return {"pred_letter": pred_letter, "is_correct": is_correct}



work_ds: Dataset = dd_with_pred

aug_ds: Dataset = work_ds.map(_map_fn, desc="Parsing pred letters")

correct_ds: Dataset = aug_ds.filter(lambda e: e["is_correct"] is True)
incorrect_ds: Dataset = aug_ds.filter(lambda e: e["is_correct"] is False and e["pred_letter"] is not None)
parse_failed_ds: Dataset = aug_ds.filter(lambda e: e["pred_letter"] is None)

summary = {
    "total": len(aug_ds),
    "parsed": len(aug_ds) - len(parse_failed_ds),
    "parse_failed": len(parse_failed_ds),
    "correct": len(correct_ds),
    "incorrect": len(incorrect_ds),
}
print("Summary:", summary)

Summary: {'total': 1, 'parsed': 0, 'parse_failed': 1, 'correct': 1, 'incorrect': 1}


In [None]:
print(correct_ds)

DatasetDict({
    train: Dataset({
        features: ['example_id', 'article', 'question', 'options', 'answer', 'example_type', 'pred', 'pred_letter', 'is_correct'],
        num_rows: 9489
    })
})


In [None]:
from datasets import Dataset, DatasetDict, concatenate_datasets
import os, shutil

obj = correct_ds

if isinstance(obj, DatasetDict):
    preferred = [k for k in ["train", "validation", "test", "dev", "val"] if k in obj]
    order = preferred if preferred else list(obj.keys())
    ds_single = concatenate_datasets([obj[k] for k in order])
elif isinstance(obj, Dataset):
    ds_single = obj

seed = 42
shuf = ds_single.shuffle(seed=seed)
splits_80_20 = shuf.train_test_split(test_size=0.2, seed=seed)
train_ds = splits_80_20["train"] 
temp_ds  = splits_80_20["test"]   

val_test = temp_ds.train_test_split(test_size=0.5, seed=seed)
val_ds   = val_test["train"]  
test_ds  = val_test["test"]    

cloth_with_pred = DatasetDict({
    "train": train_ds,
    "validation": val_ds,
    "test": test_ds,
})

save_path = "./cloth_with_pred"
if os.path.exists(save_path):
    shutil.rmtree(save_path)
cloth_with_pred.save_to_disk(save_path)

print({
    "total": len(ds_single),
    "train": len(cloth_with_pred["train"]),
    "validation": len(cloth_with_pred["validation"]),
    "test": len(cloth_with_pred["test"]),
})
print(f"Saved to: {save_path}")

Saving the dataset (0/1 shards):   0%|          | 0/7591 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 7591/7591 [00:00<00:00, 113741.04 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 949/949 [00:00<00:00, 63162.82 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 949/949 [00:00<00:00, 86488.95 examples/s]

{'total': 9489, 'train': 7591, 'validation': 949, 'test': 949}
✅ Saved to: ./cloth_with_pred



