# Prelims

## Imports

In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

import logging

debug = logging.getLogger("Debug")
info  = print
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x105dbbf70>

## Utils

In [2]:
#check GPU
device = None
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Running CUDA Mode:", device, torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Running MPS Mode:", device)
else:
    device = torch.device("cpu")
    print("Running CPU Mode:", device)

import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from tqdm import tqdm

plt.switch_backend('agg')
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
    
def saveModel(save_path, model_state, optimiser_state, loss):
    torch.save({
            'model_state_dict': model_state,
            'optimizer_state_dict':optimiser_state,
            'loss': loss,  
    }, save_path)

Running MPS Mode: mps


# Class Definitions

## Vocab and Dataloader

In [3]:
VOCAB = None
MAX_SEQ_LEN = 128

In [21]:
START_TOKEN = "<START>"
END_TOKEN = "<END>"
UNK_TOKEN = "<UNK>"
PAD_TOKEN = "<PAD>"

class Vocabulary:
    def __init__(self, freq_dict, wd_to_id, id_to_wd):
        self.freq_dict = freq_dict
        self.wd_to_id = wd_to_id
        self.id_to_wd = id_to_wd
        self.N = len(freq_dict)
    
    def get_id(self, word):
        if word in self.wd_to_id:
            return self.wd_to_id[word]
        else:
            return self.wd_to_id[UNK_TOKEN]

class LatexFormulaDataset(Dataset):
    """Latex Formula Dataset: Image and Text"""
    
    def __init__(self, csv_file, root_dir, transform = None, max_examples = None):
        """
        Arguments:
            csv_file (string): Path to the csv file with image name and text
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        info("Loading Dataset...")
        self.df = pd.read_csv(csv_file)
        
        #info("Loaded Dataset", self.df.info)
        
        #Slice the dataset if max_examples is not None
        if max_examples is not None:
            self.df = self.df.iloc[:max_examples, :]

        self.root_dir = root_dir
        self.transform = transform

        self.df['formula'] = self.df['formula'].apply(lambda x: x.split())
        self.df['formula'] = self.df['formula'].apply(lambda x: [START_TOKEN] + x + [END_TOKEN])

        self.maxlen = 0
        for formula in self.df['formula']:
            if len(formula) > self.maxlen:
                self.maxlen = len(formula)
     
        
        self.df['formula'] = self.df['formula'].apply(lambda x: x +[PAD_TOKEN]*(max(self.maxlen, MAX_SEQ_LEN) - len(x)))
        self.vocab= self.construct_vocab() 
        # self.df['formula'] = self.df['formula'].apply(convert_to_ids)
        info("Loaded.")
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        """
        Returns sample of type image, textformula
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.df.iloc[idx, 0])
        image = io.imread(img_name)
        formula = self.df.iloc[idx, 1]

        # formula = np.array([formula], dtype=str).reshape(-1, 1)
        # formula = [self.vocab.get_id(wd[0]) for wd in formula]
        
        def convert_to_ids(formula):
            form2 = [VOCAB.get_id(wd) for wd in formula]
            return torch.tensor(form2, dtype=torch.int64)[0:MAX_SEQ_LEN]
        
        sample = {'image': image, 'formula': convert_to_ids(formula)}

        if self.transform:
            sample['image'] = self.transform(sample['image'])
            
        return sample 
    
    def construct_vocab(self):
        """
        Constructs vocabulary from the dataset formulas
        """
        #Split on spaces to tokenize
        freq_dict = {}
        for formula in self.df['formula']:
            for wd in formula:
                if wd not in freq_dict:
                    freq_dict[wd] = 1
                else:
                    freq_dict[wd] += 1
        freq_dict[UNK_TOKEN] = 1
        N = len(freq_dict)
        wd_to_id = {}
        for i, wd in enumerate(freq_dict):
            wd_to_id[wd] = i
        id_to_wd = {v: k for k, v in wd_to_id.items()}
    
        #pad the formulas with 
        return Vocabulary(freq_dict, wd_to_id, id_to_wd)      

def get_dataloader(csv_path, image_root, batch_size, transform = None, max_examples = None):
    """
    Returns dataloader,dataset for the dataset
    """
    dataset = LatexFormulaDataset(csv_path, image_root, max_examples=max_examples,transform=transform) #checked
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader, dataset
 

## Model

## Encoder

