In [None]:
import os
print(os.environ["PYTHONPATH"])  # Should contain parent dirrectory of image_captioning module

In [None]:
import json

import pandas as pd
import numpy as np
import editdistance
import matplotlib.pyplot  as plt
from matplotlib.pyplot import figure, imshow, axis
from matplotlib.image import imread

import torch

import image_captioning.constants as C
from image_captioning.caption import caption_image_beam_search, visualize_att

In [None]:
report = pd.read_csv(str(C.SVHN_EVAL_PATH), dtype=np.object)
report["score"] = report["score"].astype(np.float)
report["probability"]= np.exp(report["score"])
report["predicted"] = report["predicted"].fillna("")
report["edit_distance"] = report.apply(lambda x: editdistance.eval(x["predicted"], x["correct"]), axis=1)
report["norm_edit_distance"] = report.apply(lambda x: x["edit_distance"] / len(x["correct"]), axis=1)
report["delusion"] = report["edit_distance"] * report["probability"] 
report["error"] = report["predicted"] != report["correct"]
report = report.sort_values(by="delusion", ascending=False)
report.head()

In [None]:
report[["error", "norm_edit_distance", "edit_distance"]].mean().to_frame().T

In [None]:
SHOW = 30
class args:
    model = "BEST_checkpoint_svhn_1_cap_per_img_5_min_word_freq.pth.tar"
    word_map = str(C.DIGIT_WORD_MAP_PATH)
    beam_size = 3
    smooth = True
    
device = "cpu"

# Load model
checkpoint = torch.load(args.model)
decoder = checkpoint['decoder'].to(device)
decoder.eval()
encoder = checkpoint['encoder'].to(device)
encoder.eval()

# Load word map (word2ix)
with open(args.word_map, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

In [None]:
for i, (index, row) in enumerate(report.iterrows()):
    if i > SHOW:
        break
    # Encode, decode with attention and beam search
    seq, alphas = caption_image_beam_search(encoder, decoder, row["path"], word_map, args.beam_size)
    alphas = torch.FloatTensor(alphas)
    # Visualize caption and attention of best sequence
    visualize_att(row["path"], seq, alphas, rev_word_map, args.smooth)
    plt.title("correct: %s, predicted: %s\ndistance: %d, prob: %.3f" % (
        row["correct"], row["predicted"], row["edit_distance"], row["probability"]), fontsize=12)