# __Model Beam Search Decoding hyperparameters grid search__

### __Deep Learning__

#### __Project: Image Captioning with Visual Attention__

In [1]:
import os
os.chdir(os.environ["PYTHONPATH"])

import json

import torch
import torchvision

import scripts.data_loading as dl
import scripts.data_processing as dp
from scripts import model
from scripts.eval import CocoValidator

%load_ext autoreload
%autoreload 2

MODEL_PATH = "./models/best_models/best_model_e256_a256_d512_lr9e-05.pth"
validator = CocoValidator(dl.DATASET_PATHS[dl.DatasetType.VALIDATION], dp.Vocabulary())

loading annotations into memory...
Done (t=0.17s)
creating index...
index created!


In [2]:
torch.cuda.is_available()

True

In [3]:
state_dict = torch.load(MODEL_PATH)
decoder_weights = state_dict["decoder"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Greedy decoding BLEU-4: " + str(state_dict["bleu_4"]))
decoder = model.LSTMDecoder(num_embeddings=10_004, embedding_dim=256, encoder_dim=196, decoder_dim=512, attention_dim=256)
decoder.load_state_dict(state_dict["decoder"])
decoder.to(device)
encoder = model.VGG19Encoder()
encoder.to(device)
print('model loading finished')

Greedy decoding BLEU-4: 0.04379865235910912
model loading finished


In [11]:
beam_sizes = [100, 200]
num_sequences = [1, 3, 5]

results = []
for beam_size in beam_sizes:
    for num_seq in num_sequences:
        result = validator.validate_beam(encoder, decoder, beam_size, num_seq, device)
        
        results.append({"beam_size": beam_size, "num_seq": num_seq, "bleu_4": result})
        print(f"bleu-4={result} for beam={beam_size} num_seq={num_seq}")

with open('./beam_hyp.json', 'w') as out_json:
    json.dump(results , out_json)

bleu-4=0.0021536660842141737 for beam=100 num_seq=1
bleu-4=0.006660488824502859 for beam=100 num_seq=3
bleu-4=0.009767907412849232 for beam=100 num_seq=5
bleu-4=0.00012228212423081913 for beam=200 num_seq=1
bleu-4=0.0005507570913908208 for beam=200 num_seq=3
bleu-4=0.0012397428113114983 for beam=200 num_seq=5
