# preprocess one 

# cell 1

In [None]:
# 🆕 put this in a fresh cell, at the TOP of a new session
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch

BASE_MODEL = "naver-clova-ix/donut-base"   # same base you used before

processor = DonutProcessor.from_pretrained(BASE_MODEL)
model     = VisionEncoderDecoderModel.from_pretrained(
              "./donut_phase1_step3/checkpoint_final")   # ← final phase-1 ckpt
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


# cell 2

In [1]:
import cv2, numpy as np
from pathlib import Path
from tqdm import tqdm

def deskew(img, limit=5):
    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(g, (9, 9), 0)
    th = cv2.threshold(blur, 0, 255,
                       cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
    coords = np.column_stack(np.where(th > 0))
    angle = cv2.minAreaRect(coords)[-1]
    angle = -(90 + angle) if angle < -45 else -angle
    if abs(angle) < limit:
        (h, w) = img.shape[:2]
        M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
        img = cv2.warpAffine(img, M, (w, h),
                             flags=cv2.INTER_CUBIC,
                             borderMode=cv2.BORDER_REPLICATE)
    return img, angle

def clahe(img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    cl = cv2.createCLAHE(2.0, (8,8)).apply(l)
    return cv2.cvtColor(cv2.merge((cl, a, b)), cv2.COLOR_LAB2BGR)

def resize_long(img, target=1100):
    h, w = img.shape[:2]
    scale = target / max(h, w)
    return cv2.resize(img, (int(w*scale), int(h*scale)),
                      interpolation=cv2.INTER_AREA) if scale != 1 else img

def preprocess_one(path):
    img = cv2.imread(str(path))
    img, ang = deskew(img)
    img = clahe(img)
    img = resize_long(img)
    out = path.with_suffix('.clean.jpg')
    cv2.imwrite(str(out), img, [cv2.IMWRITE_JPEG_QUALITY, 95])
    return out, ang

def batch_preprocess(dir_):
    dir_ = Path(dir_)
    files = sorted(list(dir_.glob('*.png')) + list(dir_.glob('*.jpg')))
    for p in tqdm(files):
        preprocess_one(p)
    print("✅ Done cleaning")


# cell 3

In [None]:
import matplotlib.pyplot as plt

def preprocess_one(path, show=False):
    img = cv2.imread(str(path))
    img, ang = deskew(img)
    img = clahe(img)
    img = resize_long(img)
    out = path.with_suffix('.clean.jpg')
    cv2.imwrite(str(out), img, [cv2.IMWRITE_JPEG_QUALITY, 95])
    
    if show:
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.title(f"After preprocess (rotate {ang:.2f}°)")
        plt.axis('off')
        plt.show()
    return out, ang

# --- demo ---
# preprocess_one(Path('sample_invoice.jpg'), show=True)


# data augumentation

# cell 4

In [None]:
from albumentations import (
    Compose, Rotate, RandomBrightnessContrast,
    Perspective, MotionBlur, JpegCompression
)

AUG = Compose([
    Rotate(limit=3, border_mode=cv2.BORDER_REPLICATE, p=0.7),
    Perspective(scale=(0.02, 0.05), p=0.3),
    RandomBrightnessContrast(0.1, 0.1, p=0.3),
    MotionBlur(blur_limit=3, p=0.2),
    JpegCompression(60, 95, p=0.3),
])

def augment_and_save(path, n=3):
    img = cv2.imread(str(path))
    for i in range(n):
        aug = AUG(image=img)['image']
        cv2.imwrite(f"{path.stem}_aug{i}.jpg", aug)


# cell 5

In [None]:
import sentencepiece as spm

def train_tokenizer(corpus_txt,
                    model_prefix='vi_invoice',
                    vocab_size=32000):
    spm.SentencePieceTrainer.Train(
        input=corpus_txt,
        model_prefix=model_prefix,
        vocab_size=vocab_size,
        character_coverage=0.9995,
        pad_id=0, unk_id=1, bos_id=2, eos_id=3,
        user_defined_symbols=['<none>','₫','VNĐ',
                              '<address>','<taxcode>'])
    print("✅ Tokenizer saved:", model_prefix+".model")


# cell 6

In [None]:
import re, dateparser
from Levenshtein import distance as ld

MST = re.compile(r'\b\d{10}(?:\d{3})?\b')

def normalize_amount(txt):
    txt = txt.replace(',', '.').replace(' ', '')
    m = re.findall(r'[\d\.]+', txt)
    return m[0].replace('.', '') if m else txt

def fix_taxcode(code):
    m = MST.search(code or '')
    return m.group(0) if m else code

def parse_date(txt):
    dt = dateparser.parse(txt, settings={'DATE_ORDER':'DMY'})
    return dt.strftime('%Y-%m-%d') if dt else txt


# cell 7

In [None]:
# --- phase-1 curriculum folders (64 invoices each) ---
PHASES = [
    dict(img_dir='data/folderA/clean_images',
         label_file='data/folderA/train_labels.jsonl'),
    dict(img_dir='data/folderB/clean_images',
         label_file='data/folderB/train_labels.jsonl'),
    dict(img_dir='data/folderC/clean_images',
         label_file='data/folderC/train_labels.jsonl'),
]

from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch

BASE_MODEL = "naver-clova-ix/donut-base"

processor = DonutProcessor.from_pretrained(BASE_MODEL)
model      = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL)
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


# cell 8

In [None]:
from datasets import load_dataset
from albumentations.pytorch import ToTensorV2  # keep if you use AUG on-the-fly

def build_dataset(img_dir: str, label_file: str,
                  max_len: int = 600, augment=False):
    def proc_example(ex):
        # read → optional aug → pixel values
        img = cv2.imread(f"{img_dir}/{ex['file_name']}")
        if augment:
            img = AUG(image=img)["image"].permute(1,2,0).numpy()
        px  = processor.image_processor(img, return_tensors="pt").pixel_values[0]

        ids = processor.tokenizer(
            ex["label"],
            add_special_tokens=False,
            max_length=max_len,
            truncation=True
        ).input_ids
        return {"pixel_values": px, "labels": ids}

    ds = load_dataset("json", data_files=label_file)["train"]
    ds = ds.map(proc_example, remove_columns=ds.column_names)
    ds.set_format("torch")
    return ds


# cell 9

In [None]:
from transformers import Trainer, TrainingArguments
import os, math, datetime as dt

for i, phase in enumerate(PHASES, 1):
    print(f"\n🟩 Phase-1 / step {i}  —  {phase['img_dir']}")
    ds = build_dataset(phase["img_dir"], phase["label_file"])

    out_dir = f"./donut_phase1_step{i}"
    args = TrainingArguments(
        output_dir               = out_dir,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        num_train_epochs         = 3,
        learning_rate            = 2e-5,
        warmup_steps             = 200,
        fp16                     = True,
        logging_steps            = 50,
        save_total_limit         = 2,
        report_to                = "none",
        resume_from_checkpoint   = os.path.isdir(out_dir)  # resume if rerun
    )

    trainer = Trainer(model=model, args=args, train_dataset=ds)
    trainer.train()

    # keep a clearly-named checkpoint for this step
    trainer.save_model(f"{out_dir}/checkpoint_final")


# cell 10

In [None]:
def predict_one(img_path):
    pv = proc.image_processor(img_path,
                              return_tensors='pt').pixel_values.to(model.device)
    out = model.generate(pv, max_length=512)
    return proc.batch_decode(out, skip_special_tokens=False)[0]

# demo:
# pred = predict_one('sample_invoice.jpg')
# print(pred)
