In [2]:
# Built-in libraries
import argparse
import json
import math
import os
import random
import re
import subprocess
import sys
import tarfile
from argparse import Namespace
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from urllib.request import urlretrieve

# Third-party libraries
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import editdistance
import hydra
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Metric
from tqdm import tqdm
import torchvision.models
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# PyTorch Lightning
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.wandb import WandbLogger

  from .autonotebook import tqdm as notebook_tqdm


## prepare data

In [None]:
# Create data directory if it doesn't exist
if not os.path.exists("data"):
    os.makedirs("data")

# Download files
!wget -q -P data https://im2markup.yuntiandeng.com/data/im2latex_formulas.norm.lst
!wget -q -P data https://im2markup.yuntiandeng.com/data/im2latex_validate_filter.lst
!wget -q -P data https://im2markup.yuntiandeng.com/data/im2latex_train_filter.lst
!wget -q -P data https://im2markup.yuntiandeng.com/data/im2latex_test_filter.lst

# if you want to download the raw images, uncomment the following command and comment out the next command
# !wget -q -P data https://im2markup.yuntiandeng.com/data/formula_images_processed.tar.gz
# # Extract raw image data
# !tar -xzf data/formula_images_processed.tar.gz -C data

# this image data is processed 
!wget -q -P data https://im2markup.yuntiandeng.com/data/formula_images_processed.tar.gz
# Extract processed image data
!tar -xzf data/formula_images_processed.tar.gz -C data


In [None]:
# Gốc dự án là thư mục hiện tại
PROJECT_DIRNAME = Path().resolve()

DATA_DIRNAME = PROJECT_DIRNAME / "data"

RAW_IMAGES_DIRNAME = DATA_DIRNAME / "formula_images"
PROCESSED_IMAGES_DIRNAME = DATA_DIRNAME / "formula_images_processed"
VOCAB_FILE = DATA_DIRNAME / "vocab.json"
PROJECT_DIRNAME, DATA_DIRNAME, VOCAB_FILE

(PosixPath('/home/quangliz/Documents/schoolwork/py/CV/img2latex'),
 PosixPath('/home/quangliz/Documents/schoolwork/py/CV/img2latex/data'),
 PosixPath('/home/quangliz/Documents/schoolwork/py/CV/img2latex/data/vocab.json'))

In [None]:
# define some utils
def pil_loader(fp: Path, mode: str) -> Image.Image:
    with open(fp, "rb") as f:
        img = Image.open(f)
        return img.convert(mode)
    
def get_all_formulas(filename: Path) -> List[List[str]]:
    """Returns all the formulas in the formula file."""
    with open(filename) as f:
        all_formulas = [formula.strip("\n").split() for formula in f.readlines()]
    return all_formulas

def get_split(
    all_formulas: List[List[str]],
    filename: Path,
) -> Tuple[List[str], List[List[str]]]:
    image_names = []
    formulas = []
    with open(filename) as f:
        for line in f:
            img_name, formula_idx = line.strip("\n").split()
            image_names.append(img_name)
            formulas.append(all_formulas[int(formula_idx)])
    return image_names, formulas


# uncomment this if you want to preprocess the raw images
# def first_and_last_nonzeros(arr):
#     for i in range(len(arr)):
#         if arr[i] != 0:
#             break
#     left = i
#     for i in reversed(range(len(arr))):
#         if arr[i] != 0:
#             break
#     right = i
#     return left, right

# def crop(filename: Path, padding: int = 8) -> Optional[Image.Image]:
#     image = pil_loader(filename, mode="RGBA")

#     # Replace the transparency layer with a white background
#     new_image = Image.new("RGBA", image.size, "WHITE")
#     new_image.paste(image, (0, 0), image)
#     new_image = new_image.convert("L")

#     # Invert the color to have a black background and white text
#     arr = 255 - np.array(new_image)

#     # Area that has text should have nonzero pixel values
#     row_sums = np.sum(arr, axis=1)
#     col_sums = np.sum(arr, axis=0)
#     y_start, y_end = first_and_last_nonzeros(row_sums)
#     x_start, x_end = first_and_last_nonzeros(col_sums)

#     # Some images have no text
#     if y_start >= y_end or x_start >= x_end:
#         print(f"{filename.name} is ignored because it does not contain any text")
#         return None

#     # Cropping
#     cropped = arr[y_start : y_end + 1, x_start : x_end + 1]
#     H, W = cropped.shape

