In [1]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from torchvision import transforms

from trainer import fit
import numpy as np

cuda = torch.cuda.is_available()

In [3]:
# Prepare DataLoader
from datasets import ImageFolderDataset
import csv

query_folder = '../AIC20_ReID/image_query'
query_csv = 'metadata/Label-Test-Query - Query.csv'

size = (224, 224)
query_images = []

# with open(query_csv, 'r') as csv_file:
#     csv_reader = csv.reader(csv_file)
#     header = next(csv_reader)
#     for row in csv_reader:
#         image_name, cluster_code = row[0], row[5]
#         cluster = cluster_code.split("_")[2]
#         if int(cluster) > 0 and int(cluster) <= 50:
#             query_images.append(image_name)
from pathlib import Path
for file_name in Path(query_folder).glob('*.jpg'):
    query_images.append(str(file_name.parts[-1]))
            
query_dataset = ImageFolderDataset(query_folder, query_images, query_images,
                                   transform = transforms.Compose([
                                        transforms.Resize(size),  
                                        transforms.ToTensor()
                                   ]))
    
batch_size = 8
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=True, **kwargs)

In [4]:
# Load Model
model_path = 'weights/onlinetriplet-b4-200405-hardest_30epochs.pth'
model = torch.load(model_path)
# feature_extractor = model.embedding_net
feature_extractor = model



In [5]:
N_DIMS = 1792
def extract_embeddings(dataloader, model):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), N_DIMS))
        labels = []
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.cuda()
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels += target
            k += len(images)
    return embeddings, labels

In [6]:
query_embedding, query_labels = extract_embeddings(query_loader, model)

In [7]:
def pdist_torch(emb1, emb2):
    m, n = emb1.shape[0], emb2.shape[0]
    emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n)
    emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t()
    dist_mtx = emb1_pow + emb2_pow
    dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t())
    dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt()
    return dist_mtx

def run_query_query(emb_query):
    #Calculate distance matrix between query images and gallery images
    dist_mtx = pdist_torch(emb_query, emb_query).cpu().detach().numpy()
    return dist_mtx

In [8]:
query_embedding_indices = np.argsort(np.asarray(query_labels))
query_embedding_sorted = np.asarray([query_embedding[i] for i in query_embedding_indices])
query_labels_sorted = np.asarray([query_labels[i] for i in query_embedding_indices])

query_tensor = torch.from_numpy(query_embedding_sorted)
if cuda:
    query_tensor = query_tensor.cuda()
    
dists = run_query_query(query_tensor)

In [9]:
threshold = 0.7
n_queries = len(query_labels_sorted)

mark = [False for i in range(n_queries)]

def DFS(u, component):
    mark[u] = True
    component.append(u)
    for v in range(n_queries):
        if not mark[v] and dists[u][v] < threshold:
            DFS(v, component)

components = []
for idx in range(n_queries):
    if not mark[idx]:
        components.append([])
        DFS(idx, components[-1])

for component in components:
    if len(component) > 1:
        print(component)

[0, 345]
[1, 900]
[3, 135]
[4, 539]
[5, 1037, 542, 514, 793, 523, 251, 774, 1013]
[7, 571, 747, 824]
[9, 319, 981]
[11, 248, 151, 498, 1022]
[21, 488, 889]
[29, 490]
[31, 70, 243, 275, 181, 309, 497]
[32, 472, 106]
[33, 94]
[37, 261, 954, 990]
[39, 983]
[43, 272, 294, 289]
[45, 105, 196, 212]
[46, 137, 350, 565, 473]
[49, 819, 1012]
[51, 448, 446, 519]
[52, 450]
[58, 626, 998, 74, 264, 648]
[61, 563]
[64, 256, 315, 642, 969]
[65, 879, 407, 1021]
[67, 804, 422, 524, 569, 742, 608, 903]
[68, 505, 335, 424, 810, 1000, 361, 147, 439, 910, 171, 568, 541, 849]
[71, 790]
[73, 202, 484, 817, 529, 285, 97, 242, 999, 600]
[78, 988, 1040]
[82, 940, 979]
[83, 829]
[88, 180, 102, 937, 796, 984]
[90, 177]
[91, 404, 872, 504]
[95, 663]
[98, 591, 672]
[104, 176, 935]
[109, 110, 694, 339, 1019]
[111, 353, 651, 919]
[114, 220, 1042, 582, 455, 559, 400, 526, 870, 907, 675, 486, 164]
[115, 797]
[117, 310]
[119, 507, 697, 780]
[123, 252, 138]
[124, 166]
[126, 366, 532]
[127, 317]
[129, 442, 767]
[136, 230,

In [10]:
import os
from shutil import copyfile

output_path = '../full_query_groups'

os.mkdir(output_path)

for idx, component in enumerate(components):
    group_path = os.path.join(output_path, str(idx)) 
    os.mkdir(group_path)
    for vertical in component:
        copyfile(os.path.join(query_folder, query_labels_sorted[vertical]), 
                 os.path.join(group_path, query_labels_sorted[vertical]))
        