In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import preprocessing as pre
import dataset as ds
import utils
from vocab import Vocab, label_to_index, indices_to_latex
from model import CRNN, init_weights
from IPython.display import clear_output

# **GET TRAIN/VAL/TEST SPLIT** #

In [None]:
data_root = './data'
data_folders = ['IM2LATEX-100K'] #,'IM2LATEX-100K-HANDWRITTEN']
img_folders = ['formula_images'] #, 'images']

# imgs stored in data_root/data_folders/img_folders
assert len(data_folders) == len(img_folders)

# get test/train/val indices
# change ftype if different image type used
df_list = []
for i in range(len(data_folders)):
    print(f'{data_root}/{data_folders[i]}')
    df = pre.data_to_df(f'{data_root}/{data_folders[i]}', max_entries=None)
    df['dataset'] = data_folders[i]
    df_list.append(df)

# **APPEND LABELS**

In [None]:
label_folders = ['im2latex_formulas.lst'] #, 'formulas.lst']

for i, df in enumerate(df_list):
    labels = pre.extract_labels(f'{data_root}/{data_folders[i]}/{label_folders[i]}')
    df['label'] = labels.loc[df['index']].values
    df['label_token_indices'] = np.nan
    # image shape after being processed
    df['padded_height'] = np.nan
    df['padded_width'] = np.nan

# **MERGE AND SHUFFLE DATAFRAMES** #

In [None]:
merge_shuffle_df = pd.concat(df_list).sample(frac=1).reset_index(drop=True).astype('object')
merge_shuffle_df.to_csv('../saved/merge_shuffle_df.csv')

## *If loading from saved df* ##

In [None]:
merge_shuffle_df = pd.read_csv('../saved/merge_shuffle_df.csv').astype('object')

# **TOKENIZE LATEX, CREATE FULL DATAFRAME** #

In [None]:
image_sizes = [
    (40, 160), (40, 200), (40, 240), (40, 280), (40, 320), \
    (50, 120), (50, 200), (50, 240), (50, 280), (50, 320)
]

In [None]:
import cv2

In [None]:
def reshape_images(img_path, aspect_ratios=None, image_sizes=None, reshape_strat='pad'):
    """Scale images to proper shape."""
    _, img = crop_equations(img_path)
        
    # figure out best padding
    pad_height, pad_width = None, None
    if image_sizes is not None:
        # pad with 8 px to all sides
        img = np.pad(img, pad_width=8, mode='constant', constant_values=255)

        # scale to certain aspect ratio
        if reshape_strat == 'scale':
            # look for best aspect ratio
            # second element of tuple is irrelevant
            curr_ratio = img.shape[1] / img.shape[0]
            idx = bisect.bisect_left(aspect_ratios, (curr_ratio, -1))
            reshape_to = None
            if idx == 0:
                reshape_to = 0
            elif idx == len(aspect_ratios):
                reshape_to = len(aspect_ratios)-1
            else:
                if curr_ratio / aspect_ratios[idx-1] <= aspect_ratios[idx] / curr_ratio:
                    reshape_to = idx-1
                else:
                    reshape_to = idx
            img = cv2.resize(
                img, 
                dsize=(aspect_ratios[reshape_to][1], aspect_ratios[reshape_to][0]),
                interpolation=cv2.INTER_AREA
            )
            pad_height = aspect_ratios[reshape_to][0]
            pad_width = aspect_ratios[reshape_to][1]
        # pad to fit image size
        elif reshape_strat == 'pad':
            img = cv2.resize(
                img, 
                dsize=(
                    min(int(img.shape[1]/2), 320),
                    min(int(img.shape[0]/2), 50)
                ),
                interpolation=cv2.INTER_AREA
            )
            # plt.imshow(img)
            # plt.show()
            # figure out which shape to pad to
            pad_dh, pad_dw = np.inf, np.inf
            for _, (height, width) in enumerate(image_sizes):
                dh = height-img.shape[0]
                dw = width-img.shape[1]
                if 0 <= dh and dh <= pad_dh and 0 <= dw and dw <= pad_dw:
                    pad_dh = height-img.shape[0]
                    pad_dw = width-img.shape[1]
                    pad_height = height
                    pad_width = width
            # if odd padding, top padded less than bottom and left less than right
            top = int(pad_dh/2)
            left = int(pad_dw/2)
            img = np.pad(
                img,
                pad_width=((top, pad_dh-top), (left, pad_dw-left)),
                mode='constant',
                constant_values=255
            )

    img = np.abs(255-img)
    img = np.expand_dims(img, axis=(0))

    return img, pad_height, pad_width


