In [1]:
# Cell: Build tokenized train/val/test JSONL from "data" folder (robust to stray chars/separators)

from pathlib import Path
import json, random, re
from typing import List, Tuple, Dict, Iterable

# ----------------------------
# Config
# ----------------------------
DATA_DIR = Path("data")            # root folder containing subfolders (one per level)
OUT_DIR  = Path("processed")       # where to write jsonl outputs
SEED     = 42
SPLITS   = (0.80, 0.10, 0.10)
PAD_TO_RECTANGLE = True

# Cleaning/sanitization behavior
STRICT = False                     # if True -> error on unknown token; if False -> clean with policy below
UNKNOWN_POLICY = "map_to_background"  # "map_to_background" | "drop"
COMMENT_PREFIXES = ("#", "//", ";")   # whole-line comments to skip
SKIP_SEPARATOR_LINES = True           # skip lines made of 1 non-vocab char repeated (----, =====, etc.)
MIN_SEP_RUN = 5

# ----------------------------
# Token vocabulary (from our project context)
# ----------------------------
VOCAB = ['M','F','y','Y','E','g','G','k','K','r','X','#','%','|','*','B','b','?','@','Q','!','1','2','D','S','C','U','L','o','t','T','<','>','[',']']
BACKGROUND = '|'
tok2id = {t:i for i,t in enumerate(VOCAB)}
id2tok = {i:t for t,i in tok2id.items()}
VOCAB_SET = set(VOCAB)

# ----------------------------
# Helpers
# ----------------------------
def read_level_txt(p: Path) -> List[str]:
    """Return list of raw lines without trailing newlines; strip BOM; remove trailing empties."""
    with p.open("r", encoding="utf-8", errors="replace") as f:
        raw = f.read()
    raw = raw.lstrip("\ufeff")  # strip BOM if present
    lines = [ln.rstrip("\n\r") for ln in raw.splitlines()]
    while lines and lines[-1] == "":
        lines.pop()
    return lines

SEP_LINE_RE = re.compile(r"^(.)\1+$")  # same char repeated

def is_separator_line(line: str) -> bool:
    if not SKIP_SEPARATOR_LINES:
        return False
    if len(line) < MIN_SEP_RUN:
        return False
    m = SEP_LINE_RE.match(line)
    if not m:
        return False
    ch = m.group(1)
    return ch not in VOCAB_SET  # skip only if char is not a valid tile

def sanitize_lines(lines: List[str], stats: Dict[str,int]) -> List[str]:
    """Remove comment/separator lines; optionally map/drop unknown chars. Update stats."""
    cleaned = []
    for ln in lines:
        striped = ln.strip()
        if not striped:
            continue
        # whole-line comments
        if any(striped.startswith(pref) for pref in COMMENT_PREFIXES):
            stats["skipped_comment_lines"] += 1
            continue
        # separator lines (-----, =====)
        if is_separator_line(striped):
            stats["skipped_separator_lines"] += 1
            continue

        # character-level cleaning
        new_chars = []
        for ch in ln:
            if ch in VOCAB_SET:
                new_chars.append(ch)
            else:
                if STRICT:
                    raise ValueError(f"Unknown token '{ch}' in line: {ln}")
                stats["unknown_chars"] += 1
                if UNKNOWN_POLICY == "map_to_background":
                    new_chars.append(BACKGROUND)
                elif UNKNOWN_POLICY == "drop":
                    # just skip this character
                    stats["dropped_chars"] += 1
                    continue
                else:
                    # fallback to mapping
                    new_chars.append(BACKGROUND)
        # keep line only if something remains (all-dropped lines vanish)
        if new_chars:
            cleaned.append("".join(new_chars))
        else:
            stats["dropped_empty_after_clean"] += 1
    return cleaned

def normalize_rectangular(lines: List[str]) -> Tuple[List[str], int, int]:
    """Pad all rows to the same width using BACKGROUND; return (lines, W, H)."""
    if not lines:
        return [], 0, 0
    W = max(len(row) for row in lines)
    H = len(lines)
    if PAD_TO_RECTANGLE and W > 0:
        lines = [row + (BACKGROUND * (W - len(row))) for row in lines]
    return lines, W, H

def tokenize_level(lines: List[str]) -> List[int]:
    """Flatten grid row-major into token IDs; assumes all chars ∈ VOCAB."""
    tokens: List[int] = []
    for row in lines:
        for ch in row:
            tokens.append(tok2id[ch])
    return tokens

