<a href="https://colab.research.google.com/github/clemsage/NeuralDocumentClassification/blob/master/skeleton_ocr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Training a classifier on OCR text input


# Imports & Cloning repository


In [56]:
import os
import pickle
import sys
from dataclasses import dataclass
from os import path

import matplotlib.pyplot as plt
import tqdm


In [3]:
class_names = ["email", "form", "handwritten", "invoice", "advertisement"]
NUM_CLASSES = len(class_names)

In [None]:
if not os.path.exists("NeuralDocumentClassification"):
    !git clone https://github.com/thibaultdouzon/NeuralDocumentClassification.git
else:
    !git -C NeuralDocumentClassification pull
sys.path.append("NeuralDocumentClassification")

In [4]:
from src import download_dataset

dataset_path = "dataset"

download_dataset.download_and_extract("all", dataset_path)

In [None]:
with open(path.join(dataset_path, "train.pkl"), "rb") as f:
    train_dataset = pickle.load(f)

with open(path.join(dataset_path, "test.pkl"), "rb") as f:
    test_dataset = pickle.load(f)

with open(path.join(dataset_path, "validation.pkl"), "rb") as f:
    validation_dataset = pickle.load(f)


for split_name, split_dataset in zip(
    ["train", "test", "validation"], [train_dataset, test_dataset, validation_dataset]
):
    print(f"{split_name}_dataset contains {len(split_dataset)} documents")
train_dataset[0].keys()


Each `dataset` object is a `list` containing multiple document information. A document is a `dict` with the following structure:

```json
{
  "id": "Unique document identifier",
  "image": "A PIL.Image object containing the document's image",
  "label": "A number between in [0 .. 4] representing the class of the document",
  "words": "A list of strings (not words !) extracted from the image with an OCR",
  "boxes": "A list of tuples of numbers providing the position of each word in the document"
}
```


# Explore the data

Take the time to explore the textual data included in the dataset.


Ideas

- 10 most common words? (hint: Counter)
- Count number of unique words
- Distribution of words (cumulative occurences plot)


In [None]:
# @title

from collections import Counter

all_texts = [
    [word for sentence in doc["words"] for word in sentence.split()]
    for doc in validation_dataset + test_dataset + train_dataset
]

most_common_words = Counter([w for text in all_texts for w in text])
most_common_words.most_common(10)

In [None]:
# @title

n_unique_words = len({w for text in all_texts for w in text})
n_unique_words

In [None]:
# @title

# Zipf's law

plt.figure(figsize=(10, 5))
plt.plot(
    [c / sum(most_common_words.values()) for w, c in most_common_words.most_common(50)]
)

# put words on xlabel
plt.xticks(
    range(50),
    [w for w, c in most_common_words.most_common(50)],
    rotation=80,
    fontsize=9,
)
plt.ylabel("Word frequency")
plt.title("Word frequency in the dataset")
plt.show()

In [None]:
# @title
from itertools import accumulate

cum_word_occurences = list(
    accumulate([count for word, count in most_common_words.most_common(n_unique_words)])
)

plt.figure(figsize=(10, 5))
plt.plot(cum_word_occurences)

plt.xlabel("Rank of the word")
plt.ylabel("Number of occurences")
plt.title("Cumulative number of occurences of the most common words")
plt.show()

# Classification with Scikit Learn


In [6]:
import nltk
import sklearn


@dataclass
class TextSample:
    text: str
    label: int

    def __init__(self, document: dict):
        self.text = " ".join(
            [word for sentence in document["words"] for word in sentence.split()]
        )
        self.label = document["label"]


train_samples = [TextSample(doc) for doc in train_dataset]

test_samples = [TextSample(doc) for doc in test_dataset]

validation_samples = [TextSample(doc) for doc in validation_dataset]


## Tokenization and Vectorization