#     # Add paddings
#     new_arr = np.zeros((H + padding * 2, W + padding * 2))
#     new_arr[padding : H + padding, padding : W + padding] = cropped

#     # Invert the color back to have a white background and black text
#     new_arr = 255 - new_arr
#     return Image.fromarray(new_arr).convert("L")

In [None]:
class Tokenizer:
    def __init__(self, token_to_index: Optional[Dict[str, int]] = None) -> None:
        self.pad_token = "<PAD>"
        self.sos_token = "<SOS>"
        self.eos_token = "<EOS>"
        self.unk_token = "<UNK>"

        self.token_to_index: Dict[str, int]
        self.index_to_token: Dict[int, str]

        if token_to_index:
            self.token_to_index = token_to_index
            self.index_to_token = {index: token for token, index in self.token_to_index.items()}
            self.pad_index = self.token_to_index[self.pad_token]
            self.sos_index = self.token_to_index[self.sos_token]
            self.eos_index = self.token_to_index[self.eos_token]
            self.unk_index = self.token_to_index[self.unk_token]
        else:
            self.token_to_index = {}
            self.index_to_token = {}
            self.pad_index = self._add_token(self.pad_token)
            self.sos_index = self._add_token(self.sos_token)
            self.eos_index = self._add_token(self.eos_token)
            self.unk_index = self._add_token(self.unk_token)

        self.ignore_indices = {self.pad_index, self.sos_index, self.eos_index, self.unk_index}

    def _add_token(self, token: str) -> int:
        """Add one token to the vocabulary.

        Args:
            token: The token to be added.

        Returns:
            The index of the input token.
        """
        if token in self.token_to_index:
            return self.token_to_index[token]
        index = len(self)
        self.token_to_index[token] = index
        self.index_to_token[index] = token
        return index

    def __len__(self):
        return len(self.token_to_index)

    def train(self, formulas: List[List[str]], min_count: int = 2) -> None:
        """Create a mapping from tokens to indices and vice versa.

        Args:
            formulas: Lists of tokens.
            min_count: Tokens that appear fewer than `min_count` will not be
                included in the mapping.
        """
        # Count the frequency of each token
        counter: Dict[str, int] = {}
        for formula in formulas:
            for token in formula:
                counter[token] = counter.get(token, 0) + 1

        for token, count in counter.items():
            # Remove tokens that show up fewer than `min_count` times
            if count < min_count:
                continue
            index = len(self)
            self.index_to_token[index] = token
            self.token_to_index[token] = index

    def encode(self, formula: List[str]) -> List[int]:
        indices = [self.sos_index]
        for token in formula:
            index = self.token_to_index.get(token, self.unk_index)
            indices.append(index)
        indices.append(self.eos_index)
        return indices

    def decode(self, indices: List[int], inference: bool = True) -> List[str]:
        tokens = []
        for index in indices:
            if index not in self.index_to_token:
                raise RuntimeError(f"Found an unknown index {index}")
            if index == self.eos_index:
                break
            if inference and index in self.ignore_indices:
                continue
            token = self.index_to_token[index]
            tokens.append(token)
        return tokens

    def save(self, filename: Union[Path, str]):
        """Save token-to-index mapping to a json file."""
        with open(filename, "w") as f:
            json.dump(self.token_to_index, f)

    @classmethod
    def load(cls, filename: Union[Path, str]) -> "Tokenizer":
        """Create a `Tokenizer` from a mapping file outputted by `save`.

        Args:
            filename: Path to the file to read from.

        Returns:
            A `Tokenizer` object.
        """
        with open(filename) as f:
            token_to_index = json.load(f)
        return cls(token_to_index)

In [None]:
# Extract regions of interest

DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
cur_dir = os.getcwd()
os.chdir(DATA_DIRNAME)

if not PROCESSED_IMAGES_DIRNAME.exists():
    PROCESSED_IMAGES_DIRNAME.mkdir(parents=True, exist_ok=True)
    print("Cropping images...")
    for image_filename in RAW_IMAGES_DIRNAME.glob("*.png"):
        cropped_image = crop(image_filename, padding=8)
        if not cropped_image:
            continue
        cropped_image.save(PROCESSED_IMAGES_DIRNAME / image_filename.name)

# Clean the ground truth file
cleaned_file = "im2latex_formulas.norm.new.lst"