# image_sizes is a list of tuples(height, width) to try to fit to (if batching)
# reshape_strat = 'scale', 'pad', 'None'
def scale_images(merge_shuffle_df, maxlen=None, image_sizes=None, 
                 reshape_strat='pad', model=None):
    """Scale images to proper shape and create proper labels."""
    v = Vocab()
    images = {}
    # Harvard's image sizes:
    #   (40, 160), (40, 200), (40, 240), (40, 280), (40, 320)
    #   (50, 120), (50, 200), (50, 240), (50, 280)
    aspect_ratios = []
    seq_len_dict = {}
    if image_sizes is not None:
        for i, (height, width) in enumerate(image_sizes):
            aspect_ratios.append((width / height, i))
            if maxlen is None:
                assert model is not None
                seq_len_dict[(height, width)] = model.get_seq_len(height, width)
        aspect_ratios.sort()

    for i, row in merge_shuffle_df.iterrows():
        # remove \label{...} and tokenize
        #label = tokenize(re.sub(r'\\label\{.*\}', '', row['label']))
        label = pre.tokenize(row['label'])
        # print(label)
        v.update(label)

        img, pad_height, pad_width = reshape_images(
            row['image_path'], aspect_ratios, image_sizes, reshape_strat
        )

        merge_shuffle_df.at[i, 'label'] = label
        merge_shuffle_df.at[i, 'padded_height'] = pad_height
        merge_shuffle_df.at[i, 'padded_width'] = pad_width
        images[(row['index'], row['dataset'])] = img
        
        print(f'{i+1}/{len(merge_shuffle_df)}')
        clear_output(wait=True)
        
    for i, row in merge_shuffle_df.iterrows():
        if maxlen is None:
            maxlen = None if image_sizes is None \
                else seq_len_dict[(row['padded_height'], row['padded_width'])]
        merge_shuffle_df.at[i, 'label_token_indices'] = np.array(
            label_to_index(row['label'], v, maxlen=maxlen)
        )
        
        print(f'{i+1}/{len(merge_shuffle_df)}')
        clear_output(wait=True)

    return merge_shuffle_df, images, v


def crop_equations(path, show_image=False):
    """Simplistic method to isolate the singular equation in the image."""
    img = cv2.imread(path, 0)

    LR_mask = cv2.reduce(img, 0, cv2.REDUCE_MIN) != 255
    UD_mask = cv2.reduce(img, 1, cv2.REDUCE_MIN) != 255
    c1, c2 = LR_mask.argmax(), LR_mask.shape[1] - np.flip(LR_mask).argmax()
    r1, r2 = UD_mask.argmax(), UD_mask.shape[0] - np.flip(UD_mask).argmax()

    if show_image:
        plt.imshow(img, cmap='gray')
        plt.show()
        
    return path, img[r1:r2, c1:c2]

In [None]:
merge_shuffle_df, images, vocab = scale_images(
    merge_shuffle_df,
    maxlen=200,
    image_sizes=image_sizes,
    reshape_strat='pad'
)

In [None]:
merge_shuffle_df.to_csv('../saved/merge_shuffle_df_processed.csv')
image_handler = open('../saved/images', 'wb')
pickle.dump(images, image_handler)
vocab_handler = open('../saved/vocab', 'wb')
pickle.dump(vocab, vocab_handler)
image_handler.close()
vocab_handler.close()

## *Load items in if everything is saved* ##

In [None]:
merge_shuffle_df = pd.read_csv('../saved/merge_shuffle_df_processed.csv')
images = pickle.load(open('../saved/images', 'rb'))
vocab = Vocab()
vocab = pickle.load(open('../saved/vocab', 'rb'))

merge_shuffle_df = merge_shuffle_df.astype('object')

In [None]:
for idx in merge_shuffle_df.index:
    clear_output(wait=True)
    print(idx)
    merge_shuffle_df.at[idx, 'label'] \
        = [x[1:-1] for x in np.asarray(merge_shuffle_df['label'].iloc[idx][1:-1].split(', '))]
    merge_shuffle_df.at[idx, 'label_token_indices'] \
        = np.array(merge_shuffle_df['label_token_indices'].iloc[idx][1:-1].split(), dtype=float)

# **CREATE DATALOADERS** #

In [None]:
train_dataloaders = [
    ds.gen_dataloader(
        merge_shuffle_df=merge_shuffle_df,
        images=images,
        split='train',
        extra_cond=((merge_shuffle_df['padded_height'] == h) & (merge_shuffle_df['padded_width'] == w)),
        batch_size=20, 
        shuffle=True,
        merge_train_validate=True
    ) for (h, w) in image_sizes
]

test_dataloaders = [
    ds.gen_dataloader(
        merge_shuffle_df=merge_shuffle_df,
        images=images,
        split='test',
        extra_cond=((merge_shuffle_df['padded_height'] == h) & (merge_shuffle_df['padded_width'] == w)),
        batch_size=20, 
        shuffle=True,
        merge_train_validate=True
    ) for (h, w) in image_sizes
]

# **CONSTRUCT MODEL** #

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
model = CRNN(vocab, embed_size=32)
model.apply(init_weights)
model.to('cuda:0')

## *If there are trained parameters to load in* ##

In [None]:
model.load_state_dict(torch.load('../saved/params/CRNN_params_epoch_5.pt'))

