In [None]:
!nvidia-smi

Wed Dec 11 03:42:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
!pip install pdfkit
!wget -q https://github.com/wkhtmltopdf/packaging/releases/download/0.12.6-1/wkhtmltox_0.12.6-1.bionic_amd64.deb -O wkhtmltopdf.deb
!sudo apt-get install -y ./wkhtmltopdf.deb

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Note, selecting 'wkhtmltox' instead of './wkhtmltopdf.deb'
Some packages could not be installed. This may mean that you have
requested an impossible situation or if you are using the unstable
distribution that some required packages have not yet been created
or been moved out of Incoming.
The following information may help to resolve the situation:

The following packages have unmet dependencies:
 wkhtmltox : Depends: libssl1.1 but it is not installable
E: Unable to correct problems, you have held broken packages.


In [None]:
!pip install -qqq transformers==4.27.2 --progress-bar off
!pip install -qqq pytorch-lightning==1.9.4 --progress-bar off
!pip install -qqq torchmetrics==0.11.4 --progress-bar off
!pip install -qqq imgkit==1.2.3 --progress-bar off
!pip install -qqq easyocr==1.6.2 --progress-bar off
!pip install -qqq Pillow==9.4.0 --progress-bar off
!pip install -qqq tensorboardX==2.5.1 --progress-bar off
!pip install -qqq huggingface_hub==0.11.1 --progress-bar off
!pip install -qqq --upgrade --no-cache-dir gdown
!pip install --upgrade huggingface_hub

In [None]:
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from sklearn.model_selection import train_test_split
import imgkit
import easyocr
import torchvision.transforms as T
from pathlib import Path
import matplotlib.pyplot as plt
import os
import cv2
from typing import List
import json
from torchmetrics import Accuracy
from huggingface_hub import notebook_login
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

%matplotlib inline
pl.seed_everything(42)

INFO:lightning_fabric.utilities.seed:Global seed set to 42


42

In [None]:
!gdown 1xa2SDBjrYpBKrloKp4Yj9_cAEUwU7LHz

Downloading...
From (original): https://drive.google.com/uc?id=1xa2SDBjrYpBKrloKp4Yj9_cAEUwU7LHz
From (redirected): https://drive.google.com/uc?id=1xa2SDBjrYpBKrloKp4Yj9_cAEUwU7LHz&confirm=t&uuid=82acf285-9dfb-4385-a44b-38509ce218ff
To: /content/unstructured-Documents-Maana.zip
100% 184M/184M [00:05<00:00, 34.7MB/s]


In [None]:
!unzip -q unstructured-Documents-Maana.zip
!mv "unstructured-berkeley-project-1-documents/" "documents"

## Convert HTML to images


In [None]:
for dir in Path("documents").glob("*"):
    dir.rename(str(dir).lower().replace(" ", "_"))

In [None]:
list(Path("documents").glob("*"))

In [None]:
for dir in Path("documents").glob("*"):
    image_dir = Path(f"images/{dir.name}")
    image_dir.mkdir(exist_ok=True, parents=True)

In [None]:
def convert_html_to_image(file_path: Path, images_dir: Path, scale: float = 1.0) -> Path:
    file_name = file_path.with_suffix(".jpg").name
    save_path = images_dir / file_path.parent.name / f"{file_name}"
    imgkit.from_file(str(file_path), save_path, options={'quiet': '', 'format': 'jpeg'})

    image = Image.open(save_path)
    width, height = image.size
    image = image.resize((int(width * scale), int(height * scale)))
    image.save(str(save_path))

    return save_path

In [None]:
document_paths = list(Path("documents").glob("*/*"))

for doc_path in tqdm(document_paths):
    convert_html_to_image(doc_path, Path("images"), scale=0.8)

In [None]:
image_paths = sorted(list(Path("images").glob("*/*.jpg")))

image = Image.open(image_paths[0]).convert("RGB")
width, height = image.size
image

## EasyOCR

In [None]:
reader = easyocr.Reader(['en'])

In [None]:
%%time
image_path = image_paths[0]
ocr_result = reader.readtext(str(image_path))

