In [4]:
# CELL 2 (REPLACE) — install xong KHÔNG import transformers ở đây (để tránh cache về disk trước)
!pip -q install -U --no-cache-dir \
  "transformers>=4.51.0" \
  "accelerate>=0.30.0" \
  "datasets>=2.19.0" \
  "peft>=0.11.0" \
  "trl>=0.9.6" \
  "bitsandbytes>=0.43.1" \
  "huggingface_hub>=0.23.0" \
  "tokenizers>=0.21.0" \
  "safetensors>=0.4.3" \
  "sentencepiece" \
  "jsonschema>=4.22.0" \
  "rapidfuzz>=3.9.0" \
  "openpyxl>=3.1.5"


In [5]:
# CELL 1 (REPLACE) — RAM mode phải chạy TRƯỚC MỌI import transformers/tokenizers/peft
import os
from pathlib import Path

!df -h /dev/shm

RAM_BASE = Path("/dev/shm/kaggle_ram")
RAM_BASE.mkdir(parents=True, exist_ok=True)

HF_HOME = RAM_BASE / "hf"
HF_HOME.mkdir(parents=True, exist_ok=True)

os.environ["HF_HOME"] = str(HF_HOME)
os.environ["HF_HUB_CACHE"] = str(HF_HOME / "hub")
os.environ["HF_DATASETS_CACHE"] = str(HF_HOME / "datasets")
os.environ["TRANSFORMERS_CACHE"] = str(HF_HOME / "transformers")
os.environ["TORCH_HOME"] = str(RAM_BASE / "torch")
os.environ["XDG_CACHE_HOME"] = str(RAM_BASE / ".cache")

# outputs/logs cũng đẩy vào RAM
WORKDIR = RAM_BASE / "working"
DATA_DIR = WORKDIR / "data"
TEACH_CACHE_DIR = WORKDIR / "teacher_outputs"
WORKDIR.mkdir(parents=True, exist_ok=True)
DATA_DIR.mkdir(parents=True, exist_ok=True)
TEACH_CACHE_DIR.mkdir(parents=True, exist_ok=True)

print("HF_HOME =", os.environ["HF_HOME"])
print("HF_HUB_CACHE =", os.environ["HF_HUB_CACHE"])
print("WORKDIR =", WORKDIR)
print("DATA_DIR =", DATA_DIR)
print("TEACH_CACHE_DIR =", TEACH_CACHE_DIR)


Filesystem      Size  Used Avail Use% Mounted on
shm             114G     0  114G   0% /dev/shm
HF_HOME = /dev/shm/kaggle_ram/hf
HF_HUB_CACHE = /dev/shm/kaggle_ram/hf/hub
WORKDIR = /dev/shm/kaggle_ram/working
DATA_DIR = /dev/shm/kaggle_ram/working/data
TEACH_CACHE_DIR = /dev/shm/kaggle_ram/working/teacher_outputs


In [6]:
# CELL 3 (REPLACE) — setup PATHS/QWEN32B_PATH nhưng KHÔNG reset WORKDIR/DATA_DIR về /kaggle/working nữa
import os, re, json, time, math, random
from glob import glob

# WORKDIR/DATA_DIR đã được set từ CELL 1 (RAM). Không ghi đè lại.
DATA_DIR.mkdir(parents=True, exist_ok=True)

def find_qwen32b_path():
    candidates = []
    for p in glob("/kaggle/input/**", recursive=True):
        if os.path.isdir(p):
            low = p.lower()
            if "qwen" in low and ("32b" in low or "32-b" in low):
                if os.path.exists(os.path.join(p, "config.json")):
                    candidates.append(p)
    candidates = sorted(candidates, key=lambda x: len(x))
    return candidates[0] if candidates else None

QWEN32B_PATH = find_qwen32b_path()
print("QWEN32B_PATH =", QWEN32B_PATH)

TEACHERS = {
    "open_finance_8b": "DragonLLM/Llama-Open-Finance-8B",
    "finance_llama3_8b": "instruction-pretrain/finance-Llama3-8B",
    "fingpt_lora_llama3_8b": "FinGPT/fingpt-mt_llama3-8b_lora",
}
FINGPT_BASE = "meta-llama/Meta-Llama-3-8B"

PATHS = {
    "vn_mcocr": "/kaggle/input/vietnamese-receipts-mc-ocr-2021",
    "invoice_ocr": "/kaggle/input/invoice-ocr",
    "hi_quality_invoice": "/kaggle/input/high-quality-invoice-images-for-ocr",
    "gl_xlsx": "/kaggle/input/generalledger/Data file for students.xlsx",
    "transactions_csv": "/kaggle/input/financial-transactions-dataset/financial_transactions.csv",
    "forecast_csv": "/kaggle/input/financial-forecasting-data/simulated_financial_forecasting_data.csv",
    "data_retriever_csv": "/kaggle/input/data-retreiver/Data_ret.csv",
}
print("DATA PATHS OK")


QWEN32B_PATH = /kaggle/input/qwen-3/transformers/32b/1
DATA PATHS OK


In [7]:
from jsonschema import validate
from jsonschema.exceptions import ValidationError