if not Path(cleaned_file).is_file():
    print("Cleaning data...")

    with open(DATA_DIRNAME / "im2latex_formulas.norm.lst", "r", encoding="utf-8") as infile, open(cleaned_file, "w", encoding="utf-8") as outfile:
        for line in infile:
            line = re.sub(r'\\left\(', '(', line)
            line = re.sub(r'\\right\)', ')', line)
            line = re.sub(r'\\left\[', '[', line)
            line = re.sub(r'\\right\]', ']', line)
            line = re.sub(r'\\left\{', '{', line)
            line = re.sub(r'\\right\}', '}', line)
            line = re.sub(r'\\vspace\s*\{\s*[^}]*\s*\}', '', line)
            line = re.sub(r'\\hspace\s*\{\s*[^}]*\s*\}', '', line)
            outfile.write(line)


# Build vocabulary
if not VOCAB_FILE.is_file():
    print("Building vocabulary...")
    all_formulas = get_all_formulas(cleaned_file)
    _, train_formulas = get_split(all_formulas, "im2latex_train_filter.lst")
    tokenizer = Tokenizer()
    tokenizer.train(train_formulas)
    tokenizer.save(VOCAB_FILE)
os.chdir(cur_dir)

In [None]:
class BaseDataset(Dataset):
    """A base Dataset class.

    Args:
        image_filenames: (N, *) feature vector.
        targets: (N, *) target vector relative to data.
        transform: Feature transformation.
        target_transform: Target transformation.
    """

    def __init__(
        self,
        root_dir: Path,
        image_filenames: List[str],
        formulas: List[List[str]],
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__()
        assert len(image_filenames) == len(formulas)
        self.root_dir = root_dir
        self.image_filenames = image_filenames
        self.formulas = formulas
        self.transform = transform

    def __len__(self) -> int:
        """Returns the number of samples."""
        return len(self.formulas)

    def __getitem__(self, idx: int):
        """Returns a sample from the dataset at the given index."""
        image_filename, formula = self.image_filenames[idx], self.formulas[idx]
        image_filepath = self.root_dir / image_filename
        if image_filepath.is_file():
            image = pil_loader(image_filepath, mode="L")
        else:
            # Returns a blank image if cannot find the image
            image = Image.fromarray(np.full((64, 128), 255, dtype=np.uint8))
            formula = []
        if self.transform is not None:
            image = self.transform(image=np.array(image))["image"]
        return image, formula

In [None]:
class Im2Latex(LightningDataModule):
    """Data processing for the Im2Latex-100K dataset.

    Args:
        batch_size: The number of samples per batch.
        num_workers: The number of subprocesses to use for data loading.
        pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory
            before returning them.
    """

    def __init__(
        self,
        batch_size: int = 8,
        num_workers: int = 0,
        pin_memory: bool = False,
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        self.data_dirname = Path(__file__).resolve().parents[1] / "data"
        self.vocab_file = Path(__file__).resolve().parents[1] / "data" / "vocab.json"
        formula_file = self.data_dirname / "im2latex_formulas.norm.new.lst"
        if not formula_file.is_file():
            raise FileNotFoundError("Did you run scripts/prepare_data.py?")
        self.all_formulas = get_all_formulas(formula_file)
        self.transform = {
            "train": A.Compose(
                [
                    A.Affine(scale=(0.6, 1.0), rotate=(-1, 1), p=0.5),
                    A.GaussNoise(p=0.5),
                    A.GaussianBlur(blur_limit=(1, 1), p=0.5),
                    ToTensorV2(),
                ]
            ),
            "val/test": ToTensorV2(),
        }

    @property
    def processed_images_dirname(self):
        return self.data_dirname / "formula_images_processed"

    def setup(self, stage: Optional[str] = None) -> None:
        """Load images and formulas, and assign them to a `torch Dataset`.

        `self.train_dataset`, `self.val_dataset` and `self.test_dataset` will
        be assigned after this method is called.
        """
        self.tokenizer = Tokenizer.load(self.vocab_file)

        if stage in ("fit", None):
            train_image_names, train_formulas = get_split(
                self.all_formulas,
                self.data_dirname / "im2latex_train_filter.lst",
            )
            self.train_dataset = BaseDataset(
                self.processed_images_dirname,
                image_filenames=train_image_names,
                formulas=train_formulas,
                transform=self.transform["train"],
            )

            val_image_names, val_formulas = get_split(
                self.all_formulas,
                self.data_dirname / "im2latex_validate_filter.lst",
            )
            self.val_dataset = BaseDataset(
                self.processed_images_dirname,
                image_filenames=val_image_names,
                formulas=val_formulas,
                transform=self.transform["val/test"],
            )

        if stage in ("test", None):
            test_image_names, test_formulas = get_split(
                self.all_formulas,
                self.data_dirname / "im2latex_test_filter.lst",
            )
            self.test_dataset = BaseDataset(
                self.processed_images_dirname,
                image_filenames=test_image_names,
                formulas=test_formulas,
                transform=self.transform["val/test"],
            )

    def collate_fn(self, batch):
        images, formulas = zip(*batch)
        B = len(images)
        max_H = max(image.shape[1] for image in images)
        max_W = max(image.shape[2] for image in images)
        max_length = max(len(formula) for formula in formulas)
        padded_images = torch.zeros((B, 1, max_H, max_W))
        batched_indices = torch.zeros((B, max_length + 2), dtype=torch.long)
        for i in range(B):
            H, W = images[i].shape[1], images[i].shape[2]
            y, x = random.randint(0, max_H - H), random.randint(0, max_W - W)
            padded_images[i, :, y : y + H, x : x + W] = images[i]
            indices = self.tokenizer.encode(formulas[i])
            batched_indices[i, : len(indices)] = torch.tensor(indices, dtype=torch.long)
        return padded_images, batched_indices

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self.collate_fn,
        )

## define model

In [None]:
class PositionalEncoding2D(nn.Module):
    """2-D positional encodings for the feature maps produced by the encoder.

    Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.

    Reference:
    https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py
    """

    def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
        super().__init__()
        self.d_model = d_model
        assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
        pe = self.make_pe(d_model, max_h, max_w)  # (d_model, max_h, max_w)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor:
        """Compute positional encoding."""
        pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h)  # (max_h, 1 d_model // 2)
        pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w)  # (d_model // 2, max_h, max_w)

        pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w)  # (max_w, 1, d_model // 2)
        pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1)  # (d_model // 2, max_h, max_w)

        pe = torch.cat([pe_h, pe_w], dim=0)  # (d_model, max_h, max_w)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (B, d_model, H, W)

        Returns:
            (B, d_model, H, W)
        """
        assert x.shape[1] == self.pe.shape[0]  # type: ignore
        x = x + self.pe[:, : x.size(2), : x.size(3)]  # type: ignore
        return x


class PositionalEncoding1D(nn.Module):
    """Classic Attention-is-all-you-need positional encoding."""

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = self.make_pe(d_model, max_len)  # (max_len, 1, d_model)
        self.register_buffer("pe", pe)

    @staticmethod
    def make_pe(d_model: int, max_len: int) -> Tensor:
        """Compute positional encoding."""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        return pe

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (S, B, d_model)

        Returns:
            (B, d_model, H, W)
        """
        assert x.shape[2] == self.pe.shape[2]  # type: ignore
        x = x + self.pe[: x.size(0)]  # type: ignore
        return self.dropout(x)

