# Assumptions:

Only documents up to 2 pages are considered, because:

- most invoices are one-pagers,
- pages 2, 3, 4 of multipages invoices are usually similar to each other

# Config

In [None]:
import os

os.environ["ALBUMENTATIONS_DISABLE_UPDATE_CHECK"] = "1"

In [None]:
from copy import deepcopy
from typing import Callable

import albumentations as A
import ast
import fsspec
import numpy as np
import timm
import torch
import pandas as pd

from enum import Enum
from fsspec import AbstractFileSystem
from PIL import Image
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torchmetrics import Accuracy
from torchmetrics.classification import (
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
)
from tqdm.notebook import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Utils

## Clear cuda cache

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

# Dataset

## Helpers

In [None]:
def get_filesystem(
    path: str, fs_args: dict | None = None, credentials: dict | None = None
) -> AbstractFileSystem:
    _fs_args = deepcopy(fs_args) or {}
    _credentials = deepcopy(credentials) or {}

    protocol = "file"  # constant for local development
    if protocol == "file":
        _fs_args.setdefault("auto_mkdir", True)

    return fsspec.filesystem(protocol, **{**_credentials, **_fs_args})

## Definition of dataset

In [None]:
class ImageSequencesDataset(Dataset):
    def __init__(
        self,
        filepath: str,
        paths_column: str = "paths",
        label_column: str = "label",
        fs_args: dict | None = None,
        credentials: dict | None = None,
        transform: Callable[[Image.Image], torch.Tensor] | None = None,
    ):
        super().__init__()
        self.data_file_path = filepath
        self.paths_column = paths_column
        self.label_column = label_column
        self.data: pd.DataFrame = self._load()
        self.transform_fn = transform
        self._fs: AbstractFileSystem = get_filesystem(self.data_file_path, fs_args, credentials)

    def __len__(self) -> int:
        return 0 if self.data is None else len(self.data)

    def __getitem__(self, idx: int) -> tuple[list[torch.Tensor], int]:
        img_paths = self.data.loc[idx, self.paths_column]
        label = self.data.loc[idx, self.label_column]
        item = []

        for img_path in img_paths:
            with self._fs.open(img_path, "rb") as f:
                img = Image.open(f)
                item.append(self._transform(img))

        return item, label

    # helpers
    def _load(self) -> pd.DataFrame:
        data = pd.read_csv(self.data_file_path)
        data[self.paths_column] = data[self.paths_column].apply(ast.literal_eval)
        return data

    def _transform(self, img: Image.Image) -> torch.Tensor:
        if self.transform_fn is not None:
            return self.transform_fn(img)
        return A.ToTensorV2()(image=np.array(img))["image"]

## Collate function

