In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import preprocessing as pre
import dataset as ds
import utils
from vocab import Vocab, label_to_index, indices_to_latex
from model import CRNN, init_weights
from IPython.display import clear_output

# **GET TRAIN/VAL/TEST SPLIT** #

In [None]:
data_root = '../data'
data_folders = ['IM2LATEX-100K', 'IM2LATEX-100K-HANDWRITTEN']
img_folders = ['formula_images', 'images']

# imgs stored in data_root/data_folders/img_folders
assert len(data_folders) == len(img_folders)

# get test/train/val indices
# change ftype if different image type used
df_list = []
for i in range(len(data_folders)):
    print(f'{data_root}/{data_folders[i]}')
    df = pre.data_to_df(f'{data_root}/{data_folders[i]}', max_entries=None)
    df['dataset'] = data_folders[i]
    df_list.append(df)

# **APPEND LABELS**

In [None]:
label_folders = ['im2latex_formulas.lst', 'formulas.lst']

for i, df in enumerate(df_list):
    labels = pre.extract_labels(f'{data_root}/{data_folders[i]}/{label_folders[i]}')
    df['label'] = labels.loc[df['index']].values
    df['label_token_indices'] = np.nan
    # image shape after being processed
    df['padded_height'] = np.nan
    df['padded_width'] = np.nan

# **MERGE AND SHUFFLE DATAFRAMES** #

In [None]:
merge_shuffle_df = pd.concat(df_list).sample(frac=1).reset_index(drop=True).astype('object')
merge_shuffle_df.to_csv('../saved/merge_shuffle_df.csv')

## *If loading from saved df* ##

In [None]:
merge_shuffle_df = pd.read_csv('../saved/merge_shuffle_df.csv').astype('object')

# **TOKENIZE LATEX, CREATE FULL DATAFRAME** #

In [4]:
image_sizes = [
    (40, 160), (40, 200), (40, 240), (40, 280), (40, 320), \
    (50, 120), (50, 200), (50, 240), (50, 280), (50, 320)
]

In [None]:
merge_shuffle_df, images, vocab = pre.scale_images(
    merge_shuffle_df,
    maxlen=200,
    image_sizes=image_sizes,
    reshape_strat='pad'
)

In [None]:
merge_shuffle_df.to_csv('../saved/merge_shuffle_df_processed.csv')
image_handler = open('../saved/images', 'wb')
pickle.dump(images, image_handler)
vocab_handler = open('../saved/vocab', 'wb')
pickle.dump(vocab, vocab_handler)
image_handler.close()
vocab_handler.close()

## *Load items in if everything is saved* ##

In [5]:
merge_shuffle_df = pd.read_csv('../saved/merge_shuffle_df_processed.csv')
images = pickle.load(open('../saved/images', 'rb'))
vocab = Vocab()
vocab = pickle.load(open('../saved/vocab', 'rb'))

merge_shuffle_df = merge_shuffle_df.astype('object')

In [6]:
for idx in merge_shuffle_df.index:
    clear_output(wait=True)
    print(idx)
    merge_shuffle_df.at[idx, 'label'] \
        = [x[1:-1] for x in np.asarray(merge_shuffle_df['label'].iloc[idx][1:-1].split(', '))]
    merge_shuffle_df.at[idx, 'label_token_indices'] \
        = np.array(merge_shuffle_df['label_token_indices'].iloc[idx][1:-1].split(), dtype=float)

102460


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



# **CREATE DATALOADERS** #

In [7]:
train_dataloaders = [
    ds.gen_dataloader(
        merge_shuffle_df=merge_shuffle_df,
        images=images,
        split='train',
        extra_cond=((merge_shuffle_df['padded_height'] == h) & (merge_shuffle_df['padded_width'] == w)),
        batch_size=30, 
        shuffle=True,
        merge_train_validate=True
    ) for (h, w) in image_sizes
]

test_dataloaders = [
    ds.gen_dataloader(
        merge_shuffle_df=merge_shuffle_df,
        images=images,
        split='test',
        extra_cond=((merge_shuffle_df['padded_height'] == h) & (merge_shuffle_df['padded_width'] == w)),
        batch_size=30, 
        shuffle=True,
        merge_train_validate=True
    ) for (h, w) in image_sizes
]

# **CONSTRUCT MODEL** #

In [8]:
import torch
from torch import nn
from torch.nn import functional as F

In [9]:
model = CRNN(vocab, embed_size=32)
model.apply(init_weights)
model.to('cuda:0')

CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): PReLU(num_parameters=1)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): PReLU(num_parameters=1)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil

## *If there are trained parameters to load in* ##

In [11]:
model.load_state_dict(torch.load('../saved/run4/CRNN_params_epoch_5.pt'))

<All keys matched successfully>