In [None]:
class ResNetTransformer(nn.Module):
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_decoder_layers: int,
        max_output_len: int,
        sos_index: int,
        eos_index: int,
        pad_index: int,
        num_classes: int,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.max_output_len = max_output_len + 2
        self.sos_index = sos_index
        self.eos_index = eos_index
        self.pad_index = pad_index

        # Encoder
        resnet = torchvision.models.resnet18(weights=None)
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
        )
        self.bottleneck = nn.Conv2d(256, self.d_model, 1)
        self.image_positional_encoder = PositionalEncoding2D(self.d_model)

        # Decoder
        self.embedding = nn.Embedding(num_classes, self.d_model)
        self.y_mask = generate_square_subsequent_mask(self.max_output_len)
        self.word_positional_encoder = PositionalEncoding1D(self.d_model, max_len=self.max_output_len)
        transformer_decoder_layer = nn.TransformerDecoderLayer(self.d_model, nhead, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(transformer_decoder_layer, num_decoder_layers)
        self.fc = nn.Linear(self.d_model, num_classes)

        # It is empirically important to initialize weights properly
        if self.training:
            self._init_weights()

    def _init_weights(self) -> None:
        """Initialize weights."""
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-init_range, init_range)

        nn.init.kaiming_normal_(
            self.bottleneck.weight.data,
            a=0,
            mode="fan_out",
            nonlinearity="relu",
        )
        if self.bottleneck.bias is not None:
            _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.bottleneck.weight.data)
            bound = 1 / math.sqrt(fan_out)
            nn.init.normal_(self.bottleneck.bias, -bound, bound)

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        """Forward pass.

        Args:
            x: (B, _E, _H, _W)
            y: (B, Sy) with elements in (0, num_classes - 1)

        Returns:
            (B, num_classes, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, num_classes)
        output = output.permute(1, 2, 0)  # (B, num_classes, Sy)
        return output

    def encode(self, x: Tensor) -> Tensor:
        """Encode inputs.

        Args:
            x: (B, C, _H, _W)

        Returns:
            (Sx, B, E)
        """
        # Resnet expects 3 channels but training images are in gray scale
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        x = self.backbone(x)  # (B, RESNET_DIM, H, W); H = _H // 32, W = _W // 32
        x = self.bottleneck(x)  # (B, E, H, W)
        x = self.image_positional_encoder(x)  # (B, E, H, W)
        x = x.flatten(start_dim=2)  # (B, E, H * W)
        x = x.permute(2, 0, 1)  # (Sx, B, E); Sx = H * W
        return x

    def decode(self, y: Tensor, encoded_x: Tensor) -> Tensor:
        """Decode encoded inputs with teacher-forcing.

        Args:
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, num_classes - 1)

        Returns:
            (Sy, B, num_classes) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.d_model)  # (Sy, B, E)
        y = self.word_positional_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, num_classes)
        return output

    def predict(self, x: Tensor) -> Tensor:
        """Make predctions at inference time.

        Args:
            x: (B, C, H, W). Input images.

        Returns:
            (B, max_output_len) with elements in (0, num_classes - 1).
        """
        B = x.shape[0]
        S = self.max_output_len

        encoded_x = self.encode(x)  # (Sx, B, E)

        output_indices = torch.full((B, S), self.pad_index).type_as(x).long()
        output_indices[:, 0] = self.sos_index
        has_ended = torch.full((B,), False)

        for Sy in range(1, S):
            y = output_indices[:, :Sy]  # (B, Sy)
            logits = self.decode(y, encoded_x)  # (Sy, B, num_classes)
            # Select the token with the highest conditional probability
            output = torch.argmax(logits, dim=-1)  # (Sy, B)
            output_indices[:, Sy] = output[-1:]  # Set the last output token

            # Early stopping of prediction loop to speed up prediction
            has_ended |= (output_indices[:, Sy] == self.eos_index).type_as(has_ended)
            if torch.all(has_ended):
                break

        # Set all tokens after end token to be padding
        eos_positions = find_first(output_indices, self.eos_index)
        for i in range(B):
            j = int(eos_positions[i].item()) + 1
            output_indices[i, j:] = self.pad_index

        return output_indices


