In [1]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from dataloader import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
from rouge_score import rouge_scorer
from rouge import Rouge

In [2]:
# Parameters
# data_folder = '/media/ssd/caption data'  # folder with data files saved by create_input_files.py
# data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files
df_path='/home/ss4yd/vision_transformer/captioning_vision_transformer/prepared_prelim_data_tokenized_cls256.pickle'
data_name = 'hipt_captioning_task'  # base name shared by data files
word_map=read_obj('./word_map_cls256.pickle')

checkpoint = './BEST_checkpoint_hipt_captioning_task.pth.tar'  # model checkpoint
# word_map_file = '/media/ssd/caption data/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json'  # word map, ensure it's the same the data was encoded with and the model was trained with
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
checkpoint = torch.load(checkpoint)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

# Load word map (word2ix)
# with open(word_map_file, 'r') as j:
#     word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

In [3]:
beam_size=1
loader = torch.utils.data.DataLoader(
            PreLoadedReps_v2(df_path,'test'),
            batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

# TODO: Batched Beam Search
# Therefore, do not use a batch_size greater than 1 - IMPORTANT!

# Lists to store references (true captions), and hypothesis (prediction) for each image
# If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
# references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
references = list()
hypotheses = list()

with torch.no_grad():
    # For each image
    for i, (image, caps, caplens, allcaps) in enumerate(
            tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

        k = beam_size

        # Move to GPU device, if available
        image = image.to(device).squeeze(0)  # (1, 3, 256, 256)
#         print(image.shape)
#         print([rev_word_map[i] for i in caps[0].numpy()])
        # Encode
        encoder_out,_ = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(-1)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder.init_hidden_state(encoder_out)

        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:
            embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

            awe, _ = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
                
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)
            
            # Convert unrolled indices to actual indices of scores
            prev_word_inds = torch.div(top_k_words, vocab_size, rounding_mode='floor')  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)
            
#             print(prev_word_inds,next_word_inds)

            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != word_map['<end>']]
#             print(incomplete_inds)
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 200:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}],
                img_caps))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])

        assert len(references) == len(hypotheses)

# Calculate BLEU-4 scores
bleu4 = corpus_bleu(references, hypotheses)
bleu4

EVALUATING AT BEAM SIZE 1: 100%|██████████| 5/5 [00:01<00:00,  3.74it/s]


0.11288196593337896

In [4]:
# if __name__ == '__main__':
#     beam_size = 1
#     print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size)))

In [18]:
encoder(image)[0].shape

torch.Size([1, 192])

In [20]:
torch.argmax(encoder(image)[1])



tensor(16, device='cuda:0')

In [23]:
[rev_word_map[i] for i in seqs[0].cpu().numpy()]

['<start>']

In [7]:
# [rev_word_map[i] for i in pd.read_pickle('./prepared_prelim_data_tokenized.pickle')['idx_tokens'][1]]

In [24]:
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
scores = scorer.score('The quick brown fox jumps over the lazy dog',
                      'The quick brown dog jumps on the log.')

In [25]:
def return_string_list(sll):
    if type(sll[0][0])==list:
        references_str=[x[0] for x in sll]
    else:
        references_str=sll
    references_str=[[rev_word_map[x] for x in sl] for sl in references_str]
    references_str=[' '.join(x) for x in references_str]
    return references_str

hypotheses_str=return_string_list(hypotheses)
references_str=return_string_list(references)

In [26]:
hypotheses_str[0], references_str[0]

('2 pieces , <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>',
 '6 pieces , mild atherosclerosis , some attached <unk>')

In [11]:
# for i in range(len(references_str)):
#     print(scorer.score(references_str[i], hypotheses_str[i]))

In [27]:
rouge = Rouge(metrics=['rouge-l'])
scores = rouge.get_scores(hypotheses_str, references_str)

In [28]:
np.mean([x['rouge-l']['r'] for x in scores])

0.5235714285714286