In [None]:
ocr_result[0]

In [None]:
font_path = Path(cv2.__path__[0]) / "qt/fonts/DejaVuSansCondensed.ttf"
print(font_path.exists())
font = ImageFont.truetype(str(font_path), size=12)

In [None]:
def create_bounding_box(bbox_data):
    xs = []
    ys = []
    for x, y in bbox_data:
        xs.append(x)
        ys.append(y)

    left = int(min(xs))
    top = int(min(ys))
    right = int(max(xs))
    bottom = int(max(ys))

    return [left, top, right, bottom]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(28, 28))

left_image = Image.open(image_path).convert("RGB")
right_image = Image.new("RGB", left_image.size, (255, 255, 255))

left_draw = ImageDraw.Draw(left_image)
right_draw = ImageDraw.Draw(right_image)

for i, (bbox, word, confidence) in enumerate(ocr_result):
    box = create_bounding_box(bbox)

    left_draw.rectangle(box, outline="blue", width=2)
    left, top, right, bottom = box

    left_draw.text((right + 5, top), text=str(i + 1), fill="red", font=font)
    right_draw.text((left, top), text=word, fill="black", font=font)

ax1.imshow(left_image)
ax2.imshow(right_image)
ax1.axis("off");
ax2.axis("off");

In [None]:
for image_path in tqdm(image_paths):
    ocr_result = reader.readtext(str(image_path), batch_size=16)

    ocr_page = []
    for bbox, word, confidence in ocr_result:
        ocr_page.append({
            "word": word, "bounding_box": create_bounding_box(bbox)
        })

    with image_path.with_suffix(".json").open("w") as f:
        json.dump(ocr_page, f)

In [None]:
!tar -cJf financial-documents-ocr.tar.xz "./images"

In [None]:
!gdown 1bQ4mFbVRUtOEJSe8b4hUYIcngSgfdldw

In [None]:
!tar -xf financial-documents-ocr.tar.xz

In [None]:
image_paths = sorted(list(Path("images").glob("*/*.jpg")))

## LayoutLMv3

### Preprocessing

In [None]:
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
processor = LayoutLMv3Processor(feature_extractor, tokenizer)

Calling the processor is similar to a using a tokenizer:

In [None]:
image_path = image_paths[0]
image = Image.open(image_path).convert("RGB")
width, height = image.size
width, height

In [None]:
width_scale = 1000 / width
height_scale = 1000 / height

width_scale, height_scale

In [None]:
json_path = image_path.with_suffix(".json")
json_path

In [None]:
with json_path.open("r") as f:
    ocr_result = json.load(f)

In [None]:
def scale_bounding_box(box: List[int], width_scale : float = 1.0, height_scale : float = 1.0) -> List[int]:
    return [
        int(box[0] * width_scale),
        int(box[1] * height_scale),
        int(box[2] * width_scale),
        int(box[3] * height_scale)
    ]

In [None]:
words = []
boxes = []
for row in ocr_result:
    boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
    words.append(row["word"])

In [None]:
len(words), len(boxes)

In [None]:
encoding = processor(
    image,
    words,
    boxes=boxes,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

encoding.keys()

In [None]:
print(f"""
input_ids:  {list(encoding["input_ids"].squeeze().shape)}
word boxes: {list(encoding["bbox"].squeeze().shape)}
image data: {list(encoding["pixel_values"].squeeze().shape)}
image size: {image.size}
""")

##### Image

In [None]:
image_data = encoding["pixel_values"][0]
image_data.shape

In [None]:
transform = T.ToPILImage()
transform(image_data)

Word boxes use the `[left, top, right, bottom]` format.

In [None]:
def unnormalize_box(bbox, width, height):
    return [
        width * (bbox[0] / 1000),
        height * (bbox[1] / 1000),
        width * (bbox[2] / 1000),
        height * (bbox[3] / 1000),
    ]

### Model

In [None]:
model = LayoutLMv3ForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=2)

In [None]:
model.config

In [None]:
encoding = processor(
    image,
    words,
    boxes=boxes,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_tensors="pt"
)