In [22]:
#import vgg16 from torch
from torchvision.models import vgg16

## Decoder

In [23]:
class DecoderLSTM(nn.Module):
    """
    Inputs:
    (here M is whatever the batch size is passed)

    context_size : size of the context vector [shape: (1,M,context_size)]
    n_layers: number of layers [for our purposes, defaults to 1]
    hidden_size : size of the hidden state vectors [shape: (n_layers,M,hidden_size)]
    embed_size : size of the embedding vectors [shape: (1,M,embed_size)]
    vocab_size : size of the vocabulary
    max_length : maximum length of the formula
    """
    def __init__(self, context_size, vocab, max_seq_len, n_layers = 1, hidden_size = 512, embed_size = 512):
        super().__init__()
        self.context_size = context_size
        self.vocab = vocab
        self.vocab_size = vocab.N
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.input_size = context_size + embed_size
        self.max_seq_len = max_seq_len

        self.embed = nn.Embedding(self.vocab_size, embed_size)
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.linear = nn.Linear(hidden_size, self.vocab_size)
    
    def forward(self, context, target_tensor = None):
        """
        M: batch_size
        context is the context vector from the encoder [shape: (M,context_size)]
        target_tensor is the formula in tensor form [shape: (M,max_length)] (in the second dimension, it is sequence of indices of formula tokens)
            if target_tensor is not None, then we are in Teacher Forcing mode
            else normal jo bhi (last prediction ka embed is concatenated)
        """
        context.to(device)
        batch_size = context.shape[0]

        #initialize hidden state and cell state
        hidden = context
        cell = torch.zeros(batch_size, self.hidden_size).to(device)

        #initialize the input with embedding of the start token. Expand for batch size.
        init_embed = self.embed(torch.tensor([self.vocab.wd_to_id[START_TOKEN]]).to(device).expand(batch_size, -1)).squeeze()
        
        #initialize the output_history and init_output
        outputs = []
        output = torch.zeros((batch_size, self.vocab_size)).to(device)
        
        
        for i in range(self.max_seq_len):
            #teacher forcing: 50% times
            r = torch.rand(1)
            if r>0.5 and target_tensor is not None:
                if i==0 :
                    embedding = init_embed
                else: 
                    embedding = self.embed(target_tensor[:, i-1]).reshape((batch_size, self.embed_size)).to(device)            
            else:
                if i==0 :
                    embedding = init_embed

                else:
                    #create embedding from previous input
                    embedding = self.embed(torch.argmax(output, dim = 1))

            lstm_input = torch.cat([context, embedding], dim = 1).to(device)
    
            hidden, cell = self.lstm(lstm_input, (hidden, cell))
            output = self.linear(hidden)
            outputs.append(output)
            
        output_tensor = torch.stack(outputs).permute(1,0,2) #LBV - > BLV

        return output_tensor, hidden, cell
    

## Final

In [24]:
class HandwritingToLatexModel(nn.Module):
    def __init__(self, context_size, vocab, max_seq_len, n_layers, hidden_size, embed_size):
        super().__init__()
        self.encoder = vgg16(weights="IMAGENET1K_V1")
        #Freeze weights of the encoder and unfreeze last layer
        self.encoder.classifier[-1] = nn.Linear(4096, 512)
        for param in self.encoder.parameters():
            param.requires_grad = False
        for param in self.encoder.classifier[-1].parameters():
            param.requires_grad = True
        self.decoder = DecoderLSTM(context_size, vocab, max_seq_len, n_layers, hidden_size, embed_size)
    
    def forward(self, image, target_tensor = None):
        context = self.encoder(image)
        outputs, _, _ = self.decoder(context, target_tensor)
        return outputs

# Train

In [25]:
DATA_BASE_PATH = "data/"
SAVE_BASE_PATH = "checkpoints/"

## Training Code

In [26]:
def train_epoch(dataloader,model, optimizer, criterion):
    total_loss = 0
    idx = 0
    pb = tqdm(dataloader, desc="Batch")
    for data in pb:
        idx+=1
        input_tensor, target_tensor = data['image'].to(device), data['formula'].to(device)
        outputs = model(input_tensor, target_tensor)
        train_dataset = dataloader.dataset 
        if(train_dataset and idx%100==0):
            generated_formula = [VOCAB.id_to_wd[token.item()] for token in torch.argmax(outputs, dim=2)[0]]
            required_formula = [VOCAB.id_to_wd[token.item()] for token in target_tensor[0]]
            print(f"Generated: {' '.join(generated_formula)}")
            print(f"Actual: {' '.join(required_formula)}")

        output_logits = outputs.permute(0,2,1)
        
        loss = criterion(
            output_logits,
            target_tensor
        )
        
        #backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

        total_loss += loss.item()
        pb.set_description(f"Loss: {loss.item()}")

    return total_loss / len(dataloader)

