# S-Bert caption representation and Pairwise Cosine Similarity

In [1]:
import json
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm
import torch

In [2]:
bert = SentenceTransformer('bert-large-nli-stsb-mean-tokens')

## COCO

In [3]:
f1 = '/data/project/rw/woong.ssang/graph_matching/txt2img_new/data/coco/annotations/captions_train2014.json'
cocotr = json.load(open(f1, 'r'))
f2 = '/data/project/rw/woong.ssang/graph_matching/txt2img_new/data/coco/annotations/captions_val2014.json'
cocoval = json.load(open(f2, 'r'))

In [4]:
l_sent_id = [t['id'] for t in cocotr['annotations']] + [t['id'] for t in cocoval['annotations']]

In [34]:
l_img_id = [t['image_id'] for t in cocotr['annotations']] + [t['image_id'] for t in cocoval['annotations']]


In [6]:
d_img_id2idx = {img_id: idx for idx, img_id in enumerate(sorted(list(set(l_img_id))))}

In [5]:
l_sent = [t['caption'] for t in cocotr['annotations']] + [t['caption'] for t in cocoval['annotations']]

In [7]:
l_sent[:3]

['A very clean and well decorated empty bathroom',
 'A panoramic view of a kitchen and all of its appliances.',
 'A blue and white bathroom with butterfly themed wall tiles.']

In [8]:
l_emb = bert.encode(l_sent, batch_size=128, show_progress_bar=True, )

Batches: 100%|██████████| 4819/4819 [14:57<00:00,  5.37it/s]


In [12]:
l_emb = np.array(l_emb)

In [17]:
l_norm_emb = l_emb / np.sqrt((l_emb ** 2).sum(axis=1, keepdims=True))

In [14]:
np.save('coco_sbert_emb.npy', {'image_id':np.array(l_img_id), 
                               'sent_id': np.array(l_sent_id),
                               'normed_sbert_emb': l_norm_emb})

In [2]:
d = np.load('coco_sbert_emb.npy', allow_pickle=True)

In [3]:
d = d.tolist()
l_img_id = d['imgage_id']
l_sent_id = d['sent_id']
l_norm_emb = d['normed_sbert_emb']

In [4]:
B = np.zeros((123287, 5, 1024), dtype='float32')
cap_count = np.zeros((123287,), dtype='int')

In [7]:
for i, (sent_id, img_id, emb) in enumerate(zip(l_sent_id, l_img_id, l_norm_emb)):
    img_idx = d_img_id2idx[img_id]
    sent_idx = cap_count[img_idx]
    if sent_idx >= 5:
        # skip more captions
#         print(f'{img_id}')
        continue
    B[img_idx, sent_idx, :] = emb
    cap_count[img_idx] += 1

In [8]:
n_img = len(B)
n_batch = 20
batch_size = int(n_img / n_batch)
pbar = tqdm(total=n_batch ** 2)
l_row = []
for i in range(n_batch):
    i_s = i * batch_size
    if i == n_batch - 1:
        i_e = n_img
    else:
        i_e = (i + 1) * batch_size
        
    B_i = B[i_s:i_e]
    B_i = torch.from_numpy(B_i).cuda()
    
    l_B = []
    for j in range(n_batch):
#         print(j)
        j_s = j * batch_size
        if j == n_batch - 1:
            j_e = n_img
        else:
            j_e = (j + 1) * batch_size
            
        B_j = B[j_s:j_e]
        B_j = torch.from_numpy(B_j).cuda()
        
        BB = torch.einsum('ijk,lmk->iljm', B_i, B_j)
        BB = BB.reshape((len(B_i), len(B_j), -1)).mean(axis=-1)
        
        l_B.append(BB.cpu())
        
        pbar.update(1)
        
    row = np.hstack(l_B)
    l_row.append(row)

M = np.vstack(l_row)

100%|██████████| 400/400 [03:04<00:00,  3.17it/s]

In [9]:
np.save('coco_sbert_mean.npy', M)

In [13]:
M.shape

(123287, 123287)

In [15]:
np.save('coco_sbert_img_id.npy', sorted(list(set(l_img_id))))

------

## Flickr 30K

In [3]:
f = '/data/project/rw/woong.ssang/graph_matching/txt2img_new/data/f30k/dataset_flickr30k.json'
f30k = json.load(open(f, 'r'))

In [6]:
f30k['images'][0]['sentences'][0]['raw']

'Two young guys with shaggy hair look at their hands while hanging out in the yard.'

In [8]:
f30k['images'][0]

