## Inference

In [18]:
import json
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForTokenClassification
from peft import PeftModel

# -------------------------------
# 1️⃣ Paths
# -------------------------------
base_model_name = "microsoft/layoutlmv3-base"
lora_checkpoint = "../models/layoutlmv3-finetuned-lora/checkpoint-2490"
train_json_path = "../data/SROIE2019/train/train.json"

img_path = "../data/SROIE2019/train/img/X00016469669.jpg"
ocr_path = "../data/SROIE2019/train/box/X00016469669.txt"

# -------------------------------
# 2️⃣ Load label map from training JSON
# -------------------------------
with open(train_json_path, "r", encoding="utf-8") as f:
    train_data = json.load(f)

all_labels = sorted({label for record in train_data for label in record.get("labels", [])})
id2label = {i: label for i, label in enumerate(all_labels)}
label2id = {label: i for i, label in id2label.items()}

print("Label map:", id2label)

# -------------------------------
# 3️⃣ Load base model + LoRA adapter and merge
# -------------------------------
base_model = AutoModelForTokenClassification.from_pretrained(
    base_model_name,
    num_labels=len(id2label)
)
base_model.config.id2label = id2label
base_model.config.label2id = label2id

# Load adapter
model = PeftModel.from_pretrained(base_model, lora_checkpoint)

# Merge LoRA weights into base model for standard inference
model = model.merge_and_unload()
model.eval()

# -------------------------------
# 4️⃣ Load processor
# -------------------------------
processor = AutoProcessor.from_pretrained(base_model_name, apply_ocr=False)

# -------------------------------
# 5️⃣ OCR Parsing
# -------------------------------
def parse_ocr_file(txt_path):
    words, bboxes = [], []
    with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            parts = line.strip().split(",")
            if len(parts) < 9:
                continue
            coords = list(map(int, parts[:8]))
            text = ",".join(parts[8:]).strip()
            if not text:
                continue
            xs, ys = coords[::2], coords[1::2]
            bbox = [min(xs), min(ys), max(xs), max(ys)]
            words.append(text)
            bboxes.append(bbox)
    return words, bboxes

def normalize_bbox(bbox, image_w, image_h):
    x0, y0, x1, y1 = bbox
    return [
        int(1000 * (x0 / image_w)),
        int(1000 * (y0 / image_h)),
        int(1000 * (x1 / image_w)),
        int(1000 * (y1 / image_h)),
    ]

def merge_tokens(tokens, labels):
    words, word_labels = [], []
    current_word, current_label = "", None
    for token, label in zip(tokens, labels):
        if token.startswith("Ġ"):
            if current_word:
                words.append(current_word.strip())
                word_labels.append(current_label)
            current_word = token[1:]
            current_label = label
        else:
            current_word += token
    if current_word:
        words.append(current_word.strip())
        word_labels.append(current_label)
    return list(zip(words, word_labels))

# -------------------------------
# 6️⃣ Load image + OCR
# -------------------------------
image = Image.open(img_path).convert("RGB")
words, bboxes = parse_ocr_file(ocr_path)
img_w, img_h = image.size
normalized_bboxes = [normalize_bbox(b, img_w, img_h) for b in bboxes]

# -------------------------------
# 7️⃣ Encode inputs
# -------------------------------
encoding = processor(
    image,
    words,
    boxes=normalized_bboxes,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=512
)

# -------------------------------
# 8️⃣ Inference
# -------------------------------
with torch.no_grad():
    outputs = model(**encoding)
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze())

pred_labels = [model.config.id2label[p] for p in predictions]
merged = merge_tokens(tokens, pred_labels)

# -------------------------------
# 9️⃣ Print clean results
# -------------------------------
print("Detected entities:")
for word, label in merged:
    if label != "O":
        print(f"{word}: {label}")


Some weights of LayoutLMv3ForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Label map: {0: 'B-INVOICE_ID', 1: 'I-INVOICE_ID', 2: 'O'}
Detected entities:
<s>: None
TAN: I-INVOICE_ID
CHAY: I-INVOICE_ID
YEE: I-INVOICE_ID
HO: I-INVOICE_ID
TRADING: B-INVOICE_ID
NO.2&4,JALAN: I-INVOICE_ID
HARMONI: I-INVOICE_ID
3/2,: B-INVOICE_ID
TAMAN: I-INVOICE_ID
DESA: I-INVOICE_ID
HARMONI.: I-INVOICE_ID
81100: I-INVOICE_ID
JOHOR: B-INVOICE_ID
BAHRU: B-INVOICE_ID
JOHOR: B-INVOICE_ID
07-355: I-INVOICE_ID
2616: B-INVOICE_ID
CASH: I-INVOICE_ID
BILL: I-INVOICE_ID
:: I-INVOICE_ID
01-143008: I-INVOICE_ID
DATE: I-INVOICE_ID
:: I-INVOICE_ID
09/01/2019: I-INVOICE_ID
8:01:11: I-INVOICE_ID
CASHIER: I-INVOICE_ID
:: I-INVOICE_ID
01: I-INVOICE_ID
DESCRIPTION: I-INVOICE_ID
QTY: I-INVOICE_ID
PRICE: I-INVOICE_ID
AMOUNT: I-INVOICE_ID
RM: I-INVOICE_ID
PLASTIC: I-INVOICE_ID
31.00: I-INVOICE_ID
31.00: I-INVOICE_ID
TOTAL: I-INVOICE_ID
AMOUNT:: I-INVOICE_ID
31.00: I-INVOICE_ID
CASH: I-INVOICE_ID
RECEIVED: I-INVOICE_ID
:: I-INVOICE_ID
101.00: I-INVOICE_ID
CHANGE: I-INVOICE_ID
:: I-INVOICE_ID
70.00: I-INV