def generate_square_subsequent_mask(size: int) -> Tensor:
    """Generate a triangular (size, size) mask."""
    mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
    return mask


def find_first(x: Tensor, element: Union[int, float], dim: int = 1) -> Tensor:
    """Find the first occurence of element in x along a given dimension.

    Args:
        x: The input tensor to be searched.
        element: The number to look for.
        dim: The dimension to reduce.

    Returns:
        Indices of the first occurence of the element in x. If not found, return the
        length of x along dim.

    Usage:
        >>> first_element(Tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1]]), 3)
        tensor([2, 1, 3])

    Reference:
        https://discuss.pytorch.org/t/first-nonzero-index/24769/9

        I fixed an edge case where the element we are looking for is at index 0. The
        original algorithm will return the length of x instead of 0.
    """
    mask = x == element
    found, indices = ((mask.cumsum(dim) == 1) & mask).max(dim)
    indices[(~found) & (indices == 0)] = x.shape[dim]
    return indices

In [None]:
class CharacterErrorRate(Metric):
    def __init__(self, ignore_indices: Set[int], *args):
        super().__init__(*args)
        self.ignore_indices = ignore_indices
        self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.error: Tensor
        self.total: Tensor

    def update(self, preds, targets):
        N = preds.shape[0]
        for i in range(N):
            pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
            target = [token for token in targets[i].tolist() if token not in self.ignore_indices]
            distance = editdistance.distance(pred, target)
            if max(len(pred), len(target)) > 0:
                self.error += distance / max(len(pred), len(target))
        self.total += N

    def compute(self) -> Tensor:
        return self.error / self.total