# ===== Schemas =====
RECEIPT_SCHEMA = {
    "type": "object",
    "properties": {
        "vendor_name": {"type": ["string", "null"]},
        "address": {"type": ["string", "null"]},
        "date": {"type": ["string", "null"]},            # YYYY-MM-DD preferred
        "total_amount": {"type": ["number", "null"]},
        "currency": {"type": ["string", "null"]},        # "VND"
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["vendor_name","address","date","total_amount","currency","confidence","flags"]
}

INVOICE_SCHEMA = {
    "type": "object",
    "properties": {
        "vendor_name": {"type": ["string", "null"]},
        "invoice_no": {"type": ["string", "null"]},
        "date": {"type": ["string", "null"]},
        "subtotal": {"type": ["number", "null"]},
        "tax": {"type": ["number", "null"]},
        "total": {"type": ["number", "null"]},
        "currency": {"type": ["string", "null"]},
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["vendor_name","invoice_no","date","subtotal","tax","total","currency","confidence","flags"]
}

JOURNAL_SCHEMA = {
    "type": "object",
    "properties": {
        "entries": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "account": {"type": "string"},
                    "debit": {"type": "number"},
                    "credit": {"type": "number"},
                    "memo": {"type": ["string","null"]}
                },
                "required": ["account","debit","credit","memo"]
            }
        },
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["entries","confidence","flags"]
}

TASK2SCHEMA = {
    "receipt_extract_text": RECEIPT_SCHEMA,
    "invoice_extract_text": INVOICE_SCHEMA,
    "journal_from_structured_txn": JOURNAL_SCHEMA,
}

def schema_pass(task: str, obj: dict) -> bool:
    try:
        validate(instance=obj, schema=TASK2SCHEMA[task])
        return True
    except ValidationError:
        return False
    except Exception:
        return False

# ===== JSON extract / repair =====
def extract_json_from_text(text: str):
    if text is None:
        return None
    # pick first {...} block
    m = re.search(r"\{.*\}", text, flags=re.S)
    if not m:
        return None
    s = m.group(0)
    try:
        return json.loads(s)
    except Exception:
        return None

def json_repair_minimal(text: str):
    """
    deterministic repair for common LLM issues:
    - trailing commas
    - single quotes -> double quotes (simple cases)
    """
    if text is None:
        return None
    m = re.search(r"\{.*\}", text, flags=re.S)
    if not m:
        return None
    s = m.group(0).strip()

    s = re.sub(r",\s*}", "}", s)
    s = re.sub(r",\s*]", "]", s)
    # naive quote fix (only if it looks like JSON)
    if "'" in s and '"' not in s:
        s = s.replace("'", '"')

    try:
        return json.loads(s)
    except Exception:
        return None

# ===== Prompt builder =====
def build_prompt(task: str, input_data):
    if task == "receipt_extract_text":
        return f"""
Extract receipt key fields from Vietnamese text.
Return ONLY valid JSON with fields:
vendor_name,address,date,total_amount,currency,confidence,flags

Receipt Text:
{input_data}
""".strip()

    if task == "invoice_extract_text":
        return f"""
Extract invoice fields from text.
Return ONLY valid JSON with fields:
vendor_name,invoice_no,date,subtotal,tax,total,currency,confidence,flags

Invoice Text:
{input_data}
""".strip()

    if task == "journal_from_structured_txn":
        return f"""
You are an ERP accountant.
Given a structured transaction JSON, propose journal entries.
Return ONLY valid JSON with fields:
- entries: array of objects (account, debit, credit, memo)
- confidence
- flags

Transaction:
{json.dumps(input_data, ensure_ascii=False)}
""".strip()

    raise ValueError("Unknown task")

print("Schemas + prompt builder ready.")


Schemas + prompt builder ready.


In [8]:
import pandas as pd

def infer_col(df, candidates):
    cols = {c.lower(): c for c in df.columns}
    for cand in candidates:
        if cand.lower() in cols:
            return cols[cand.lower()]
    # fuzzy contains
    for c in df.columns:
        low = c.lower()
        for cand in candidates:
            if cand.lower() in low:
                return c
    return None

