# Utils 

In [25]:
import pickle
import math
from random import random
import xml.etree.ElementTree as ET
import numpy as np
import albumentations as A
import cv2 as cv
from dataclasses import dataclass, field
from functools import partial
from random import randint
import html
import random
from pathlib import Path
from typing import Union, Tuple, Dict, Sequence, Optional, List, Any, Callable, Optional
import pandas as pd
from torch import Tensor, nn
from torch.utils.data import Dataset
from PIL import Image
from torchmetrics import Metric
import torch
from torchvision import models
from torch.utils.data import Dataset
import editdistance
import wandb
from torch.utils.data.dataloader import DataLoader
from torch.optim import Optimizer
import tqdm

def pickle_load(file) -> Any:
    with open(file, "rb") as f:
        return pickle.load(f)

def pickle_save(obj, file):
    with open(file, "wb") as f:
        pickle.dump(obj, f)

def read_xml(file: Union[Path, str]) -> ET.Element:
    tree = ET.parse(file)
    root = tree.getroot()

    return root

def find_child_by_tag(element: ET.Element, tag: str, value: str) -> Union[ET.Element, None]:
    for child in element:
        if child.get(tag) == value:
            return child
    return None

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def dpi_adjusting(img: np.ndarray, scale: float, **kwargs) -> np.ndarray:
    height, width = img.shape[:2]
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

In [26]:
class LabelParser:
    def __init__(self):
        self.classes = None
        self.vocab_size = None
        self.class_to_idx = None
        self.idx_to_class = None

    def fit(self, classes: Sequence[str]):
        self.classes = list(classes)
        self.vocab_size = len(classes)
        self.idx_to_class = dict(enumerate(classes))
        self.class_to_idx = {cls: i for i, cls in self.idx_to_class.items()}

        return self

    def addClasses(self, classes: List[str]):
        all_classes = sorted(set(self.classes + classes))

        self.fit(all_classes)

    def encode_labels(self, sequence: Sequence[str]):
        self._check_fitted()
        return [self.class_to_idx[c] for c in sequence]

    def decode_labels(self, sequence: Sequence[int]):
        self._check_fitted()
        return [self.idx_to_class[c] for c in sequence]

    def _check_fitted(self):
        if self.classes is None:
            raise ValueError("LabelParser class was not fitted yet")

## Image transformations

In [27]:


class SafeRandomScale(A.RandomScale):
    def apply(self, img, scale=0, interpolation=cv.INTER_LINEAR, **params):
        height, width = img.shape[:2]
        new_height, new_width = int(height * scale), int(width * scale)
        if new_height <= 0 or new_width <= 0:
            return img
        return super().apply(img, scale, interpolation, **params)

def adjust_dpi(img: np.ndarray, scale: float, **kwargs):
    height, width = img.shape
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

