<a href="https://colab.research.google.com/github/robbiedantonio/Img2LaTeX/blob/main/Img2LaTeX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Img2LaTeX Project
#### Using deep learning techniques to get LaTeX code strings from a typed (PDF) image of a mathematical expression

**Loading data and implementing token dictionary**

In [1]:
'''
Mount Google Drive, copy data to local runtime, and unzip folders
'''
from google.colab import drive
drive.mount('/content/drive')

! cp /content/drive/MyDrive/'Img2LaTeX_data'/train.zip /content
! cp /content/drive/MyDrive/'Img2LaTeX_data'/test.zip /content
! cp /content/drive/MyDrive/'Img2LaTeX_data'/val.zip /content
! cp /content/drive/'MyDrive'/'Img2LaTeX_data'/math.txt /content

! unzip -DD -q  ./train.zip -d  .
! unzip -DD -q  ./test.zip -d  .
! unzip -DD -q  ./val.zip -d  .

Mounted at /content/drive


In [2]:
'''
Length of datasets
'''
num_train_str = !ls train | wc -l
num_test_str = !ls test | wc -l
num_val_str = !ls val | wc -l
num_train = int(num_train_str[0])
num_test = int(num_test_str[0])
num_val = int(num_val_str[0])

print(f'Number of train images: {num_train}\nNumber of test images: {num_test}\nNumber of validation images: {num_val}\nTotal images: {num_train+num_test+num_val}')

Number of train images: 158480
Number of test images: 30637
Number of validation images: 6765
Total images: 195882


In [4]:
'''
Paths to folders
'''
train_root = "./train/"
test_root = "./test/"
val_root = "./val/"
labels = "./math.txt"

In [7]:
'''
Load data and preprocess images
'''
import os
import torch.utils.data
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

class LatexDataset(torch.utils.data.Dataset):
  def __init__(self, transform=None, dataroot=None, labels=None, max_seq_length=256):
        '''
        Initialize the dataset
            transform: A torchvision transform to apply to the images
            dataroot: The root directory of the dataset
            max_seq_length: The maximum length of a sequence. This allows us
                to simplify training by avoiding sequences that are too long.
        '''
        assert dataroot is not None and labels is not None          # Make sure dataroot and labels are specified
        assert os.path.exists(dataroot) and os.path.exists(labels)  # Make sure dataroot and labels exist
        assert max_seq_length > 0                                   # Make sure max_seq_length is positive

        self.transform = transform
        self.dataroot = dataroot
        self.labels_txt = labels
        self.max_seq_length = max_seq_length
        self._parse()

  def __parse__(self):
        '''
        Parse the math.txt file.
        Populates the following private variables:
            self.im_paths: A list of strings storing the associated image paths
            self.labels: A list of strings, where each string is the latex code for an image
        '''
        def getImPath(idx):
            # Find image in either train, test, or validation folder
            imname = str(idx - 1).zfill(7) + '.png'
            if os.path.exists(f'{self.dataroot}{imname}'):
              impath = f'{self.dataroot}{imname}'
            else:
              return None

            try:
                Image.open(impath).verify()
            except Exception as e:
                # Some images can't be opened
                # print(f"Image at path {impath} is corrupted. Error: {e}")
                return None

            return impath

        self.im_paths = []
        self.labels = []

        with open(self.labels_txt) as f:
            for idx, line in enumerate(f):
                impath = getImPath(idx+1)

                if impath is not None:
                    labels = line.strip('\n')
                    if len(labels) < self.max_seq_length-1: # Loading images with certain latex length
                      self.im_paths.append(impath)          # Image name
                      self.labels.append(labels)            # String of latex code

  def __len__(self):
        '''
        Return length of the dataset.
        '''
        assert len(self.labels) == len(self.im_paths)
        return len(self.labels)

  def __getitem__(self, index):
        '''
        Get a single sample from the dataset.
        Returns a single (image, attributes) tuple.
        '''
        def img_load(index):
            img = Image.open(self.im_paths[index])
            # imgray = imraw.convert('L')                         # Convert image to greyscale
            # imthresh = imgray.point(lambda p: p > 240 and 255)  # Threshold image to remove background (white)
            if self.transform is not None:
              return self.transform(img)
            else:
              return img

        target = self.labels[index]
        return img_load(index), target