In [None]:
# dataloaders is a list
def train_loop(model, train_dataloaders, vocab, num_epochs, lr=0.1, test_dataloaders=None, main_device='cpu'):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=1e-6)
    ce_loss = nn.CrossEntropyLoss(reduction='none')
    ctc_loss = nn.CTCLoss(blank=vocab.get_index('<nul>'), reduction='none', zero_infinity=True)
    animator = utils.Animator(
        measurement_names=['ctc_loss'],
        refresh=1
    )
    metrics = utils.MetricBuffer(['sum_ctc_loss', 'n_tokens'])
    
    for epoch in range(num_epochs):
        torch.autograd.set_detect_anomaly(True)

        n_imgs = 0
        for dataloader in train_dataloaders:
            for i, (imgs, labels) in enumerate(dataloader):
                model.train()
                optimizer.zero_grad()

                imgs = torch.tensor(imgs).to(main_device).float()
                labels = labels.to(main_device).long()
                output = model(imgs, teacher_forcing=False)
                # output.shape = (seq_len, batch_size, vocab_size)

                # for CTCLoss
                log_softmax_probs = F.log_softmax(output, dim=2)
                labels = labels[:, 1:log_softmax_probs.shape[0]+1]
                # requires shape (seq_len, batch_size, vocab_size)
                print(log_softmax_probs.shape, labels.shape, (torch.ones((labels.shape[0])) * labels.shape[1]).shape)
                main_loss = ctc_loss(
                    log_softmax_probs,
                    labels, 
                    (torch.ones((labels.shape[0])) * labels.shape[1]).int().to(output.device),
                    (labels != vocab.get_index('<pad>')).sum(dim=1).int().to(output.device)
                )

                # for CrossEntropyLoss
                second_loss = ce_loss(
                    output.permute(1, 2, 0),
                    labels
                )

                main_loss.sum().backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()

                metrics.update([('sum_ctc_loss', main_loss.sum().cpu().detach().numpy())])
                metrics.update([('n_tokens', (labels != vocab.get_index('<pad>')).sum(dim=[0, 1]).cpu().detach().numpy())])

                n_imgs += imgs.shape[0]
                # print(n_imgs)
                
                print(i)
                if i % 10 == 0:
                    clear_output(wait=True)
                    print(f"n_imgs #{n_imgs}")
                    print(f"Current loss: {metrics['sum_ctc_loss'] / metrics['n_tokens']}")
                    # try printing from test set
                    #for test_imgs, test_labels in gen_dataloader('train', batch_size=2, shuffle=True):
                    if test_dataloaders is not None:
                        test_loader = test_dataloaders[np.random.randint(0, len(test_dataloaders))]
                        for test_imgs, test_labels in test_loader:
                            model.eval()
                            with torch.no_grad():
                                output = model(test_imgs.to(main_device).float())
                                print(output)
                                seq_list = indices_to_latex(output, vocab)
                                label_list = indices_to_latex(test_labels.int(), vocab)
                                print(len(seq_list))
                                for i, seq in enumerate(seq_list):
                                    plt.imshow(test_imgs[i, 0])
                                    plt.show()
                                    print("Prediction:")
                                    print(seq)
                                    print("Real:")
                                    print(label_list[i])
                                    print()
                                    print()
                                break

            
        animator.append([
            ('ctc_loss', metrics['sum_ctc_loss'] / metrics['n_tokens']),
        ])

        metrics.clear()
        torch.save(model.state_dict(), f'../saved/params/CRNN_params_epoch_{epoch+1}.pt')

In [None]:
train_loop(model, train_dataloaders, vocab, num_epochs=5, lr=3e-4, test_dataloaders=test_dataloaders, main_device='cuda:0')

# **EVALUATE MODEL** #

In [None]:
import nltk

In [None]:
sum_bleu = 0
sum_edit = 0
n_examples = 0
for test_loader in test_dataloaders:
    for test_imgs, test_labels in test_loader:
        model.eval()
        with torch.no_grad():
            output = model(test_imgs.to('cuda:0').float())
            seq_list = indices_to_latex(output, vocab)
            label_list = indices_to_latex(test_labels.int(), vocab, using_ctc_loss=False)
            for i, seq in enumerate(seq_list):
                plt.imshow(test_imgs[i, 0])
                plt.show()
                bleu_score = nltk.translate.bleu_score.sentence_bleu([label_list[i]], seq)
                edit_dist = nltk.edit_distance(label_list[i], seq)
                sum_bleu += bleu_score
                sum_edit += edit_dist
                n_examples += 1
                print("Prediction:")
                print(seq)
                print("Real:")
                print(label_list[i])
                print("BLEU:")
                print(bleu_score)
                print("Edit:")
                print(edit_dist)
                print()

clear_output(wait=True)
print(f"Avg. BLEU-4: {sum_bleu / n_examples}")
print(f"Avg. edit distance: {sum_edit / n_examples}")