To train models at solving our problem, we need to convert texts into vectors that will represent our documents.
Take a look at Scikit Learn [CountVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html#) and [TFIDFVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html).
First fit a vectorizer on the training set, then apply the vectorization transformation to each dataset split.

What are the shapes of the resulting vectors? What does each dimension mean?


In [None]:
# @title

from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer

vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform([sample.text for sample in train_samples])
X_test = vectorizer.transform([sample.text for sample in test_samples])
X_validation = vectorizer.transform([sample.text for sample in validation_samples])

Y_train = [sample.label for sample in train_samples]
Y_test = [sample.label for sample in test_samples]
Y_validation = [sample.label for sample in validation_samples]


X_train.shape, X_test.shape, X_validation.shape
# Each vector's first dimension is the number of documents, the second dimension is the number of unique words in the dataset
# The value at (i, j) is the number of occurences of the j-th word in the i-th document

## Basic Model: Scikit-Learn Classification

Use any Scikit-Learn classification model to train a first text model.
Good first picks: [Support Vector Classifier](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) or [Random Forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)


In [None]:
# @title

from sklearn.svm import SVC

model = SVC(kernel="linear")
model.fit(X_train, Y_train)


## Evaluate the model

Use Scikit-Learn [metrics](https://scikit-learn.org/stable/modules/model_evaluation.html) to evaluate your model


In [None]:
# @title

from sklearn.metrics import accuracy_score, confusion_matrix

print("Test")
Y_pred = model.predict(X_test)
accuracy = accuracy_score(Y_test, Y_pred)
print(f"Accuracy on the test set: {accuracy:.2f}")
print(confusion_matrix(Y_test, Y_pred))

print("Validation")
Y_pred = model.predict(X_validation)
accuracy = accuracy_score(Y_validation, Y_pred)
print(f"Accuracy on the validation set: {accuracy:.2f}")
print(confusion_matrix(Y_validation, Y_pred))


# Transformers

Done playing with kids toys.

All modern AI models use the [Transformer architecture](https://arxiv.org/pdf/1706.03762). The initial research paper is one of the most influencial of the last decade.


In [86]:
import torch
import transformers
from torch import nn
from torch.utils import data

## Tokenization

Transformers usually use subword tokenizer, ie. a word _can_ be tokenized into multiple tokens.


In [None]:
# Let's use LayoutLM tokenizer first

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "microsoft/layoutlm-base-uncased"
)

In [None]:
encoding = tokenizer("Hello, world! I can tokenize any sentence.")

for token_id in encoding["input_ids"]:
    print(tokenizer.decode(token_id))

# Note how `tokenize` is encoded as `token ##ize`

## Dataset for LayoutLM

LayoutLM uses both textual and 2D positional information, here is a new data sample class to work with


In [75]:
def split_sentence_into_words(
    sentence: str, sentence_box: list[int, int, int, int]
) -> tuple[list[str], list[tuple[int, int, int, int]]]:
    ret_words = []
    ret_boxes = []
    words = sentence.split()

    ret_words.extend(words)

    words_len = [len(word) for word in words]
    box_width = sentence_box[2] - sentence_box[0]

    word_left = sentence_box[0]
    for word_len in words_len:
        word_right = word_left + int(word_len * box_width / len(sentence))
        ret_boxes.append((word_left, sentence_box[1], word_right, sentence_box[3]))
        word_left = word_right + int(1 * box_width / len(sentence))

    return ret_words, ret_boxes


@dataclass
class TextBoxSample:
    words: list[str]
    boxes: list[tuple[int, int, int, int]]  # (left, top, right, bottom)
    label: int

    def __init__(self, document: dict):
        self.words = []
        self.boxes = []

        # We need to split the words in the sentences and compute the bounding boxes for each word
        for sentence, sentence_box in zip(document["words"], document["boxes"]):
            new_words, new_boxes = split_sentence_into_words(sentence, sentence_box)
            self.words.extend(new_words)
            self.boxes.extend(new_boxes)

        self.label = document["label"]


train_samples = [TextBoxSample(doc) for doc in train_dataset]

test_samples = [TextBoxSample(doc) for doc in test_dataset]

validation_samples = [TextBoxSample(doc) for doc in validation_dataset]

Let's implement the pytorch dataset that will hold those samples. Keep it very simple, we will delay most computation to the batching function.


In [None]:
class TextBoxDataset(data.Dataset):
    def __init__(self, samples: list[TextBoxSample]):
        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError

    def __getitem__(self, idx: int) -> TextBoxSample:
        raise NotImplementedError

In [None]:
# @tilte


class TextBoxDataset(data.Dataset):
    def __init__(self, samples: list[TextBoxSample]):
        self.samples = samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx) -> TextBoxSample:
        return self.samples[idx]

## Batching function

Like we did in the vision part, we need to implement a batching function that will batch together multiple inputs together and prepare them to be fed to the model.

Huggingface transformers tokenizers have the hability to tokenize a whole batch at once and perform most of the computation for us.


In [88]:
# LayoutLM tokenizer does not support bounding boxes, so we will use the LayoutLMv2 tokenizer instead
# Otherwise we would have to implement ourselves the mapping of bounding boxes to tokens
# This can be tricky because some words can be split into multiple tokens

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "microsoft/layoutlmv2-base-uncased"
)