In [None]:
def collate_fn(
    batch: list[tuple[list[torch.Tensor], int]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    items_lengths = np.array([len(items_pair[0]) for items_pair in batch])
    sorted_indices = np.argsort(-items_lengths)  # sort in descending order

    items_batch = []
    items_label_batch = []

    for idx in sorted_indices:
        items, label = batch[idx]
        items_batch += [*items]
        items_label_batch.append(label)

    return (
        torch.stack(items_batch),
        torch.LongTensor(items_lengths[sorted_indices]),
        torch.LongTensor(items_label_batch),
    )

# Image transformer

In [None]:
def transform(img: Image.Image) -> torch.Tensor:
    _transform = A.Compose(
        [
            A.Resize(height=224, width=224),
            A.GaussianBlur(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.ToGray(p=0.5),
            A.Normalize(
                mean=(0.5, 0.5, 0.5),
                std=(0.5, 0.5, 0.5),
            ),
            A.ToTensorV2(),
        ]
    )

    return _transform(image=np.array(img))["image"]

# Test pretrained timm model (encoder backbone)

In [None]:
model_name = "timm/vit_base_patch16_clip_224.openai"

In [None]:
backbone = timm.create_model(model_name, pretrained=True, num_classes=0)

In [None]:
backbone??

In [None]:
image = Image.open("../data/05_model_input/notebooks/train/invoice_0_type_2/invoice_0_type_2_0.jpg")
inputs = transform(image)

# Get the model's output
with torch.no_grad():
    outputs = backbone(inputs.unsqueeze(0))

In [None]:
outputs

# Image Encoders

## Recursive Image Encoder

In [None]:
class RecursiveImageEncoder(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        hidden_size: int = 512,
        rnn_layers_num: int = 1,
        dropout: float = 0,
    ):
        super().__init__()
        self.backbone = backbone
        self.embedding_dim = backbone.num_features
        self.rnn = nn.RNN(
            input_size=self.embedding_dim,
            hidden_size=hidden_size,
            batch_first=True,
            num_layers=rnn_layers_num,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor, lengths: torch.Tensor):
        positional_embeddings = self._get_positional_embeddings(max(lengths), self.embedding_dim)
        output = self.backbone(x)
        sequences = torch.split(output, lengths.tolist(), dim=0)
        padded_sequences = pad_sequence(list(sequences), batch_first=True)
        padded_sequences += positional_embeddings
        packed_seqs = pack_padded_sequence(
            padded_sequences, lengths.cpu(), batch_first=True, enforce_sorted=True
        )
        _, last_hidden_state = self.rnn(packed_seqs)

        return last_hidden_state.squeeze(dim=0)

    def _get_positional_embeddings(self, seq_len, embedding_dim):
        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2) * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
        )
        pe = torch.zeros(seq_len, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.to(device)

## Avg Image Encoder

In [None]:
class AvgImageEncoder(nn.Module):
    def __init__(self, backbone: nn.Module):
        super().__init__()
        self.backbone = backbone

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        output = self.backbone(x)
        sequences = torch.split(output, lengths.tolist(), dim=0)
        batch_output = torch.stack([torch.mean(sec, dim=0) for sec in sequences])

        return batch_output

# Classification head

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_shape, output_size, activation=nn.ReLU()):
        super().__init__()
        self.hidden = nn.Sequential(
            *self.__build_hidden_layers(input_size, hidden_shape, activation)
        )
        self.out = nn.Linear(hidden_shape[-1], output_size)

    def __build_hidden_layers(self, input_size, hidden_shape, activation):
        hidden_layers = [nn.Linear(input_size, hidden_shape[0]), activation]
        for i in range(0, len(hidden_shape) - 1):
            hidden_layers.extend([nn.Linear(hidden_shape[i], hidden_shape[i + 1]), activation])
        return hidden_layers

    def forward(self, x):
        x = self.hidden(x)
        x = self.out(x)
        return x

# Classifier

In [None]:
class DocumentClassifier(nn.Module):
    def __init__(self, encoder: AvgImageEncoder | RecursiveImageEncoder, classifier: MLP):
        super().__init__()
        self.encoder = encoder
        self.classifier = classifier

    def forward(self, x):
        output = self.encoder(*x)
        output = self.classifier(output)
        return output

# ClassificationReport

In [None]:
class ClassificationReport:
    def __init__(self, num_classes: int = 4, average: str = "macro"):
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes).to(device)
        self.precision = MulticlassPrecision(num_classes=num_classes, average=average).to(device)
        self.recall = MulticlassRecall(num_classes=num_classes, average=average).to(device)
        self.f1 = MulticlassF1Score(num_classes=num_classes, average=average).to(device)

    def generate(self, y_pred, y_true):
        accuracy_score = self.accuracy(y_pred, y_true)
        precision_score = self.precision(y_pred, y_true)
        recall_score = self.recall(y_pred, y_true)
        f1_score = self.f1(y_pred, y_true)

        return accuracy_score, precision_score, recall_score, f1_score

# Evaluate function