In [14]:
image.shape

torch.Size([30, 256, 384])

In [16]:
k = 3
vocab_size = len(word_map)

# Read image and process
# img = imread(image_path)
# if len(img.shape) == 2:
#     img = img[:, :, np.newaxis]
#     img = np.concatenate([img, img, img], axis=2)
# img = imresize(img, (256, 256))
# img = img.transpose(2, 0, 1)
# img = img / 255.
# img = torch.FloatTensor(img).to(device)
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                  std=[0.229, 0.224, 0.225])
# transform = transforms.Compose([normalize])
# image = transform(img)  # (3, 256, 256)

# Encode
# image = image.unsqueeze(0)  # (1, 3, 256, 256)
encoder_out,_ = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
enc_image_size = encoder_out.size(1)
encoder_dim = encoder_out.size(-1)

# Flatten encoding
encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
num_pixels = encoder_out.size(1)

# We'll treat the problem as having a batch size of k
encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

# Tensor to store top k previous words at each step; now they're just <start>
k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

# Tensor to store top k sequences; now they're just <start>
seqs = k_prev_words  # (k, 1)

# Tensor to store top k sequences' scores; now they're just 0
top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

# Tensor to store top k sequences' alphas; now they're just 1s
seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device)  # (k, 1, enc_image_size, enc_image_size)

# Lists to store completed sequences, their alphas and scores
complete_seqs = list()
complete_seqs_alpha = list()
complete_seqs_scores = list()

# Start decoding
step = 1
h, c = decoder.init_hidden_state(encoder_out)

# s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
while True:

    embeddings = decoder.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)

    awe, alpha = decoder.attention(encoder_out, h)  # (s, encoder_dim), (s, num_pixels)

    alpha = alpha.view(-1, enc_image_size, enc_image_size)  # (s, enc_image_size, enc_image_size)

    gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
    awe = gate * awe

    h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

    scores = decoder.fc(h)  # (s, vocab_size)
    scores = F.log_softmax(scores, dim=1)

    # Add
    scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

    # For the first step, all k points will have the same scores (since same k previous words, h, c)
    if step == 1:
        top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
    else:
        # Unroll and find top scores, and their unrolled indices
        top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

    # Convert unrolled indices to actual indices of scores
    prev_word_inds = torch.div(top_k_words, vocab_size, rounding_mode='floor')  # (s)
    next_word_inds = top_k_words % vocab_size  # (s)

    # Add new words to sequences, alphas
    seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
    seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
                           dim=1)  # (s, step+1, enc_image_size, enc_image_size)

    # Which sequences are incomplete (didn't reach <end>)?
    incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                       next_word != word_map['<end>']]
    complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

    # Set aside complete sequences
    if len(complete_inds) > 0:
        complete_seqs.extend(seqs[complete_inds].tolist())
        complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
        complete_seqs_scores.extend(top_k_scores[complete_inds])
    k -= len(complete_inds)  # reduce beam length accordingly

    # Proceed with incomplete sequences
    if k == 0:
        break
    seqs = seqs[incomplete_inds]
    seqs_alpha = seqs_alpha[incomplete_inds]
    h = h[prev_word_inds[incomplete_inds]]
    c = c[prev_word_inds[incomplete_inds]]
    encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
    top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
    k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

    # Break if things have been going on too long
    if step > 50:
        break
    step += 1

i = complete_seqs_scores.index(max(complete_seqs_scores))
seq = complete_seqs[i]
alphas = complete_seqs_alpha[i]

seq, alphas

RuntimeError: shape '[-1, 192, 192]' is invalid for input of size 3

In [None]:
[rev_word_map[i] for i in seq]

In [None]:
len(seq)

In [None]:
len(alphas)

In [None]:
complete_seqs

In [None]:
[rev_word_map[i] for i in complete_seqs[2]]

In [None]:
caps[0].numpy()

In [None]:
[rev_word_map[i] for i in caps[0].numpy()]