# 🧾 Key-Value Extraction with LayoutLM
This notebook demonstrates how to preprocess the FUNSD dataset and fine-tune a LayoutLM model for key-value extraction.

## 📦 Install Required Libraries

In [None]:
!pip install transformers datasets seqeval layoutparser opencv-python Pillow scikit-learn

## 📁 Dataset Setup
Download and prepare the FUNSD dataset.

In [None]:
# Download script or manual step
# https://guillaumejaume.github.io/FUNSD/
# unzip into data/ folder

## 📚 Load and Inspect FUNSD Annotations

In [None]:
import json
import os

example_path = 'data/training_data/0000971160.json'
with open(example_path) as f:
    data = json.load(f)

for item in data['form'][:2]:
    print(item)

## 🧾 Preprocess to HuggingFace Format

In [None]:
from data_preprocessing import parse_funsd_json

parsed = parse_funsd_json('data/training_data/0000971160.json')
print(parsed['words'][:10])
print(parsed['boxes'][:10])
print(parsed['labels'][:10])

## 🧠 Load LayoutLM and Tokenizer

In [None]:
from transformers import LayoutLMTokenizer, LayoutLMForTokenClassification

tokenizer = LayoutLMTokenizer.from_pretrained('microsoft/layoutlm-base-uncased')
model = LayoutLMForTokenClassification.from_pretrained('microsoft/layoutlm-base-uncased', num_labels=7)

## 📊 Load Multiple Examples with Dataset

In [None]:
from layoutlm_dataset import FUNSDDataset
import glob

files = glob.glob('data/training_data/*.json')
dataset = FUNSDDataset(files[:10])  # small subset for testing
sample = dataset[0]
print({k: v.shape for k, v in sample.items()})

## 🏋️ Fine-tune Model

In [None]:
from transformers import Trainer, TrainingArguments
import numpy as np
from seqeval.metrics import precision_score, recall_score, f1_score

label_list = ['O', 'B-KEY', 'I-KEY', 'B-VALUE', 'I-VALUE', 'B-OTHER', 'I-OTHER']
from transformers import default_data_collator

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = [
        [label_list[p] for (p, l) in zip(pred, lab) if l != -100]
        for pred, lab in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(pred, lab) if l != -100]
        for pred, lab in zip(predictions, labels)
    ]
    return {
        "precision": precision_score(true_labels, true_predictions),
        "recall": recall_score(true_labels, true_predictions),
        "f1": f1_score(true_labels, true_predictions)
    }

training_args = TrainingArguments(
    output_dir="outputs/layoutlm",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    logging_dir="outputs/logs",
    logging_steps=10,
    evaluation_strategy="no",
    save_strategy="no",
    fp16=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=None,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

# To train:
# trainer.train()

## 🧪 Inference and Visualization

In [None]:
import layoutparser as lp
import matplotlib.pyplot as plt
import cv2
import torch

sample = dataset[0]
input_ids = sample['input_ids'].unsqueeze(0)
attention_mask = sample['attention_mask'].unsqueeze(0)
bbox = sample['bbox'].unsqueeze(0)

model.eval()
with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, bbox=bbox)
    predictions = outputs.logits.argmax(-1).squeeze().tolist()

img_path = 'data/images/0000971160.png'
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

for box, pred in zip(sample['bbox'], predictions):
    if pred != 0:
        x0, y0, x1, y1 = box.tolist()
        cv2.rectangle(image, (x0, y0), (x1, y1), (0, 255, 0), 1)
        cv2.putText(image, label_list[pred], (x0, y0 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.title("Predicted Key-Value Labels")
plt.show()