In [1]:
import pandas as pd
import torch
import torchvision.transforms as transforms
import pickle

from data_utils import get_karpathy_split, refcoco_splits
from data_loader import get_caption_loader, COCOCaptionDataset, get_reg_loader

from build_vocab import Vocabulary

In [2]:
import sys
sys.path.append('/home/simeon/Dokumente/Code/Uni/Repos/Adaptive/nlg-eval')
from nlgeval import NLGEval
nlgeval = NLGEval(no_skipthoughts=True, no_glove=True)  # loads the models

In [3]:
crop_size=224
image_dir='/home/simeon/Dokumente/Code/Data/COCO/'

In [4]:
with open('data/coco_vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [4]:
caps_df = get_karpathy_split(splits_path='/home/simeon/Dokumente/Code/Data/COCO/splits/karpathy/caption_datasets/', caps_path='/home/simeon/Dokumente/Code/Data/COCO/')

In [5]:
transform = transforms.Compose([
    transforms.Resize((crop_size, crop_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

In [7]:
c = caps_df.loc[caps_df.split == 'restval'].iloc[:1000]

In [8]:
loader = get_caption_loader(
        decoding_level='word', 
        split=['val'],
        data_df=caps_df, 
        image_dir=image_dir, 
        vocab=vocab,
        transform=transform, 
        batch_size=20, 
        shuffle=False,
        num_workers=2, 
        drop_last=False
)

In [9]:
for i, (images, captions, lengths, _, _) in enumerate(loader):    
    if i > 2:
        break
    print(i)
    
idx = [i.item() for i in captions[2]]
' '.join([vocab.idx2word[i] for i in idx])

0
1
2


'<start> a bicycle is chained up to a pole at a train station <end> <pad> <pad>'

In [30]:
len(loader)

1251

In [35]:
from tqdm.autonotebook import tqdm

hypotheses = []
references = []

for i, (images, _, _, image_ids, _) in enumerate(loader):
    
    if i % 100 == 0:
        print(i)
    
    # Build caption based on Vocabulary and the '<end>' token
    for image_idx in range(images.size()[0]):

        img_id = int(image_ids[image_idx])

        refs = caps_df.loc[caps_df.image_id == img_id].caption.to_list()
        references.append(refs)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200


In [6]:
with open('data/refcoco_vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [7]:
ref_df = refcoco_splits('/home/simeon/Dokumente/Code/Data/RefCOCO/refcoco/')[0]

In [8]:
loader = get_reg_loader(
        decoding_level='word', 
        split=['val'],
        data_df=ref_df.groupby('ann_id').agg('first').reset_index(), 
        image_dir=image_dir, 
        vocab=vocab,
        transform=transform, 
        batch_size=20, 
        shuffle=False,
        num_workers=2, 
        drop_last=False
)

In [9]:
len(loader)

191

In [10]:
for i, (images, captions, positions, lengths, ann_ids, filenames) in enumerate(loader):    
    if i > 2:
        break
    print(i)

0
1
2


In [11]:
idx = [i.item() for i in captions[2]]
' '.join([vocab.idx2word[i] for i in idx])

'<start> bird closest to camera <end> <pad> <pad> <pad>'

In [12]:
from adaptive_reg import Encoder2Decoder
import torch.nn.functional as F
import torch.nn as nn

In [13]:
model = Encoder2Decoder(256, len(vocab), 512)

In [14]:
#encoder = AttentiveCNN(256, 512)
# encoder.affine_b = nn.Linear(2048+7, 256)

In [15]:
model.sampler(images, positions)

(tensor([[2465,  493, 2886,  454,  239, 1564,  454,  981, 2209, 2717,  565, 1634,
          1833, 1409,   18, 2105,  507,  410, 2465, 2465],
         [1835, 1095, 2657, 2394, 2062,  405, 2510,  341,  729,  931, 1095, 1634,
          1521, 2465, 2465, 1255, 2465,  843, 2165, 2165],
         [2209,  701, 1634, 2465,  639, 2324, 2460, 2465, 1897, 2465, 2653,  870,
           707, 2209,  933, 1392, 1598, 1095, 2394, 1203],
         [2465, 2465, 2796, 2209, 2046, 2445,  275, 1952, 2465,  791, 1952,  289,
           701, 2465,  454,  512,  512, 2417,  830, 2465],
         [2465, 1634, 2165, 2031, 2698, 2165, 2165,  788, 2465, 2465, 1049, 1897,
           922, 2165, 2465, 2465, 1049, 1416,  701, 2962],
         [1672, 1318,  311, 2465, 2209, 1891, 2465,  757, 1897, 1049, 1049, 1049,
          1634, 2465, 2465, 1318,  488,  208,  134, 1049],
         [1255, 1952, 2465,  512, 1255, 2110, 1489, 2465,  493,  454,  843,  729,
          2153,  841,  454,  454, 2465, 2465, 2465, 2796],
         [246