In [8]:
'''
Dictionary block: converts a LaTeX string to a dictionary of latex tokens, where
each unique token has its own entry and integer value assigned to it
'''

class LatexDict():
    def __init__(self, label_file=None, max_seq_length=256):
        assert label_file is not None                       # Make sure label_file is specified
        self.labels_txt = label_file
        self.max_seq_length = max_seq_length
        self.latex_dict = {'<UKN>':0, '<PAD>':1, '<EOS>':2}            # Initialize with token for unknown, pad, and end of sequence
        self.latex_dict_inverse = {0:'<UKN>', 1:'<PAD>', 2:'<EOS>'}    # Initialize reverse dict for quicker reverse lookups
        self.create_dict()

    def create_dict(self):
        '''
        Go through entire label file and populate normal and reverse dictionary
        '''
        with open(self.labels_txt) as f:
            for line in f:
                tokens = line.split()
                for token in tokens:
                    if token not in self.latex_dict:
                        # Assign a new ID for the unseen token
                        new_id = len(self.latex_dict)
                        self.latex_dict[token] = new_id
                        self.latex_dict_inverse[new_id] = token

    def map_tokens(self, latex_strings_list, batch_size):
        '''
        Map a list of LaTeX strings to a tensor of integers using the dictionary
        latex_string_list: A list of LaTeX strings
        batch_size: Number of samples in the batch

        Returns:
            ids_tensor: A tensor of integers with shape (batch_size, max_seq_length)
        '''
        ids_tensor = torch.full((batch_size, self.max_seq_length), self.latex_dict['<PAD>'], dtype=torch.float32)

        for row, tex_str in enumerate(latex_strings_list):
            tex_str = r'{ ' + tex_str + ' }'
            tokens = tex_str.split()
            for col, token in enumerate(tokens):
                ids_tensor[row, col] = self.latex_dict[token]

        return ids_tensor

    def tokens_to_tex(self, token_vec):
        '''
        Maps a 1D tensor of integers to a LaTeX string
        token_vec: A tensor of integers of length max_seq_length

        Returns:
            tex_str: A string of LaTeX code corresponding to the token vector
        '''
        tex_str = ' '
        for token_id in token_vec.tolist():
            if token_id in self.latex_dict_inverse:
                if self.latex_dict_inverse[token_id] == '<EOS>':
                    break
                if self.latex_dict_inverse[token_id] != '<PAD>' and self.latex_dict_inverse[token_id] != '<UKN>':
                    tex_str += self.latex_dict_inverse[token_id] + ' '

        return tex_str

    def __dict__(self):
        return self.latex_dict

    def __len__(self):
        return len(self.latex_dict)

**Implementing Model:**

Current design idea:

Image -> Preprocess (transforms) -> CNN encoder + sinusoidal positional encoding -> transformer decoder -> Cross-entropy loss

In [None]:
import torch
import torch.nn as nn

In [None]:
'''
CNN Encoder Class
'''
class CNNEncoder(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_sizes, strides, paddings):
        '''
        Initialize the CNN encoder

        input_channels: Number of input channels
        output_channels: Number of output channels
        kernel_sizes: List of kernel sizes
        strides: List of strides
        paddings: List of paddings
        '''
        super(CNNEncoder, self).__init__()

        self.conv_layers = nn.ModuleList()
        self.pool_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()

        for i in range(len(kernel_sizes)):
            in_channels = input_channels if i == 0 else output_channels[i-1]
            out_channels = output_channels[i]
            kernel_size = kernel_sizes[i]
            stride = strides[i]
            padding = paddings[i]

            conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            pool_layer = nn.MaxPool2d(2, 2)
            norm_layer = nn.BatchNorm2d(out_channels)

            self.conv_layers.append(conv_layer)
            self.pool_layers.append(pool_layer)
            self.norm_layers.append(norm_layer)

    def forward(self, x):
        '''
        Forward pass of the CNN encoder
        x: Input tensor of shape (batch_size, input_channels, height, width)

        Returns:
            features: A list of tensors, where each tensor represents the feature maps
                from one convolutional layer, of shape:
                (batch_size, output_channels[i], layer_output_height[i], layer_output_width[i]),
                where i is the index of the corresponding convolutional layer.
        '''

        features = []
        for conv, pool, norm in zip(self.conv_layers, self.pool_layers, self.norm_layers):
            x = conv(x)
            x = pool(x)
            x = norm(x)
            features.append(x)

        return features