In [84]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.models import resnet50

from models import model_pool
from models.util import create_model
from scipy.spatial import distance

import shutil

In [107]:
# step 1. get support set loader 
normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
transform_train = transforms.Compose([
                    transforms.Resize((550, 550)),
                    transforms.RandomCrop(448, padding=8),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize
                ])
support_dataset = ImageFolder(root='../../datasets/open_set_few_shot_retrieval_set/support_document', transform=transform_train)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=5, shuffle=True, num_workers=4)

In [120]:
# step 2. get query set loader
normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
transform_train = transforms.Compose([
                    transforms.Resize((550, 550)),
                    transforms.RandomCrop(448, padding=8),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize
                ])
query_dataset = ImageFolder(root='../../datasets/open_set_few_shot_retrieval_set/query_set', transform=transform_train)
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=1, shuffle=True, num_workers=4)

In [121]:
# step 3. load the best pretrained model
model_path = './checkpoint/best_model_sofar_multi.pth'
model = create_model('resnet50', 11, 'SOFAR')
ckpt = torch.load(model_path)
corrected_dict = { k.replace('features.', ''): v for k, v in ckpt.items() } 
model.load_state_dict(corrected_dict, strict = False)
model = model.cuda()

model.classifier = nn.Identity()

In [123]:
# step 4. get the support set representation vector (5shot; for 5 support set samples )
for idx, (img, target) in enumerate(support_loader) : 
    img = img.cuda()
    target = target.cuda()
    support_output = model(img)
    
    break

In [125]:
# step 5. Start Retrieval
result_dict = dict()

for idx, (img, target) in enumerate(query_loader) : 
    img = img.cuda()
    target = target.cuda()
    img_path = query_dataset.imgs[idx][0]
    
    query_vec = model(img)
    query_vec = query_vec.cpu().detach().numpy()
    
    distance_list = list()
    for i in range(len(support_output)) :
        support_vec = support_output[i].reshape(1, -1)
        support_vec = support_vec.cpu().detach().numpy()
        #dist = np.linalg.norm(support_vec - query_vec) # euclidean distacne
        dist = distance.cosine(support_vec, query_vec) # cosine distance
        distance_list.append(dist)
    avg_dist = np.mean(distance_list)
    result_dict[img_path] = avg_dist

In [126]:
# step 6. Similarity sort
sorted_dict = sorted(result_dict.items(), key = lambda item: item[1])

In [127]:
# step 6-1. Get the top50 simiilar samples
sorted_dict_50 = sorted_dict[:50] 

In [128]:
top_50_list = list()

for key, item in sorted_dict_50 : 
    top_50_list.append(key)

In [129]:
count = 0

for i in range(len(top_50_list)) :
    if 'document' in top_50_list[i] :
        count += 1

In [130]:
count / 50

0.14

In [131]:
# step 7. save the result to the new folder
for i in range(len(top_50_list)) :
    filenm = top_50_list[i].split('/')[-1]
    dst_file_path = os.path.join('../../datasets/open_set_few_shot_retrieval_set/multi_cosine_document/', filenm)
    shutil.copy(top_50_list[i], dst_file_path)