<a href="https://colab.research.google.com/github/ynusinovich/layoutlmv2-practice/blob/main/Fine_tuning_LayoutLMv2ForTokenClassification_on_CORD_%2B_FUNSD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to fine-tune `LayoutLMv2ForTokenClassification` on the [CORD](https://github.com/clovaai/cord) dataset. The goal for the model is to label words appearing in scanned documents (namely, receipts) appropriately. This task is treated as a NER problem (sequence labeling). However, compared to BERT, LayoutLMv2 also incorporates visual and layout information about the tokens when encoding them into vectors. This makes the LayoutLMv2 model very powerful for document understanding tasks.

LayoutLMv2 is itself an upgrade of LayoutLM. The main novelty of LayoutLMv2 is that it also pre-trains visual embeddings, whereas the original LayoutLM only adds visual embeddings during fine-tuning.

* Paper: https://arxiv.org/abs/2012.14740
* Original repo: https://github.com/microsoft/unilm/tree/master/layoutlmv2

NOTES: 

* you first need to prepare the CORD dataset for LayoutLMv2. For that, check out the notebook "Prepare CORD for LayoutLMv2".
* this notebook is heavily inspired by [this Github repository](https://github.com/omarsou/layoutlm_CORD), which fine-tunes both BERT and LayoutLM (v1) on the CORD dataset.



## Install dependencies

First, we install the required libraries:
* Transformers (for the LayoutLMv2 model)
* Datasets (for data preprocessing)
* Seqeval (for metrics)
* Detectron2 (which LayoutLMv2 requires for its visual backbone).



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!rm -r transformers
!git clone -b modeling_layoutlmv2_v2 https://github.com/NielsRogge/transformers.git
!cd tranformers
!pip install -q ./transformers 

rm: cannot remove 'transformers': No such file or directory
Cloning into 'transformers'...
remote: Enumerating objects: 125241, done.[K
remote: Total 125241 (delta 0), reused 0 (delta 0), pack-reused 125241[K
Receiving objects: 100% (125241/125241), 103.52 MiB | 16.76 MiB/s, done.
Resolving deltas: 100% (91748/91748), done.
/bin/bash: line 0: cd: tranformers: No such file or directory
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 

In [3]:
!pip install -q datasets seqeval

[K     |████████████████████████████████| 346 kB 19.3 MB/s 
[K     |████████████████████████████████| 43 kB 2.5 MB/s 
[K     |████████████████████████████████| 212 kB 72.9 MB/s 
[K     |████████████████████████████████| 1.1 MB 61.4 MB/s 
[K     |████████████████████████████████| 86 kB 7.0 MB/s 
[K     |████████████████████████████████| 140 kB 73.3 MB/s 
[K     |████████████████████████████████| 127 kB 74.3 MB/s 
[K     |████████████████████████████████| 271 kB 63.7 MB/s 
[K     |████████████████████████████████| 144 kB 54.2 MB/s 
[K     |████████████████████████████████| 94 kB 1.5 MB/s 
[K     |████████████████████████████████| 112 kB 56.4 MB/s 
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatib

In [4]:
!pip install pyyaml==5.1
# workaround: install old version of pytorch since detectron2 hasn't released packages for pytorch 1.9 (issue: https://github.com/facebookresearch/detectron2/issues/3158)
!pip install torch==1.8.0+cu101 torchvision==0.9.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

# install detectron2 that matches pytorch 1.8
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
# !pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.8/index.html
# exit(0)  # After installation, you need to "restart runtime" in Colab. This line can also restart runtime
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyyaml==5.1
  Downloading PyYAML-5.1.tar.gz (274 kB)
[K     |████████████████████████████████| 274 kB 24.8 MB/s 
[?25hBuilding wheels for collected packages: pyyaml
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.1-cp37-cp37m-linux_x86_64.whl size=44092 sha256=f40fe2c71608a6d4df388d1cdfa76274e17cad4320cc38c6ee698df4c7ff4d13
  Stored in directory: /root/.cache/pip/wheels/77/f5/10/d00a2bd30928b972790053b5de0c703ca87324f3fead0f2fd9
Successfully built pyyaml
Installing collected packages: pyyaml
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 6.0
    Uninstalling PyYAML-6.0:
      Successfully uninstalled PyYAML-6.0
Successfully installed pyyaml-5.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/tor

## Prepare the data

First, let's read in the annotations which we prepared in the other notebook. These contain the word-level annotations (words, labels, normalized bounding boxes).

**Load FUNSD Data**

In [5]:
from datasets import load_dataset

datasets = load_dataset("nielsr/funsd")

Downloading builder script:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Downloading and preparing dataset funsd/funsd to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595...


Downloading data:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset funsd downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 149
    })
    test: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 50
    })
})

In [7]:
datasets['train'].features

{'bboxes': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'id': Value(dtype='string', id=None),
 'image_path': Value(dtype='string', id=None),
 'ner_tags': Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER'], id=None), length=-1, id=None),
 'words': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}

In [8]:
labels = datasets['train'].features['ner_tags'].feature.names
id2label = {v: k for v, k in enumerate(labels)}
id2label

{0: 'O',
 1: 'B-HEADER',
 2: 'I-HEADER',
 3: 'B-QUESTION',
 4: 'I-QUESTION',
 5: 'B-ANSWER',
 6: 'I-ANSWER'}

In [9]:
datasets_train_ner_tag_values = [[]] * len(datasets["train"]["ner_tags"])
for i in range(len(datasets["train"]["ner_tags"])):
  for j in range(len(datasets["train"]["ner_tags"][i])):
    datasets_train_ner_tag_values[i].append(id2label[datasets["train"]["ner_tags"][i][j]])

datasets_test_ner_tag_values = [[]] * len(datasets["test"]["ner_tags"])
for i in range(len(datasets["test"]["ner_tags"])):
  for j in range(len(datasets["test"]["ner_tags"][i])):
    datasets_test_ner_tag_values[i].append(id2label[datasets["test"]["ner_tags"][i][j]])

**Load CORD Data**

In [10]:
import pandas as pd

train = pd.read_pickle('/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/train/train.pkl')
val = pd.read_pickle('/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/dev/dev.pkl')
test = pd.read_pickle('/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/test/test.pkl')

**Combine CORD and FUNSD**

In [11]:
count_funsd = 0
for i in range(len(datasets_train_ner_tag_values)):
  count_funsd += len(datasets_train_ner_tag_values[i])
count_funsd

count_cord = 0
for i in range(len(train[1])):
  for j in range(len(train[1][i])):
    count_cord += len(train[1][i])
count_cord

count_funsd/count_cord

4.906633928813223

In [12]:
funsd_train_count = len(datasets["train"]["words"])
funsd_train_count

149

In [13]:
funsd_test_count = len(datasets["test"]["words"])
funsd_test_count

50

In [14]:
import random

funsd_train_and_val_indices = list(range(funsd_train_count))
random.shuffle(funsd_train_and_val_indices)
funsd_train_sample = funsd_train_and_val_indices[:135]
funsd_train_sample = random.sample(funsd_train_sample, 27)
funsd_val_sample = funsd_train_and_val_indices[136:]
funsd_val_sample = random.sample(funsd_val_sample, 3)
funsd_test_sample = random.sample(range(len(datasets["test"]["words"])), 5)

funsd_train_sample.sort()
funsd_val_sample.sort()
funsd_test_sample.sort()

In [15]:
from os import listdir
import os
import shutil 

funsd_training_images_dir = '/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/FUNSD_full_data/training_data/images/'
funsd_training_annotations_dir = '/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/FUNSD_full_data/training_data/annotations/'
funsd_testing_images_dir = '/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/FUNSD_full_data/testing_data/images/'
funsd_testing_annotations_dir = '/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/FUNSD_full_data/testing_data/annotations/'

funsd_training_images_file_names = [f for f in listdir(funsd_training_images_dir)]
funsd_training_annotations_file_names = [f for f in listdir(funsd_training_annotations_dir)]
funsd_testing_images_file_names = [f for f in listdir(funsd_testing_images_dir)]
funsd_testing_annotations_file_names = [f for f in listdir(funsd_testing_annotations_dir)]

In [16]:
datasets_train_order = [datasets["train"]["image_path"][i].split("/images/")[-1][:-4] for i in range(len(datasets["train"]["image_path"]))]
datasets_test_order = [datasets["test"]["image_path"][i].split("/images/")[-1][:-4] for i in range(len(datasets["test"]["image_path"]))]

In [17]:
funsd_training_images_file_names_ordered = []
funsd_training_annotations_file_names_ordered = []
funsd_testing_images_file_names_ordered = []
funsd_testing_annotations_file_names_ordered = []

for item in datasets_train_order:
  if item + ".png" in funsd_training_images_file_names:
    funsd_training_images_file_names_ordered.append(item + ".png")
  if item + ".json" in funsd_training_annotations_file_names:
    funsd_training_annotations_file_names_ordered.append(item + ".json")
for item in datasets_test_order:
  if item + ".png" in funsd_testing_images_file_names:
    funsd_testing_images_file_names_ordered.append(item + ".png")
  if item + ".json" in funsd_testing_annotations_file_names:
    funsd_testing_annotations_file_names_ordered.append(item + ".json")

In [18]:
funsd_training_images_file_names_sample = [funsd_training_images_file_names_ordered[f] for f in range(len(funsd_training_images_file_names_ordered)) if f in funsd_train_sample]
funsd_training_annotations_file_names_sample = [funsd_training_annotations_file_names_ordered[f] for f in range(len(funsd_training_annotations_file_names_ordered)) if f in funsd_train_sample]
funsd_validation_images_file_names_sample = [funsd_training_images_file_names_ordered[f] for f in range(len(funsd_training_images_file_names_ordered)) if f in funsd_val_sample]
funsd_validation_annotations_file_names_sample = [funsd_training_annotations_file_names_ordered[f] for f in range(len(funsd_training_annotations_file_names_ordered)) if f in funsd_val_sample]
funsd_testing_images_file_names_sample = [funsd_testing_images_file_names_ordered[f] for f in range(len(funsd_testing_images_file_names_ordered)) if f in funsd_test_sample]
funsd_testing_annotations_file_names_sample = [funsd_testing_annotations_file_names_ordered[f] for f in range(len(funsd_testing_annotations_file_names_ordered)) if f in funsd_test_sample]

In [19]:
training_images_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/train/image"
for file in funsd_training_images_file_names_sample:
    src = os.path.join(funsd_training_images_dir, file)
    dst = os.path.join(training_images_dir, file)
    shutil.copyfile(src, dst)
training_annotations_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/train/json"
for file in funsd_training_annotations_file_names_sample:
    src = os.path.join(funsd_training_annotations_dir, file)
    dst = os.path.join(training_annotations_dir, file)
    shutil.copyfile(src, dst)
validation_images_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/dev/image"
for file in funsd_validation_images_file_names_sample:
    src = os.path.join(funsd_training_images_dir, file)
    dst = os.path.join(validation_images_dir, file)
    shutil.copyfile(src, dst)
validation_annotations_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/dev/json"
for file in funsd_validation_annotations_file_names_sample:
    src = os.path.join(funsd_training_annotations_dir, file)
    dst = os.path.join(validation_annotations_dir, file)
    shutil.copyfile(src, dst)
testing_images_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/test/image"
for file in funsd_testing_images_file_names_sample:
    src = os.path.join(funsd_testing_images_dir, file)
    dst = os.path.join(testing_images_dir, file)
    shutil.copyfile(src, dst)
testing_annotations_dir = "/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/test/json"
for file in funsd_testing_annotations_file_names_sample:
    src = os.path.join(funsd_testing_annotations_dir, file)
    dst = os.path.join(testing_annotations_dir, file)
    shutil.copyfile(src, dst)

In [20]:
datasets_train_ner_tag_values_sample = [[] for i in range(len(funsd_training_images_file_names_sample))]
for file_index in range(len(funsd_training_images_file_names_sample)):
  for i in range(len(datasets["train"]["ner_tags"])):
    if datasets["train"]["image_path"][i].split("/images/")[-1] == funsd_training_images_file_names_sample[file_index]:
      ner_tags_for_file = datasets["train"]["ner_tags"][i]
      for j in range(len(ner_tags_for_file)):
        datasets_train_ner_tag_values_sample[file_index].append(id2label[ner_tags_for_file[j]])

datasets_val_ner_tag_values_sample = [[] for i in range(len(funsd_validation_images_file_names_sample))]
for file_index in range(len(funsd_validation_images_file_names_sample)):
  for i in range(len(datasets["train"]["ner_tags"])):
    if datasets["train"]["image_path"][i].split("/images/")[-1] == funsd_validation_images_file_names_sample[file_index]:
      ner_tags_for_file = datasets["train"]["ner_tags"][i]
      for j in range(len(ner_tags_for_file)):
        datasets_val_ner_tag_values_sample[file_index].append(id2label[ner_tags_for_file[j]])

datasets_test_ner_tag_values_sample = [[] for i in range(len(funsd_testing_images_file_names_sample))]
for file_index in range(len(funsd_testing_images_file_names_sample)):
  for i in range(len(datasets["test"]["ner_tags"])):
    if datasets["test"]["image_path"][i].split("/images/")[-1] == funsd_testing_images_file_names_sample[file_index]:
      ner_tags_for_file = datasets["test"]["ner_tags"][i]
      for j in range(len(ner_tags_for_file)):
        datasets_test_ner_tag_values_sample[file_index].append(id2label[ner_tags_for_file[j]])

In [32]:
train[0].extend([datasets["train"]["words"][i] for i in range(len(datasets["train"]["words"])) if i in funsd_train_sample])
train[1].extend(datasets_train_ner_tag_values_sample)
train[2].extend([datasets["train"]["bboxes"][i] for i in range(len(datasets["train"]["bboxes"])) if i in funsd_train_sample])

val[0].extend([datasets["train"]["words"][i] for i in range(len(datasets["train"]["words"])) if i in funsd_val_sample])
val[1].extend(datasets_val_ner_tag_values_sample)
val[2].extend([datasets["train"]["bboxes"][i] for i in range(len(datasets["train"]["bboxes"])) if i in funsd_val_sample])

test[0].extend([datasets["test"]["words"][i] for i in range(len(datasets["test"]["words"])) if i in funsd_test_sample])
test[1].extend(datasets_test_ner_tag_values_sample)
test[2].extend([datasets["test"]["bboxes"][i] for i in range(len(datasets["test"]["bboxes"])) if i in funsd_test_sample])

Let's define a list of all unique labels. For that, let's first count the number of occurrences for each label:

In [35]:
from collections import Counter

all_labels = [item for sublist in train[1] for item in sublist] + [item for sublist in val[1] for item in sublist] + [item for sublist in test[1] for item in sublist]
Counter(all_labels)

Counter({'B-ANSWER': 639,
         'B-HEADER': 97,
         'B-QUESTION': 800,
         'I-ANSWER': 1626,
         'I-HEADER': 256,
         'I-QUESTION': 1008,
         'O': 1238,
         'menu.cnt': 2429,
         'menu.discountprice': 403,
         'menu.etc': 19,
         'menu.itemsubtotal': 7,
         'menu.nm': 6597,
         'menu.num': 109,
         'menu.price': 2585,
         'menu.sub_cnt': 189,
         'menu.sub_etc': 9,
         'menu.sub_nm': 822,
         'menu.sub_price': 160,
         'menu.sub_unitprice': 14,
         'menu.unitprice': 750,
         'menu.vatyn': 9,
         'sub_total.discount_price': 191,
         'sub_total.etc': 283,
         'sub_total.othersvc_price': 6,
         'sub_total.service_price': 353,
         'sub_total.subtotal_price': 1482,
         'sub_total.tax_price': 1283,
         'total.cashprice': 1393,
         'total.changeprice': 1297,
         'total.creditcardprice': 410,
         'total.emoneyprice': 129,
         'total.menuqty_cn

As we can see, there are some labels that contain very few examples. Let's replace them by the "neutral" label "O" (which stands for "Outside").

In [36]:
replacing_labels = {'menu.etc': 'O', 'mneu.itemsubtotal': 'O', 'menu.sub_etc': 'O', 'menu.sub_unitprice': 'O', 'menu.vatyn': 'O',
                  'void_menu.nm': 'O', 'void_menu.price': 'O', 'sub_total.othersvc_price': 'O'}

In [37]:
def replace_elem(elem):
  try:
    return replacing_labels[elem]
  except KeyError:
    return elem
def replace_list(ls):
  return [replace_elem(elem) for elem in ls]
train[1] = [replace_list(ls) for ls in train[1]]
val[1] = [replace_list(ls) for ls in val[1]]
test[1] = [replace_list(ls) for ls in test[1]]

In [38]:
all_labels = [item for sublist in train[1] for item in sublist] + [item for sublist in val[1] for item in sublist] + [item for sublist in test[1] for item in sublist]
Counter(all_labels)

Counter({'B-ANSWER': 639,
         'B-HEADER': 97,
         'B-QUESTION': 800,
         'I-ANSWER': 1626,
         'I-HEADER': 256,
         'I-QUESTION': 1008,
         'O': 1299,
         'menu.cnt': 2429,
         'menu.discountprice': 403,
         'menu.itemsubtotal': 7,
         'menu.nm': 6597,
         'menu.num': 109,
         'menu.price': 2585,
         'menu.sub_cnt': 189,
         'menu.sub_nm': 822,
         'menu.sub_price': 160,
         'menu.unitprice': 750,
         'sub_total.discount_price': 191,
         'sub_total.etc': 283,
         'sub_total.service_price': 353,
         'sub_total.subtotal_price': 1482,
         'sub_total.tax_price': 1283,
         'total.cashprice': 1393,
         'total.changeprice': 1297,
         'total.creditcardprice': 410,
         'total.emoneyprice': 129,
         'total.menuqty_cnt': 630,
         'total.menutype_cnt': 130,
         'total.total_etc': 89,
         'total.total_price': 2120})

Now we have to save all the unique labels in a list.

In [39]:
labels = list(set(all_labels))
print(labels)

['total.total_etc', 'total.creditcardprice', 'I-ANSWER', 'B-HEADER', 'menu.price', 'menu.sub_price', 'sub_total.etc', 'menu.num', 'sub_total.discount_price', 'total.menutype_cnt', 'menu.discountprice', 'B-QUESTION', 'menu.cnt', 'menu.itemsubtotal', 'O', 'total.cashprice', 'menu.sub_nm', 'total.total_price', 'B-ANSWER', 'sub_total.tax_price', 'I-QUESTION', 'sub_total.subtotal_price', 'total.menuqty_cnt', 'menu.sub_cnt', 'menu.nm', 'total.emoneyprice', 'menu.unitprice', 'I-HEADER', 'sub_total.service_price', 'total.changeprice']


In [40]:
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for idx, label in enumerate(labels)}
print(label2id)
print(id2label)

{'total.total_etc': 0, 'total.creditcardprice': 1, 'I-ANSWER': 2, 'B-HEADER': 3, 'menu.price': 4, 'menu.sub_price': 5, 'sub_total.etc': 6, 'menu.num': 7, 'sub_total.discount_price': 8, 'total.menutype_cnt': 9, 'menu.discountprice': 10, 'B-QUESTION': 11, 'menu.cnt': 12, 'menu.itemsubtotal': 13, 'O': 14, 'total.cashprice': 15, 'menu.sub_nm': 16, 'total.total_price': 17, 'B-ANSWER': 18, 'sub_total.tax_price': 19, 'I-QUESTION': 20, 'sub_total.subtotal_price': 21, 'total.menuqty_cnt': 22, 'menu.sub_cnt': 23, 'menu.nm': 24, 'total.emoneyprice': 25, 'menu.unitprice': 26, 'I-HEADER': 27, 'sub_total.service_price': 28, 'total.changeprice': 29}
{0: 'total.total_etc', 1: 'total.creditcardprice', 2: 'I-ANSWER', 3: 'B-HEADER', 4: 'menu.price', 5: 'menu.sub_price', 6: 'sub_total.etc', 7: 'menu.num', 8: 'sub_total.discount_price', 9: 'total.menutype_cnt', 10: 'menu.discountprice', 11: 'B-QUESTION', 12: 'menu.cnt', 13: 'menu.itemsubtotal', 14: 'O', 15: 'total.cashprice', 16: 'menu.sub_nm', 17: 'total.

In [41]:
from os import listdir
from torch.utils.data import Dataset
import torch
from PIL import Image

class CORDandFUNSDDataset(Dataset):
    """CORD and FUNSD dataset."""

    def __init__(self, annotations, image_dir, processor=None, max_length=512):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes).
            image_dir (string): Directory with all the document images.
            processor (LayoutLMv2Processor): Processor to prepare the text + image.
        """
        self.words, self.labels, self.boxes = annotations
        self.image_dir = image_dir
        self.image_file_names = [f for f in listdir(image_dir)]
        self.processor = processor

    def __len__(self):
        return len(self.image_file_names)

    def __getitem__(self, idx):
        # first, take an image
        item = self.image_file_names[idx]
        image = Image.open(self.image_dir + item).convert("RGB")

        # get word-level annotations 
        words = self.words[idx]
        boxes = self.boxes[idx]
        word_labels = self.labels[idx]

        assert len(words) == len(boxes) == len(word_labels)
        
        word_labels = [label2id[label] for label in word_labels]
        # use processor to prepare everything
        encoded_inputs = self.processor(image, words, boxes=boxes, word_labels=word_labels, 
                                        padding="max_length", truncation=True, 
                                        return_tensors="pt")
        
        # remove batch dimension
        for k,v in encoded_inputs.items():
          encoded_inputs[k] = v.squeeze()

        assert encoded_inputs.input_ids.shape == torch.Size([512])
        assert encoded_inputs.attention_mask.shape == torch.Size([512])
        assert encoded_inputs.token_type_ids.shape == torch.Size([512])
        assert encoded_inputs.bbox.shape == torch.Size([512, 4])
        assert encoded_inputs.image.shape == torch.Size([3, 224, 224])
        assert encoded_inputs.labels.shape == torch.Size([512]) 
      
        return encoded_inputs

In [42]:
from transformers import LayoutLMv2Processor

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")

train_dataset = CORDandFUNSDDataset(annotations=train,
                            image_dir='/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/train/image/', 
                            processor=processor)
val_dataset = CORDandFUNSDDataset(annotations=val,
                            image_dir='/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/dev/image/', 
                            processor=processor)
test_dataset = CORDandFUNSDDataset(annotations=test,
                            image_dir='/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/test/image/', 
                            processor=processor)

Downloading:   0%|          | 0.00/135 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/707 [00:00<?, ?B/s]

Let's verify an example:

In [43]:
encoding = train_dataset[0]
encoding.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'labels', 'image'])

In [44]:
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
token_type_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
image torch.Size([3, 224, 224])


In [45]:
print(processor.tokenizer.decode(encoding['input_ids']))

[CLS] 1 ns gr telor kebuli 27, 272 subtotal 27, 272 tax 2, 727 total 30, 000 cash 50, 000 change 20, 000 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [46]:
train[0][0]

['1',
 'NS',
 'GR',
 'TELOR',
 'KEBULI',
 '27,272',
 'SUBTOTAL',
 '27,272',
 'TAX',
 '2,727',
 'TOTAL',
 '30,000',
 'CASH',
 '50,000',
 'Change',
 '20,000']

In [47]:
train[1][0]

['menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'total.total_price',
 'total.total_price',
 'total.cashprice',
 'total.cashprice',
 'total.changeprice',
 'total.changeprice']

In [48]:
[id2label[label] for label in encoding['labels'].tolist() if label != -100]

['menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'total.total_price',
 'total.total_price',
 'total.cashprice',
 'total.cashprice',
 'total.changeprice',
 'total.changeprice']

In [49]:
for id, label in zip(encoding['input_ids'][:30], encoding['labels'][:30]):
  print(processor.tokenizer.decode([id]), label.item())

[CLS] -100
1 12
ns 24
gr 24
tel 24
##or -100
ke 24
##bu -100
##li -100
27 4
, -100
272 -100
sub 21
##to -100
##tal -100
27 21
, -100
272 -100
tax 19
2 19
, -100
72 -100
##7 -100
total 17
30 17
, -100
000 -100
cash 15
50 15
, -100


Next, we create corresponding dataloaders.

In [50]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2)

## Train the model

Let's train the model using native PyTorch. We use the AdamW optimizer with learning rate = 5e-5 (this is a good default value when fine-tuning Transformer-based models).



In [51]:
from transformers import LayoutLMv2ForTokenClassification, AdamW
import torch
from tqdm.notebook import tqdm

model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased',
                                                                      num_labels=len(labels))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

global_step = 0
num_train_epochs = 4

#put the model in training mode
model.train() 
for epoch in range(num_train_epochs):  
   print("Epoch:", epoch)
   for batch in tqdm(train_dataloader):
        # get the inputs;
        input_ids = batch['input_ids'].to(device)
        bbox = batch['bbox'].to(device)
        image = batch['image'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = model(input_ids=input_ids,
                        bbox=bbox,
                        image=image,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        labels=labels) 
        loss = outputs.loss
        
        # print loss every 100 steps
        if global_step % 100 == 0:
          print(f"Loss after {global_step} steps: {loss.item()}")

        loss.backward()
        optimizer.step()
        global_step += 1

model.save_pretrained("/content/drive/MyDrive/Work Shared/AISC/2022-02-19 Computer Vision Discussion Group/checkpoints")

Downloading:   0%|          | 0.00/802M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/layoutlmv2-base-uncased were not used when initializing LayoutLMv2ForTokenClassification: ['layoutlmv2.visual.backbone.bottom_up.res3.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.2.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.16.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.4.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.8.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.11.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.1.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.22.conv1.norm.num_batches_tracke

Epoch: 0


  0%|          | 0/414 [00:00<?, ?it/s]

Loss after 0 steps: 3.3981549739837646
Loss after 100 steps: 3.2159981727600098
Loss after 200 steps: 1.3570280075073242
Loss after 300 steps: 1.8687337636947632
Loss after 400 steps: 0.8611792325973511
Epoch: 1


  0%|          | 0/414 [00:00<?, ?it/s]

Loss after 500 steps: 0.5236421823501587
Loss after 600 steps: 0.9588270783424377
Loss after 700 steps: 0.41407713294029236
Loss after 800 steps: 0.4185955226421356
Epoch: 2


  0%|          | 0/414 [00:00<?, ?it/s]

Loss after 900 steps: 0.49218064546585083
Loss after 1000 steps: 0.2872087061405182
Loss after 1100 steps: 0.16050751507282257
Loss after 1200 steps: 0.5465169548988342
Epoch: 3


  0%|          | 0/414 [00:00<?, ?it/s]

Loss after 1300 steps: 0.09599148482084274
Loss after 1400 steps: 0.40294086933135986
Loss after 1500 steps: 0.5218316316604614
Loss after 1600 steps: 0.0502297580242157


## Evaluation

Let's evaluate the model on the test set. First, let's do a sanity check on the first example of the test set.

In [52]:
encoding = test_dataset[0]
processor.tokenizer.decode(encoding['input_ids'])

'[CLS] 1001 - choco bun 22. 000 x1 22. 000 6001 - plastic bag small 0 x1 0 22. 000 total. total item : 2 cash 25. 000 tendered : 3. 000 change : [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

In [53]:
ground_truth_labels = [id2label[label] for label in encoding['labels'].squeeze().tolist() if label != -100]
print(ground_truth_labels)

['menu.nm', 'menu.nm', 'menu.unitprice', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.unitprice', 'menu.cnt', 'menu.price', 'total.total_price', 'total.total_price', 'total.menuqty_cnt', 'total.menuqty_cnt', 'total.menuqty_cnt', 'total.cashprice', 'total.cashprice', 'total.cashprice', 'total.changeprice', 'total.changeprice']


In [54]:
for k,v in encoding.items():
  encoding[k] = v.unsqueeze(0).to(device)

model.eval()
# forward pass
outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'],
                token_type_ids=encoding['token_type_ids'], bbox=encoding['bbox'],
                image=encoding['image'])

In [55]:
prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()
print(prediction_indices)

[15, 24, 24, 24, 24, 24, 24, 26, 26, 26, 12, 12, 4, 4, 4, 24, 24, 24, 24, 24, 24, 26, 12, 12, 4, 17, 17, 17, 17, 17, 22, 22, 22, 22, 15, 15, 15, 15, 29, 15, 15, 29, 29, 29, 29, 29, 29, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,

In [56]:
prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()
predictions = [id2label[label] for gt, label in zip(encoding['labels'].squeeze().tolist(), prediction_indices) if gt != -100]
print(predictions)

['menu.nm', 'menu.nm', 'menu.unitprice', 'menu.cnt', 'menu.price', 'menu.nm', 'menu.nm', 'menu.nm', 'menu.unitprice', 'menu.cnt', 'menu.price', 'total.total_price', 'total.total_price', 'total.menuqty_cnt', 'total.menuqty_cnt', 'total.menuqty_cnt', 'total.cashprice', 'total.cashprice', 'total.changeprice', 'total.changeprice', 'total.changeprice']


In [57]:
import numpy as np

preds_val = None
out_label_ids = None

# put model in evaluation mode
model.eval()
for batch in tqdm(test_dataloader, desc="Evaluating"):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        bbox = batch['bbox'].to(device)
        image = batch['image'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        # forward pass
        outputs = model(input_ids=input_ids, bbox=bbox, image=image, attention_mask=attention_mask, 
                        token_type_ids=token_type_ids, labels=labels)
        
        if preds_val is None:
          preds_val = outputs.logits.detach().cpu().numpy()
          out_label_ids = batch["labels"].detach().cpu().numpy()
        else:
          preds_val = np.append(preds_val, outputs.logits.detach().cpu().numpy(), axis=0)
          out_label_ids = np.append(
              out_label_ids, batch["labels"].detach().cpu().numpy(), axis=0
          )

Evaluating:   0%|          | 0/53 [00:00<?, ?it/s]

In [58]:
import warnings
warnings.filterwarnings("ignore")
from seqeval.metrics import (
    classification_report,
    f1_score,
    precision_score,
    recall_score)

def results_test(preds, out_label_ids, labels):
  preds = np.argmax(preds, axis=2)

  label_map = {i: label for i, label in enumerate(labels)}

  out_label_list = [[] for _ in range(out_label_ids.shape[0])]
  preds_list = [[] for _ in range(out_label_ids.shape[0])]

  for i in range(out_label_ids.shape[0]):
      for j in range(out_label_ids.shape[1]):
          if out_label_ids[i, j] != -100:
              out_label_list[i].append(label_map[out_label_ids[i][j]])
              preds_list[i].append(label_map[preds[i][j]])

  results = {
      "precision": precision_score(out_label_list, preds_list),
      "recall": recall_score(out_label_list, preds_list),
      "f1": f1_score(out_label_list, preds_list),
  }
  return results, classification_report(out_label_list, preds_list)

In [59]:
labels = list(set(all_labels))
val_result, class_report = results_test(preds_val, out_label_ids, labels)
print("Overall results:", val_result)
print(class_report)

Overall results: {'precision': 0.8390052356020943, 'recall': 0.8518272425249169, 'f1': 0.845367622815694}
                         precision    recall  f1-score   support

                 ANSWER       0.21      0.13      0.16        68
                 HEADER       0.00      0.00      0.00         8
               QUESTION       0.34      0.51      0.41       110
                enu.cnt       0.94      0.99      0.97       224
      enu.discountprice       0.75      0.60      0.67        10
       enu.itemsubtotal       0.00      0.00      0.00         6
                 enu.nm       0.94      0.95      0.94       251
                enu.num       0.83      0.91      0.87        11
              enu.price       0.96      0.99      0.98       247
            enu.sub_cnt       0.80      0.24      0.36        17
             enu.sub_nm       0.65      0.88      0.75        32
          enu.sub_price       0.88      0.75      0.81        20
          enu.unitprice       0.97      0.99    

The results I was getting were: 

`{'precision': 0.9307458143074582, 'recall': 0.9272175890826384, 'f1': 0.9289783516900872}`