outputs = model(**encoding)

In [None]:
outputs.logits

### Accuracy

In [None]:
import re
import Levenshtein
import numpy as np
from collections import defaultdict

def parse_tagged_string(input_string):
    """
    Parses a string with tags like <page_header>, <paragraph>, <sep/> into a structured dictionary.
    Handles unclosed or malformed tags.

    Args:
        input_string (str): The input string.

    Returns:
        dict: A dictionary representing the structured data, or None if parsing fails.
    """
    try:
        # Remove <s>, </s>, and <sep/> tags
        input_string = (
            input_string.replace("<s>", "").replace("</s>", "").replace("<sep/>", "")
        )
        tag_positions = []
        for match in re.finditer(r"<(/?)(\w+)>", input_string):
            tag_positions.append(
                (match.start(), match.group(1) == "/", match.group(2))
            )

        elements = []
        start = 0
        current_tag = None

        for pos, is_closing, tag_name in tag_positions:
            if not is_closing:
                #  Handles the case where Previous tag was not closed
                if current_tag is not None:
                    content = input_string[start:pos].strip()
                    if content and current_tag in [
                        "page_header",
                        "paragraph",
                        "subheading",
                        "title",
                        "image",
                        "table",
                        "code_snippet",
                        "page_footer",
                    ]:
                        elements.append({"type": current_tag, "content": content})

                current_tag = tag_name
                start = pos + len(f"<{tag_name}>")
            else:
                if current_tag == tag_name:
                    content = input_string[start:pos].strip()
                    if content and current_tag in [
                        "page_header",
                        "paragraph",
                        "subheading",
                        "title",
                        "image",
                        "table",
                        "code_snippet",
                        "page_footer",
                    ]:
                        elements.append({"type": current_tag, "content": content})

                    current_tag = None
                    start = pos + len(f"</{tag_name}>")

        if current_tag is not None and start < len(input_string):
            content = input_string[start:].strip()
            if content and current_tag in [
                "page_header",
                "paragraph",
                "subheading",
                "title",
                "image",
                "table",
                "code_snippet",
                "page_footer",
            ]:
                elements.append({"type": current_tag, "content": content})

        return {"document": elements}

    except Exception as e:
        print(f"Error during parsing: {e}")
        return None

def safe_divide(numerator, denominator):
    """Performs safe division, handling potential division by zero."""
    if denominator == 0:
        return 0
    else:
        return numerator / denominator

def calculate_text_accuracy(ref_text, pred_text):
    """
    Calculates text accuracy using Levenshtein distance.

    Args:
        ref_text (str): Reference text.
        pred_text (str): Predicted text.

    Returns:
        tuple: (accuracy, word_count) where accuracy is a float between 0 and 1,
               and word_count is the number of words in the reference text.
    """
    if not ref_text:
        if not pred_text:
            return 1.0, 0  # Both empty is considered 100% accurate
        else:
            return 0.0, 0
    if not isinstance(ref_text, str) or not isinstance(pred_text, str):
        return 0.0, 0
    dist = Levenshtein.distance(ref_text, pred_text)
    ref_word_count = len(ref_text.split())
    accuracy = safe_divide(max(0, len(ref_text) - dist), len(ref_text))

    return accuracy, ref_word_count

