# Utils 

In [None]:
!pip install editdistance torchmetrics pytorch_lightning
!pip install  torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html


In [None]:
import pickle
import math
from random import random
import xml.etree.ElementTree as ET
import numpy as np
import albumentations as A
import cv2 as cv
from dataclasses import dataclass, field
from functools import partial
from random import randint
import html
import random
from pathlib import Path
from typing import Union, Tuple, Dict, Sequence, Optional, List, Any, Callable, Optional
import pandas as pd
from torch import Tensor, nn
from torch.utils.data import Dataset
from PIL import Image
from torchmetrics import Metric
import torch
from torchvision import models
from torch.utils.data import Dataset
import editdistance
import wandb
from torch.utils.data.dataloader import DataLoader
from torch.optim import Optimizer
import torch.optim as optim
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import concurrent
from pytorch_lightning.callbacks import TQDMProgressBar


In [None]:
class LabelParser:
    def __init__(self):
        self.classes = None
        self.vocab_size = None
        self.class_to_idx = None
        self.idx_to_class = None

    def fit(self, classes: Sequence[str]):
        self.classes = list(classes)
        self.vocab_size = len(classes)
        self.idx_to_class = dict(enumerate(classes))
        self.class_to_idx = {cls: i for i, cls in self.idx_to_class.items()}

        return self

    def addClasses(self, classes: List[str]):
        all_classes = sorted(set(self.classes + classes))

        self.fit(all_classes)

    def encode_labels(self, sequence: Sequence[str]):
        self._check_fitted()
        return [self.class_to_idx[c] for c in sequence]

    def decode_labels(self, sequence: Sequence[int]):
        self._check_fitted()
        return [self.idx_to_class[c] for c in sequence]

    def _check_fitted(self):
        if self.classes is None:
            raise ValueError("LabelParser class was not fitted yet")


In [None]:
def pickle_load(file) -> Any:
    with open(file, "rb") as f:
        return pickle.load(f)

def pickle_save(obj, file):
    with open(file, "wb") as f:
        pickle.dump(obj, f)

def read_xml(file: Union[Path, str]) -> ET.Element:
    tree = ET.parse(file)
    root = tree.getroot()

    return root

def find_child_by_tag(element: ET.Element, tag: str, value: str) -> Union[ET.Element, None]:
    for child in element:
        if child.get(tag) == value:
            return child
    return None

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def dpi_adjusting(img: np.ndarray, scale: float, **kwargs) -> np.ndarray:
    height, width = img.shape[:2]
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

class LitProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, model):
        # don't show the version number
        items = super().get_metrics(trainer, model)
        for k in list(items.keys()):
            if k.startswith("grad"):
                items.pop(k, None)
        items.pop("v_num", None)
        return items
    
def decode_prediction_and_target(
    pred: Tensor, target: Tensor, label_encoder: LabelParser, eos_tkn_idx: int
) -> Tuple[str, str]:
    # Find padding and <EOS> positions in predictions and targets.
    eos_idx_pred = (pred == eos_tkn_idx).float().argmax().item()
    eos_idx_tgt = (target == eos_tkn_idx).float().argmax().item()

    # Decode prediction and target.
    p, t = pred.tolist(), target.tolist()
    p = p[1:]  # skip the initial <SOS> token, which is added by default
    p = p[:eos_idx_pred] if eos_idx_pred != 0 else p
    t = t[:eos_idx_tgt] if eos_idx_tgt != 0 else t
    pred_str = "".join(label_encoder.decode_labels(p))
    target_str = "".join(label_encoder.decode_labels(t))
    return pred_str, target_str

def matplotlib_imshow(
    img: torch.Tensor, mean: float = 0.5, std: float = 0.5, one_channel=True
):
    assert img.device.type == "cpu"
    if one_channel and img.ndim == 3:
        img = img.mean(dim=0)
    img = img * std + mean  # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))


## Image transformations

In [None]:


class SafeRandomScale(A.RandomScale):
    def apply(self, img, scale=0, interpolation=cv.INTER_LINEAR, **params):
        height, width = img.shape[:2]
        new_height, new_width = int(height * scale), int(width * scale)
        if new_height <= 0 or new_width <= 0:
            return img
        return super().apply(img, scale, interpolation, **params)

def adjust_dpi(img: np.ndarray, scale: float, **kwargs):
    height, width = img.shape
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