def load_vn_mcocr_cases(limit=300):
    root = PATHS["vn_mcocr"]
    cases = []

    # Prefer CSV with gold labels if present
    csv_candidates = [
        os.path.join(root, "mcocr_train_df.csv"),
        os.path.join(root, "mcocr_val_sample_df.csv"),
        os.path.join(root, "results.csv"),
    ]
    for p in csv_candidates:
        if os.path.exists(p):
            df = pd.read_csv(p)

            text_col = infer_col(df, ["text", "ocr_text", "raw_text", "content", "transcription"])
            if text_col is None:
                # fallback pick longest string col
                str_cols = [c for c in df.columns if df[c].dtype == "object"]
                if str_cols:
                    text_col = max(str_cols, key=lambda c: df[c].astype(str).str.len().mean())

            seller_col = infer_col(df, ["seller", "vendor", "vendor_name", "merchant", "store", "shop"])
            addr_col   = infer_col(df, ["address", "seller_address", "vendor_address"])
            date_col   = infer_col(df, ["timestamp", "date", "datetime", "time"])
            total_col  = infer_col(df, ["total_cost", "total", "amount", "total_amount", "sum"])

            for i, row in df.head(limit).iterrows():
                raw_text = str(row[text_col]) if text_col else ""

                gold = None
                if seller_col or addr_col or date_col or total_col:
                    def safe_float(x):
                        try:
                            if pd.isna(x): 
                                return None
                            s = str(x)
                            s = re.sub(r"[^\d\.\-]", "", s)
                            return float(s) if s else None
                        except:
                            return None

                    gold = {
                        "vendor_name": str(row[seller_col]) if seller_col and pd.notna(row[seller_col]) else None,
                        "address": str(row[addr_col]) if addr_col and pd.notna(row[addr_col]) else None,
                        "date": str(row[date_col]) if date_col and pd.notna(row[date_col]) else None,
                        "total_amount": safe_float(row[total_col]) if total_col else None,
                        "currency": "VND",
                        "confidence": 0.0,
                        "flags": []
                    }

                cases.append({
                    "id": f"vn_mcocr_{i}",
                    "task": "receipt_extract_text",
                    "input": raw_text,
                    "gold": gold,
                    "meta": {"source": os.path.basename(p)}
                })
            return cases

    # fallback txt (OCR lines)
    txt_candidates = [
        os.path.join(root, "text_recognition_train_data.txt"),
        os.path.join(root, "text_recognition_val_data.txt"),
    ]
    for p in txt_candidates:
        if os.path.exists(p):
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                for idx, line in enumerate(f):
                    if idx >= limit:
                        break
                    parts = line.strip().split("\t")
                    raw_text = parts[-1] if parts else ""
                    cases.append({
                        "id": f"vn_mcocr_txt_{idx}",
                        "task": "receipt_extract_text",
                        "input": raw_text,
                        "gold": None,
                        "meta": {"source": os.path.basename(p)}
                    })
            return cases

    return []

def load_gl_cases(limit=200):
    xlsx_path = PATHS["gl_xlsx"]
    if not os.path.exists(xlsx_path):
        return []
    xls = pd.ExcelFile(xlsx_path)
    # take first sheet by default
    df = pd.read_excel(xlsx_path, sheet_name=xls.sheet_names[0])

    cases = []
    for i, row in df.head(limit).iterrows():
        txn = row.to_dict()
        cases.append({
            "id": f"gl_{i}",
            "task": "journal_from_structured_txn",
            "input": txn,
            "gold": None,
            "meta": {"sheet": xls.sheet_names[0]}
        })
    return cases

def load_invoice_ocr_cases(limit=200):
    """
    Robust loader:
    - If JSON/CSV annotations exist -> use their text fields
    - Otherwise use image paths (text-only LLM can't read images, but still valid for KD if you later OCR)
    """
    root = PATHS["invoice_ocr"]
    if not os.path.exists(root):
        return []

    ann_files = []
    for ext in ["*.json","*.csv"]:
        ann_files += glob(os.path.join(root, "**", ext), recursive=True)

    cases = []
    if ann_files:
        # take first annotation file found
        p = ann_files[0]
        if p.endswith(".csv"):
            df = pd.read_csv(p)
            text_col = infer_col(df, ["text","ocr","raw","content"])
            for i, row in df.head(limit).iterrows():
                raw_text = str(row[text_col]) if text_col else ""
                cases.append({
                    "id": f"invoice_ocr_csv_{i}",
                    "task": "invoice_extract_text",
                    "input": raw_text,
                    "gold": None,
                    "meta": {"ann": os.path.basename(p)}
                })
        else:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                js = json.load(f)
            # try to find list items with "text"
            items = []
            if isinstance(js, list):
                items = js
            elif isinstance(js, dict):
                # common keys
                for k in ["data","items","annotations","samples"]:
                    if k in js and isinstance(js[k], list):
                        items = js[k]
                        break

            for i, it in enumerate(items[:limit]):
                raw_text = it.get("text") or it.get("ocr_text") or it.get("content") or ""
                cases.append({
                    "id": f"invoice_ocr_json_{i}",
                    "task": "invoice_extract_text",
                    "input": str(raw_text),
                    "gold": None,
                    "meta": {"ann": os.path.basename(p)}
                })

        return cases

    # fallback: use image paths (for later OCR pipeline)
    imgs = glob(os.path.join(root, "**", "*.png"), recursive=True) + glob(os.path.join(root, "**", "*.jpg"), recursive=True)
    for i, ip in enumerate(imgs[:limit]):
        cases.append({
            "id": f"invoice_ocr_img_{i}",
            "task": "invoice_extract_text",
            "input": f"[IMAGE_PATH]{ip}",
            "gold": None,
            "meta": {"img": os.path.basename(ip)}
        })
    return cases

print("Loaders ready.")


Loaders ready.


In [9]:
# CELL SPLIT (REPLACE) — tách MCOCR (gold) riêng + tạo kd_pool lớn hơn nhưng có kiểm soát
import numpy as np
import pandas as pd
from datetime import datetime, date
import random, json, re
from pathlib import Path

