In [1]:
import os
os.environ['http_proxy'] = "http://192.41.170.23:3128" 
os.environ['https_proxy'] = "http://192.41.170.23:3128" 

In [2]:
!pip install pycocotools
!pip install ruamel.yaml

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [3]:
import ruamel.yaml as yaml
import numpy as np
import torch
from torch import nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, PILToTensor
import cv2
from captum.attr import visualization
from tqdm import tqdm
from transformers import ViTForImageClassification
import timm
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import json
from torch.utils.data import DataLoader
import torch
from torchvision.datasets import CocoCaptions
import pycocotools

from collections import OrderedDict
from datasets import load_dataset
import gc
from typing import Any, Tuple, Callable, Optional, List
from sklearn.metrics import recall_score

from models.model_retrieval_mplug import MPLUG
from models.vit import interpolate_pos_embed, resize_pos_embed
from models.tokenization_bert import BertTokenizer
from torchvision import transforms

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

cuda_card = 1

In [4]:
def normalize_vector(arr):
    norms = torch.linalg.norm(arr, axis=1, keepdims=True)
    return arr / norms

In [5]:
class CocoCustom(CocoCaptions): 
    def __init__(
        self,
        root: str,
        annFile: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, annFile, transform, target_transform, transforms)
        from pycocotools.coco import COCO

        self.annotations = json.load(open(annFile))
        self.num_captions = len(self.annotations['annotations'])

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, id, target

    def getAnnotationRange(self, index: int, count: int) -> List[Any]:
        return [self.annotations['annotations'][index]['caption'] for index in range(index, index+count)]

    def getImgIdFromAnnotationIndex(self, annotation_index: int) -> int:
        return self.annotations['annotations'][annotation_index]['image_id']
    
    def buildFaissIndex(self, text_encoder, tokenize, batch_size, nlist) :
        tokenized = tokenize(self.getAnnotationRange(0, batch_size)).cuda(cuda_card)
        encoded_captions = normalize_vector(text_encoder(tokenized, get_all_token=False).detach().cpu().numpy().astype('float32'))
        vector_dimension = encoded_captions.shape[1]
        
        quantizer = faiss.IndexFlatIP(vector_dimension)
        index = faiss.IndexIVFFlat(quantizer, vector_dimension, nlist)
        index.train(encoded_captions)
        index.add(encoded_captions)
        
        for i in tqdm(range(batch_size, self.num_captions - batch_size, batch_size)):
            tokenized = clip.tokenize(self.getAnnotationRange(i, batch_size)).cuda(cuda_card)
            encoded_captions = normalize_vector(model.encode_text(tokenized, get_all_token=False).detach().cpu().numpy().astype('float32'))
            index.add(encoded_captions)

        return index

    def __len__(self) -> int:
        return len(self.ids)


In [6]:
config = yaml.load(open("./configs/retrieval_coco_mplug_large.yaml", 'r'), Loader=yaml.Loader)

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = MPLUG(config=config, tokenizer=tokenizer)



In [8]:
checkpoint = torch.load('./checkpoint/mplug_large.pth', map_location='cpu') 
state_dict = checkpoint['model']

if config["clip_name"] == "ViT-B-16":
    num_patches = int(config["image_res"] * config["image_res"] / (16 * 16))
elif config["clip_name"] == "ViT-L-14":
    num_patches = int(config["image_res"] * config["image_res"] / (14 * 14))

pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())
pos_embed = resize_pos_embed(state_dict['visual_encoder.visual.positional_embedding'].unsqueeze(0),
                                               pos_embed.unsqueeze(0))
state_dict['visual_encoder.visual.positional_embedding'] = pos_embed

if config['distill']:
    if config["clip_name"] == "ViT-B-16":
        num_patches = int(config["image_res"] * config["image_res"] / (16 * 16))
    elif config["clip_name"] == "ViT-L-14":
        num_patches = int(config["image_res"] * config["image_res"] / (14 * 14))
    pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())

    pos_embed = resize_pos_embed(state_dict['visual_encoder_m.visual.positional_embedding'].unsqueeze(0),
                                 pos_embed.unsqueeze(0))
    state_dict['visual_encoder_m.visual.positional_embedding'] = pos_embed

for key in list(state_dict.keys()):
    if ('fusion' in key or 'bert' in key) and 'decode' not in key:
        encoder_key = key.replace('fusion.', '').replace('bert.', '')
        state_dict[encoder_key] = state_dict[key]
        del state_dict[key]

msg = model.load_state_dict(state_dict, strict=False)
print('load checkpoint')
# print(msg)
model = model.cuda(cuda_card)

load checkpoint


In [9]:
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
test_transform = transforms.Compose([
        transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        normalize,
 ])

In [10]:
path = '../../Dataset/CV/mscoco/2017'
cocoCaptions = CocoCustom(root = path + '/val2017',
                        annFile = path + '/annotations/captions_val2017.json',
                        transform=test_transform)

print('Number of samples: ', len(cocoCaptions))
img, img_id, target = cocoCaptions[3]

print("Image Size:", img.size())
print("Captions:", target)
print("Image Id:", img_id)

loading annotations into memory...
Done (t=0.82s)
creating index...
index created!
Number of samples:  5000
Image Size: torch.Size([3, 336, 336])
Captions: ["A stop sign is mounted upside-down on it's post. ", 'A stop sign that is hanging upside down.', 'An upside down stop sign by the road.', 'a stop sign put upside down on a metal pole ', 'A stop sign installed upside down on a street corner']
Image Id: 724


