# Finetuning the original model
This notebook simply does the initial finetuning of the model - 5 epochs, default LLM learning rate, etc.
The weights of the result will be saved and used in a separate notebook looking at the actual distillation.
Common functions/classes will pobably be put in a separate script once this is working

In [1]:
import os
import json
import torch
import pandas as pd
from pathlib import Path
from typing import List, Tuple
from PIL import Image
from Levenshtein import distance
from transformers import (
    AutoProcessor,
    LayoutLMv3Processor,
    LayoutLMv3ForTokenClassification
)

In [2]:
# create runtime vars
categories = ["address", "company", "date", "total", "none"]
cat_id_map = dict(enumerate(categories))
id_cat_map = {v:k for k,v in cat_id_map.items()}

### Define dataset & preprocessing funcs

In [3]:
def read_SROIE_csv(filepath: str, delim: str = ",") -> pd.DataFrame:
    """
    Write a quick and dirty method to deal with the fact that the text in the csv file has commas 
    and so is not interpreted as fixed-width by the default pandas read_csv method.
    """
    with open(filepath) as f:
        file = f.readlines()
    p1 = []
    p2 = []
    p3 = []
    p4 = []
    text = []
    for line in file:
        split_data = line.split(delim)
        # get bbox coords
        p1.append(split_data[:2])
        p2.append(split_data[2:4])
        p3.append(split_data[4:6])
        p4.append(split_data[6:8])
        # get text data
        text.append(" ".join(split_data[8:]).replace("\n", ""))
    df = pd.DataFrame({
        "p1": p1,
        "p2": p2,
        "p3": p3,
        "p4": p4,
        "text": text
    })
    return df

In [5]:
def split_bbox_to_word_level(df: pd.DataFrame) -> Tuple[List[str], List[List[int]]]:
    """
    The bounding box data is not on the word level, and is instead on a string-of-words level.
    This function splits the data down to a word level by letting multiple words share a single
    bounding box.
    """
    if "bbox" not in df.columns or "text" not in df.columns:
        raise ValueError("df does not have bbox or text column!")
    words = []
    bboxes = []
    for row in df.itertuples():
        text = row.text
        bbox = row.bbox
        text_split = text.split(" ")
        dup_bbox = [bbox] * len(text_split)
        words.extend(text_split)
        bboxes.extend(dup_bbox)
    return words, bboxes

In [6]:
# flatten key dictionary to a word level, with each word getting the label it appears under
def flatten_keys(keys: dict) -> Tuple[List[str], List[str]]:
    """
    Brings the labels to the word level
    """
    words = []
    labels = []
    for label, sequence in keys.items():
        seq = sequence.replace(",", " ")
        word_split = [word for word in seq.split(" ") if word != ""]
        words.extend(word_split)
        labels.extend([label] * len(word_split))
    return words, labels

In [7]:
class SROIEProcDataset(torch.utils.data.Dataset):
    """
    Takes care of the processing of each document for use by the model
    """
    def __init__(self, filepath: str) -> None:
        self.processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
        # read in data on disk - sort to ensure data is accessed in the same order
        img = sorted(os.listdir(f"{filepath}/data/img"))
        bbox = sorted(os.listdir(f"{filepath}/data/box"))
        key = sorted(os.listdir(f"{filepath}/data/key"))
        # check the dataset is valid
        if len(img) != len(bbox) or len(bbox) != len(key):
            raise RuntimeError("Different number of documents with images and bounding box data")
        # preprocess filepaths so no extra processing needs to be done
        self.img = [f"{filepath}/data/img/{f}" for f in img]
        self.bbox = [f"{filepath}/data/box/{f}" for f in bbox]
        self.keys = [f"{filepath}/data/key/{f}" for f in key]
        
    def __process_bbox(self, bbox):
        """
        Processes each bounding box in the data into a LLMv3-compatible format
        Could probably be a static method, but not important right now
        """
        proc_box = bbox.apply(lambda x: x[0] + x[2], axis = 1).tolist()
        bbox["bbox"] = proc_box
        bbox = bbox.drop(["p1", "p2", "p3", "p4"], axis=1)
        # split the bbox data so we predict on the word level - train on this
        # we will combine this data back to the bbox level after predictions occur
        word, bbox = split_bbox_to_word_level(bbox)
        bboxes = pd.DataFrame({"word": word, "bbox": bbox})
        # remove rows that are empty
        bboxes = bboxes[bboxes.word != ""]
        # normalise the bboxes to be in [0, 1000]
        bboxes_norm = bboxes.apply(
            lambda row: [
                int(row[1][0])/image.width * 1000, 
                int(row[1][1])/image.height * 1000,
                int(row[1][2])/image.width * 1000, 
                int(row[1][3])/image.height * 1000
            ],
            axis=1
        )
        bboxes["bbox"] = bboxes_norm
        return bboxes
        
    def __getitem__(self, idx):
        image = Image.open(self.img[idx]).convert("RGB")
        bbox = read_SROIE_csv(self.bbox[idx])
        # process bboxes into LLMv3 format
        bboxes = self.__process_bbox(bbox)        
        # retrieve keys - these are the labels that we want to process on
        with open(self.keys[idx]) as f:
            keys = json.load(f)
        words, labels = flatten_keys(keys)
        labels = pd.DataFrame({"word": words, "label": labels})
        # now match the label to the word-bbox pairs
        bbox_labels = []
        for word in bboxes.word:
            item = labels[labels.word == word].label.tolist()
            item = ["none"] if item == [] else item
            bbox_labels.extend(item)
        bboxes["label"] = bbox_labels
        # now format as a dict to be fed into the processor
        words = bboxes.word.tolist()
        boxes = bboxes.bbox.tolist()
        labels = bboxes.label.tolist()
        encoding = self.processor(
            image, 
            words, 
            boxes=boxes, 
            word_labels=labels, 
            truncation=True,
            padding="max_length"
        )
        return encoding
        
    def __len__(self):
        return len(self.img)

In [8]:
data_path = str(Path("ICDAR-2019-SROIE").absolute())
dataset = SROIEProcDataset(data_path)

### Finetune the model