In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderCNN(nn.Module):
    def __init__(self, layers, hparams):
        '''
        Args:
            layers: Description of all layers in the Encoder: [(layer_type, {layer_params})]
                - layer types - ['conv1d', 'conv2d', 'maxpool1d', 'maxpool2d', 'avgpool2d', 'avgpool2d', 'linear', 'dropout']
                - layer_params - dict of parameters for the layer

            hparams: Hyperparameters for the model
        '''
        super(EncoderCNN, self).__init__()
        self.hp = hparams
        self.layers = nn.ModuleList()

        for layer_type, layer_params in layers:
            if layer_type == 'conv1d':
                self.layers.append(nn.Conv1d(**layer_params))
            elif layer_type == 'conv2d':
                self.layers.append(nn.Conv2d(**layer_params))
            elif layer_type == 'maxpool1d':
                self.layers.append(nn.MaxPool1d(**layer_params))
            elif layer_type == 'maxpool2d':
                self.layers.append(nn.MaxPool2d(**layer_params))
            elif layer_type == 'avgpool1d':
                self.layers.append(nn.AvgPool1d(**layer_params))
            elif layer_type == 'avgpool2d':
                self.layers.append(nn.AvgPool2d(**layer_params))
            elif layer_type == 'linear':
                self.layers.append(nn.Linear(**layer_params))
            elif layer_type == 'dropout':
                self.layers.append(nn.Dropout(**layer_params))
            else:
                raise ValueError(f'Invalid layer type: {layer_type}')

    def forward(self, input):
        for layer in self.layers:
            input = layer(input)
        return input
    
class DecoderRNN(nn.Module):
    def __init__(self, vocab, vocab_dict, input_size, embedding_size):
        super(DecoderRNN, self).__init__()
        '''
        Args:
            vocabulary_size: Size of the vocabulary
            embedding_size: Size of the embedding vector
        '''

        self.vocab = vocab
        self.vocab_dict = vocab_dict

        self.embedding = nn.Embedding(len(vocab), embedding_size)
        self.embedding_size = embedding_size
        self.lstm = nn.LSTM(input_size+embedding_size, embedding_size, batch_first=True)
        self.output = nn.Linear(embedding_size, len(vocab))

    def forward(self, input, hidden):
        '''
        Args:
            input: Input to the decoder
            hidden: Hidden state of the previous time step
        '''
        # prev_embed = self.embedding(prev_tokens)
        # concated_inp = torch.cat((input, prev_embed), dim=1)
        if hidden is None:
            output, hidden = self.lstm(input)
        else:
            output, hidden = self.lstm(input, hidden)
        output = self.output(output)

        return output, hidden
    
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def predict(self, input):
        context_vec = self.encoder(input)
        prev_token = torch.ones((input.shape[0]), dtype=int).cuda()

        

In [2]:
# Load dataset
import torch.utils.data as data
from torchvision import transforms
from torchtext.vocab import build_vocab_from_iterator
import pandas as pd
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PAD = "<pad>"
SOS = "<sos>"
EOS = "<eos>"

