In [1]:
!pip install editdistance torchmetrics pytorch_lightning

Collecting editdistance
  Downloading editdistance-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Downloading editdistance-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (401 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m401.8/401.8 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: editdistance
Successfully installed editdistance-0.8.1


In [2]:
import pickle
import math
from random import random
import xml.etree.ElementTree as ET
import numpy as np
import albumentations as A
import cv2 as cv
from dataclasses import dataclass, field
from functools import partial
from random import randint
import html
import random
from pathlib import Path
from typing import Union, Tuple, Dict, Sequence, Optional, List, Any, Callable, Optional
import pandas as pd
from torch import Tensor, nn
from torch.utils.data import Dataset
from PIL import Image
from torchmetrics import Metric
import torch
from torchvision import models
from torch.utils.data import Dataset
import editdistance
import wandb
from torch.utils.data.dataloader import DataLoader
from torch.optim import Optimizer
import torch.optim as optim
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import concurrent
from pytorch_lightning.callbacks import TQDMProgressBar

In [3]:
class LabelParser:
    def __init__(self):
        self.classes = None
        self.vocab_size = None
        self.class_to_idx = None
        self.idx_to_class = None
        self.ctc_classes = None
        self.ctc_idx_to_class = None
        self.ctc_class_to_idx = None
        
    def fit(self, classes: Sequence[str]):
        self.classes = list(classes)
        self.vocab_size = len(classes)
        self.idx_to_class = dict(enumerate(classes))
        self.class_to_idx = {cls: i for i, cls in self.idx_to_class.items()}
        
        self.ctc_classes = ["<blank>"] + self.classes
        self.ctc_idx_to_class = dict(enumerate(self.ctc_classes))
        self.ctc_class_to_idx = {cls: i for i, cls in self.ctc_idx_to_class.items()}
        
        return self

    def addClasses(self, classes: List[str]):
        all_classes = sorted(set(self.classes + classes))

        self.fit(all_classes)

    def encode_labels(self, sequence: Sequence[str]):
        self._check_fitted()
        return [self.class_to_idx[c] for c in sequence]

    def decode_labels(self, sequence: Sequence[int]):
        self._check_fitted()
        return [self.idx_to_class[c] for c in sequence]
    
    def ctc_encode_labels(self, sequence: Sequence[str]):
        self._check_fitted()
        return [self.ctc_class_to_idx[c] for c in sequence]
    
    def ctc_decode_labels(self, sequence: Sequence[int]):
        self._check_fitted()
        return [self.ctc_idx_to_class[c] for c in sequence]
    
    def _check_fitted(self):
        if self.classes is None:
            raise ValueError("LabelParser class was not fitted yet")

In [4]:
def pickle_load(file) -> Any:
    with open(file, "rb") as f:
        return pickle.load(f)

def pickle_save(obj, file):
    with open(file, "wb") as f:
        pickle.dump(obj, f)

def read_xml(file: Union[Path, str]) -> ET.Element:
    tree = ET.parse(file)
    root = tree.getroot()

    return root

def find_child_by_tag(element: ET.Element, tag: str, value: str) -> Union[ET.Element, None]:
    for child in element:
        if child.get(tag) == value:
            return child
    return None

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def dpi_adjusting(img: np.ndarray, scale: float, **kwargs) -> np.ndarray:
    height, width = img.shape[:2]
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

class LitProgressBar(TQDMProgressBar):
    def get_metrics(self, trainer, model):
        # don't show the version number
        items = super().get_metrics(trainer, model)
        for k in list(items.keys()):
            if k.startswith("grad"):
                items.pop(k, None)
        items.pop("v_num", None)
        return items
    
def decode_prediction_and_target(
    pred: Tensor, target: Tensor, label_encoder: LabelParser
) -> Tuple[str, str]:

    # Decode prediction and target.
    p, t = pred.tolist(), target.tolist()
    pred_str = "".join(label_encoder.ctc_decode_labels(p))
    target_str = "".join(label_encoder.ctc_decode_labels(t))
    return pred_str, target_str

def matplotlib_imshow(
    img: torch.Tensor, mean: float = 0.5, std: float = 0.5, one_channel=True
):
    assert img.device.type == "cpu"
    if one_channel and img.ndim == 3:
        img = img.mean(dim=0)
    img = img * std + mean  # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [5]:
class SafeRandomScale(A.RandomScale):
    def apply(self, img, scale=0, interpolation=cv.INTER_LINEAR, **params):
        height, width = img.shape[:2]
        new_height, new_width = int(height * scale), int(width * scale)
        if new_height <= 0 or new_width <= 0:
            return img
        return super().apply(img, scale, interpolation, **params)

def adjust_dpi(img: np.ndarray, scale: float, **kwargs):
    height, width = img.shape
    new_height, new_width = math.ceil(height * scale), math.ceil(width * scale)
    return cv.resize(img, (new_width, new_height))

def randomly_displace_and_pad(
    img: np.ndarray,
    padded_size: Tuple[int, int],
    crop_if_necessary: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Randomly displace an image within a frame, and pad zeros around the image.

    Args:
        img (np.ndarray): image to process
        padded_size (Tuple[int, int]): (height, width) tuple indicating the size of the frame
        crop_if_necessary (bool): whether to crop the image if its size exceeds that
            of the frame
    """
    frame_h, frame_w = padded_size
    img_h, img_w = img.shape
    if frame_h < img_h or frame_w < img_w:
        if crop_if_necessary:
            print(
                "WARNING (`randomly_displace_and_pad`): cropping input image before "
                "padding because it exceeds the size of the frame."
            )
            img_h, img_w = min(img_h, frame_h), min(img_w, frame_w)
            img = img[:img_h, :img_w]
        else:
            raise AssertionError(
                f"Frame is smaller than the image: ({frame_h}, {frame_w}) vs. ({img_h},"
                f" {img_w})"
            )

    res = np.zeros((frame_h, frame_w), dtype=img.dtype)

    pad_top =  randint(0, frame_h - img_h)
    pad_bottom = pad_top + img_h
    pad_left = randint(0, frame_w - img_w)
    pad_right = pad_left + img_w

    res[pad_top:pad_bottom, pad_left:pad_right] = img
    return res

@dataclass
class ImageTransforms:
    max_img_size: Tuple[int, int]  # (h, w)
    normalize_params: Tuple[float, float]  # (mean, std)
    scale: float = (
        0.25
    )
    random_scale_limit: float = 0.1
    random_rotate_limit: int = 1

    train_trnsf: A.Compose = field(init=False)
    test_trnsf: A.Compose = field(init=False)

    def __post_init__(self):
        scale, random_scale_limit, random_rotate_limit, normalize_params =(
            self.scale,
            self.random_scale_limit,
            self.random_rotate_limit,
            self.normalize_params
        )

        max_img_h, max_img_w = self.max_img_size
        max_scale = scale + scale * random_scale_limit
        padded_h, padded_w = math.ceil(max_scale * max_img_h), math.ceil(max_scale * max_img_w)

        self.train_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            SafeRandomScale(scale_limit=random_scale_limit, p=0.5),
#             A.SafeRotate(
#                 limit = random_rotate_limit,
#                 border_mode = cv.BORDER_CONSTANT,
#                 value = 0
#             ),
            A.RandomBrightnessContrast(),
#             A.Perspective(scale=(0.01, 0.05)),
            A.GaussNoise(),
            A.Normalize(*normalize_params),
#             A.Lambda(
#                 image=partial(
#                     randomly_displace_and_pad,
#                     padded_size=(padded_h, padded_w),
#                     crop_if_necessary=False,
#                 )
#             )
        ])

        self.test_trnsf = A.Compose([
            A.Lambda(partial(adjust_dpi, scale=scale)),
            A.Normalize(*normalize_params),
            A.PadIfNeeded(
                max_img_h, max_img_w, border_mode=cv.BORDER_CONSTANT, value=0
            )
        ])

In [6]:
class CharacterErrorRate(Metric):

    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder
        
        self.add_state("cer_sum", default=torch.zeros(1, dtype=torch.float), dist_reduce_fx="sum")
        self.add_state("nr_samples",default=torch.zeros(1, dtype=torch.int64), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim

        for i, (p, t) in enumerate(zip(preds, target)):
            p_str, t_str = map(tensor_to_str, (p, t))
            editd = editdistance.eval(p_str, t_str)

            self.cer_sum += editd/t.numel()
            self.nr_samples +=1

    def compute(self) -> torch.Tensor:
        return self.cer_sum / self.nr_samples.float()

class WordErrorRate(Metric):
    def __init__(self, label_encoder: LabelParser):
        super().__init__()
        self.label_encoder = label_encoder
        
        self.add_state("wer_sum", default=torch.zeros(1, dtype=torch.float), dist_reduce_fx="sum")
        self.add_state("nr_samples", default=torch.zeros(1, dtype=torch.int64), dist_reduce_fx="sum")


    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.ndim == target.ndim

        for i, (p, t) in enumerate(zip(preds, target)):
            p = p.flatten().tolist()
            t = t.flatten().tolist()
            p_words = "".join(self.label_encoder.ctc_decode_labels(p)).split()
            t_words = "".join(self.label_encoder.ctc_decode_labels(t)).split()
            editd = editdistance.eval(p_words, t_words)
            
            self.wer_sum += editd / len(t_words)
            self.nr_samples += 1
            
    def compute(self) -> torch.Tensor:
        """Compute Word Error Rate."""
        return self.wer_sum / self.nr_samples.float()

def tensor_to_str(t: torch.Tensor) -> str:
    return "".join(map(str, t.flatten().tolist()))

In [7]:
class IAMDataset(Dataset):
    MEAN = 0.8275
    STD = 0.2314
    MAX_FORM_HEIGHT = 3542
    MAX_FORM_WIDTH = 2479

    MAX_SEQ_LENS = {
        "word": 55,
        "line": 90,
        "form": 700,
    }  # based on the maximum seq lengths found in the dataset

    root: Path
    data: pd.DataFrame
    label_enc: LabelParser
    parse_method: str
    only_lowercase: bool
    transforms: Optional[A.Compose]
    id_to_idx: Dict[str, int]
    _split: str
    _return_writer_id: Optional[bool]

    def __init__(
        self,
        root: Union[Path, str],
        parse_method: str,
        split: str,
        return_writer_id: bool = False,
        only_lowercase: bool = False,
        label_enc: Optional[LabelParser] = None,
    ):
        super().__init__()
        _parse_methods = ["form", "line", "word"]
        err_message = (
            f"{parse_method} is not a possible parsing method: {_parse_methods}"
        )
        assert parse_method in _parse_methods, err_message

        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message

        self._split = split
        self._return_writer_id = return_writer_id
        self.only_lowercase = only_lowercase
        self.root = Path(root)
        self.label_enc = label_enc
        self.parse_method = parse_method

        # Process the data.
        if not hasattr(self, "data"):
            self.data = self._get_forms()

        # Create the label encoder.
        if self.label_enc is None:
            vocab = []
            s = "".join(self.data["target"].tolist())

            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc = LabelParser().fit(vocab)
        if not "target_enc" in self.data.columns:
            self.data.insert(
                2,
                "target_enc",
                self.data["target"].apply(
                    lambda s: np.array(
                        self.label_enc.encode_labels(
                            [c for c in (s.lower() if self.only_lowercase else s)]
                        )
                    )
                ),
            )

        self.transforms = self._get_transforms(split)
        self.id_to_idx = {
            Path(self.data.iloc[i]["img_path"]).stem: i for i in range(len(self))
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img = cv.imread(data["img_path"], cv.IMREAD_GRAYSCALE)
        if all(col in data.keys() for col in ["bb_y_start", "bb_y_end"]):
            # Crop the image vertically.
            img = img[data["bb_y_start"] : data["bb_y_end"], :]
        assert isinstance(img, np.ndarray), (
            f"Error: image at path {data['img_path']} is not properly loaded. "
            f"Is there something wrong with this image?"
        )
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        if self._return_writer_id:
            return img, data["writer_id"], data["target_enc"]
        return img, data["target_enc"]

    def get_max_height(self):
        return (self.data["bb_y_end"] - self.data["bb_y_start"]).max()

    @property
    def vocab(self):
        return self.label_enc.classes

    @staticmethod
    def collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
        dataset_returns_writer_id: bool = False,
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if dataset_returns_writer_id:
            imgs, writer_ids, targets = zip(*batch)
        else:
            imgs, targets = zip(*batch)

        img_sizes = [im.shape for im in imgs]
        if (
            not len(set(img_sizes)) == 1
        ):  # images are of varying sizes, so pad them to the maximum size in the batch
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode=cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
        imgs = np.stack(imgs, axis=0)
        imgs, targets = torch.Tensor(imgs), torch.Tensor(targets)
        return imgs, targets

    def set_transforms_for_split(self, split: str):
        _splits = ["train", "val", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        self.transforms = self._get_transforms(split)

    def _get_transforms(self, split: str) -> A.Compose:
        max_img_w = self.MAX_FORM_WIDTH

        if self.parse_method == "form":
            max_img_h = (self.data["bb_y_end"] - self.data["bb_y_start"]).max()
        else:  # word or line
            max_img_h = self.MAX_FORM_HEIGHT

        transforms = ImageTransforms(
            (max_img_h, max_img_w), (IAMDataset.MEAN, IAMDataset.STD)
        )

        if split == "train":
            return transforms.train_trnsf
        elif split == "test" or split == "val":
            return transforms.test_trnsf

    def statistics(self) -> Dict[str, float]:
        assert len(self) > 0
        tmp = self.transforms
        self.transforms = None
        mean, std, cnt = 0, 0, 0
        for img, _ in self:
            mean += np.mean(img)
            std += np.var(img)
            cnt += 1
        mean /= cnt
        std = np.sqrt(std / cnt)
        self.transforms = tmp
        return {"mean": mean, "std": std}
    
    def get_max_target_len(self):
        return (self.data["target_len"]).max()
    
    def _get_forms(self) -> pd.DataFrame:
        
        data = {
            "img_path": [],
            "img_id": [],
            "target": [],
            "bb_y_start": [],
            "bb_y_end": [],
            "target_len": [],
        }
        for form_dir in ["formsA-D", "formsE-H", "formsI-Z"]:
            dr = self.root / form_dir
            for img_path in dr.iterdir():
                doc_id = img_path.stem
                xml_root = read_xml(self.root / "xml" / (doc_id + ".xml"))

                bb_y_start = int(xml_root[1][0].get("asy")) - 150
                bb_y_end = int(xml_root[1][-1].get("dsy")) + 150

                form_text = []
                for line in xml_root.iter("line"):
                    form_text.append(html.unescape(line.get("text", "")))
                
                img_w, img_h = Image.open(str(img_path)).size
                target = " ".join(form_text)
                target = target.replace("\n", " ")
                data["img_path"].append(str(img_path))
                data["img_id"].append(doc_id)
                data["target"].append(target)
                data["bb_y_start"].append(bb_y_start)
                data["bb_y_end"].append(bb_y_end)
                data["target_len"].append(len(target))
        return pd.DataFrame(data).sort_values(
            "target_len"
        )  # by default, sort by target length


In [8]:
import re
import random
class RIMESDataset(Dataset):
    MEAN = 0.8275
    STD = 0.2314
    
    root: Path
    data: pd.DataFrame
    label_enc: LabelParser
    transforms: Optional[A.Compose]
    id_to_idx: Dict[str, int]
    _split: str
    _return_writer_id: Optional[bool]
    
    max_width: Optional[int]
    max_height: Optional[int]
    
    @staticmethod
    def process_target(target: str):
        # Splitting the input string into lines
        target = target.replace("\\n", "\n")
        target = target.replace("\n", " ")
        
        pattern = re.compile(r'¤?\{?([^ ¤{}0-9]*)\/([^ ¤{}0-9]*)\}?¤?')
        matches = pattern.findall(target)
        match_pos = [a for a in pattern.finditer(target)]
        new_target = ""
        last_ind = 0

        for i, (choices, position) in enumerate(zip(matches, match_pos)):
            if position.start() > 0:
                new_target += target[last_ind:position.start()-1]
            if len(choices) != 0:
                new_target += " " + choices[random.randint(0, len(choices) - 1)] + " "
            last_ind = position.end() + 1
        
        new_target += target[last_ind:]
        
        return new_target
    
    def __init__(
            self,
            root: Union[Path, str],
            split: str, 
            only_lowercase: bool = False,
            label_enc: Optional[LabelParser] = None,):
        super().__init__()
        
        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        
        self._split = split
        self.only_lowercase = only_lowercase
        self.root = Path(root)
        self.label_enc = label_enc
        
        if not hasattr(self, "data"):
            self.data = self._get_form_data()
        
        if self.label_enc is None:
            vocab = []
            s = "".join(self.data["target"].tolist())
            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc = LabelParser().fit(vocab)
        else:
            vocab = []
            s = "".join(self.data["target"].tolist())
            if self.only_lowercase:
                s = s.lower()
            vocab += sorted(list(set(s)))
            self.label_enc.addClasses(vocab)
        
        if not "target_enc" in self.data.columns:
            self.data.insert(
                2,
                "target_enc",
                self.data["target"].apply(
                    lambda s: np.array(
                        self.label_enc.ctc_encode_labels(
                            [c for c in (s.lower() if self.only_lowercase else s)]
                        )
                    )
                )
            )
        self.transforms = self._get_transforms(split)
        self.id_to_idx = {
            Path(self.data.iloc[i]["img_path"]).stem: i for i in range(len(self))
        }
        
    def __len__(self):
        return len(self.data)
    
    def set_transform_for_split(self, split):
        _splits = ["train", "val", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        self.transforms = self._get_transforms(split)
    
    def __getitem__(self, idx):
        data = self.data.iloc[idx]
        img = cv.imread(data["img_path"], cv.IMREAD_GRAYSCALE)
        
        if all(col in data.keys() for col in ["bb_y_start", "bb_y_end"]):
            img = img[data["bb_y_start"]: data["bb_y_end"], :]
        assert isinstance(img, np.ndarray), (
            f"Error: image at path {data['img_path']} is not properly loaded. "
            f"Is there something wrong with this image?"
        )
        if self.transforms is not None:
            img = self.transforms(image=img)["image"]
        
        return img, data["target_enc"]
    
    def get_max_height(self):
        return (self.data["bb_y_end"] - self.data["bb_y_start"]).max() + 150
    
    def get_max_width(self):
        return (self.data["bb_x_end"] - self.data["bb_x_start"]).max() + 150
    
    @property
    def vocab(self):
        return self.label_enc.classes
        
    @staticmethod
    def collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        
        imgs, targets = zip(*batch)

        img_sizes = [im.shape for im in imgs]
        if not len(set(img_sizes)) == 1:
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode=cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
        imgs = np.stack(imgs, axis=0)
        imgs, targets = torch.Tensor(imgs), torch.Tensor(targets)
        return imgs, targets
    
    def _get_transforms(self, split: str) -> A.Compose:
        max_img_w = self.max_width
    
        max_img_h = self.max_height
    
        transforms = ImageTransforms(
            (max_img_h, max_img_w), (RIMESDataset.MEAN, RIMESDataset.STD)
        )
    
        if split == "train":
            return transforms.train_trnsf
        elif split == "test" or split == "val":
            return transforms.test_trnsf
    
    def get_max_target_len(self):
        return (self.data["target_len"]).max()

    def _get_form_data(self):
        data = {
            "img_path": [],
            "img_id": [],
            "target": [],
            "bb_y_start": [],
            "bb_y_end": [],
            "bb_x_start": [],
            "bb_x_end": [],
            "target_len": [],
        }
        
        
        def process_forms(paths: Tuple[str, str, Path]):
            return_data = {
                "img_path": [],
                "img_id": [],
                "target": [],
                "bb_y_start": [],
                "bb_y_end": [],
                "bb_x_start": [],
                "bb_x_end": [],
                "target_len": []
            }
            img_path, xml_path, root = paths
            img_path = root / img_path
            xml_path = root / xml_path
            doc_id = img_path.stem[:-2]
            xml_root = read_xml(xml_path)
            
            target = ""
            num_corps = 0
            for box in xml_root.iter("box"):
                type_tag =  box.find("type")
                if type_tag.text == "Corps de texte":
                    target = box.find("text").text
                    
                    
                    if target is None or target == "":
                        continue
                    words = target.split("\\n")
                    if len(words) <= 5:
                        continue

                    target = self.process_target(target)
                    if target.find("¤") != -1 or target.find("{") != -1 or target.find("}") != -1:
                        continue #skip if the target sequence is not standard
                        
                    return_data["img_path"].append(str(img_path.resolve()))
                    return_data["img_id"].append(doc_id)
                    return_data["target"].append(target)
                    return_data["bb_y_start"].append(int(box.get("top_left_y")))
                    return_data["bb_y_end"].append(int(box.get("bottom_right_y")))
                    return_data["bb_x_start"].append(int(box.get("top_left_x")))
                    return_data["bb_x_end"].append(int(box.get("bottom_right_x")))
                    return_data["target_len"].append(len(target))
                    num_corps += 1
            
            return return_data
        
        image_pairs = []
        for form_dir in ["DVD1_TIF", "DVD2_TIF", "DVD3_TIF"]:
            dr = self.root / form_dir
            for file in dr.iterdir():
                name = file.stem
                ext = file.suffix
                if ext == ".tif" and name[-1] == "L":
                    image_pairs.append((name + ".tif", name + ".xml", dr))
        
        with ThreadPoolExecutor() as executor:
            results = list(executor.map(process_forms, iter(image_pairs)))
            
            for single_results in results:
                if single_results["img_path"] == "":
                    continue
                data["img_path"].extend(single_results["img_path"])
                data["img_id"].extend(single_results["img_id"])
                data["target"].extend(single_results["target"])
                data["bb_y_start"].extend(single_results["bb_y_start"])
                data["bb_y_end"].extend(single_results["bb_y_end"])
                data["bb_x_start"].extend(single_results["bb_x_start"])
                data["bb_x_end"].extend(single_results["bb_x_end"])
                data["target_len"].extend(single_results["target_len"])
        
        to_ret = pd.DataFrame(data)
        self.max_height = (to_ret["bb_y_end"] - to_ret["bb_y_start"]).max() + 150
        self.max_width = (to_ret["bb_x_end"] - to_ret["bb_x_start"]).max() + 150
        
        return to_ret

In [9]:
class AggregatedDataset(Dataset):
    datasets: List[Dataset]
    def __init__(self, rimes: RIMESDataset,
                 iam: IAMDataset, 
                 split:str,
                 only_lowercase: bool = False,
                 label_enc: Optional[LabelParser] = None,):
        super().__init__()
        _splits = ["train", "test"]
        err_message = f"{split} is not a possible split: {_splits}"
        assert split in _splits, err_message
        
        self._split = split
        self.rimes = rimes
        self.iam = iam
        self._only_lowercase = only_lowercase
        self.label_enc = label_enc
        
        if self.label_enc is None:
            iamLabelEncoder = iam.label_enc
            rimesLabelEncoder = rimes.label_enc
            self.label_enc = LabelParser()
            self.label_enc.addClasses(iamLabelEncoder.classes)
            self.label_enc.addClasses(rimesLabelEncoder.classes)
        
    def __len__(self):
        return len(self.rimes) + len(self.iam)
    
    def get_max_target_len(self):
        return max(self.iam.get_max_target_len(), self.rimes.get_max_target_len())
    
    @staticmethod
    def unified_collate_fn(
        batch: Sequence[Tuple[np.ndarray, np.ndarray]],
        ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        
        imgs, targets = zip(*batch)
        img_sizes = [im.shape for im in imgs]
        if not len(set(img_sizes)) == 1:
            hs, ws = zip(*img_sizes)
            pad_fn = A.PadIfNeeded(
                max(hs), max(ws), border_mode = cv.BORDER_CONSTANT, value=0
            )
            imgs = [pad_fn(image=im)["image"] for im in imgs]
            
        imgs = [np.expand_dims(im, axis=0) for im in imgs]
        imgs = np.stack(imgs, axis=0)
        
        seq_lengths = [t.shape[0] for t in targets]
        targets_padded = np.full((len(targets), max(seq_lengths) + 1), 0)
        for i, t in enumerate(targets):
            targets_padded[i, : seq_lengths[i]] = t
        
        imgs, targets_padded = torch.tensor(imgs), torch.tensor(targets_padded)
        
        return imgs, targets_padded
    
    def __getitem__(self, idx):
        iam = self.iam 
        rimes = self.rimes
        if idx < len(self.rimes):
            img, target = rimes[idx]
        if len(rimes) <= idx < len(rimes) + len(iam):
            img, target = iam[idx - len(self.rimes)]
        
        if img is None :
            raise ValueError("Image is None.")
        if target is None:
            raise ValueError("Image is None.")
        
        assert not np.any(np.isnan(img)), img
        return img, target
        
        return img, target
    def set_transforms_for_split(self, split):
        self.iam.set_transforms_for_split(split)
        self.rimes.set_transform_for_split(split)

In [10]:
def get_gpu_memory_map():
    result = os.popen('nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader').read()
    return int(result.strip())

class LayerNorm(nn.Module):
    def forward(self, x):
        return nn.functional.layer_norm(x, x.size()[1:], weight=None, bias=None, eps=1e-05)

def pCnv(inp,out,groups=1):
  return nn.Sequential(
      nn.Conv2d(inp,out,1,bias=False,groups=groups),
      nn.InstanceNorm2d(out,affine=True)
  )

def dsCnv(inp,k):
  return nn.Sequential(
      nn.Conv2d(inp,inp,k,groups=inp,bias=False,padding=(k - 1) // 2),
      nn.InstanceNorm2d(inp,affine=True)
  )

class InitBlock(nn.Module):
    def __init__(self, fup, num_channels):
        super().__init__()
        
        self.n1 = LayerNorm()
        self.InitSeq = nn.Sequential(
            pCnv(num_channels, fup),
            nn.Softmax(dim=1),
            dsCnv(fup, 11),
            LayerNorm()
        )
        
    def forward(self, x):
        x  = self.n1(x)
        xt = x
        x = self.InitSeq(x)
        x = torch.cat([x, xt], dim=1)
        return x

class Gate(nn.Module):
    def __init__(self, ifsz):
        super().__init__()
        self.ln = LayerNorm()
        
    def forward(self, x):
        t0, t1 = torch.chunk(x, 2, dim=1)
        t0 = torch.tanh(t0)
        t1.sub(2)
        t1 = torch.sigmoid(t1)
        
        return t1 * t0

class GateBlock(nn.Module):
    def __init__(self, ifsz, ofsz, gt = True, ksz = 3):
        super().__init__()
        
        cfsz = int(math.floor(ifsz / 2))
        ifsz2 = ifsz + ifsz%2
        
        self.sq = nn.Sequential(
            pCnv(ifsz, cfsz),
            dsCnv(cfsz, ksz),
            nn.ELU(),
            
            pCnv(cfsz, cfsz * 2),
            dsCnv(cfsz * 2, ksz),
            Gate(cfsz),
            
            pCnv(cfsz, ifsz),
            dsCnv(ifsz, ksz),
            nn.ELU(),
        )
        
        self.gt = gt
        
    def forward(self, x):
        y = self.sq(x)
        
        out = x + y
        return out
    
class OrigamiNet(nn.Module):
    def __init__(self, 
                 n_channels: int, 
                 label_enc: LabelParser, 
                 mul_rate, 
                 layer_resizes, 
                 layer_sizes, 
                 num_layers, 
                 fup, 
                 reduceAxis=3 ):
        super().__init__()
        
        self.layer_resizes = layer_resizes
        self.Init_sequence = InitBlock(fup, 1)
        self.label_enc = label_enc
        
        self.cer_metric = CharacterErrorRate(label_enc)
        self.wer_metric = WordErrorRate(label_enc)
        
        layers = []
        input_size = fup + n_channels
        output_size = input_size
        
        for i in range(num_layers):
            output_size = int(math.floor(layer_sizes[i] * mul_rate) ) if i in layer_sizes else input_size
            layers.append(GateBlock(input_size, output_size, True, 3))
            
            if input_size != output_size:
                layers.append(pCnv(input_size, output_size))
                layers.append(nn.ELU())
            input_size = output_size
            
            if i in layer_resizes:
                layers.append(layer_resizes[i])
        
        layers.append(LayerNorm())
        self.Gatesq = nn.Sequential(*layers)
        self.Finsq = nn.Sequential(
            pCnv(output_size, self.label_enc.vocab_size),
            nn.ELU()
        )
        
        self.n1 = LayerNorm()
        self.it = 0
        self.reduceAxis = reduceAxis
        self.loss_fn = nn.CTCLoss(reduction="none", zero_infinity=True)
        
    def forward(self, image, targets: Optional[torch.Tensor]):
        x = self.Init_sequence(image)
        x = self.Gatesq(x)
        x = self.Finsq(x)
        x = torch.mean(x, self.reduceAxis, keepdim=False)
        x = self.n1(x)
        x = x.permute(0, 2, 1)
        if targets is not None:
            logits = x
            logits = logits.permute(1, 0, 2).log_softmax(2)
            logits_size = torch.IntTensor([logits.size(0)] * targets.size(0))
            targets_size = torch.IntTensor([targets.size(1)] * targets.size(0))
            targets = targets.cpu()
            loss = self.loss_fn(logits, targets, logits_size, targets_size).mean() 
            return x, loss
        return x
    
    def calculate_metrics(self, preds: torch.Tensor, targets: torch.Tensor):
        self.cer_metric.reset()
        self.wer_metric.reset()
        
        cer = self.cer_metric(preds, targets)
        wer = self.wer_metric(preds, targets)
        
        return {"CER": cer, "WER":wer}

In [11]:
import pytorch_lightning as pl

class LitOrigamiNet(pl.LightningModule):
    model: OrigamiNet

    """
    Pytorch Lightning module that acting as a wrapper around the
    FullPageHTREncoderDecoder class.

    Using a PL module allows the model to be used in conjunction with a Pytorch
    Lightning Trainer, and takes care of logging relevant metrics to Tensorboard.
    """

    def __init__(
        self,
        n_channels: int, 
        label_encoder: LabelParser,
        mul_rate: int, 
        layer_resizes: dict,
        layer_sizes: dict,
        num_layers: int,
        fup: int,
        reduce_axis:int = 3,
        learning_rate: float = 0.0002,
        params_to_log: Optional[Dict[str, Union[str, float, int]]] = None,
    ):
        super().__init__()

        # Save hyperparameters.
        self.learning_rate = learning_rate
        if params_to_log is not None:
            self.save_hyperparameters(params_to_log)
        self.save_hyperparameters(
            "learning_rate",
            "n_channels",
            "num_layers",
        )

        # Initialize the model.
        self.model = OrigamiNet(
            n_channels,
            label_encoder,
            mul_rate,
            layer_resizes,
            layer_sizes,
            fup,
            reduce_axis
        )

    def forward(self, imgs: Tensor, targets: Optional[Tensor] = None):
        return self.model(imgs, targets)

    def training_step(self, batch, batch_idx):
        imgs, targets = batch
        logits, loss = self.model(imgs, targets)
        self.log("train_loss", loss, sync_dist=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        return self.val_or_test_step(batch)

    def test_step(self, batch, batch_idx):
        return self.val_or_test_step(batch)

    def val_or_test_step(self, batch) -> Tensor:
        imgs, targets = batch
        logits, loss = self(imgs, targets)
        _, preds = logits.max(-1)

        # Update and log metrics.
        self.model.cer_metric(preds, targets)
        self.model.wer_metric(preds, targets)
        self.log("char_error_rate", self.model.cer_metric, on_step=True ,prog_bar=True)
        self.log("word_error_rate", self.model.wer_metric, on_step=True ,prog_bar=True)
        self.log("val_loss", loss, sync_dist=True, prog_bar=True, on_step=True)

        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

In [12]:
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import matplotlib.pyplot as plt

PREDICTIONS_TO_LOG = {
    "word": 10,
    "line": 6,
    "form": 1,
}


class LogWorstPredictions(Callback):
    """
    At the end of training, log the worst image prediction, meaning the predictions
    with the highest character error rates.
    """

    def __init__(
        self,
        train_dataloader: Optional[DataLoader] = None,
        val_dataloader: Optional[DataLoader] = None,
        test_dataloader: Optional[DataLoader] = None,
        training_skipped: bool = False,
        data_format: str = "word",
    ):
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.training_skipped = training_skipped
        self.data_format = data_format

    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.training_skipped and self.val_dataloader is not None:
            self.log_worst_predictions(
                self.val_dataloader, trainer, pl_module, mode="val"
            )

    def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.test_dataloader is not None:
            self.log_worst_predictions(
                self.test_dataloader, trainer, pl_module, mode="test"
            )

    def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        if self.train_dataloader is not None:
            self.log_worst_predictions(
                self.train_dataloader, trainer, pl_module, mode="train"
            )
        if self.val_dataloader is not None:
            self.log_worst_predictions(
                self.val_dataloader, trainer, pl_module, mode="val"
            )

    def log_worst_predictions(
        self,
        dataloader: DataLoader,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        mode: str = "train",
    ):
        img_cers = []
        device = "cuda:0" if pl_module.on_gpu else "cpu"
        if not self.training_skipped:
            self._load_best_model(trainer, pl_module)
            pl_module = trainer.model

        print(f"Running {mode} inference on best model...")
        # Run inference on the validation set.
        pl_module.eval()
        for img, target in dataloader:
            assert target.ndim == 2, target.ndim
            cer_metric = pl_module.model.cer_metric
            with torch.inference_mode():
                preds, _ = pl_module(img.to(device), target.to(device))
#                 preds = preds.log_softmax(2)
                preds = torch.max(preds, dim=-1)[1] # extract the predicted characters
                for prd, tgt, im in zip(preds, target, img):
                    cer_metric.reset()
                    cer = cer_metric(prd.unsqueeze(0), tgt.unsqueeze(0)).item()
                    img_cers.append((im, cer, prd, tgt))
        
        # Log the worst k predictions.
        to_log = PREDICTIONS_TO_LOG[self.data_format] * 2
        img_cers.sort(key=lambda x: x[1], reverse=True)  # sort by CER
        img_cers = img_cers[:to_log]
        fig = plt.figure(figsize=(15, 10))
        for i, (im, cer, prd, tgt) in enumerate(img_cers):
            pred_str, target_str = decode_prediction_and_target(
                prd, tgt, pl_module.model.label_encoder
            )

            # Create plot.
            ncols = 4 if self.data_format == "word" else 2
            nrows = math.ceil(to_log / ncols)
            ax = fig.add_subplot(nrows, ncols, i + 1, xticks=[], yticks=[])
            matplotlib_imshow(im, IAMDataset.MEAN, IAMDataset.STD)
            ax.set_title(f"Pred: {pred_str} (CER: {cer:.2f})\nTarget: {target_str}")

        # # Log the results to Tensorboard.
        # tensorboard = trainer.logger.experiment
        # tensorboard.add_figure(f"{mode}: worst predictions", fig, trainer.global_step)
        trainer.logger.experiment.log({f"{mode}: predictions vs targets": wandb.Image(fig)})
        plt.close(fig)

        print("Done.")

    @staticmethod
    def _load_best_model(trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
        ckpt_callback = None
        for cb in trainer.callbacks:
            if isinstance(cb, ModelCheckpoint):
                ckpt_callback = cb
                break
        assert ckpt_callback is not None, "ModelCheckpoint not found in callbacks."
        best_model_path = ckpt_callback.best_model_path

        print(f"Loading best model at {best_model_path}")
        label_encoder = pl_module.model.label_encoder
        model = LitOrigamiNet.load_from_checkpoint(
            best_model_path,
            label_encoder=label_encoder,
        )
        trainer.model.load_state_dict(model.state_dict())


class LogModelPredictions(Callback):
    """
    Use a fixed test batch to monitor model predictions at the end of every epoch.

    Specifically: it generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's prediction alongside the actual target.
    """

    def __init__(
        self,
        label_encoder: LabelParser,
        val_batch: Tuple[torch.Tensor, torch.Tensor],
        use_gpu: bool = True,
        data_format: str = "word",
        train_batch: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ):
        self.label_encoder = label_encoder
        self.val_batch = val_batch
        self.use_gpu = use_gpu
        self.data_format = data_format
        self.train_batch = train_batch

    def on_validation_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ):
        self._predict_intermediate(trainer, pl_module, split="val")

    def on_train_epoch_end(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ):
        if self.train_batch is not None:
            self._predict_intermediate(trainer, pl_module, split="train")

    def _predict_intermediate(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", split="val"
    ):
        """Make predictions on a fixed batch of data and log the results to Tensorboard."""

        # Make predictions.
        if split == "train":
            imgs, targets = self.train_batch
        else:  # split == "val"
            imgs, targets = self.val_batch
        with torch.inference_mode():
            pl_module.eval()
            preds = pl_module(imgs.cuda() if self.use_gpu else imgs)
#             preds = preds.log_softmax(2)
            preds = torch.max(preds, dim=-1)[1] # get the predicted outputs

        # Decode predictions and generate a plot.
        fig = plt.figure(figsize=(15, 10))
        for i, (p, t) in enumerate(zip(preds, targets)):
            pred_str, target_str = decode_prediction_and_target(
                p, t, self.label_encoder
            )

            # Create plot.
            ncols = 2 if self.data_format == "word" else 1
            nrows = math.ceil(preds.size(0) / ncols)
            ax = fig.add_subplot(nrows, ncols, i + 1, xticks=[], yticks=[])
            matplotlib_imshow(imgs[i], IAMDataset.MEAN, IAMDataset.STD)
            ax.set_title(f"Pred: {pred_str}\nTarget: {target_str}")


        trainer.logger.experiment.log({f"{split}: predictions vs targets": wandb.Image(fig)})
        plt.close(fig)

In [13]:
from copy import copy
from pytorch_lightning import seed_everything

seed_everything(12345)
ds = IAMDataset(root="/kaggle/input/iam-rimes/data/raw/IAM", label_enc=None, parse_method="form" ,split="train")
# rimes_ds = RIMESDataset(root="/kaggle/input/iam-rimes/data/raw/RIMES", label_enc=iam_ds.label_enc, split="train")
# ds = AggregatedDataset(rimes_ds, iam_ds, split="train", label_enc=rimes_ds.label_enc)
# print(len(ds.label_enc.ctc_classes))
# print(ds.label_enc.ctc_decode_labels([7]))
# print(f"Maximum target length for Aggregated dataset: {ds.get_max_target_len()}")
# ds_orig, ds_compl = torch.utils.data.random_split(ds, [math.ceil(0.5 * len(ds)), math.floor(0.5 * len(ds))])

ds_train, ds_val = torch.utils.data.random_split(ds, [math.ceil(0.8 * len(ds)), math.floor(0.2 * len(ds))])

ds_val.data = copy(ds)
ds_val.data.set_transforms_for_split("val")
train_len = len(ds_train)
val_len = len(ds_val)
print(train_len, val_len)


1232 307


In [14]:

batch_size = 2

collate_fn = partial(
        AggregatedDataset.unified_collate_fn
)
num_workers = 2
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
dl_val = DataLoader(
    ds_val,
    batch_size= batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=True,
)
train_len //= batch_size
val_len //= batch_size

In [15]:
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ModelSummary
from torch.utils.data import DataLoader, Subset
from pytorch_lightning.loggers import WandbLogger
import wandb
import os

wandb.login(key="0350b0cc5bd9521bb37a798168d31b6b65e9caca")
wandb_logger = WandbLogger(project="origami_net_bach", log_model="all")
callbacks = [
        ModelCheckpoint(
            save_top_k=(3),
            mode="min",
            monitor="word_error_rate",
            filename="{epoch}-{char_error_rate:.4f}-{word_error_rate:.4f}",
        ),
        ModelSummary(max_depth=2),
        LitProgressBar(),
        LogWorstPredictions(
            dl_train,
            dl_val,
            training_skipped=False,
            data_format="form",
        ),
        LogModelPredictions(
            ds.label_enc,
            val_batch=next(
                iter(
                    DataLoader(
                        Subset(
                            ds_val,
                            random.sample(
                                range(len(ds_val)), 1
                            ),
                        ),
                        batch_size=1,
                        shuffle=False,
                        collate_fn=collate_fn,
                        num_workers=num_workers,
                        pin_memory=True,
                    )
                )
            ),
            train_batch=next(
                iter(
                    DataLoader(
                        Subset(
                            ds_train,
                            random.sample(
                                range(len(ds_train)),
                                1,
                            ),
                        ),
                        batch_size=1,
                        shuffle=False,
                        collate_fn=collate_fn,
                        num_workers=num_workers,
                        pin_memory=True,
                    )
                )
            ),
            data_format="form",
            use_gpu=True,
        ),
        EarlyStopping(
                    monitor="word_error_rate",
                    patience=50,
                    verbose=True,
                    mode="min",
                    check_on_train_epoch_end=False,
                )
    ]

def get_gpu_memory_map():
    result = os.popen('nvidia-smi --query-gpu=memory.used --format=csv,nounits,noheader').read()
    return int(result.strip())

trainer = Trainer(
    max_epochs=3000,
    accelerator="gpu", 
    devices=1,
    callbacks = callbacks,
#     accumulate_grad_batches=54,
    logger=wandb_logger,
    fast_dev_run=False,
    precision="16-mixed")
model = LitOrigamiNet(
    n_channels = 1, 
    label_encoder=ds.label_enc,
    mul_rate= 1.0, 
    layer_resizes= {
            0: nn.MaxPool2d(2, 2),
            2: nn.MaxPool2d(2, 2),
            4: nn.MaxPool2d(2,2),
            6: nn.ZeroPad2d(1),
            8: nn.ZeroPad2d(1),
            10: nn.Upsample((450, 15), align_corners=True, mode="bilinear"),
            11: nn.Upsample((1100, 8), align_corners=True, mode="bilinear")
        },
    layer_sizes= {
            0:  512,
            4:  1024,
            11: 512
        }, 
    num_layers=12, 
    fup=33,
    learning_rate=0.001,
)
# def profile_memory(model, dataloader):
#     model.cuda()
#     model.train()
#     for batch in dataloader:
#         inputs, _ = batch
#         inputs = inputs.cuda()
#         outputs = model(inputs)
#         print(f"Memory Usage: {torch.cuda.memory_allocated()} bytes")
        
            
# profile_memory(model, dl_train)

# trainer.fit(model, dl_train, dl_val)
print(f"Memory Used: {get_gpu_memory_map()} MB")
import traceback
try:
    # Start training
    trainer.fit(model, dl_train, dl_val)
except Exception as e:
    # Catch and print any exceptions during training
    print(f"An error occurred: {e}")
#     traceback.print_exc()
finally:
    # Ensure wandb is properly closed
    wandb.finish()

    # Cleanup to free GPU memory, if the objects are not needed anymore
    del model
    torch.cuda.empty_cache()
    print("Training finished, resources cleared.")
    print(f"Memory Used: {get_gpu_memory_map()} MB")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mstefannastasa[0m. Use [1m`wandb login --relogin`[0m to force relogin


Memory Used: 256 MB


[34m[1mwandb[0m: wandb version 0.17.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.16.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20240518_123558-egfv9bm3[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mhopeful-smoke-7[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/stefannastasa/origami_net_bach[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/stefannastasa/origami_net_bach/runs/egfv9bm3[0m


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

An error occurred: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 15.89 GiB of which 52.12 MiB is free. Process 3239 has 15.84 GiB memory in use. Of the allocated memory 15.24 GiB is allocated by PyTorch, and 312.59 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF


[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 🚀 View run [33mhopeful-smoke-7[0m at: [34m[4mhttps://wandb.ai/stefannastasa/origami_net_bach/runs/egfv9bm3[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/stefannastasa/origami_net_bach[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)
[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240518_123558-egfv9bm3/logs[0m


Training finished, resources cleared.
Memory Used: 330 MB
