In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderCNN(nn.Module):
    def __init__(self, layers, hparams):
        '''
        Args:
            layers: Description of all layers in the Encoder: [(layer_type, {layer_params})]
                - layer types - ['conv1d', 'conv2d', 'maxpool1d', 'maxpool2d', 'avgpool2d', 'avgpool2d', 'linear', 'dropout']
                - layer_params - dict of parameters for the layer

            hparams: Hyperparameters for the model
        '''
        super(EncoderCNN, self).__init__()
        self.hp = hparams
        self.layers = nn.ModuleList()

        for layer_type, layer_params in layers:
            if layer_type == 'conv1d':
                self.layers.append(nn.Conv1d(**layer_params))
            elif layer_type == 'conv2d':
                self.layers.append(nn.Conv2d(**layer_params))
            elif layer_type == 'maxpool1d':
                self.layers.append(nn.MaxPool1d(**layer_params))
            elif layer_type == 'maxpool2d':
                self.layers.append(nn.MaxPool2d(**layer_params))
            elif layer_type == 'avgpool1d':
                self.layers.append(nn.AvgPool1d(**layer_params))
            elif layer_type == 'avgpool2d':
                self.layers.append(nn.AvgPool2d(**layer_params))
            elif layer_type == 'linear':
                self.layers.append(nn.Linear(**layer_params))
            elif layer_type == 'dropout':
                self.layers.append(nn.Dropout(**layer_params))
            else:
                raise ValueError(f'Invalid layer type: {layer_type}')

    def forward(self, input):
        for layer in self.layers:
            input = layer(input)
        return input
    
class DecoderRNN(nn.Module):
    def __init__(self, vocab, vocab_dict, input_size, embedding_size):
        super(DecoderRNN, self).__init__()
        '''
        Args:
            vocabulary_size: Size of the vocabulary
            embedding_size: Size of the embedding vector
        '''

        self.vocab = vocab
        self.vocab_dict = vocab_dict

        self.embedding = nn.Embedding(len(vocab), embedding_size)
        self.embedding_size = embedding_size
        self.lstm = nn.LSTM(input_size+embedding_size, embedding_size, batch_first=True)
        self.output = nn.Linear(embedding_size, len(vocab))

    def forward(self, input, hidden):
        '''
        Args:
            input: Input to the decoder
            hidden: Hidden state of the previous time step
        '''
        # prev_embed = self.embedding(prev_tokens)
        # concated_inp = torch.cat((input, prev_embed), dim=1)
        if hidden is None:
            output, hidden = self.lstm(input)
        else:
            output, hidden = self.lstm(input, hidden)
        output = self.output(output)

        return output, hidden
    
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

In [2]:
# Load dataset
import torch.utils.data as data
from torchvision import transforms
from torchtext.vocab import build_vocab_from_iterator
import pandas as pd
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PAD = "<pad>"
SOS = "<sos>"
EOS = "<eos>"