def train(train_dataloader, model, n_epochs, optimizer = None, learning_rate=0.001, print_every=1, save_interval=2, save_prefix = 'model'):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every

    if not optimizer: optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=VOCAB.wd_to_id[PAD_TOKEN]).to(device) #as stated in assignment

    model.train()
    for epoch in range(1, n_epochs + 1):
        info(f"Epoch {epoch}")
        loss = train_epoch(train_dataloader, model, optimizer, criterion)
        print_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
    
        if epoch % save_interval == 0:
            saveModel(f'{SAVE_BASE_PATH}{save_prefix}_epoch_{epoch}.pt', model.state_dict(), optimizer.state_dict(), loss)    
                
        info('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg))

## Load Datasets

In [27]:
#part a
batch_size = 32
vocab_size = 1000
CONTEXT_SIZE = 512
HIDDEN_SIZE = 512
EMBED_SIZE = 512
MAX_EXAMPLES = 1000

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])

transform_hw = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])

train_csv_path = f"{DATA_BASE_PATH}/SyntheticData/train.csv"
image_root_path = f"{DATA_BASE_PATH}/SyntheticData/images/"
train_dataloader, train_dataset = get_dataloader(train_csv_path, image_root_path, batch_size, transform, max_examples=None)

hw_train_csv_path = f"{DATA_BASE_PATH}/HandwrittenData/train_hw.csv"
hw_image_root_path = f"{DATA_BASE_PATH}/HandwrittenData/images/train/"
hw_train_dataloader, hw_train_dataset = get_dataloader(hw_train_csv_path, hw_image_root_path, batch_size, transform_hw, max_examples=None)

Loading Dataset...
Loaded.
Loading Dataset...
Loaded.


## Create Model

In [28]:
#Create vocabulary connsisting of both vocabs of synthetic and handwritten datasets
def combine_vocabs(v1, v2):
    freq_dict = {}
    for wd in v1.freq_dict:
        freq_dict[wd] = v1.freq_dict[wd]
    for wd in v2.freq_dict:
        if wd not in freq_dict:
            freq_dict[wd] = v2.freq_dict[wd]
        else:
            freq_dict[wd] += v2.freq_dict[wd]
    freq_dict[UNK_TOKEN] = 1
    N = len(freq_dict)
    wd_to_id = {}
    for i, wd in enumerate(freq_dict):
        wd_to_id[wd] = i
    id_to_wd = {v: k for k, v in wd_to_id.items()}
    return Vocabulary(freq_dict, wd_to_id, id_to_wd)

VOCAB = combine_vocabs(train_dataset.vocab, hw_train_dataset.vocab)

# model = HandwritingToLatexModel(CONTEXT_SIZE, VOCAB, n_layers=1, hidden_size= HIDDEN_SIZE, embed_size=EMBED_SIZE, max_seq_len=MAX_SEQ_LEN).to(device)

In [29]:
def load_from_checkpoint(checkpoint_path):
    model_dicts = torch.load(checkpoint_path, map_location=device)
    
    model = HandwritingToLatexModel(CONTEXT_SIZE, VOCAB, n_layers=1, hidden_size= HIDDEN_SIZE, embed_size=EMBED_SIZE, max_seq_len=MAX_SEQ_LEN).to(device)
    model.load_state_dict(model_dicts['model_state_dict'])
    
    optims = torch.optim.Adam(model.parameters(), lr=0.001)
    optims.load_state_dict(model_dicts['optimizer_state_dict'])

    return model, optims

model_load, optims_load = load_from_checkpoint(f"{SAVE_BASE_PATH}model4.0_epoch_10.pt")

## Training

In [30]:
train(hw_train_dataloader, model_load, 25, optims_load, save_interval=2, save_prefix = 'FT4.0')

Epoch 1


Loss: 3.564805030822754:   3%|▎         | 8/282 [00:13<07:35,  1.66s/it] 


KeyboardInterrupt: 