def _json_default(o):
    if isinstance(o, (pd.Timestamp, datetime, date)):
        return o.isoformat()
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, (np.ndarray,)):
        return o.tolist()
    return str(o)

def write_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False, default=_json_default) + "\n")

# 1) Load dữ liệu
mcocr_cases = load_vn_mcocr_cases(limit=2000)          # tăng để có nhiều gold
invoice_cases = load_invoice_ocr_cases(limit=2000)     # tăng, nhưng thường thiếu gold
gl_cases = load_gl_cases(limit=2000)

# 2) Split MCOCR có gold để EVAL chuẩn
mcocr_gold = [c for c in mcocr_cases if isinstance(c.get("gold"), dict)]
random.shuffle(mcocr_gold)

# Giữ eval_gold đủ lớn để đo; train_gold còn lại để KD/augment nếu muốn
N_EVAL_GOLD = min(500, len(mcocr_gold))
eval_gold_cases = mcocr_gold[:N_EVAL_GOLD]
train_gold_cases = mcocr_gold[N_EVAL_GOLD:]

# 3) Eval set: ưu tiên MCOCR gold + thêm journal để đo schema/journal
eval_cases = []
eval_cases += eval_gold_cases
# thêm journal (không có gold) để đo schema strict
eval_cases += gl_cases[:500]

# (tuỳ chọn) invoice text-only nếu có (nếu input là IMAGE_PATH thì eval text-only LLM sẽ fail => bỏ)
invoice_text_only = [c for c in invoice_cases if isinstance(c.get("input"), str) and not str(c["input"]).startswith("[IMAGE_PATH]")]
eval_cases += invoice_text_only[:300]

random.shuffle(eval_cases)

# 4) KD pool: lấy (train_gold + invoice_text_only + gl) => lớn, đa dạng
kd_pool = []
kd_pool += train_gold_cases
kd_pool += invoice_text_only
kd_pool += gl_cases
random.shuffle(kd_pool)

# giới hạn để iteration ổn; bạn có thể nâng dần (2k, 5k, 10k)
KD_POOL_MAX = min(5000, len(kd_pool))
kd_pool = kd_pool[:KD_POOL_MAX]

print("MCOCR total:", len(mcocr_cases), "| gold:", len(mcocr_gold))
print("EVAL cases:", len(eval_cases), "| EVAL gold:", len(eval_gold_cases))
print("KD pool:", len(kd_pool))

eval_path = str(DATA_DIR / "eval_cases.jsonl")
kd_pool_path = str(DATA_DIR / "kd_pool.jsonl")
write_jsonl(eval_path, eval_cases)
write_jsonl(kd_pool_path, kd_pool)

print("Saved eval_path:", eval_path)
print("Saved kd_pool_path:", kd_pool_path)


MCOCR total: 1155 | gold: 0
EVAL cases: 500 | EVAL gold: 0
KD pool: 2000
Saved eval_path: /dev/shm/kaggle_ram/working/data/eval_cases.jsonl
Saved kd_pool_path: /dev/shm/kaggle_ram/working/data/kd_pool.jsonl


In [10]:
# =======================
# CELL 4B (REPLACE)
# =======================
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

HF_CACHE_DIR = os.environ["HF_HUB_CACHE"]

def assert_cuda_bf16_ready():
    assert torch.cuda.is_available(), "CUDA is not available. You said H100 runtime."
    name = torch.cuda.get_device_name(0)
    print("CUDA device:", name)
    # H100 supports bf16
    assert torch.cuda.is_bf16_supported(), "bf16 not supported on this GPU runtime."
    return name

def load_bf16_model(repo_or_path: str):
    """
    Full BF16 weights on GPU. Use this for BOTH training+eval if you want pure BF16.
    """
    assert_cuda_bf16_ready()

    tok = AutoTokenizer.from_pretrained(
        repo_or_path,
        use_fast=True,
        trust_remote_code=True,
        cache_dir=HF_CACHE_DIR,
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "left"

    mdl = AutoModelForCausalLM.from_pretrained(
        repo_or_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        cache_dir=HF_CACHE_DIR,
    )
    mdl.eval()
    return tok, mdl

@torch.no_grad()
def generate_batch(tok, mdl, prompts, max_new_tokens=256):
    enc = tok(prompts, return_tensors="pt", padding=True, truncation=True).to(mdl.device)
    out = mdl.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        top_p=1.0,
        repetition_penalty=1.05,
        pad_token_id=tok.eos_token_id,
    )
    in_len = enc["input_ids"].shape[1]
    gen = out[:, in_len:]
    return tok.batch_decode(gen, skip_special_tokens=True)

print("Model utils ready.")




Model utils ready.


In [11]:
# CELL 5 (REPLACE) — teacher cell: không hardcode token + TEACH_CACHE_DIR dùng RAM
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

secret_value_1 = user_secrets.get_secret("HF_TOKEN")

HF_TOKEN = os.getenv("HF_TOKEN", secret_value_1)
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("HF login OK")
else:
    print("HF_TOKEN missing -> gated teachers may fail")

TEACH_CACHE_DIR.mkdir(parents=True, exist_ok=True)