def collect_pairs(data_dir: Path):
    """Yield dicts with tokenized corrupted/original and basic shape info."""
    stats = {
        "folders_seen": 0,
        "folders_kept": 0,
        "skipped_missing_files": 0,
        "skipped_comment_lines": 0,
        "skipped_separator_lines": 0,
        "unknown_chars": 0,
        "dropped_chars": 0,
        "dropped_empty_after_clean": 0,
        "empty_after_sanitize_pairs": 0,
    }

    for sub in sorted([p for p in data_dir.iterdir() if p.is_dir()]):
        stats["folders_seen"] += 1
        corr = sub / "corrupted.txt"
        orig = sub / "original.txt"
        if not (corr.exists() and orig.exists()):
            stats["skipped_missing_files"] += 1
            continue

        corr_lines_raw = read_level_txt(corr)
        orig_lines_raw = read_level_txt(orig)

        corr_lines = sanitize_lines(corr_lines_raw, stats)
        orig_lines = sanitize_lines(orig_lines_raw, stats)

        # If sanitization nuked all rows, skip this pair
        if not corr_lines or not orig_lines:
            stats["empty_after_sanitize_pairs"] += 1
            continue

        corr_lines, cW, cH = normalize_rectangular(corr_lines)
        orig_lines, oW, oH = normalize_rectangular(orig_lines)

        # Tokenize (now safe)
        corr_ids = tokenize_level(corr_lines)
        orig_ids = tokenize_level(orig_lines)

        stats["folders_kept"] += 1
        yield {
            "level_id": sub.name,
            "corrupted_ids": corr_ids,
            "original_ids": orig_ids,
            "width_corrupted": cW,
            "height_corrupted": cH,
            "width_original": oW,
            "height_original": oH,
            # For debugging: uncomment if you want raw/clean text persisted
            # "corrupted_text": "\n".join(corr_lines),
            # "original_text":  "\n".join(orig_lines),
        }

    return stats

# ----------------------------
# Main
# ----------------------------
OUT_DIR.mkdir(parents=True, exist_ok=True)

# First pass: collect and also get cleaning stats
pairs_iter = collect_pairs(DATA_DIR)
pairs = list(pairs_iter)  # exhaust generator

total = len(pairs)
if total == 0:
    raise RuntimeError(f"No valid level pairs found under {DATA_DIR.resolve()} after sanitization. "
                       f"Try setting STRICT=False and UNKNOWN_POLICY='map_to_background'.")

# Deterministic shuffle + split
random.Random(SEED).shuffle(pairs)
train_ratio, val_ratio, test_ratio = SPLITS
assert abs((train_ratio + val_ratio + test_ratio) - 1.0) < 1e-9, "SPLITS must sum to 1.0"

n_train = int(total * train_ratio)
n_val   = int(total * val_ratio)
n_test  = total - n_train - n_val

train_set = pairs[:n_train]
val_set   = pairs[n_train:n_train+n_val]
test_set  = pairs[n_train+n_val:]

def write_jsonl(path: Path, rows: Iterable[Dict]):
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

write_jsonl(OUT_DIR / "train.jsonl", train_set)
write_jsonl(OUT_DIR / "val.jsonl",   val_set)
write_jsonl(OUT_DIR / "test.jsonl",  test_set)

with (OUT_DIR / "vocab.json").open("w", encoding="utf-8") as f:
    json.dump({"vocab": VOCAB, "tok2id": tok2id}, f, ensure_ascii=False, indent=2)

# ----------------------------
# Summary (with cleaning report)
# ----------------------------
def avg_dims(rows, key_w, key_h):
    if not rows: return (0.0, 0.0)
    aw = sum(r.get(key_w, 0) for r in rows) / len(rows)
    ah = sum(r.get(key_h, 0) for r in rows) / len(rows)
    return (round(aw, 2), round(ah, 2))

cW_tr, cH_tr = avg_dims(train_set, "width_corrupted", "height_corrupted")
oW_tr, oH_tr = avg_dims(train_set, "width_original", "height_original")

