In [43]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.models import resnet50

from models import model_pool
from models.util import create_model

import shutil

In [58]:
# support set에 해당하는 dataloader (5장만 가져오도록)
normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
transform_train = transforms.Compose([
                    transforms.Resize((550, 550)),
                    transforms.RandomCrop(448, padding=8),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize
                ])
support_dataset = ImageFolder(root='../../datasets/open_set_few_shot_retrieval_set/support_snow', transform=transform_train)
support_loader = torch.utils.data.DataLoader(support_dataset, batch_size=5, shuffle=True, num_workers=4)

In [59]:
# query set이 loader에서 전체 돌면서 support set과의 mean(sim)을 확인해서 저장
normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
transform_train = transforms.Compose([
                    transforms.Resize((550, 550)),
                    transforms.RandomCrop(448, padding=8),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize
                ])
query_dataset = ImageFolder(root='../../datasets/open_set_few_shot_retrieval_set/query_set', transform=transform_train)
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=1, shuffle=True, num_workers=4)

In [60]:
# model load
model_path = './checkpoint/best_model_sofar_multi.pth'
model = create_model('resnet50', 11, 'SOFAR')
ckpt = torch.load(model_path)
corrected_dict = { k.replace('features.', ''): v for k, v in ckpt.items() } 
model.load_state_dict(corrected_dict, strict = False)
model = model.cuda()

In [61]:
# classifier 제외하고 feature extractor만 가져오기
model.classifier = nn.Identity()

In [62]:
# 여기서 support set을 먼저 지정하기
for idx, (img, target) in enumerate(support_loader) : 
    img = img.cuda()
    target = target.cuda()
    support_output = model(img)
    
    break

In [63]:
support_output[0].reshape(1, -1).shape

torch.Size([1, 401408])

In [64]:
result_dict = dict()

for idx, (img, target) in enumerate(query_loader) : 
    img = img.cuda()
    target = target.cuda()
    img_path = query_dataset.imgs[idx][0]
    
    query_vec = model(img)
    query_vec = query_vec.cpu().detach().numpy()
    
    distance_list = list()
    for i in range(len(support_output)) :
        support_vec = support_output[i].reshape(1, -1)
        support_vec = support_vec.cpu().detach().numpy()
        dist = np.linalg.norm(support_vec - query_vec)
        distance_list.append(dist)
    avg_dist = np.mean(distance_list)
    result_dict[img_path] = avg_dist

In [65]:
# value 기준으로 sort
sorted_dict = sorted(result_dict.items(), key = lambda item: item[1])

In [66]:
sorted_dict_50 = sorted_dict[:50] 

In [67]:
top_50_list = list()

for key, item in sorted_dict_50 : 
    top_50_list.append(key)

In [68]:
count = 0

for i in range(len(top_50_list)) :
    if 'document' in top_50_list[i] :
        count += 1

In [69]:
count / 50

0.02

In [45]:
# 폴더로 저장
for i in range(len(top_50_list)) :
    filenm = top_50_list[i].split('/')[-1]
    dst_file_path = os.path.join('../../datasets/open_set_few_shot_retrieval_set/multi_euclidean_document/', filenm)
    shutil.copy(top_50_list[i], dst_file_path)