In [None]:
import json
import matplotlib.pyplot as plt
from PIL import Image
import os
import random
from torchvision import transforms

def show_pil_img(img, title=None):
    """
    Gets PIL image as input and plots it
    """
    plt.figure()
    if title:
        plt.title(title)
    plt.imshow(img)

def show_normalized_image_tensor(img_tensor, mean = (0.48145466, 0.4578275, 0.40821073),
                                 std = (0.26862954, 0.26130258, 0.27577711),
                                 title=None):
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.4578275/0.224, -0.40821073/0.255],
        std=[1/0.26862954, 1/0.26130258, 1/0.27577711]
        )
    inv_tensor = inv_normalize(img_tensor)
    npimg = inv_tensor.numpy()
    plt.figure()
    if title:
        plt.title(title)
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
from omegaconf import OmegaConf

from dataloader.stir_dataset import STIRDataset
from dataloader.data_loaders import get_dataloader

from model import composition_models
from utils.simple_tokenizer import SimpleTokenizer
from utils.util import set_seed, mkdir, load_config_file, write_json
from utils.logger import setup_logger

from tqdm import tqdm

DATA_CONFIG_PATH = 'configs/dataset_config.yaml'
TRAINER_CONFIG_PATH = 'configs/train_config.yaml'

data_config = load_config_file(DATA_CONFIG_PATH)
train_config = load_config_file(TRAINER_CONFIG_PATH)

config = OmegaConf.merge(train_config, data_config)

config.device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# getting text tokenizer
tokenizer = SimpleTokenizer()

# getting dataset for validation
val_dataset = STIRDataset(data_config, tokenizer, split='val')
train_dataset = STIRDataset(data_config, tokenizer, split='train')

# getting model
texts = train_dataset.get_all_texts()
config.model = 'concat'
config.embed_dim = 512
config.n_gpu = 1
opt = config
"""Builds the model and related optimizer."""
print('Creating model and optimizer for', opt.model)
if opt.model == 'concat':
    model = composition_models.Concat(texts, embed_dim=opt.embed_dim)

checkpoint_path = "/home/trevant/DL4CV/project_saved_checkpoints/checkpoint_best_concat.pt"
assert checkpoint_path is not None
assert os.path.isfile(checkpoint_path)

print(f"Loading saved checkpoint at {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(torch.device(config.device))

In [None]:
#  ranking all targets for all queries
model.eval()
losses = []
all_query_features = []
all_target_features = []

from torch.utils.data import random_split
# testing for first 1000 data points only
indices = torch.randperm(len(val_dataset))[:1000]
small_set = torch.utils.data.Subset(val_dataset, indices)

eval_dataloader = get_dataloader(config, small_set, is_train=True)
with torch.no_grad():
    for step, batch in tqdm(enumerate(eval_dataloader), desc="evaluating"):
        query_img_input, query_text_input, target_img_input = batch['query_img_input'], batch['query_text'], batch['target_img_input']

        query_img_input = query_img_input.to(torch.device(config.device))
        # query_text_input = query_text_input.to(torch.device(config.device))
        target_img_input = target_img_input.to(torch.device(config.device))
        
        # FWD dunction itself calculates loss
        composition_features = model.compose_img_text(query_img_input, query_text_input)
        target_image_features = model.extract_img_feature(target_img_input)
        
        all_query_features.append(composition_features)
        all_target_features.append(target_image_features)

    all_query_features = torch.vstack(all_query_features)
    all_target_features = torch.vstack(all_target_features)

    # normalizing
    # normalized features
    all_query_features = all_query_features / all_query_features.norm(dim=-1, keepdim=True)
    all_target_features = all_target_features / all_target_features.norm(dim=-1, keepdim=True)

    similarity = all_query_features @ all_target_features.t()


In [None]:
sorted_targets = torch.argsort(similarity, dim=1, descending=True)
print("sorted targets", sorted_targets)

Now visualizing the retrieved images

In [None]:
IMG_DIR = "/home/trevant/DL4CV/project_datasets/VG_ALL" # path of VG images
SKETCHES_PATH = "/home/trevant/DL4CV/project_datasets/imagenet-sketch"

In [None]:
INDEX_TO_VIEW = 42

query_img_path = val_dataset[INDEX_TO_VIEW]['query_img_path']
query_text = val_dataset[INDEX_TO_VIEW]['query_text']
gt_target = val_dataset[INDEX_TO_VIEW]['target_img_id']

top_retrieved_images_paths = []
for target_idx in sorted_targets[INDEX_TO_VIEW, :10]:
    target_img_id = val_dataset[target_idx]['target_img_id']
    top_retrieved_images_paths.append(target_img_id)


In [None]:
## Directly sketch reading path from datset
query_img_path = os.path.join(SKETCHES_PATH, query_img_path)
query_img = Image.open(query_img_path)
show_pil_img(query_img, title=f"+ {query_text}")
# ----------
show_normalized_image_tensor(val_dataset[INDEX_TO_VIEW]['target_img_input'], 
                             title=f"GT TARGET")

for i, target_idx in enumerate(sorted_targets[INDEX_TO_VIEW, :10]):
    show_normalized_image_tensor(val_dataset[target_idx]['target_img_input'], title=f"rank {i+1}")