print(f"✓ Processed {total} level folders from: {DATA_DIR.resolve()}")
print(f"Split -> Train: {len(train_set)} | Val: {len(val_set)} | Test: {len(test_set)}  (seed={SEED})")
print(f"Outputs -> {OUT_DIR / 'train.jsonl'}, {OUT_DIR / 'val.jsonl'}, {OUT_DIR / 'test.jsonl'}")
print(f"Saved   -> {OUT_DIR / 'vocab.json'}")
print(f"Train avg dims (corrupted): {cW_tr}x{cH_tr} | (original): {oW_tr}x{oH_tr}")
print("--- Cleaning report ---")
print(f"STRICT={STRICT}, UNKNOWN_POLICY='{UNKNOWN_POLICY}'")
print("Note: counts reflect both corrupted/original files across all folders.")
# We can’t read the `stats` dict returned from a generator after exhaustion; re-scan quickly to print stats.
# Light-weight rescan to only count issues:
def quick_scan_unknowns(base: Path) -> Dict[str,int]:
    s = {"folders_seen":0,"skipped_missing_files":0,"skipped_comment_lines":0,
         "skipped_separator_lines":0,"unknown_chars":0,"dropped_chars":0,"dropped_empty_after_clean":0}
    for sub in sorted([p for p in base.iterdir() if p.is_dir()]):
        s["folders_seen"] += 1
        corr = sub / "corrupted.txt"
        orig = sub / "original.txt"
        if not (corr.exists() and orig.exists()):
            s["skipped_missing_files"] += 1
            continue
        for p in (corr, orig):
            lines = read_level_txt(p)
            # simulate sanitize (no errors)
            _stats = {k:0 for k in s.keys()}
            sanitize_lines(lines, _stats)
            for k in s.keys():
                if k in _stats: s[k] += _stats[k]
    return s

scan = quick_scan_unknowns(DATA_DIR)
for k,v in scan.items():
    print(f"{k}: {v}")
print("If '-' or other stray characters are meaningful tiles for you, tell me and I'll add them to VOCAB.")

✓ Processed 3843 level folders from: C:\Users\xhepon\Documents\a-No.2\rp\explore\Mario-AI-Framework\ship_to_train\data
Split -> Train: 3074 | Val: 384 | Test: 385  (seed=42)
Outputs -> processed\train.jsonl, processed\val.jsonl, processed\test.jsonl
Saved   -> processed\vocab.json
Train avg dims (corrupted): 199.99x9.71 | (original): 199.99x8.08
--- Cleaning report ---
STRICT=False, UNKNOWN_POLICY='map_to_background'
Note: counts reflect both corrupted/original files across all folders.
folders_seen: 3843
skipped_missing_files: 0
skipped_comment_lines: 0
skipped_separator_lines: 54759
unknown_chars: 9928048
dropped_chars: 0
dropped_empty_after_clean: 0
If '-' or other stray characters are meaningful tiles for you, tell me and I'll add them to VOCAB.


In [None]:
# # Cell: Fine-tune a small Transformer to repair levels (corrupted_ids -> original_ids)

# import json, math, random
# from pathlib import Path
# from typing import List, Dict, Tuple

# import torch
# import torch.nn as nn
# from torch.utils.data import Dataset, DataLoader

# # ----------------------------
# # Paths & config
# # ----------------------------
# DATA_DIR = Path("processed")
# TRAIN_PATH = DATA_DIR / "train.jsonl"
# VAL_PATH   = DATA_DIR / "val.jsonl"
# VOCAB_PATH = DATA_DIR / "vocab.json"

# SAVE_DIR   = Path("checkpoints")
# SAVE_DIR.mkdir(parents=True, exist_ok=True)
# CKPT_PATH  = SAVE_DIR / "level_repair_transformer.pt"
# TOKCONF    = SAVE_DIR / "token_config.json"

# SEED = 42
# random.seed(SEED)
# torch.manual_seed(SEED)

# # Model/training hyperparams (tweak as needed)
# BATCH_SIZE = 32
# EMBED_DIM  = 256
# FF_DIM     = 512
# N_HEADS    = 8
# N_LAYERS   = 4
# DROPOUT    = 0.1
# LR         = 3e-4
# EPOCHS     = 5
# MAX_LEN    = 4096          # maximum allowed sequence length after flattening

# # ----------------------------
# # Load vocab and define PAD token
# # ----------------------------
# with VOCAB_PATH.open("r", encoding="utf-8") as f:
#     vocab_file = json.load(f)
# VOCAB = vocab_file["vocab"]                   # list of string tokens
# tok2id = vocab_file["tok2id"]                 # mapping str -> int

# # add a dedicated PAD token id for sequence padding (not present in the training data)
# PAD_ID = len(VOCAB)
# NUM_TOKENS = len(VOCAB) + 1                   # +1 for PAD