def calculate_overall_text_extraction_accuracy(reference, prediction):
    total_accuracy = 0
    total_words = 0

    ref_texts = [item["content"] for item in reference["document"]]
    pred_texts = [item["content"] for item in prediction["document"]]

    for ref_text, pred_text in zip(ref_texts, pred_texts):
        accuracy, word_count = calculate_text_accuracy(ref_text, pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    for pred_text in pred_texts[len(ref_texts):]:
        accuracy, word_count = calculate_text_accuracy("", pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    overall_accuracy = safe_divide(total_accuracy, total_words)

    return overall_accuracy

def calculate_tag_categorization_accuracy(reference, prediction):
    """
    Calculates tag categorization accuracy and F1 score.

    Args:
        reference (dict): Parsed reference document.
        prediction (dict): Parsed prediction document.

    Returns:
        tuple: (tag_accuracy_dict, overall_accuracy, tag_f1_dict, overall_f1)
               where tag_accuracy_dict is accuracy per tag, overall_accuracy is overall tag accuracy,
               tag_f1_dict is F1 score per tag, and overall_f1 is overall F1 score.
    """
    ref_tags = [item["type"] for item in reference["document"]]
    pred_tags = [item["type"] for item in prediction["document"]]

    # Predefined list of tags
    all_tags = [
        "page_header",
        "paragraph",
        "subheading",
        "title",
        "image",
        "table",
        "code_snippet",
        "page_footer",
    ]

    tag_accuracy_dict = {}
    tag_f1_dict = {}
    overall_tp = 0
    overall_fp = 0
    overall_fn = 0

    for tag in all_tags:
        tp = 0
        fp = 0
        fn = 0

        for r_tag, p_tag in zip(ref_tags, pred_tags):
            if r_tag == tag:
                if p_tag == tag:
                    tp += 1 # True Positive
                else:
                    fn += 1 # False Negative
            elif p_tag == tag:
                fp += 1     # False Positive

        overall_tp += tp
        overall_fp += fp
        overall_fn += fn

        precision = safe_divide(tp, tp + fp)
        recall = safe_divide(tp, tp + fn)
        f1_score = safe_divide(2 * precision * recall, precision + recall)

        correct_for_tag = tp  # Correct for tag is TP in this context
        total_for_tag_ref = ref_tags.count(tag)

        if total_for_tag_ref > 0:
            tag_accuracy_dict[tag] = safe_divide(correct_for_tag, total_for_tag_ref)
        else:
            tag_accuracy_dict[tag] = -9999  # Use -9999 for tags not present in reference

        tag_f1_dict[tag] = f1_score if total_for_tag_ref > 0 or tag in pred_tags else -9999 # F1 score -9999 if tag not in ref and pred


    correct_tags = overall_tp # Overall correct tags is sum of TPs
    overall_accuracy = safe_divide(
        correct_tags, max(len(ref_tags), len(pred_tags))
    )

    overall_precision = safe_divide(overall_tp, overall_tp + overall_fp)
    overall_recall = safe_divide(overall_tp, overall_tp + overall_fn)
    overall_f1 = safe_divide(2 * overall_precision * overall_recall, overall_precision + overall_recall)

    return tag_accuracy_dict, overall_accuracy, tag_f1_dict, overall_f1

def calculate_text_accuracy_for_tag(reference, prediction, tag):
    """
    Calculates text extraction accuracy for a specific tag, handling missing tags.
    """
    total_accuracy = 0
    total_words = 0

    ref_texts = [item["content"] for item in reference["document"] if item["type"] == tag]
    pred_texts = [item["content"] for item in prediction["document"] if item["type"] == tag]

    # If the tag is missing in both reference and prediction, return -9999 because it's not evaluated
    if not ref_texts and not pred_texts:
        return -9999

    for ref_text, pred_text in zip(ref_texts, pred_texts):
        accuracy, word_count = calculate_text_accuracy(ref_text, pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    for ref_text in ref_texts[len(pred_texts):]:
        accuracy, word_count = calculate_text_accuracy(ref_text, "")
        total_accuracy += accuracy * word_count
        total_words += word_count

    for pred_text in pred_texts[len(ref_texts):]:
        accuracy, word_count = calculate_text_accuracy("", pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    accuracy_for_tag = safe_divide(total_accuracy, total_words)

    return accuracy_for_tag

def calculate_all_accuracies(reference, prediction):
    """
    Calculates overall text extraction accuracy, tag categorization accuracy, and
    text extraction accuracy for each tag, including F1 score for tag categorization.
    """
    overall_text_accuracy = calculate_overall_text_extraction_accuracy(
        reference, prediction
    )
    tag_accuracy_dict, overall_tag_accuracy, tag_f1_dict, overall_f1 = calculate_tag_categorization_accuracy(
        reference, prediction
    )

    tag_accuracies = {}
    tag_f1_scores = {} # Store F1 scores per tag
    for tag in [
        "paragraph",
        "subheading",
        "page_footer",
        "title",
        "image",
        "table",
        "page_header",
        "code_snippet",
    ]:
        tag_accuracies[tag] = calculate_text_accuracy_for_tag(
            reference, prediction, tag
        )
        tag_f1_scores[tag] = tag_f1_dict[tag]

    return {
        "overall_text_extraction": {"percentage": overall_text_accuracy},
        "tag_categorization": {
            "overall_tag_accuracy": overall_tag_accuracy,
            "tag_accuracy_dict": tag_accuracy_dict,
            "overall_f1": overall_f1, # Overall F1 for tag categorization
            "tag_f1_dict": tag_f1_scores # F1 score per tag
        },
        "text_extraction_by_tag": {
            tag: {"percentage": tag_accuracies[tag]} for tag in tag_accuracies
        },
    }

def final_accuracy(reference_string, prediction_string):
    """
    Calculates accuracy metrics from tagged input strings, including F1 score.

    Args:
        reference_string (str): The reference string with tags.
        prediction_string (str): The prediction string with tags.

    Returns:
        dict: A dictionary containing the accuracy and F1 scores, or None if parsing fails.
    """
    reference_data = parse_tagged_string(reference_string)
    prediction_data = parse_tagged_string(prediction_string)

    if reference_data is None or prediction_data is None:
        print("Error: Could not parse input strings.")
        return None

    return calculate_all_accuracies(reference_data, prediction_data)

def avg_accuracy(all_accs):
    """
    Calculates the average accuracy and F1 scores from a list of accuracy dictionaries,
    handling missing tags.
    """
    overall_text_accuracies = []
    tag_categorization_accuracies = []
    tag_accuracy_dicts = []
    tag_level_accuracies = defaultdict(list)
    tag_f1_dicts = [] # List to store tag F1 dictionaries
    overall_f1_scores = [] # List to store overall F1 scores

    for acc_dict in all_accs:
        overall_text_accuracies.append(acc_dict["overall_text_extraction"]["percentage"])
        tag_categorization_accuracies.append(
            acc_dict["tag_categorization"]["overall_tag_accuracy"]
        )
        tag_accuracy_dicts.append(acc_dict["tag_categorization"]["tag_accuracy_dict"])  # Store the dict
        tag_f1_dicts.append(acc_dict["tag_categorization"]["tag_f1_dict"]) # Store tag F1 dict
        overall_f1_scores.append(acc_dict["tag_categorization"]["overall_f1"]) # Store overall F1

        for tag, metrics in acc_dict["text_extraction_by_tag"].items():
            tag_level_accuracies[tag].append(metrics["percentage"])

    # Calculate averages, skipping values less than -9998 (-9999)
    avg_overall_text_accuracy = np.mean(
        [x for x in overall_text_accuracies if x > -9998]
    )
    avg_tag_categorization_accuracy = np.mean(
        [x for x in tag_categorization_accuracies if x > -9998]
    )
    avg_overall_f1_score = np.mean([x for x in overall_f1_scores if x > -9998]) # Average overall F1

    avg_tag_accuracies = {}
    avg_tag_f1_scores = {} # Average tag F1 scores
    for tag in tag_accuracy_dicts[0]:
        avg_tag_accuracies[tag] = np.mean(
            [tag_dict[tag] for tag_dict in tag_accuracy_dicts if tag_dict[tag] > -9998]
        )
        avg_tag_f1_scores[tag] = np.mean(
            [tag_dict[tag] for tag_dict in tag_f1_dicts if tag_dict[tag] > -9998 and tag_dict[tag] != -9999] # Handle -9999 for F1
        )


    avg_tag_level_accuracies = {}
    for tag, values in tag_level_accuracies.items():
        avg_tag_level_accuracies[tag] = np.mean([x for x in values if x > -9998])

    return {
        "overall_text_extraction": {"percentage": avg_overall_text_accuracy},
        "tag_categorization": {
            "percentage": avg_tag_categorization_accuracy,
            "tag_accuracy_dict": avg_tag_accuracies,
            "overall_f1": avg_overall_f1_score, # Average overall F1
            "tag_f1_dict": avg_tag_f1_scores # Average tag F1 scores
        },
        "text_extraction_by_tag": {
            tag: {"percentage": avg_tag_level_accuracies[tag]}
            for tag, value in avg_tag_level_accuracies.items()
        },
    }

In [None]:
import json
import re
from collections import defaultdict
import re

output_list = []
all_accs = []

processed_data = []

for item in test_data:
  processed_item = 'your model tokenized output' #FILL THIS OUT WITH THE SPECIFIED FORMAT ABOVE
  processed_data.append({"prediction": processed_item, "target_sequence": item["target_sequence"]})

output_list = []
all_accs = []

for item in processed_data:
  try:
    accuracies = final_accuracy(item['target_sequence'], item['prediction'])
    # print(accuracies)
    if accuracies:
        all_accs.append(accuracies)

  except Exception as e:  # Catching a broader range of exceptions
    print(f"Skip: {e}")

average_accuracies = avg_accuracy(all_accs)

#print all accuracies and F1 scores
overall_text_accuracies = average_accuracies['overall_text_extraction']['percentage']
tag_categorization_accuracies = average_accuracies['tag_categorization']['percentage']
overall_tag_f1 = average_accuracies['tag_categorization']['overall_f1'] # Get overall F1
paragraph_text_accuracies = average_accuracies['text_extraction_by_tag']['paragraph']['percentage']
image_text_accuracies = average_accuracies['text_extraction_by_tag']['image']['percentage']
title_text_accuracies = average_accuracies['text_extraction_by_tag']['title']['percentage']
table_text_accuracies = average_accuracies['text_extraction_by_tag']['table']['percentage']
page_header_text_accuracies = average_accuracies['text_extraction_by_tag']['page_header']['percentage']
subheading_text_accuracies = average_accuracies['text_extraction_by_tag']['subheading']['percentage']
code_snippet_text_accuracies = average_accuracies['text_extraction_by_tag']['code_snippet']['percentage']
page_footer_text_accuracies = average_accuracies['text_extraction_by_tag']['page_footer']['percentage']

# Get tag-specific F1 scores
paragraph_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('paragraph', float('nan'))
image_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('image', float('nan'))
title_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('title', float('nan'))
table_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('table', float('nan'))
page_header_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('page_header', float('nan'))
subheading_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('subheading', float('nan'))
code_snippet_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('code_snippet', float('nan'))
page_footer_tag_f1 = average_accuracies['tag_categorization']['tag_f1_dict'].get('page_footer', float('nan'))


print(f"Average Overall Text Extraction Accuracy: {overall_text_accuracies:.4f}")
print(f"Average Tag Categorization Accuracy: {tag_categorization_accuracies:.4f}")
print(f"Average Overall Tag Categorization F1 Score: {overall_tag_f1:.4f}") # Print overall F1
print(f"Average Paragraph Extraction Accuracy: {paragraph_text_accuracies:.4f}")
print(f"Average Image Extraction Accuracy: {image_text_accuracies:.4f}")
print(f"Average Title Extraction Accuracy: {title_text_accuracies:.4f}")
print(f"Average Table Extraction Accuracy: {table_text_accuracies:.4f}")
print(f"Average Page Header Extraction Accuracy: {page_header_text_accuracies:.4f}")
print(f"Average Subheading Extraction Accuracy: {subheading_text_accuracies:.4f}")
print(f"Average Code Snippet Extraction Accuracy: {code_snippet_text_accuracies:.4f}")
print(f"Average Page Footer Extraction Accuracy: {page_footer_text_accuracies:.4f}")

print("\nTag-Specific F1 Scores:") # Print tag-specific F1 scores
print(f"Paragraph Tag F1 Score: {paragraph_tag_f1:.4f}")
print(f"Image Tag F1 Score: {image_tag_f1:.4f}")
print(f"Title Tag F1 Score: {title_tag_f1:.4f}")
print(f"Table Tag F1 Score: {table_tag_f1:.4f}")
print(f"Page Header Tag F1 Score: {page_header_tag_f1:.4f}")
print(f"Subheading Tag F1 Score: {subheading_tag_f1:.4f}")
print(f"Code Snippet Tag F1 Score: {code_snippet_tag_f1:.4f}")
print(f"Page Footer Tag F1 Score: {page_footer_tag_f1:.4f}")

## Training

In [None]:
train_images, test_images = train_test_split(image_paths, test_size=.2)

In [None]:
DOCUMENT_CLASSES = sorted(list(map(lambda p: p.name, Path("images").glob("*"))))
DOCUMENT_CLASSES

In [None]:
image_path = image_paths[300]
print(image_path)
DOCUMENT_CLASSES.index(image_path.parent.name)

In [None]:
image_path = image_paths[0]
print(image_path)
DOCUMENT_CLASSES.index(image_path.parent.name)

In [None]:
class DocumentClassificationDataset(Dataset):

    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):

        image_path = self.image_paths[item]
        json_path = image_path.with_suffix(".json")
        with json_path.open("r") as f:
            ocr_result = json.load(f)

            with Image.open(image_path).convert("RGB") as image:

                width, height = image.size
                width_scale = 1000 / width
                height_scale = 1000 / height

                words = []
                boxes = []
                roles = []  # New field for text roles
                font_sizes = []
                for row in ocr_result:
                    boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
                    words.append(row["word"])
                    roles.append(row.get("role", "body"))  # Default role is "body"
                    font_sizes.append(row.get("font_size", 12))

                encoding = self.processor(
                    image,
                    words,
                    boxes=boxes,
                    max_length=512,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )

        label = DOCUMENT_CLASSES.index(image_path.parent.name)

        return dict(
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            bbox=encoding["bbox"].flatten(end_dim=1),
            pixel_values=encoding["pixel_values"].flatten(end_dim=1),
            labels=torch.tensor(label, dtype=torch.long),
            roles=roles,  # Pass roles for multitask learning
            font_sizes=torch.tensor(font_sizes, dtype=torch.float)
        )

    def preprocess_with_roles_and_fonts(image, ocr_result, processor, width_scale, height_scale):
       words = []
       boxes = []
       roles = []
       font_sizes = []

       for row in ocr_result:
           boxes.append(scale_bounding_box(row["bounding_box"], width_scale, height_scale))
           words.append(row["word"])
           roles.append(row.get("role", "body"))  # Default role is "body"
           font_sizes.append(row.get("font_size", 12))  # Default font size

           encoding = self.processor(
                    image,
                    words,
                    boxes=boxes,
                    max_length=512,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                )

           encoding["roles"] = torch.tensor(roles, dtype=torch.long)
           encoding["font_sizes"] = torch.tensor(font_sizes, dtype=torch.float)
           return encoding

In [None]:
train_dataset = DocumentClassificationDataset(train_images, processor)
test_dataset = DocumentClassificationDataset(test_images, processor)

In [None]:
for item in train_dataset:
    print(item["bbox"].shape)
    print(item["pixel_values"].shape)
    print(item["labels"].shape)
    break

In [None]:
train_data_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)

test_data_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=2
)

In [None]:
class ModelModule(pl.LightningModule):
    def __init__(self, n_classes: int, n_roles: int):
        super().__init__()
        self.model = LayoutLMv3ForSequenceClassification.from_pretrained(
            "microsoft/layoutlmv3-base",
            num_labels=n_classes
        )
        self.role_classifier = torch.nn.Linear(self.model.config.hidden_size, n_roles)
        self.font_regressor = torch.nn.Linear(self.model.config.hidden_size, 1)
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)

    def forward(self, input_ids, attention_mask, bbox, pixel_values, labels=None, roles=None, font_sizes=None):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            bbox=bbox,
            pixel_values=pixel_values,
            labels=labels
        )

        role_logits = self.role_classifier(outputs.hidden_states[-1])  # Role classification
        font_predictions = self.font_regressor(outputs.hidden_states[-1]).squeeze(-1)  # Font regression
        return outputs, role_logits, font_predictions

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        roles = batch["roles"]
        font_sizes = batch["font_sizes"]

        outputs, role_logits, font_predictions = self(input_ids, attention_mask, bbox, pixel_values, labels, roles, font_sizes)

        # Calculate losses
        role_loss = torch.nn.CrossEntropyLoss()(role_logits, roles)
        font_loss = torch.nn.MSELoss()(font_predictions, font_sizes)
        loss = outputs.loss + role_loss + font_loss

        # Calculate new accuracy metric
        new_accuracy = calculate_text_accuracy_for_tag(labels, outputs.logits)

        # Logging loss and accuracy
        self.log("train_loss", loss)
        self.log("train_acc", self.train_accuracy(outputs.logits, labels), on_step=True, on_epoch=True)
        self.log("new_train_acc", new_accuracy, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        bbox = batch["bbox"]
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]

        outputs = self(input_ids, attention_mask, bbox, pixel_values, labels)

        val_loss = outputs.loss
        val_acc = self.val_accuracy(outputs.logits, labels)
        new_val_acc = calculate_text_accuracy_for_tag(labels, outputs.logits)

        self.log("val_loss", val_loss, on_epoch=True)
        self.log("val_acc", val_acc, on_epoch=True)
        self.log("new_val_acc", new_val_acc, on_epoch=True)

        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.00001) #1e-5
        return optimizer