def randomly_displace_and_pad(
    img: np.ndarray,
    padded_size: Tuple[int, int],
    crop_if_necessary: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Randomly displace an image within a frame, and pad zeros around the image.

    Args:
        img (np.ndarray): image to process
        padded_size (Tuple[int, int]): (height, width) tuple indicating the size of the frame
        crop_if_necessary (bool): whether to crop the image if its size exceeds that
            of the frame
    """
    frame_h, frame_w = padded_size
    img_h, img_w = img.shape
    if frame_h < img_h or frame_w < img_w:
        if crop_if_necessary:
            print(
                "WARNING (`randomly_displace_and_pad`): cropping input image before "
                "padding because it exceeds the size of the frame."
            )
            img_h, img_w = min(img_h, frame_h), min(img_w, frame_w)
            img = img[:img_h, :img_w]
        else:
            raise AssertionError(
                f"Frame is smaller than the image: ({frame_h}, {frame_w}) vs. ({img_h},"
                f" {img_w})"
            )

    res = np.zeros((frame_h, frame_w), dtype=img.dtype)

    pad_top =  randint(0, frame_h - img_h)
    pad_bottom = pad_top + img_h
    pad_left = randint(0, frame_w - img_w)
    pad_right = pad_left + img_w

    res[pad_top:pad_bottom, pad_left:pad_right] = img
    return res

@dataclass
class ImageTransforms:
    max_img_size: Tuple[int, int]  # (h, w)
    normalize_params: Tuple[float, float]  # (mean, std)
    scale: float = (
        0.5
    )
    random_scale_limit: float = 0.1
    random_rotate_limit: int = 10

    train_trnsf: A.Compose = field(init=False)
    test_trnsf: A.Compose = field(init=False)

    def __post_init__(self):
        scale, random_scale_limit, random_rotate_limit, normalize_params =(
            self.scale,
            self.random_scale_limit,
            self.random_rotate_limit,
            self.normalize_params
        )

        max_img_h, max_img_w = self.max_img_size
        max_scale = scale + scale * random_scale_limit
        padded_h, padded_w = math.ceil(max_scale * max_img_h), math.ceil(max_scale * max_img_w)

        self.train_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            SafeRandomScale(scale_limit=random_scale_limit, p=0.5),
            A.SafeRotate(
                limit = random_rotate_limit,
                border_mode = cv.BORDER_CONSTANT,
                value = 0
            ),
            A.RandomBrightnessContrast(),
            A.Perspective(scale=(0.01, 0.05)),
            A.GaussNoise(),
            A.Normalize(*normalize_params),
            A.Lambda(
                image=partial(
                    randomly_displace_and_pad,
                    padded_size=(padded_h, padded_w),
                    crop_if_necessary=False,
                )
            )
        ])

        self.test_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            A.Normalize(*normalize_params),
            A.PadIfNeeded(
                max_img_h, max_img_w, border_mode=cv.BORDER_CONSTANT, value=0
            )
        ])


## IAM Dataset and Synthetic Dataset

In [None]:

class IAMDataset(Dataset):
    MEAN = 0.8275
    STD = 0.2314
    MAX_FORM_HEIGHT = 3542
    MAX_FORM_WIDTH = 2479

    MAX_SEQ_LENS = {
        "word": 55,
        "line": 90,
        "form": 700,
    }  # based on the maximum seq lengths found in the dataset

    _pad_token = "<PAD>"
    _sos_token = "<SOS>"
    _eos_token = "<EOS>"

    root: Path
    data: pd.DataFrame
    label_enc: LabelParser
    parse_method: str
    only_lowercase: bool
    transforms: Optional[A.Compose]
    id_to_idx: Dict[str, int]
    _split: str
    _return_writer_id: Optional[bool]

    def __init__(
        self,
        root: Union[Path, str],
        parse_method: str,
        split: str,
        return_writer_id: bool = False,
        only_lowercase: bool = False,
        label_enc: Optional[LabelParser] = None,
    ):
        super().__init__()
        _parse_methods = ["form", "line", "word"]
        err_message = (
            f"{parse_method} is not a possible parsing method: {_parse_methods}"
        )
        assert parse_method in _parse_methods, err_message

        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message

        self._split = split
        self._return_writer_id = return_writer_id
        self.only_lowercase = only_lowercase
        self.root = Path(root)
        self.label_enc = label_enc
        self.parse_method = parse_method

        # Process the data.
        if not hasattr(self, "data"):
            if self.parse_method == "form":
                self.data = self._get_forms()
            elif self.parse_method == "word":
                self.data = self._get_words(skip_bad_segmentation=True)
            elif self.parse_method == "line":
                self.data = self._get_lines()

        # Create the label encoder.
        if self.label_enc is None:
            vocab = [self._pad_token, self._sos_token, self._eos_token]
            s = "".join(self.data["target"].tolist())
            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc = LabelParser().fit(vocab)
        if not "target_enc" in self.data.columns:
            self.data.insert(
                2,
                "target_enc",
                self.data["target"].apply(
                    lambda s: np.array(
                        self.label_enc.encode_labels(
                            [c for c in (s.lower() if self.only_lowercase else s)]
                        )
                    )
                ),
            )

        self.transforms = self._get_transforms(split)
        self.id_to_idx = {
            Path(self.data.iloc[i]["img_path"]).stem: i for i in range(len(self))
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img = cv.imread(data["img_path"], cv.IMREAD_GRAYSCALE)
        if all(col in data.keys() for col in ["bb_y_start", "bb_y_end"]):
            # Crop the image vertically.
            img = img[data["bb_y_start"] : data["bb_y_end"], :]
        assert isinstance(img, np.ndarray), (
            f"Error: image at path {data['img_path']} is not properly loaded. "
            f"Is there something wrong with this image?"
        )
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        if self._return_writer_id:
            return img, data["writer_id"], data["target_enc"]
        return img, data["target_enc"]

    def get_max_height(self):
        return (self.data["bb_y_end"] - self.data["bb_y_start"]).max()

    @property
    def vocab(self):
        return self.label_enc.classes

    @staticmethod
    def collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
        pad_val: int,
        eos_tkn_idx: int,
        dataset_returns_writer_id: bool = False,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if dataset_returns_writer_id:
            imgs, writer_ids, targets = zip(*batch)
        else:
            imgs, targets = zip(*batch)

        img_sizes = [im.shape for im in imgs]
        if (
            not len(set(img_sizes)) == 1
        ):  # images are of varying sizes, so pad them to the maximum size in the batch
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode=cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
        imgs = np.stack(imgs, axis=0)

        seq_lengths = [t.shape[0] for t in targets]
        targets_padded = np.full((len(targets), max(seq_lengths) + 1), pad_val)
        for i, t in enumerate(targets):
            targets_padded[i, : seq_lengths[i]] = t
            targets_padded[i, seq_lengths[i]] = eos_tkn_idx

        imgs, targets_padded = torch.tensor(imgs), torch.tensor(targets_padded)
        if dataset_returns_writer_id:
            return imgs, targets_padded, torch.tensor(writer_ids)
        return imgs, targets_padded

    def set_transforms_for_split(self, split: str):
        _splits = ["train", "val", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        self.transforms = self._get_transforms(split)

    def _get_transforms(self, split: str) -> A.Compose:
        max_img_w = self.MAX_FORM_WIDTH

        if self.parse_method == "form":
            max_img_h = (self.data["bb_y_end"] - self.data["bb_y_start"]).max()
        else:  # word or line
            max_img_h = self.MAX_FORM_HEIGHT

        transforms = ImageTransforms(
            (max_img_h, max_img_w), (IAMDataset.MEAN, IAMDataset.STD)
        )

        if split == "train":
            return transforms.train_trnsf
        elif split == "test" or split == "val":
            return transforms.test_trnsf

    def statistics(self) -> Dict[str, float]:
        assert len(self) > 0
        tmp = self.transforms
        self.transforms = None
        mean, std, cnt = 0, 0, 0
        for img, _ in self:
            mean += np.mean(img)
            std += np.var(img)
            cnt += 1
        mean /= cnt
        std = np.sqrt(std / cnt)
        self.transforms = tmp
        return {"mean": mean, "std": std}

    def _get_forms(self) -> pd.DataFrame:
        """Read all form images from the IAM dataset.

        Returns:
            pd.DataFrame
                A pandas dataframe containing the image path, image id, target, vertical
                upper bound, vertical lower bound, and target length.
        """
        data = {
            "img_path": [],
            "img_id": [],
            "target": [],
            "bb_y_start": [],
            "bb_y_end": [],
            "target_len": [],
        }
        for form_dir in ["formsA-D", "formsE-H", "formsI-Z"]:
            dr = self.root / form_dir
            for img_path in dr.iterdir():
                doc_id = img_path.stem
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))

                # Based on some empiricial evaluation, the 'asy' and 'dsy'
                # attributes of a line xml tag seem to correspond to its upper and
                # lower bound, respectively. We add padding of 150 pixels.
                bb_y_start = int(xml_root[1][0].get("asy")) - 150
                bb_y_end = int(xml_root[1][-1].get("dsy")) + 150

                form_text = []
                for line in xml_root.iter("line"):
                    form_text.append(html.unescape(line.get("text", "")))

                img_w, img_h = Image.open(str(img_path)).size
                data["img_path"].append(str(img_path))
                data["img_id"].append(doc_id)
                data["target"].append("\n".join(form_text))
                data["bb_y_start"].append(bb_y_start)
                data["bb_y_end"].append(bb_y_end)
                data["target_len"].append(len("\n".join(form_text)))
        return pd.DataFrame(data).sort_values(
            "target_len"
        )  # by default, sort by target length

    def _get_lines(self, skip_bad_segmentation: bool = False) -> pd.DataFrame:
        """Read all line images from the IAM dataset.

        Args:
            skip_bad_segmentation (bool): skip lines that have the
                segmentation='err' xml attribute
        Returns:
            List of 2-tuples, where each tuple contains the path to a line image
            along with its ground truth text.
        """
        data = {"img_path": [], "img_id": [], "target": []}
        root = self.root / "lines"
        for d1 in root.iterdir():
            for d2 in d1.iterdir():
                doc_id = d2.name
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))
                for img_path in d2.iterdir():
                    target = self._find_line(
                        xml_root, img_path.stem, skip_bad_segmentation
                    )
                    if target is not None:
                        data["img_path"].append(str(img_path.resolve()))
                        data["img_id"].append(doc_id)
                        data["target"].append(target)
        return pd.DataFrame(data)

    def _get_words(self, skip_bad_segmentation: bool = False) -> pd.DataFrame:
        """Read all word images from the IAM dataset.

        Args:
            skip_bad_segmentation (bool): skip lines that have the
                segmentation='err' xml attribute
        Returns:
            List of 2-tuples, where each tuple contains the path to a word image
            along with its ground truth text.
        """
        data = {"img_path": [], "img_id": [], "writer_id": [], "target": []}
        root = self.root / "words"
        parallel_inputs = []
        for d1 in root.iterdir():
            if d1.is_file():
                continue
            for d2 in d1.iterdir():
                parallel_inputs.append(d2)

        def process_dir(directory):
            directory_results = {"img_path": [], "img_id": [], "writer_id": [], "target": []}
            doc_id = directory.name
            xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))
            writer_id = int(xml_root.get("writer-id"))
            for img_path in directory.iterdir():
                img = cv.imread(str(img_path.resolve()), cv.IMREAD_GRAYSCALE)

                if isinstance(img, np.ndarray):
                    target = self._find_word(
                        xml_root, img_path.stem, skip_bad_segmentation
                    )
                    if target is not None:
                        directory_results["img_path"].append(str(img_path.resolve()))
                        directory_results["img_id"].append(doc_id)
                        directory_results["writer_id"].append(writer_id)
                        directory_results["target"].append(target)
            return directory_results
        
        with ThreadPoolExecutor() as executor:
            results = list(tqdm(executor.map(process_dir, iter(parallel_inputs)), total=len(parallel_inputs)))
            for single_results in results:
                data["img_path"].extend(single_results["img_path"])
                data["img_id"].extend(single_results["img_id"])
                data["writer_id"].extend(single_results["writer_id"])
                data["target"].extend(single_results["target"])
                
        return pd.DataFrame(data)

    def _find_line(
        self,
        xml_root: ET.Element,
        line_id: str,
        skip_bad_segmentation: bool = False,
    ) -> Union[str, None]:
        line = find_child_by_tag(xml_root[1].findall("line"), "id", line_id)
        if line is not None and not (
            skip_bad_segmentation and line.get("segmentation") == "err"
        ):
            return html.unescape(line.get("text"))
        return None

    def _find_word(
        self,
        xml_root: ET.Element,
        word_id: str,
        skip_bad_segmentation: bool = False,
    ) -> Union[str, None]:
        line_id = "-".join(word_id.split("-")[:-1])
        line = find_child_by_tag(xml_root[1].findall("line"), "id", line_id)
        if line is not None and not (
            skip_bad_segmentation and line.get("segmentation") == "err"
        ):
            word = find_child_by_tag(line.findall("word"), "id", word_id)
            if word is not None:
                return html.unescape(word.get("text"))
        return None

class IAMSyntheticDataGenerator(Dataset):
    """
    Data generator that creates synthetic line/form images by stitching together word
    images from the IAM dataset.
    Calling `__getitem__()` samples a newly generated synthetic image every time
    it is called.
    """

    PUNCTUATION = [",", ".", ";", ":", "'", '"', "!", "?"]

    def __init__(
        self,
        iam_root: Union[str, Path],
        label_encoder: Optional[LabelParser] = None,
        transforms: Optional[A.Compose] = None,
        line_width: Tuple[int, int] = (1500, 2000),
        lines_per_form: Tuple[int, int] = (1, 11),
        words_per_line: Tuple[int, int] = (4, 10),
        words_per_sequence: Tuple[int, int] = (7, 13),
        px_between_lines: Tuple[int, int] = (25, 50),
        px_between_words: int = 50,
        px_around_image: Tuple[int, int] = (100, 200),
        sample_form: bool = False,
        only_lowercase: bool = False,
        rng_seed: int = 0,
        max_height: Optional[int] = None,
    ):
        super().__init__()
        self.iam_root = iam_root
        self.label_enc = label_encoder
        self.transforms = transforms
        self.line_width = line_width
        self.lines_per_form = lines_per_form
        self.words_per_line = words_per_line
        self.words_per_sequence = words_per_sequence
        self.px_between_lines = px_between_lines
        self.px_between_words = px_between_words
        self.px_around_image = px_around_image
        self.sample_form = sample_form
        self.only_lowercase = only_lowercase
        self.rng_seed = rng_seed
        self.max_height = max_height

        self.iam_words = IAMDataset(
            iam_root,
            "word",
            "test",
            only_lowercase=only_lowercase,
        )
        if self.max_height is None:
            self.max_height = IAMDataset.MAX_FORM_HEIGHT
        if sample_form and "\n" not in self.label_encoder.classes:
            # Add the `\n` token to the label encoder (since forms can contain newlines)
            self.label_encoder.addClasses(["\n"])
        self.iam_words.transforms = None
        self.rng = np.random.default_rng(rng_seed)

    def __len__(self):
        # This dataset does not have a finite length since it can generate random
        # images at will, so return 1.
        return 1

    @property
    def label_encoder(self):
        if self.label_enc is not None:
            return self.label_enc
        return self.iam_words.label_enc

    def __getitem__(self, *args, **kwargs):
        """By calling this method, a newly generated synthetic image is sampled."""
        if self.sample_form:
            img, target = self.generate_form()
        else:
            img, target = self.generate_line()
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        # Encode the target sequence using the label encoder.
        target_enc = np.array(self.label_encoder.encode_labels([c for c in target]))
        return img, target_enc

    def generate_line(self) -> Tuple[np.ndarray, str]:
        words_to_sample = self.rng.integers(*self.words_per_line)
        line_width = self.rng.integers(*self.line_width)
        return self.sample_lines(words_to_sample, line_width, sample_one_line=True)

    def generate_form(self) -> Tuple[np.ndarray, str]:
        # Randomly pick the number of words and inter-line distance in the form.
        words_to_sample = self.rng.integers(*self.lines_per_form) * 5  # 7 is handpicked
        px_between_lines = self.rng.integers(*self.px_between_lines)

        # Sample line images.
        line_width = self.rng.integers(*self.line_width)
        lines, target = self.sample_lines(words_to_sample, line_width)

        # Concatenate the lines vertically.
        form_w = max(l.shape[1] for l in lines)
        form_h = sum(l.shape[0] + px_between_lines for l in lines)
        if form_h > self.max_height:
            print(
                "Generated form height exceeds maximum height. Generating a new form."
            )
            return self.generate_form()
        form = np.ones((form_h, form_w), dtype=lines[0].dtype) * 255
        curr_h = 0
        for line_img in lines:
            h, w = line_img.shape
            if curr_h + h + px_between_lines > self.max_height:
                break

            form[curr_h : curr_h + h, :w] = line_img
            curr_h += h + px_between_lines

        # Add a random amount of padding around the image.
        pad_px = self.rng.integers(*self.px_around_image)
        new_h, new_w = form.shape[0] + pad_px * 2, form.shape[1] + pad_px * 2
        form = A.PadIfNeeded(
            new_h, new_w, border_mode=cv.BORDER_CONSTANT, value=255, always_apply=True
        )(image=form)["image"]

        return form, target

    def set_rng(self, seed: int):
        self.rng = np.random.default_rng(seed)

    def sample_word_image(self) -> Tuple[np.ndarray, str]:
        idx = random.randint(0, len(self.iam_words) - 1)
        img, target = self.iam_words[idx]
        target = "".join(self.iam_words.label_enc.decode_labels(target))
        return img, target

    def sample_word_image_sequence(
        self, words_to_sample: int
    ) -> List[Tuple[np.ndarray, str]]:
        """Sample a sequence of contiguous words."""
        assert words_to_sample >= 1
        start_idx = random.randint(0, len(self.iam_words) - 1)

        img_idxs = [start_idx]
        img_path = Path(self.iam_words.data.iloc[start_idx]["img_path"])
        _, _, line_id, word_id = img_path.stem.split("-")
        sampled_words = 1
        while sampled_words < words_to_sample:
            word_id = f"{int(word_id) + 1 :02}"
            img_name = (
                "-".join(img_path.stem.split("-")[:-2] + [line_id, word_id]) + ".png"
            )
            if not (img_path.parent / img_name).is_file():
                # Previous image was the last on its line. Go to the next line.
                line_id = f"{int(line_id) + 1 :02}"
                word_id = "00"
                img_name = (
                    "-".join(img_path.stem.split("-")[:-2] + [line_id, word_id])
                    + ".png"
                )
            if not (img_path.parent / img_name).is_file():
                # End of the document.
                return self.sample_word_image_sequence(words_to_sample)
            # Find the dataset index for the sampled word.
            ix = self.iam_words.id_to_idx.get(Path(img_name).stem)
            if ix is None:
                # If the image has segmentation=err attribute, it will
                # not be in the dataset. In this case try again.
                return self.sample_word_image_sequence(words_to_sample)
            img_idxs.append(ix)
            sampled_words += 1

        imgs, targets = zip(*[self.iam_words[idx] for idx in img_idxs])
        targets = [
            "".join(self.iam_words.label_enc.decode_labels(t)) for t in targets
        ]
        return list(zip(imgs, targets))

    def sample_lines(
        self, words_to_sample: int, max_line_width: int, sample_one_line: bool = False
    ) -> Tuple[Union[List[np.ndarray], np.ndarray], str]:
        """
        Calls `sample_word_image_sequence` several times, using some heuristics
        to glue the sequences together.

        Returns:
            - list of line images
            - transcription for all lines combined
        """
        curr_pos, sampled_words = 0, 0
        imgs, targets, lines = [], [], []
        target_str, last_target = "", ""

        # Sample images.
        while sampled_words < words_to_sample:
            words_per_seq = self.rng.integers(*self.words_per_sequence)
            # Sample a sequence of contiguous words.
            img_tgt_seq = self.sample_word_image_sequence(words_per_seq)
            for i, (img, tgt) in enumerate(img_tgt_seq):
                # Add the sequence to the sampled words so far.
                if sampled_words >= words_to_sample:
                    break
                h, w = img.shape

                if curr_pos + w > max_line_width:
                    # Concatenate the sampled images into a line.
                    line = self.concatenate_line(imgs, targets, max_line_width)

                    if sample_one_line:
                        return line, target_str

                    lines.append(line)
                    target_str += "\n"
                    last_target = "\n"
                    curr_pos = 0
                    imgs, targets = [], []

                # Basic heuristics to avoid some strange looking sentences.
                if i == 0 and (
                    (last_target in self.PUNCTUATION and tgt in self.PUNCTUATION)
                    or (tgt in self.PUNCTUATION and sampled_words == 0)
                ):
                    continue

                if (
                    sampled_words == 0
                    or tgt in [c for c in self.PUNCTUATION if c not in ["'", '"']]
                    or last_target == "\n"
                ):
                    target_str += tgt
                else:
                    target_str += " " + tgt

                targets.append(tgt)
                imgs.append(img)

                sampled_words += 1
                last_target = tgt
                if tgt in self.PUNCTUATION:
                    # Reduce horizontal spacing for punctuation tokens.
                    curr_pos = max(0, curr_pos - self.px_between_words)
                curr_pos += w + self.px_between_words
        if imgs and targets:
            # Concatenate the remaining images into a new line.
            line = self.concatenate_line(imgs, targets, max_line_width)
            lines.append(line)
            if sample_one_line:
                return line, target_str
        return lines, target_str

    def concatenate_line(
        self, imgs: List[np.ndarray], targets: List[str], line_width: int
    ) -> np.ndarray:
        """
        Concatenate a series of (img, target) tuples into a line to create a line image.
        """
        assert len(imgs) == len(targets)

        line_height = max(im.shape[0] for im in imgs)
        line = np.ones((line_height, line_width), dtype=imgs[0].dtype) * 255

        curr_pos = 0
        prev_lower_bound = line_height
        for img, tgt in zip(imgs, targets):
            h, w = img.shape
            # Center the image in the middle of the line.
            start_h = min(max(0, int((line_height - h) / 2)), line_height - h)

            if tgt in [",", "."]:
                # If sampled a comma or dot, place them at the bottom of the line.
                start_h = min(max(0, prev_lower_bound - int(h / 2)), line_height - h)
            elif tgt in ['"', "'"]:
                # If sampled a quote, place them at the top of the line.
                start_h = 0
            if tgt in self.PUNCTUATION:
                # Reduce horizontal spacing for punctuation tokens.
                curr_pos = max(0, curr_pos - self.px_between_words)

            assert curr_pos + w <= line_width, f"{curr_pos + w} > {line_width}"
            assert start_h + h <= line_height, f"{start_h + h} > {line_height}"

            # Concatenate the word image to the line.
            line[start_h : start_h + h, curr_pos : curr_pos + w] = img

            curr_pos += w + self.px_between_words
            prev_lower_bound = start_h + h
        return line

    @staticmethod
    def get_worker_init_fn():
        def worker_init_fn(worker_id: int):
            set_seed(worker_id)
            worker_info = torch.utils.data.get_worker_info()
            dataset = worker_info.dataset  # the dataset copy in this worker process
            if hasattr(dataset, "set_rng"):
                dataset.set_rng(worker_id)
            else:  # dataset is instance of `IAMDatasetSynthetic` class
                dataset.synth_dataset.set_rng(worker_id)

        return worker_init_fn



class IAMDatasetSynthetic(Dataset):
    """
    A Pytorch dataset combining the IAM dataset with the IAMSyntheticDataGenerator
    dataset.

    The distribution of real/synthetic images can be controlled by setting the
    `synth_prob` argument.
    """

    iam_dataset: IAMDataset
    synth_dataset: IAMSyntheticDataGenerator

    def __init__(self, iam_dataset: IAMDataset, synth_prob: float = 0.3, **kwargs):
        """
        Args:
            iam_dataset (Dataset): the IAM dataset to sample from
            synth_prob (float): the probability of sampling a synthetic image when
                calling `__getitem__()`.
        """
        self.iam_dataset = iam_dataset
        self.synth_prob = synth_prob
        self.synth_dataset = IAMSyntheticDataGenerator(
            iam_root=iam_dataset.root,
            label_encoder=iam_dataset.label_enc,
            transforms=iam_dataset.transforms,
            sample_form=(True if iam_dataset.parse_method == "form" else False),
            only_lowercase=iam_dataset.only_lowercase,
            max_height=(
                (iam_dataset.data["bb_y_end"] - iam_dataset.data["bb_y_start"]).max()
                if iam_dataset.parse_method == "form"
                else None
            ),
            **kwargs,
        )

    def __getitem__(self, idx):
        iam = self.iam_dataset
        if random.random() > 1 - self.synth_prob:
            # Sample from the synthetic dataset.
            img, target = self.synth_dataset[0]
        else:
            # Index the IAM dataset.
            img, target = iam[idx]
        assert not np.any(np.isnan(img)), img
        return img, target

    def __len__(self):
        return len(self.iam_dataset)
    
    
import time
class RIMESDataset(Dataset):
    MEAN = 0.8275
    STD = 0.2314
    
    root: Path
    data: pd.DataFrame
    label_enc: LabelParser
    transforms: Optional[A.Compose]
    id_to_idx: Dict[str, int]
    _split: str
    _return_writer_id: Optional[bool]
    
    _pad_token = "<PAD>"
    _sos_token = "<SOS>"
    _eos_token = "<EOS>"
    
    max_width: Optional[int]
    max_height: Optional[int]
    
    @staticmethod
    def process_target(target: str):
    # Splitting the input string into lines
        lines = target.split("\\n")
        
        new_lines = []
        for line in lines:
            new_line = line
            start_index = new_line.find("¤{")
            
            while start_index != -1:
                # Find the corresponding closing bracket
                end_index = new_line.find("¤", start_index + 1)
                if end_index == -1:
                    break  # Safety check
    
                seq = new_line[start_index + 2:end_index]
                choices = seq.split("/")
                val = choices[randint(0, len(choices) - 1)]
    
                new_line = new_line[:start_index] + " " + val + " " + new_line[end_index + 1:]
                
                start_index = new_line.find("¤{", start_index + 1)
            
            new_lines.append(new_line)
        
        return "\n".join(new_lines)
    
    def __init__(
            self,
            root: Union[Path, str],
            split: str, 
            only_lowercase: bool = False,
            label_enc: Optional[LabelParser] = None,):
        super().__init__()
        
        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        
        self._split = split
        self.only_lowercase = only_lowercase
        self.root = Path(root)
        self.label_enc = label_enc
        
        if not hasattr(self, "data"):
            self.data = self._get_form_data()
        
        if self.label_enc is None:
            vocab = [self._pad_token, self._sos_token, self._eos_token]
            s = "".join(self.data["target"].tolist())
            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc = LabelParser().fit(vocab)
            
        if not "target_enc" in self.data.columns:
            self.data.insert(
                2,
                "target_enc",
                self.data["target"].apply(
                    lambda s: np.array(
                        self.label_enc.encode_labels(
                            [c for c in (s.lower() if self.only_lowercase else s)]
                        )
                    )
                )
            )
        self.transforms = self._get_transforms(split)
        self.id_to_idx = {
            Path(self.data.iloc[i]["img_path"]).stem: i for i in range(len(self))
        }
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img = cv.imread(data["img_path"], cv.IMREAD_GRAYSCALE)
        
        if all(col in data.keys() for col in ["bb_y_start", "bb_y_end"]):
            img = img[data["bb_y_start"]: data["bb_y_end"], :]
        assert isinstance(img, np.ndarray), (
            f"Error: image at path {data['img_path']} is not properly loaded. "
            f"Is there something wrong with this image?"
        )
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        
        return img, data["target_enc"]
    
    def get_max_height(self):
        return (self.data["bb_y_end"] - self.data["bb_y_start"]).max() + 150
    
    def get_max_width(self):
        return (self.data["bb_x_end"] - self.data["bb_x_start"]).max() + 150
    
    @property
    def vocab(self):
        return self.label_enc.classes
        
    @staticmethod
    def collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
        pad_val: int,
        eos_tkn_idx: int,
        dataset_returns_writer_id: bool = False,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        
        imgs, targets = zip(*batch)

        img_sizes = [im.shape for im in imgs]
        if not len(set(img_sizes)) == 1:
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode=cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
        imgs = np.stack(imgs, axis=0)

        seq_lengths = [t.shape[0] for t in targets]
        targets_padded = np.full((len(targets), max(seq_lengths) + 1), pad_val)
        for i, t in enumerate(targets):
            targets_padded[i, : seq_lengths[i]] = t
            targets_padded[i, seq_lengths[i]] = eos_tkn_idx

        imgs, targets_padded = torch.tensor(imgs), torch.tensor(targets_padded)
        
        
        return imgs, targets_padded
    
    def _get_transforms(self, split: str) -> A.Compose:
        max_img_w = self.max_width
    
        max_img_h = self.max_height
    
        transforms = ImageTransforms(
            (max_img_h, max_img_w), (RIMESDataset.MEAN, RIMESDataset.STD)
        )
    
        if split == "train":
            return transforms.train_trnsf
        elif split == "test" or split == "val":
            return transforms.test_trnsf
    

    def _get_form_data(self):
        data = {
            "img_path": [],
            "img_id": [],
            "target": [],
            "bb_y_start": [],
            "bb_y_end": [],
            "bb_x_start": [],
            "bb_x_end": [],
            "target_len": [],
        }
        
        
        def process_forms(paths: Tuple[str, str, Path]):
            return_data = {
                "img_path": [],
                "img_id": [],
                "target": [],
                "bb_y_start": [],
                "bb_y_end": [],
                "bb_x_start": [],
                "bb_x_end": [],
                "target_len": []
            }
            img_path, xml_path, root = paths
            img_path = root / img_path
            xml_path = root / xml_path
            doc_id = img_path.stem[:-2]
            xml_root = read_xml(xml_path)
            
            bb_y_start, bb_y_end, bb_x_start, bb_x_end = None, None, None, None
            target = ""
            num_corps = 0
            for box in xml_root.iter("box"):
                type_tag =  box.find("type")
                if type_tag.text == "Corps de texte":
                    target = box.find("text").text
                    if target is None or target == "":
                        continue
                    words = target.split("\\n")
                    if len(words) <= 5:
                        continue
                    bb_y_start = box.get("top_left_y")
                    bb_y_end   = box.get("bottom_right_y")
                    bb_x_start = box.get("top_left_x")
                    bb_x_end   = box.get("bottom_right_x")
                    
                    return_data["img_path"].append(str(img_path.resolve()))
                    return_data["img_id"].append(doc_id)
                    return_data["target"].append(self.process_target(target))
                    return_data["bb_y_start"].append(int(bb_y_start))
                    return_data["bb_y_end"].append(int(bb_y_end))
                    return_data["bb_x_start"].append(int(bb_x_start))
                    return_data["bb_x_end"].append(int(bb_x_end))
                    return_data["target_len"].append(len(target))
                    num_corps += 1
            
            
            # print(return_data["img_path"])
            
            return return_data
        
        image_pairs = []
        for form_dir in ["DVD1_TIF", "DVD2_TIF", "DVD3_TIF"]:
            dr = self.root / form_dir
            for file in dr.iterdir():
                name = file.stem
                ext = file.suffix
                if ext == ".tif" and name[-1] == "L":
                    image_pairs.append((name + ".tif", name + ".xml", dr))
        
        with ThreadPoolExecutor() as executor:
            results = list(executor.map(process_forms, iter(image_pairs)))
            
            for single_results in results:
                if single_results["img_path"] == "":
                    continue
                data["img_path"].extend(single_results["img_path"])
                data["img_id"].extend(single_results["img_id"])
                data["target"].extend(single_results["target"])
                data["bb_y_start"].extend(single_results["bb_y_start"])
                data["bb_y_end"].extend(single_results["bb_y_end"])
                data["bb_x_start"].extend(single_results["bb_x_start"])
                data["bb_x_end"].extend(single_results["bb_x_end"])
                data["target_len"].extend(single_results["target_len"])
        
        to_ret = pd.DataFrame(data)
        self.max_height = (to_ret["bb_y_end"] - to_ret["bb_y_start"]).max() + 150
        self.max_width = (to_ret["bb_x_end"] - to_ret["bb_x_start"]).max() + 150
        
        return to_ret
    

class AggregatedDataset(Dataset):
    datasets: List[Dataset]
    def __init__(self, rimes: RIMESDataset,
                 iam: IAMDataset, 
                 split:str,
                 only_lowercase: bool = False,
                 label_enc: Optional[LabelParser] = None,):
        super().__init__()
        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        
        self._split = split
        self.rimes = rimes
        self.iam = iam
        self._only_lowercase = only_lowercase
        self.label_enc = label_enc
        
        if self.label_enc is None:
            iamLabelEncoder = iam.label_enc
            rimesLabelEncoder = rimes.label_enc
            self.label_enc = LabelParser()
            self.label_enc.addClasses(iamLabelEncoder.classes)
            self.label_enc.addClasses(rimesLabelEncoder.classes)
        
    def __len__(self):
        return len(self.rimes) + len(self.iam)
    
    def __getitem__(self, idx):
        if idx < len(self.rimes):
            img, target = self.rimes[idx]
        if len(self.rimes) <= idx < len(self.iam):
            img, target = self.iam[idx - len(self.rimes)]
        
        return img, target
        
    

## Metrics

In [None]:

class CharacterErrorRate(Metric):

    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder
        
        self.add_state("cer_sum", default=torch.zeros(1, dtype=torch.float), dist_reduce_fx="sum")
        self.add_state("nr_samples",default=torch.zeros(1, dtype=torch.int64), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim
        eos_tkn_idx, sos_tkn_idx = list(
            self.label_encoder.encode_labels(["<EOS>", "<SOS>"])
        )

        if (preds[:, 0] == sos_tkn_idx).all():  # this should normally be the case
            preds = preds[:, 1:]

        eos_idxs_prd = (preds == eos_tkn_idx).float().argmax(1).tolist()
        eos_idxs_tgt = (target == eos_tkn_idx).float().argmax(1).tolist()

        for i, (p, t) in enumerate(zip(preds, target)):
            eos_idx_p, eos_idx_t = eos_idxs_prd[i], eos_idxs_tgt[i]
            p = p[:eos_idx_p] if eos_idx_p else p
            t = t[:eos_idx_t] if eos_idx_t else t
            p_str, t_str = map(tensor_to_str, (p, t))
            editd = editdistance.eval(p_str, t_str)

            self.cer_sum += editd/t.numel()
            self.nr_samples +=1

    def compute(self) -> torch.Tensor:
        return self.cer_sum / self.nr_samples.float()

class WordErrorRate(Metric):
    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder
        
        self.add_state("wer_sum", default=torch.zeros(1, dtype=torch.float), dist_reduce_fx="sum")
        self.add_state("nr_samples", default=torch.zeros(1, dtype=torch.int64), dist_reduce_fx="sum")


    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim

        eos_tkn_idx, sos_tkn_idx = self.label_encoder.encode_labels(["<EOS>", "<SOS>"])

        if (preds[:, 0] == sos_tkn_idx).all():
            preds = preds[:, 1:]

        eos_idxs_prd = (preds == eos_tkn_idx).float().argmax(1).tolist()
        eos_idxs_tgt = (target == eos_tkn_idx).float().argmax(1).tolist()

        for i, (p, t) in enumerate(zip(preds, target)):
            eos_idx_p, eos_idx_t = eos_idxs_prd[i], eos_idxs_tgt[i]
            p = (p[:eos_idx_p] if eos_idx_p else p).flatten().tolist()
            t = (t[:eos_idx_t] if eos_idx_t else t).flatten().tolist()
            if not t:
                continue
            
            p_words = "".join(self.label_encoder.decode_labels(p)).split()
            t_words = "".join(self.label_encoder.decode_labels(t)).split()
            editd = editdistance.eval(p_words, t_words)
            
            
            self.wer_sum += editd / len(t_words)
            self.nr_samples += 1
            
    def compute(self) -> torch.Tensor:
        """Compute Word Error Rate."""
        return self.wer_sum / self.nr_samples.float()

def tensor_to_str(t: torch.Tensor) -> str:
    return "".join(map(str, t.flatten().tolist()))

## Model class

In [None]:

class PosEmbedding1D(nn.Module):
    """
    Implements 1D sinusoidal embeddings.

    Adapted from 'The Annotated Transformer', http://nlp.seas.harvard.edu/2018/04/03/attention.html
    """

    def __init__(self, d_model, max_len=1000):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros((max_len, d_model), requires_grad=False)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Add a 1D positional embedding to an input tensor.

        Args:
            x (Tensor): tensor of shape (B, T, d_model) to add positional
                embedding to
        """
        _, T, _ = x.shape
        # assert T <= self.pe.size(0) \
        assert T <= self.pe.size(1), (
            f"Stored 1D positional embedding does not have enough dimensions for the current feature map. "
            f"Currently max_len={self.pe.size(1)}, T={T}. Consider increasing max_len such that max_len >= T."
        )
        return x + self.pe[:, :T]



class PosEmbedding2D(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        self.d_model = d_model
        pe_x = torch.zeros((max_len, d_model // 2), requires_grad=False)
        pe_y = torch.zeros((max_len, d_model // 2), requires_grad=False)

        pos = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(
            -math.log(10000.0) * torch.arange(0, d_model // 2, 2) / d_model
        )

        pe_y[:, 0::2] = torch.sin(pos * div_term)
        pe_y[:, 1::2] = torch.cos(pos * div_term)
        pe_x[:, 0::2] = torch.sin(pos * div_term)
        pe_x[:, 1::2] = torch.cos(pos * div_term)

        self.register_buffer("pe_x", pe_x)
        self.register_buffer("pe_y", pe_y)

    def forward(self, x):
        _, w, h, _ = x.shape

        pe_x_ = self.pe_x[:w, :].unsqueeze(1).expand(-1, h, -1)
        pe_y_ = self.pe_y[:h, :].unsqueeze(0).expand(w, -1, -1)

        pe = torch.cat([pe_y_, pe_x_], -1)
        pe = pe.unsqueeze(0)

        return x + pe


class encoderHTR(nn.Module):
    def __init__(self, d_model: int, encoder_type: str, dropout=0.1, bias=True):
        super().__init__()
        assert encoder_type in ["resnet18", "resnet34", "resnet50"], "Model not found"

        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.pos_embd = PosEmbedding2D(d_model)

        resnet = getattr(models, encoder_type)(pretrained=False)

        modules = list(resnet.children())
        cnv_1 = modules[0]
        cnv_1 = nn.Conv2d(
            1,
            cnv_1.out_channels,
            cnv_1.kernel_size,
            cnv_1.stride,
            cnv_1.padding,
            bias=cnv_1.bias
        )
        self.encoder = nn.Sequential(cnv_1, *modules[1:-2])
        self.linear = nn.Conv2d(resnet.fc.in_features, d_model, kernel_size=1)

    def forward(self, imgs):
        x = self.encoder(imgs.unsqueeze(1))
        x = self.linear(x).transpose(1, 2).transpose(2, 3)
        x = self.pos_embd(x)
        x = self.dropout(x)
        x = x.flatten(1, 2)

        return x


class decoderHTR(nn.Module):
    def __init__(self,
                 vocab_length,
                 max_seq_len,
                 eos_tkn_idx,
                 sos_tkn_idx,
                 pad_tkn_idx,
                 d_model,
                 num_layers,
                 nhead,
                 dim_ffn,
                 dropout,
                 activation="relu"):
        super().__init__()
        self.vocab_length = vocab_length
        self.max_seq_len = max_seq_len
        self.eos_idx = eos_tkn_idx
        self.sos_idx = sos_tkn_idx
        self.pad_idx = pad_tkn_idx
        self.d_model = d_model
        self.num_layers = num_layers
        self.nhead = nhead
        self.dim_ffn = dim_ffn
        self.pos_emb = PosEmbedding1D(d_model)
        self.emb = nn.Embedding(vocab_length, d_model)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model,
            nhead,
            dim_ffn,
            dropout,
            activation,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.clf = nn.Linear(d_model, vocab_length)
        self.drop = nn.Dropout(dropout)

    def forward(self, memory: torch.Tensor):
        B, _, _ = memory.shape
        all_logits = []
        sampled_ids = [torch.full([B], self.sos_idx).to(memory.device)]
        tgt = self.pos_emb(
            self.emb(sampled_ids[0]).unsqueeze(1) * math.sqrt(self.d_model)
        )
        tgt = self.drop(tgt)
        eos_sampled = torch.zeros(B).bool()

        for t in range(self.max_seq_len):
            tgt_mask = self.subsequent_mask(len(sampled_ids)).to(memory.device)
            out = self.decoder(tgt, memory, tgt_mask=tgt_mask)
            logits = self.clf(out[:, -1, :])
            _, pred = torch.max(logits, -1)
            all_logits.append(logits)
            sampled_ids.append(pred)
            for i, pr in enumerate(pred):
                if pr == self.eos_idx:
                    eos_sampled[i] = True
            if eos_sampled.all():
                break

            tgt_ext = self.drop(
                self.pos_emb.pe[:, len(sampled_ids)]
                + self.emb(pred) * math.sqrt(self.d_model)
            ).unsqueeze(1)
            tgt = torch.cat([tgt, tgt_ext], 1)
        sampled_ids = torch.stack(sampled_ids, 1)
        all_logits = torch.stack(all_logits, 1)

        eos_idxs = (sampled_ids == self.eos_idx).float().argmax(1)
        for i in range(B):
            if eos_idxs[i] != 0:
                sampled_ids[i, eos_idxs[i] + 1:] = self.pad_idx

        return all_logits, sampled_ids

    def forward_teacher_forcing(self, memory: torch.Tensor, tgt: torch.Tensor):
        B, T = tgt.shape
        tgt = torch.cat(
            [
                torch.full([B], self.sos_idx).unsqueeze(1).to(memory.device),
                tgt[:, :-1]
            ],
            1
        )

        tgt_key_masking = tgt == self.pad_idx
        tgt_mask = self.subsequent_mask(T).to(tgt.device)

        tgt = self.pos_emb(self.emb(tgt) * math.sqrt(self.d_model))
        tgt = self.drop(tgt)
        out = self.decoder(
            tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_masking
        )
        logits = self.clf(out)
        return logits

    @staticmethod
    def subsequent_mask(size: int):
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        return mask == 1


class FullPageHTR(nn.Module):
    encoder: encoderHTR
    decoder: decoderHTR
    cer_metric: CharacterErrorRate
    wer_metric: WordErrorRate
    loss_fn: Callable
    label_encoder: LabelParser

    def __init__(self, label_encoder: LabelParser,
                 max_seq_len=500,
                 d_model=1024,
                 num_layers=6,
                 nhead=4,
                 dim_feedforward=1024,
                 encoder_name="resnet18",
                 drop_enc=0.1,
                 drop_dec=0.1,
                 activ_dec="gelu",
                 loss_type="cross_entropy",
                 label_smoothing=0.0,
                 vocab_len: Optional[int] = None):
        super().__init__()
        self.eos_token_idx, self.sos_token_idx, self.pad_token_idx = label_encoder.encode_labels(
            ["<EOS>", "<SOS>", "<PAD>"]
        )

        self.encoder = encoderHTR(d_model, encoder_type=encoder_name, dropout=drop_enc)
        self.decoder = decoderHTR(vocab_length=(vocab_len or len(label_encoder.classes)),
                                  max_seq_len=max_seq_len,
                                  eos_tkn_idx=self.eos_token_idx,
                                  sos_tkn_idx=self.sos_token_idx,
                                  pad_tkn_idx=self.pad_token_idx,
                                  d_model=d_model,
                                  num_layers=num_layers,
                                  nhead=nhead,
                                  dim_ffn=dim_feedforward,
                                  dropout=drop_dec,
                                  activation=activ_dec)
        self.label_encoder = label_encoder
        self.cer_metric = CharacterErrorRate(label_encoder)
        self.wer_metric = WordErrorRate(label_encoder)
        self.log_softmax = nn.LogSoftmax()
        
        assert loss_type in ["cross_entropy", "ctc_loss"]
        self.loss_type = loss_type
        if loss_type == "cross_entropy":
            self.loss_fn = nn.CrossEntropyLoss(
                ignore_index=self.pad_token_idx,
                label_smoothing=label_smoothing
            )
        elif loss_type == "ctc_loss":
            self.loss_fn = nn.CTCLoss(
                blank=self.pad_token_idx
            )

    def forward(self, imgs: torch.Tensor, targets: Optional[torch.Tensor] = None):
        logits, sampled_ids = self.decoder(self.encoder(imgs))
        loss = None
        if targets is not None:
            if self.loss_type == "cross_entropy":
                loss = self.loss_fn(
                    logits[:, : targets.size(1), :].transpose(1, 2),
                    targets[:, : logits.size(1)],
                )
            elif self.loss_type == "ctc_loss":
                logits = self.log_softmax(logits)
                _, preds = logits[:,: targets.size(1)].max(-1)
                lengths = preds == self.eos_token_idx
                _, target_lengths = (targets == self.eos_token_idx).max(-1)

                # Calculate predicted lengths
                pred_length = []
                for batch in lengths:
                    val, eos_pos = batch.max(-1)
                    if val == 0:
                        pred_length.append(batch.size(-1))
                    else:
                        pred_length.append(eos_pos.item())

                # Optimize memory usage by avoiding temporary list
                pred_length_tensor = torch.tensor(pred_length, dtype=torch.long)

                logits = logits[:,: targets.size(1)].permute(1, 0, 2)
                loss = self.loss_fn(logits, targets, tuple(pred_length), target_lengths.tolist())
            
        return logits, sampled_ids, loss

    def forward_teacher_forcing(self, imgs: torch.Tensor, targets: torch.Tensor):
        memory = self.encoder(imgs)
        logits = self.decoder.forward_teacher_forcing(memory, targets)
        if self.loss_type == "cross_entropy":
            loss = self.loss_fn(logits.transpose(1, 2), targets)
        elif self.loss_type == "ctc_loss":
            logits = self.log_softmax(logits)
            _, preds = logits.max(-1)
            lengths = preds == self.eos_token_idx
            _, target_lengths = (targets == self.eos_token_idx).max(-1)

            # Calculate predicted lengths
            pred_length = []
            for batch in lengths:
                val, eos_pos = batch.max(-1)
                if val == 0:
                    pred_length.append(batch.size(-1))
                else:
                    pred_length.append(eos_pos.item())
            
            logits = logits.permute(1, 0, 2)
            loss = self.loss_fn(logits, targets, tuple(pred_length), tuple(target_lengths.tolist()) )
            print(loss.item())
        return logits, loss

    def calculate_metrics(self, preds: torch.Tensor, targets: torch.Tensor):
        self.cer_metric.reset()
        self.wer_metric.reset()

        cer = self.cer_metric(preds, targets)
        wer = self.wer_metric(preds, targets)
        return {"CER": cer, "WER": wer}

    def set_num_output_classes(self, n_classes: int):
        old_vocab_len = self.decoder.vocab_length
        self.decoder.vocab_length = n_classes
        self.decoder.clf = nn.Linear(self.decoder.d_model, n_classes)

        new_embs = nn.Embedding(n_classes, self.decoder.d_model)
        with torch.no_grad():
            new_embs.weight[:old_vocab_len] = self.decoder.emb.weight
            self.decoder.emb = new_embs


## Model Trainer Class

In [None]:
from copy import copy
from pytorch_lightning import seed_everything
!ls /kaggle/input/iam-dataset/raw/
seed_everything(12345)
ds = IAMDataset(root="../input/iam-dataset/raw", label_enc=None, parse_method="form" ,split="train")
ds_train, ds_val = torch.utils.data.random_split(ds, [math.ceil(0.8 * len(ds)), math.floor(0.2 * len(ds))])

ds_val.data = copy(ds)
ds_val.data.set_transforms_for_split("val")
train_len = len(ds_train)
val_len = len(ds_val)

In [None]:

batch_size = 2
pad_tkn_idx, eos_tkn_idx = ds.label_enc.encode_labels(["<PAD>", "<EOS>"])
collate_fn = partial(
        IAMDataset.collate_fn, pad_val=pad_tkn_idx, eos_tkn_idx=eos_tkn_idx
)
num_workers = 4
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
dl_val = DataLoader(
    ds_val,
    batch_size=2 * batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
train_len //= batch_size
val_len //= 2 * batch_size

 import gc

 try:
   device = "cuda"
   torch.cuda.empty_cache()
   gc.collect()

   model = FullPageHTR(ds.label_enc).to(device)
   optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
   trainer = ModelTrainer("Testing_run", model, ds_name="IAM_forms" , train_data=dl_train, val_data=dl_val, optimizer=optimizer, num_epochs=100, device=device, normalization_steps=56)

   wandb.finish()
   trainer.train(train_len, val_len, wanda=True)
 except RuntimeError:
   del model
   print("Error time!!")

In [None]:
synth_ds = IAMDatasetSynthetic(iam_dataset=ds, synth_prob=0.3)
print("Initialized synthetic dataset")
ds_synth_train, ds_synth_val = torch.utils.data.random_split(synth_ds, [math.ceil(0.8 * len(synth_ds)), math.floor(0.2 * len(synth_ds))])

ds_synth_val.data = copy(synth_ds)
ds_synth_val.data.iam_dataset.set_transforms_for_split("val")
train_len = len(ds_synth_train)
val_len = len(ds_synth_val)
print(train_len, val_len)

In [None]:
batch_size = 56
pad_tkn_idx, eos_tkn_idx = ds.label_enc.encode_labels(["<PAD>", "<EOS>"])
collate_fn = partial(
        IAMDataset.collate_fn, pad_val=pad_tkn_idx, eos_tkn_idx=eos_tkn_idx
)
num_workers = 4
dl_synth_train = DataLoader(
    ds_synth_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
dl_synth_val = DataLoader(
    ds_synth_val,
    batch_size=2 * batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
train_len //= batch_size
val_len //= 2 * batch_size

import gc

try:
    device = "cuda"
    torch.cuda.empty_cache()
    gc.collect()

    model = FullPageHTR(ds.label_enc, loss_type="cross_entropy", encoder_name="resnet34").to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
    trainer = ModelTrainer("Synth_0.3_ctc_loss", model, ds_name="IAM_forms_synthetic" , train_data=dl_train_synth, val_data=dl_val_synth, optimizer=optimizer, num_epochs=100, device=device, normalization_steps=56)

    wandb.finish()
    trainer.train(train_len, val_len, wanda=True)
finally:
    del model
 try:
  
 except RuntimeError:
   del model
   print("Error time!!")

In [None]:
import pytorch_lightning as pl

class LitFullPageHTREncoderDecoder(pl.LightningModule):
    model: FullPageHTR

    """
    Pytorch Lightning module that acting as a wrapper around the
    FullPageHTREncoderDecoder class.

    Using a PL module allows the model to be used in conjunction with a Pytorch
    Lightning Trainer, and takes care of logging relevant metrics to Tensorboard.
    """

    def __init__(
        self,
        label_encoder: LabelParser,
        learning_rate: float = 0.0002,
        label_smoothing: float = 0.0,
        max_seq_len: int = 500,
        d_model: int = 260,
        num_layers: int = 6,
        nhead: int = 4,
        dim_feedforward: int = 1024,
        encoder_name: str = "resnet18",
        drop_enc: int = 0.1,
        drop_dec: int = 0.1,
        activ_dec: str = "gelu",
        loss_function: str = "cross_entropy",
        vocab_len: Optional[int] = None,  # if not specified len(label_encoder) is used
        params_to_log: Optional[Dict[str, Union[str, float, int]]] = None,
    ):
        super().__init__()

        # Save hyperparameters.
        self.learning_rate = learning_rate
        if params_to_log is not None:
            self.save_hyperparameters(params_to_log)
        self.save_hyperparameters(
            "learning_rate",
            "d_model",
            "num_layers",
            "nhead",
            "dim_feedforward",
            "max_seq_len",
            "encoder_name",
            "drop_enc",
            "drop_dec",
            "activ_dec",
        )

        # Initialize the model.
        self.model = FullPageHTR(
            label_encoder=label_encoder,
            max_seq_len=max_seq_len,
            d_model=d_model,
            num_layers=num_layers,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            encoder_name=encoder_name,
            drop_enc=drop_enc,
            drop_dec=drop_dec,
            activ_dec=activ_dec,
            vocab_len=vocab_len,
            label_smoothing=label_smoothing,
        )

    @property
    def encoder(self):
        return self.model.encoder

    @property
    def decoder(self):
        return self.model.decoder

    def forward(self, imgs: Tensor, targets: Optional[Tensor] = None):
        return self.model(imgs, targets)

    def training_step(self, batch, batch_idx):
        imgs, targets = batch
        logits, loss = self.model.forward_teacher_forcing(imgs, targets)
        self.log("train_loss", loss, sync_dist=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        return self.val_or_test_step(batch)

    def test_step(self, batch, batch_idx):
        return self.val_or_test_step(batch)

    def val_or_test_step(self, batch) -> Tensor:
        imgs, targets = batch
        logits, _, loss = self(imgs, targets)
        _, preds = logits.max(-1)

        # Update and log metrics.
        self.model.cer_metric(preds, targets)
        self.model.wer_metric(preds, targets)
        self.log("char_error_rate", self.model.cer_metric, on_step=True ,prog_bar=True)
        self.log("word_error_rate", self.model.wer_metric, on_step=True ,prog_bar=True)
        self.log("val_loss", loss, sync_dist=True, prog_bar=True, on_step=True)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)


In [None]:
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import matplotlib.pyplot as plt

PREDICTIONS_TO_LOG = {
    "word": 10,
    "line": 6,
    "form": 1,
}


class LogWorstPredictions(Callback):
    """
    At the end of training, log the worst image prediction, meaning the predictions
    with the highest character error rates.
    """

    def __init__(
        self,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloader: Optional[DataLoader] = None,
        test_dataloader: Optional[DataLoader] = None,
        training_skipped: bool = False,
        data_format: str = "word",
    ):
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.training_skipped = training_skipped
        self.data_format = data_format

    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.training_skipped and self.val_dataloader is not None:
            self.log_worst_predictions(
                self.val_dataloader, trainer, pl_module, mode="val"
            )

    def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.test_dataloader is not None:
            self.log_worst_predictions(
                self.test_dataloader, trainer, pl_module, mode="test"
            )

    def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.train_dataloader is not None:
            self.log_worst_predictions(
                self.train_dataloader, trainer, pl_module, mode="train"
            )
        if self.val_dataloader is not None:
            self.log_worst_predictions(
                self.val_dataloader, trainer, pl_module, mode="val"
            )

    def log_worst_predictions(
        self,
        dataloader: DataLoader,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        mode: str = "train",
    ):
        img_cers = []
        device = "cuda:0" if pl_module.on_gpu else "cpu"
        if not self.training_skipped:
            self._load_best_model(trainer, pl_module)
            pl_module = trainer.model

        print(f"Running {mode} inference on best model...")

        # Run inference on the validation set.
        pl_module.eval()
        for img, target in dataloader:
            assert target.ndim == 2, target.ndim
            cer_metric = pl_module.model.cer_metric
            with torch.inference_mode():
                logits, preds, _ = pl_module(img.to(device), target.to(device))
                for prd, tgt, im in zip(preds, target, img):
                    cer_metric.reset()
                    cer = cer_metric(prd.unsqueeze(0), tgt.unsqueeze(0)).item()
                    img_cers.append((im, cer, prd, tgt))

        # Log the worst k predictions.
        to_log = PREDICTIONS_TO_LOG[self.data_format] * 2
        img_cers.sort(key=lambda x: x[1], reverse=True)  # sort by CER
        img_cers = img_cers[:to_log]
        fig = plt.figure(figsize=(24, 16))
        for i, (im, cer, prd, tgt) in enumerate(img_cers):
            pred_str, target_str = decode_prediction_and_target(
                prd, tgt, pl_module.model.label_encoder, pl_module.decoder.eos_tkn_idx
            )

            # Create plot.
            ncols = 4 if self.data_format == "word" else 2
            nrows = math.ceil(to_log / ncols)
            ax = fig.add_subplot(nrows, ncols, i + 1, xticks=[], yticks=[])
            matplotlib_imshow(im, IAMDataset.MEAN, IAMDataset.STD)
            ax.set_title(f"Pred: {pred_str} (CER: {cer:.2f})\nTarget: {target_str}")

        # # Log the results to Tensorboard.
        # tensorboard = trainer.logger.experiment
        # tensorboard.add_figure(f"{mode}: worst predictions", fig, trainer.global_step)
        trainer.logger.experiment.log({f"{mode}: predictions vs targets": wandb.Image(fig)})

        plt.close(fig)

        print("Done.")

    @staticmethod
    def _load_best_model(trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        ckpt_callback = None
        for cb in trainer.callbacks:
            if isinstance(cb, ModelCheckpoint):
                ckpt_callback = cb
                break
        assert ckpt_callback is not None, "ModelCheckpoint not found in callbacks."
        best_model_path = ckpt_callback.best_model_path

        print(f"Loading best model at {best_model_path}")
        label_encoder = pl_module.model.label_encoder
        model = LitFullPageHTREncoderDecoder.load_from_checkpoint(
            best_model_path,
            label_encoder=label_encoder,
        )
        trainer.model.load_state_dict(model.state_dict())


class LogModelPredictions(Callback):
    """
    Use a fixed test batch to monitor model predictions at the end of every epoch.

    Specifically: it generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's prediction alongside the actual target.
    """

    def __init__(
        self,
        label_encoder: LabelParser,
        val_batch: Tuple[torch.Tensor, torch.Tensor],
        use_gpu: bool = True,
        data_format: str = "word",
        train_batch: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ):
        self.label_encoder = label_encoder
        self.val_batch = val_batch
        self.use_gpu = use_gpu
        self.data_format = data_format
        self.train_batch = train_batch

    def on_validation_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ):
        self._predict_intermediate(trainer, pl_module, split="val")

    def on_train_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ):
        if self.train_batch is not None:
            self._predict_intermediate(trainer, pl_module, split="train")

    def _predict_intermediate(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", split="val"
    ):
        """Make predictions on a fixed batch of data and log the results to Tensorboard."""

        # Make predictions.
        if split == "train":
            imgs, targets = self.train_batch
        else:  # split == "val"
            imgs, targets = self.val_batch
        with torch.inference_mode():
            pl_module.eval()
            _, preds, _ = pl_module(imgs.cuda() if self.use_gpu else imgs)

        # Decode predictions and generate a plot.
        fig = plt.figure(figsize=(12, 16))
        for i, (p, t) in enumerate(zip(preds, targets)):
            pred_str, target_str = decode_prediction_and_target(
                p, t, self.label_encoder, pl_module.decoder.eos_idx
            )

            # Create plot.
            ncols = 2 if self.data_format == "word" else 1
            nrows = math.ceil(preds.size(0) / ncols)
            ax = fig.add_subplot(nrows, ncols, i + 1, xticks=[], yticks=[])
            matplotlib_imshow(imgs[i], IAMDataset.MEAN, IAMDataset.STD)
            ax.set_title(f"Pred: {pred_str}\nTarget: {target_str}")

        # Log the results to Tensorboard.
#         tensorboard = trainer.logger.experiment
#         tensorboard.add_figure(
#             f"{split}: predictions vs targets", fig, trainer.global_step
#         )
        trainer.logger.experiment.log({f"{split}: predictions vs targets": wandb.Image(fig)})

        plt.close(fig)

In [None]:
!pip install wandb --upgrade
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ModelSummary
from torch.utils.data import DataLoader, Subset
from pytorch_lightning.loggers import WandbLogger
import wandb

wandb.login(key="0350b0cc5bd9521bb37a798168d31b6b65e9caca")
wandb_logger = WandbLogger(project="bach_thesis", log_model="all")
callbacks = [
        ModelCheckpoint(
            save_top_k=(3),
            mode="min",
            monitor="word_error_rate",
            filename="{epoch}-{char_error_rate:.4f}-{word_error_rate:.4f}",
        ),
        ModelSummary(max_depth=2),
        LitProgressBar(),
        LogWorstPredictions(
            dl_synth_train,
            dl_synth_val,
            training_skipped=False,
            data_format="form",
        ),
        LogModelPredictions(
            ds.label_enc,
            val_batch=next(
                iter(
                    DataLoader(
                        Subset(
                            ds_val,
                            random.sample(
                                range(len(ds_val)), 1
                            ),
                        ),
                        batch_size=1,
                        shuffle=False,
                        collate_fn=collate_fn,
                        num_workers=2,
                        pin_memory=True,
                    )
                )
            ),
            train_batch=next(
                iter(
                    DataLoader(
                        Subset(
                            ds_train,
                            random.sample(
                                range(len(ds_train)),
                                1,
                            ),
                        ),
                        batch_size=1,
                        shuffle=False,
                        collate_fn=collate_fn,
                        num_workers=2,
                        pin_memory=True,
                    )
                )
            ),
            data_format="form",
            use_gpu=True,
        ),
        EarlyStopping(
                    monitor="word_error_rate",
                    patience=50,
                    verbose=True,
                    mode="min",
                    check_on_train_epoch_end=False,
                )
    ]

trainer = Trainer(
    max_epochs=3000,
    accelerator="tpu", 
    devices=4,
    callbacks = callbacks,
    logger=wandb_logger,
    fast_dev_run=False)
model = LitFullPageHTREncoderDecoder(ds.label_enc,
        learning_rate = 0.0001,
        label_smoothing = 0.0,
        max_seq_len = 500,
        d_model = 260,
        num_layers = 6,
        nhead = 4,
        dim_feedforward = 1024,
        encoder_name = "resnet34",
        drop_enc = 0.1,
        drop_dec = 0.1,
        activ_dec = "gelu",
        loss_function = "cross_entropy",
        vocab_len = None, 
    )
trainer.fit(model, dl_synth_train, dl_synth_val)