# # Save token config used by the model
# with TOKCONF.open("w", encoding="utf-8") as f:
#     json.dump({"PAD_ID": PAD_ID,
#                "NUM_TOKENS": NUM_TOKENS,
#                "VOCAB": VOCAB}, f, indent=2)

# # ----------------------------
# # Dataset
# # ----------------------------
# def read_jsonl(path: Path) -> List[Dict]:
#     rows = []
#     with path.open("r", encoding="utf-8") as f:
#         for ln in f:
#             if ln.strip():
#                 rows.append(json.loads(ln))
#     return rows

# class LevelPairs(Dataset):
#     def __init__(self, jsonl_path: Path, max_len: int = MAX_LEN):
#         self.rows = read_jsonl(jsonl_path)
#         self.max_len = max_len

#         # Optionally filter overly long samples (Transformer memory safeguard)
#         kept = []
#         for r in self.rows:
#             if len(r["corrupted_ids"]) <= self.max_len and len(r["original_ids"]) <= self.max_len:
#                 kept.append(r)
#         dropped = len(self.rows) - len(kept)
#         self.rows = kept
#         if dropped:
#             print(f"[{jsonl_path.name}] Dropped {dropped} samples exceeding MAX_LEN={self.max_len}")

#     def __len__(self):
#         return len(self.rows)

#     def __getitem__(self, idx):
#         r = self.rows[idx]
#         src = r["corrupted_ids"]
#         tgt = r["original_ids"]
#         # We feed tgt_in (shifted right, BOS-free) and compute loss on tgt_out (shifted left)
#         # Here we skip using explicit BOS/EOS; we learn to reproduce full sequence.
#         return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long)

# def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]):
#     # Pad to max length in batch with PAD_ID
#     src_seqs, tgt_seqs = zip(*batch)
#     max_src = max(s.size(0) for s in src_seqs)
#     max_tgt = max(t.size(0) for t in tgt_seqs)

#     padded_src = torch.full((len(batch), max_src), PAD_ID, dtype=torch.long)
#     padded_tin = torch.full((len(batch), max_tgt), PAD_ID, dtype=torch.long)  # decoder input
#     padded_tout= torch.full((len(batch), max_tgt), PAD_ID, dtype=torch.long)  # supervised target

#     for i, (s, t) in enumerate(zip(src_seqs, tgt_seqs)):
#         # decoder: teacher forcing with 1-position shift (tin = [t0..t_{n-2}, t_{n-1}], tout = [t0..t_{n-1}])
#         padded_src[i, :s.size(0)] = s
#         padded_tin[i, :t.size(0)] = t
#         padded_tout[i,:t.size(0)] = t

#     src_key_padding_mask = (padded_src == PAD_ID)  # (B, S)
#     tgt_key_padding_mask = (padded_tin == PAD_ID)  # (B, T)

#     return {
#         "src": padded_src,
#         "tin": padded_tin,
#         "tout": padded_tout,
#         "src_mask": src_key_padding_mask,
#         "tgt_mask": tgt_key_padding_mask,
#     }

# train_ds = LevelPairs(TRAIN_PATH)
# val_ds   = LevelPairs(VAL_PATH)

# train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn, num_workers=0)
# val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0)

# print(f"Loaded: train={len(train_ds)} | val={len(val_ds)} | vocab={len(VOCAB)} (+PAD)")

# # ----------------------------
# # Model
# # ----------------------------
# class PositionalEncoding(nn.Module):
#     def __init__(self, d_model: int, max_len: int = 65536):
#         super().__init__()
#         pe = torch.zeros(max_len, d_model)
#         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
#         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
#         pe[:, 0::2] = torch.sin(position * div_term)
#         pe[:, 1::2] = torch.cos(position * div_term)
#         self.register_buffer('pe', pe.unsqueeze(0))  # (1, L, D)

#     def forward(self, x):
#         # x: (B, L, D)
#         L = x.size(1)
#         return x + self.pe[:, :L, :]

# class RepairTransformer(nn.Module):
#     def __init__(self, vocab_size: int, pad_id: int, d_model: int, ff_dim: int,
#                  n_heads: int, n_layers: int, dropout: float):
#         super().__init__()
#         self.pad_id = pad_id
#         self.d_model = d_model

#         self.src_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
#         self.tgt_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
#         self.pos_enc = PositionalEncoding(d_model)

