In [379]:
%pip install numpy==1.22.4

Note: you may need to restart the kernel to use updated packages.


In [380]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

import logging

debug = logging.getLogger("Debug")
info  = print
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x2e45d8910>

## Data and Classes
- Create Dataloader class

Note: Working on Part (a) as of now.  
Guiding light: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

In [381]:
START_TOKEN = "START"
END_TOKEN = "END"
UNK_TOKEN = "UNK"

MAX_EXAMPLES = 100
class Vocabulary:
    def __init__(self, freq_dict, wd_to_id, id_to_wd):
        self.freq_dict = freq_dict
        self.wd_to_id = wd_to_id
        self.id_to_wd = id_to_wd
        self.N = len(freq_dict)
    
    def get_id(self, word):
        if word in self.wd_to_id:
            return self.wd_to_id[word]
        else:
            return self.wd_to_id[UNK_TOKEN]

class LatexFormulaDataset(Dataset):
    """Latex Formula Dataset: Image and Text"""
    
    def __init__(self, csv_file, root_dir, max_examples=None, transform = None):
        """
        Arguments:
            csv_file (string): Path to the csv file with image name and text
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        #@TODO: May want to preload images
        self.df = pd.read_csv(csv_file)

        if max_examples is not None:
            info(self.df.shape)
            self.df = self.df.iloc[:max_examples, :]
            info(self.df.shape)
            
        info("Loading Dataset")

        info(self.df.head())
        
        self.root_dir = root_dir
        self.transform = transform

        '''Tokenize the formula by splitting on spaces'''
        self.df['formula'] = self.df['formula'].apply(lambda x: x.split())
        self.vocab= self.construct_vocab()  

        self.maxlen = 0
        for formula in self.df['formula']:
            if len(formula) > self.maxlen:
                self.maxlen = len(formula)

        self.df['formula'] = self.df['formula'].apply(lambda x: x + [END_TOKEN] + [UNK_TOKEN]*(self.maxlen - len(x)))
        self.maxlen += 1
        #slice df to first max_examples using iloc


            
        #Embedding layer
        self.embed = nn.Embedding(self.vocab.N, 512)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        """
        Returns sample of type image, textformula
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 0])
        image = io.imread(img_name)
        formula = self.df.iloc[idx, 1]
        formula = np.array([formula], dtype=str).reshape(-1, 1)
        formula = [[self.vocab.get_id(wd[0]) for wd in formula]] 
        sample = {'image': image, 'formula': torch.tensor(formula, dtype=torch.int64)}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            
        return sample 
    
    def construct_vocab(self):
        """
        Constructs vocabulary from the dataset formulas
        """
        freq_dict = {}
        for formula in self.df['formula']:
            for wd in formula:
                if wd not in freq_dict:
                    freq_dict[wd] = 1
                else:
                    freq_dict[wd] += 1
        freq_dict[START_TOKEN] = 1
        freq_dict[END_TOKEN] = 1
        freq_dict[UNK_TOKEN] = 1
        N = len(freq_dict)
        wd_to_id = {}
        for i, wd in enumerate(freq_dict):
            wd_to_id[wd] = i
        id_to_wd = {v: k for k, v in wd_to_id.items()}
    
        #pad the formulas with 
        return Vocabulary(freq_dict, wd_to_id, id_to_wd)      

def get_dataloader(csv_path, image_root, batch_size, transform = None, max_examples = None):
    """
    Returns dataloader,dataset for the dataset
    """
    dataset = LatexFormulaDataset(csv_path, image_root, max_examples=max_examples,transform=transform) #checked
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader, dataset
     

### Encoder Network
- A CNN to encode image to more meaningful vector