In [15]:
# --- Parse OCR file ---
def parse_ocr_file(txt_path):
    words, bboxes = [], []
    with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            parts = line.strip().split(",")
            if len(parts) < 9:
                continue
            coords = list(map(int, parts[:8]))
            text = ",".join(parts[8:]).strip()
            if not text:
                continue
            xs, ys = coords[::2], coords[1::2]
            bbox = [min(xs), min(ys), max(xs), max(ys)]
            words.append(text)
            bboxes.append(bbox)
    return words, bboxes

def normalize_bbox(bbox, image_w, image_h):
    x0, y0, x1, y1 = bbox
    return [
        int(1000 * (x0 / image_w)),
        int(1000 * (y0 / image_h)),
        int(1000 * (x1 / image_w)),
        int(1000 * (y1 / image_h)),
    ]

def merge_tokens(tokens, labels):
    words, word_labels = [], []
    current_word, current_label = "", None
    for token, label in zip(tokens, labels):
        if token.startswith("Ġ"):
            if current_word:
                words.append(current_word.strip())
                word_labels.append(current_label)
            current_word = token[1:]
            current_label = label
        else:
            current_word += token
    if current_word:
        words.append(current_word.strip())
        word_labels.append(current_label)
    return list(zip(words, word_labels))

# --- Load image + OCR ---
image = Image.open(img_path).convert("RGB")
words, bboxes = parse_ocr_file(ocr_path)
img_w, img_h = image.size
normalized_bboxes = [normalize_bbox(b, img_w, img_h) for b in bboxes]

# --- Encode ---
encoding = processor(
    image,
    words,
    boxes=normalized_bboxes,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=512
)

# --- Inference ---
with torch.no_grad():
    outputs = model(**encoding)
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    tokens = processor.tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze())

pred_labels = [model.config.id2label[p] for p in predictions]
merged = merge_tokens(tokens, pred_labels)

# --- Print results ---
for word, label in merged:
    if label != "O":
        print(f"{word}: {label}")




<s>: None
TAN: I-INVOICE_ID
WOON: I-INVOICE_ID
YANN: B-INVOICE_ID
INDAH: I-INVOICE_ID
GIFT: B-INVOICE_ID
&: B-INVOICE_ID
HOME: B-INVOICE_ID
DECO: I-INVOICE_ID
27,JALAN: B-INVOICE_ID
DEDAP: B-INVOICE_ID
13,: B-INVOICE_ID
TAMAN: I-INVOICE_ID
JOHOR: I-INVOICE_ID
JAYA,: I-INVOICE_ID
81100: B-INVOICE_ID
JOHOR: I-INVOICE_ID
BAHRU,JOHOR.: B-INVOICE_ID
TEL:07-3507405: I-INVOICE_ID
FAX:07-3558160: B-INVOICE_ID
RECEIPT: B-INVOICE_ID
19/10/2018: I-INVOICE_ID
20:49:59: I-INVOICE_ID
#01: I-INVOICE_ID
CASHIER:: I-INVOICE_ID
CN: I-INVOICE_ID
LOCATION/SP:: B-INVOICE_ID
05: B-INVOICE_ID
/0531: B-INVOICE_ID
MB:: I-INVOICE_ID
MO26588: I-INVOICE_ID
ROOM: B-INVOICE_ID
NO:: I-INVOICE_ID
01: I-INVOICE_ID
050100035279: B-INVOICE_ID
DESC/ITEM: I-INVOICE_ID
QTY: I-INVOICE_ID
PRICE: B-INVOICE_ID
AMT(RM): I-INVOICE_ID
ST-PRIVILEGE: I-INVOICE_ID
CARD/GD: I-INVOICE_ID
INDAH: B-INVOICE_ID
88888: I-INVOICE_ID
1: I-INVOICE_ID
10.00: I-INVOICE_ID
10.00: I-INVOICE_ID
GF-TABLE: B-INVOICE_ID
LAMP/STITCH: I-INVOICE_ID
<I>:

In [10]:
id2label = {i: label for i, label in enumerate(all_labels)}
label2id = {label: i for i, label in id2label.items()}

print("id2label:", id2label)
print("label2id:", label2id)


id2label: {0: 'B-INVOICE_ID', 1: 'I-INVOICE_ID', 2: 'O'}
label2id: {'B-INVOICE_ID': 0, 'I-INVOICE_ID': 1, 'O': 2}