#         encoder_layer = nn.TransformerEncoderLayer(d_model, n_heads, ff_dim, dropout, batch_first=True)
#         decoder_layer = nn.TransformerDecoderLayer(d_model, n_heads, ff_dim, dropout, batch_first=True)
#         self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
#         self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)

#         self.proj = nn.Linear(d_model, vocab_size)

#     def make_square_subsequent_mask(self, size: int, device):
#         # causal mask for decoder (T, T)
#         return torch.triu(torch.full((size, size), float('-inf'), device=device), diagonal=1)

#     def forward(self, src, tin, src_pad_mask, tgt_pad_mask):
#         # src, tin: (B, S/T)
#         src_emb = self.pos_enc(self.src_emb(src))
#         tgt_emb = self.pos_enc(self.tgt_emb(tin))

#         memory = self.encoder(src_emb, src_key_padding_mask=src_pad_mask)  # (B, S, D)

#         # causal mask for decoder self-attn
#         T = tin.size(1)
#         causal = self.make_square_subsequent_mask(T, tin.device)
#         out = self.decoder(
#             tgt_emb, memory,
#             tgt_mask=causal,
#             tgt_key_padding_mask=tgt_pad_mask,
#             memory_key_padding_mask=src_pad_mask
#         )  # (B, T, D)
#         logits = self.proj(out)  # (B, T, V)
#         return logits

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = RepairTransformer(
#     vocab_size=NUM_TOKENS, pad_id=PAD_ID,
#     d_model=EMBED_DIM, ff_dim=FF_DIM,
#     n_heads=N_HEADS, n_layers=N_LAYERS, dropout=DROPOUT
# ).to(device)

# optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
# criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)

# # ----------------------------
# # Train / eval loops
# # ----------------------------
# def run_epoch(loader, train_mode=True):
#     model.train(train_mode)
#     total_loss = 0.0
#     total_tokens = 0
#     for batch in loader:
#         src  = batch["src"].to(device)
#         tin  = batch["tin"].to(device)
#         tout = batch["tout"].to(device)
#         src_mask = batch["src_mask"].to(device)
#         tgt_mask = batch["tgt_mask"].to(device)

#         if train_mode:
#             optimizer.zero_grad()

#         logits = model(src, tin, src_mask, tgt_mask)  # (B, T, V)
#         # Compute loss over all positions
#         loss = criterion(logits.reshape(-1, logits.size(-1)), tout.reshape(-1))

#         if train_mode:
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             optimizer.step()

#         # token count excludes pads in target
#         ntoks = (tout != PAD_ID).sum().item()
#         total_loss += loss.item() * ntoks
#         total_tokens += ntoks

#     return total_loss / max(1, total_tokens)

# best_val = float("inf")
# for epoch in range(1, EPOCHS + 1):
#     tr_loss = run_epoch(train_dl, train_mode=True)
#     vl_loss = run_epoch(val_dl, train_mode=False)
#     print(f"Epoch {epoch:02d} | train xent: {tr_loss:.4f} | val xent: {vl_loss:.4f} | ppl: {math.exp(min(vl_loss, 20)):.2f}")

#     if vl_loss < best_val:
#         best_val = vl_loss
#         torch.save({
#             "model_state_dict": model.state_dict(),
#             "config": {
#                 "EMBED_DIM": EMBED_DIM, "FF_DIM": FF_DIM,
#                 "N_HEADS": N_HEADS, "N_LAYERS": N_LAYERS,
#                 "DROPOUT": DROPOUT, "PAD_ID": PAD_ID,
#                 "NUM_TOKENS": NUM_TOKENS
#             }
#         }, CKPT_PATH)
#         print(f"  ✓ Saved checkpoint -> {CKPT_PATH}")

# # ----------------------------
# # Greedy decode on a few val samples
# # ----------------------------
# id2tok = {int(v): k for k, v in tok2id.items()}
# id2tok[PAD_ID] = "<PAD>"

# def greedy_decode(src_ids: List[int], max_len: int = 4096) -> List[int]:
#     model.eval()
#     with torch.no_grad():
#         src = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)
#         src_mask = (src == PAD_ID)
#         memory = model.encoder(model.pos_enc(model.src_emb(src)), src_key_padding_mask=src_mask)

#         # initialize decoder input with zeros — here we’ll just start from a copy prompt
#         # (you can introduce BOS/EOS later if you want exact length control)
#         out = torch.full((1, 1), PAD_ID, dtype=torch.long, device=device)

#         for _ in range(min(len(src_ids), max_len)):
#             tgt_mask = (out == PAD_ID)
#             causal = model.make_square_subsequent_mask(out.size(1), device)
#             dec = model.decoder(
#                 model.pos_enc(model.tgt_emb(out)),
#                 memory, tgt_mask=causal,
#                 tgt_key_padding_mask=tgt_mask,
#                 memory_key_padding_mask=src_mask
#             )
#             logits = model.proj(dec[:, -1, :])  # (1, V)
#             next_id = logits.argmax(-1).unsqueeze(1)  # (1,1)
#             out = torch.cat([out, next_id], dim=1)
#         # drop the first PAD seed token
#         return out.squeeze(0).tolist()[1:len(src_ids)+1]

# def ids_to_grid(ids: List[int], width: int, height: int) -> List[str]:
#     # reconstruct rows using your BACKGROUND padding policy; assumes rectangular dims were saved earlier
#     chars = []
#     for i in ids:
#         if i == PAD_ID:
#             ch = "|"  # visualize PAD as background, you can pick something else
#         else:
#             ch = VOCAB[i]
#         chars.append(ch)
#     # clamp if lengths mismatch
#     total = width * height
#     chars = chars[:total] + (["|"] * max(0, total - len(chars)))
#     return ["".join(chars[r*width:(r+1)*width]) for r in range(height)]

# # Show a few qualitative examples from val set
# val_rows = read_jsonl(VAL_PATH)
# print("\n=== Qualitative samples (val) ===")
# for r in random.sample(val_rows, k=min(3, len(val_rows))):
#     src = r["corrupted_ids"]
#     tgt = r["original_ids"]
#     w   = r.get("width_original", 0) or r.get("width_corrupted", 0)
#     h   = r.get("height_original", 0) or r.get("height_corrupted", 0)

#     pred = greedy_decode(src, max_len=len(src))
#     # decode to grid strings
#     tgt_grid  = ids_to_grid(tgt, w, h) if (w and h) else ["(no dims recorded)"]
#     pred_grid = ids_to_grid(pred, w, h) if (w and h) else ["(no dims recorded)"]

#     print(f"\nLevel: {r.get('level_id','?')}")
#     print("Corrupted (first 120 ids):", src[:120], "...")
#     print("Target    (first 120 ids):", tgt[:120], "...")
#     print("Pred      (first 120 ids):", pred[:120], "...")
#     if w and h:
#         print(f"Dims: {w}x{h}")
#         print("Target grid preview:")
#         for line in tgt_grid[:min(6, len(tgt_grid))]:
#             print(" ", line[:min(120, len(line))])
#         print("Pred grid preview:")
#         for line in pred_grid[:min(6, len(pred_grid))]:
#             print(" ", line[:min(120, len(line))])

# print(f"\nDone. Best checkpoint at: {CKPT_PATH}")

[train.jsonl] Dropped 1 samples exceeding MAX_LEN=4096
Loaded: train=3073 | val=384 | vocab=35 (+PAD)


In [None]:
# Cell: Fine-tune a Hugging Face causal LM (Qwen / gpt-oss / GPT-2) for level repair
# !pip install -q transformers accelerate datasets

import json, random, math
from pathlib import Path
from typing import Dict, List
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    Trainer, TrainingArguments
)