# Use it like this, it can support batched inputs
tokenizer(
    text=train_samples[0].words,
    boxes=train_samples[0].boxes,
    padding="max_length",
    truncation=True,
)

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file config.json from cache at /Users/thibaultdouzon/.cache/huggingface/hub/models--microsoft--layoutlmv2-base-uncased/snapshots/ae6f4350c668f88ec580046e35c670df6ec616c1/config.json
Model config LayoutLMv2Config {
  "_name_or_path": "microsoft/layoutlmv2-base-uncased",
  "attention_probs_dropout_prob": 0.1,
  "convert_sync_batchnorm": true,
  "coordinate_size": 128,
  "detectron2_config_args": {
    "MODEL.ANCHOR_GENERATOR.SIZES": [
      [
        32
      ],
      [
        64
      ],
      [
        128
      ],
      [
        256
      ],
      [
        512
      ]
    ],
    "MODEL.BACKBONE.NAME": "build_resnet_fpn_backbone",
    "MODEL.FPN.IN_FEATURES": [
      "res2",
      "res3",
      "res4",
      "res5"
    ],
    "MODEL.MASK_ON": true,
    "MODEL.PIXEL_STD": [
      57.375,
      57.12,
      58.395
    ],
    "MODEL.POST_NMS_TOPK_TEST": 1000,
    "MODEL.RE

{'input_ids': [101, 2013, 1024, 26429, 1010, 6338, 1039, 2006, 12256, 1025, 19802, 2756, 1010, 2722, 2184, 1024, 4466, 2572, 1015, 3395, 1024, 6819, 3527, 12928, 10374, 2013, 1039, 1012, 26429, 6764, 2000, 1024, 11404, 1025, 3021, 2332, 1024, 10507, 1024, 5003, 8747, 1025, 3744, 1025, 14264, 12170, 2100, 2436, 2001, 2333, 2197, 2733, 1998, 25610, 2368, 2128, 1066, 15242, 2070, 1997, 1996, 8378, 1997, 20868, 5358, 2665, 1000, 5371, 8208, 2433, 5585, 2549, 3074, 5491, 2234, 2408, 2628, 2011, 4809, 1997, 1996, 2392, 1997, 2261, 2678, 2696, 10374, 1025, 2228, 2122, 6876, 13262, 1041, 7959, 7610, 2278, 1025, 2031, 2042, 2513, 1998, 2020, 17153, 7652, 2098, 2000, 3765, 18808, 1999, 2254, 1997, 2023, 2095, 2043, 1031, 2187, 1011, 2021, 2245, 2009, 2190, 1056, 2692, 2191, 2469, 1012, 2097, 5860, 4232, 1996, 2065, 29337, 2064, 12210, 2008, 2122, 2678, 2696, 10374, 2024, 1999, 5527, 1010, 4809, 1997, 2037, 4486, 2027, 2019, 2063, 7610, 2072, 3314, 3539, 2099, 1006, 4098, 5349, 17550, 1007, 5170,

In [76]:
@dataclass
class TextBoxBatch:
    words: torch.LongTensor  # (batch_size, max_seq_len)
    boxes: torch.LongTensor  # (batch_size, max_seq_len, 4)
    labels: torch.LongTensor  # (batch_size, max_seq_len)
    token_type_ids: torch.LongTensor  # (batch_size, max_seq_len)
    attention_mask: torch.LongTensor  # (batch_size, max_seq_len)

    def to(self, device: str):
        self.words = self.words.to(device)
        self.boxes = self.boxes.to(device)
        self.labels = self.labels.to(device)
        self.token_type_ids = self.token_type_ids.to(device)
        self.attention_mask = self.attention_mask.to(device)
        return self

    def __post_init__(self):
        batch_size = self.words.shape[0]
        assert self.words.shape == (batch_size, 512)
        assert self.boxes.shape == (batch_size, 512, 4)
        assert self.labels.shape == (batch_size, 512)

In [77]:
def collate_fn(
    samples: list[TextBoxSample],
    tokenizer: transformers.LayoutLMv2Tokenizer = tokenizer,
) -> TextBoxBatch:
    # Implement the collate_fn function
    pass

In [78]:
# @title


def collate_fn(
    samples: list[TextBoxSample],
    tokenizer: transformers.LayoutLMv2Tokenizer = tokenizer,
) -> TextBoxBatch:
    encodings = tokenizer(
        text=[sample.words for sample in samples],
        boxes=[sample.boxes for sample in samples],
        padding="max_length",
        truncation=True,
        return_tensors="pt",  # return PyTorch tensors
    )
    encodings["labels"] = (
        torch.zeros_like(encodings["input_ids"]) - 100
    )  # -100 is the default ignore value for the loss function
    encodings["labels"][:, 0] = torch.tensor(
        [sample.label for sample in samples], dtype=torch.long
    )

    return TextBoxBatch(
        words=encodings["input_ids"],
        boxes=encodings["bbox"],
        labels=encodings["labels"],
        token_type_ids=encodings["token_type_ids"],
        attention_mask=encodings["attention_mask"],
    )

In [82]:
# If you got it right, this should work properly

collate_fn(train_samples[:12])

TextBoxBatch(words=tensor([[  101,  2013,  1024,  ...,     0,     0,     0],
        [  101,  1049,  2546,  ...,     0,     0,     0],
        [  101,  1016, 17710,  ...,     0,     0,     0],
        ...,
        [  101,  1015,  1015,  ...,     0,     0,     0],
        [  101,  2622,  3642,  ...,     0,     0,     0],
        [  101, 25294,  2470,  ...,     0,     0,     0]]), boxes=tensor([[[  0,   0,   0,   0],
         [145, 105, 176, 121],
         [145, 105, 176, 121],
         ...,
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0]],

        [[  0,   0,   0,   0],
         [ 87,  65, 111,  77],
         [ 87,  65, 111,  77],
         ...,
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0]],

        [[  0,   0,   0,   0],
         [ 56,  64,  80,  92],
         [299,  94, 328, 117],
         ...,
         [  0,   0,   0,   0],
         [  0,   0,   0,   0],
         [  0,   0,   0,   0]],

      

## Model - LayoutLM

The transformer library provides model's code and weights. We will use the weights of a [fine-tuned model](https://huggingface.co/gurvgupta/LayoutLM_rvl-cdip) on RVL-CDIP from the hub.
Let's first download its weights and fix his mistakes so we can load the model weights.

This pre-trained model was already fine-tuned on a superset of our dataset.
We will still fine tune it for a few epochs because our final classes are different.


In [None]:
if not path.exists("LayoutLM_rvl-cdip"):
    !git lfs install
    !git clone https://huggingface.co/gurvgupta/LayoutLM_rvl-cdip
    !mv LayoutLM_rvl-cdip/LayoutLM_rvl-cdip_epoch_50.pt LayoutLM_rvl-cdip/pytorch_model.bin

In [90]:
from transformers.models.layoutlm import LayoutLMForSequenceClassification

model = LayoutLMForSequenceClassification.from_pretrained(
    "./LayoutLM_rvl-cdip", num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
)
model

Model config LayoutLMConfig {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4
  },
  "layer_norm_eps": 1e-12,
  "max_2d_position_embeddings": 1024,
  "max_position_embeddings": 512,
  "model_type": "layoutlm",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.45.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file ./LayoutLM_rvl-cdip/pytorch_model.bin
All model checkpoint weights were used when initializing LayoutLMForSequenceClassification.

Some weights of LayoutLMForSequenceClassification were not initi

LayoutLMForSequenceClassification(
  (layoutlm): LayoutLMModel(
    (embeddings): LayoutLMEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (x_position_embeddings): Embedding(1024, 768)
      (y_position_embeddings): Embedding(1024, 768)
      (h_position_embeddings): Embedding(1024, 768)
      (w_position_embeddings): Embedding(1024, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LayoutLMEncoder(
      (layer): ModuleList(
        (0-11): 12 x LayoutLMLayer(
          (attention): LayoutLMAttention(
            (self): LayoutLMSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True

In [91]:
?model.forward

[0;31mSignature:[0m
[0mmodel[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_ids[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mLongTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbbox[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mLongTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mattention_mask[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mFloatTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtoken_type_ids[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mLongTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mposition_ids[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mLongTensor[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m


In [92]:
# The model can be used like this

data = collate_fn(train_samples[:2])
model(
    input_ids=data.words,
    bbox=data.boxes,
    token_type_ids=data.token_type_ids,
    attention_mask=data.attention_mask,
)

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.3880,  0.3127,  0.1283, -0.3651, -0.1338],
        [ 0.4449, -0.3743,  0.2650, -0.7613, -0.0456]],
       grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

## Train the model

First, let's copy the training loop procedure from the previous notebook


In [87]:
# Copied from `chapter_1_vision.ipynb`


def train_one_epoch(
    model: nn.Module,
    dataloader: data.DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,  # type: ignore
    device: torch.device,
) -> float:
    """This function should train the model for one epoch and return the average loss"""
    model.train()
    model.to(device)

    epoch_loss = 0.0
    with tqdm.tqdm(desc="Training", total=len(dataloader)) as pbar:
        for i, batch in enumerate(dataloader):
            images, labels = batch.images.to(device), batch.labels.to(device)

            optimizer.zero_grad()  # Reset gradients
            outputs = model(images)  # Compute model's predictions
            loss = loss_fn(outputs, labels)  # Compute the loss

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            pbar.set_postfix(loss=epoch_loss / (i + 1))
            pbar.update(1)
    mean_loss = epoch_loss / len(dataloader)
    print(f"Training loss (↓): {mean_loss:.4f}")
    return mean_loss


def evaluate(
    model: nn.Module,
    dataloader: data.DataLoader,
    loss_fn: nn.Module,
    metric_fn: nn.Module,
    device: torch.device,
    dataset_name: str = "validation",
) -> tuple[float, float]:
    """This function should evaluate the model on the dataset and return the average loss and metric"""
    model.eval()
    model.to(device)

    epoch_loss = 0.0
    epoch_metric = 0.0
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Evaluation"):
            images, labels = batch.images.to(device), batch.labels.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, labels)
            metric = metric_fn(outputs.argmax(dim=-1), labels)

            epoch_loss += loss.item()
            epoch_metric += metric.item()

        mean_loss = epoch_loss / len(dataloader)
        print(f"{dataset_name.capitalize()} loss (↓): {mean_loss:.4f}")
        mean_metric = epoch_metric / len(dataloader)
        print(f"{dataset_name.capitalize()} metric (↑): {mean_metric:.4f}")
        return mean_loss, mean_metric


def train(
    model: nn.Module,
    train_dataloader: data.DataLoader,
    validation_dataloader: data.DataLoader,
    loss_fn: nn.Module,
    metric_fn: nn.Module,
    optimizer: torch.optim.Optimizer,  # type: ignore
    device: torch.device,
    n_epochs: int = 10,
) -> tuple[list[float], list[float], list[float]]:
    """This function should train the model for some epochs and return the training and validation losses"""
    train_losses = []
    validation_losses = []
    validation_metrics = []

    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs}")
        train_loss = train_one_epoch(
            model, train_dataloader, loss_fn, optimizer, device
        )
        train_losses.append(train_loss)

        validation_loss, validation_metric = evaluate(
            model, validation_dataloader, loss_fn, metric_fn, device
        )
        validation_losses.append(validation_loss)
        validation_metrics.append(validation_metric)

    return train_losses, validation_losses, validation_metrics

In [None]:
import torchmetrics

train_loader = data.DataLoader(
    DocumentTextBoxDataset(train_dataset),
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
)
validation_loader = data.DataLoader(
    DocumentTextBoxDataset(validation_dataset),
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=False,
)

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

selected_model = mlp_model

optimizer = torch.optim.Adam(selected_model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
metric_fn = torchmetrics.Accuracy(task="multiclass", num_classes=NUM_CLASSES).to(device)

n_epochs = 5

hist = train(
    selected_model,
    train_loader,
    validation_loader,
    loss_fn,
    metric_fn,
    optimizer,
    device,
    n_epochs=n_epochs,
)