In [None]:
def evaluate(model, dataloader, criterion):
    classifiaction_report = ClassificationReport()
    model.eval()
    with torch.no_grad():
        model.to(device)
        epoch_loss = 0
        y_true = torch.empty(0).to(device)
        y_pred = torch.empty(0).to(device)

        for items, lengths, labels in tqdm(dataloader):
            items, lengths, labels = (
                items.to(device),
                lengths.to(device),
                labels.to(device),
            )
            output = model((items, lengths))
            _, lables_pred = output.max(dim=1)

            y_true = torch.cat((y_true, labels), dim=0)
            y_pred = torch.cat((y_pred, lables_pred), dim=0)

            loss = criterion(output, labels)
            epoch_loss += loss.item() * items.size(0)
    avg_epoch_loss = epoch_loss / len(dataloader.dataset)
    accuracy, precision, recall, f1 = classifiaction_report.generate(y_pred, y_true)
    return avg_epoch_loss, accuracy, precision, recall, f1

# Zero-shot tests (to check whether encoder and classifier work together)

## Encoder backbone

In [None]:
backbone = timm.create_model(
    "timm/vit_base_patch16_clip_224.openai", pretrained=True, num_classes=0
)

## Dataset & Dataloader

In [None]:
dataset = ImageSequencesDataset(
    filepath="../data/05_model_input/notebooks/train/metadata.csv", transform=transform
)

In [None]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

## Tests with Recursive Image Encoder

### Encoder

In [None]:
encoder_rnn = RecursiveImageEncoder(backbone)

### Classification head

In [None]:
classification_head_rnn = MLP(512, (512, 256, 64, 16), 4)

### Classifier

In [None]:
model_rnn = DocumentClassifier(encoder_rnn, classification_head_rnn)

### Test run

In [None]:
model_rnn.eval()
with torch.no_grad():
    model_rnn.to(device)
    for batch in tqdm(dataloader):
        items, lengths, labels = batch
        items = items.to(device)
        lengths = lengths.to(device)
        model_rnn((items, lengths))

### Evaluate

In [None]:
loss, accuracy, precision, recall, f1 = evaluate(model_rnn, dataloader, CrossEntropyLoss())
loss, accuracy, precision, recall, f1

## Test with Avg Image Encoder

### Encoder

In [None]:
encoder_avg = AvgImageEncoder(backbone)

### Classification head

In [None]:
classification_head_avg = MLP(768, (265, 16, 8), 4)

### Classifier

In [None]:
model_avg = DocumentClassifier(encoder_avg, classification_head_avg)

### Test run

In [None]:
model_avg.eval()
with torch.no_grad():
    model_avg.to(device)
    for batch in tqdm(dataloader):
        items, lengths, labels = batch
        items = items.to(device)
        lengths = lengths.to(device)
        model_avg((items, lengths))

### Evaluate

In [None]:
loss, accuracy, precision, recall, f1 = evaluate(model_avg, dataloader, CrossEntropyLoss())
loss, accuracy, precision, recall, f1

## Summary

The zero-shot approach is not the correct approach, and models need to be trained, as clearly shown by all metrics. However, it has been proven that all components of the final model work together.

# Early Stopper

