In [1]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
# from datasets import *
# from utils.utils import *
from dataloader import Flickr8KDataset
from decoder import CaptionDecoder
from utils.decoding_utils import greedy_decoding
from utils.utils import save_checkpoint, log_gradient_norm, set_up_causal_mask
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from trainer import evaluate
import json
import torchvision.models as models
from decoder import CaptionDecoder

In [3]:
config_path = "config.json"
with open(config_path, "r", encoding="utf8") as f:
    config = json.load(f)
    
use_gpu = config["use_gpu"] and torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")

encoder = models.resnet50(pretrained=True)

In [4]:
#######################
# Set up the encoder 
#######################
# Download pretrained CNN encoder
encoder = models.resnet50(pretrained=True)
# Extract only the convolutional backbone of the model
encoder = torch.nn.Sequential(*(list(encoder.children())[:-2]))
encoder = encoder.to(device)
# Freeze encoder layers
for param in encoder.parameters():
    param.requires_grad = False
encoder.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [5]:
decoder = CaptionDecoder(config)
decoder = decoder.to(device)

checkpoint_path = config["checkpoint"]["checkpoint_path"]
decoder.load_state_dict(torch.load(checkpoint_path))

<All keys matched successfully>

In [6]:
test_set = Flickr8KDataset(config, config["split_save"]["test"], training=False)



In [7]:
print(device)

cpu


In [8]:
valid_bleu = evaluate(test_set, encoder, decoder, config, device)

Evaluating model.


In [9]:
print(valid_bleu)

[24.25508596479211, 12.72923072220504, 5.479570187993903, 1.9803409551077067]