class ExactMatchScore(Metric):
    def __init__(self, ignore_indices: Set[int], *args):
        super().__init__(*args)
        self.ignore_indices = ignore_indices
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.correct: Tensor
        self.total: Tensor

    def update(self, preds, targets):
        N = preds.shape[0]
        for i in range(N):
            pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
            target = [token for token in targets[i].tolist() if token not in self.ignore_indices]
            if pred == target:
                self.correct += 1
        self.total += N

    def compute(self) -> Tensor:
        return self.correct / self.total


class BLEUScore(Metric):
    def __init__(self, ignore_indices: Set[int], *args):
        super().__init__(*args)
        self.ignore_indices = ignore_indices
        self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.score: Tensor
        self.total: Tensor
        self.smoothing_function = SmoothingFunction().method1

    def update(self, preds, targets):
        N = preds.shape[0]
        for i in range(N):
            pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
            target = [token for token in targets[i].tolist() if token not in self.ignore_indices]

            # Convert token IDs to strings for BLEU calculation
            pred_str = [str(token) for token in pred]
            target_str = [str(token) for token in target]

            # Calculate BLEU score (using smoothing to handle edge cases)
            if len(pred_str) > 0 and len(target_str) > 0:
                bleu = sentence_bleu([target_str], pred_str, smoothing_function=self.smoothing_function)
                self.score += torch.tensor(bleu)
        self.total += N

    def compute(self) -> Tensor:
        return self.score / self.total


