# OCR on handwritten pages
Train a handwriting recognition model on my handwriting to automatically transcribe my mission journals.

State-of-the-art research so far: [Full Page Handwriting Recognition via Image to Sequence Extraction](https://paperswithcode.com/paper/full-page-handwriting-recognition-via-image), by Sumeet S. Singh and Sergey Karayev, 11-Mar-2021. See also: [YouTube presentation from the authors](https://youtu.be/BOIrib04fmE) (~1 hr.).
* Uses a convolutional neural network (ResNet34) as an encoder and a Transformer network as a decoder. Trained on thousands of samples from the IAM, WikiText, and proprietary datasets. The input data was also augmented to help the model generalize.

## References
* [PyTorch implementation](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition) by Tobias van der Werff
* Training data: [Text Recognition Data Generator (`trdg`)](https://github.com/Belval/TextRecognitionDataGenerator), a Python package for generating images of random text that can be used to train handwriting recognition (or OCR) models
* Potential training data: [IAM dataset files, official website](https://fki.tic.heia-fr.ch/databases/download-the-iam-handwriting-database)
* Potential training data: [IAM forms dataset, on Kaggle](https://www.kaggle.com/naderabdalghani/iam-handwritten-forms-dataset)
* Potential training data: [Images of handwritten names, Kaggle](https://www.kaggle.com/landlord/handwriting-recognition)

Others
* [Stanford's CS231n course PDF outline of handwritten text recognition](http://cs231n.stanford.edu/reports/2017/pdfs/810.pdf), including a discussion of the IAM dataset and various data augmentation methods
* [Towards Data Science article](https://towardsdatascience.com/build-a-handwritten-text-recognition-system-using-tensorflow-2326a3487cd5) and associated [GitHub repo](https://github.com/githubharald/SimpleHTR) demonstrating a handwritten text recognition task in TensorFlow using a convolutional neural network combined with a LSTM net to recognize individual words (not full paragraphs or pages).

## Research notes
[**Research paper**](https://paperswithcode.com/paper/full-page-handwriting-recognition-via-image)
* First successful implementation of full-page handwritten text recognition
* Does not require uniformly formatted data (like the IAM dataset)
* Achieves a character-error rate (CER) of about 6%

[**YouTube presentation from the authors**](https://youtu.be/BOIrib04fmE)
* Research goals that this paper addressed:
 - Multi-paragraph text detection in proper sequence
 - Capture indentation
 - Ignore scratched-out text, math, tables, or unrecognized symbols
 - Use character-level generation to avoid the constraint of a language model (with a fixed vocabulary size)
* Prior to performing image-to-sequence extraction, you need to convert the image from grayscale to black-and-white (to assist in focusing on the text)
* This method is called "offline handwritten text recognition (HTR)", which means the input is an image of text. In contrast, "online HTR" means that the input is pen strokes, with directional and timing data, such as what could be obtained from a digital stylus on a Microsoft Surface or Apple iPad.
* You can include special sequence tags like `<deleted text>` or `<side note>` and the model can learn to recognize those regions
* "If you added a Transformer at the end of the CNN encoder, it would probably increase accuracy, but we found it wasn't necessary and model size was too large with a Transformer encoder as well"
* "You don't need line-number encodings in the decoder, if you omit them, your model will perform almost as well."
* "Decoder uses local attention, though Transformers commonly use global attention. We found that local attention didn't reduce efficacy and trained faster, since we could use larger batch sizes."
* "We didn't use beam search"
* "We used ResNet34 for the encoder, which has 21M parameters. You could use any CNN, though. We did not use the pre-trained version and trained ours from scratch."
* "Inference takes 4.6 seconds per page on a CPU thread, when the input image is 2500x2200 pixels with 456 characters and 11.65 lines."
* "Decoder is a Transformer with 6M parameters. We use cross-entropy loss, dropout, and cross attention on the entire image."
* "The longest part, where we spent the most time, was preparing the data to use for model training. Ideally, you want 100k samples. We had about 20k after augmenting our proprietary 13k samples. We also used WikiText rendered in over 300 fonts, with varying degrees of skewness, text layout on page, contrast, brightness, and blank images. You really need a lot of data to train these models on. The goal is to have a large enough sample base that it is independent and identically distributed, so mini-batch samples all reflect roughly the same (identical) data distribution."
* "Model doesn't perform well on out-of-distribution data."
* "The model is sensitive to image padding"
* Encoder takes up a lot of parameters, potentially an improvement area would be creating a more efficient encoder.
* "Character error rate on full-page data is about 6.3%"
* "During training, weights are 32 bits, gradients are 16 bits"
* "We use AWS Lambda for inferencing, so inferencing happens on CPUs"

# Neural Network
Adapted from [Full Page Handwriting Recognition via Image to Sequence Extraction](https://paperswithcode.com/paper/full-page-handwriting-recognition-via-image), by Singh et al. in March 2021. See pages 4-8 of the paper.

## Model architecture

Screenshot from pg. 4 of the paper by Singh et al.

<img src="/imgs/Model-Architecture_FPHR_Singh-et-al_2021.png" width="600">


By the way, Visual Studio Code has trouble loading images in Markdown if the images are stored locally. Here are other potential ways to view the image using Markdown, from [StackOverflow answer](https://stackoverflow.com/questions/32370281/how-to-embed-image-or-picture-in-jupyter-notebook-either-from-a-local-machine-o):

```md
![Screenshot from pg. 4 of the paper by Singh et al.](https://drive.google.com/uc?id=1HQL3j2eVXS_N5KEHDpvWR4YVE1jhEcg3&export=download)

# Original link: https://drive.google.com/file/d/1HQL3j2eVXS_N5KEHDpvWR4YVE1jhEcg3/view

<img src="Model-Architecture_FPHR_Singh-et-al_2021.png" style="height:300px">
```

## Library imports

In [None]:
# Mount Google Drive for access to data
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Set variables with paths to the image and label folders, which are accessed by
# the dataset on initialization. The expected folder structure looks like this:

# /content/drive/MyDrive/data
#  |-- training_data
#    |-- labels
#      |-- train
#      |-- test
#    |-- processed_images
#      |-- train
#      |-- test
#  |-- journal_data
#    |-- labels
#      |-- train
#      |-- test
#    |-- processed_images
#      |-- train
#      |-- test

# Where training_data has generated images of full-page text
# and journal_data has images of actual handwritten text

path_to_image_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/processed_images'
path_to_label_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/labels'

# File paths to the folders with label data
label_folders = [
    r'/content/drive/MyDrive/School/Deep learning final project/training_data/labels/train',
    r'/content/drive/MyDrive/School/Deep learning final project/training_data/labels/test',
    r'/content/drive/MyDrive/School/Deep learning final project/journal_data/labels/train',
    r'/content/drive/MyDrive/School/Deep learning final project/journal_data/labels/test']

In [None]:
%pip install torchmetrics
from torchmetrics import Metric             # Package for computing measurements on PyTorch models



In [None]:
# PyTorch modules for neural network
import torch                                # PyTorch
import torch.nn as nn                       # Neural network module
import torch.optim as optim                 # Optimizer module (e.g., SGD, Adam)
import torchvision                          # Computer vision module
from torchvision import transforms          # Image augmentation

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("GPU is available, device set to GPU." if torch.cuda.is_available() 
        else "GPU unavailable, device set to CPU.")

# Other packages
import numpy as np                          # Array and mathematical functions
from tqdm.notebook import tqdm              # Progress bars
import matplotlib.pyplot as plt             # Plotting functionality
from PIL import Image                       # Loading and manipulating images
import random                               # Create train-test split of data
import shutil                               # Move files
import os                                   # Work with files and folders
import math                                 # Use the log() function (for positional encoding)
import string                               # Character lists (e.g., for character-level vocab)
import editdistance                         # Implementation of Levenshtein edit distance
import gc                                   # Garbage collector, for clearing memory
from typing import Callable, Optional       # Modules for type annotations (type hints)

GPU is available, device set to GPU.


In [None]:
!nvidia-smi

Tue May  3 16:42:02 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   62C    P8    11W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Metrics and utility functions
Adapted from: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition -> metrics.py](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/master/src/metrics.py)

In [None]:
# from torchmetrics import Metric
# import editdistance

class CharacterErrorRate(Metric):
    '''
    Calculates the character error rate, a measurement of the percentage
    of characters that were predicted incorrectly.

    Calculation: Levenshtein edit distance / length of target.
    '''

    def __init__(self):
        super().__init__()
        self.add_state("edits", default=torch.Tensor([0]), dist_reduce_fx="sum")
        self.add_state("total_chars", default=torch.Tensor([0]), dist_reduce_fx="sum")
    
    def update(self, predictions, targets):
        '''
        Updates the running count of the number of edits
        and ground truth characters.

        Parameters
        ---
        `predictions`: Tensor of shape (batch_size, predicted_characters)
        `targets`: Tensor of shape (batch_size, target_characters)
        '''
        assert predictions.ndim == targets.ndim
        
        sos_token_idx = token_to_index['<START>']
        eos_token_idx = token_to_index['<END>']

        # Check if the first token is the start token (i.e., the standard situation)
        if (predictions[:, 0] == sos_token_idx).all():
            # Remove the start token
            predictions = predictions[:, 1:]
        
        eos_idxs_pred = (predictions == eos_token_idx).float().argmax(dim=1).tolist()
        eos_idxs_tgt = (targets == eos_token_idx).float().argmax(dim=1).tolist()

        for i, (p, t) in enumerate(zip(predictions, targets)):
            eos_idx_p, eos_idx_t = eos_idxs_pred[i], eos_idxs_tgt[i]
            p = p[:eos_idx_p] if eos_idx_p else p
            t = t[:eos_idx_t] if eos_idx_t else t
            # Convert the predictions and target tensors to
            # strings of indexes.
            # Note that the editdistance package can handle
            # any hashable iterable (like lists of strings), not just strings.
            # see: https://pypi.org/project/editdistance/#distance-with-any-object
            p_str, t_str = map(lambda tsr: str(tsr.flatten().tolist()), (p, t))
            edit_d = editdistance.eval(p_str, t_str)

            self.edits += edit_d
            self.total_chars += t.numel()
        
    def compute(self):
        '''Calculate the character error rate'''
        return self.edits.float() / self.total_chars

In [None]:
# Testing the tensor-to-string function used on line 49 in the cell above
t = torch.tensor([[9, 8, 3], [1, 4, 91]])
to_str = "".join(map(str, t.flatten().tolist()))
to_str

'9831491'

In [None]:
# import string

# Define a set to store character-level vocabulary from the dataset
dataset_vocab = set()

# Find all possible characters the model will encounter during training

total_count = 0
for folder in folders:
    total_count += len(os.listdir(folder))

p_bar = tqdm(total=total_count, leave=False)
p_bar.set_description('Reading files in dataset')

# Only the final two folders are needed, since those contain
# custom-labeled data rather than generated labels. The generated
# labels use only ascii-printable characters, but the custom labels
# may have other characters.
for folder in label_folders[2:4]:
    for onefile in os.scandir(folder):
        with open(onefile.path, mode='rt', encoding='utf-8') as labelfile:
            # Wrapping labelfile.read() in list() converts the string to
            # a character-level list
            dataset_vocab.update(list(labelfile.read()))
        p_bar.update(n=1)
p_bar.close()


new_characters = [char for char in dataset_vocab if char not in list(string.printable)]
special_tokens = ['<START>', '<END>', '<PAD>', '<INSERT>', '</INSERT>']
all_chars = list(string.printable) + new_characters
all_chars = sorted(all_chars)
all_chars = special_tokens + all_chars

index_to_token = dict(enumerate(all_chars))
token_to_index = {char: i for i, char in index_to_token.items()}

print(f"Total number of characters: {len(token_to_index)}")
print(token_to_index.keys())

  0%|          | 0/10 [00:00<?, ?it/s]

Total number of characters: 116
dict_keys(['<START>', '<END>', '<PAD>', '<INSERT>', '</INSERT>', '\t', '\n', '\x0b', '\x0c', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¿', 'É', 'á', 'í', 'ñ', 'ó', 'ú', '•', '❤', '🙂'])


Text-to-tensor conversion

In [None]:
def text_to_tensor(
    text: str,
    token_to_idx: dict,
    start_token: str = '<START>',
    end_token: str = '<END>',
    pad_token: str = '<PAD>',
    max_len: int = 3500,
    multichar_tokens: list = ['<INSERT>', '</INSERT>'],
    placeholder_chars: list = ['🔴', '🟠'],
    device: torch.device = device,
    dtype: torch.dtype = torch.long) -> torch.Tensor:
    '''
    Takes an input string (labels) and converts it to
    a Tensor of d_type Long (i.e., an integer Tensor),
    where each item in the Tensor is an index into the
    `token_to_idx` vocabulary dictionary.

    Also adds the `start_token` to the beginning of
    the tensor, the `end_token` after all the tokens
    in the `text`, and fills the remaining space up
    to `max_len` tokens with the `pad_token`.

    Returns a Tensor of d_type Long.

    Parameters
    ---
    `text`: str
        The text string to be converted into a Tensor.
        This will be the label (target) text.
    `token_to_idx`: dict
        A token-to-index dictionary for each token
        in the vocabulary.
    `start_token`: str, default='<START>'
        The token to be added at the beginning of each 
        `text` sequence.
    `end_token`: str, default='<END>'
        The token to be added after all the tokens in the
        `text` sequence.
    `pad_token`: str, default='<PAD>'
        The token to use for padding the given `text` up to
        the `max_len` number of tokens.
    `max_len`: int, default=3500
        The maximum length (in tokens) of the provided `text`.
        Raises an error if `text` exceeds `max_len`.
        If `text` has fewer tokens than `max_len`, the remaining
        tokens will be padded with `pad_token`.
    `multichar_tokens`: list, default=['<INSERT>', '</INSERT>']
        Since this is a character-level vocabulary, any multi-character
        sequences that should be treated as a single token must be replaced
        with a single-character placeholder before iterating through the `text`
        character-by-character, and will be converted back into the 
        multi-character token during that iteration.
    `placeholder_chars`: list, default=['🔴', '🟠']
        The single-character placeholders that will take the place of
        multi-character tokens so the function can iterate through the
        `text` character-by-character but still recognize the correct
        tokens. These tokens must not exist within the `token_to_idx`.
    `device`: a torch.device object
        The device on which to place the returned tensor. Can be either
        torch.device('cpu') or torch.device('cuda').
    `dtype`: a torch datatype, default=torch.long
        See: https://pytorch.org/docs/stable/tensor_attributes.html
        for a list of the acceptable data types. Since I am using
        a character-level vocabulary with fewer than 256 characters,
        torch.uint8 suits my needs and save memory.
    '''
    assert len(text) <= max_len, (
        f"The given `text` has more tokens (characters) than the `max_len` setting allows. "
        f"Please increase the size of max_len to accomodate the input text. "
        f"len(text)={len(text)}, and max_len={max_len}."
    )

    if type(multichar_tokens) != list:
        multichar_tokens = [multichar_tokens]

    if type(placeholder_chars) != list:
        placeholder_chars = [placeholder_chars]
    
    for char in placeholder_chars:
        assert char not in list(token_to_idx.keys()), (
            f"The provided placeholder character {char} is in the `token_to_idx`."
            + "Please use a placeholder character not found in the `token_to_idx`"
        )
    
    # List of token indices for each character in the input text.
    # Begin the sequence with the '<START>' token.
    index_list = [token_to_idx[start_token]]

    # Replace multi-character tokens with single-character placeholders
    for i, token in enumerate(multichar_tokens):
        replaced_text = text.replace(token, placeholder_chars[i])
    
    # Iterate through the text character-by-character and return integer indices
    for char in replaced_text:
        if char in placeholder_chars:
            idx = placeholder_chars.index(char)
            index_list.append(token_to_idx[multichar_tokens[idx]])
        else:
            index_list.append(token_to_idx[char])
    
    # End the sequence with the '<END>' token.
    index_list.append(token_to_idx[end_token])

    # Fill remaining space in sequence with '<PAD>' token up to max_len
    if len(index_list) < max_len:
        num_pads = max_len - len(index_list)
        pad_list = [token_to_idx[pad_token]] * num_pads
        index_list += pad_list


    # Convert the index_list into a Tensor
    return torch.tensor(index_list, dtype=dtype, device=device)

Tensor-to-text conversion

In [None]:
def tensor_to_text(
    tensor: torch.Tensor,
    idx_to_token: dict) -> str:
    '''
    Takes an input Tensor and converts it to
    a string by returning the `idx_to_token` token
    for each index given in the tensor.

    Returns a string.

    Parameters
    ---
    `tensor`: torch.Tensor
        The integer-type Tensor that holds indices
        to tokens in the idx_to_token dict.
    `idx_to_token`: dict
        An index-to-token dictionary for each token
        in the vocabulary.
    '''
    return_string = ''
    for item in tensor:
        return_string += idx_to_token[item.item()]
    return return_string

Testing the conversion functions

In [None]:
test_str = """This is a test paragraph, to see if
the text-to-tensor and tensor-to-text functions work correctly.
    This is a sample indented line.

This is a new paragraph, with a <INSERT>superscript</INSERT> added.

It's also possible to have special characters, like ~ or á."""

test_tensor = text_to_tensor(test_str, token_to_idx=token_to_index)
print(test_tensor[:20])
print('='*20, '\n', '='*20, sep='')
print(tensor_to_text(test_tensor, idx_to_token=index_to_token)[:302] + ' ...to max_len tokens (3500 by default)')

tensor([ 0, 62, 82, 83, 93, 10, 83, 93, 10, 75, 10, 94, 79, 93, 94, 10, 90, 75,
        92, 75], device='cuda:0')
<START>This is a test paragraph, to see if
the text-to-tensor and tensor-to-text functions work correctly.
    This is a sample indented line.

This is a new paragraph, with a <INSERT>superscript</INSERT> added.

It's also possible to have special characters, like ~ or á.<END><PAD><PAD><PAD><PAD><PAD> ...to max_len tokens (3500 by default)


## Dataset and DataLoader

### Dataset
The dataset retrieves an item (image with its associated text label) from either a 'train' or 'test' folder and returns a (image, label) tuple, where both the image and label are `torch.Tensor`s.

**Augmentations**

The dataset will also implement augmentation, including:
* padding images to match the size of the largest image in the dataset (so all images in all batches are the same size). For the 10,000 training images and 10 fine-tuning images I am using, the **largest width is 2295 pixels**, and the **largest height is 1884 pixels**.
* scaling
* rotation
* brightness
* (background color?)
* contrast
* perspective
* Gaussian noise

During training, images can be randomly placed anywhere within the padding dimensions of the largest image in the batch, but during testing the images should be centered (still within padding).

For more information on augmentation used in the paper, see Singh et al., p. 10-11 ("Image Augmentation" and "Data Sampling" sections).

**Dataset references**
* [Examples of custom datasets](https://github.com/utkuozbulak/pytorch-custom-dataset-examples)
* [PyTorch docs: Datasets](https://pytorch.org/vision/stable/datasets.html)
* [PyTorch docs: DataLoaders and Datasets](https://pytorch.org/docs/stable/data.html)
* [PyTorch docs: custom datasets (in particular, see `VisionDataset`)](https://pytorch.org/vision/stable/datasets.html#base-classes-for-custom-datasets)

**Transforms references**
* [PyTorch docs: Transforming and Augmenting Images](https://pytorch.org/vision/stable/transforms.html)
* [PyTorch docs: illustrated examples of transforms](https://pytorch.org/vision/stable/auto_examples/plot_transforms.html)
* [PyTorch source code (GitHub): transforms.py](https://github.com/pytorch/vision/blob/main/torchvision/transforms/transforms.py)

**Note:** `torchvision.transforms.RandomApply(transforms=[], p=0.5)` will apply _all_ transformations in the list if _any_ transformation is applied (based on the probability).

See also: one-step transform: [`TrivialAugmentWide`](https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#trivialaugmentwide)

In [None]:
class FPHRdataset(torch.utils.data.Dataset):
    def __init__(
        self,
        path_to_image_folder,
        path_to_label_folder,
        max_width,
        max_height,
        augmentation_likelihood: float = 0.5,
        train: bool = True,
        vocab_dict: dict=token_to_index):
        '''
        Args
        ---
        `path_to_image_folder`: str
            Root directory where images are stored. Assumes that
            within that directory are two sub-directories:
            'test' and 'train', each with images.
        `path_to_label_folder`: str
            Root directory where labels are stored. Assumes that
            within that directory are two sub-directories:
            'test' and 'train', each with labels.
        `max_width`: int
            Defines the size for each image; that is, each image will be padded
            so its width equals max_width (in pixels). Should be at least
            as large as the maximum width (in pixels) in any image
            in the entire dataset.
        `max_height`: int
            Defines the size for each image; that is, each image will be padded
            so its height equals max_height (in pixels). Should be at least
            as large as the maximum height (in pixels) in any image
            in the entire dataset.
        `augmentation_likelihood`: float, default=0.5
            Sets the probability that any one of the augmentation transforms
            will be applied to the input image.
        `train`: bool, default=True
            If True, uses the 'train' subdirectories in the `path_...`
            arguments. If False, uses the 'test' subdirectories.
        `vocab_dict`: dict
            Dictionary where keys are tokens in the vocabulary 
            (single characters, in this case), and values are indices of those
            tokens.
        '''
        # No need to inherit from the base class (torch.utils.data.Dataset)
        # That is, we don't need: super().__init__()

        self.max_width = max_width
        self.max_height = max_height
        self.augment_prob = augmentation_likelihood
        self.train = train

        if self.train:
            # Store the paths to the directories for images and labels
            self.path_to_image_folder = os.path.join(path_to_image_folder, 'train')
            self.path_to_label_folder = os.path.join(path_to_label_folder, 'train')
        else:
            self.path_to_image_folder = os.path.join(path_to_image_folder, 'test')
            self.path_to_label_folder = os.path.join(path_to_label_folder, 'test')
        
        # Store all image file paths in memory
        self.image_files = [os.path.join(self.path_to_image_folder, file) for file in os.listdir(self.path_to_image_folder)]
        # Store all label file paths in memory
        self.label_files = [os.path.join(self.path_to_label_folder, file) for file in os.listdir(self.path_to_label_folder)]


    def create_transformations(self, padding, t_horizontal, t_vertical, augment_prob, train=True):
        '''
        Compose the transformations to be applied to the input image.
        '''
        if train:
            # Perform augmentations with torchvision.transforms.
            # Order of transforms:
            # 1. ColorJitter: adjusts brightness, contrast, and saturation
            # 2. RandomAffine: changes scale, rotates image, and translates (moves) image
            # 3. RandomPerspective: shifts the perspective of the image
            # 4. GaussianBlur: appplies a Gaussian blur to the image (in place of adding Gaussian noise)
            # 5. Pad: pad image so all images are the same size 
            # 6. ToTensor: Converts the image into a torch.Tensor
            augmentations = transforms.Compose(
                transforms = [
                    transforms.RandomApply(
                        transforms = [
                            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                            transforms.RandomAffine(degrees=2.5, translate=(t_horizontal, t_vertical), scale=(0.9, 1.0)),
                            transforms.RandomPerspective(distortion_scale=0.1, p=augment_prob)],
                        p = augment_prob),
                    transforms.RandomApply(
                        transforms = [
                            transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 1.0))],
                        p = augment_prob),
                    transforms.Pad(padding=padding),
                    transforms.ToTensor()
                    ])
        else:
            # For inference (testing), just center the image within padding space
            augmentations = transforms.Compose([transforms.Pad(padding=padding), transforms.ToTensor()])
        
        return augmentations


    def __getitem__(self, index):
        # =====================================
        # Set values needed for transformations
        #  (padding, vertical and horizontal translation)
        input_img = Image.open(self.image_files[index])
        width, height = input_img.size
        pad_width, pad_height = (self.max_width - width, self.max_height - height)

        if any([pad_width % 2 == 1, pad_height % 2 == 1]):
            # padding cannot be evenly split between sides of image
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            pad_top = pad_height // 2
            pad_bottom = pad_height - pad_top
            padding = (pad_left, pad_top, pad_right, pad_bottom)
        else:
            # padding can be evenly split between sides of image
            pad_leftright = pad_width // 2
            pad_topbottom = pad_height // 2
            padding = (pad_leftright, pad_topbottom)

        # Determine the fraction of image width and height used for translation
        #  such that there is still padding on all sides of the image.
        # left/right translation
        t_horizontal = (pad_width / 2) / self.max_width
        # top/bottom translation
        t_vertical = (pad_height / 2) / self.max_height
        # =====================================
        # =====================================

        # Compose the transformations for the input
        augmentations = self.create_transformations(
            padding = padding,
            t_horizontal = t_horizontal,
            t_vertical = t_vertical,
            augment_prob = self.augment_prob,
            train = self.train)
        # Apply transformations
        img_tensor = augmentations(input_img)
        
        with open(self.label_files[index], mode='rt', encoding='utf-8') as labelfile:
            target_label = labelfile.read()
        
        target_label_tensor = text_to_tensor(
            text = target_label,
            token_to_idx = vocab_dict,
            multichar_tokens = ['<INSERT>', '</INSERT>'],
            placeholder_chars = ['🔴', '🟠'],
            device = device,
            dtype = torch.long
        )

        # print(f"Padding size: {padding}")
        # print(f"Horizontal translation percent {t_horizontal:.2f}")
        # print(f"Vertical translation percent {t_vertical:.2f}")
        
        return (img_tensor.to(device), target_label_tensor.to(device))


    def __len__(self):
        return len(self.image_files)

## Positional encoding

### 1D positional encoding (for decoder Transformer)

Implementation reference: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition/src/models](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/47a2d27fc40815898474e9f74badc7d544740fee/src/models.py#L18).

See also:
* [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding), Harvard's implementation of the "Attention Is All You Need" (2017) paper by Vaswani et al.
* [How to Code the Transformer in PyTorch](https://blog.floydhub.com/the-transformer-in-pytorch/#giving-our-words-context-the-positional-encoding), FloydHub blog
* [Transformers from Scratch](http://peterbloem.nl/blog/transformers#input-using-the-positions), Peter Bloem's blog, with accompanying code on [GitHub](https://github.com/pbloem/former) and videos on [YouTube](https://www.youtube.com/playlist?list=PLIXJ-Sacf8u60G1TwcznBmK6rEL3gmZmV). Note that Peter Bloem uses positional _embeddings_ rather than encodings. The drawback with embeddings is that each possible sequence length needs to be seen during training, or that length will not be learned by the network.

#### Implement the function for positional encodings
Implementation of the function proposed in the paper "Attention Is All You Need" by Vaswani et al. (2017).

$PE_{(pos, 2i)} = sin(\frac{pos}{10000^{2i/d_{model}}})$

$PE_{(pos, 2i+1)} = cos(\frac{pos}{10000^{2i/d_{model}}})$

Where $pos$ is the token position and $i$ is the embedding dimension (which matches the number of embeddings in the model).

This creates a 2D matrix of $pos$ rows by $d_{model}$ columns. There are as many rows as the max_length parameter set when creating the positional encodings, and as many columns as there are embedding dimensions.

In [None]:
class PositionalEncoding1D(nn.Module):
    '''
    Positional encodings for labels (text sequences),
    which have a single dimension (tokens).

    Adapted from Tobias van der Werff's GitHub repo:
      full-page-handwriting-recognition
    '''
    def __init__(self, d_model, max_len=3500):
        # Inherit from parent class (nn.Module)
        super().__init__()

        self.max_len = max_len

        # Compute positional encodings in logarithmic space
        # This is a 2D matrix (tensor) with position (max_len) as the rows
        # and model embedding dimension (d_model) as the columns
        pe = torch.zeros((max_len, d_model), requires_grad=False)
        
        # A 2D tensor with max_len rows and 1 column. Each row holds the
        # value of that position (from 0 to max_len)
        position = torch.arange(0, max_len).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )

        # Indexing technique accesses all rows (sequence position),
        # but skips every other column (embedding dimension)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Create a dimension for the batch.
        # The dimensions of pe will now be (batch_size, max_len, d_model)
        pe = pe.unsqueeze(0)

        # Save the encodings in the model, without registering them as a parameter
        # (a "buffer" is saved to the model, but not as a learnable paramater).
        self.register_buffer("pos_enc", pe)
    
    def forward(self, x):
        '''
        Add a 1D positional encoding to an embedding tensor.

        Args
        ---
        `x`: Tensor of shape (batch_size, num_tokens, d_model)
            The embedding tensor to which positional encodings
            will be added.
        
        Returns
        ---
        The sum of the input `x` (embedding tensor) with the positional
        encodings.
        '''
        _, T, _ = x.shape
        assert T <= self.pos_enc.size(1), (
            f"The given embedding has {T} tokens, which is more than the max length stored in the positional encodings ({self.max_len}). "
            + "\nPlease increase the 'max_len' argument in the `PositionalEncoding1D` instance so max_len is greater than the number of tokens in the input."
        )

        # Add positional encodings to the input tensor, across all batches
        # and up to the number of rows as there are input tokens.
        return x + self.pos_enc[:, :T]

### 2D positional encoding (for encoder ResNet)

Implementation reference: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition/src/models](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/47a2d27fc40815898474e9f74badc7d544740fee/src/models.py#L56)

p. 7 of Singh et al. (2021) states:

>    "The encoder uses a CNN to extract a 2D feature-map from the input image.
    It uses the ResNet architecture without its last two layers: 
    the average-pool and linear projection. The feature-map is then projected 
    to match the Transformer's hidden-size dmodel, then a 2D positional encoding 
    added and finally flattened into a 1D sequence. 2D positional encoding is a 
    fixed sinusoidal encoding as in Vaswani et al. (2017), but using the first 
    dmodel/2 channels to encode the Y coordinate and the rest to encode the 
    X coordinate (similar to Parmar et al. (Image Transformer, 2018)).
    Output I of the Flatten layer is made available to all Transformer decoder layers,
    as is standard."

$PE_{(y, 2i)} = sin(\frac{y}{10000^{2i/d_{model}}})$

$PE_{(y, 2i+1)} = cos(\frac{y}{10000^{2i/d_{model}}})$

$PE_{(x, \frac{d_{model}}{2} + 2i)} = sin(\frac{x}{10000^{2i/d_{model}}})$

$PE_{(x, \frac{d_{model}}{2} + 2i + 1)} = cos(\frac{x}{10000^{2i/d_{model}}})$

Where $x$ is the x-coordinate in the feature map, $y$ is the y-coordinate in the feature map, and $i$ is the embedding dimension (which matches the embedding dimensions of the model).

This creates a 3D matrix (tensor) of $y$ rows by $x$ columns by $d_{model}$ channels. There are as many rows as the height (in pixels) of the feature map output by the encoder CNN, as many columns as the width (in pixels) of the feature map, and as many channels as there are embedding dimensions.

In [None]:
class PositionalEncoding2D(nn.Module):
    '''
    Positional encodings for inputs (feature maps),
    which have three dimensions (channels, height, width).

    The feature maps are the final output from the Convolutional
    Neural Network used as an image encoder, such as the ResNet 
    models used in this implementation.

    p. 7 of Singh et al. (2021) states:

    "The encoder uses a CNN to extract a 2D feature-map from the input image.
    It uses the ResNet architecture without its last two layers: 
    the average-pool and linear projection. The feature-map is then projected 
    to match the Transformer's hidden-size dmodel, then a 2D positional encoding 
    added and finally flattened into a 1D sequence. 2D positional encoding is a 
    fixed sinusoidal encoding as in Vaswani et al. (2017), but using the first 
    dmodel/2 channels to encode the Y coordinate and the rest to encode the 
    X coordinate (similar to Parmar et al. (Image Transformer, 2018)).
    Output I of the Flatten layer is made available to all Transformer decoder layers,
    as is standard."

    Adapted from Tobias van der Werff's GitHub repo:
      full-page-handwriting-recognition
    '''
    def __init__(self, d_model, max_len=100):
        '''

        Args
        ---
        d_model: int
            the number of embedding dimensions in the model

        max_len: int
            the maximum size (in pixels) for the height and width
            of the feature map output from the encoder CNN
        '''
        # Inherit from parent class (nn.Module)
        super().__init__()

        self.max_len = max_len

        assert d_model % 4 == 0, f"Model dimensions must be divisible by 4. `d_model` is {d_model}, which is not divisible by 4."

        # Initialize positional encodings.
        # These are 2D tensors of max_len rows and (d_model / 2) columns.
        pe_x = torch.zeros((max_len, d_model // 2), requires_grad=False)
        pe_y = torch.zeros((max_len, d_model //2), requires_grad=False)

        # A 2D tensor with max_len rows and 1 column. Each row holds the
        # value of that position (from 0 to max_len)
        position = torch.arange(0, max_len).unsqueeze(1)
        
        # === COMMENT FROM TOBIAS van der WERFF ===
        # Div term term is calculated in log space, as done in other implementations;
        # this is most likely for numerical stability (i.e., precision). The expression below is
        # equivalent to:
        #     div_term = 10000 ** (torch.arange(0, d_model // 2, 2) / d_model)
        # === END ===
        
        div_term = torch.exp(
            -math.log(10000.0) * torch.arange(0, d_model // 2, 2) / d_model
        )

        # Calculate the positional encodings
        # Indexing technique accesses all rows (y or x coordinate),
        # but skips every other column (embedding dimension).
        # Shape is (max_len, d_model / 4)
        pe_y[:, 0::2] = torch.sin(position * div_term)
        pe_y[:, 1::2] = torch.cos(position * div_term)
        pe_x[:, 0::2] = torch.sin(position * div_term)
        pe_x[:, 1::2] = torch.cos(position * div_term)

        # Save the encodings in the model's state, without registering them as a parameter
        # (a "buffer" is saved to the model, but not as a learnable paramater).
        self.register_buffer("pos_enc_y", pe_y)
        self.register_buffer("pos_enc_x", pe_x)
    
    def forward(self, x):
        '''
        Add a 2D positional encoding to an embedding tensor.

        Args
        ---
        `x`: Tensor of shape (batch_size, width, height, d_model)
            The embedding tensor to which positional encodings
            will be added.
        
        Returns
        ---
        The sum of the input `x` (embedding tensor) with the positional
        encodings.
        '''
        _, w, h, _ = x.shape
        assert w <= self.pos_enc_x.size(0) and h <= self.pos_enc_y.size(0), (
            "The stored positional encodings do not have enough dimensions to support the input feature map."
            + f"\nInput feature map must be of dimensions less than or equal to {self.pos_enc_x.size(0)} width by {self.pos_enc_x.size(0)} height, "
            + f"\nbut the input provided has dimensions {w} width by {h} height."
            + f"Please re-initialize the `PositionalEncoding2D` instance and set the 'max_len' argument greater than or equal to {max(w, h)}." 
        )

        # Add positional encodings to the input tensor, across all d_model dimensions (columns)
        # and up to the number of rows as there are input dimensions (width and height of input feature map).
        pe_x_ = self.pos_enc_x[:w, :].unsqueeze(1).expand(-1, h, -1)    # shape: (w, h, d_model / 2)
        pe_y_ = self.pos_enc_y[:h, :].unsqueeze(0).expand(w, -1, -1)    # shape: (w, h, d_model / 2)
        # Combine across the channel dimension (z-axis of the 3D tensor)
        pe = torch.cat([pe_y_, pe_x_], -1)                              # shape: (w, h, d_model)
        # Unsqueeze to accommodate batch dimension
        pe = pe.unsqueeze(0)                                            # shape: (1, w, h, d_model)
        
        return x + pe

## ResNet encoder
Adapted from: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition/src/models](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/47a2d27fc40815898474e9f74badc7d544740fee/src/models.py#L250)

ResNet is included as one of the models in [`torchvision.models`](https://pytorch.org/vision/stable/models.html). The source code is found in the [PyTorch Torchvision GitHub repo](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py).

In [None]:
class ResNetEncoder(nn.Module):
    """
    Takes as input grayscale (single-channel) images
    and outputs an encoding of the image that the decoder
    attends over to produce the sequential output (text).
    """

    def __init__(self, d_model: int, model_name: str, dropout: float):
        """
        Initialize the encoder part of the network.

        Parameters
        ---
        `d_model`: int
            The number of dimensions in the model's embedding space.
            Singh et al. (p. 9) used d_model=260.
        `model_name`: {'resnet18', 'resnet34', 'resnet50'}
            The model to use for the encoder. PyTorch supports a number
            of ResNet architectures, but the three implemented here are
            the ones used by Singh et al. (base model was resnet34)
        `dropout`: float
            The percentage of neurons to mask during dropout on the final
            layer of the encoder.
        """
        # Inherit the methods of nn.Module
        super().__init__()

        # Check that inputs are specified correctly
        assert d_model % 4 == 0, f"Model embedding dimension ('d_model') must be divisible by 4. Provided value was: {d_model}."
        _models = ['resnet18', 'resnet34', 'resnet50']
        err_message = f"{model_name} is not one of the available models: {_models}"
        assert model_name in _models, err_message

        # Save attributes
        self.d_model = d_model
        self.model_name = model_name
        self.pos_encoding = PositionalEncoding2D(d_model)
        self.drop = nn.Dropout(p=dropout)

        # Load the resnet model from torchvision.models with random weights
        resnet = getattr(torchvision.models, model_name)(pretrained=False)
        modules = list(resnet.children())

        # Modify the first convolutional layer so its input takes a single-channel image
        # (i.e., grayscale rather than RGB)
        conv1 = modules[0]
        conv1 = nn.Conv2d(
            in_channels = 1,
            out_channels = conv1.out_channels,
            kernel_size = conv1.kernel_size,
            stride = conv1.stride,
            padding = conv1.padding,
            bias = conv1.bias
        )

        # Create the encoder module by combining conv1 with the rest of the ResNet
        # except the last two layers, which will be replaced by 1x1 conv and positional encoding.
        # The final two layers of ResNet are avgpool and fully-connected linear. See: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L203
        self.encoder = nn.Sequential(conv1, *modules[1:-2])
        # Define the linear layer, a 1x1 convolution that takes as input the same number of features
        # as the final ResNet fully-connected layer (which we removed in the line above)
        self.linear = nn.Conv2d(in_channels=resnet.fc.in_features, out_channels=d_model, kernel_size=1)


    def forward(self, imgs):
        x = self.encoder(imgs)                              # x shape: (batch_size, d_model, w, h)
        x = self.linear(x).transpose(1, 2).transpose(2, 3)  # x shape: (batch_size, w, h, d_model)
        x = self.pos_encoding(x)                            # x shape: (batch_size, w, h, d_model)
        x = self.drop(x)                                    # x shape: (batch_size, w, h, d_model)
        x = x.flatten(1, 2)                                 # x shape: (batch_size, w*h, d_model)
        return x

## Transformer decoder
In "[Full Page Handwriting Recognition via Sequence-to-Sequence Extraction](https://paperswithcode.com/paper/full-page-handwriting-recognition-via-image)" (2021), Singh et al. describe the hyperparameters they used for the transformer decoder on pages 8-9 (section 4: Training Configuration and Procedure).

PyTorch has a built-in module [`torch.nn.TransformerDecoder`](https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html) for creating decoder-only Transformer networks.

This is not to be confused with the `torch.nn.Transformer` module which implements the full encoder-decoder Transformer from the paper "Attention is All You Need" by Vaswani et al. (2017).

Adapted from: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition -> models.py](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/47a2d27fc40815898474e9f74badc7d544740fee/src/models.py#L107)

In [None]:
class TransformerDecoder(nn.Module):
    # Class-level variables (with type annotations)
    decoder: nn.TransformerDecoder
    clf: nn.Linear
    emb: nn.Embedding
    pos_enc: PositionalEncoding1D
    drop: nn.Dropout
    
    vocab_len: int
    max_seq_len: int
    eos_token_idx: int
    sos_token_idx: int
    pad_token_idx: int
    d_model: int
    num_layers: int
    num_heads: int
    dim_feedforward: int
    dropout: float
    activation_fn: str


    def __init__(
        self,
        vocab_len: int,
        max_seq_len: int,
        eos_token_idx: int,
        sos_token_idx: int,
        pad_token_idx: int,
        d_model: int,
        num_layers: int,
        num_heads: int,
        dim_feedforward: int,
        dropout: float,
        activation_fn: str = 'gelu'
    ):
        # Inherit methods from the base class (nn.Module)
        super().__init__()
        assert d_model % 4 == 0, f"Model dimensions must be divisible by 4. `d_model` is {d_model}, which is not divisible by 4."

        self.vocab_len = vocab_len
        self.max_seq_len = max_seq_len
        self.eos_token_idx = eos_token_idx
        self.sos_token_idx = sos_token_idx
        self.pad_token_idx = pad_token_idx
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation_fn = activation_fn

        self.emb = nn.Embedding(vocab_len, d_model)
        self.pos_enc = PositionalEncoding1D(d_model)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model = d_model,
            nhead = num_heads,
            dim_feedforward = dim_feedforward,
            dropout = dropout,
            activation = activation_fn,
            batch_first = True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.clf = nn.Linear(d_model, vocab_len)
        self.drop = nn.Dropout(p=dropout)


    def forward(self, memory: torch.Tensor):
        '''
        Takes an input Tensor and pushes it through the Transformer decoder,
        softmaxing over the logits output from the final feed-forward layer
        (which creates a probability distribution over the vocabulary),
        then selects the vocabulary token with the highest probability.

        In other words, this uses "greedy decoding" rather than "beam search,"
        since only the top-probability value is sampled for each sequence iteration.
        '''
        B, _, _ = memory.shape
        # Info on logits: real-valued numbers that represent probabilities and
        # are inputs to the softmax function to get normalized (sum to 1) probabilities
        # https://en.wikipedia.org/wiki/Logit
        # https://deepai.org/machine-learning-glossary-and-terms/logit
        # https://datascience.stackexchange.com/questions/31041/what-does-logits-in-machine-learning-mean/31045#31045
        # https://stackoverflow.com/questions/41455101/what-is-the-meaning-of-the-word-logits-in-tensorflow
        all_logits = []

        # Set start-of-sequence tokens for each item in the batch
        sampled_ids = [torch.full(size=[B], fill_value=self.sos_token_idx).to(memory.device)]
        
        # target's shape is: (batch_size, 1, d_model)
        # Multiplying by the square root of d_model is a more efficient way
        # to scale the self-attention, rather than allowing the numbers to get
        # larger and scaling them later on. See the Peter Bloem blog article
        # referenced in a preceeding section.
        target = self.pos_enc(
            self.emb(sampled_ids[0]).unsqueeze(1) * math.sqrt(self.d_model)
        )
        target = self.drop(target)

        # Set each item in the batch to "end-of-sequence not reached"
        eos_sampled = torch.zeros(B).bool()

        for token in range(self.max_seq_len):
            target_mask = self.subsequent_mask(size=len(sampled_ids)).to(memory.device)
            out = self.decoder(target, memory, tgt_mask=target_mask)    # shape: (batch_size, tokens, d_model)
            # Grab the output from the final feed-forward layer for the last
            #   token (across all batches and all d_model dimensions)
            logits = self.clf(out[:, -1, :])                            # shape: (batch_size, vocab_size)
            _, pred = torch.max(input=logits, dim=-1)
            all_logits.append(logits)
            sampled_ids.append(pred)
            # Check if all items in the batch have output end-of-sequence tokens
            for i, pr in enumerate(pred):
                if pr == self.eos_token_idx:
                    eos_sampled[i] = True
            if eos_sampled.all():
                # Exit loop, all items in the batch have predicted end-of-sequence tokens
                break
            # Add position encodings to the (scaled) embeddings of the predicted tokens
            target_ext = self.drop(
                self.pos_enc.pos_enc[:, len(sampled_ids)-1] + self.emb(pred) * math.sqrt(self.d_model)
            ).unsqueeze(1)
            # Concatenate along dimension 1 (tokens)
            target = torch.cat([target, target_ext], 1)
        
        # torch.stack concatenates a list of tensors along a new dimension
        sampled_ids = torch.stack(sampled_ids, dim=1)
        all_logits = torch.stack(all_logits, dim=1)

        # Replace all tokens after an EOS token with the pad token
        eos_indexes = (sampled_ids == self.eos_token_idx).float().argmax(dim=1)
        for batch_item in range(B):
            if eos_indexes[batch_item] != 0:
                # the sampled sequence has an EOS token,
                # so set all tokens after the EOS token to the pad token
                sampled_ids[batch_item, eos_indexes + 1:] = self.pad_token_idx
        
        return all_logits, sampled_ids

    
    def decode_teacher_forcing(self, memory: torch.Tensor, target: torch.Tensor):
        '''
        Implements attention over the ResNet encoder's output.
        "Teacher forcing" refers to the autoregressive self-supervised technique
        of having the targets set to the inputs shifted one token to the right
        (that is, the targets are the next token in the input sequence).

        Returns a Tensor with a probability distribution (actually, logits, one
        step before a softmaxed probability distribution, but the same results
        when just choosing the max value as the predicted token) over the vocabulary,
        with dimensions: (batch_size, tokens, vocab_len)

        Args
        ---
        `memory`: a Tensor of shape (batch_size, w*h, d_model)
            This is the tensor output by the encoder model. It is used
            as the Key and Value vectors for attention.
        `target`: a Tensor of shape (batch_size, tokens)
            The target tokens being predicted.
        '''
        B, T = target.shape

        # Shift elements of target to the right, for predicting
        # the next item in the sequence (teacher forcing)
        target = torch.cat(
            [
                torch.full([B], fill_value=self.sos_token_idx).unsqueeze(1).to(memory.device),
                target[:, :-1]
            ],
            dim=1
        )

        # Implement masking for causal self-attention (where all future tokens
        # cannot be used in the prediction).
        # == COMMENT FROM TOBIAS VAN DER WERFF ==
        # This is a combination of pad token masking (target_key_padding_mask)
        # and causal self-attention masking (target_mask), where the target_mask
        # is of shape (T, T) and the targets are shifted to the right by one.
        # == END OF COMMENT FROM TOBIAS VAN DER WERFF ==
        # The next line first initializes a variable called target_key_padding_mask,
        # then makes that variable equal the Tensor target, and then checks whether
        # each item in the target_key_padding_mask equals the pad_token_idx,
        # returning Boolean values for each item in the Tensor.
        target_key_padding_mask = target == self.pad_token_idx
        target_mask = self.subsequent_mask(size=T).to(target.device)

        target = self.pos_enc(self.emb(target) * math.sqrt(self.d_model))
        target = self.drop(target)
        out = self.decoder(
            target, memory, tgt_mask=target_mask, tgt_key_padding_mask=target_key_padding_mask
        )
        logits = self.clf(out)
        return logits


    # The @staticmethod decorator places a standalone function within the
    # scope of a class, to show that the function is somehow related to
    # the class, even though the function does not require the class nor
    # any instantiated class object to perform its operations.
    # See: https://stackoverflow.com/questions/23508248/why-do-we-use-staticmethod
    @staticmethod
    def subsequent_mask(size: int):
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        return mask == 1

## Complete network

Adapted from: Tobias van der Werff's GitHub repo: [full-page-handwriting-recognition -> models.py](https://github.com/tobiasvanderwerff/full-page-handwriting-recognition/blob/47a2d27fc40815898474e9f74badc7d544740fee/src/models.py#L304).

For the model's hyperparameters, see [pgs. 8-9 of the PDF](https://arxiv.org/pdf/2103.06450v2.pdf) for "Full Page Handwritten Text Recognition via Image to Sequence Extraction" by Singh et al. (2021).

In [None]:
# from typing import Callable, Optional

class HTRmodel(nn.Module):
    # Class-level variables with type annotations
    encoder: ResNetEncoder
    decoder: TransformerDecoder
    cer_metric: CharacterErrorRate
    loss_fn: Callable

    def __init__(
        self,
        max_seq_len: int = 3500,
        d_model: int = 260,
        num_layers: int = 6,
        num_heads: int = 4,
        dim_feedforward: int = 1024,
        encoder_name: str = 'resnet18',
        drop_enc: float = 0.5,
        drop_dec: float = 0.5,
        activation_dec: str = 'gelu',
        label_smoothing: float = 0.0,
        vocab_len: Optional[int] = None,
        vocab_dict: dict = token_to_index
    ):
        '''
        A complete neural network for full-page handwritten text recognition.
        Incorporates a convolutional neural network encoder (`ResNetEncoder`)
        with a Transformer decoder (`TransformerDecoder`). The encoder takes a
        grayscale (or binary) image as input and outputs a feature map. Each layer
        in the decoder attends over the feature map output from the encoder, along
        with masked self-attention on previous output tokens.

        Parameters
        ---
        max_seq_len: int = 3500,
        d_model: int = 260,
        num_layers: int = 6,
        num_heads: int = 4,
        dim_feedforward: int = 1024,
        encoder_name: str = 'resnet18',
        drop_enc: float = 0.5,
        drop_dec: float = 0.5,
        activation_dec: str = 'gelu',
        label_smoothing: float, default=0.0
            The label-smoothing epsilon setting for cross-entropy loss.
            0 means no smoothing.
        vocab_len: Optional[int] = None
        '''
        super().__init__()

        # Get the indices for the special tokens
        self.sos_token_idx = vocab_dict['<START>']
        self.eos_token_idx = vocab_dict['<END>']
        self.pad_token_idx = vocab_dict['<PAD>']

        # Initialize encoder and decoder
        self.encoder = ResNetEncoder(
            d_model=d_model, model_name=encoder_name, dropout=drop_enc
        )

        self.decoder = TransformerDecoder(
            vocab_len = vocab_len,
            max_seq_len = max_seq_len,
            eos_token_idx = self.eos_token_idx,
            sos_token_idx = self.sos_token_idx,
            pad_token_idx = self.pad_token_idx,
            d_model = d_model,
            num_layers = num_layers,
            num_heads = num_heads,
            dim_feedforward = dim_feedforward,
            dropout = drop_dec,
            activation_fn = activation_dec
        )

        # Initialize metrics and loss function
        self.cer_metric = CharacterErrorRate()
        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=self.pad_token_idx, label_smoothing=label_smoothing
        )


    def forward(self, imgs, targets=None):
        '''
        Run inference on the model using greedy decoding,
        that is, return predicted text from the input `imgs`.

        Returns a 3-tuple of (logits, predictions, loss)
        - logits: represents the probability distribution
            output across the vocabulary space
        - predictions: argmax of the logits at each sequential step
        - loss: if  `targets` is not None, this value holds the loss
            of predictions vs. targets
        
        Parameters
        ---
        imgs: Tensor
            A Tensor of shape: (batch_size, w, h)
        targets: Tensor
            A Tensor of shape (batch_size, tokens)
        '''
        # Encoder input dims: (batch_size, d_model, w, h)
        # Decoder input dims: (batch_size, w*h, d_model)
        logits, predictions = self.decoder(self.encoder(imgs))
        # logits dims: (batch_size, tokens, vocab_len), 
        #  where tokens is the max_len, with padding added to tokens that come
        #  after the <END> token.
        loss = None
        if targets is not None:
            # print(f"Target shape: {targets.shape}")
            # print(f"Logits shape: {logits.shape}")
            # print(f"Predictions shape: {predictions.shape}")
            # # Target shape: torch.Size([1, 3500])
            # # Logits shape: torch.Size([1, 3500, 116])
            # # Predictions shape: torch.Size([1, 3500])
            loss = self.loss_fn(
                logits[:, :targets.size(1), :].transpose(1, 2),
                targets[:, :logits.size(1)]
            )
        return logits, predictions, loss


    def forward_teacher_forcing(
        self, imgs: torch.Tensor, targets: torch.Tensor):
        '''
        Predict text using teacher forcing and greedy decoding.

        Teacher forcing means that the model expects as input
        the ground truth tokens.

        Returns:
        - logits: reflects a probability distribution across
            output classes (that is, the model's character-level
            vocabulary)
        - loss: value measuring the model's loss
        '''
        memory = self.encoder(imgs)
        logits = self.decoder.decode_teacher_forcing(memory, targets)
        loss = self.loss_fn(logits.transpose(1, 2), targets)
        return logits, loss


    def calculate_char_error(self, predictions, targets):
        self.cer_metric.reset()
        cer = self.cer_metric(predictions, targets)
        return cer

## Model training

### Set training variables

In [None]:
# Create file to save model performance metrics
if not os.path.exists('./model_performance.csv'):
    with open('./model_performance.csv', mode='a', encoding='UTF-8') as csv_file:
        csv_file.write(','.join(
            ['Batch number', 'Training loss', 'Validation loss', 'Validation error']) 
            + '\n')

# Create .csv file to save checkpoint performance info
if not os.path.exists('./checkpoint_comparisons.csv'):
    with open('./checkpoint_comparisons.csv', mode='a', encoding='UTF-8') as chkpt_file:
        chkpt_file.write(','.join(
            ['Batch number', 'Loss', 'Character error rate'])
            + '\n')

### Training loop

#### Saving checkpoints
To save periodic checkpoints, create a checkpoint dictionary with at least the `model.state_dict()` and `optimizer.state_dict()` as entries in the dictionary. You may also want to include entries for `epoch`, `loss`, or `nn.Embedding()` layers that occur outside of the main `model`. Next, use `torch.save(checkpoint_dict, filepath)` with a `.pt` or `.pth` extension in `filepath` and the saved file will be a serialized dictionary that holds all the information you wanted to capture.

To load the saved checkpoint, use `checkpoint = torch.load(checkpoint_filepath)`, then intitialize the model and optimizer as normal, then do `model.load_state_dict(checkpoint['model_state_dict'])` to load the parameters and registered buffers (like BatchNorm mean) into the model. Similarly, run `optimizer.load_state_dict(checkpoint['optim_state_dict'])` to load the optimizer's parameters into the initialized optimizer instance. Notice how the saved `state_dict()` is simply a value in a dictionary that is accessed with the keys 'model_state_dict' or 'optim_state_dict'. The `torch.load()` function takes the saved checkpoint file and loads it back into a Python dictionary.

**References**
* [PyTorch tutorial: Saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html), most detailed, with helpful information on what a state dictionary is, how to configure a checkpoint dictionary, and a reminder to use `model.eval()` to set the model in inference mode if you plan to use the saved model for predictions, since by default it is saved in `model.train()` mode.
* [Weights and Biases blog: Saving checkpoints in PyTorch](https://wandb.ai/wandb/common-ml-errors/reports/How-to-Save-and-Load-Models-in-PyTorch--VmlldzozMjg0MTE), quick and helpful reference with all key information included.
* [PyTorch recipes: Saving and loading a general checkpoint](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html), takes a slightly different approach and clearly demonstrates the checkpoint dictionary.

In [None]:
# def initialize(
#     batch_size = 4,
#     vocab_dict = token_to_index,
#     device = device):
#     """
#     Initialize the datasets, dataloaders, neural network (model), 
#     and optimizer (Adam).

#     Returns a tuple of the format:
#         (train_loader, val_loader, model, optimizer)
#     """

#     # Intialize the training and validation datasets
#     train_dataset = FPHRdataset(
#         path_to_image_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/processed_images',
#         path_to_label_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/labels',
#         max_width = 2300,
#         max_height = 1900,
#         augmentation_likelihood = 0.5,
#         train = True,
#         vocab_dict = vocab_dict
#     )

#     val_dataset = FPHRdataset(
#         path_to_image_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/processed_images',
#         path_to_label_folder = r'/content/drive/MyDrive/School/Deep learning final project/training_data/labels',
#         max_width = 2300,
#         max_height = 1900,
#         augmentation_likelihood = 0.5,
#         train = False,
#         vocab_dict = vocab_dict
#     )

#     # ==============================
#     # Initialize data loaders
#     # ==============================
#     # Shuffle can be set to False to save memory. See: https://medium.com/@raghadalghonaim/memory-leakage-with-pytorch-23f15203faa4
#     # Note that pin_memory cannot be set to True, or an exception will raise that says the Tensors must be dense float Tensors.
#     train_loader = torch.utils.data.DataLoader(
#         train_dataset, batch_size=batch_size, shuffle=True)
#     val_loader = torch.utils.data.DataLoader(
#         val_dataset, batch_size=batch_size, shuffle=True)
    
#     # ==============================
#     # Initalize model and optimizer
#     # ==============================

#     # Create the model and place it on the available device (CPU or GPU)
#     model = HTRmodel(
#         max_seq_len = 3500,
#         d_model = 260,
#         num_layers = 6,
#         num_heads = 4,
#         dim_feedforward = 1024,
#         encoder_name = 'resnet18',
#         drop_enc = 0.2,
#         drop_dec = 0.2,
#         activation_dec = 'gelu',
#         label_smoothing = 0,
#         vocab_len = len(vocab_dict),
#         vocab_dict = vocab_dict
#     ).to(device)

#     print("Model loaded. Total number of parameters: "
#         f"{sum([layer.numel() for layer in model.parameters()]):,d}.")

#     # Initialize the optimizer.
#     # Hyperparameters come from Singh et al., page 9.
#     optim = torch.optim.Adam(
#         model.parameters(), lr=0.0002, betas=(0.9, 0.999)
#     )

#     return train_loader, val_loader, model, optim


In [None]:
# def process_one_batch(model):
#     pass

In [None]:
# def validate_one_batch(model):
#     pass

In [None]:
# def log_metrics(model):
#     pass

In [None]:
# def save_checkpoint(model):
#     pass

In [None]:
# def train_model_split_into_smaller_functions(
#     num_epochs,
#     batch_size = 4,
#     save_every = 500,
#     grad_accum_factor = 8,
#     vocab_dict = token_to_index,
#     device = device):
#     """
#     Loop for training the neural network on full-page handwritten
#     text recognition.

#     Returns a tuple of lists of the format:
#         `(batch_num, loss_train, loss_val, error_val)`,
#     where
#         `batch_num` is a list of batch numbers when model performance was logged
#         `loss train` is a list of loss values during model training at each
#             `batch_num`
#         `loss_val` is a list of average loss values from the validation set,
#             captured at each `batch_num`
#         `error_val` is a list of the average character error rate from the
#             validation set, captured at each `batch_num`
#     """
    
#     # Clear unused variables from memory
#     gc.collect()

#     # Create lists to store model performance information.
#     batch_num = []
#     loss_train = []
#     loss_val = []
#     error_val = []

#     # Initialize the datasets, dataloaders, neural network, and optimizer
#     train_loader, val_loader, model, optim = initialize(
#         batch_size = batch_size,
#         vocab_dict = vocab_dict,
#         device = device)

#     # Determine the total number of iterations (batches) that will be processed
#     #  during training, for setting the 'total' argument of tqdm's progress bar
    
#     # Set number of iterations (batches) to use during validation loop.
#     # len(val_loader) will use all available items in the validation set.
#     # At the end of the validation loop, the loss and error values from each
#     # batch will be averaged, returning an average loss and avg. error rate.
#     val_iterations = 5      # much fewer than len(val_loader)
    
#     # Calculate total iterations (training plus validation)
#     total_iterations = (
#         (num_epochs * len(train_loader))
#       + (num_epochs * ((len(train_loader) / save_every) * val_iterations))
#     )

#     iter_count = 0
#     iters_per_epoch = len(train_loader)
#     progress_bar = tqdm(total=total_iterations)
#     progress_bar.set_description('Beginning model training')
#     for epoch in range(num_epochs):
#         # Process one batch of images at a time in a loop through
#         #  all images in the training dataset.
#         for imgs, lbls in train_loader:
#             # Make sure the Tensors returned by the DataLoader are placed
#             #  on the GPU to avoid running out of CPU memory.
#             #  See: https://medium.com/@raghadalghonaim/memory-leakage-with-pytorch-23f15203faa4
#             # Note that this step is unnecessary since I add the Tensors to the
#             #  correct device in the __getitem__ function of my Dataset class.
#             # imgs = imgs.to(device)
#             # lbls = lbls.to(device)
#             # Alternative version:
#             # imgs = imgs.cuda(non_blocking=True)
#             # lbls = lbls.cuda(non_blocking=True)
            
#             # Clear unused variables from memory
#             gc.collect()

#             iter_count += 1

#             # Set BatchNorm and Dropout layers to training mode
#             model.train()

#             # Run data through model
#             train_logits, train_loss = model.forward_teacher_forcing(imgs=imgs, targets=lbls)
#             # Training predictions are not needed, so delete variable to save memory
#             del train_logits

#             # =========================
#             # Update model parameters
#             # =========================
#             # See: https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch/62076913#62076913
#             # Scale loss by the number of batches to accumulate
#             loss = train_loss / grad_accum_factor
#             # Backpropagate loss
#             loss.backward()
#             if iter_count % grad_accum_factor == 0:
#                 # Update gradients
#                 optim.step()
#                 # Reset gradients if accumulation factor is reached
#                 optim.zero_grad()

#             # =========================
#             # Run validation
#             # =========================
#             if iter_count % save_every == 0:
#                 with torch.no_grad():
#                     # Set model to inference mode so BatchNorm
#                     # and Dropout layers perform consistently.
#                     model.eval()
#                     # Capture all losses and errors to average at the end.
#                     val_losses = []
#                     val_errors = []
#                     # Loop through validation set.
#                     for i, (val_imgs, val_lbls) in enumerate(val_loader):
#                         val_logits, val_predictions, val_loss = model.forward(imgs=val_imgs, targets=val_lbls)
#                         val_error = model.calculate_char_error(predictions=val_predictions, targets=val_lbls)
#                         val_losses.append(val_loss.item())
#                         val_errors.append(float(val_error))
#                         progress_bar.set_description(
#                             f"Validation loop ({i+1}/{val_iterations}). "
#                             f"Loss: {val_losses[-1]:,.2f}, "
#                             f"Error: {val_errors[-1]:.1%} | ")
#                         progress_bar.update(n=1)
#                         progress_bar.refresh()
#                         if i > val_iterations:
#                             break
#                 # Average the loss and error values
#                 val_loss = np.mean(val_losses)
#                 val_error = np.mean(val_errors)

#                 # =========================
#                 # Log model metrics
#                 # =========================
#                 batch_num.append(iter_count)
#                 loss_train.append(train_loss.detach().item())
#                 loss_val.append(val_loss)
#                 error_val.append(val_error)
#                 with open('./model_performance.csv', mode='a', encoding='UTF-8') as csv_file:
#                     csv_file.write(','.join(
#                         list(zip(batch_num, loss_train, loss_val, error_val))[-1])
#                         + '\n'
#                     )
#                 progress_bar.set_description(
#                     f"Resuming training. "
#                     f"Epoch: {epoch+1}/{num_epochs}, "
#                     f"Avg. val. loss: {loss_val[-1]:,.2f}, "
#                     f"Avg. val. error: {error_val[-1]:.1%} | ")

#                 # =========================
#                 # Save checkpoint
#                 # =========================
#                 checkpoint_dict = {
#                     'model_state': model.state_dict(),
#                     'optimizer_state': optim.state_dict(),
#                     'epoch_num': epoch + 1,
#                     'loss': loss_val[-1]
#                 }
#                 checkpoint_filepath = f"./checkpoint_{str(iter_count).zfill(5)}.pt"
#                 torch.save(checkpoint_dict, checkpoint_filepath)

#                 # Save model accuracy
#                 # model_info = f"Epoch: {str(epoch+1).zfill(2)}, loss: {loss_val[-1]:,.3f}, character error rate: {error_val[-1]:.2%}"
#                 with open('./checkpoint_comparisons.csv', mode='a', encoding='UTF-8') as chkpt_file:
#                     chkpt_file.write(','.join(
#                         [iter_count, loss_val[-1], error_val[-1]])
#                         + '\n'
#                     )
            
#             # =========================
#             # Update progress bar
#             # =========================
#             progress_bar.update(n=1)
    
#     return batch_num, loss_train, loss_val, error_val


Original training loop (all-in-one function).

💡 Better implementation of logging to a .csv file: use the `csv` module and when initializing the `csv.writer()` object, set `newline=''` to prevent an extra line in between each row.

See:
* https://www.adamsmith.haus/python/answers/how-to-write-to-a-%60.csv%60-file-without-blank-lines-in-python
* https://docs.python.org/3/library/csv.html?highlight=csv#id3

In [None]:
def train_model(
    num_epochs,
    batch_size = 4,
    save_every = 500,
    grad_accum_factor = 8,
    vocab_dict = token_to_index,
    device = device):
    
    # Clear unused variables from memory
    gc.collect()

    # Create lists to store model performance information.
    batch_num = []
    loss_train = []
    loss_val = []
    error_val = []

    # Intialize the training and validation datasets
    train_dataset = FPHRdataset(
        path_to_image_folder = path_to_image_folder,
        path_to_label_folder = path_to_label_folder,
        max_width = 2300,
        max_height = 1900,
        augmentation_likelihood = 0.5,
        train = True,
        vocab_dict = vocab_dict
    )

    val_dataset = FPHRdataset(
        path_to_image_folder = path_to_image_folder,
        path_to_label_folder = path_to_label_folder,
        max_width = 2300,
        max_height = 1900,
        augmentation_likelihood = 0.5,
        train = False,
        vocab_dict = vocab_dict
    )

    # ==============================
    # Initialize data loaders
    # ==============================
    # Shuffle is set to False to save memory. See: https://medium.com/@raghadalghonaim/memory-leakage-with-pytorch-23f15203faa4
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    
    # ==============================
    # Initalize model and optimizer
    # ==============================

    # Create the model and place it on the available device (CPU or GPU)
    model = HTRmodel(
        max_seq_len = 3500,
        d_model = 260,
        num_layers = 6,
        num_heads = 4,
        dim_feedforward = 1024,
        encoder_name = 'resnet18',
        drop_enc = 0.2,
        drop_dec = 0.2,
        activation_dec = 'gelu',
        label_smoothing = 0,
        vocab_len = len(vocab_dict),
        vocab_dict = vocab_dict
    ).to(device)

    print("Model loaded. Total number of parameters: "
        f"{sum([layer.numel() for layer in model.parameters()]):,d}.")

    # Initialize the optimizer.
    # Hyperparameters come from Singh et al., page 9.
    optim = torch.optim.Adam(
        model.parameters(), lr=0.0002, betas=(0.9, 0.999)
    )

    # Determine the total number of iterations (batches) that will be processed
    #  during training, for setting the 'total' argument of tqdm's progress bar
    val_iterations = 5      # len(val_loader)
    total_iterations = (
        (num_epochs * len(train_loader))
        + (num_epochs * ((len(train_loader) / save_every) * val_iterations))
    )

    iter_count = 0
    iters_per_epoch = len(train_loader)
    progress_bar = tqdm(total=total_iterations)
    progress_bar.set_description('Beginning model training')
    for epoch in range(num_epochs):
        # Process one batch of images at a time in a loop through
        #  all images in the training dataset.
        for imgs, lbls in train_loader:
            # Make sure the Tensors returned by the DataLoader are placed
            #  on the GPU, to avoid running out of CPU memory.
            #  See: https://medium.com/@raghadalghonaim/memory-leakage-with-pytorch-23f15203faa4
            # Note that this step is unnecessary since I add the Tensors to the
            #  correct device in the __getitem__ function of my Dataset class.
            # imgs = imgs.to(device)
            # lbls = lbls.to(device)
            # imgs = imgs.cuda(non_blocking=True)
            # lbls = lbls.cuda(non_blocking=True)
            
            # Clear unused variables from memory
            gc.collect()

            iter_count += 1

            # Set BatchNorm and Dropout layers to training mode
            model.train()

            # Run data through model
            train_logits, train_loss = model.forward_teacher_forcing(imgs=imgs, targets=lbls)
            # Training predictions are not needed, so delete variable to save memory
            del train_logits

            # =========================
            # Update model parameters
            # =========================
            # See: https://stackoverflow.com/questions/62067400/understanding-accumulated-gradients-in-pytorch/62076913#62076913
            # Scale loss by the number of batches to accumulate
            loss = train_loss / grad_accum_factor
            # Backpropagate loss
            loss.backward()
            if iter_count % grad_accum_factor == 0:
                # Update gradients
                optim.step()
                # Reset gradients if accumulation factor is reached
                optim.zero_grad()

            # =========================
            # Run validation
            # =========================
            if iter_count % save_every == 0:
                with torch.no_grad():
                    # Set model to inference mode so BatchNorm
                    # and Dropout layers perform consistently.
                    model.eval()
                    # Capture all losses and errors to average at the end.
                    val_losses = []
                    val_errors = []
                    # Loop through validation set.
                    for i, (val_imgs, val_lbls) in enumerate(val_loader):
                        # val_imgs = val_imgs.to(device)
                        # val_lbls = val_lbls.to(device)
                        val_logits, val_predictions, val_loss = model.forward(imgs=val_imgs, targets=val_lbls)
                        val_error = model.calculate_char_error(predictions=val_predictions, targets=val_lbls)
                        val_losses.append(val_loss.item())
                        val_errors.append(float(val_error))
                        progress_bar.set_description(
                            f"Validation loop ({i+1}/{val_iterations}). "
                            f"Loss: {val_losses[-1]:,.2f}, "
                            f"Error: {val_errors[-1]:.1%} | ")
                        progress_bar.update(n=1)
                        progress_bar.refresh()
                        if i > val_iterations:
                            break
                # Average the loss and error values
                val_loss = np.mean(val_losses)
                val_error = np.mean(val_errors)

                # =========================
                # Log model metrics
                # =========================
                batch_num.append(iter_count)
                loss_train.append(train_loss.detach().item())
                loss_val.append(val_loss)
                error_val.append(val_error)
                with open('./model_performance.csv', mode='a', encoding='UTF-8') as csv_file:
                    csv_file.write(','.join(
                        list(zip(batch_num, loss_train, loss_val, error_val))[-1])
                        + '\n'
                    )
                progress_bar.set_description(
                    f"Resuming training. "
                    f"Epoch: {epoch+1}/{num_epochs}, "
                    f"Avg. val. loss: {loss_val[-1]:,.2f}, "
                    f"Avg. val. error: {error_val[-1]:.1%} | ")

                # =========================
                # Save checkpoint
                # =========================
                checkpoint_dict = {
                    'model_state': model.state_dict(),
                    'optimizer_state': optim.state_dict(),
                    'epoch_num': epoch + 1,
                    'loss': loss_val[-1]
                }
                checkpoint_filepath = f"./checkpoint_{str(iter_count).zfill(5)}.pt"
                torch.save(checkpoint_dict, checkpoint_filepath)

                # Save model accuracy
                # model_info = f"Epoch: {str(epoch+1).zfill(2)}, loss: {loss_val[-1]:,.3f}, character error rate: {error_val[-1]:.2%}"
                with open('./checkpoint_comparisons.csv', mode='a', encoding='UTF-8') as chkpt_file:
                    chkpt_file.write(','.join(
                        [iter_count, loss_val[-1], error_val[-1]])
                        + '\n'
                    )
            
            # =========================
            # Update progress bar
            # =========================
            progress_bar.update(n=1)
    
    return batch_num, loss_train, loss_val, error_val


In [None]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   28160 B  |   28160 B  |   28160 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |   28160 B  |   28160 B  |   28160 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |   28160 B  |   28160 B  |   28160 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |   28160 B  |   28160 B  |   28160 B  |       0 B  |
|---------------------------------------------------------------

In [None]:
# Make sure to run the code in the "Metrics and utility functions" section to
# initialize the vocabulary dictionaries:
#   token_to_index is a dictionary with vocabulary as keys and indexes as values.
#   index_to_token is the inverse: a dictionary with indexes as keys as vocab as values.

# Set the number of epochs to use for training.
# Note that a model checkpoint will be saved at the end of
# every epoch.
num_epochs = 1

# Set frequency for logging performance metrics,
# in number of batches processed.
save_every = 400

# Set gradient accumulation factor. 1 means "reset gradient every batch" (normal),
# and any number greater than 1 means to accumulate gradients for that number of
# batches. The paper authors used a gradient accumulation factor of 2 (see Singh et al., pg. 9)
grad_accum_factor = 8

# Train model and collect performance info
batch_num, loss_train, loss_val, error_val = train_model(
    num_epochs = num_epochs,
    batch_size = 1,
    save_every = save_every,
    grad_accum_factor = grad_accum_factor)

# Bug: CPU RAM increases during training, leading to OOM error

The problem occurs during the training iterations, not during the validation loop.

Another thing to look into is the validation loop -- it takes _forever_! This could be because each token is passed through the model one at a time, rather than in parallel. Is there any way around that? Model inference follows that pattern, so any performance improvement would greatly help in applying the model to the actual pictures of my journal.

**Next steps**
* (Not sure if this would help) ... Create separate functions for the training and validation loops, so any variables used in those loops can be released as soon as an individual iteration completes.

* Since the validation loop doesn't suffer from the problem, perhaps the problem is related to the computation graph? (val loop uses `with torch.no_grad()`).

* Is the problem in the Dataset? That is, does loading the tensors from images and text create variable references that are not released?

* Is there anywhere that I have a list of PyTorch tensor objects? If so, am I `.detach()`-ing the tensors before storing them?

* If I explicitly `del` each variable in the training loop after I finish using it, will that solve the problem?

* Are the tensors from the Dataset bound to the computation graph? (that is, `requires_grad_(requires_grad=True)`?)

* Ensure that the computation graph exists solely on the GPU, and is not kept on the CPU. See this [PyTorch discussion on computation graphs and transfering Tensors](https://discuss.pytorch.org/t/link-between-require-grad-and-moving-tensors-between-cpu-and-gpu-memory/97504).

* This could be related to computation graph calculation, perhaps with the Tensors coming from the datasets, or the loss or logits tensors from the forward pass. See this [PyTorch forum on not detaching tensors from the computation graph](https://discuss.pytorch.org/t/moving-from-gpu-to-cpu-not-freeing-gpuram/85313/5).

* PyTorch automatically allocates around 2GB of CPU RAM to facilitate movement to GPU. See this [PyTorch forum on freeing CPU RAM after moving tensors to the GPU](https://discuss.pytorch.org/t/how-to-free-cpu-ram-after-module-to-cuda-device/20381).

* When PyTorch first places a Tensor on the GPU, it allocates somewhere around 900MB of GPU memory for the CUDA kernels. See this [PyTorch forum discussion on clearing GPU memory after transferring back to CPU](https://discuss.pytorch.org/t/model-to-cpu-does-not-release-gpu-memory-allocated-by-registered-buffer/126102).



---

There's a memory leak error in the training loop, before the validation loop begins. Somewhere in that code, something is accumulating and taking up increasing amounts of system RAM. The problem does not persist in the validation loop.

Potential sources: the loss function might be hosted on the CPU, rather than the GPU. Check the `.to(device)` methods on any returned Tensor in preceding functions and classes.

I don't remember this happening before, so check the areas I modified, including the decoder class and the overall model. Are there any places I store a variable that isn't used later? For example, am I storing any Tensors inside Python lists?

I could also check whether I properly use `.detach().item()` to remove a Tensor from the computation graph and get just its value before storing it to a Python list.

Also, see `/var/colab/app.log` to view the system logs, including where it saves cookies and the server where the Jupyter notebook is running (so I could potentially access it from VS Code). You can find logs under Runtime > View runtime logs.

**Helpful resources**
* [Super helpful article on PyTorch performance improvements](https://towardsdatascience.com/7-tips-for-squeezing-maximum-performance-from-pytorch-ca4a40951259), TowardsDataScience, from the creator of PyTorch Lightning
 - Big idea: create tensors on the device they will be hosted on, don't use `.to(device)`.
* [CUDA memory allocation is around 400MB when first used, even with a small tensor](https://stackoverflow.com/questions/64068771/pytorch-what-happens-to-memory-when-moving-tensor-to-gpu), Stack Overflow answer
* [Backpropagation on partial parameter groups, between CPU and GPU](https://discuss.pytorch.org/t/keeping-only-part-of-model-parameters-on-gpu/71308), PyTorch forums
* [Really clear article on speeding up PyTorch](https://betterprogramming.pub/how-to-make-your-pytorch-code-run-faster-93079f3c1f7b), TowardsDataScience (member-only article)
* [Ensuring that function calls don't take up extra memory](https://pythonspeed.com/articles/function-calls-prevent-garbage-collection/), PythonSpeed article
* [Diagnosing Data Starvation: where CPU RAM limits GPU usage](https://www.willprice.dev/2021/03/27/debugging-pytorch-performance-bottlenecks.html), Will Price blog
* [Using Memory Profiler](https://towardsdatascience.com/did-you-know-how-to-identify-variables-in-your-python-code-which-are-high-on-memory-consumption-787bef949dbd), Medium article
* [Memory Leakage with PyTorch: 4 Tips](https://medium.com/@raghadalghonaim/memory-leakage-with-pytorch-23f15203faa4), Medium article
* [Performance Tuning Guide](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#general-optimizations), PyTorch recipes
* [PyTorch FAQ: CUDA out-of-memory errors](https://pytorch.org/docs/stable/notes/faq.html#my-model-reports-cuda-runtime-error-2-out-of-memory)



In [None]:
# Install memory profiler
%pip install memory_profiler
%load_ext memory_profiler



In [None]:
from memory_test import train_model_all_in_one

In [None]:
vocab_dict = token_to_index

In [None]:
%mprun -f train_model_all_in_one train_model_all_in_one(num_epochs=1, vocab_dict=token_to_index, device=device, FPHRdataset=FPHRdataset, HTRmodel=HTRmodel, batch_size = 1, save_every = 400, grad_accum_factor = 8)


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/local/lib/python3.7/dist-packages/memory_profiler.py", line 845, in enable
    sys.settrace(self.trace_memory_usage)



Model loaded. Total number of parameters: 17,833,280.


  0%|          | 0/8100.0 [00:00<?, ?it/s]


sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/local/lib/python3.7/dist-packages/memory_profiler.py", line 848, in disable
    sys.settrace(self._original_trace_function)






In [None]:
# =======================================
# OPTIONAL: Clear GPU memory.
# Run this cell if the training
#   loop encounters an error and
#   memory is not released automatically.
# Afterwards, you'll need to re-run the
#   notebook starting from the cell after
#   library imports.
# =======================================
CLEAR_GPU_MEMORY = False

if CLEAR_GPU_MEMORY:
    print('Clearing memory...')
    print('Afterwards, remember to re-run this notebook from the beginning.')
    # Remove all variables that reference PyTorch objects
    del FPHRdataset
    del ResNetEncoder
    del TransformerDecoder
    del HTRmodel
    del PositionalEncoding1D
    del PositionalEncoding2D
    del train_model
    del CharacterErrorRate
    del text_to_tensor
    del tensor_to_text
    del token_to_index
    del index_to_token
    # Clear CUDA cache
    torch.cuda.empty_cache()
    # Examine GPU memory usage
    print(torch.cuda.memory_summary())

Clearing memory...
Afterwards, remember to re-run this notebook from the beginning.
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  400214 KB |    8747 MB |    4737 GB |    4736 GB |
|       from large pool |  315294 KB |    8649 MB |    4729 GB |    4728 GB |
|       from small pool |   84919 KB |     100 MB |       7 GB |       7 GB |
|---------------------------------------------------------------------------|
| Active memory         |  400214 KB |    8747 MB |    4737 GB |    4736 GB |
|       from large pool |  315294 KB |    8649 MB |    4729 GB |    4728 GB |
|       from small pool |   84919 KB |     100 MB |       

## Plots of model performance

In [None]:
plt.title('Model loss')
plt.plot(batch_num, loss_train, 'tan', label='Training loss')
plt.plot(batch_num, loss_val, 'slateblue', label='Validation loss')
plt.legend()
plt.show()

In [None]:
plt.title('Model error during training')
plt.plot(batch_num, error_val, 'darkred', label='Validation error %')
plt.legend()
plt.show()

## Inference: Testing model on journal pages

In [None]:
# Use model.eval() to set the model in evaluation (inference)
#   mode for the BatchNorm and Dropout modules.
# Also use the "with torch.no_grad():" context manager
#   to reduce computational load by not calculating gradient information.