def load_img(path, size = (224, 224)):
    img = (Image.open(path))
    transform = transforms.Compose([transforms.Resize(size, antialias=True), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    im = transform(img).detach()
    im = 1 - im
    return im

class Img2LatexDataset(data.Dataset):
    def __init__(self, img_dir, formula_path, img_size = (224, 224)):
        self.data_frame = pd.read_csv(formula_path)
        self.img_dir = img_dir
        self.img_size = img_size

        self.token_to_idx = {}
        self.tokens = []

        for row in self.data_frame["formula"]:
            row = row.split()

            for token in row:
                if token not in self.token_to_idx:
                    self.token_to_idx[token] = len(self.token_to_idx)
                    self.tokens.append(token)
        
        for special_token in [SOS, EOS, PAD]:
            self.token_to_idx[special_token] = len(self.token_to_idx)
            self.tokens.append(special_token)

        max_len = max([len(row.split()) for row in self.data_frame["formula"]])+2
        def indexer(row):
            index_list = [self.token_to_idx[SOS]]
            index_list.extend([self.token_to_idx[token] for token in row.split()])
            index_list.append(self.token_to_idx[EOS])
            index_list.extend([self.token_to_idx[PAD]] * (max_len - len(index_list)))

            return index_list
        
        self.data_frame["IndexList"] = self.data_frame["formula"].apply(indexer)

    def __getitem__(self, index):
        img = load_img(self.img_dir + self.data_frame["image"][index], self.img_size)
        return img, torch.tensor(self.data_frame["IndexList"][index], requires_grad=False)

    def __len__(self):
        return len(self.data_frame)
    
    def get_vocab(self):
        return self.token_to_idx, self.tokens

img_dir = "../data/SyntheticData/images/"
formula_dir = "../data/SyntheticData/train.csv"

dataset = Img2LatexDataset(img_dir, formula_dir)


Using device: cuda


In [3]:
hparams = {
    "lr" : 0.001,
    "batch_size" : 64,
    "epochs" : 10
}

channel_seq = [3, 32, 64, 128, 256, 512]
num_conv_pool = 5

enc_layers = []

for i in range(num_conv_pool):
    enc_layers.append(('conv2d', {'in_channels': channel_seq[i], 'out_channels': channel_seq[i+1], 'kernel_size': 5}))
    enc_layers.append(('maxpool2d', {'kernel_size': 2}))

enc_layers.append(('avgpool2d', {'kernel_size': (3,3)}))

enc = EncoderCNN(enc_layers, hparams).to(device)
dec = DecoderRNN(dataset.tokens, dataset.token_to_idx, 512, 512).to(device)

model = EncoderDecoder(enc, dec).to(device)

In [4]:
for param in model.parameters():
    print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([32, 3, 5, 5])
<class 'torch.Tensor'> torch.Size([32])
<class 'torch.Tensor'> torch.Size([64, 32, 5, 5])
<class 'torch.Tensor'> torch.Size([64])
<class 'torch.Tensor'> torch.Size([128, 64, 5, 5])
<class 'torch.Tensor'> torch.Size([128])
<class 'torch.Tensor'> torch.Size([256, 128, 5, 5])
<class 'torch.Tensor'> torch.Size([256])
<class 'torch.Tensor'> torch.Size([512, 256, 5, 5])
<class 'torch.Tensor'> torch.Size([512])
<class 'torch.Tensor'> torch.Size([549, 512])
<class 'torch.Tensor'> torch.Size([2048, 1024])
<class 'torch.Tensor'> torch.Size([2048, 512])
<class 'torch.Tensor'> torch.Size([2048])
<class 'torch.Tensor'> torch.Size([2048])
<class 'torch.Tensor'> torch.Size([549, 512])
<class 'torch.Tensor'> torch.Size([549])


In [6]:
PAD_IDX = dataset.token_to_idx[PAD]

def remove_trailing_pads(labels):
   # Clip trailing PAD on labels
   non_pad_cols = (labels != PAD_IDX).sum(dim=0)
   non_pad_cols = non_pad_cols[non_pad_cols > 0]
   
   return labels[:, :len(non_pad_cols)]

loader = data.DataLoader(dataset, batch_size = 31, shuffle = True)

model_path = "./models/model.pt"
current_params_path = "./models/current_params.txt"

state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.eval()
print(f"LOADED MODEL to {device}")

images, labels = next(iter(loader))
images = images.to(device)
labels = labels.to(device)

labels = remove_trailing_pads(labels)
context_vec = model.encoder(images).squeeze()
if len(context_vec.shape) == 1:
    context_vec = context_vec.unsqueeze(0)
print(context_vec.shape)
print(labels)
print(labels.shape)
print(context_vec.unsqueeze(1).repeat(1, labels.shape[1], 1).shape)
print(model.decoder.embedding(labels).shape)
target = nn.functional.one_hot(labels, num_classes=len(dataset.tokens)).float().to(device)
inputs = torch.cat([context_vec.unsqueeze(1).repeat(1, labels.shape[1], 1), model.decoder.embedding(labels)], dim=2)

output, _ = model.decoder(inputs, None)
print(output.shape)
print("H")
mask = labels == PAD_IDX
print(mask)
print(output[labels == PAD_IDX].shape)
print(target[labels == PAD_IDX].shape)
# output[labels == PAD_IDX] = 0
# target[labels == PAD_IDX] = 0
print("h")
print(output.shape)
output = output.argmax(dim=2)
output_tokens = [dataset.tokens[token_idx] for token_idx in output[0].tolist()]
print(output_tokens)
print([dataset.tokens[token_idx] for token_idx in labels[0, 1:].tolist()])


LOADED MODEL to cuda
torch.Size([31, 512])
tensor([[546,   0,   3,  ..., 548, 548, 548],
        [546,   0,  52,  ..., 548, 548, 548],
        [546,   0,  40,  ..., 548, 548, 548],
        ...,
        [546,   0,  21,  ..., 548, 548, 548],
        [546,   0,  55,  ..., 548, 548, 548],
        [546,   0, 130,  ..., 548, 548, 548]], device='cuda:0')
torch.Size([31, 161])
torch.Size([31, 161, 512])
torch.Size([31, 161, 512])
torch.Size([31, 161, 549])
H
tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True]], device='cuda:0')
torch.Size([2835, 549])
torch.Size([2835, 549])
h
torch.Size([31, 161, 549])
['$', '{', '\\cal', '}', '}', '=', '{', '1', '}', '}', '=', '\\frac', '{', '1', '}', '{',

In [7]:
t = 27
output_tokens = [dataset.tokens[token_idx] for token_idx in output[t].tolist()]
print(output_tokens)
print([dataset.tokens[token_idx] for token_idx in labels[t, 1:].tolist()])

['$', '{', '_', '{', '1', '}', '}', '}', '=', '{', '2', '}', '=', '=', '\\frac', '\\,', '^', '{', '2', '}', '}', '{', '1', '}', '}', '}', '}', '{', '2', '}', '}', '{', '.', '\\frac', '{', '1', '}', '}', '{', '\\frac', '}', '{', '{', '\\frac', '}', '{', '1', '}', '}', '}', '}', '{', '{', '\\frac', '1', '}', '_', '{', '1', '}', '}', '}', '}', '{', '{', '\\frac', '}', '^', '{', '\\right)', '$', '.', '\\frac', '}', '}', '^', '.', '_', '\\,', '\\frac', '{', '}', '^', '{', 'i', '}', '}', '}', '=', '.', '\\frac', '}', '}', '^', '{', '$', '<eos>', '_', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{', '{']
['$', 'S', '_', '{', '[', '1', ']', '}', '^', '{', 'c', '}', '\\,', '=', '\\,', 'e', '^', '{', 'i', '\\vartheta', '_', '{', '[', '1', 

In [42]:
# context_vec.std(dim=0)

tensor([1734.1604, 2072.4707,  252.6192,  215.5585,  217.3692, 1952.0667,
        2512.1045,  230.7138,  220.1318, 2070.3169, 1944.1429, 1655.5834,
         178.0718,  213.9041, 1245.3400, 2149.6555, 1944.7310, 1942.9019,
         374.3238, 1622.1216,  504.0812,  623.3975,  214.0963,  230.4600,
        1964.6390,  205.6519,  273.4448,  226.8584, 1580.5632, 1955.5258,
         233.6819,  261.7012, 1666.2042, 1623.9436, 1670.8289, 1598.0292,
         243.2216,  220.0372, 1949.1882,  222.4726, 1978.6477, 1937.8784,
         302.8336,  217.6581, 1967.6904, 2024.3748, 1591.1747, 1569.5627,
        3150.0876,  218.5603,  216.4439, 1637.1022,  217.0717,  164.1452,
        1994.8130, 1587.9167, 1956.4849,  203.3632,  200.6049,  311.7800,
        1585.7943, 1646.0098, 1944.4889, 1584.9890, 1602.2743,  194.7858,
        1958.3536,  444.8534, 1958.7516, 1928.8143, 1658.7949, 1406.0608,
        2660.7786,  210.7152, 1591.8102, 2446.9768, 1605.8856,  199.1292,
        2092.7134, 1621.6423, 1997.470

In [4]:
# print(f"Longest formula in training: {max([len(formula) for formula in dataset.data_frame['IndexList']])}")
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
PAD_IDX = dataset.token_to_idx[PAD]
if device == "cuda":
    torch.cuda.empty_cache()
def remove_trailing_pads(labels):
   # Clip trailing PAD on labels
   non_pad_cols = (labels != PAD_IDX).sum(dim=0)
   non_pad_cols = non_pad_cols[non_pad_cols > 0]

   return labels[:, :len(non_pad_cols)]

loader = data.DataLoader(dataset, batch_size = enc.hp["batch_size"], shuffle = True)
print(len(loader))
model_path = "./models/model.pt"
model_backup_path = "./models/model_backup.pt"
current_params_path = "./models/current_params.txt" 

state_dict = torch.load(model_path)
torch.save((state_dict), model_backup_path)
model.load_state_dict(state_dict)
model.train()
print(f"LOADED MODEL to {device}")

prev_loss = 100
for epoch in range(100):
    curr_loss = 0
    for bidx, batch in enumerate(loader):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        
        labels = remove_trailing_pads(labels)
        context_vec = model.encoder(images).squeeze()

        inputs = torch.cat([context_vec.unsqueeze(1).repeat(1, labels.shape[1], 1), model.decoder.embedding(labels)], dim=2)
        print(f"Running Batch {bidx}, Epoch {epoch}, Total Tokens: {labels.shape[1]}")
        output, _ = model.decoder(inputs, None)

        # output[labels == PAD_IDX] = 0
        # output = F.normalize(output, dim=2, p=1)
        output = output[:, :-1, :]

        target = nn.functional.one_hot(labels[:,1:], num_classes=len(dataset.tokens)).float().to(device)
        # target[labels == PAD_IDX] = 0
        
        # print(f"Output shape: {output.shape}, Labels shape: {labels.shape}, Target shape: {target.shape}")
        optimizer.zero_grad()
        loss = criterion(output.transpose(1, 2), target.transpose(1, 2))
        loss = loss[labels[:,1:] != PAD_IDX].mean()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(f"Layer: {name}, Mean: {param.grad.mean()}, Std: {param.grad.std()}")

        # optimizer.zero_grad()

        print(f"Loss: {loss.item()}")
        curr_loss += loss.item()
        if bidx % 10 == 9:
            print(f"SAVING MODEL to {model_path}")
            torch.save(model.state_dict(), model_path)
            print("SAVED MODEL")
            print(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            try:
                with open(current_params_path, 'w') as f:
                    f.write(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            except:
                print("\n Could not write to file \n")
    print(f"AVG LOSS: {(curr_loss)/len(loader)}, Epoch: {epoch+1}")
    prev_loss = curr_loss
        

1172
LOADED MODEL to cuda
Running Batch 0, Epoch 0, Total Tokens: 123
Loss: 2.506558418273926
Running Batch 1, Epoch 0, Total Tokens: 142
Loss: 2.6690030097961426
Running Batch 2, Epoch 0, Total Tokens: 142
Loss: 2.5177550315856934
Running Batch 3, Epoch 0, Total Tokens: 161
Loss: 2.6396701335906982
Running Batch 4, Epoch 0, Total Tokens: 145
Loss: 2.5864717960357666
Running Batch 5, Epoch 0, Total Tokens: 154
Loss: 2.5781874656677246
Running Batch 6, Epoch 0, Total Tokens: 145
Loss: 2.533294677734375
Running Batch 7, Epoch 0, Total Tokens: 130
Loss: 2.6249117851257324
Running Batch 8, Epoch 0, Total Tokens: 131
Loss: 2.603182315826416
Running Batch 9, Epoch 0, Total Tokens: 145
Loss: 2.6024012565612793
SAVING MODEL to ./models/model.pt
SAVED MODEL
Epoch: 0, Batch: 9, Loss: 2.6024012565612793
Running Batch 10, Epoch 0, Total Tokens: 127
Loss: 2.5511460304260254
Running Batch 11, Epoch 0, Total Tokens: 118
Loss: 2.5721981525421143
Running Batch 12, Epoch 0, Total Tokens: 146
Loss: 2.714

: 