In [None]:
!pip install -qqq transformers==4.27.2 --progress-bar off
!pip install -qqq pytorch-lightning==1.9.4 --progress-bar off
!pip install -qqq torchmetrics==0.11.4 --progress-bar off
!pip install -qqq imgkit==1.2.3 --progress-bar off
!pip install -qqq easyocr==1.6.2 --progress-bar off
!pip install -qqq Pillow==9.4.0 --progress-bar off
!pip install -qqq tensorboardX==2.5.1 --progress-bar off
# !pip install -qqq huggingface_hub==0.11.1 --progress-bar off
!pip install -qqq --upgrade --no-cache-dir gdown

In [None]:
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.model_selection import train_test_split
import imgkit
import easyocr
import torchvision.transforms as T
from pathlib import Path
import matplotlib.pyplot as plt
import os
import cv2
from typing import List
import json
from torchmetrics import Accuracy
# from huggingface_hub import notebook_login
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

%matplotlib inline
pl.seed_everything(42)

In [None]:
from PIL import Image, ImageDraw, ImageFont
image_paths = sorted(list(Path("images").glob("*/*.png")))

image = Image.open(image_paths[0]).convert("RGB")
width, height = image.size
image

In [None]:
for image_path in tqdm(image_paths):
    ocr_result = reader.readtext(str(image_path), batch_size=16)

    ocr_page = []
    for bbox, word, confidence in ocr_result:
        ocr_page.append({
            "word": word, "bounding_box": create_bounding_box(bbox)
        })
    
    with image_path.with_suffix(".json").open("w") as f:
        json.dump(ocr_page, f)

In [None]:
image_paths = sorted(list(Path("images").glob("*/*.png")))

In [None]:
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
processor = LayoutLMv3Processor(feature_extractor, tokenizer)

In [None]:
print(f"""
input_ids:  {list(encoding["input_ids"].squeeze().shape)}
word boxes: {list(encoding["bbox"].squeeze().shape)}
image data: {list(encoding["pixel_values"].squeeze().shape)} 
image size: {image.size}
""")

In [None]:
encoding = processor(
    image, 
    words,
    boxes=boxes,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

outputs = model(**encoding)

In [None]:
train_images, test_images = train_test_split(image_paths, test_size=.2)

In [None]:
DOCUMENT_CLASSES = sorted(list(map(lambda p: p.name, Path("images").glob("*"))))
DOCUMENT_CLASSES

In [None]:
def scale_bounding_box(box: List[int], width_scale : float = 1.0, height_scale : float = 1.0) -> List[int]:
    return [
        int(box[0] * width_scale),
        int(box[1] * height_scale),
        int(box[2] * width_scale),
        int(box[3] * height_scale)
    ]

In [None]:
class DocumentClassificationDataset(Dataset):

    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor
  
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, item):

        image_path = self.image_paths[item]
        json_path = image_path.with_suffix(".json")
        with json_path.open("r") as f:
            ocr_result = json.load(f)
        
            with Image.open(image_path).convert("RGB") as image:
    
                width, height = image.size
                width_scale = 1000 / width
                height_scale = 1000 / height
        
                words = []
                boxes = []
                for row in ocr_result:
                    boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
                    words.append(row["word"])
        
                encoding = self.processor(
                    image, 
                    words,
                    boxes=boxes,
                    max_length=512,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )
        
        label = DOCUMENT_CLASSES.index(image_path.parent.name)

        return dict(
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            bbox=encoding["bbox"].flatten(end_dim=1),
            pixel_values=encoding["pixel_values"].flatten(end_dim=1),
            labels=torch.tensor(label, dtype=torch.long)
        )

In [None]:
train_dataset = DocumentClassificationDataset(train_images, processor)
test_dataset = DocumentClassificationDataset(test_images, processor)

In [None]:
train_data_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=8
)

test_data_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=8
)

In [None]:
# from torch import nn

# class ModelModule(nn.Module):
class ModelModule(pl.LightningModule):
    def __init__(self, n_classes:int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
            "microsoft/layoutlmv3-base", 
            num_labels=n_classes
        )
        self.model.config.id2label = {k: v for k, v in enumerate(DOCUMENT_CLASSES)}
        self.model.config.label2id = {v: k for k, v in enumerate(DOCUMENT_CLASSES)}
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None):
        return self.model(
            input_ids, 
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("train_loss", output.loss)
        self.log("train_acc", self.train_accuracy(output.logits, labels), on_step=True, on_epoch=True)
        return output.loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        output = self(input_ids, attention_mask, bbox, pixel_values, labels)
        self.log("val_loss", output.loss)
        self.log("val_acc", self.val_accuracy(output.logits, labels), on_step=False, on_epoch=True)
        return output.loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.00001) #1e-5
        return optimizer

In [None]:
model_module = ModelModule(len(DOCUMENT_CLASSES))

In [None]:
import tensorboard
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
model_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{val_loss:.4f}", save_last=True, save_top_k=3, monitor="val_loss", mode="min"
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    devices=1,
    max_epochs=10,
    callbacks=[
        model_checkpoint
    ],
)