In [7]:
# ===== FIX Transformers/TRL/BnB (avoid GenerationMixin bug + avoid dependency hell) =====
!pip -q install -U --no-cache-dir \
  "transformers==4.46.3" \
  "accelerate>=0.30.0" \
  "datasets>=2.19.0" \
  "peft>=0.11.0" \
  "trl>=0.9.6" \
  "bitsandbytes>=0.43.1" \
  "huggingface_hub>=0.23.0" \
  "tokenizers>=0.20.0,<0.21.0" \
  "safetensors>=0.4.3" \
  "jsonschema>=4.22.0" \
  "rapidfuzz>=3.9.0" \
  "openpyxl>=3.1.5"

import transformers, tokenizers
from transformers.generation import GenerationMixin

print("✅ transformers =", transformers.__version__)
print("✅ tokenizers   =", tokenizers.__version__)
print("✅ GenerationMixin import OK")


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m0:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m380.9/380.9 kB[0m [31m147.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.2/515.2 kB[0m [31m235.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.0/557.0 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m348.0/348.0 kB[0m [31m531.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m32.5 MB/s[0m eta [36m

In [2]:
import os, re, json, time, math, random
from pathlib import Path
from glob import glob

WORKDIR = Path("/kaggle/working")
DATA_DIR = WORKDIR / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)

# ===== Auto-discover Qwen3-32B model path in /kaggle/input =====
def find_qwen32b_path():
    candidates = []
    for p in glob("/kaggle/input/**", recursive=True):
        if os.path.isdir(p):
            low = p.lower()
            if "qwen" in low and ("32b" in low or "32-b" in low):
                # must contain config.json to be HF-compatible
                if os.path.exists(os.path.join(p, "config.json")):
                    candidates.append(p)
    # choose shortest path (usually the root model folder)
    candidates = sorted(candidates, key=lambda x: len(x))
    return candidates[0] if candidates else None

QWEN32B_PATH = find_qwen32b_path()
print("QWEN32B_PATH =", QWEN32B_PATH)

# ===== Teachers (HF) =====
TEACHERS = {
    "open_finance_8b": "DragonLLM/Llama-Open-Finance-8B",            # gated :contentReference[oaicite:3]{index=3}
    "finance_llama3_8b": "instruction-pretrain/finance-Llama3-8B",   # public :contentReference[oaicite:4]{index=4}
    "fingpt_lora_llama3_8b": "FinGPT/fingpt-mt_llama3-8b_lora"        # adapter :contentReference[oaicite:5]{index=5}
}
FINGPT_BASE = "meta-llama/Meta-Llama-3-8B"  # gated sometimes :contentReference[oaicite:6]{index=6}

# ===== Dataset roots (bạn đã mount sẵn) =====
PATHS = {
    "vn_mcocr": "/kaggle/input/vietnamese-receipts-mc-ocr-2021",
    "invoice_ocr": "/kaggle/input/invoice-ocr",
    "hi_quality_invoice": "/kaggle/input/high-quality-invoice-images-for-ocr",
    "gl_xlsx": "/kaggle/input/generalledger/Data file for students.xlsx",
    "transactions_csv": "/kaggle/input/financial-transactions-dataset/financial_transactions.csv",
    "forecast_csv": "/kaggle/input/financial-forecasting-data/simulated_financial_forecasting_data.csv",
    "data_retriever_csv": "/kaggle/input/data-retreiver/Data_ret.csv",
}
print("DATA PATHS OK:", PATHS)


QWEN32B_PATH = /kaggle/input/qwen-3/transformers/32b/1
DATA PATHS OK: {'vn_mcocr': '/kaggle/input/vietnamese-receipts-mc-ocr-2021', 'invoice_ocr': '/kaggle/input/invoice-ocr', 'hi_quality_invoice': '/kaggle/input/high-quality-invoice-images-for-ocr', 'gl_xlsx': '/kaggle/input/generalledger/Data file for students.xlsx', 'transactions_csv': '/kaggle/input/financial-transactions-dataset/financial_transactions.csv', 'forecast_csv': '/kaggle/input/financial-forecasting-data/simulated_financial_forecasting_data.csv', 'data_retriever_csv': '/kaggle/input/data-retreiver/Data_ret.csv'}


In [3]:
from jsonschema import validate
from jsonschema.exceptions import ValidationError

# ===== Schemas =====
RECEIPT_SCHEMA = {
    "type": "object",
    "properties": {
        "vendor_name": {"type": ["string", "null"]},
        "address": {"type": ["string", "null"]},
        "date": {"type": ["string", "null"]},            # YYYY-MM-DD preferred
        "total_amount": {"type": ["number", "null"]},
        "currency": {"type": ["string", "null"]},        # "VND"
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["vendor_name","address","date","total_amount","currency","confidence","flags"]
}

INVOICE_SCHEMA = {
    "type": "object",
    "properties": {
        "vendor_name": {"type": ["string", "null"]},
        "invoice_no": {"type": ["string", "null"]},
        "date": {"type": ["string", "null"]},
        "subtotal": {"type": ["number", "null"]},
        "tax": {"type": ["number", "null"]},
        "total": {"type": ["number", "null"]},
        "currency": {"type": ["string", "null"]},
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["vendor_name","invoice_no","date","subtotal","tax","total","currency","confidence","flags"]
}

JOURNAL_SCHEMA = {
    "type": "object",
    "properties": {
        "entries": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "account": {"type": "string"},
                    "debit": {"type": "number"},
                    "credit": {"type": "number"},
                    "memo": {"type": ["string","null"]}
                },
                "required": ["account","debit","credit","memo"]
            }
        },
        "confidence": {"type": "number"},
        "flags": {"type": "array", "items": {"type": "string"}}
    },
    "required": ["entries","confidence","flags"]
}

TASK2SCHEMA = {
    "receipt_extract_text": RECEIPT_SCHEMA,
    "invoice_extract_text": INVOICE_SCHEMA,
    "journal_from_structured_txn": JOURNAL_SCHEMA,
}

def schema_pass(task: str, obj: dict) -> bool:
    try:
        validate(instance=obj, schema=TASK2SCHEMA[task])
        return True
    except ValidationError:
        return False
    except Exception:
        return False

# ===== JSON extract / repair =====
def extract_json_from_text(text: str):
    if text is None:
        return None
    # pick first {...} block
    m = re.search(r"\{.*\}", text, flags=re.S)
    if not m:
        return None
    s = m.group(0)
    try:
        return json.loads(s)
    except Exception:
        return None

def json_repair_minimal(text: str):
    """
    deterministic repair for common LLM issues:
    - trailing commas
    - single quotes -> double quotes (simple cases)
    """
    if text is None:
        return None
    m = re.search(r"\{.*\}", text, flags=re.S)
    if not m:
        return None
    s = m.group(0).strip()

    s = re.sub(r",\s*}", "}", s)
    s = re.sub(r",\s*]", "]", s)
    # naive quote fix (only if it looks like JSON)
    if "'" in s and '"' not in s:
        s = s.replace("'", '"')

    try:
        return json.loads(s)
    except Exception:
        return None

# ===== Prompt builder =====
def build_prompt(task: str, input_data):
    if task == "receipt_extract_text":
        return f"""
Extract receipt key fields from Vietnamese text.
Return ONLY valid JSON with fields:
vendor_name,address,date,total_amount,currency,confidence,flags

Receipt Text:
{input_data}
""".strip()

    if task == "invoice_extract_text":
        return f"""
Extract invoice fields from text.
Return ONLY valid JSON with fields:
vendor_name,invoice_no,date,subtotal,tax,total,currency,confidence,flags

Invoice Text:
{input_data}
""".strip()

    if task == "journal_from_structured_txn":
        return f"""
You are an ERP accountant.
Given a structured transaction JSON, propose journal entries.
Return ONLY valid JSON with fields:
entries[{account,debit,credit,memo}],confidence,flags

Transaction:
{json.dumps(input_data, ensure_ascii=False)}
""".strip()

    raise ValueError("Unknown task")

print("Schemas + prompt builder ready.")


Schemas + prompt builder ready.


In [4]:
import pandas as pd

def infer_col(df, candidates):
    cols = {c.lower(): c for c in df.columns}
    for cand in candidates:
        if cand.lower() in cols:
            return cols[cand.lower()]
    # fuzzy contains
    for c in df.columns:
        low = c.lower()
        for cand in candidates:
            if cand.lower() in low:
                return c
    return None

def load_vn_mcocr_cases(limit=300):
    root = PATHS["vn_mcocr"]
    cases = []

    # Prefer CSV with gold labels if present
    csv_candidates = [
        os.path.join(root, "mcocr_train_df.csv"),
        os.path.join(root, "mcocr_val_sample_df.csv"),
        os.path.join(root, "results.csv"),
    ]
    for p in csv_candidates:
        if os.path.exists(p):
            df = pd.read_csv(p)

            text_col = infer_col(df, ["text", "ocr_text", "raw_text", "content", "transcription"])
            if text_col is None:
                # fallback pick longest string col
                str_cols = [c for c in df.columns if df[c].dtype == "object"]
                if str_cols:
                    text_col = max(str_cols, key=lambda c: df[c].astype(str).str.len().mean())

            seller_col = infer_col(df, ["seller", "vendor", "vendor_name", "merchant", "store", "shop"])
            addr_col   = infer_col(df, ["address", "seller_address", "vendor_address"])
            date_col   = infer_col(df, ["timestamp", "date", "datetime", "time"])
            total_col  = infer_col(df, ["total_cost", "total", "amount", "total_amount", "sum"])

            for i, row in df.head(limit).iterrows():
                raw_text = str(row[text_col]) if text_col else ""

                gold = None
                if seller_col or addr_col or date_col or total_col:
                    def safe_float(x):
                        try:
                            if pd.isna(x): 
                                return None
                            s = str(x)
                            s = re.sub(r"[^\d\.\-]", "", s)
                            return float(s) if s else None
                        except:
                            return None

                    gold = {
                        "vendor_name": str(row[seller_col]) if seller_col and pd.notna(row[seller_col]) else None,
                        "address": str(row[addr_col]) if addr_col and pd.notna(row[addr_col]) else None,
                        "date": str(row[date_col]) if date_col and pd.notna(row[date_col]) else None,
                        "total_amount": safe_float(row[total_col]) if total_col else None,
                        "currency": "VND",
                        "confidence": 0.0,
                        "flags": []
                    }

                cases.append({
                    "id": f"vn_mcocr_{i}",
                    "task": "receipt_extract_text",
                    "input": raw_text,
                    "gold": gold,
                    "meta": {"source": os.path.basename(p)}
                })
            return cases

    # fallback txt (OCR lines)
    txt_candidates = [
        os.path.join(root, "text_recognition_train_data.txt"),
        os.path.join(root, "text_recognition_val_data.txt"),
    ]
    for p in txt_candidates:
        if os.path.exists(p):
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                for idx, line in enumerate(f):
                    if idx >= limit:
                        break
                    parts = line.strip().split("\t")
                    raw_text = parts[-1] if parts else ""
                    cases.append({
                        "id": f"vn_mcocr_txt_{idx}",
                        "task": "receipt_extract_text",
                        "input": raw_text,
                        "gold": None,
                        "meta": {"source": os.path.basename(p)}
                    })
            return cases

    return []

def load_gl_cases(limit=200):
    xlsx_path = PATHS["gl_xlsx"]
    if not os.path.exists(xlsx_path):
        return []
    xls = pd.ExcelFile(xlsx_path)
    # take first sheet by default
    df = pd.read_excel(xlsx_path, sheet_name=xls.sheet_names[0])

    cases = []
    for i, row in df.head(limit).iterrows():
        txn = row.to_dict()
        cases.append({
            "id": f"gl_{i}",
            "task": "journal_from_structured_txn",
            "input": txn,
            "gold": None,
            "meta": {"sheet": xls.sheet_names[0]}
        })
    return cases

def load_invoice_ocr_cases(limit=200):
    """
    Robust loader:
    - If JSON/CSV annotations exist -> use their text fields
    - Otherwise use image paths (text-only LLM can't read images, but still valid for KD if you later OCR)
    """
    root = PATHS["invoice_ocr"]
    if not os.path.exists(root):
        return []

    ann_files = []
    for ext in ["*.json","*.csv"]:
        ann_files += glob(os.path.join(root, "**", ext), recursive=True)

    cases = []
    if ann_files:
        # take first annotation file found
        p = ann_files[0]
        if p.endswith(".csv"):
            df = pd.read_csv(p)
            text_col = infer_col(df, ["text","ocr","raw","content"])
            for i, row in df.head(limit).iterrows():
                raw_text = str(row[text_col]) if text_col else ""
                cases.append({
                    "id": f"invoice_ocr_csv_{i}",
                    "task": "invoice_extract_text",
                    "input": raw_text,
                    "gold": None,
                    "meta": {"ann": os.path.basename(p)}
                })
        else:
            with open(p, "r", encoding="utf-8", errors="ignore") as f:
                js = json.load(f)
            # try to find list items with "text"
            items = []
            if isinstance(js, list):
                items = js
            elif isinstance(js, dict):
                # common keys
                for k in ["data","items","annotations","samples"]:
                    if k in js and isinstance(js[k], list):
                        items = js[k]
                        break

            for i, it in enumerate(items[:limit]):
                raw_text = it.get("text") or it.get("ocr_text") or it.get("content") or ""
                cases.append({
                    "id": f"invoice_ocr_json_{i}",
                    "task": "invoice_extract_text",
                    "input": str(raw_text),
                    "gold": None,
                    "meta": {"ann": os.path.basename(p)}
                })

        return cases

    # fallback: use image paths (for later OCR pipeline)
    imgs = glob(os.path.join(root, "**", "*.png"), recursive=True) + glob(os.path.join(root, "**", "*.jpg"), recursive=True)
    for i, ip in enumerate(imgs[:limit]):
        cases.append({
            "id": f"invoice_ocr_img_{i}",
            "task": "invoice_extract_text",
            "input": f"[IMAGE_PATH]{ip}",
            "gold": None,
            "meta": {"img": os.path.basename(ip)}
        })
    return cases

print("Loaders ready.")


Loaders ready.


In [5]:
import numpy as np
import pandas as pd
from datetime import datetime, date

def _json_default(o):
    if isinstance(o, (pd.Timestamp, datetime, date)):
        return o.isoformat()
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, (np.ndarray,)):
        return o.tolist()
    return str(o)

def write_jsonl(path, rows):
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False, default=_json_default) + "\n")


eval_cases = []
eval_cases += load_vn_mcocr_cases(limit=300)
eval_cases += load_invoice_ocr_cases(limit=200)
eval_cases += load_gl_cases(limit=150)

print("Total eval cases:", len(eval_cases))
eval_path = str(DATA_DIR / "eval_cases.jsonl")
write_jsonl(eval_path, eval_cases)
print("Saved:", eval_path)

# KD training uses the same pool (you can enlarge later)
kd_pool = eval_cases.copy()
random.shuffle(kd_pool)
kd_pool = kd_pool[:500]  # keep KD small for iteration speed
kd_pool_path = str(DATA_DIR / "kd_pool.jsonl")
write_jsonl(kd_pool_path, kd_pool)
print("Saved KD pool:", kd_pool_path, "| size:", len(kd_pool))


Total eval cases: 450
Saved: /kaggle/working/data/eval_cases.jsonl
Saved KD pool: /kaggle/working/data/kd_pool.jsonl | size: 450


In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

bnb4 = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

def load_4bit_model(repo_or_path: str):
    tok = AutoTokenizer.from_pretrained(repo_or_path, use_fast=True, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "left"

    mdl = AutoModelForCausalLM.from_pretrained(
        repo_or_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb4,
        trust_remote_code=True,
        attn_implementation="sdpa",
    )
    mdl.eval()
    return tok, mdl

def load_fingpt(adapter_repo: str, base_repo: str):
    tok, base = load_4bit_model(base_repo)
    mdl = PeftModel.from_pretrained(base, adapter_repo)
    mdl.eval()
    return tok, mdl

@torch.no_grad()
def generate_batch(tok, mdl, prompts, max_new_tokens=320):
    enc = tok(prompts, return_tensors="pt", padding=True, truncation=True).to(mdl.device)
    out = mdl.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        top_p=1.0,
        repetition_penalty=1.05,
    )
    return tok.batch_decode(out, skip_special_tokens=True)

print("Model utils ready.")


2026-01-21 03:44:45.971626: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768967086.492261     242 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768967086.646269     242 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768967087.831907     242 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768967087.831937     242 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768967087.831940     242 computation_placer.cc:177] computation placer alr

Model utils ready.


In [9]:
from huggingface_hub import login

HF_TOKEN = os.getenv("HF_TOKEN", None)
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("HF login OK")
else:
    print("HF_TOKEN missing -> gated teachers may fail (OpenFinance/Llama base).")

TEACH_CACHE_DIR = DATA_DIR / "teacher_outputs"
TEACH_CACHE_DIR.mkdir(parents=True, exist_ok=True)

def run_teacher(name, repo=None, adapter=None, base=None, batch_size=16):
    print(f"\n=== Teacher: {name} ===")
    try:
        if adapter and base:
            tok, mdl = load_fingpt(adapter, base)
        else:
            tok, mdl = load_4bit_model(repo)

        # load KD pool
        with open(kd_pool_path, "r", encoding="utf-8") as f:
            pool = [json.loads(x) for x in f]

        out_path = TEACH_CACHE_DIR / f"{name}.jsonl"
        with open(out_path, "w", encoding="utf-8") as fw:
            for i in range(0, len(pool), batch_size):
                batch = pool[i:i+batch_size]
                prompts = [build_prompt(ex["task"], ex["input"]) for ex in batch]
                texts = generate_batch(tok, mdl, prompts)

                for ex, t in zip(batch, texts):
                    fw.write(json.dumps({
                        "id": ex["id"],
                        "task": ex["task"],
                        "raw": t
                    }, ensure_ascii=False) + "\n")

        del mdl
        torch.cuda.empty_cache()
        print("Saved:", str(out_path))
        return str(out_path)
    except Exception as e:
        print("FAILED:", name, "| reason:", str(e))
        return None

teacher_paths = {}

# finance llama3 (public)
teacher_paths["finance_llama3_8b"] = run_teacher("finance_llama3_8b", repo=TEACHERS["finance_llama3_8b"])

# open finance (gated)
teacher_paths["open_finance_8b"] = run_teacher("open_finance_8b", repo=TEACHERS["open_finance_8b"])

# fingpt lora (needs base llama3 gated)
teacher_paths["fingpt_lora_llama3_8b"] = run_teacher(
    "fingpt_lora_llama3_8b",
    adapter=TEACHERS["fingpt_lora_llama3_8b"],
    base=FINGPT_BASE
)

print("\nTeacher paths:", teacher_paths)


HF_TOKEN missing -> gated teachers may fail (OpenFinance/Llama base).

=== Teacher: finance_llama3_8b ===


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/2.57G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]

FAILED: finance_llama3_8b | reason: name 'account' is not defined

=== Teacher: open_finance_8b ===
FAILED: open_finance_8b | reason: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/DragonLLM/Llama-Open-Finance-8B.
401 Client Error. (Request ID: Root=1-69704bfe-00fbec8f3132e0df050f49de;1ea2811f-0d02-438b-b27d-665d9c29c577)

Cannot access gated repo for url https://huggingface.co/DragonLLM/Llama-Open-Finance-8B/resolve/main/config.json.
Access to model DragonLLM/Llama-Open-Finance-8B is restricted. You must have access to it and be authenticated to access it. Please log in.

=== Teacher: fingpt_lora_llama3_8b ===
FAILED: fingpt_lora_llama3_8b | reason: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B.
401 Client Error. (Request ID: Root=1-69704bfe-130b44fe3fef28350127adda;d60bda15-af48-4ea2-adef-e59b2c776d6d)

Cannot access gated repo for url https://huggingface.co/m

In [10]:
def load_teacher_outputs(path):
    out = {}
    if not path:
        return out
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            r = json.loads(line)
            out[r["id"]] = r
    return out

teacher_outputs = {k: load_teacher_outputs(v) for k,v in teacher_paths.items() if v}

def router_priority(task):
    if task in ["receipt_extract_text", "invoice_extract_text"]:
        return ["open_finance_8b", "finance_llama3_8b", "fingpt_lora_llama3_8b"]
    if task == "journal_from_structured_txn":
        return ["finance_llama3_8b", "open_finance_8b", "fingpt_lora_llama3_8b"]
    return ["finance_llama3_8b", "open_finance_8b", "fingpt_lora_llama3_8b"]

distilled = []
dropped = 0
picked = {}

with open(kd_pool_path, "r", encoding="utf-8") as f:
    pool = [json.loads(x) for x in f]

for ex in pool:
    task = ex["task"]
    cid  = ex["id"]

    chosen_obj = None
    chosen_teacher = None

    for tname in router_priority(task):
        if tname not in teacher_outputs:
            continue
        rec = teacher_outputs[tname].get(cid)
        if not rec:
            continue

        raw = rec["raw"]
        obj = extract_json_from_text(raw)
        if obj is None:
            obj = json_repair_minimal(raw)

        if obj is not None and schema_pass(task, obj):
            chosen_obj = obj
            chosen_teacher = tname
            break

    if chosen_obj is None:
        dropped += 1
        continue

    picked[chosen_teacher] = picked.get(chosen_teacher, 0) + 1
    distilled.append({
        "id": cid,
        "task": task,
        "prompt": build_prompt(task, ex["input"]),
        "answer_json": chosen_obj
    })

print("KD distilled:", len(distilled), "| dropped:", dropped)
print("picked_by_teacher:", picked)

distill_path = str(DATA_DIR / "distilled_train.jsonl")
with open(distill_path, "w", encoding="utf-8") as f:
    for r in distilled:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("Saved:", distill_path)


KD distilled: 0 | dropped: 450
picked_by_teacher: {}
Saved: /kaggle/working/data/distilled_train.jsonl


In [11]:
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from transformers import TrainingArguments

assert QWEN32B_PATH is not None, "Cannot find Qwen3-32B in /kaggle/input"

# Load Qwen base
qwen_tok, qwen_base = load_4bit_model(QWEN32B_PATH)

# Auto-target modules (robust across architectures)
def guess_lora_targets(model):
    names = set()
    for n, _ in model.named_modules():
        if any(k in n for k in ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]):
            names.add(n.split(".")[-1])
    # fallback default
    if not names:
        return ["q_proj","k_proj","v_proj","o_proj"]
    return sorted(list(names))