def randomly_displace_and_pad(
    img: np.ndarray,
    padded_size: Tuple[int, int],
    crop_if_necessary: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Randomly displace an image within a frame, and pad zeros around the image.

    Args:
        img (np.ndarray): image to process
        padded_size (Tuple[int, int]): (height, width) tuple indicating the size of the frame
        crop_if_necessary (bool): whether to crop the image if its size exceeds that
            of the frame
    """
    frame_h, frame_w = padded_size
    img_h, img_w = img.shape
    if frame_h < img_h or frame_w < img_w:
        if crop_if_necessary:
            print(
                "WARNING (`randomly_displace_and_pad`): cropping input image before "
                "padding because it exceeds the size of the frame."
            )
            img_h, img_w = min(img_h, frame_h), min(img_w, frame_w)
            img = img[:img_h, :img_w]
        else:
            raise AssertionError(
                f"Frame is smaller than the image: ({frame_h}, {frame_w}) vs. ({img_h},"
                f" {img_w})"
            )

    res = np.zeros((frame_h, frame_w), dtype=img.dtype)

    pad_top =  randint(0, frame_h - img_h)
    pad_bottom = pad_top + img_h
    pad_left = randint(0, frame_w - img_w)
    pad_right = pad_left + img_w

    res[pad_top:pad_bottom, pad_left:pad_right] = img
    return res

@dataclass
class ImageTransforms:
    max_img_size: Tuple[int, int]  # (h, w)
    normalize_params: Tuple[float, float]  # (mean, std)
    scale: float = (
        0.5
    )
    random_scale_limit: float = 0.1
    random_rotate_limit: int = 10

    train_trnsf: A.Compose = field(init=False)
    test_trnsf: A.Compose = field(init=False)

    def __post_init__(self):
        scale, random_scale_limit, random_rotate_limit, normalize_params =(
            self.scale,
            self.random_scale_limit,
            self.random_rotate_limit,
            self.normalize_params
        )

        max_img_h, max_img_w = self.max_img_size
        max_scale = scale + scale * random_scale_limit
        padded_h, padded_w = math.ceil(max_scale * max_img_h), math.ceil(max_scale * max_img_w)

        self.train_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            SafeRandomScale(scale_limit=random_scale_limit, p=0.5),
            A.SafeRotate(
                limit = random_rotate_limit,
                border_mode = cv.BORDER_CONSTANT,
                value = 0
            ),
            A.RandomBrightnessContrast(),
            A.Perspective(scale=(0.01, 0.05)),
            A.GaussNoise(),
            A.Normalize(*normalize_params),
            A.Lambda(
                image=partial(
                    randomly_displace_and_pad,
                    padded_size=(padded_h, padded_w),
                    crop_if_necessary=False,
                )
            )
        ])

        self.test_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            A.Normalize(*normalize_params),
            A.PadIfNeeded(
                max_img_h, max_img_w, border_mode=cv.BORDER_CONSTANT, value=0
            )
        ])


## IAM Dataset and Synthetic Dataset

In [28]:

class IAMDataset(Dataset):
    MEAN = 0.8275
    STD = 0.2314
    MAX_FORM_HEIGHT = 3542
    MAX_FORM_WIDTH = 2479

    MAX_SEQ_LENS = {
        "word": 55,
        "line": 90,
        "form": 700,
    }  # based on the maximum seq lengths found in the dataset

    _pad_token = "<PAD>"
    _sos_token = "<SOS>"
    _eos_token = "<EOS>"

    root: Path
    data: pd.DataFrame
    label_enc: LabelParser
    parse_method: str
    only_lowercase: bool
    transforms: Optional[A.Compose]
    id_to_idx: Dict[str, int]
    _split: str
    _return_writer_id: Optional[bool]

    def __init__(
        self,
        root: Union[Path, str],
        parse_method: str,
        split: str,
        return_writer_id: bool = False,
        only_lowercase: bool = False,
        label_enc: Optional[LabelParser] = None,
    ):
        super().__init__()
        _parse_methods = ["form", "line", "word"]
        err_message = (
            f"{parse_method} is not a possible parsing method: {_parse_methods}"
        )
        assert parse_method in _parse_methods, err_message

        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message

        self._split = split
        self._return_writer_id = return_writer_id
        self.only_lowercase = only_lowercase
        self.root = Path(root)
        self.label_enc = label_enc
        self.parse_method = parse_method

        # Process the data.
        if not hasattr(self, "data"):
            if self.parse_method == "form":
                self.data = self._get_forms()
            elif self.parse_method == "word":
                self.data = self._get_words(skip_bad_segmentation=True)
            elif self.parse_method == "line":
                self.data = self._get_lines()

        # Create the label encoder.
        if self.label_enc is None:
            vocab = [self._pad_token, self._sos_token, self._eos_token]
            s = "".join(self.data["target"].tolist())
            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc = LabelParser().fit(vocab)
        if not "target_enc" in self.data.columns:
            self.data.insert(
                2,
                "target_enc",
                self.data["target"].apply(
                    lambda s: np.array(
                        self.label_enc.encode_labels(
                            [c for c in (s.lower() if self.only_lowercase else s)]
                        )
                    )
                ),
            )

        self.transforms = self._get_transforms(split)
        self.id_to_idx = {
            Path(self.data.iloc[i]["img_path"]).stem: i for i in range(len(self))
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img = cv.imread(data["img_path"], cv.IMREAD_GRAYSCALE)
        if all(col in data.keys() for col in ["bb_y_start", "bb_y_end"]):
            # Crop the image vertically.
            img = img[data["bb_y_start"] : data["bb_y_end"], :]
        assert isinstance(img, np.ndarray), (
            f"Error: image at path {data['img_path']} is not properly loaded. "
            f"Is there something wrong with this image?"
        )
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        if self._return_writer_id:
            return img, data["writer_id"], data["target_enc"]
        return img, data["target_enc"]

    def get_max_height(self):
        return (self.data["bb_y_end"] - self.data["bb_y_start"]).max()

    @property
    def vocab(self):
        return self.label_enc.classes

    @staticmethod
    def collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
        pad_val: int,
        eos_tkn_idx: int,
        dataset_returns_writer_id: bool = False,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if dataset_returns_writer_id:
            imgs, writer_ids, targets = zip(*batch)
        else:
            imgs, targets = zip(*batch)

        img_sizes = [im.shape for im in imgs]
        if (
            not len(set(img_sizes)) == 1
        ):  # images are of varying sizes, so pad them to the maximum size in the batch
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode=cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
        imgs = np.stack(imgs, axis=0)

        seq_lengths = [t.shape[0] for t in targets]
        targets_padded = np.full((len(targets), max(seq_lengths) + 1), pad_val)
        for i, t in enumerate(targets):
            targets_padded[i, : seq_lengths[i]] = t
            targets_padded[i, seq_lengths[i]] = eos_tkn_idx

        imgs, targets_padded = torch.tensor(imgs), torch.tensor(targets_padded)
        if dataset_returns_writer_id:
            return imgs, targets_padded, torch.tensor(writer_ids)
        return imgs, targets_padded

    def set_transforms_for_split(self, split: str):
        _splits = ["train", "val", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        self.transforms = self._get_transforms(split)

    def _get_transforms(self, split: str) -> A.Compose:
        max_img_w = self.MAX_FORM_WIDTH

        if self.parse_method == "form":
            max_img_h = (self.data["bb_y_end"] - self.data["bb_y_start"]).max()
        else:  # word or line
            max_img_h = self.MAX_FORM_HEIGHT

        transforms = ImageTransforms(
            (max_img_h, max_img_w), (IAMDataset.MEAN, IAMDataset.STD)
        )

        if split == "train":
            return transforms.train_trnsf
        elif split == "test" or split == "val":
            return transforms.test_trnsf

    def statistics(self) -> Dict[str, float]:
        assert len(self) > 0
        tmp = self.transforms
        self.transforms = None
        mean, std, cnt = 0, 0, 0
        for img, _ in self:
            mean += np.mean(img)
            std += np.var(img)
            cnt += 1
        mean /= cnt
        std = np.sqrt(std / cnt)
        self.transforms = tmp
        return {"mean": mean, "std": std}

    def _get_forms(self) -> pd.DataFrame:
        """Read all form images from the IAM dataset.

        Returns:
            pd.DataFrame
                A pandas dataframe containing the image path, image id, target, vertical
                upper bound, vertical lower bound, and target length.
        """
        data = {
            "img_path": [],
            "img_id": [],
            "target": [],
            "bb_y_start": [],
            "bb_y_end": [],
            "target_len": [],
        }
        for form_dir in ["formsA-D", "formsE-H", "formsI-Z"]:
            dr = self.root / form_dir
            for img_path in dr.iterdir():
                doc_id = img_path.stem
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))

                # Based on some empiricial evaluation, the 'asy' and 'dsy'
                # attributes of a line xml tag seem to correspond to its upper and
                # lower bound, respectively. We add padding of 150 pixels.
                bb_y_start = int(xml_root[1][0].get("asy")) - 150
                bb_y_end = int(xml_root[1][-1].get("dsy")) + 150

                form_text = []
                for line in xml_root.iter("line"):
                    form_text.append(html.unescape(line.get("text", "")))

                img_w, img_h = Image.open(str(img_path)).size
                data["img_path"].append(str(img_path))
                data["img_id"].append(doc_id)
                data["target"].append("\n".join(form_text))
                data["bb_y_start"].append(bb_y_start)
                data["bb_y_end"].append(bb_y_end)
                data["target_len"].append(len("\n".join(form_text)))
        return pd.DataFrame(data).sort_values(
            "target_len"
        )  # by default, sort by target length

    def _get_lines(self, skip_bad_segmentation: bool = False) -> pd.DataFrame:
        """Read all line images from the IAM dataset.

        Args:
            skip_bad_segmentation (bool): skip lines that have the
                segmentation='err' xml attribute
        Returns:
            List of 2-tuples, where each tuple contains the path to a line image
            along with its ground truth text.
        """
        data = {"img_path": [], "img_id": [], "target": []}
        root = self.root / "lines"
        for d1 in root.iterdir():
            for d2 in d1.iterdir():
                doc_id = d2.name
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))
                for img_path in d2.iterdir():
                    target = self._find_line(
                        xml_root, img_path.stem, skip_bad_segmentation
                    )
                    if target is not None:
                        data["img_path"].append(str(img_path.resolve()))
                        data["img_id"].append(doc_id)
                        data["target"].append(target)
        return pd.DataFrame(data)

    def _get_words(self, skip_bad_segmentation: bool = False) -> pd.DataFrame:
        """Read all word images from the IAM dataset.

        Args:
            skip_bad_segmentation (bool): skip lines that have the
                segmentation='err' xml attribute
        Returns:
            List of 2-tuples, where each tuple contains the path to a word image
            along with its ground truth text.
        """
        data = {"img_path": [], "img_id": [], "writer_id": [], "target": []}
        root = self.root / "words"
        for d1 in root.iterdir():
            if d1.is_file():
                continue

            for d2 in d1.iterdir():

                doc_id = d2.name
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))
                writer_id = int(xml_root.get("writer-id"))
                for img_path in d2.iterdir():
                    img = cv.imread(str(img_path.resolve()), cv.IMREAD_GRAYSCALE)
                    if isinstance(img, np.ndarray):
                        target = self._find_word(
                            xml_root, img_path.stem, skip_bad_segmentation
                        )
                        if target is not None:
                            data["img_path"].append(str(img_path.resolve()))
                            data["img_id"].append(doc_id)
                            data["writer_id"].append(writer_id)
                            data["target"].append(target)
        return pd.DataFrame(data)

    def _find_line(
        self,
        xml_root: ET.Element,
        line_id: str,
        skip_bad_segmentation: bool = False,
    ) -> Union[str, None]:
        line = find_child_by_tag(xml_root[1].findall("line"), "id", line_id)
        if line is not None and not (
            skip_bad_segmentation and line.get("segmentation") == "err"
        ):
            return html.unescape(line.get("text"))
        return None

    def _find_word(
        self,
        xml_root: ET.Element,
        word_id: str,
        skip_bad_segmentation: bool = False,
    ) -> Union[str, None]:
        line_id = "-".join(word_id.split("-")[:-1])
        line = find_child_by_tag(xml_root[1].findall("line"), "id", line_id)
        if line is not None and not (
            skip_bad_segmentation and line.get("segmentation") == "err"
        ):
            word = find_child_by_tag(line.findall("word"), "id", word_id)
            if word is not None:
                return html.unescape(word.get("text"))
        return None

class IAMSyntheticDataGenerator(Dataset):
    """
    Data generator that creates synthetic line/form images by stitching together word
    images from the IAM dataset.
    Calling `__getitem__()` samples a newly generated synthetic image every time
    it is called.
    """

    PUNCTUATION = [",", ".", ";", ":", "'", '"', "!", "?"]

    def __init__(
        self,
        iam_root: Union[str, Path],
        label_encoder: Optional[LabelParser] = None,
        transforms: Optional[A.Compose] = None,
        line_width: Tuple[int, int] = (1500, 2000),
        lines_per_form: Tuple[int, int] = (1, 11),
        words_per_line: Tuple[int, int] = (4, 10),
        words_per_sequence: Tuple[int, int] = (7, 13),
        px_between_lines: Tuple[int, int] = (25, 50),
        px_between_words: int = 50,
        px_around_image: Tuple[int, int] = (100, 200),
        sample_form: bool = False,
        only_lowercase: bool = False,
        rng_seed: int = 0,
        max_height: Optional[int] = None,
    ):
        super().__init__()
        self.iam_root = iam_root
        self.label_enc = label_encoder
        self.transforms = transforms
        self.line_width = line_width
        self.lines_per_form = lines_per_form
        self.words_per_line = words_per_line
        self.words_per_sequence = words_per_sequence
        self.px_between_lines = px_between_lines
        self.px_between_words = px_between_words
        self.px_around_image = px_around_image
        self.sample_form = sample_form
        self.only_lowercase = only_lowercase
        self.rng_seed = rng_seed
        self.max_height = max_height

        self.iam_words = IAMDataset(
            iam_root,
            "word",
            "test",
            only_lowercase=only_lowercase,
        )
        if self.max_height is None:
            self.max_height = IAMDataset.MAX_FORM_HEIGHT
        if sample_form and "\n" not in self.label_encoder.classes:
            # Add the `\n` token to the label encoder (since forms can contain newlines)
            self.label_encoder.addClasses(["\n"])
        self.iam_words.transforms = None
        self.rng = np.random.default_rng(rng_seed)

    def __len__(self):
        # This dataset does not have a finite length since it can generate random
        # images at will, so return 1.
        return 1

    @property
    def label_encoder(self):
        if self.label_enc is not None:
            return self.label_enc
        return self.iam_words.label_enc

    def __getitem__(self, *args, **kwargs):
        """By calling this method, a newly generated synthetic image is sampled."""
        if self.sample_form:
            img, target = self.generate_form()
        else:
            img, target = self.generate_line()
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        # Encode the target sequence using the label encoder.
        target_enc = np.array(self.label_encoder.encode_labels([c for c in target]))
        return img, target_enc

    def generate_line(self) -> Tuple[np.ndarray, str]:
        words_to_sample = self.rng.integers(*self.words_per_line)
        line_width = self.rng.integers(*self.line_width)
        return self.sample_lines(words_to_sample, line_width, sample_one_line=True)

    def generate_form(self) -> Tuple[np.ndarray, str]:
        # Randomly pick the number of words and inter-line distance in the form.
        words_to_sample = self.rng.integers(*self.lines_per_form) * 5  # 7 is handpicked
        px_between_lines = self.rng.integers(*self.px_between_lines)

        # Sample line images.
        line_width = self.rng.integers(*self.line_width)
        lines, target = self.sample_lines(words_to_sample, line_width)

        # Concatenate the lines vertically.
        form_w = max(l.shape[1] for l in lines)
        form_h = sum(l.shape[0] + px_between_lines for l in lines)
        if form_h > self.max_height:
            print(
                "Generated form height exceeds maximum height. Generating a new form."
            )
            return self.generate_form()
        form = np.ones((form_h, form_w), dtype=lines[0].dtype) * 255
        curr_h = 0
        for line_img in lines:
            h, w = line_img.shape
            if curr_h + h + px_between_lines > self.max_height:
                break

            form[curr_h : curr_h + h, :w] = line_img
            curr_h += h + px_between_lines

        # Add a random amount of padding around the image.
        pad_px = self.rng.integers(*self.px_around_image)
        new_h, new_w = form.shape[0] + pad_px * 2, form.shape[1] + pad_px * 2
        form = A.PadIfNeeded(
            new_h, new_w, border_mode=cv.BORDER_CONSTANT, value=255, always_apply=True
        )(image=form)["image"]

        return form, target

    def set_rng(self, seed: int):
        self.rng = np.random.default_rng(seed)

    def sample_word_image(self) -> Tuple[np.ndarray, str]:
        idx = random.randint(0, len(self.iam_words) - 1)
        img, target = self.iam_words[idx]
        target = "".join(self.iam_words.label_enc.decode_labels(target))
        return img, target

    def sample_word_image_sequence(
        self, words_to_sample: int
    ) -> List[Tuple[np.ndarray, str]]:
        """Sample a sequence of contiguous words."""
        assert words_to_sample >= 1
        start_idx = random.randint(0, len(self.iam_words) - 1)

        img_idxs = [start_idx]
        img_path = Path(self.iam_words.data.iloc[start_idx]["img_path"])
        _, _, line_id, word_id = img_path.stem.split("-")
        sampled_words = 1
        while sampled_words < words_to_sample:
            word_id = f"{int(word_id) + 1 :02}"
            img_name = (
                "-".join(img_path.stem.split("-")[:-2] + [line_id, word_id]) + ".png"
            )
            if not (img_path.parent / img_name).is_file():
                # Previous image was the last on its line. Go to the next line.
                line_id = f"{int(line_id) + 1 :02}"
                word_id = "00"
                img_name = (
                    "-".join(img_path.stem.split("-")[:-2] + [line_id, word_id])
                    + ".png"
                )
            if not (img_path.parent / img_name).is_file():
                # End of the document.
                return self.sample_word_image_sequence(words_to_sample)
            # Find the dataset index for the sampled word.
            ix = self.iam_words.id_to_idx.get(Path(img_name).stem)
            if ix is None:
                # If the image has segmentation=err attribute, it will
                # not be in the dataset. In this case try again.
                return self.sample_word_image_sequence(words_to_sample)
            img_idxs.append(ix)
            sampled_words += 1

        imgs, targets = zip(*[self.iam_words[idx] for idx in img_idxs])
        targets = [
            "".join(self.iam_words.label_enc.decode_labels(t)) for t in targets
        ]
        return list(zip(imgs, targets))

    def sample_lines(
        self, words_to_sample: int, max_line_width: int, sample_one_line: bool = False
    ) -> Tuple[Union[List[np.ndarray], np.ndarray], str]:
        """
        Calls `sample_word_image_sequence` several times, using some heuristics
        to glue the sequences together.

        Returns:
            - list of line images
            - transcription for all lines combined
        """
        curr_pos, sampled_words = 0, 0
        imgs, targets, lines = [], [], []
        target_str, last_target = "", ""

        # Sample images.
        while sampled_words < words_to_sample:
            words_per_seq = self.rng.integers(*self.words_per_sequence)
            # Sample a sequence of contiguous words.
            img_tgt_seq = self.sample_word_image_sequence(words_per_seq)
            for i, (img, tgt) in enumerate(img_tgt_seq):
                # Add the sequence to the sampled words so far.
                if sampled_words >= words_to_sample:
                    break
                h, w = img.shape

                if curr_pos + w > max_line_width:
                    # Concatenate the sampled images into a line.
                    line = self.concatenate_line(imgs, targets, max_line_width)

                    if sample_one_line:
                        return line, target_str

                    lines.append(line)
                    target_str += "\n"
                    last_target = "\n"
                    curr_pos = 0
                    imgs, targets = [], []

                # Basic heuristics to avoid some strange looking sentences.
                if i == 0 and (
                    (last_target in self.PUNCTUATION and tgt in self.PUNCTUATION)
                    or (tgt in self.PUNCTUATION and sampled_words == 0)
                ):
                    continue

                if (
                    sampled_words == 0
                    or tgt in [c for c in self.PUNCTUATION if c not in ["'", '"']]
                    or last_target == "\n"
                ):
                    target_str += tgt
                else:
                    target_str += " " + tgt

                targets.append(tgt)
                imgs.append(img)

                sampled_words += 1
                last_target = tgt
                if tgt in self.PUNCTUATION:
                    # Reduce horizontal spacing for punctuation tokens.
                    curr_pos = max(0, curr_pos - self.px_between_words)
                curr_pos += w + self.px_between_words
        if imgs and targets:
            # Concatenate the remaining images into a new line.
            line = self.concatenate_line(imgs, targets, max_line_width)
            lines.append(line)
            if sample_one_line:
                return line, target_str
        return lines, target_str

    def concatenate_line(
        self, imgs: List[np.ndarray], targets: List[str], line_width: int
    ) -> np.ndarray:
        """
        Concatenate a series of (img, target) tuples into a line to create a line image.
        """
        assert len(imgs) == len(targets)

        line_height = max(im.shape[0] for im in imgs)
        line = np.ones((line_height, line_width), dtype=imgs[0].dtype) * 255

        curr_pos = 0
        prev_lower_bound = line_height
        for img, tgt in zip(imgs, targets):
            h, w = img.shape
            # Center the image in the middle of the line.
            start_h = min(max(0, int((line_height - h) / 2)), line_height - h)

            if tgt in [",", "."]:
                # If sampled a comma or dot, place them at the bottom of the line.
                start_h = min(max(0, prev_lower_bound - int(h / 2)), line_height - h)
            elif tgt in ['"', "'"]:
                # If sampled a quote, place them at the top of the line.
                start_h = 0
            if tgt in self.PUNCTUATION:
                # Reduce horizontal spacing for punctuation tokens.
                curr_pos = max(0, curr_pos - self.px_between_words)

            assert curr_pos + w <= line_width, f"{curr_pos + w} > {line_width}"
            assert start_h + h <= line_height, f"{start_h + h} > {line_height}"

            # Concatenate the word image to the line.
            line[start_h : start_h + h, curr_pos : curr_pos + w] = img

            curr_pos += w + self.px_between_words
            prev_lower_bound = start_h + h
        return line

    @staticmethod
    def get_worker_init_fn():
        def worker_init_fn(worker_id: int):
            set_seed(worker_id)
            worker_info = torch.utils.data.get_worker_info()
            dataset = worker_info.dataset  # the dataset copy in this worker process
            if hasattr(dataset, "set_rng"):
                dataset.set_rng(worker_id)
            else:  # dataset is instance of `IAMDatasetSynthetic` class
                dataset.synth_dataset.set_rng(worker_id)

        return worker_init_fn



class IAMDatasetSynthetic(Dataset):
    """
    A Pytorch dataset combining the IAM dataset with the IAMSyntheticDataGenerator
    dataset.

    The distribution of real/synthetic images can be controlled by setting the
    `synth_prob` argument.
    """

    iam_dataset: IAMDataset
    synth_dataset: IAMSyntheticDataGenerator

    def __init__(self, iam_dataset: IAMDataset, nr_of_samples: int, **kwargs):
        """
        Args:
            iam_dataset (Dataset): the IAM dataset to sample from
            synth_prob (float): the probability of sampling a synthetic image when
                calling `__getitem__()`.
        """
        self.iam_dataset = iam_dataset
        self.nr_of_samples = nr_of_samples
        self.synth_dataset = IAMSyntheticDataGenerator(
            iam_root=iam_dataset.root,
            label_encoder=iam_dataset.label_enc,
            transforms=iam_dataset.transforms,
            sample_form=(True if iam_dataset.parse_method == "form" else False),
            only_lowercase=iam_dataset.only_lowercase,
            max_height=(
                (iam_dataset.data["bb_y_end"] - iam_dataset.data["bb_y_start"]).max()
                if iam_dataset.parse_method == "form"
                else None
            ),
            **kwargs,
        )

    def __getitem__(self, idx):
        iam = self.iam_dataset
        if len(self.iam_dataset) <= idx < self.nr_of_samples:
            # Sample from the synthetic dataset.
            img, target = self.synth_dataset[0]
        else:
            # Index the IAM dataset.
            img, target = iam[idx]
        assert not np.any(np.isnan(img)), img
        return img, target

    def __len__(self):
        return self.nr_of_samples

## Metrics

In [29]:

class CharacterErrorRate(Metric):

    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder

        self.add_state("edits", default=torch.Tensor([0]), dist_reduce_fx="sum")
        self.add_state("total_chars", default=torch.Tensor([0]), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim

        eos_tkn_idx, sos_tkn_idx = list(
            self.label_encoder.encode_labels(["<EOS>", "<SOS>"])
        )

        if (preds[:, 0] == sos_tkn_idx).all():  # this should normally be the case
            preds = preds[:, 1:]

        eos_idxs_prd = (preds == eos_tkn_idx).float().argmax(1).tolist()
        eos_idxs_tgt = (target == eos_tkn_idx).float().argmax(1).tolist()

        for i, (p, t) in enumerate(zip(preds, target)):
            eos_idx_p, eos_idx_t = eos_idxs_prd[i], eos_idxs_tgt[i]
            p = p[:eos_idx_p] if eos_idx_p else p
            t = t[:eos_idx_t] if eos_idx_t else t
            p_str, t_str = map(tensor_to_str, (p, t))
            editd = editdistance.eval(p_str, t_str)

            self.edits += editd
            self.total_chars += t.numel()

    def compute(self) -> torch.Tensor:
        return self.edits.float() / self.total_chars

class WordErrorRate(Metric):
    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder

        self.add_state("edits", default=torch.Tensor([0]), dist_reduce_fx="sum")
        self.add_state("total_words", default=torch.Tensor([0]), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim

        eos_tkn_idx, sos_tkn_idx = self.label_encoder.encode_labels(["<EOS>", "<SOS>"])

        if (preds[:, 0] == sos_tkn_idx).all():
            preds = preds[:, 1:]

        eos_idxs_prd = (preds == eos_tkn_idx).float().argmax(1).tolist()
        eos_idxs_tgt = (target == eos_tkn_idx).float().argmax(1).tolist()

        for i, (p, t) in enumerate(zip(preds, target)):
            eos_idx_p, eos_idx_t = eos_idxs_prd[i], eos_idxs_tgt[i]
            p = (p[:eos_idx_p] if eos_idx_p else p).flatten().tolist()
            t = (t[:eos_idx_t] if eos_idx_t else t).flatten().tolist()
            p_words = "".join(self.label_encoder.decode_labels(p)).split()
            t_words = "".join(self.label_encoder.decode_labels(t)).split()
            editd = editdistance.eval(p_words, t_words)

            self.edits += editd
            self.total_words += len(t_words)

    def compute(self) -> torch.Tensor:
        """Compute Word Error Rate."""
        return self.edits.float() / self.total_words

def tensor_to_str(t: torch.Tensor) -> str:
    return "".join(map(str, t.flatten().tolist()))

## Model class

In [30]:

class PosEmbedding1D(nn.Module):
    """
    Implements 1D sinusoidal embeddings.

    Adapted from 'The Annotated Transformer', http://nlp.seas.harvard.edu/2018/04/03/attention.html
    """

    def __init__(self, d_model, max_len=1000):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros((max_len, d_model), requires_grad=False)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Add a 1D positional embedding to an input tensor.

        Args:
            x (Tensor): tensor of shape (B, T, d_model) to add positional
                embedding to
        """
        _, T, _ = x.shape
        # assert T <= self.pe.size(0) \
        assert T <= self.pe.size(1), (
            f"Stored 1D positional embedding does not have enough dimensions for the current feature map. "
            f"Currently max_len={self.pe.size(1)}, T={T}. Consider increasing max_len such that max_len >= T."
        )
        return x + self.pe[:, :T]



class PosEmbedding2D(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        self.d_model = d_model
        pe_x = torch.zeros((max_len, d_model // 2), requires_grad=False)
        pe_y = torch.zeros((max_len, d_model // 2), requires_grad=False)

        pos = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(
            -math.log(10000.0) * torch.arange(0, d_model // 2, 2) / d_model
        )

        pe_y[:, 0::2] = torch.sin(pos * div_term)
        pe_y[:, 1::2] = torch.cos(pos * div_term)
        pe_x[:, 0::2] = torch.sin(pos * div_term)
        pe_x[:, 1::2] = torch.cos(pos * div_term)

        self.register_buffer("pe_x", pe_x)
        self.register_buffer("pe_y", pe_y)

    def forward(self, x):
        _, w, h, _ = x.shape

        pe_x_ = self.pe_x[:w, :].unsqueeze(1).expand(-1, h, -1)
        pe_y_ = self.pe_y[:h, :].unsqueeze(0).expand(w, -1, -1)

        pe = torch.cat([pe_y_, pe_x_], -1)
        pe = pe.unsqueeze(0)

        return x + pe


class encoderHTR(nn.Module):
    def __init__(self, d_model: int, encoder_type: str, dropout=0.1, bias=True):
        super().__init__()
        assert encoder_type in ["resnet18", "resnet34", "resnet50"], "Model not found"

        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.pos_embd = PosEmbedding2D(d_model)

        resnet = getattr(models, encoder_type)(pretrained=False)

        modules = list(resnet.children())
        cnv_1 = modules[0]
        cnv_1 = nn.Conv2d(
            1,
            cnv_1.out_channels,
            cnv_1.kernel_size,
            cnv_1.stride,
            cnv_1.padding,
            bias=cnv_1.bias
        )
        self.encoder = nn.Sequential(cnv_1, *modules[1:-2])
        self.linear = nn.Conv2d(resnet.fc.in_features, d_model, kernel_size=1)

    def forward(self, imgs):
        x = self.encoder(imgs.unsqueeze(1))
        x = self.linear(x).transpose(1, 2).transpose(2, 3)
        x = self.pos_embd(x)
        x = self.dropout(x)
        x = x.flatten(1, 2)

        return x


class decoderHTR(nn.Module):
    def __init__(self,
                 vocab_length,
                 max_seq_len,
                 eos_tkn_idx,
                 sos_tkn_idx,
                 pad_tkn_idx,
                 d_model,
                 num_layers,
                 nhead,
                 dim_ffn,
                 dropout,
                 activation="relu"):
        super().__init__()
        self.vocab_length = vocab_length
        self.max_seq_len = max_seq_len
        self.eos_idx = eos_tkn_idx
        self.sos_idx = sos_tkn_idx
        self.pad_idx = pad_tkn_idx
        self.d_model = d_model
        self.num_layers = num_layers
        self.nhead = nhead
        self.dim_ffn = dim_ffn
        self.pos_emb = PosEmbedding1D(d_model)
        self.emb = nn.Embedding(vocab_length, d_model)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model,
            nhead,
            dim_ffn,
            dropout,
            activation,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.clf = nn.Linear(d_model, vocab_length)
        self.drop = nn.Dropout(dropout)

    def forward(self, memory: torch.Tensor):
        B, _, _ = memory.shape
        all_logits = []
        sampled_ids = [torch.full([B], self.sos_idx).to(memory.device)]
        tgt = self.pos_emb(
            self.emb(sampled_ids[0]).unsqueeze(1) * math.sqrt(self.d_model)
        )
        tgt = self.drop(tgt)
        eos_sampled = torch.zeros(B).bool()

        for t in range(self.max_seq_len):
            tgt_mask = self.subsequent_mask(len(sampled_ids)).to(memory.device)
            out = self.decoder(tgt, memory, tgt_mask=tgt_mask)
            logits = self.clf(out[:, -1, :])
            _, pred = torch.max(logits, -1)
            all_logits.append(logits)
            sampled_ids.append(pred)
            for i, pr in enumerate(pred):
                if pr == self.eos_idx:
                    eos_sampled[i] = True
            if eos_sampled.all():
                break

            tgt_ext = self.drop(
                self.pos_emb.pe[:, len(sampled_ids)]
                + self.emb(pred) * math.sqrt(self.d_model)
            ).unsqueeze(1)
            tgt = torch.cat([tgt, tgt_ext], 1)
        sampled_ids = torch.stack(sampled_ids, 1)
        all_logits = torch.stack(all_logits, 1)

        eos_idxs = (sampled_ids == self.eos_idx).float().argmax(1)
        for i in range(B):
            if eos_idxs[i] != 0:
                sampled_ids[i, eos_idxs[i] + 1:] = self.pad_idx

        return all_logits, sampled_ids

    def forward_teacher_forcing(self, memory: torch.Tensor, tgt: torch.Tensor):
        B, T = tgt.shape
        tgt = torch.cat(
            [
                torch.full([B], self.sos_idx).unsqueeze(1).to(memory.device),
                tgt[:, :-1]
            ],
            1
        )

        tgt_key_masking = tgt == self.pad_idx
        tgt_mask = self.subsequent_mask(T).to(tgt.device)

        tgt = self.pos_emb(self.emb(tgt) * math.sqrt(self.d_model))
        tgt = self.drop(tgt)
        out = self.decoder(
            tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_masking
        )
        logits = self.clf(out)
        return logits

    @staticmethod
    def subsequent_mask(size: int):
        mask = torch.triu(torch.ones(size, size), diagonal=1)
        return mask == 1


class FullPageHTR(nn.Module):
    encoder: encoderHTR
    decoder: decoderHTR
    cer_metric: CharacterErrorRate
    wer_metric: WordErrorRate
    loss_fn: Callable
    label_encoder: LabelParser

    def __init__(self, label_encoder: LabelParser,
                 max_seq_len=500,
                 d_model=260,
                 num_layers=6,
                 nhead=4,
                 dim_feedforward=1024,
                 encoder_name="resnet18",
                 drop_enc=0.5,
                 drop_dec=0.5,
                 activ_dec="gelu",
                 label_smoothing=0.1,
                 vocab_len: Optional[int] = None):
        super().__init__()
        self.eos_token_idx, self.sos_token_idx, self.pad_token_idx = label_encoder.encode_labels(
            ["<EOS>", "<SOS>", "<PAD>"]
        )

        self.encoder = encoderHTR(d_model, encoder_type=encoder_name, dropout=drop_enc)
        self.decoder = decoderHTR(vocab_length=(vocab_len or len(label_encoder.classes)),
                                  max_seq_len=max_seq_len,
                                  eos_tkn_idx=self.eos_token_idx,
                                  sos_tkn_idx=self.sos_token_idx,
                                  pad_tkn_idx=self.pad_token_idx,
                                  d_model=d_model,
                                  num_layers=num_layers,
                                  nhead=nhead,
                                  dim_ffn=dim_feedforward,
                                  dropout=drop_dec,
                                  activation=activ_dec)
        self.label_encoder = label_encoder
        self.cer_metric = CharacterErrorRate(label_encoder)
        self.wer_metric = WordErrorRate(label_encoder)
        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=self.pad_token_idx,
            label_smoothing=label_smoothing
        )

    def forward(self, imgs: torch.Tensor, targets: Optional[torch.Tensor] = None):
        logits, sampled_ids = self.decoder(self.encoder(imgs))
        loss = None
        if targets is not None:
            loss = self.loss_fn(
                logits[:, : targets.size(1), :].transpose(1, 2),
                targets[:, : logits.size(1)],
            )
        return logits, sampled_ids, loss

    def forward_teacher_forcing(self, imgs: torch.Tensor, targets: torch.Tensor):
        memory = self.encoder(imgs)
        logits = self.decoder.forward_teacher_forcing(memory, targets)
        loss = self.loss_fn(logits.transpose(1, 2), targets)

        return logits, loss

    def calculate_metrics(self, preds: torch.Tensor, targets: torch.Tensor):
        self.cer_metric.reset()
        self.wer_metric.reset()

        cer = self.cer_metric(preds, targets)
        wer = self.wer_metric(preds, targets)
        return {"CER": cer, "WER": wer}

    def set_num_output_classes(self, n_classes: int):
        old_vocab_len = self.decoder.vocab_length
        self.decoder.vocab_length = n_classes
        self.decoder.clf = nn.Linear(self.decoder.d_model, n_classes)

        new_embs = nn.Embedding(n_classes, self.decoder.d_model)
        with torch.no_grad():
            new_embs.weight[:old_vocab_len] = self.decoder.emb.weight
            self.decoder.emb = new_embs


## Model Trainer Class

In [31]:

class ModelTrainer:

    def __init__(self, run_name: str,
                 model: FullPageHTR,
                 ds_name: str,
                 train_data: DataLoader,
                 val_data: DataLoader,
                 optimizer: Optimizer,
                 num_epochs: int,
                 device: torch.device,
                 normalization_steps: int):

        self.normalization_steps = normalization_steps

        self.model = model
        self.train_data, self.val_data = train_data, val_data
        self.num_epochs = num_epochs
        self.optimizer = optimizer
        self.ds_name = ds_name
        self.run_name = run_name
        self.device = device

    def _init_wandb(self):

        wandb.init(project="fullpage-htr-base",
                   config={
                       "run_name": self.run_name,
                       "learning_rate": self.optimizer.param_groups[0]["lr"],
                       "epochs": self.num_epochs,
                       "dataset": self.ds_name
                   })
        wandb.define_metric('Train')
        wandb.define_metric('Val')

    def train_epoch_ga(self, epoch_nr, ds_size):
      self.model.train()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0

      nr_batches = 0
      b_loss = 0.0
      b_cer = 0.0
      b_wer = 0.0
      for idx, batch in enumerate(tqdm(self.train_data)):
        inputs, labels = batch
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)

        outputs, loss = self.model.forward_teacher_forcing(inputs, labels)
        loss = loss / (self.normalization_steps * labels.size(0))
        loss.backward()
        b_loss += loss.item()

        _, preds = outputs.max(-1)
        res = self.model.calculate_metrics(preds, labels)

        b_cer += res["CER"] / (self.normalization_steps * labels.size(0))
        b_wer += res["WER"] / (self.normalization_steps * labels.size(0))

        if idx > 0 and (idx % self.normalization_steps == 0 or idx + 1 == len(self.train_data)):

          self.optimizer.step()
          self.optimizer.zero_grad()
          if self.wanda:
            wandb.log({
              'Train Loss': b_loss ,
              'Train CER' : b_cer ,
              'Train WER' : b_wer ,
              'Train': idx + ds_size * epoch_nr
            })

          total_loss += b_loss
          total_cer  += b_cer
          total_wer  += b_wer
          b_cer = 0.0
          b_wer = 0.0
          b_loss = 0.0

      total_loss /= (ds_size // self.normalization_steps)
      total_cer  /= (ds_size // self.normalization_steps)
      total_wer  /= (ds_size // self.normalization_steps)

      return total_loss, total_cer, total_wer




    def train_epoch(self, epoch_nr, ds_size):
      self.model.train()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0

      nr_batches = 0
      b_cer = 0.0
      b_wer = 0.0
      for i, mb in enumerate(tqdm(self.train_data)):

        inputs, labels = mb
        inputs = inputs.to(self.device)
        labels = labels.to(self.device)

        output_logits, loss = self.model.forward_teacher_forcing(inputs, labels)

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        _, preds = output_logits.max(-1)
        res = self.model.calculate_metrics(preds, labels)


        b_cer = res["CER"]
        b_wer = res["WER"]
        if self.wanda:
          wandb.log({
              'Train Loss': loss.item() / labels.size(0),
              'Train CER' : b_cer / labels.size(0),
              'Train WER' : b_wer / labels.size(0),
              'Train': i + ds_size * epoch_nr
            })

        total_loss += loss.item() / labels.size(0)
        total_cer += b_cer / labels.size(0)
        total_wer += b_wer / labels.size(0)



      total_loss /= ds_size
      total_cer /= ds_size
      total_wer /= ds_size

      return total_loss, total_cer, total_wer

    def val_epoch(self, epoch_nr, ds_size):
      self.model.eval()
      total_loss = 0.0
      total_cer = 0.0
      total_wer = 0.0
      nr_batches = 0
      b_loss = 0.0
      b_cer = 0.0
      b_wer = 0.0
      with torch.no_grad():
        for i, mb in enumerate(tqdm(self.val_data)):

          inputs, labels = mb
          inputs = inputs.to(self.device)
          labels = labels.to(self.device)
          self.optimizer.zero_grad()

          output_logits, _, loss = self.model.forward(inputs, labels)


          b_loss = loss.item()
          _, preds = output_logits.max(-1)
          res = self.model.calculate_metrics(preds, labels)
          b_cer = res["CER"]
          b_wer = res["WER"]
          if self.wanda:
            wandb.log({
                'Val Loss': b_loss / labels.size(0),
                'Val CER' : b_cer / labels.size(0),
                'Val WER' : b_wer / labels.size(0),
                'Val' : i  + ds_size * epoch_nr})

          total_loss += b_loss / labels.size(0)
          total_cer += b_cer / labels.size(0)
          total_wer += b_wer / labels.size(0)


          torch.cuda.empty_cache()


        total_loss /= ds_size
        total_cer /= ds_size
        total_wer /= ds_size
      return total_loss, total_cer, total_wer

    def train(self, train_len, val_len, wanda=True):
        self.wanda = wanda
        if wanda:
          self._init_wandb()
        for i in range(self.num_epochs):
            print(f'#.Epoch {i}')
            torch.cuda.empty_cache()
            train_loss, train_cer, train_wer = self.train_epoch_ga(i, train_len)
            val_loss, val_cer, val_wer = self.val_epoch(i, val_len)
            print(f"Train Loss avg: {train_loss}, Train CER avg: {train_cer}, Train WER avg: {train_wer}")
            print(f"Val Loss avg: {val_loss}, Val CER avg: {val_cer}, Val WER avg: {val_wer}")
        if wanda:
          wandb.finish()

In [32]:
from copy import copy

ds = IAMDataset(root="/Users/tefannastasa/BachelorsWorkspace/BachModels/BachModels/data/raw", label_enc=None, parse_method="form" ,split="train")
ds_train, ds_val = torch.utils.data.random_split(ds, [math.ceil(0.8 * len(ds)), math.floor(0.2 * len(ds))])

ds_val.data = copy(ds)
ds_val.data.set_transforms_for_split("val")
train_len = len(ds_train)
val_len = len(ds_val)

In [33]:

batch_size = 2
pad_tkn_idx, eos_tkn_idx = ds.label_enc.encode_labels(["<PAD>", "<EOS>"])
collate_fn = partial(
        IAMDataset.collate_fn, pad_val=pad_tkn_idx, eos_tkn_idx=eos_tkn_idx
)
num_workers = 1
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
dl_val = DataLoader(
    ds_val,
    batch_size=2 * batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
train_len //= batch_size
val_len //= 2 * batch_size

In [35]:
import gc

try:
  device = "cuda"
  torch.cuda.empty_cache()
  gc.collect()

  model = FullPageHTR(ds.label_enc).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
  trainer = ModelTrainer("Testing_run", model, ds_name="IAM_forms" , train_data=dl_train, val_data=dl_val, optimizer=optimizer, num_epochs=100, device=device, normalization_steps=56)

  wandb.finish()
  trainer.train(train_len, val_len, wanda=True)
except RuntimeError:
  del model
  print("Error time!!")



AssertionError: Torch not compiled with CUDA enabled