In [None]:
# dataloaders is a list
def train_loop(model, train_dataloaders, vocab, num_epochs, lr=0.1, test_dataloaders=None, main_device='cpu'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-6)
    ce_loss = nn.CrossEntropyLoss(reduction='none')
    ctc_loss = nn.CTCLoss(blank=vocab.get_index('<nul>'), reduction='none', zero_infinity=True)
    animator = utils.Animator(
        measurement_names=['ctc_loss'],
        refresh=1
    )
    metrics = utils.MetricBuffer(['sum_ctc_loss', 'n_tokens'])
    
    for epoch in range(num_epochs):
        torch.autograd.set_detect_anomaly(True)

        n_imgs = 0
        for dataloader in train_dataloaders:
            for i, (imgs, labels) in enumerate(dataloader):
                model.train()
                optimizer.zero_grad()

                imgs = torch.tensor(imgs).to(main_device).float()
                labels = labels.to(main_device).long()
                output = model(imgs, teacher_forcing=False)
                # output.shape = (seq_len, batch_size, vocab_size)

                # for CTCLoss
                log_softmax_probs = F.log_softmax(output, dim=2)
                labels = labels[:, 1:log_softmax_probs.shape[0]+1]
                # requires shape (seq_len, batch_size, vocab_size)
                print(log_softmax_probs.shape, labels.shape, (torch.ones((labels.shape[0])) * labels.shape[1]).shape)
                main_loss = ctc_loss(
                    log_softmax_probs,
                    labels, 
                    (torch.ones((labels.shape[0])) * labels.shape[1]).int().to(output.device),
                    (labels != vocab.get_index('<pad>')).sum(dim=1).int().to(output.device)
                )

                # for CrossEntropyLoss
                second_loss = ce_loss(
                    output.permute(1, 2, 0),
                    labels
                )

                main_loss.sum().backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()

                metrics.update([('sum_ctc_loss', main_loss.sum().cpu().detach().numpy())])
                metrics.update([('n_tokens', (labels != vocab.get_index('<pad>')).sum(dim=[0, 1]).cpu().detach().numpy())])

                n_imgs += imgs.shape[0]
                # print(n_imgs)
                
                print(i)
                if i % 10 == 0:
                    clear_output(wait=True)
                    print(f"n_imgs #{n_imgs}")
                    print(f"Current loss: {metrics['sum_ctc_loss'] / metrics['n_tokens']}")
                    # try printing from test set
                    #for test_imgs, test_labels in gen_dataloader('train', batch_size=2, shuffle=True):
                    if test_dataloaders is not None:
                        test_loader = test_dataloaders[np.random.randint(0, len(test_dataloaders))]
                        for test_imgs, test_labels in test_loader:
                            model.eval()
                            with torch.no_grad():
                                output = model(test_imgs.to(main_device).float())
                                print(output)
                                seq_list = vocab.indices_to_latex(output, vocab)
                                label_list = vocab.indices_to_latex(test_labels.int(), vocab)
                                print(len(seq_list))
                                for i, seq in enumerate(seq_list):
                                    plt.imshow(test_imgs[i, 0])
                                    plt.show()
                                    print("Prediction:")
                                    print(seq)
                                    print("Real:")
                                    print(label_list[i])
                                    print()
                                    print()
                                break
            
        animator.append([
            ('ctc_loss', metrics['sum_ctc_loss'] / metrics['n_tokens']),
        ])

        metrics.clear()
        torch.save(model.state_dict(), f'../saved/main2/CRNN_params_epoch_{epoch}.pt')
            
    torch.save(model.state_dict(), f'../saved/main2/CRNN_params_epoch_{num_epochs}.pt')

In [None]:
train_loop(model, train_dataloaders, vocab, num_epochs=5, lr=3e-4, test_dataloaders=test_dataloaders, main_device='cuda:0')

# **EVALUATE MODEL** #

In [None]:
import nltk

In [None]:
sum_bleu = 0
sum_edit = 0
n_examples = 0
for test_loader in test_dataloaders:
    for test_imgs, test_labels in test_loader:
        model.eval()
        with torch.no_grad():
            output = model(test_imgs.to('cuda:0').float())
            seq_list = indices_to_latex(output, vocab)
            label_list = indices_to_latex(test_labels.int(), vocab, using_ctc_loss=False)
            for i, seq in enumerate(seq_list):
                plt.imshow(test_imgs[i, 0])
                plt.show()
                bleu_score = nltk.translate.bleu_score.sentence_bleu([label_list[i]], seq)
                edit_dist = nltk.edit_distance(label_list[i], seq)
                sum_bleu += bleu_score
                sum_edit += edit_dist
                n_examples += 1
                print("Prediction:")
                print(seq)
                print("Real:")
                print(label_list[i])
                print("BLEU:")
                print(bleu_score)
                print("Edit:")
                print(edit_dist)
                print()

clear_output(wait=True)
print(f"Avg. BLEU-4: {sum_bleu / n_examples}")
print(f"Avg. edit distance: {sum_edit / n_examples}")