<a href="https://colab.research.google.com/github/thibaultdouzon/NeuralDocumentClassification/blob/master/chapter_2_text.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training a classifier on OCR text input


# Imports & Cloning repository


You will need this library version for the next of the project

In [None]:
%pip install transformers==4.45.1
%pip install torchmetrics

In [None]:
import os
import pickle
import sys
from dataclasses import dataclass
from os import path

import matplotlib.pyplot as plt
import tqdm

In [None]:
class_names = ["email", "form", "handwritten", "invoice", "advertisement"]
NUM_CLASSES = len(class_names)

In [None]:
if not os.path.exists("NeuralDocumentClassification"):
    !git clone https://github.com/thibaultdouzon/NeuralDocumentClassification.git
else:
    !git -C NeuralDocumentClassification pull
sys.path.append("NeuralDocumentClassification")

In [None]:
from src import download_dataset

dataset_path = "dataset"
download_dataset.download_and_extract("all", dataset_path)

In [None]:
with open(path.join(dataset_path, "train.pkl"), "rb") as f:
    train_dataset = pickle.load(f)

with open(path.join(dataset_path, "test.pkl"), "rb") as f:
    test_dataset = pickle.load(f)

with open(path.join(dataset_path, "validation.pkl"), "rb") as f:
    validation_dataset = pickle.load(f)


for split_name, split_dataset in zip(
    ["train", "test", "validation"], [train_dataset, test_dataset, validation_dataset]
):
    print(f"{split_name}_dataset contains {len(split_dataset)} documents")
train_dataset[0].keys()

Each `dataset` object is a `list` containing multiple document information. A document is a `dict` with the following structure:

```json
{
  "id": "Unique document identifier",
  "image": "A PIL.Image object containing the document's image",
  "label": "A number between in [0 .. 4] representing the class of the document",
  "words": "A list of strings (not words !) extracted from the image with an OCR",
  "boxes": "A list of tuples of numbers providing the position of each word in the document"
}
```


# Explore the data

Take the time to explore the textual data included in the dataset.


Ideas

- 10 most common words? (hint: Counter)


In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

from collections import Counter

all_texts = [
    [word for sentence in doc["words"] for word in sentence.split()]
    for doc in validation_dataset + test_dataset + train_dataset
]

most_common_words = Counter([w for text in all_texts for w in text])
most_common_words.most_common(10)

- Count number of unique words

In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

n_unique_words = len({w for text in all_texts for w in text})
n_unique_words

- Distribution of words (cumulative occurences plot)

In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

# Zipf's law
plt.figure(figsize=(10, 5))
plt.plot(
    [c / sum(most_common_words.values()) for w, c in most_common_words.most_common(50)]
)

# put words on xlabel
plt.xticks(
    range(50),
    [w for w, c in most_common_words.most_common(50)],
    rotation=80,
    fontsize=9,
)
plt.ylabel("Word frequency")
plt.title("Word frequency in the dataset")
plt.show()

In [None]:
# @title
from itertools import accumulate

cum_word_occurences = list(
    accumulate([count for word, count in most_common_words.most_common(n_unique_words)])
)

plt.figure(figsize=(10, 5))
plt.plot(cum_word_occurences)

plt.xlabel("Rank of the word")
plt.ylabel("Number of occurences")
plt.title("Cumulative number of occurences of the most common words")
plt.show()

# Classification with Scikit Learn

In this part, we will train simple classification algorithms using Scikit-learn library.
The following code defines the training samples we will use.

You can try to modify it to further clean the data using the nltk library.


In [None]:
import nltk
import sklearn


@dataclass
class TextSample:
    text: str
    label: int

    def __init__(self, document: dict):
        self.text = " ".join(
            [word for sentence in document["words"] for word in sentence.split()]
        )
        self.label = document["label"]


train_samples = [TextSample(doc) for doc in train_dataset]

test_samples = [TextSample(doc) for doc in test_dataset]

validation_samples = [TextSample(doc) for doc in validation_dataset]


## Tokenization and Vectorization

