In [1]:
# !nvidia-smi

In [2]:
# from google.colab import drive
# drive.mount('/content/gdrive')
# %cd /content/gdrive/My\ Drive/Colab/ITSP

In [13]:
!ls

best_ckpt.pkl	data.py		  itsp	     media	  static
best_ckpt.pt	db.sqlite3	  latex      model	  template
build_vocab.py	evaluation.ipynb  manage.py  __pycache__  utils.py


In [14]:
! pip install distance

Defaulting to user installation because normal site-packages is not writeable


In [15]:
# load checkpoint and evaluating
from os.path import join
from functools import partial
import argparse

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from data import Im2LatexDataset
from build_vocab import Vocab, load_vocab
from utils import collate_fn
from model import LatexProducer, Im2LatexModel
from model.score import score_files

In [16]:
from argparse import Namespace

args = Namespace(
    model_path = "best_ckpt.pkl",

    # model args
    data_path = "./data/",
    cuda = True,
    batch_size = 8,
    beam_size = 5,
    result_path = "./results/result.txt",
    ref_path = "./results/ref.txt",
    max_len = 64,
    split = "validate"
)

args

Namespace(batch_size=8, beam_size=5, cuda=True, data_path='./data/', max_len=64, model_path='best_ckpt.pkl', ref_path='./results/ref.txt', result_path='./results/result.txt', split='validate')

In [18]:
# Loading Model
checkpoint = torch.load(join(args.model_path))
model_args = checkpoint['args']

# Read the dictionary and set other related parameters
vocab = load_vocab(args.data_path)
use_cuda = True if args.cuda and torch.cuda.is_available() else False

Load vocab including 250 words!


In [20]:
# Load test set
data_loader = DataLoader(
    Im2LatexDataset(args.data_path, args.split, args.max_len),
    batch_size=args.batch_size,
    collate_fn=partial(collate_fn, vocab.sign2id),
    pin_memory=True if use_cuda else False,
    num_workers=4
)

model = Im2LatexModel(
    len(vocab), model_args.emb_dim, model_args.dec_rnn_h,
    add_pos_feat=model_args.add_position_features,
    dropout=model_args.dropout
)
model.load_state_dict(checkpoint['model_state_dict'])

RuntimeError: Error(s) in loading state_dict for Im2LatexModel:
	size mismatch for embedding.weight: copying a param with shape torch.Size([394, 80]) from checkpoint, the shape in current model is torch.Size([250, 80]).
	size mismatch for W_out.weight: copying a param with shape torch.Size([394, 512]) from checkpoint, the shape in current model is torch.Size([250, 512]).

In [None]:
result_file = open(args.result_path, 'w')
ref_file = open(args.ref_path, 'w')

In [None]:
latex_producer = LatexProducer(
    model, vocab, max_len=args.max_len,
    use_cuda=use_cuda, beam_size=args.beam_size)

for imgs, tgt4training, tgt4cal_loss in tqdm(data_loader):
    try:
        reference = latex_producer._idx2formulas(tgt4cal_loss)
        results = latex_producer(imgs)
    except RuntimeError:
        break

    result_file.write('\n'.join(results))
    ref_file.write('\n'.join(reference))

result_file.close()
ref_file.close()
score = score_files(args.result_path, args.ref_path)
print("beam search result:", score)