def load_img(path, size = (224, 224)):
    img = (Image.open(path))
    transform = transforms.Compose([transforms.Resize(size, antialias=True), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    im = transform(img).detach()
    im = 1 - im
    return im

class Img2LatexDataset(data.Dataset):
    def __init__(self, img_dir, formula_path, img_size = (224, 224)):
        self.data_frame = pd.read_csv(formula_path)
        self.img_dir = img_dir
        self.img_size = img_size

        self.token_to_idx = {}
        self.tokens = []

        for row in self.data_frame["formula"]:
            row = row.split()

            for token in row:
                if token not in self.token_to_idx:
                    self.token_to_idx[token] = len(self.token_to_idx)
                    self.tokens.append(token)
        
        for special_token in [SOS, EOS, PAD]:
            self.token_to_idx[special_token] = len(self.token_to_idx)
            self.tokens.append(special_token)

        max_len = max([len(row.split()) for row in self.data_frame["formula"]])+2
        def indexer(row):
            index_list = [self.token_to_idx[SOS]]
            index_list.extend([self.token_to_idx[token] for token in row.split()])
            index_list.append(self.token_to_idx[EOS])
            index_list.extend([self.token_to_idx[PAD]] * (max_len - len(index_list)))

            return index_list
        
        self.data_frame["IndexList"] = self.data_frame["formula"].apply(indexer)

    def __getitem__(self, index):
        img = load_img(self.img_dir + self.data_frame["image"][index], self.img_size)
        return img, torch.tensor(self.data_frame["IndexList"][index], requires_grad=False)

    def __len__(self):
        return len(self.data_frame)
    
    def get_vocab(self):
        return self.token_to_idx, self.tokens

img_dir = "../data/SyntheticData/images/"
formula_dir = "../data/SyntheticData/train.csv"

dataset = Img2LatexDataset(img_dir, formula_dir)


Using device: cuda


In [6]:
hparams = {
    "lr" : 0.0001,
    "batch_size" : 96,
    "epochs" : 10
}

channel_seq = [3, 32, 64, 128, 256, 512]
num_conv_pool = 5

enc_layers = []

for i in range(num_conv_pool):
    enc_layers.append(('conv2d', {'in_channels': channel_seq[i], 'out_channels': channel_seq[i+1], 'kernel_size': 5}))
    enc_layers.append(('maxpool2d', {'kernel_size': 2}))

enc_layers.append(('avgpool2d', {'kernel_size': (3,3)}))

enc = EncoderCNN(enc_layers, hparams).to(device)
dec = DecoderRNN(dataset.tokens, dataset.token_to_idx, 512, 512).to(device)

model = EncoderDecoder(enc, dec).to(device)

In [None]:
for param in model.parameters():
    print(type(param.data), param.size())

In [7]:
# print(f"Longest formula in training: {max([len(formula) for formula in dataset.data_frame['IndexList']])}")
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
PAD_IDX = dataset.token_to_idx[PAD]
if device == "cuda":
    torch.cuda.empty_cache()
def remove_trailing_pads(labels):
   # Clip trailing PAD on labels
   non_pad_cols = (labels != PAD_IDX).sum(dim=0)
   non_pad_cols = non_pad_cols[non_pad_cols > 0]

   return labels[:, :len(non_pad_cols)]

loader = data.DataLoader(dataset, batch_size = enc.hp["batch_size"], shuffle = True)
print(len(loader))
model_path = "./models/model.pt"
model_backup_path = "./models/model_backup.pt"
current_params_path = "./models/current_params.txt" 

state_dict = torch.load(model_path)
torch.save((state_dict), model_backup_path)
model.load_state_dict(state_dict)
model.train()
print(f"LOADED MODEL to {device}")

fifty_fifty = False
teacher_forcing = True

prev_loss = 100
for epoch in range(8):
    curr_loss = 0
    for bidx, batch in enumerate(loader):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        
        labels = remove_trailing_pads(labels)
        context_vec = model.encoder(images).squeeze()
        if (bidx%2 and fifty_fifty) or teacher_forcing:
            inputs = torch.cat([context_vec.unsqueeze(1).repeat(1, labels.shape[1], 1), model.decoder.embedding(labels)], dim=2)
            print(f"Running Batch {bidx}, Epoch {epoch}, Total Tokens: {labels.shape[1]}")
            output, _ = model.decoder(inputs, None)

            # output[labels == PAD_IDX] = 0
            # output = F.normalize(output, dim=2, p=1)
            output = output[:, :-1, :]

        else:
            output = torch.zeros((labels.shape[0], labels.shape[1]-1, len(dataset.tokens))).to(device)

            prev_token = torch.ones(labels.shape[0], dtype=int).to(device) * dataset.token_to_idx[SOS]
            prev_token_embed = model.decoder.embedding(prev_token).to(device)

            input = torch.cat([context_vec, prev_token_embed], dim=1).to(device)
            hidden = None

            for i in range(labels.shape[1]-1):
                output[:, i, :], hidden = model.decoder(input, hidden)
                prev_token = output[:, i, :].argmax(dim=1)
                prev_token_embed = model.decoder.embedding(prev_token)
                input = torch.cat([context_vec, prev_token_embed], dim=1).to(device)
            
        target = nn.functional.one_hot(labels[:,1:], num_classes=len(dataset.tokens)).float().to(device)
        # target[labels == PAD_IDX] = 0
        mask = labels[:,1:] != PAD_IDX
        
        # print(f"Output shape: {output.shape}, Labels shape: {labels.shape}, Target shape: {target.shape}")
        optimizer.zero_grad()
        loss = criterion(output.transpose(1, 2), target.transpose(1, 2))
        loss = loss * mask
        loss = loss.sum() / mask.sum()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(f"Layer: {name}, Mean: {param.grad.mean()}, Std: {param.grad.std()}")

        # optimizer.zero_grad()

        print(f"Loss: {loss.item()}")
        curr_loss += loss.item()
        if bidx % 10 == 9:
            print(f"SAVING MODEL to {model_path}")
            torch.save(model.state_dict(), model_path)
            print("SAVED MODEL")
            print(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            try:
                with open(current_params_path, 'w') as f:
                    f.write(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            except:
                print("\n Could not write to file \n")
    print(f"AVG LOSS: {(curr_loss)/len(loader)}, Epoch: {epoch+1}")
    prev_loss = curr_loss

782
LOADED MODEL to cuda
Running Batch 0, Epoch 0, Total Tokens: 171
Loss: 6.324601650238037
Running Batch 1, Epoch 0, Total Tokens: 135
Loss: 6.2754950523376465
Running Batch 2, Epoch 0, Total Tokens: 150
Loss: 6.211289882659912
Running Batch 3, Epoch 0, Total Tokens: 186
Loss: 6.125494480133057
Running Batch 4, Epoch 0, Total Tokens: 203
Loss: 5.97739839553833
Running Batch 5, Epoch 0, Total Tokens: 224
Loss: 5.7900309562683105
Running Batch 6, Epoch 0, Total Tokens: 143
Loss: 5.58447790145874
Running Batch 7, Epoch 0, Total Tokens: 143
Loss: 5.351614952087402
Running Batch 8, Epoch 0, Total Tokens: 176
Loss: 5.184464931488037
Running Batch 9, Epoch 0, Total Tokens: 146
Loss: 4.966055870056152
SAVING MODEL to ./models/model.pt
SAVED MODEL
Epoch: 0, Batch: 9, Loss: 4.966055870056152
Running Batch 10, Epoch 0, Total Tokens: 126
Loss: 4.906158924102783
Running Batch 11, Epoch 0, Total Tokens: 140
Loss: 4.785327911376953
Running Batch 12, Epoch 0, Total Tokens: 155
Loss: 4.72297716140747

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "./models/model_50synth.pt")

In [None]:
# fine tuning
handwritten_imgs = "../data/HandwrittenData/images/"
handwritten_labels = "../data/HandwrittenData/train_hw.csv"

handwritten_dataset = Img2LatexDataset(handwritten_imgs, handwritten_labels)
handwritten_loader = data.DataLoader(handwritten_dataset, batch_size=64, shuffle=True)
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.encoder.parameters(), lr = 0.001)
PAD_IDX = handwritten_dataset.token_to_idx[PAD]
if device == "cuda":
    torch.cuda.empty_cache()
def remove_trailing_pads(labels):
   # Clip trailing PAD on labels
   non_pad_cols = (labels != PAD_IDX).sum(dim=0)
   non_pad_cols = non_pad_cols[non_pad_cols > 0]

   return labels[:, :len(non_pad_cols)]

print(len(handwritten_loader))
model_path = "./models/model_hw.pt"
model_backup_path = "./models/model_hw_backup.pt"
current_params_path = "./models/current_params_hw.txt" 

state_dict = torch.load("./models/model_synth.pt")
torch.save((state_dict), model_backup_path)
model.load_state_dict(state_dict)
model.train()
print(f"LOADED MODEL to {device}")

fifty_fifty = True
teacher_forcing = False

prev_loss = 100
for epoch in range(1):
    curr_loss = 0
    for bidx, batch in enumerate(handwritten_loader):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        
        labels = remove_trailing_pads(labels)
        context_vec = model.encoder(images).squeeze()
        if (bidx%2 and fifty_fifty) or teacher_forcing:
            inputs = torch.cat([context_vec.unsqueeze(1).repeat(1, labels.shape[1], 1), model.decoder.embedding(labels)], dim=2)
            print(f"Running Batch {bidx}, Epoch {epoch}, Total Tokens: {labels.shape[1]}")
            output, _ = model.decoder(inputs, None)

            # output[labels == PAD_IDX] = 0
            # output = F.normalize(output, dim=2, p=1)
            output = output[:, :-1, :]

        else:
            output = torch.zeros((labels.shape[0], labels.shape[1]-1, len(handwritten_dataset.tokens))).to(device)

            prev_token = torch.ones(labels.shape[0], dtype=int).to(device) * handwritten_dataset.token_to_idx[SOS]
            prev_token_embed = model.decoder.embedding(prev_token).to(device)

            input = torch.cat([context_vec, prev_token_embed], dim=1).to(device)
            hidden = None

            for i in range(labels.shape[1]-1):
                output[:, i, :], hidden = model.decoder(input, hidden)
                prev_token = output[:, i, :].argmax(dim=1)
                prev_token_embed = model.decoder.embedding(prev_token)
                input = torch.cat([context_vec, prev_token_embed], dim=1).to(device)
            
        target = nn.functional.one_hot(labels[:,1:], num_classes=len(handwritten_dataset.tokens)).float().to(device)
        # target[labels == PAD_IDX] = 0
        mask = labels[:,1:] != PAD_IDX
        # print(f"Output shape: {output.shape}, Labels shape: {labels.shape}, Target shape: {target.shape}")
        optimizer.zero_grad()
        loss = criterion(output.transpose(1, 2), target.transpose(1, 2))
        loss = loss * mask
        loss = loss.sum() / mask.sum()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(f"Layer: {name}, Mean: {param.grad.mean()}, Std: {param.grad.std()}")

        # optimizer.zero_grad()

        print(f"Loss: {loss.item()}")
        curr_loss += loss.item()
        if bidx % 10 == 9:
            print(f"SAVING MODEL to {model_path}")
            torch.save(model.state_dict(), model_path)
            print("SAVED MODEL")
            print(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            try:
                with open(current_params_path, 'w') as f:
                    f.write(f"Epoch: {epoch}, Batch: {bidx}, Loss: {loss.item()}")
            except:
                print("\n Could not write to file \n")
    print(f"AVG LOSS: {(curr_loss)/len(handwritten_loader)}, Epoch: {epoch+1}")
    prev_loss = curr_loss