# Retrieval visualization

In [None]:
import logging
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader

import paths
from datasets.voc12 import VOCDataset, create_image_only_transforms
from models.blip.blip_config import BlipConfig
from models.blip.blip_retrieval import BlipRetrieval
from models.preprocessing.preprocess import get_processors
from utils.logger import setup_logger

In [None]:
# setup model for validation
setup_logger(level=logging.INFO)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
logging.info(f"Running on device: {device}, cuda available: {torch.cuda.is_available()}")

model_cfg = BlipConfig()
model = BlipRetrieval.from_config(model_cfg)
logging.info(f"Created model {type(model).__name__} with {model.show_n_params()} parameters.")
model.load_checkpoint(Path(paths.CV_PATH_CKPT) / "blip_model_base.pth")
# todo optionally overwrite this with your own checkpoint
eval_ckpt = Path(paths.CV_PATH_CKPT) / "blip_model_retrieval_head.pth"
model.load_retrieval_head(eval_ckpt)
model = model.to(device)
model.eval()
pass


In [None]:
# setup dataset
voc_path = Path(paths.CV_PATH_VOC)
vis_processor_val, text_processor_val = get_processors(model_cfg, mode="eval")
dataset = VOCDataset(voc_path, voc_path / "ImageSets" / "Segmentation" / "val.txt",
                     load_captions=True, transforms=create_image_only_transforms(vis_processor_val))


In [None]:
# setup code to show a datapoint
def show_datapoint(n):
    data = dataset[n]
    # now data["image"] is the tensor preprocessed for model input, get the file instead
    image_file = dataset.files[n]["img"]
    image_pil = Image.open(image_file)
    display(image_pil)
    print(f"Name: {data['name']}")
    print(f"Reference caption: {data['caption']}")


show_datapoint(0)

In [None]:
# collect image features for the dataset
val_dataset = VOCDataset(
    voc_path, voc_path / "ImageSets" / "Segmentation" / "val.txt", load_captions=True)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False,
                        num_workers=0, drop_last=False)
logging.info(f"Collect image features")

image_feats, text_feats = [], []
for i, batch in enumerate(dataloader):
    if i % 10 == 0:
        logging.info(f"{i}/{len(dataloader)}")    
    image = batch["image"].to(device)
    with torch.no_grad():
        image_feat = model.forward_image(image)
        image_feats.append(image_feat.detach().cpu().numpy())
image_feats = np.concatenate(image_feats, axis=0)


## Search the dataset via text-to-image retrieval

In [None]:
# define search query
caption = "a picture of a plane"
print(f"Search query: {caption}")

# get the text feature
with torch.no_grad():
    text_feat = model.forward_text([text_processor_val(caption)]).cpu().numpy()

# compute similarity
sim = (text_feat @ image_feats.T)[0]

# show the top10 results
top10 = np.argsort(-sim)[:10]
for rank, i in enumerate(top10):
    sim_here = sim[i]
    show_datapoint(i)
    print(f"Rank {rank + 1} with similarity: {sim_here:.3f}")