In [None]:
model_module = ModelModule(len(DOCUMENT_CLASSES))

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
model_checkpoint = ModelCheckpoint(
    filename="{epoch}-{step}-{val_loss:.4f}", save_last=True, save_top_k=3, monitor="val_loss", mode="min"
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=16,
    devices=1,
    max_epochs=5,
    callbacks=[
        model_checkpoint
    ],
)

In [None]:
trainer.fit(model_module, train_data_loader, test_data_loader)

## Evaluation

In [None]:
model_checkpoint.best_model_path

In [None]:
trained_model = ModelModule.load_from_checkpoint(
    model_checkpoint.best_model_path,
    n_classes=len(DOCUMENT_CLASSES),
    local_files_only=True
)

# Load test dataset
test_results = []
for batch in test_data_loader:
    outputs = trained_model(batch["input_ids"], batch["attention_mask"], batch["bbox"], batch["pixel_values"])
    new_accuracy = calculate_text_accuracy_for_tag(batch["labels"], outputs.logits)
    test_results.append(new_accuracy)

# Compute final average accuracy
average_test_accuracy = sum(test_results) / len(test_results)
print(f"Final Average Test Accuracy: {average_test_accuracy:.4f}")


In [None]:
trained_model.model.save_pretrained(Path("best-model"))