In [11]:
caption_map = {}
captions = []
caption_idx = 0
images_id_list = []
for i in tqdm(range(len(cocoCaptions))):
    img, img_id, target = cocoCaptions[i]
    images_id_list.append(img_id)
    for idx in range(len(target)):
        capt = target[idx]
        caption_map[caption_idx] = img_id
        caption_idx += 1
        captions.append(capt)

 12%|█▏        | 584/5000 [01:52<14:10,  5.19it/s]  


In [None]:
device = torch.device('cuda:' + str(cuda_card))

In [None]:
batch_size = 16
text_feats = []
text_atts = []
text_embeds = []

for i in tqdm(range(0, len(captions), batch_size)):
    text = captions[i: min(len(captions), i+batch_size)]
    text_input = tokenizer(text, padding='max_length', truncation=True, max_length=30, return_tensors="pt").to(device)
    text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask)
    text_feat = text_output.last_hidden_state.detach()
    text_embed = model.text_proj(text_feat[:,0,:]).detach()
    text_feats.append(text_feat)
    text_atts.append(text_input.attention_mask)
    text_embeds.append(normalize_vector(text_embed))
    
text_feats = torch.concat(text_feats)
text_atts = torch.concat(text_atts)
text_embeds = torch.concat(text_embeds)

In [None]:
torch.cuda.empty_cache()

In [None]:
def collate_coco(data):
    images = torch.stack([image for image, image_id, text in data])
    images_id = torch.tensor([image_id for image, image_id, text in data])
    return images, images_id

batch_size = 50
data_loader = DataLoader(cocoCaptions, collate_fn=collate_coco, batch_size=batch_size)

In [None]:
model.eval()
correct_count = 0
total_samples = 0
k = 70

image_feats = []
image_embeds = []

with torch.no_grad():
    for images, images_id in tqdm(data_loader):
        images = images.cuda(cuda_card)
        image_feat = model.visual_encoder.visual(images, skip_last_layer=True).detach()
        image_feat = model.visn_layer_norm(model.visn_fc(image_feat))
        
        image_embed = model.vision_proj(image_feat[:,0,:]).detach()            
        image_embed = normalize_vector(image_embed)

        image_feats.append(image_feat)
        image_embeds.append(image_embed)

image_feats = torch.concat(image_feats)
image_embeds = torch.concat(image_embeds)

sims_matrix = image_embeds.matmul(text_embeds.T)        

In [None]:
score_matrix_i2t = torch.full((len(images_id_list), len(captions)), -100.0)
score_matrix_t2i = torch.full((len(captions), len(images_id_list)), -100.0)

for idx, sims in enumerate(sims_matrix):
    _, topk_idx = sims.topk(k=k, dim=0)
    encoder_output = image_feats[idx].repeat(k,1,1)
    encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
    _, output = model.fusion_encoder(encoder_embeds = text_feats[topk_idx], 
                                attention_mask = text_atts[topk_idx],
                                encoder_hidden_states = encoder_output,
                                encoder_attention_mask = encoder_att,                             
                                return_dict = False,
                               )
    score = model.itm_head(output[:, 0, :])[:, 1].detach().cpu()
    score_matrix_i2t[idx, topk_idx] = score

sims_matrix = sims_matrix.t()

for idx, sims in enumerate(sims_matrix):
    _, topk_idx = sims.topk(k=k, dim=0)
    encoder_output = image_feats[topk_idx]
    encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
    _, output = model.fusion_encoder(encoder_embeds=text_feats[idx].repeat(k, 1, 1),
                                     attention_mask=text_atts[idx].repeat(k, 1),
                                     encoder_hidden_states=encoder_output,
                                     encoder_attention_mask=encoder_att,
                                     return_dict=False,
                                    )
    score = model.itm_head(output[:, 0, :])[:, 1]
    score_matrix_t2i[idx, topk_idx] = score.float()
        

In [None]:
img2txt = dict([(value, []) for value in images_id_list])
for key, value in caption_map.items():
    img2txt[value].append(key)

In [None]:
ranks = np.zeros(score_matrix_i2t.shape[0])
for index, score in enumerate(score_matrix_i2t):
    score = np.array(score)
    inds = np.flip(np.argsort(score))
    # Score
    rank = 1e20
    for i in img2txt[images_id_list[index]]:
        tmp = np.where(inds == i)[0] # search for text i location in score array tmp is single index number
        if len(tmp) != 0: 
            tmp = tmp[0]
        else:
            tmp = 1e20
        if tmp < rank:
            rank = tmp
    ranks[index] = rank

# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

for index,score in enumerate(score_matrix_t2i):
    inds = np.argsort(score)[::-1]
    ranks[index] = np.where(inds == caption_map[index])[0][0]

# Compute metrics
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)   

In [None]:
model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
file_path = "result.txt"

with open(file_path, 'w') as file:
    file.write("Tr1: " + str(tr1) + "\n")
    file.write("Tr5: " + str(tr5) + "\n")
    file.write("Tr10: " + str(tr10) + "\n")
    file.write("Ir1: " + str(ir1) + "\n")
    file.write("Ir5: " + str(ir5) + "\n")
    file.write("Ir10: " + str(ir10) + "\n")
    file.write("Model param: " + str(model_params) + "\n")