To train models at solving our problem, we need to convert texts into vectors that will represent our documents.
Take a look at Scikit Learn [CountVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html#) and [TFIDFVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html).
First fit a vectorizer on the training set, then apply the vectorization transformation to each dataset split.

What are the shapes of the resulting vectors? What does each dimension mean?


In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform([sample.text for sample in train_samples])
X_test = vectorizer.transform([sample.text for sample in test_samples])
X_validation = vectorizer.transform([sample.text for sample in validation_samples])

Y_train = [sample.label for sample in train_samples]
Y_test = [sample.label for sample in test_samples]
Y_validation = [sample.label for sample in validation_samples]

X_train.shape, X_test.shape, X_validation.shape
# Each vector's first dimension is the number of documents, the second dimension is the number of unique words in the dataset
# The value at (i, j) is the number of occurences of the j-th word in the i-th document

## Basic Model: Scikit-Learn Classification

Use any Scikit-Learn classification model to train a first text model.
Good first picks: [Support Vector Classifier](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) or [Random Forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)


In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

from sklearn.svm import SVC

model = SVC(kernel="linear")
model.fit(X_train, Y_train)

## Evaluate the model

Use Scikit-Learn [metrics](https://scikit-learn.org/stable/modules/model_evaluation.html) to evaluate your model


In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title

from sklearn.metrics import accuracy_score, confusion_matrix

print("Test")
Y_pred = model.predict(X_test)
accuracy = accuracy_score(Y_test, Y_pred)
print(f"Accuracy on the test set: {accuracy:.2f}")
print(confusion_matrix(Y_test, Y_pred))

print("Validation")
Y_pred = model.predict(X_validation)
accuracy = accuracy_score(Y_validation, Y_pred)
print(f"Accuracy on the validation set: {accuracy:.2f}")
print(confusion_matrix(Y_validation, Y_pred))

# Transformers

Done playing with kid toys.

All modern AI models use the [Transformer architecture](https://arxiv.org/pdf/1706.03762). The initial research paper is one of the most influencial of the last decade.


In [None]:
import torch
import transformers
from torch import nn
from torch.utils import data

## Tokenization

Transformers usually use subword tokenizer, ie. a word _can_ be tokenized into multiple tokens.


In [None]:
# Let's use LayoutLM tokenizer first

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "microsoft/layoutlm-base-uncased"
)

In [None]:
encoding = tokenizer("Hello, world! I can tokenize any sentence.")

for token_id in encoding["input_ids"]:
    print(tokenizer.decode(token_id))

# Note how `tokenize` is encoded as `token ##ize`

## Dataset for LayoutLM

LayoutLM uses both textual and 2D positional information, here is a new data sample class to work with


In [None]:
def split_sentence_into_words(
    sentence: str, sentence_box: list[int, int, int, int]
) -> tuple[list[str], list[tuple[int, int, int, int]]]:
    ret_words = []
    ret_boxes = []
    words = sentence.split()

    ret_words.extend(words)

    words_len = [len(word) for word in words]
    box_width = sentence_box[2] - sentence_box[0]

    word_left = sentence_box[0]
    for word_len in words_len:
        word_right = word_left + int(word_len * box_width / len(sentence))
        ret_boxes.append((word_left, sentence_box[1], word_right, sentence_box[3]))
        word_left = word_right + int(1 * box_width / len(sentence))

    return ret_words, ret_boxes


@dataclass
class TextBoxSample:
    words: list[str]
    boxes: list[tuple[int, int, int, int]]  # (left, top, right, bottom)
    label: int

    def __init__(self, document: dict):
        self.words = []
        self.boxes = []

        # We need to split the words in the sentences and compute the bounding boxes for each word
        for sentence, sentence_box in zip(document["words"], document["boxes"]):
            new_words, new_boxes = split_sentence_into_words(sentence, sentence_box)
            self.words.extend(new_words)
            self.boxes.extend(new_boxes)

        self.label = document["label"]


train_samples = [TextBoxSample(doc) for doc in train_dataset]

test_samples = [TextBoxSample(doc) for doc in test_dataset]

validation_samples = [TextBoxSample(doc) for doc in validation_dataset]

Let's implement the pytorch dataset that will hold those samples. Keep it very simple, we will delay most computation to the batching function.


In [None]:
### Modify the code here ###
# See the expected solution by clicking on the cell below

class DocumentTextBoxDataset(data.Dataset):
    def __init__(self, samples: list[TextBoxSample]):
        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError

    def __getitem__(self, idx: int) -> TextBoxSample:
        raise NotImplementedError

In [None]:
# @title

class DocumentTextBoxDataset(data.Dataset):
    def __init__(self, samples: list[TextBoxSample]):
        self.samples = samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx) -> TextBoxSample:
        return self.samples[idx]

## Batching function

Like we did in the vision part, we need to implement a batching function that will batch together multiple inputs together and prepare them to be fed to the model.

Huggingface transformers tokenizers have the hability to tokenize a whole batch at once and perform most of the computation for us.


In [None]:
# LayoutLM tokenizer does not support bounding boxes, so we will use the LayoutLMv2 tokenizer instead
# Otherwise we would have to implement ourselves the mapping of bounding boxes to tokens
# This can be tricky because some words can be split into multiple tokens

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "microsoft/layoutlmv2-base-uncased"
)

