In [29]:
import argparse
import os
import shutil
import numpy as np

import torch
from PIL import Image, ImageDraw

'''
parser = argparse.ArgumentParser(description='Test CGD')
parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str,
                    help='query image name')
parser.add_argument('--data_base', default='car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth',
                    type=str, help='queried database')
parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number')

opt = parser.parse_args()
'''

query_img_name = np.random.choice(data_base['test_images'])
data_base_name = 'cub_cropped_resnet50_G_64_0.1_0.5_0.1_32_data_base.pth'
retrieval_num = 8
data_name = data_base_name.split('_')[0]

data_base = torch.load('results/{}'.format(data_base_name))

if query_img_name not in data_base['test_images']:
    raise FileNotFoundError('{} not found'.format(query_img_name))
query_index = data_base['test_images'].index(query_img_name)
query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
query_label = torch.tensor(data_base['test_labels'][query_index])
query_feature = data_base['test_features'][query_index]

gallery_images = data_base['{}_images'.format('test' if data_name != 'isc' else 'gallery')]
gallery_labels = torch.tensor(data_base['{}_labels'.format('test' if data_name != 'isc' else 'gallery')])
gallery_features = data_base['{}_features'.format('test' if data_name != 'isc' else 'gallery')]

feature_concat = torch.cat((query_feature[None, :], gallery_features))
query_label = torch.tensor(torch.reshape(query_label, (1,1)), dtype=torch.float32)
gallery_labels = torch.tensor(gallery_labels[:, None], dtype=torch.float32)
label_concat = torch.cat((query_label, gallery_labels))



In [2]:
from sklearn.metrics.pairwise import rbf_kernel
import numpy as np

def embed(sigma_vec, k_vec, X, Y, Gamma, kernel=True):
    def my_laplacian(A):
        D = A.sum(dim=0)
        return torch.diag(D) - A
    
    dist_X = torch.cdist(X, X)
    dist_Y = torch.cdist(Y, Y).cuda()
    n = len(Y)
    I_X = torch.eye(n).cuda()
    J = I_X - 1/n * torch.ones(n, n).cuda() #1/n_concat * torch.ones(n_concat, n_concat)
    WX = -0.5 * J @ dist_X @ J
    WY = -0.5 * J @ dist_Y @ J

    embeddings = {}
    rProjs = {}
    
    for k in k_vec:
        for sigma in sigma_vec:
            LY = my_laplacian(WY)
            if not kernel:
                LX = WX
            else: 
                gamma = 1/(sigma**2)
                w = rbf_kernel(X.cpu().numpy(), gamma=gamma)
                np.fill_diagonal(w, 0)
                LX = my_laplacian(torch.tensor(w)).cuda()
                
            diag_LX = torch.diag(0.5 * (1/(torch.diag(LX) + 0.000000001)))
            
            pLI = k*LY - LX
            embedding = (diag_LX @ pLI @ Gamma) + Gamma
            embeddings[(k, sigma)] = embedding
    
    return embeddings

In [30]:
sigma_vec = [100]
k_vec = [1000]
dim = 2
Gamma = torch.randn(len(label_concat), dim).cuda()
embedding = embed(sigma_vec, k_vec, feature_concat, label_concat, Gamma)[(1000,100)]

In [32]:
dist_matrix = torch.cdist(embedding[0][None,:], embedding[1:]).squeeze()
if data_name != 'isc':
    dist_matrix[query_index] = float('inf')
idx = dist_matrix.topk(k=retrieval_num, dim=-1, largest=False)[1]

result_path = 'results/{}'.format(query_img_name.split('/')[-1].split('.')[0])
if os.path.exists(result_path):
    shutil.rmtree(result_path)
os.mkdir(result_path)
query_image.save('{}/query_img.jpg'.format(result_path))
for num, index in enumerate(idx):
    retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \
        .resize((224, 224), resample=Image.BILINEAR)
    draw = ImageDraw.Draw(retrieval_image)
    retrieval_label = gallery_labels[index.item()]
    retrieval_status = (retrieval_label == query_label).item()
    retrieval_dist = dist_matrix[index.item()].item()
    if retrieval_status:
        draw.rectangle((0, 0, 223, 223), outline='green', width=8)
    else:
        draw.rectangle((0, 0, 223, 223), outline='red', width=8)
    retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_dist))