In [None]:
notebook_login()

In [None]:
trained_model.model.push_to_hub("layoutlmv3-unstructured-berkeley-project-1")

In [None]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
model = LayoutLMv3ForSequenceClassification.from_pretrained("layoutlmv3-unstructured-berkeley-project-1")
model = model.eval().to(DEVICE)

In [None]:
def predict_document_image(
    image_path: Path,
    model: LayoutLMv3ForSequenceClassification,
    processor: LayoutLMv3Processor):

    json_path = image_path.with_suffix(".json")
    with json_path.open("r") as f:
        ocr_result = json.load(f)

        with Image.open(image_path).convert("RGB") as image:

            width, height = image.size
            width_scale = 1000 / width
            height_scale = 1000 / height

            words = []
            boxes = []
            for row in ocr_result:
                boxes.append(
                    scale_bounding_box(
                        row["bounding_box"],
                        width_scale,
                        height_scale
                    )
                )
                words.append(row["word"])

            encoding = processor(
                image,
                words,
                boxes=boxes,
                max_length=512,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )

    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE)
        )

    predicted_class = output.logits.argmax()
    return model.config.id2label[predicted_class.item()]

In [None]:
labels = []
predictions = []
for image_path in tqdm(test_images):
    labels.append(image_path.parent.name)
    predictions.append(predict_document_image(image_path, model, processor))