# -----------------------------
# Config
# -----------------------------
DATA_DIR   = Path("processed")
TRAIN_PATH = DATA_DIR / "train.jsonl"
VAL_PATH   = DATA_DIR / "val.jsonl"
VOCAB_PATH = DATA_DIR / "vocab.json"

# Choose a small-ish causal LM; you can swap:
#   - "Qwen/Qwen2.5-0.5B" (good quality; needs GPU)
#   - "Qwen/Qwen2-0.5B"
#   - "Gryphe/gpt-oss-mini" or "Gryphe/gpt-oss-125M" (community)
#   - "gpt2" (tiny, CPU-friendly baseline)
MODEL_NAME = "Qwen/Qwen2.5-0.5B"   # change if needed
OUTPUT_DIR = "hf-checkpoints/level-repair-qwen"
SEED       = 42

BATCH_SIZE = 2          # keep small for big models; increase if you have VRAM
GRAD_ACCUM = 8
EPOCHS     = 3
LR         = 2e-5
MAX_LEN    = 4096       # max tokens per example after tokenization
USE_ROW_BREAKS = True   # if dims saved, insert "\n" at row breaks for structure
PREVIEW_SAMPLES = 3

random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------------
# Load vocab (id -> tile char)
# -----------------------------
with open(VOCAB_PATH, "r", encoding="utf-8") as f:
    vconf = json.load(f)