def run_teacher(name, repo=None, adapter=None, base=None, batch_size=16):
    print(f"\n=== Teacher: {name} ===")
    try:
        if adapter and base:
            tok, mdl = load_fingpt(adapter, base)
        else:
            tok, mdl = load_4bit_model(repo)

        with open(kd_pool_path, "r", encoding="utf-8") as f:
            pool = [json.loads(x) for x in f]

        out_path = TEACH_CACHE_DIR / f"{name}.jsonl"
        with open(out_path, "w", encoding="utf-8") as fw:
            for i in range(0, len(pool), batch_size):
                batch = pool[i:i+batch_size]
                prompts = [build_prompt(ex["task"], ex["input"]) for ex in batch]
                texts = generate_batch(tok, mdl, prompts)

                for ex, t in zip(batch, texts):
                    fw.write(json.dumps({"id": ex["id"], "task": ex["task"], "raw": t}, ensure_ascii=False) + "\n")

        del mdl
        torch.cuda.empty_cache()
        print("Saved:", str(out_path))
        return str(out_path)
    except Exception as e:
        print("FAILED:", name, "| reason:", str(e))
        return None


teacher_paths = {}

# finance llama3 (public)
teacher_paths["finance_llama3_8b"] = run_teacher("finance_llama3_8b", repo=TEACHERS["finance_llama3_8b"])

# open finance (gated)
teacher_paths["open_finance_8b"] = run_teacher("open_finance_8b", repo=TEACHERS["open_finance_8b"])

# fingpt lora (needs base llama3 gated)
teacher_paths["fingpt_lora_llama3_8b"] = run_teacher(
    "fingpt_lora_llama3_8b",
    adapter=TEACHERS["fingpt_lora_llama3_8b"],
    base=FINGPT_BASE
)

print("\nTeacher paths:", teacher_paths)


HF login OK

=== Teacher: finance_llama3_8b ===
FAILED: finance_llama3_8b | reason: name 'load_4bit_model' is not defined

=== Teacher: open_finance_8b ===
FAILED: open_finance_8b | reason: name 'load_4bit_model' is not defined

=== Teacher: fingpt_lora_llama3_8b ===
FAILED: fingpt_lora_llama3_8b | reason: name 'load_fingpt' is not defined

Teacher paths: {'finance_llama3_8b': None, 'open_finance_8b': None, 'fingpt_lora_llama3_8b': None}


In [12]:
# CELL DISTILL (REPLACE) — robust JSON extraction + coercion so KD distilled != 0
import re, json
from datetime import datetime

def _strip_code_fences(s: str) -> str:
    if not s:
        return s
    s = s.strip()
    # remove ```json ... ``` or ``` ... ```
    s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.I)
    s = re.sub(r"\s*```$", "", s)
    return s.strip()

def _try_json_load(s: str):
    try:
        return json.loads(s)
    except Exception:
        return None

def _json_repair_minimal(s: str):
    if s is None:
        return None
    s = _strip_code_fences(s)
    s = s.strip()

    # remove trailing commas
    s = re.sub(r",\s*}", "}", s)
    s = re.sub(r",\s*]", "]", s)

    # if single quotes and no double quotes (naive)
    if "'" in s and '"' not in s:
        s = s.replace("'", '"')

    return _try_json_load(s)

def extract_json_robust(text: str):
    """
    Strategy:
    1) try whole text (after stripping fences)
    2) try JSON code block
    3) scan all {...} candidates and pick the first that parses
    """
    if text is None:
        return None

    t = _strip_code_fences(text)

    # 1) whole
    obj = _try_json_load(t)
    if obj is not None:
        return obj

    # 2) try inside ```json ... ```
    m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.S | re.I)
    if m:
        obj = _try_json_load(m.group(1))
        if obj is not None:
            return obj
        obj = _json_repair_minimal(m.group(1))
        if obj is not None:
            return obj

    # 3) scan brace blocks non-greedy
    for m in re.finditer(r"\{.*?\}", t, flags=re.S):
        cand = m.group(0)
        obj = _try_json_load(cand)
        if obj is not None:
            return obj
        obj = _json_repair_minimal(cand)
        if obj is not None:
            return obj

    return None