{'sentids': [0, 1, 2, 3, 4],
 'imgid': 0,
 'sentences': [{'tokens': ['two',
    'young',
    'guys',
    'with',
    'shaggy',
    'hair',
    'look',
    'at',
    'their',
    'hands',
    'while',
    'hanging',
    'out',
    'in',
    'the',
    'yard'],
   'raw': 'Two young guys with shaggy hair look at their hands while hanging out in the yard.',
   'imgid': 0,
   'sentid': 0},
  {'tokens': ['two',
    'young',
    'white',
    'males',
    'are',
    'outside',
    'near',
    'many',
    'bushes'],
   'raw': 'Two young, White males are outside near many bushes.',
   'imgid': 0,
   'sentid': 1},
  {'tokens': ['two',
    'men',
    'in',
    'green',
    'shirts',
    'are',
    'standing',
    'in',
    'a',
    'yard'],
   'raw': 'Two men in green shirts are standing in a yard.',
   'imgid': 0,
   'sentid': 2},
  {'tokens': ['a',
    'man',
    'in',
    'a',
    'blue',
    'shirt',
    'standing',
    'in',
    'a',
    'garden'],
   'raw': 'A man in a blue shirt standing in

In [25]:
l_sent_id = [s['sentid'] for t in f30k['images'] for s in t['sentences']] 

In [27]:
l_img_id = [s['imgid'] for t in f30k['images'] for s in t['sentences']]

In [28]:
d_img_id2idx = {img_id: idx for idx, img_id in enumerate(sorted(list(set(l_img_id))))}

In [29]:
l_sent = [s['raw'] for t in f30k['images'] for s in t['sentences']]

In [30]:
l_sent[:3]

['Two young guys with shaggy hair look at their hands while hanging out in the yard.',
 'Two young, White males are outside near many bushes.',
 'Two men in green shirts are standing in a yard.']

In [33]:
l_emb = bert.encode(l_sent, batch_size=128, show_progress_bar=True, )


Batches:   0%|          | 0/1212 [00:00<?, ?it/s][A
Batches:   0%|          | 1/1212 [00:00<02:03,  9.83it/s][A
Batches:   0%|          | 2/1212 [00:00<02:04,  9.75it/s][A
Batches:   0%|          | 3/1212 [00:00<02:05,  9.61it/s][A
Batches:   0%|          | 4/1212 [00:00<02:05,  9.64it/s][A
Batches:   0%|          | 5/1212 [00:00<02:04,  9.66it/s][A
Batches:   0%|          | 6/1212 [00:00<02:09,  9.28it/s][A
Batches:   1%|          | 7/1212 [00:00<02:14,  8.98it/s][A
Batches:   1%|          | 8/1212 [00:00<02:15,  8.86it/s][A
Batches:   1%|          | 9/1212 [00:00<02:17,  8.77it/s][A
Batches:   1%|          | 10/1212 [00:01<02:18,  8.70it/s][A
Batches:   1%|          | 11/1212 [00:01<02:18,  8.66it/s][A
Batches:   1%|          | 12/1212 [00:01<02:14,  8.94it/s][A
Batches:   1%|          | 13/1212 [00:01<02:16,  8.78it/s][A
Batches:   1%|          | 14/1212 [00:01<02:17,  8.69it/s][A
Batches:   1%|          | 15/1212 [00:01<02:18,  8.65it/s][A
Batches:   1%|▏         |

In [34]:
l_emb = np.array(l_emb)

In [35]:
l_norm_emb = l_emb / np.sqrt((l_emb ** 2).sum(axis=1, keepdims=True))

In [36]:
np.save('f30k_sbert_emb.npy', {'image_id':np.array(l_img_id), 
                               'sent_id': np.array(l_sent_id),
                               'normed_sbert_emb': l_norm_emb})

In [37]:
len(l_norm_emb)

155070

In [42]:
from collections import Counter
max([j for i,j in Counter(l_img_id).items()])

5

In [46]:
n_img = len(list(set(l_img_id)))

In [2]:
# d = np.load('coco_sbert_emb.npy', allow_pickle=True)

In [3]:
# d = d.tolist()
# l_img_id = d['imgage_id']
# l_sent_id = d['sent_id']
# l_norm_emb = d['normed_sbert_emb']

In [47]:
B = np.zeros((n_img, 5, 1024), dtype='float32')
cap_count = np.zeros((n_img,), dtype='int')

In [48]:
for i, (sent_id, img_id, emb) in enumerate(zip(l_sent_id, l_img_id, l_norm_emb)):
    img_idx = d_img_id2idx[img_id]
    sent_idx = cap_count[img_idx]
    if sent_idx >= 5:
        # skip more captions
#         print(f'{img_id}')
        continue
    B[img_idx, sent_idx, :] = emb
    cap_count[img_idx] += 1

In [49]:
n_img = len(B)
n_batch = 20
batch_size = int(n_img / n_batch)
pbar = tqdm(total=n_batch ** 2)
l_row = []
for i in range(n_batch):
    i_s = i * batch_size
    if i == n_batch - 1:
        i_e = n_img
    else:
        i_e = (i + 1) * batch_size
        
    B_i = B[i_s:i_e]
    B_i = torch.from_numpy(B_i).cuda()
    
    l_B = []
    for j in range(n_batch):
#         print(j)
        j_s = j * batch_size
        if j == n_batch - 1:
            j_e = n_img
        else:
            j_e = (j + 1) * batch_size
            
        B_j = B[j_s:j_e]
        B_j = torch.from_numpy(B_j).cuda()
        
        BB = torch.einsum('ijk,lmk->iljm', B_i, B_j)
        BB = BB.reshape((len(B_i), len(B_j), -1)).mean(axis=-1)
        
        l_B.append(BB.cpu())
        
        pbar.update(1)
        
    row = np.hstack(l_B)
    l_row.append(row)

M = np.vstack(l_row)


  0%|          | 0/400 [00:00<?, ?it/s][A
  0%|          | 1/400 [00:04<27:06,  4.08s/it][A
  2%|▏         | 6/400 [00:04<18:47,  2.86s/it][A
  3%|▎         | 11/400 [00:04<13:01,  2.01s/it][A
  4%|▍         | 16/400 [00:04<09:02,  1.41s/it][A
  5%|▌         | 20/400 [00:04<06:18,  1.00it/s][A
  6%|▌         | 24/400 [00:04<04:34,  1.37it/s][A
  8%|▊         | 30/400 [00:05<03:10,  1.94it/s][A
  9%|▉         | 36/400 [00:05<02:13,  2.72it/s][A
 10%|█         | 41/400 [00:05<01:37,  3.67it/s][A
 12%|█▏        | 47/400 [00:05<01:09,  5.08it/s][A
 13%|█▎        | 53/400 [00:05<00:49,  6.96it/s][A
 15%|█▍        | 59/400 [00:05<00:36,  9.38it/s][A
 16%|█▌        | 64/400 [00:06<00:30, 11.15it/s][A
 18%|█▊        | 70/400 [00:06<00:22, 14.57it/s][A
 19%|█▉        | 75/400 [00:06<00:17, 18.46it/s][A
 20%|██        | 81/400 [00:06<00:16, 19.38it/s][A
 22%|██▏       | 87/400 [00:06<00:13, 23.80it/s][A
 23%|██▎       | 93/400 [00:06<00:10, 28.36it/s][A
 24%|██▍       | 98/40

In [50]:
np.save('f30k_sbert_mean.npy', M)

In [51]:
M.shape

(31014, 31014)

In [52]:
np.save('f30k_sbert_img_id.npy', sorted(list(set(l_img_id))))

# Flickr Gen

In [8]:
import pickle
import numpy as np

In [3]:
s = pickle.load(open('/data/project/rw/CBIR/data/f30k/f30k_gencap_embs.pkl', 'rb'))

In [11]:
l_emb = np.array([cap['sentence_emb'] for cap in s])
l_id = np.array([cap['imgid'] for cap in s])

In [13]:
M = l_emb.dot(l_emb.T)

In [14]:
M.shape

(1000, 1000)

In [16]:
l_id.shape

(1000,)

In [17]:
np.save('f30k_gencap_sbert.npy', M)
np.save('f30k_gencap_sbert_img_id.npy', l_id)

In [4]:
s

[{'filename': '1007129816.jpg',
  'imgid': 25,
  'gen_cap': 'a man with a red hat and a white hat',
  'sentence_emb': array([ 0.11125354, -0.37085184,  0.5451936 , ...,  0.02885841,
          0.32426697, -0.17793894], dtype=float32)},
 {'filename': '1009434119.jpg',
  'imgid': 34,
  'gen_cap': 'the brown and white dog is running through the grass',
  'sentence_emb': array([ 0.10185666,  0.47394028,  0.01623874, ...,  0.78574514,
         -0.72369504,  0.31394646], dtype=float32)},
 {'filename': '101362133.jpg',
  'imgid': 51,
  'gen_cap': 'a woman in a white outfit is performing a martial arts move',
  'sentence_emb': array([-1.2471358 ,  0.22585154, -1.4792902 , ..., -0.6924333 ,
         -0.23400767, -0.33210197], dtype=float32)},
 {'filename': '102617084.jpg',
  'imgid': 89,
  'gen_cap': 'a group of people are standing in the snow',
  'sentence_emb': array([-0.02196577, -0.22466417, -0.5837976 , ..., -0.20548226,
         -0.02253086, -0.7795834 ], dtype=float32)},
 {'filename': '10