In [382]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super().__init__()
    
        #@TODO:reduce number of layers: eliminate pools and acts
        self.conv1 = nn.Conv2d(3, 32, (5, 5))
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d((2, 2))
        
        self.conv2 = nn.Conv2d(32, 64, (5, 5))
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d((2, 2))
        
        self.conv3 = nn.Conv2d(64, 128, (5, 5))
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d((2, 2))
        
        self.conv4 = nn.Conv2d(128, 256, (5, 5))
        self.act4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d((2, 2))
        
        self.conv5 = nn.Conv2d(256, 512, (5, 5))
        self.act5 = nn.ReLU()
        self.pool5 = nn.MaxPool2d((2, 2))
        
        self.avg_pool = nn.AvgPool2d((3, 3))
    
    def forward(self, x):
        x = self.act1(self.conv1(x))
        x = self.pool1(x)
        
        x = self.act2(self.conv2(x))
        x = self.pool2(x)
        
        x = self.act3(self.conv3(x))
        x = self.pool3(x)
        
        x = self.act4(self.conv4(x))
        x = self.pool4(x)
        
        x = self.act5(self.conv5(x))
        x = self.pool5(x)
        
        x = self.avg_pool(x)
        x = x.view(-1,512) 
        # info(f"Encoder Output Shape: {x.shape}")
        return x

### Vocabulary
- https://github.com/harvardnlp/im2markup/blob/master

### Decoder Network

In [383]:
class Decoder(nn.Module):
    """
    Inputs:
    (here M is whatever the batch size is passed)

    context_size : size of the context vector [shape: (1,M,context_size)]
    n_layers: number of layers [for our purposes, defaults to 1]
    hidden_size : size of the hidden state vectors [shape: (n_layers,M,hidden_size)]
    embed_size : size of the embedding vectors [shape: (1,M,embed_size)]
    vocab_size : size of the vocabulary
    max_length : maximum length of the formula
    """
    def __init__(self, context_size, vocab, n_layers = 1, hidden_size = 512, embed_size = 512,  max_length = 100):
        super().__init__()
        self.context_size = context_size
        self.vocab = vocab
        self.vocab_size = vocab.N
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.max_length = max_length


        self.input_size = context_size + embed_size

        self.embed = nn.Embedding(self.vocab_size, embed_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.linear = nn.Linear(hidden_size, self.vocab_size)
        self.softmax = nn.Softmax(dim = 1)
    
    def forward(self, context, target_tensor = None):
        """
        M: batch_size
        context is the context vector from the encoder [shape: (M,context_size)]
        target_tensor is the formula in tensor form [shape: (M,max_length)] (in the second dimension, it is sequence of indices of formula tokens)
            if target_tensor is not None, then we are in Teacher Forcing mode
            else normal jo bhi (last prediction ka embed is concatenated)
        """
        # info("Decoder Forward")
        # info(f"Context shape: {context.shape}")
        
        batch_size = context.shape[0]

        #initialize hidden state and cell state
            #@TODO: Some caveat in the size of the cell state. Should it be same as hidden_size? (check nn.LSTM documentation)
        hidden = torch.zeros((batch_size, self.hidden_size))
        cell = torch.zeros((batch_size, self.hidden_size))

        #initialize the input with embedding of the start token
        init_embed = self.embed(torch.tensor([self.vocab.wd_to_id[START_TOKEN]])).reshape((1, self.embed_size))
        init_embed = torch.repeat_interleave(init_embed, batch_size, dim = 0)

        # info(f"Initial Embedding Shape: {init_embed.shape}")

        input = torch.cat([context, init_embed], dim = 1)

        #initialize the output_history and init_output
        outputs = []
        output = torch.zeros((batch_size, self.vocab_size))

        for i in range(self.max_length):
            hidden, cell = self.lstm(input, (hidden, cell))
            output = self.linear(hidden)
            # output = self.softmax(output)
            outputs.append(output)
            if target_tensor is not None:
                embedding = self.embed(target_tensor[:, i]).reshape((batch_size, self.embed_size))
                input = torch.cat([context, embedding], dim = 1)
            else:
                #add the embedding of the last prediction
                input = torch.cat([context, self.embed(torch.argmax(output, dim = 1))], dim = 1)
        # info(f"Outputs: {outputs}")
        return torch.stack(outputs), hidden, cell

### Utility Functions

In [384]:
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from tqdm import tqdm

plt.switch_backend('agg')
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

### Training Code.
- Dataloader automatically loads in batches. The data need not be modified by us.

In [385]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    total_loss = 0
    idx = 0
    for data in dataloader:
        idx+=1
        # info(f"----Batch {idx}----")
        input_tensor, target_tensor = data['image'], data['formula']

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_output = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_output)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001, print_every=1, plot_every=5):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss() #as stated in assignment

    for epoch in tqdm(range(1, n_epochs + 1)):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