def _to_float(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    s = str(x)
    s = s.replace(",", ".")
    s = re.sub(r"[^0-9\.\-]", "", s)
    try:
        return float(s) if s else None
    except Exception:
        return None

def _to_str_or_none(x):
    if x is None:
        return None
    s = str(x).strip()
    return s if s else None

def _to_flags(x):
    if x is None:
        return []
    if isinstance(x, list):
        return [str(v) for v in x if str(v).strip()]
    # split by comma/newline
    s = str(x)
    parts = re.split(r"[,;\n]+", s)
    return [p.strip() for p in parts if p.strip()]

def coerce_to_schema(task: str, obj: dict):
    if not isinstance(obj, dict):
        return None

    if task == "receipt_extract_text":
        out = {
            "vendor_name": _to_str_or_none(obj.get("vendor_name") or obj.get("vendor") or obj.get("seller")),
            "address": _to_str_or_none(obj.get("address") or obj.get("addr")),
            "date": _to_str_or_none(obj.get("date")),
            "total_amount": _to_float(obj.get("total_amount") or obj.get("total") or obj.get("amount")),
            "currency": _to_str_or_none(obj.get("currency")) or "VND",
            "confidence": _to_float(obj.get("confidence")) if obj.get("confidence") is not None else 0.5,
            "flags": _to_flags(obj.get("flags")),
        }
        return out

    if task == "invoice_extract_text":
        out = {
            "vendor_name": _to_str_or_none(obj.get("vendor_name") or obj.get("vendor") or obj.get("seller")),
            "invoice_no": _to_str_or_none(obj.get("invoice_no") or obj.get("invoice_number") or obj.get("invoice")),
            "date": _to_str_or_none(obj.get("date")),
            "subtotal": _to_float(obj.get("subtotal")),
            "tax": _to_float(obj.get("tax") or obj.get("vat")),
            "total": _to_float(obj.get("total") or obj.get("total_amount") or obj.get("amount")),
            "currency": _to_str_or_none(obj.get("currency")) or "VND",
            "confidence": _to_float(obj.get("confidence")) if obj.get("confidence") is not None else 0.5,
            "flags": _to_flags(obj.get("flags")),
        }
        return out

    if task == "journal_from_structured_txn":
        entries = obj.get("entries")
        if not isinstance(entries, list):
            entries = []
        norm_entries = []
        for e in entries:
            if not isinstance(e, dict):
                continue
            norm_entries.append({
                "account": _to_str_or_none(e.get("account")) or "",
                "debit": _to_float(e.get("debit")) or 0.0,
                "credit": _to_float(e.get("credit")) or 0.0,
                "memo": _to_str_or_none(e.get("memo")),
            })
        out = {
            "entries": norm_entries,
            "confidence": _to_float(obj.get("confidence")) if obj.get("confidence") is not None else 0.5,
            "flags": _to_flags(obj.get("flags")),
        }
        return out

    return obj

# ---- run distill again ----
def load_teacher_outputs(path):
    out = {}
    if not path:
        return out
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            r = json.loads(line)
            out[r["id"]] = r
    return out

teacher_outputs = {k: load_teacher_outputs(v) for k, v in teacher_paths.items() if v}

def router_priority(task):
    if task in ["receipt_extract_text", "invoice_extract_text"]:
        return ["open_finance_8b", "finance_llama3_8b", "fingpt_lora_llama3_8b"]
    if task == "journal_from_structured_txn":
        return ["finance_llama3_8b", "open_finance_8b", "fingpt_lora_llama3_8b"]
    return ["finance_llama3_8b", "open_finance_8b", "fingpt_lora_llama3_8b"]

distilled = []
dropped = 0
picked = {}

with open(kd_pool_path, "r", encoding="utf-8") as f:
    pool = [json.loads(x) for x in f]

for ex in pool:
    task = ex["task"]
    cid = ex["id"]

    chosen_obj = None
    chosen_teacher = None

    for tname in router_priority(task):
        if tname not in teacher_outputs:
            continue
        rec = teacher_outputs[tname].get(cid)
        if not rec:
            continue

        raw = rec.get("raw", "")
        obj = extract_json_robust(raw)
        obj = coerce_to_schema(task, obj) if obj is not None else None

        if obj is not None and schema_pass(task, obj):
            chosen_obj = obj
            chosen_teacher = tname
            break

    if chosen_obj is None:
        dropped += 1
        continue

    picked[chosen_teacher] = picked.get(chosen_teacher, 0) + 1
    distilled.append({
        "id": cid,
        "task": task,
        "prompt": build_prompt(task, ex["input"]),
        "answer_json": chosen_obj
    })

print("KD distilled:", len(distilled), "| dropped:", dropped)
print("picked_by_teacher:", picked)

distill_path = str(DATA_DIR / "distilled_train.jsonl")
with open(distill_path, "w", encoding="utf-8") as f:
    for r in distilled:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("Saved:", distill_path)


KD distilled: 0 | dropped: 2000
picked_by_teacher: {}
Saved: /dev/shm/kaggle_ram/working/data/distilled_train.jsonl


In [13]:
# CELL TRAIN (REPLACE) — bảo đảm lưu adapter vào 1 folder rõ ràng + tồn tại adapter_config.json
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from transformers import TrainingArguments
import json, os
from pathlib import Path

assert QWEN32B_PATH is not None, "Cannot find Qwen3-32B in /kaggle/input"

# Load Qwen base
qwen_tok, qwen_base = load_4bit_model(QWEN32B_PATH)

def guess_lora_targets(model):
    names = set()
    for n, _ in model.named_modules():
        if any(k in n for k in ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]):
            names.add(n.split(".")[-1])
    return sorted(list(names)) if names else ["q_proj","k_proj","v_proj","o_proj"]

targets = guess_lora_targets(qwen_base)
print("LoRA targets:", targets)

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=targets,
)

student = get_peft_model(qwen_base, lora_cfg)
student.print_trainable_parameters()

# Prepare dataset
rows = []
with open(distill_path, "r", encoding="utf-8") as f:
    for line in f:
        r = json.loads(line)
        rows.append({"text": r["prompt"] + "\n\n" + json.dumps(r["answer_json"], ensure_ascii=False)})

