In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as T

import numpy as np
import pandas as pd

from PIL import Image
from typing import Union
from tqdm import tqdm
from datetime import datetime
from collections import Counter

import warnings
import json
import re
import os

In [21]:
class ConvBlock(nn.Module):
    """
    Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        self.relu = nn.LeakyReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        c = self.conv(x)
        r = self.relu(c)
        return r


class CNNLSTM(nn.Module):
    def __init__(
        self,
        emb_dim: int = 80,
        hidden_dim: int = 512,
        dropout_prob: float = 0,
        vocab_len: int = None,
        max_output_length: int = 512,
        device: str = "cpu",
    ) -> None:
        super().__init__()

        # At least special tokens must be present
        assert vocab_len > 4, "Vocabulary length must be at least 4"

        self._encoder_out = hidden_dim
        self._hidden_dim = hidden_dim
        self._vocab_len = vocab_len
        # standard vocabulary with special tokens
        self._vocab = {
            "<UNK>": 0,
            "<SOS>": 1,
            "<PAD>": 2,
            "<EOS>": 3,
        }
        # output dim is set to vocab_len
        self._output_dim = self._vocab_len
        self._emb_dim = emb_dim
        self._max_len = max_output_length
        self._device = device

        self.encoder = nn.Sequential(
            ConvBlock(in_channels=1, out_channels=64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=64, out_channels=128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBlock(in_channels=128, out_channels=256),
            ConvBlock(in_channels=256, out_channels=256),
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
            ConvBlock(
                in_channels=256, out_channels=self._encoder_out, padding=0
            ),
        )

        self.decoder = nn.LSTM(
            input_size=self._emb_dim + self._encoder_out,
            hidden_size=self._hidden_dim,
            num_layers=1,
            batch_first=True,
        )

        self.dropout = nn.Dropout(p=dropout_prob)
        self.embedding = nn.Embedding(
            self._output_dim,
            self._emb_dim,  # padding_idx=self._vocab["<PAD>"]
        )
        self.hidden0_fc = nn.Sequential(
            nn.Linear(self._encoder_out, self._hidden_dim)
        )
        self.fc_out = nn.Sequential(
            nn.Linear(self._hidden_dim, self._output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)

        # INITIALIZE OUTPUT TENSORS
        # outputs of shape (B, MAX_LEN, VOCAB_SIZE (OUTPUT_DIM))
        outputs = (
            torch.ones(
                batch_size, self._max_len, self._output_dim, requires_grad=True
            )
            .type_as(x)
            .to(self._device)
            * self._vocab["<PAD>"]
        )
        # 1st input is always <START_SEQ>
        input_token = (
            torch.ones(batch_size, 1).type_as(x).to(self._device).long()
            * self._vocab["<SOS>"]
        )

        # ENCODE IMAGE
        encoded_img = self.encoder(x)
        encoded_img = encoded_img.permute(
            0, 2, 3, 1
        )  # make B * H * W * HIDDEN_DIM to use contiguous

        _, H, W, _ = encoded_img.size()

        encoded_img = encoded_img.contiguous().view(
            batch_size,
            H * W,
            self._encoder_out,
        )  # [B, HIDDEN_DIM, H * W]
        encoded_img = encoded_img.mean(dim=1)  # [B, HIDDEN_DIM]

        hidden = self.hidden0_fc(encoded_img)
        cell = torch.zeros(batch_size, self._hidden_dim)
        output = torch.zeros(batch_size, self._hidden_dim)

        for t in range(1, self._max_len):
            hidden, cell, output = self.decode(
                hidden=hidden,
                cell=cell,
                out_t=output,
                input_token=input_token,
            )

            logit = self.fc_out(output)
            outputs[:, t, ...] = logit
            input_token = torch.argmax(logit, 1)

        return outputs

    def decode(
        self,
        hidden: tuple[torch.Tensor],
        cell: torch.Tensor,
        out_t: torch.Tensor,
        input_token: torch.Tensor,
    ) -> tuple[torch.Tensor]:
        prev_y = self.embedding(input_token).squeeze(1)
        input_t = torch.cat([prev_y, out_t], 1)
        out_t, (hidden_t, cell_t) = self.decoder(input_t, (hidden, cell))

        return hidden_t, cell_t, out_t

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        encoded_img = self.encoder(x)

        hidden = encoded_img
        cell = torch.zeros(batch_size, self._hidden_dim)
        output = torch.zeros(batch_size, self._hidden_dim)

        # outputs of shape (B, MAX_LEN, VOCAB_LEN (OUTPUT_DIM))
        outputs = (
            torch.ones(batch_size, self._max_len, self._output_dim)
            .type_as(x)
            .long()
            .to(self._device)
            * self._vocab["<PAD>"]
        )
        # 1st input is always <START_SEQ>
        input_token = (
            torch.ones(batch_size, 1).type_as(x).to(self._device).long()
            * self._vocab["<SOS>"]
        )

        for t in range(1, self._max_len):
            hidden, cell, output, logit = self.decode(
                hidden=hidden,
                cell=cell,
                out_t=output,
                input_token=input_token,
            )

            outputs[:, t, ...] = logit
            input_token = torch.argmax(logit, 1)

            if input_token == self._vocab["<EOS>"]:
                break

        return outputs

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)

    @property
    def output_dim(self):
        return self._output_dim

    @property
    def max_len(self):
        return self._max_len

    @property
    def device(self):
        return self._device

    @property
    def img_size(self):
        return self._img_size

    @property
    def img_size(self):
        return self._img_size


In [22]:
class LaTEXTokenizer:
    """
    Tokenizer for LaTEX exspressions. Vocabulary is generated from IM2LATEX-100k

    ---
    Parameters
    ---
    id2token: dict[str. int]
        Mapping from tokens to indices
    """

    def __init__(self, token2id: dict[str, int]) -> None:
        self._token2id = {k: int(v) for k, v in token2id.items()}
        self._id2token = {int(v): k for k, v in token2id.items()}

    def tokenize(
        self,
        x: list[str],
        return_tensors: bool = True,
        pad: bool = True,
        max_len: int = 512,
    ) -> Union[torch.Tensor, list[list[int]]]:  # separate dots
        """
        Tokenize list of sentences.

        -----
        Parameters
        -----

        x: list[str]
            Input list of sentences

        ---
        Returns
        ---
        torch.Tensor
            Tensor of indices with shape (B, MAX_LEN)"""

        x = [s.replace(".", " . ") for s in x]
        # separate digits
        x = [re.sub(r"(\d)", r" \1", s) for s in x]
        x = [s.strip().split() for s in x]
        x = [
            [self._token2id.get(token, self._token2id["<UNK>"]) for token in s]
            for s in x
        ]

        if any(self._token2id["<UNK>"] in s for s in x):
            warnings.warn(
                "Got unknown token. May affect final result",
            )

        # insert start and end tokens
        x = [[self._token2id["<SOS>"]] + s for s in x]
        x = [s + [self._token2id["<EOS>"]] for s in x]
        x = [s[:max_len] for s in x]

        if pad:
            # pad sequences to max length
            x = [s + [self._token2id["<PAD>"]] * (max_len - len(s)) for s in x]

        if return_tensors:
            x = torch.Tensor(x)

        return x

    def decode(
        self,
        x: Union[torch.Tensor, np.ndarray, list[int]],
    ) -> list[str]:
        """
        Decodes input array of indices into a string(s)

        ---
        Parameters
        ---
        x: Union[torch.Tensor, np.ndarray, list[int]]
            Input sequence of sequences of indices

        ---
        Returns
        ---
        list[str]
            List of decoded strings (ready to be displayed as LaTEX)
        """
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().tolist()
        if isinstance(x, np.ndarray):
            x = x.tolist()

        x = [
            [
                self._id2token[i]
                for i in s
                if i not in list(self.special_tokens.values())
                # ignore special tokens
            ]
            for s in x
        ]
        x = [" ".join(s) for s in x]
        x = [re.sub(r" +", " ", s) for s in x]
        x = [s.strip() for s in x]

        return x

    @property
    def special_tokens(self) -> dict[str, int]:
        """
        Get mapping (str2int) of special tokens
        """
        tokens = ["<SOS>", "<PAD>", "<EOS>"]
        return {k: self.token2id.get(k, 3) for k in tokens}

    @property
    def id2token(self) -> dict[int, str]:
        """
        Get vocabulary (int2str)
        """
        return self._id2token

    @property
    def token2id(self) -> dict[str, int]:
        """
        Get vocabulary (str2int)
        """
        return self._token2id


In [23]:
def read_input(path: str, col_name: str = None) -> pd.DataFrame:
    # read input file or raise ValueError
    ext = path.strip().split(".")[-1]

    if ext == "json":
        with open(path, "r") as f:
            df = json.load(f)

    else:
        try:
            read_file = getattr(pd, f"read_{ext}")
            df = read_file(path)
            print(f"{df.shape=}")
            # check if df has only one column
            if not col_name:
                assert (
                    df.shape[1] == 1
                ), "DataFrame should have only one column or column to be used must be specified"
                col = df.columns.to_list()[0]
                df = df.rename({col: "formula"}, axis=1)
            # or name is specified
            else:
                df = df.rename({col_name: "formula"}, axis=1)

        except Exception as e:
            raise e

    return df


def save_output(path: str, obj: object) -> None:
    if path.endswith(".json"):
        with open(path, "w") as f:
            json.dump(obj, f, indent=4)
    else:
        ext = path.strip().split(".")[-1]

        # get pandas to_{extension}, ie to_csv, method
        save = getattr(obj, f"to_{ext}")
        save(path, index=False)


def prepare_csv_array(row: str) -> list[int]:
    row = row.replace("[", "")
    row = row.replace("]", "")
    row = row.strip().split(",")
    row = list(map(int, row))
    return row

In [38]:
class IM2LaTEX100K(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        image_folder: str,
        vocab_len: int,
        transform: Union[T.Compose, None],
    ):
        super().__init__()
        self._file = ""
        self._data = data
        self._transform = transform
        self._vocab_len = vocab_len
        self._image_folder = image_folder

    def __getitem__(self, index) -> tuple[torch.Tensor]:
        row = self._data.iloc[index]
        tokens = row["formula"]
        img_path = f"{self._image_folder}/{row['image']}"

        img = Image.open(img_path)

        if self._transform is not None:
            img = self._transform(img)

        tokens = torch.Tensor(tokens[0]).long()

        return img, tokens

    def __len__(self) -> int:
        return self._data.shape[0]

    def __repr__(self) -> str:
        return f"IM2LaTEXDataset(file={self._file}, image_folder={self._image_folder})"

    def __str__(self) -> str:
        return self.__repr__()


In [39]:
def make_vocab_func(
    input_file: str, col_name: str, add_special: bool, # output_folder: str
) -> None:
    df = read_input(input_file, col_name)

    df["formula"] = df["formula"].apply(lambda x: x.replace(".", " . "))
    df["formula"] = df["formula"].apply(lambda x: re.sub("(\d)", r" \1", x))
    df["formula_tokenized"] = df["formula"].apply(lambda x: x.strip().split())

    words = df["formula_tokenized"].tolist()
    vocab = Counter([x for sublist in words for x in sublist])
    vocab = sorted(list(vocab.keys()))

    if add_special:
        vocab = ["<UNK>", "<SOS>", "<PAD>", "<EOS>"] + vocab

    token2id = {k: i for i, k in enumerate(vocab)}
    id2token = {i: k for i, k in enumerate(vocab)}

    return token2id, id2token

In [40]:
def preprocess_function(
    input_file: str,
    vocab: dict[str, int],
    col_name: str,
    add_padding: str,
    # output_file: str,
    debug: bool,
) -> None:
    token2id = vocab
    df = read_input(input_file, col_name=col_name)
    df_len = df.shape[0]
    df = df.dropna(axis=0)

    if df_len != df.shape[0]:
        print(f"Dropped {df_len - df.shape[0]} rows from {input_file}")

    if debug:
        df = df.sample(1)

    df["formula"] = df["formula"].apply(lambda x: [x])

    tokenizer = LaTEXTokenizer(token2id=token2id)
    df["tokenized_formula"] = df["formula"].apply(
        tokenizer.tokenize, return_tensors=False, pad=add_padding, max_len=512
    )
    df = df.drop("formula", axis=1)
    df = df.rename({"tokenized_formula": "formula"}, axis=1)
    
    return df

In [41]:
def edit_distance(token1: torch.Tensor, token2: torch.Tensor) -> int:
    len1 = token1.size(0)
    distances = torch.zeros((len1 + 1, len1 + 1))

    distances[:, 0] = torch.arange(len1 + 1)
    distances[0, :] = torch.arange(len1 + 1)

    a = 0
    b = 0
    c = 0

    for t1 in range(1, len1 + 1):
        for t2 in range(1, len1 + 1):
            if token1[t1 - 1] == token2[t2 - 1]:
                distances[t1][t2] = distances[t1 - 1][t2 - 1]
            else:
                a = distances[t1][t2 - 1]
                b = distances[t1 - 1][t2]
                c = distances[t1 - 1][t2 - 1]

                if a <= b and a <= c:
                    distances[t1][t2] = a + 1
                elif b <= a and b <= c:
                    distances[t1][t2] = b + 1
                else:
                    distances[t1][t2] = c + 1

    return distances[len1][len1].item()


def make_training_report(
    loss: float, levenstein: float, accuracy: float, lr=float
) -> str:
    return f"Loss={loss:.4f} | Acc={accuracy:.4f} | ED={levenstein:.1f} | LR={lr:.1e}"

In [44]:
def training_func(
    model: nn.Module,
    n_epochs: int,
    loss_fn: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: optim.lr_scheduler.LRScheduler,
    dataloaders: dict[str, DataLoader],
    checkpoint: str,
    end_token: int,
) -> tuple[dict[dict[int]], nn.Module]:
    # training consists of 3 phases: train, val, test
    # test is performed only once after model has been
    # trained for NUM_EPOCHS

    lengths = {phase: len(dl) for phase, dl in dataloaders.items()}
    history = {}

    # for checkpoints
    keep = 2  # number of checkpoints to keep
    _last_saved = []  # cached paths of checkpoints
    # range (1, n_epochs + 1) for prettier monitoring
    for epoch in range(1, n_epochs + 1):
        losses = {"train": 0, "val": 0, "test": 0}
        num_correct = {"train": 0, "val": 0, "test": 0}
        num_total = {"train": 0, "val": 0, "test": 0}
        accuracies = {"train": 0, "val": 0, "test": 0}
        total_distances = {"train": 0, "val": 0, "test": 0}
        levenstein = {"train": 0, "val": 0, "test": 0}
        min_loss = float("inf")

        for phase in ["train", "val", "test"]:
            if phase == "train":
                model.train()
            else:
                # eval on val and test
                model.eval()

            if phase == "test" and epoch != n_epochs:
                # skip test phase if not last epoch
                continue

            # tqdm progress bar
            with tqdm(
                dataloaders[phase], miniters=1, unit="batch"
            ) as pbar, torch.set_grad_enabled(phase == "train"):
                # if phase != "train" works as torch.no_grad()
                for img, tokens in pbar:
                    pbar.set_description_str(
                        f"{phase.capitalize():5}({epoch:03d})"
                    )

                    # tokens [batch_size, max_len]
                    # output [batch_size, max_len, len(vocab)]
                    img = img.to(model.device)
                    tokens = tokens.to(model.device)
                    output = model(img)
  

                    # reshape target and output for loss
                    # output_flat [(max_len - 1) * batch_size, len(vocab)]
                    # target_flat [(max_len - 1) * batch_size]
                    output_dim = output.shape[-1]
                    output_flat = output[:, 1:].view(-1, output_dim)
                    tokens_flat = tokens[:, 1:].view(-1)

                    loss = loss_fn(output_flat, tokens_flat)

                    # will raise error with no grads
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        optimizer.zero_grad()

                    # scale loss by length of dataloader to obtain avg loss
                    losses[phase] += loss.item() / lengths[phase]

                    # [TODO] Mb ignore padding ???
                    predictions = output.argmax(2)

                    # get index of <EOS> token guaranteed to be only one in tokens
                    end_token_idx = (tokens == end_token).nonzero(
                        as_tuple=True
                    )[-1]
                    # truncate to only valuable output
                    # after hitting EOS tokens all true tokens will be <PAD>
                    predictions = predictions[:, 1:end_token_idx]
                    tokens = tokens[:, 1:end_token_idx]

                    # compute accuracy
                    num_correct[phase] += torch.sum(
                        predictions == tokens,
                        dim=1,
                    ).item()

                    num_total[phase] += tokens.size(0) * tokens.size(1)
                    accuracies[phase] = num_correct[phase] / num_total[phase]

                    # compute edit distances over batch
                    for b in range(tokens.size(0)):
                        total_distances[phase] += edit_distance(
                            predictions[b], tokens[b]
                        )
                    # average over epoch
                    levenstein[phase] = (
                        total_distances[phase]
                        # number of sequences processed = total_tokens / seq_len
                        / (num_total[phase] / tokens.size(1))
                    )

                    # make progress string
                    report = make_training_report(
                        losses[phase],
                        levenstein[phase],
                        accuracies[phase],
                        scheduler.get_last_lr()[0],
                    )
                    pbar.set_postfix_str(report)

                # step scheduler after epoch if training
                if phase == "train":
                    lr = scheduler.get_last_lr()[0]
                    scheduler.step()

            # save training history
            history[epoch] = {
                "acc": accuracies,
                "edit_dist": levenstein,
                "loss": losses,
                "lr": lr,
            }

            if losses["val"] < min_loss:
                min_loss = losses["val"]
                file_path = f"{checkpoint}/{datetime.now().strftime('%d%m-%H%M%S')}-acc-{accuracies['val']:.3f}.pth"
                torch.save(
                    {
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "history": history,
                    },
                    file_path,
                )

                _last_saved.append(file_path)

                # remove file if more than KEEP is stored
                if len(_last_saved) > keep:
                    oldest = _last_saved.pop(0)

                    if os.path.exists(oldest):
                        os.remove(oldest)

    print(f"Last checkpoint: {_last_saved[-1]}")
    return history, model


def train_model(
    # model_arch: str,
    checkpoint: str,
    data: list[pd.DataFrame],
    batch_size: int,
    num_epochs: int,
    device: torch.device,
    vocab: dict[str, int],
    image_folder: str,
):
    # [TODO] do not pass vocab only pass vocab length

    model = CNNLSTM(vocab_len=len(vocab), device=device)

    if checkpoint == "NA":
        checkpoint = f"./artifacts/{CNNLSTM.__name__}/"

    if not os.path.exists(checkpoint):
        os.mkdir(checkpoint)

    # if batch_size is not specified
    if batch_size == 0:
        batch_size = 1 if device == "cpu" else 8

    # [TODO] Use model.input_size member
    transforms = T.Compose(
        [
            T.ToTensor(),
            T.Grayscale(),
        ]
    )

    phases = ["train", "val", "test"]
    datasets = {
        phase: IM2LaTEX100K(
            data=data,
            transform=transforms,
            vocab_len=len(vocab),
            image_folder=image_folder,
        )
        for phase, data in zip(phases, data)
    }
    dataloaders = {
        phase: DataLoader(dataset, batch_size=batch_size)
        for phase, dataset in datasets.items()
    }

    ce_loss = nn.CrossEntropyLoss(ignore_index=vocab["<PAD>"])
    optimizer = optim.Adam(params=model.parameters(), lr=5e-2)
    scheduler = CosineAnnealingLR(
        optimizer=optimizer, T_max=10, eta_min=1e-3, last_epoch=-1
    )

    model.to(device)

    history, model = training_func(
                        model=model,
                        n_epochs=num_epochs,
                        loss_fn=ce_loss,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        dataloaders=dataloaders,
                        checkpoint=checkpoint,
                        end_token=vocab["<EOS>"],
                    )

    return history, model

In [45]:
token2id, _ = make_vocab_func(input_file="../data/raw/im2latex_formulas.norm.csv",
                              col_name="formulas",
                              add_special=True)

train_df = preprocess_function(input_file="../data/interim/im2latex.debug.csv",
                               vocab=token2id,
                               col_name="formula",
                               add_padding=True,
                               debug=False)
val_df = preprocess_function(input_file="../data/interim/im2latex.debug.csv",
                               vocab=token2id,
                               col_name="formula",
                               add_padding=True,
                               debug=False)
test_df = preprocess_function(input_file="../data/interim/im2latex.debug.csv",
                               vocab=token2id,
                               col_name="formula",
                               add_padding=True,
                               debug=False)

history, model = train_model(checkpoint="../artifacts/CNNLSTM/",
                             data=[train_df, val_df, test_df],
                             batch_size=1,
                             num_epochs=1,
                             device="cpu",
                             vocab=token2id,
                             image_folder="../data/raw/formula_images_processed/formula_images_processed/",
                             )

df.shape=(102863, 1)
df.shape=(5, 2)
df.shape=(5, 2)
df.shape=(5, 2)


Train(001): 100%|██████████| 5/5 [00:35<00:00,  7.20s/batch, Loss=12.7653 | Acc=0.0478 | ED=86.7 | LR=5.0e-02]
Val  (001): 100%|██████████| 5/5 [00:11<00:00,  2.22s/batch, Loss=14.3147 | Acc=0.0533 | ED=86.1 | LR=4.9e-02]
Test (001):   0%|          | 0/5 [00:02<?, ?batch/s]


KeyboardInterrupt: 