In [1]:
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from typing import Optional

def pil_loader(fp: Path, mode: str) -> Image.Image:
    with open(fp, "rb") as f:
        img = Image.open(f)
        return img.convert(mode)

def first_and_last_nonzeros(arr):
    for i in range(len(arr)):
        if arr[i] != 0:
            break
    left = i
    for i in reversed(range(len(arr))):
        if arr[i] != 0:
            break
    right = i
    return left, right


def crop(filename: Path, padding: int = 8) -> Optional[Image.Image]:
    image = pil_loader(filename, mode="RGBA")

    # Replace the transparency layer with a white background
    new_image = Image.new("RGBA", image.size, "WHITE")
    new_image.paste(image, (0, 0), image)
    new_image = new_image.convert("L")

    # Invert the color to have a black background and white text
    arr = 255 - np.array(new_image)

    # Area that has text should have nonzero pixel values
    row_sums = np.sum(arr, axis=1)
    col_sums = np.sum(arr, axis=0)
    y_start, y_end = first_and_last_nonzeros(row_sums)
    x_start, x_end = first_and_last_nonzeros(col_sums)

    # Some images have no text
    if y_start >= y_end or x_start >= x_end:
        print(f"{filename.name} is ignored because it does not contain any text")
        return None

    # Cropping
    cropped = arr[y_start : y_end + 1, x_start : x_end + 1]
    H, W = cropped.shape

    # Add paddings
    new_arr = np.zeros((H + padding * 2, W + padding * 2))
    new_arr[padding : H + padding, padding : W + padding] = cropped

    # Invert the color back to have a white background and black text
    new_arr = 255 - new_arr
    return Image.fromarray(new_arr).convert("L")

In [2]:
home = Path.home()
RAW_IMAGES_DIRNAME = home / "Downloads/ML_Project/formula_images"
PROCESSED_IMAGES_DIRNAME = home / "Downloads/ML_Project/formula_images_processed"

class LatexDataset(Dataset):
    def __init__(self, data_dir, transform=None, split="train"):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        self.images = []
        self.labels = []

        self.load_data()

    def load_data(self):
        formula_file = os.path.join(self.data_dir, "im2latex_formulas.lst")
        image_list_file = os.path.join(self.data_dir, f"im2latex_{self.split}.lst")

        with open(formula_file, "r", encoding="latin-1") as f:
            formulas = f.readlines()

        with open(image_list_file, "r", encoding="latin-1") as f:
            image_list = f.readlines()

        for line in image_list:
            parts = line.strip().split(" ")
            formula_idx = int(parts[0])
            image_name = parts[1]
            render_type = parts[2]

            label = formulas[formula_idx].strip()

            image_path = os.path.join(self.data_dir, "formula_images", f"{image_name}.png")

            self.images.append(image_path)
            self.labels.append(label)
        
        if not PROCESSED_IMAGES_DIRNAME.exists():
            PROCESSED_IMAGES_DIRNAME.mkdir(parents=True, exist_ok=True)
            print("Cropping images...")
            for image_filename in RAW_IMAGES_DIRNAME.glob("*.png"):
                cropped_image = crop(image_filename, padding=8)
                if not cropped_image:
                    continue
                cropped_image.save(PROCESSED_IMAGES_DIRNAME / image_filename.name)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert("L")

        if self.transform:
            image = self.transform(image)

        return {"image": image, "label": label}

transform = transforms.Compose([
    transforms.ToTensor(),
])

data_dir = r"C:\Users\avnee\Downloads\ML_Project\formula_images"

train_dataset = LatexDataset(data_dir, transform=transform, split="train")
val_dataset = LatexDataset(data_dir, transform=transform, split="validate")
test_dataset = LatexDataset(data_dir, transform=transform, split="test")

batch_size = 32
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CNNWithPositionalEncoding(nn.Module):
    def __init__(self, d_model, height, width):
        super(CNNWithPositionalEncoding, self).__init__()

        # CNN architecture
        self.conv1 = nn.Conv2d(1, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.bn1 = nn.BatchNorm2d(512)
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))

        self.conv2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.bn2 = nn.BatchNorm2d(512)
        self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))

        self.conv3 = nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv5 = nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv6 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        # Positional encoding
        self.positional_encoding = positionalencoding2d(d_model, height, width)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)

        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.conv4(x))
        x = self.pool3(x)

        x = F.relu(self.conv5(x))
        x = self.pool4(x)

        x = F.relu(self.conv6(x))

        # Add positional encoding
        x = x + self.positional_encoding.unsqueeze(0)  # Broadcasting the positional encoding

        return x

def positionalencoding2d(d_model, height, width):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model*height*width position matrix
    """
    if d_model % 4 != 0:
        raise ValueError("Cannot use sin/cos positional encoding with "
                         "odd dimension (got dim={:d})".format(d_model))
    pe = torch.zeros(d_model, height, width)
    # Each dimension use half of d_model
    d_model = int(d_model / 2)
    div_term = torch.exp(torch.arange(0., d_model, 2) *
                         -(math.log(10000.0) / d_model))
    pos_w = torch.arange(0., width).unsqueeze(1)
    pos_h = torch.arange(0., height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

    return pe

In [8]:
import torch.nn as nn
import json
with open("vocab.json", "r") as f:
    VOCAB = json.load(f)

class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention, self).__init__()

        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, encoder_outputs, hidden):
        batch_size, seq_len, _ = encoder_outputs.shape

        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)

        return F.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, attention):
        super(Decoder, self).__init__()

        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.LSTM((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.attention = attention

    def forward(self, input, hidden, encoder_outputs):
        input = input.unsqueeze(0)
        embedded = self.embedding(input)
        a = self.attention(encoder_outputs, hidden[0])
        a = a.unsqueeze(1)
        weighted = torch.bmm(a, encoder_outputs)
        rnn_input = torch.cat((embedded, weighted), dim=2)
        output, hidden = self.rnn(rnn_input, hidden)
        output = torch.cat((embedded.squeeze(0), hidden[0].squeeze(0), weighted.squeeze(1)), dim=1)
        prediction = self.fc_out(output)
        return prediction, hidden

class Seq2SeqWithAttention(nn.Module):
    def __init__(self, cnn_model, attention, output_dim, emb_dim, enc_hid_dim, dec_hid_dim):
        super(Seq2SeqWithAttention, self).__init__()

        self.cnn_model = cnn_model
        self.attention = attention
        self.decoder = Decoder(output_dim, emb_dim, enc_hid_dim, dec_hid_dim, attention)

    def forward(self, x, target, teacher_forcing_ratio=0.5):
        encoder_outputs = self.cnn_model(x)

        batch_size, _, _ = encoder_outputs.shape
        trg_len = target.shape[1]
        trg_vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(x.device)
        input = target[:, 0]

        hidden = (torch.zeros(1, batch_size, self.decoder.rnn.hidden_size).to(x.device),
                  torch.zeros(1, batch_size, self.decoder.rnn.hidden_size).to(x.device))

        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = target[:, t] if teacher_force and t < trg_len - 1 else top1

        return outputs