VOCAB: List[str] = vconf["vocab"]
TOK2ID: Dict[str,int] = {k:int(v) for k,v in vconf["tok2id"].items()}
ID2TOK: Dict[int,str] = {v:k for k,v in TOK2ID.items()}

def ids_to_grid(ids: List[int], w: int, h: int) -> str:
    """Rebuild a grid string with optional row breaks for better structure learning."""
    chars = [ID2TOK.get(i, "|") for i in ids]
    if USE_ROW_BREAKS and w and h and w*h <= len(chars) + 64:  # tolerate a little mismatch
        chars = chars[:w*h]
        rows = ["".join(chars[r*w:(r+1)*w]) for r in range(h)]
        return "\n".join(rows)
    else:
        return "".join(chars)

def read_jsonl(path: Path) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            if ln.strip():
                rows.append(json.loads(ln))
    return rows

# -----------------------------
# Tokenizer & model
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Ensure we have a pad token (some decoder-only models don't). Use eos as pad if needed.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Align model embeddings with tokenizer size if necessary (usually not needed for stock tokenizer)
model.resize_token_embeddings(len(tokenizer))

INSTR = "### Instruction:\nRepair the level.\n\n"
CORR  = "### Corrupted:\n"
REPR  = "\n\n### Repaired:\n"

def build_example(row: Dict) -> Dict[str, torch.Tensor]:
    # Build text fields
    w = row.get("width_original") or row.get("width_corrupted") or 0
    h = row.get("height_original") or row.get("height_corrupted") or 0
    corrupted_txt = ids_to_grid(row["corrupted_ids"], w, h)
    target_txt    = ids_to_grid(row["original_ids"],  w, h)

    prompt = INSTR + CORR + corrupted_txt + REPR
    full   = prompt + target_txt

    enc = tokenizer(full, truncation=True, max_length=MAX_LEN, padding=False, return_tensors=None)
    # Compute label mask: loss only on target portion
    p_ids = tokenizer(prompt, truncation=True, max_length=MAX_LEN, padding=False, return_tensors=None)["input_ids"]
    labels = enc["input_ids"][:]
    loss_mask_upto = len(p_ids)  # positions before this should be ignored (-100)

    labels = [-100 if i < loss_mask_upto else tok for i, tok in enumerate(labels)]

    return {
        "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
        "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "level_id": row.get("level_id", "")
    }

class RepairDataset(Dataset):
    def __init__(self, jsonl_path: Path):
        raw = read_jsonl(jsonl_path)
        # Optional: filter pathological long pairs after tokenization
        processed = []
        for r in raw:
            ex = build_example(r)
            if ex["input_ids"].numel() <= MAX_LEN:
                processed.append(ex)
        dropped = len(raw) - len(processed)
        if dropped:
            print(f"[{jsonl_path.name}] dropped {dropped} over-long samples (> {MAX_LEN} tokens)")
        self.data = processed

    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

train_ds = RepairDataset(TRAIN_PATH)
val_ds   = RepairDataset(VAL_PATH)
print(f"Loaded HF datasets: train={len(train_ds)} | val={len(val_ds)} | model={MODEL_NAME}")

# -----------------------------
# Data collator (already pre-tokenized & padded per-batch by Trainer)
# -----------------------------
@dataclass
class SimpleCollator:
    pad_token_id: int
    def __call__(self, features):
        # Dynamic pad to the longest in batch
        batch = {}
        keys = ["input_ids","attention_mask","labels"]
        max_len = max(len(f["input_ids"]) for f in features)
        for k in keys:
            padded = []
            for f in features:
                seq = f[k]
                pad_id = self.pad_token_id if k != "labels" else -100
                if len(seq) < max_len:
                    pad_len = max_len - len(seq)
                    if isinstance(seq, torch.Tensor):
                        seq = seq.tolist()
                else:
                    pad_len = 0
                    if isinstance(seq, torch.Tensor):
                        seq = seq.tolist()
                seq = seq + [pad_id]*pad_len
                padded.append(torch.tensor(seq, dtype=torch.long))
            batch[k] = torch.stack(padded, dim=0)
        return batch

collator = SimpleCollator(pad_token_id=tokenizer.pad_token_id)

# -----------------------------
# Training
# -----------------------------
# args = TrainingArguments(
#     output_dir=OUTPUT_DIR,
#     per_device_train_batch_size=BATCH_SIZE,
#     per_device_eval_batch_size=BATCH_SIZE,
#     gradient_accumulation_steps=GRAD_ACCUM,
#     num_train_epochs=EPOCHS,
#     learning_rate=LR,
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.03,
#     weight_decay=0.01,
#     logging_steps=25,
#     evaluation_strategy="steps",
#     eval_steps=100,
#     save_strategy="steps",
#     save_steps=100,
#     save_total_limit=2,
#     bf16=torch.cuda.is_available(),   # use bf16 when possible
#     fp16=not torch.cuda.is_available() and False,  # keep False on CPU
#     report_to="none",
# )

# trainer = Trainer(
#     model=model,
#     args=args,
#     tokenizer=tokenizer,
#     data_collator=collator,
#     train_dataset=train_ds,
#     eval_dataset=val_ds,
# )

# trainer.train()
# trainer.save_model(OUTPUT_DIR)  # saves adapter/weights + tokenizer

# # -----------------------------
# # Quick qualitative check
# # -----------------------------
# def generate_repair(corrupted_ids: List[int], w:int, h:int, max_new_tokens:int=4096) -> str:
#     corrupted_txt = ids_to_grid(corrupted_ids, w, h)
#     prompt = INSTR + CORR + corrupted_txt + REPR
#     inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
#     with torch.no_grad():
#         out = model.generate(
#             **inputs,
#             max_new_tokens=min(max_new_tokens, 2*w*h if w and h else 2048),
#             do_sample=False,
#             temperature=1.0,
#             top_p=1.0,
#             eos_token_id=tokenizer.eos_token_id,
#             pad_token_id=tokenizer.pad_token_id,
#         )
#     text = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
#     return text

# val_rows = read_jsonl(VAL_PATH)
# print("\n=== Generations (validation samples) ===")
# for row in random.sample(val_rows, k=min(PREVIEW_SAMPLES, len(val_rows))):
#     w = row.get("width_original") or row.get("width_corrupted") or 0
#     h = row.get("height_original") or row.get("height_corrupted") or 0
#     pred = generate_repair(row["corrupted_ids"], w, h)
#     print(f"\nLevel: {row.get('level_id','?')}")
#     print("Pred repaired (first 6 lines):")
#     print("\n".join(pred.splitlines()[:6]))


# Cell: Ultra-compatible Trainer patch (filters kwargs to match your transformers version)
import inspect, math, torch, transformers
from transformers import Trainer, TrainingArguments

print("Transformers version:", transformers.__version__)

def filter_kwargs(cls, kwargs):
    params = set(inspect.signature(cls.__init__).parameters.keys())
    return {k: v for k, v in kwargs.items() if k in params}

# Start with a generous set, then filter to what your version supports
train_args = dict(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    weight_decay=0.01,
    logging_steps=25,
    save_steps=100,
    save_total_limit=2,
    lr_scheduler_type="cosine",   # will be dropped if unsupported
    warmup_ratio=0.03,            # will be dropped if unsupported
    bf16=torch.cuda.is_available(),  # will be dropped if unsupported
    fp16=(torch.cuda.is_available() and not torch.cuda.is_available()),  # will be dropped if unsupported
    report_to="none",             # will be dropped if unsupported
    # NOTE: we intentionally DO NOT pass evaluation_strategy or evaluate_during_training
)

args = TrainingArguments(**filter_kwargs(TrainingArguments, train_args))

# Some ancient Trainer versions don't accept tokenizer in ctor; guard it too.
trainer_kwargs = dict(
    model=model,
    args=args,
    data_collator=collator,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,  # will be filtered below if unsupported
)

if "tokenizer" not in inspect.signature(Trainer.__init__).parameters:
    trainer_kwargs.pop("tokenizer", None)

trainer = Trainer(**trainer_kwargs)

trainer.train()

# Manual eval (since scheduling args aren’t available)
try:
    eval_out = trainer.evaluate()
    print("Eval metrics:", eval_out)
    if "eval_loss" in eval_out:
        ppl = math.exp(min(20, eval_out["eval_loss"]))
        print(f"Perplexity (clamped): {ppl:.2f}")
except Exception as e:
    print("Evaluation skipped (not supported in this version):", e)

# Save model
try:
    trainer.save_model(OUTPUT_DIR)
    print("Saved to:", OUTPUT_DIR)
except Exception as e:
    print("Save failed:", e)