# Use it like this, it can support batched inputs
tokenizer(
    text=train_samples[0].words,
    boxes=train_samples[0].boxes,
    padding="max_length",
    truncation=True,
)

This is the definition we will use for a batch of samples


In [None]:
@dataclass
class TextBoxBatch:
    words: torch.LongTensor  # (batch_size, max_seq_len)
    boxes: torch.LongTensor  # (batch_size, max_seq_len, 4)
    labels: torch.LongTensor  # (batch_size)
    token_type_ids: torch.LongTensor  # (batch_size, max_seq_len)
    attention_mask: torch.LongTensor  # (batch_size, max_seq_len)

    def to(self, device: str):
        self.words = self.words.to(device)
        self.boxes = self.boxes.to(device)
        self.labels = self.labels.to(device)
        self.token_type_ids = self.token_type_ids.to(device)
        self.attention_mask = self.attention_mask.to(device)
        return self

    def __post_init__(self):
        if self.boxes.max() > 1000 or self.boxes.min() < 0:
            self.boxes.clamp_(min=0, max=1000)

        self.boxes[:, :, 0] = torch.where(
            self.boxes[:, :, 0] > self.boxes[:, :, 2],
            self.boxes[:, :, 2],
            self.boxes[:, :, 0]
        )
        self.boxes[:, :, 1] = torch.where(
            self.boxes[:, :, 1] > self.boxes[:, :, 3],
            self.boxes[:, :, 3],
            self.boxes[:, :, 1]
        )

        self.boxes.clamp_(min=0, max=1000)
        batch_size = self.words.shape[0]
        assert self.words.shape == (batch_size, 512)
        assert self.boxes.shape == (batch_size, 512, 4)
        assert self.labels.shape == (batch_size,)

Implement the batching function that converts a list of `TextBoxSample` to a `TextBoxBatch`. Use the tokenizer to tokenize words and boxes.


In [None]:
### Modify the code here ###
# See the expected solution by clicking on the cell below

def collate_fn(
    samples: list[TextBoxSample],
    tokenizer: transformers.LayoutLMv2Tokenizer = tokenizer,
) -> TextBoxBatch:
    # Implement the collate_fn function
    pass

In [None]:
# @title

def collate_fn(
    samples: list[TextBoxSample],
    tokenizer: transformers.LayoutLMv2Tokenizer = tokenizer,
) -> TextBoxBatch:
    encodings = tokenizer(
        text=[sample.words for sample in samples],
        boxes=[sample.boxes for sample in samples],
        padding="max_length",
        truncation=True,
        return_tensors="pt",  # return PyTorch tensors
    )
    encodings["labels"] = torch.tensor(
        [sample.label for sample in samples], dtype=torch.long
    )

    return TextBoxBatch(
        words=encodings["input_ids"],
        boxes=encodings["bbox"],
        labels=encodings["labels"],
        token_type_ids=encodings["token_type_ids"],
        attention_mask=encodings["attention_mask"],
    )

In [None]:
# If you got it right, this should work properly

collate_fn(train_samples[:12])

## Model - LayoutLM

The transformer library provides model's code and weights. We will use the weights of a [fine-tuned model](https://huggingface.co/gurvgupta/LayoutLM_rvl-cdip) on RVL-CDIP from the hub.
Let's first download its weights and fix his mistakes so we can load the model weights.

This pre-trained model was already fine-tuned on a superset of our dataset.
We will still fine tune it for a few epochs because our final classes are different.


In [None]:
if not path.exists("LayoutLM_rvl-cdip"):
    !git lfs install
    !git clone https://huggingface.co/gurvgupta/LayoutLM_rvl-cdip
    !mv LayoutLM_rvl-cdip/LayoutLM_rvl-cdip_epoch_50.pt LayoutLM_rvl-cdip/pytorch_model.bin

In [None]:
from transformers.models.layoutlm import LayoutLMForSequenceClassification

model = LayoutLMForSequenceClassification.from_pretrained(
    "./LayoutLM_rvl-cdip", num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
)
model