if len(rows) == 0:
    raise RuntimeError("distilled_train.jsonl is empty")

train_ds = Dataset.from_list(rows)

args = TrainingArguments(
    output_dir=str(WORKDIR / "student_ckpt"),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,        # giảm từ 2e-4
    num_train_epochs=2,        # chỉ tăng nếu B-STRICT tốt lên
    warmup_steps=20,
    logging_steps=10,
    save_steps=200,
    bf16=True,
    optim="paged_adamw_8bit",
    report_to="none",
)


trainer = SFTTrainer(
    model=student,
    args=args,
    train_dataset=train_ds,
    processing_class=qwen_tok,   # TRL mới dùng processing_class
)

trainer.train()

# ✅ SAVE ADAPTER to a concrete path (and export ADAPTER_DIR for later cells)
ADAPTER_DIR = Path(WORKDIR) / "outputs" / "adapters" / "student_adapter"
ADAPTER_DIR.mkdir(parents=True, exist_ok=True)

trainer.model.save_pretrained(str(ADAPTER_DIR))
qwen_tok.save_pretrained(str(ADAPTER_DIR))

# hard assert: must exist
assert (ADAPTER_DIR / "adapter_config.json").exists(), f"Missing adapter_config.json in {ADAPTER_DIR}"
print("✅ Saved student adapter:", str(ADAPTER_DIR))


2026-01-22 06:45:13.946990: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769064314.425773     106 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769064314.554155     106 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769064315.687238     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769064315.687277     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769064315.687278     106 computation_placer.cc:177] computation placer alr

NameError: name 'load_4bit_model' is not defined

In [None]:
# =======================
# CELL EVAL (STRICT GOLD, BF16) — score only MCOCR gold + NO default fill
# Place: REPLACE your current CELL EVAL (after TRAIN)
# =======================
import json, re, random
from rapidfuzz import fuzz

def _strip_code_fences(s: str) -> str:
    if not s:
        return s
    s = s.strip()
    s = re.sub(r"^```(?:json)?\s*", "", s, flags=re.I)
    s = re.sub(r"\s*```$", "", s)
    return s.strip()

def _try_json_load(s: str):
    try:
        return json.loads(s)
    except Exception:
        return None

def extract_json_robust(text: str):
    if text is None:
        return None
    t = _strip_code_fences(text)

    obj = _try_json_load(t)
    if obj is not None:
        return obj

    m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.S | re.I)
    if m:
        obj = _try_json_load(m.group(1))
        if obj is not None:
            return obj

    for m in re.finditer(r"\{.*?\}", t, flags=re.S):
        obj = _try_json_load(m.group(0))
        if obj is not None:
            return obj
    return None