class EditDistance(Metric):
    def __init__(self, ignore_indices: Set[int], *args):
        super().__init__(*args)
        self.ignore_indices = ignore_indices
        self.add_state("distance", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.distance: Tensor
        self.total: Tensor

    def update(self, preds, targets):
        N = preds.shape[0]
        for i in range(N):
            pred = [token for token in preds[i].tolist() if token not in self.ignore_indices]
            target = [token for token in targets[i].tolist() if token not in self.ignore_indices]

            # Calculate raw edit distance
            distance = editdistance.distance(pred, target)
            self.distance += distance
        self.total += N

    def compute(self) -> Tensor:
        return self.distance / self.total

In [None]:
class LitResNetTransformer(LightningModule):
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_decoder_layers: int,
        max_output_len: int,
        lr: float = 0.001,
        weight_decay: float = 0.0001,
        milestones: List[int] = [5],
        gamma: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.weight_decay = weight_decay
        self.milestones = milestones
        self.gamma = gamma

        # For tracking epoch-level metrics
        self.training_step_outputs = []
        self.validation_step_outputs = []

        vocab_file = Path(__file__).resolve().parents[1] / "data" / "vocab.json"
        self.tokenizer = Tokenizer.load(vocab_file)
        self.model = ResNetTransformer(
            d_model=d_model,
            dim_feedforward=dim_feedforward,
            nhead=nhead,
            dropout=dropout,
            num_decoder_layers=num_decoder_layers,
            max_output_len=max_output_len,
            sos_index=self.tokenizer.sos_index,
            eos_index=self.tokenizer.eos_index,
            pad_index=self.tokenizer.pad_index,
            num_classes=len(self.tokenizer),
        )
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_index)

        # Training metrics
        self.train_cer = CharacterErrorRate(self.tokenizer.ignore_indices)
        self.train_exact_match = ExactMatchScore(self.tokenizer.ignore_indices)
        self.train_bleu = BLEUScore(self.tokenizer.ignore_indices)
        self.train_edit_distance = EditDistance(self.tokenizer.ignore_indices)

        # Validation metrics
        self.val_cer = CharacterErrorRate(self.tokenizer.ignore_indices)
        self.val_exact_match = ExactMatchScore(self.tokenizer.ignore_indices)
        self.val_bleu = BLEUScore(self.tokenizer.ignore_indices)
        self.val_edit_distance = EditDistance(self.tokenizer.ignore_indices)

        # Test metrics
        self.test_cer = CharacterErrorRate(self.tokenizer.ignore_indices)
        self.test_exact_match = ExactMatchScore(self.tokenizer.ignore_indices)
        self.test_bleu = BLEUScore(self.tokenizer.ignore_indices)
        self.test_edit_distance = EditDistance(self.tokenizer.ignore_indices)

    def training_step(self, batch, batch_idx):
        imgs, targets = batch
        logits = self.model(imgs, targets[:, :-1])
        loss = self.loss_fn(logits, targets[:, 1:])

        # Log loss for each step (for progress bar and step-level tracking)
        self.log("train/loss_step", loss, on_step=True, on_epoch=False, prog_bar=True)

        # Store loss for epoch-level logging
        self.training_step_outputs.append(loss.detach())

        # Only calculate metrics on a small subset of batches (1%) to avoid slowing down training
        # This gives us some metrics during training without significant slowdown
        if batch_idx % 100 == 0:  # Calculate metrics every 100 batches
            with torch.no_grad():  # Use no_grad to save memory
                preds = self.model.predict(imgs)
                self.train_cer.update(preds, targets)
                self.train_exact_match.update(preds, targets)
                self.train_bleu.update(preds, targets)
                self.train_edit_distance.update(preds, targets)

        return loss

    def on_train_epoch_end(self):
        # Calculate and log average training loss for the epoch
        if len(self.training_step_outputs) > 0:
            epoch_mean_loss = torch.stack(self.training_step_outputs).mean()
            # Log the epoch average loss
            self.log("train/loss_epoch", epoch_mean_loss, prog_bar=True)
            # Clear the list for the next epoch
            self.training_step_outputs.clear()

        # Calculate and log training metrics at the end of each epoch
        if self.train_cer.total > 0:  # Only log if we have collected some data
            self.log("train/cer", self.train_cer.compute(), prog_bar=True)
            self.log("train/exact_match", self.train_exact_match.compute())
            self.log("train/bleu", self.train_bleu.compute())
            self.log("train/edit_distance", self.train_edit_distance.compute())

            # Reset metrics for next epoch
            self.train_cer.reset()
            self.train_exact_match.reset()
            self.train_bleu.reset()
            self.train_edit_distance.reset()

    def validation_step(self, batch, batch_idx):  # batch_idx is required by PyTorch Lightning
        imgs, targets = batch
        logits = self.model(imgs, targets[:, :-1])
        loss = self.loss_fn(logits, targets[:, 1:])

        # Store loss for epoch-level logging
        self.validation_step_outputs.append(loss.detach())

        # Log step-level loss for debugging
        self.log("val/loss_step", loss, on_step=True, on_epoch=False, prog_bar=False)

        with torch.no_grad():  # Use no_grad to save memory
            preds = self.model.predict(imgs)
            # Update metrics (don't log yet, will be logged at epoch end)
            self.val_cer.update(preds, targets)
            self.val_exact_match.update(preds, targets)
            self.val_bleu.update(preds, targets)
            self.val_edit_distance.update(preds, targets)

        return loss

    def on_validation_epoch_end(self):
        # Calculate and log average validation loss for the epoch
        if len(self.validation_step_outputs) > 0:
            epoch_mean_loss = torch.stack(self.validation_step_outputs).mean()
            # Log the epoch average loss with sync_dist=True
            self.log("val/loss_epoch", epoch_mean_loss, prog_bar=True, sync_dist=True)
            # Clear the list for the next epoch
            self.validation_step_outputs.clear()

        # Log computed metrics
        if self.val_cer.total > 0:
            val_cer = self.val_cer.compute()
            val_exact_match = self.val_exact_match.compute()
            val_bleu = self.val_bleu.compute()
            val_edit_distance = self.val_edit_distance.compute()

            self.log("val/cer", val_cer, prog_bar=True, sync_dist=True)
            self.log("val/exact_match", val_exact_match, sync_dist=True)
            self.log("val/bleu", val_bleu, sync_dist=True)
            self.log("val/edit_distance", val_edit_distance, sync_dist=True)

            # Reset metrics for next epoch
            self.val_cer.reset()
            self.val_exact_match.reset()
            self.val_bleu.reset()
            self.val_edit_distance.reset()

            # Print a message to confirm metrics were logged
            print(f"Validation metrics logged: loss_epoch={epoch_mean_loss:.4f}, cer={val_cer:.4f}, exact_match={val_exact_match:.4f}, bleu={val_bleu:.4f}, edit_distance={val_edit_distance:.4f}")

    def test_step(self, batch, batch_idx):  # batch_idx is required by PyTorch Lightning
        imgs, targets = batch
        with torch.no_grad():  # Use no_grad to save memory
            preds = self.model.predict(imgs)
            test_cer = self.test_cer(preds, targets)
            test_exact_match = self.test_exact_match(preds, targets)
            test_bleu = self.test_bleu(preds, targets)
            test_edit_distance = self.test_edit_distance(preds, targets)

            self.log("test/cer", test_cer)
            self.log("test/exact_match", test_exact_match)
            self.log("test/bleu", test_bleu)
            self.log("test/edit_distance", test_edit_distance)
            return preds

    def on_test_epoch_end(self):
        test_outputs = self.trainer.predict_loop.predictions
        with open("test_predictions.txt", "w") as f:
            for preds in test_outputs:
                for pred in preds:
                    decoded = self.tokenizer.decode(pred.tolist())
                    decoded.append("\n")
                    decoded_str = " ".join(decoded)
                    f.write(decoded_str)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.gamma)
        return [optimizer], [scheduler]