## Training

In [386]:
batch_size = 32
batch_size = MAX_EXAMPLES

vocab_size = 1000
CONTEXT_SIZE = 512
HIDDEN_SIZE = 512
# OUTPUT_SIZE  = vocab_size
# MAX_LENGTH = 10000

In [387]:
# image processing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Lambda(lambda x: x/255.0), #min-max normalisation
])

In [388]:
#part a
#train_csv_path = "/kaggle/input/converting-handwritten-equations-to-latex-code/col_774_A4_2023/SyntheticData/train.csv"
#image_root_path = "/kaggle/input/converting-handwritten-equations-to-latex-code/col_774_A4_2023/SyntheticData/images"
train_csv_path = "data/SyntheticData/train.csv"
image_root_path = "data/SyntheticData/images"
train_dataloader, train_dataset = get_dataloader(train_csv_path, image_root_path, batch_size, transform, max_examples=MAX_EXAMPLES)
vocab = train_dataset.vocab
MAX_LENGTH = train_dataset.maxlen

print(train_dataset.df.shape)

(75000, 2)
(100, 2)
Loading Dataset
            image                                            formula
0  74d337e8a0.png  $ \gamma _ { \Omega R , 5 } ^ { T } = - \gamma...
1  2d0f18f71d.png  $ l ^ { ( -- ) \underline { { m } } } u _ { \u...
2  6d9b9de88d.png  $ \left[ H , \gamma _ { i } ^ { \left( 2 \righ...
3  38c6d510bb.png  $ < a _ { i } > \; \propto \; \int _ { \omega ...
4  24537a86e3.png  $ \Psi ( \mu _ { 1 } , \ldots , \mu _ { K } ) ...
(100, 2)


In [389]:
#create a network instance
encoder = EncoderCNN()
decoder = Decoder(CONTEXT_SIZE, vocab, n_layers=1, hidden_size= HIDDEN_SIZE, embed_size=512,max_length=MAX_LENGTH)
train(train_dataloader, encoder, decoder, 1000)

  0%|          | 1/1000 [00:07<2:00:18,  7.23s/it]

0m 7s (- 120m 19s) (1 0%) 5.2621


  0%|          | 2/1000 [00:14<1:58:30,  7.12s/it]

0m 14s (- 118m 46s) (2 0%) 4.8850


  0%|          | 3/1000 [00:21<1:56:51,  7.03s/it]

0m 21s (- 117m 26s) (3 0%) 3.3599


  0%|          | 4/1000 [00:28<1:55:44,  6.97s/it]

0m 28s (- 116m 32s) (4 0%) 2.6453


  0%|          | 5/1000 [00:34<1:55:01,  6.94s/it]

0m 34s (- 115m 56s) (5 0%) 2.6303


  1%|          | 6/1000 [00:41<1:54:32,  6.91s/it]

0m 41s (- 115m 29s) (6 0%) 2.4709


  1%|          | 7/1000 [00:48<1:54:08,  6.90s/it]

0m 48s (- 115m 6s) (7 0%) 2.3081


  1%|          | 8/1000 [00:55<1:54:09,  6.90s/it]

0m 55s (- 114m 55s) (8 0%) 2.3134


  1%|          | 9/1000 [01:02<1:52:54,  6.84s/it]

1m 2s (- 114m 19s) (9 0%) 2.2730


  1%|          | 10/1000 [01:09<1:52:18,  6.81s/it]

1m 9s (- 113m 54s) (10 1%) 2.2354


  1%|          | 11/1000 [01:15<1:51:11,  6.75s/it]

1m 15s (- 113m 21s) (11 1%) 2.2204


  1%|          | 12/1000 [01:22<1:50:56,  6.74s/it]

1m 22s (- 113m 1s) (12 1%) 2.2183


  1%|▏         | 13/1000 [01:28<1:50:20,  6.71s/it]

1m 29s (- 112m 37s) (13 1%) 2.2223


  1%|▏         | 14/1000 [01:35<1:50:36,  6.73s/it]

1m 35s (- 112m 25s) (14 1%) 2.2220


  2%|▏         | 15/1000 [01:42<1:51:13,  6.78s/it]

1m 42s (- 112m 21s) (15 1%) 2.2233


  2%|▏         | 16/1000 [01:49<1:50:25,  6.73s/it]

1m 49s (- 112m 1s) (16 1%) 2.2224


  2%|▏         | 17/1000 [01:56<1:50:49,  6.76s/it]

1m 56s (- 111m 55s) (17 1%) 2.2231


  2%|▏         | 18/1000 [02:02<1:50:24,  6.75s/it]

2m 2s (- 111m 41s) (18 1%) 2.2204


  2%|▏         | 19/1000 [02:09<1:49:23,  6.69s/it]

2m 9s (- 111m 21s) (19 1%) 2.2175


  2%|▏         | 20/1000 [02:16<1:49:14,  6.69s/it]

2m 16s (- 111m 8s) (20 2%) 2.2147


  2%|▏         | 21/1000 [02:22<1:49:03,  6.68s/it]

2m 22s (- 110m 55s) (21 2%) 2.2136


  2%|▏         | 22/1000 [02:29<1:49:45,  6.73s/it]

2m 29s (- 110m 50s) (22 2%) 2.2097


  2%|▏         | 23/1000 [02:37<1:54:59,  7.06s/it]

2m 37s (- 111m 27s) (23 2%) 2.2099


  2%|▏         | 24/1000 [02:45<1:59:27,  7.34s/it]

2m 45s (- 112m 7s) (24 2%) 2.2098


  2%|▎         | 25/1000 [02:54<2:05:20,  7.71s/it]

2m 54s (- 113m 6s) (25 2%) 2.2062


  3%|▎         | 26/1000 [03:02<2:06:46,  7.81s/it]

3m 2s (- 113m 39s) (26 2%) 2.2086


  3%|▎         | 27/1000 [03:10<2:07:55,  7.89s/it]

3m 10s (- 114m 11s) (27 2%) 2.2075


  3%|▎         | 28/1000 [03:18<2:12:04,  8.15s/it]

3m 18s (- 115m 4s) (28 2%) 2.2054


  3%|▎         | 29/1000 [03:26<2:10:30,  8.06s/it]

3m 26s (- 115m 22s) (29 2%) 2.2061


  3%|▎         | 30/1000 [03:34<2:08:58,  7.98s/it]

3m 34s (- 115m 36s) (30 3%) 2.2054


  3%|▎         | 31/1000 [03:42<2:07:42,  7.91s/it]

3m 42s (- 115m 47s) (31 3%) 2.2059


  3%|▎         | 32/1000 [03:50<2:07:36,  7.91s/it]

3m 50s (- 116m 2s) (32 3%) 2.2039


  3%|▎         | 33/1000 [03:58<2:07:40,  7.92s/it]

3m 58s (- 116m 17s) (33 3%) 2.2026


  3%|▎         | 34/1000 [04:05<2:06:37,  7.87s/it]

4m 5s (- 116m 25s) (34 3%) 2.2029


  4%|▎         | 35/1000 [04:14<2:07:52,  7.95s/it]

4m 14s (- 116m 43s) (35 3%) 2.2046


  4%|▎         | 36/1000 [04:21<2:07:23,  7.93s/it]

4m 21s (- 116m 52s) (36 3%) 2.2018


  4%|▎         | 37/1000 [04:29<2:06:00,  7.85s/it]

4m 29s (- 116m 55s) (37 3%) 2.2039


  4%|▎         | 37/1000 [04:31<1:57:42,  7.33s/it]


KeyboardInterrupt: 