In [None]:
class EarlyStopper:
    def __init__(self, patience: int = 1, min_delta: float = 0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_stop_metric_value = float("inf")
        self.best_model_state_dict: dict | None = None

    def early_stop(self, metric_value: float, model: nn.Module) -> bool:
        if metric_value < self.min_stop_metric_value:
            self.min_stop_metric_value = metric_value
            self.counter = 0
            self.best_model_state_dict = deepcopy(model.state_dict())
        elif metric_value > (self.min_stop_metric_value + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# Train function

In [None]:
class StopMetric(Enum):
    LOSS = 0
    ACCURACY = 1
    PRECISION = 2
    RECALL = 3
    F1 = 4

In [None]:
def freez_model(model):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
def train_one_epoch(model, train_dataloader, optimizer, criterion):
    classifiaction_report = ClassificationReport()
    model.train()
    epoch_loss = 0
    y_true = torch.empty(0).to(device)
    y_pred = torch.empty(0).to(device)

    for items, lengths, labels in tqdm(train_dataloader):
        items, lengths, labels = items.to(device), lengths.to(device), labels.to(device)
        output = model((items, lengths))
        _, lables_pred = output.max(dim=1)

        y_true = torch.cat((y_true, labels), dim=0)
        y_pred = torch.cat((y_pred, lables_pred), dim=0)

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * items.size(0)
    avg_epoch_loss = epoch_loss / len(dataloader.dataset)
    accuracy, precision, recall, f1 = classifiaction_report.generate(y_pred, y_true)
    return avg_epoch_loss, accuracy, precision, recall, f1

In [None]:
def train(model, train_dataloader, valid_dataloader, optimizer, criterion, epochs, stop_metric):
    early_stopper = EarlyStopper(20)
    freez_model(model.encoder.backbone)
    model.to(device)
    for epoch in tqdm(range(epochs)):
        metrics_train = train_one_epoch(model, train_dataloader, optimizer, criterion)
        metrics_valid = evaluate(model, valid_dataloader, criterion)
        print(f"Epoch {epoch + 1}:")
        print(
            f"loss_train={metrics_train[0]}, accuracy_train={metrics_train[1]}, precision_train={metrics_train[2]}, recall_train={metrics_train[3]}, f1_train={metrics_train[4]}"
        )
        print(
            f"loss_valid={metrics_valid[0]}, accuracy_valid={metrics_valid[1]}, precision_valid={metrics_valid[2]}, recall_valid={metrics_valid[3]}, f1_valid={metrics_valid[4]}"
        )
        if early_stopper.early_stop(metrics_valid[stop_metric.value], model):
            print(
                f"Early stopping on epoch {epoch}, {stop_metric.name}: {metrics_valid[stop_metric.value]}"
            )
            break
    if early_stopper.best_model_state_dict is not None:
        model.load_state_dict(early_stopper.best_model_state_dict)

# Train Classifier

## Encoder backbone

In [None]:
backbone = timm.create_model(
    "timm/vit_base_patch16_clip_224.openai", pretrained=True, num_classes=0
)

## Datasets and Dataloaders

In [None]:
dataset_train = ImageSequencesDataset(
    filepath="../data/05_model_input/notebooks/train/metadata.csv", transform=transform
)
dataloader_train = DataLoader(dataset_train, batch_size=20, shuffle=True, collate_fn=collate_fn)

In [None]:
dataset_valid = ImageSequencesDataset(
    filepath="../data/05_model_input/notebooks/valid/metadata.csv", transform=transform
)
dataloader_valid = DataLoader(dataset_valid, batch_size=10, shuffle=False, collate_fn=collate_fn)

## Train Classifer with Recursive Image Encoder

### Encoder

In [None]:
encoder_rnn = RecursiveImageEncoder(backbone)

### Classification head

In [None]:
classification_head_rnn = MLP(512, (512, 256, 64, 16), 4)

### Classifier

In [None]:
model_rnn = DocumentClassifier(encoder_rnn, classification_head_rnn)

### Train

In [None]:
train(
    model_rnn,
    dataloader_train,
    dataloader_valid,
    optim.Adam(model_rnn.parameters(), lr=0.0000001),
    CrossEntropyLoss(),
    10,
    StopMetric.LOSS,
)

## Train Classifier with Avg Image Encoder

### Encoder

In [None]:
encoder_avg = AvgImageEncoder(backbone)

### Classification head

In [None]:
classification_head_avg = MLP(768, (265, 16, 8), 4)

### Classifier

In [None]:
model_avg = DocumentClassifier(encoder_avg, classification_head_avg)

### Train

In [None]:
train(
    model_avg,
    dataloader_train,
    dataloader_valid,
    optim.Adam(model_avg.parameters(), lr=0.00001),
    CrossEntropyLoss(),
    10,
    StopMetric.LOSS,
)