def _to_float(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    s = str(x).replace(",", ".")
    s = re.sub(r"[^0-9\.\-]", "", s)
    try:
        return float(s) if s else None
    except Exception:
        return None

def _to_str(x):
    if x is None:
        return None
    s = str(x).strip()
    return s if s else None

def _to_flags(x):
    if x is None:
        return None
    if isinstance(x, list):
        return [str(v).strip() for v in x if str(v).strip()]
    return [p.strip() for p in re.split(r"[,;\n]+", str(x)) if p.strip()]

def coerce_receipt_strict(obj: dict):
    """
    STRICT: NO DEFAULT FILL. Missing fields stay None -> schema should fail.
    """
    if not isinstance(obj, dict):
        return None
    return {
        "vendor_name": _to_str(obj.get("vendor_name") or obj.get("vendor") or obj.get("seller")),
        "address": _to_str(obj.get("address") or obj.get("addr")),
        "date": _to_str(obj.get("date")),
        "total_amount": _to_float(obj.get("total_amount") or obj.get("total") or obj.get("amount")),
        "currency": _to_str(obj.get("currency")),
        "confidence": _to_float(obj.get("confidence")),
        "flags": _to_flags(obj.get("flags")),
    }

def vendor_similarity(a, b):
    if not a or not b:
        return None
    return fuzz.token_set_ratio(str(a), str(b)) / 100.0

def amount_close(g, p, rel_tol=0.02, abs_tol=2000.0):
    if g is None or p is None:
        return None
    try:
        g = float(g); p = float(p)
        return abs(g - p) <= max(abs_tol, rel_tol * max(abs(g), 1.0))
    except Exception:
        return None

def eval_receipt_gold_only(tok, mdl, cases, batch_size=32, max_new_tokens=256):
    gold_cases = [
        c for c in cases
        if c.get("task") == "receipt_extract_text"
        and isinstance(c.get("gold"), dict)
        and any(c["gold"].get(k) is not None for k in ["vendor_name","date","total_amount","address"])
    ]
    n = len(gold_cases)
    if n == 0:
        return {"n": 0}

    json_valid = 0
    schema_ok = 0
    vendor_sims = []
    total_close = []
    date_exact = []
    bad = []

    for i in range(0, n, batch_size):
        batch = gold_cases[i:i+batch_size]
        prompts = [build_prompt("receipt_extract_text", ex["input"]) for ex in batch]
        texts = generate_batch(tok, mdl, prompts, max_new_tokens=max_new_tokens)

        for ex, t in zip(batch, texts):
            gold = ex["gold"]

            raw_obj = extract_json_robust(t)
            obj = coerce_receipt_strict(raw_obj) if raw_obj else None

            if obj is not None:
                json_valid += 1
                if schema_pass("receipt_extract_text", obj):
                    schema_ok += 1
                else:
                    if len(bad) < 10:
                        bad.append({"id": ex["id"], "reason": "schema_fail", "raw": t[:400]})
            else:
                if len(bad) < 10:
                    bad.append({"id": ex["id"], "reason": "json_parse_fail", "raw": t[:400]})

            vs = vendor_similarity(gold.get("vendor_name"), obj.get("vendor_name") if obj else None)
            if vs is not None:
                vendor_sims.append(vs)

            tc = amount_close(gold.get("total_amount"), obj.get("total_amount") if obj else None)
            if tc is not None:
                total_close.append(1.0 if tc else 0.0)

            gd = gold.get("date")
            pd = obj.get("date") if obj else None
            if gd is not None and pd is not None:
                date_exact.append(1.0 if str(gd).strip() == str(pd).strip() else 0.0)

    def avg(xs):
        return float(sum(xs)/len(xs)) if xs else None

    return {
        "n": n,
        "json_valid_rate": json_valid / n,
        "schema_pass_rate": schema_ok / n,
        "vendor_sim_avg": avg(vendor_sims),
        "total_close_rate": avg(total_close),
        "date_exact_rate": avg(date_exact),
        "debug_bad_samples": bad,
    }

# ---- load eval cases ----
with open(eval_path, "r", encoding="utf-8") as f:
    eval_cases = [json.loads(x) for x in f]

# ---- BASE (BF16) ----
base_tok, base_mdl = load_bf16_model(QWEN32B_PATH)
base_gold = eval_receipt_gold_only(base_tok, base_mdl, eval_cases, batch_size=32, max_new_tokens=256)
print("BASE (GOLD receipt strict, BF16):", {k: v for k, v in base_gold.items() if k != "debug_bad_samples"})

# ---- STUDENT (BF16 base + adapter) ----
from peft import PeftModel
student_tok, student_base = load_bf16_model(QWEN32B_PATH)
student_mdl = PeftModel.from_pretrained(student_base, str(ADAPTER_DIR))
student_mdl.eval()

student_gold = eval_receipt_gold_only(student_tok, student_mdl, eval_cases, batch_size=32, max_new_tokens=256)
print("STUDENT (GOLD receipt strict, BF16):", {k: v for k, v in student_gold.items() if k != "debug_bad_samples"})

report = {
    "base_gold_receipt_strict_bf16": base_gold,
    "student_gold_receipt_strict_bf16": student_gold,
    "meta": {
        "eval_path": str(eval_path),
        "mode": "bf16_inference_only",
        "task": "receipt_extract_text (gold only)",
        "batch_size": 32,
        "max_new_tokens": 256,
        "notes": "strict: no default fill; score only gold-bearing MCOCR receipts",
    },
}

report_path = str(DATA_DIR / "eval_report_gold_strict_bf16.json")
with open(report_path, "w", encoding="utf-8") as f:
    json.dump(report, f, ensure_ascii=False, indent=2)

print("Saved report:", report_path)


In [None]:
# FINAL SAVE (adapter) — student adapter in RAM
ADAPTER_DIR = str(WORKDIR / "student_adapter")
trainer.model.save_pretrained(ADAPTER_DIR)
qwen_tok.save_pretrained(ADAPTER_DIR)
print("Saved student adapter dir:", ADAPTER_DIR)


In [None]:
# CELL SAVE (ADD NEW, cuối notebook) — xuất student adapter + tokenizer + lineage để tải về
import json, shutil
from pathlib import Path

OUT_DIR = WORKDIR / "outputs"
REL = Path("/kaggle/working") / "release" / "qwen3-32b-accounting-distilled-v0.1.0"
REL.mkdir(parents=True, exist_ok=True)

ADAPTER_OUT = REL / "adapters"
TOKEN_OUT  = REL / "tokenizer"
ADAPTER_OUT.mkdir(parents=True, exist_ok=True)
TOKEN_OUT.mkdir(parents=True, exist_ok=True)

# Save LoRA adapter (nhẹ) + tokenizer
trainer.model.save_pretrained(str(ADAPTER_OUT))
qwen_tok.save_pretrained(str(TOKEN_OUT))

lineage = {
    "model_name": "qwen3-32b-accounting-distilled",
    "version": "v0.1.0",
    "method": "Knowledge Distillation (multi-teacher) + QLoRA SFT",
    "teachers": TEACHERS,
    "paths": {
        "distill_path": str(distill_path),
        "adapter_dir": str(ADAPTER_OUT),
        "tokenizer_dir": str(TOKEN_OUT),
    },
}
with open(REL / "lineage.json", "w", encoding="utf-8") as f:
    json.dump(lineage, f, ensure_ascii=False, indent=2)

# Zip để download dễ
zip_path = shutil.make_archive(str(REL), "zip", root_dir=str(REL))
print("✅ RELEASE folder:", REL)
print("✅ ZIP:", zip_path)