In [None]:
config = {
    "seed": 1234,
    "trainer": {
        "accelerator": "gpu",
        "devices": 1,
        "overfit_batches": 0.0,
        "check_val_every_n_epoch": 1,
        "fast_dev_run": False,
        "max_epochs": 15,
        "min_epochs": 1,
        "num_sanity_val_steps": 0,
        "enable_checkpointing": True,
        "log_every_n_steps": 100,
        "enable_progress_bar": True,
        "enable_model_summary": True,
    },
    "callbacks": {
        "model_checkpoint": {
            "dirpath": "checkpoints",
            "save_top_k": 2,
            "save_weights_only": True,
            "mode": "min",
            "monitor": "val/loss_epoch",
            "filename": "epoch_{epoch:02d}_valloss_{val/loss_epoch:.2f}",
        },
        "early_stopping": {
            "patience": 3,
            "mode": "min",
            "monitor": "val/loss_epoch",
            "min_delta": 0.001,
        },
    },
    "data": {
        "batch_size": 8,
        "num_workers": 2,
        "pin_memory": True,
    },
    "lit_model": {
        "lr": 0.001,
        "weight_decay": 0.0001,
        "milestones": [10],
        "gamma": 0.5,
        "d_model": 64,
        "dim_feedforward": 128,
        "nhead": 2,
        "dropout": 0.3,
        "num_decoder_layers": 2,
        "max_output_len": 150,
    },
    "logger": {
        "project": "image-to-latex",
        "log_model": True,
        "offline": False,
        "name": None,
        "save_dir": "wandb",
        "version": None,
        "prefix": "",
        "job_type": "train",
    },
}


In [None]:
from argparse import Namespace

def dict_to_namespace(d):
    ns = Namespace()
    for k, v in d.items():
        setattr(ns, k, dict_to_namespace(v) if isinstance(v, dict) else v)
    return ns


In [None]:
import wandb
wandb.login()

# Dataloader & model
data_module = Im2Latex(**config["data"])
data_module.setup()

lit_model = LitResNetTransformer(**config["lit_model"])

# Callbacks
callbacks = [
    ModelCheckpoint(**config["callbacks"]["model_checkpoint"]),
    EarlyStopping(**config["callbacks"]["early_stopping"])
]

# Logger
logger = WandbLogger(**config["logger"])

# Trainer
trainer = Trainer(
    **config["trainer"],
    callbacks=callbacks,
    logger=logger,
)

# Logging hyperparams
if trainer.logger:
    trainer.logger.log_hyperparams(dict_to_namespace(config))

In [None]:
# Fit
trainer.fit(lit_model, datamodule=data_module)
trainer.test(lit_model, datamodule=data_module)

In [None]:
# === Inference ===

# Determine device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ====== Paths (nhập thủ công hoặc gán từ biến) ======
image_path = ""
checkpoint_path = ""

# Check if paths exist
if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")

if not os.path.exists(image_path):
    raise FileNotFoundError(f"Image file not found at: {image_path}")

# Load model
model = LitResNetTransformer.load_from_checkpoint(checkpoint_path)
model.eval()
model.to(device)
model.freeze()

# Transform
transform = ToTensorV2()

# Load & preprocess image
image = Image.open(image_path).convert("L")
image_tensor = transform(image=np.array(image))["image"]
image_tensor = image_tensor.unsqueeze(0).float().to(device)

# Inference
with torch.no_grad():
    pred = model.model.predict(image_tensor)[0]
    pred = pred.cpu() if device.type == "cuda" else pred
    decoded = model.tokenizer.decode(pred.tolist())
    decoded_str = "".join(decoded)

print("=== LaTeX Prediction ===")
print(decoded_str)