The model can be used like this, observe it's inputs and output type.


In [None]:
# The model can be used like this

batch = collate_fn(train_samples[:2])
model(
    input_ids=batch.words,
    bbox=batch.boxes,
    token_type_ids=batch.token_type_ids,
    attention_mask=batch.attention_mask,
)

## Train the model

First, let's copy the training loop procedure from the previous notebook and modify it to adapt to the new data format and model's output.


In [None]:
### Insert your code here ###
# See the expected solution by clicking on the cell below

In [None]:
# @title
# Copied from `chapter_1_vision.ipynb`

def train_one_epoch(
    model: nn.Module,
    dataloader: data.DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,  # type: ignore
    device: torch.device,
) -> float:
    """This function should train the model for one epoch and return the average loss"""
    model.train()
    model.to(device)

    epoch_loss = 0.0
    with tqdm.tqdm(desc="Training", total=len(dataloader)) as pbar:
        for i, batch in enumerate(dataloader):
            batch.to(device)
            words, boxes, labels, token_type_ids, attention_mask = (
                batch.words,
                batch.boxes,
                batch.labels,
                batch.token_type_ids,
                batch.attention_mask,
            )

            optimizer.zero_grad()  # Reset gradients
            outputs = model(
                input_ids=words,
                bbox=boxes,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
            ).logits  # Compute model's predictions

            loss = loss_fn(outputs, labels)  # Compute the loss

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            pbar.set_postfix(loss=epoch_loss / (i + 1))
            pbar.update(1)
    mean_loss = epoch_loss / len(dataloader)
    print(f"Training loss (↓): {mean_loss:.4f}")
    return mean_loss


def evaluate(
    model: nn.Module,
    dataloader: data.DataLoader,
    loss_fn: nn.Module,
    metric_fn: nn.Module,
    device: torch.device,
    dataset_name: str = "validation",
) -> tuple[float, float]:
    """This function should evaluate the model on the dataset and return the average loss and metric"""
    model.eval()
    model.to(device)

    epoch_loss = 0.0
    epoch_metric = 0.0
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Evaluation"):
            batch.to(device)
            words, boxes, labels, token_type_ids, attention_mask = (
                batch.words,
                batch.boxes,
                batch.labels,
                batch.token_type_ids,
                batch.attention_mask,
            )

            outputs = model(
                input_ids=words,
                bbox=boxes,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
            ).logits  # Compute model's predictions

            loss = loss_fn(outputs, labels)
            metric = metric_fn(outputs.argmax(dim=-1), labels)

            epoch_loss += loss.item()
            epoch_metric += metric.item()

        mean_loss = epoch_loss / len(dataloader)
        print(f"{dataset_name.capitalize()} loss (↓): {mean_loss:.4f}")
        mean_metric = epoch_metric / len(dataloader)
        print(f"{dataset_name.capitalize()} metric (↑): {mean_metric:.4f}")
        return mean_loss, mean_metric


def train(
    model: nn.Module,
    train_dataloader: data.DataLoader,
    validation_dataloader: data.DataLoader,
    loss_fn: nn.Module,
    metric_fn: nn.Module,
    optimizer: torch.optim.Optimizer,  # type: ignore
    device: torch.device,
    n_epochs: int = 10,
) -> tuple[list[float], list[float], list[float]]:
    """This function should train the model for some epochs and return the training and validation losses"""
    train_losses = []
    validation_losses = []
    validation_metrics = []

    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs}")
        train_loss = train_one_epoch(
            model, train_dataloader, loss_fn, optimizer, device
        )
        train_losses.append(train_loss)

        validation_loss, validation_metric = evaluate(
            model, validation_dataloader, loss_fn, metric_fn, device
        )
        validation_losses.append(validation_loss)
        validation_metrics.append(validation_metric)

    return train_losses, validation_losses, validation_metrics

In [None]:
import torchmetrics

train_loader = data.DataLoader(
    DocumentTextBoxDataset(train_samples),
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
)
validation_loader = data.DataLoader(
    DocumentTextBoxDataset(validation_samples),
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=False,
)

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

device = "cpu"

selected_model = model

optimizer = torch.optim.Adam(selected_model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()
metric_fn = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device)

n_epochs = 2

hist = train(
    selected_model,
    train_loader,
    validation_loader,
    loss_fn,
    metric_fn,
    optimizer,
    device,
    n_epochs=n_epochs,
)

Do not hesitate to compare the performance of the models of the first and second notebook.