targets = guess_lora_targets(qwen_base)
print("LoRA targets:", targets)

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=targets
)

student = get_peft_model(qwen_base, lora_cfg)
student.print_trainable_parameters()

# Prepare dataset
rows = []
with open(distill_path, "r", encoding="utf-8") as f:
    for line in f:
        r = json.loads(line)
        rows.append({
            "text": r["prompt"] + "\n\n" + json.dumps(r["answer_json"], ensure_ascii=False)
        })

train_ds = Dataset.from_list(rows)

args = TrainingArguments(
    output_dir=str(WORKDIR / "student_ckpt"),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=200,
    bf16=True,
    optim="paged_adamw_8bit",
    report_to="none"
)

trainer = SFTTrainer(
    model=student,
    tokenizer=qwen_tok,
    train_dataset=train_ds,
    args=args,
    max_seq_length=2048,
)

trainer.train()

ADAPTER_DIR = str(WORKDIR / "student_adapter")
trainer.model.save_pretrained(ADAPTER_DIR)
qwen_tok.save_pretrained(ADAPTER_DIR)

print("Saved student adapter:", ADAPTER_DIR)




ValueError: The checkpoint you are trying to load has model type `qwen3` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

In [None]:
def normalize_obj(obj):
    if obj is None:
        return None
    # stable dump for determinism comparison
    return json.dumps(obj, ensure_ascii=False, sort_keys=True)

def field_exact(gold, pred, key):
    if gold is None or pred is None:
        return None
    if gold.get(key) is None:
        return None
    return 1.0 if str(gold.get(key)).strip() == str(pred.get(key)).strip() else 0.0

def eval_model(tok, mdl, cases, repeats=3, batch_size=8):
    stats = {
        "json_valid_rate": 0,
        "schema_pass_rate": 0,
        "determinism_rate": 0,
        "n": len(cases),
        "field_vendor_acc": [],
        "field_total_acc": [],
        "field_date_acc": [],
    }

    det_same = 0
    valid = 0
    schema_ok = 0

    for i in range(0, len(cases), batch_size):
        batch = cases[i:i+batch_size]
        prompts = [build_prompt(ex["task"], ex["input"]) for ex in batch]

        # determinism: run repeats times
        outputs_all = []
        for _ in range(repeats):
            texts = generate_batch(tok, mdl, prompts, max_new_tokens=320)
            objs = []
            for ex, t in zip(batch, texts):
                obj = extract_json_from_text(t) or json_repair_minimal(t)
                objs.append(obj)
            outputs_all.append(objs)

        # compute per-sample stats
        for j, ex in enumerate(batch):
            task = ex["task"]
            gold = ex.get("gold")

            # use first run as "pred"
            pred = outputs_all[0][j]

            if pred is not None:
                valid += 1
                if schema_pass(task, pred):
                    schema_ok += 1

            # determinism check: all normalized equal
            norms = [normalize_obj(outputs_all[r][j]) for r in range(repeats)]
            if len(set(norms)) == 1:
                det_same += 1

            # Tier B field acc if gold exists & relevant
            if gold and isinstance(gold, dict) and task in ["receipt_extract_text","invoice_extract_text"]:
                v = field_exact(gold, pred, "vendor_name")
                d = field_exact(gold, pred, "date")
                # receipt: total_amount ; invoice: total
                if task == "receipt_extract_text":
                    tacc = field_exact(gold, pred, "total_amount")
                else:
                    tacc = field_exact(gold, pred, "total")

                if v is not None: stats["field_vendor_acc"].append(v)
                if d is not None: stats["field_date_acc"].append(d)
                if tacc is not None: stats["field_total_acc"].append(tacc)

    n = max(1, stats["n"])
    stats["json_valid_rate"] = valid / n
    stats["schema_pass_rate"] = schema_ok / n
    stats["determinism_rate"] = det_same / n

    def avg(x):
        return float(sum(x)/len(x)) if x else None

    stats["vendor_acc"] = avg(stats["field_vendor_acc"])
    stats["date_acc"] = avg(stats["field_date_acc"])
    stats["total_acc"] = avg(stats["field_total_acc"])

    # cleanup arrays
    stats.pop("field_vendor_acc", None)
    stats.pop("field_date_acc", None)
    stats.pop("field_total_acc", None)

    return stats

# Load eval cases
with open(eval_path, "r", encoding="utf-8") as f:
    eval_cases = [json.loads(x) for x in f]

# BASE QWEN
base_tok, base_mdl = load_4bit_model(QWEN32B_PATH)
base_stats = eval_model(base_tok, base_mdl, eval_cases, repeats=3)
print("BASE:", base_stats)

# STUDENT = base + adapter
from peft import PeftModel
student_tok, student_base = load_4bit_model(QWEN32B_PATH)
student_mdl = PeftModel.from_pretrained(student_base, ADAPTER_DIR)
student_mdl.eval()
student_stats = eval_model(student_tok, student_mdl, eval_cases, repeats=3)
print("STUDENT:", student_stats)

report = {
    "base_qwen3_32b": base_stats,
    "student_qwen3_32b_adapter": student_stats,
}

report_path = str(DATA_DIR / "eval_report.json")
with open(report_path, "w", encoding="utf-8") as f:
    json.dump(report, f, ensure_ascii=False, indent=2)

print("Saved report:", report_path)
