In [3]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from torchvision import transforms

from trainer import fit
import numpy as np

cuda = torch.cuda.is_available()

In [4]:
# Prepare DataLoader
from datasets import ImageFolderDataset
import csv

query_folder = '../AIC20_ReID/image_query'
query_csv = 'metadata/Label-Test-Query - Query.csv'

size = (224, 224)
query_images = []
query_cluster_codes = []

with open(query_csv, 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    header = next(csv_reader)
    for row in csv_reader:
        image_name, cluster_code = row[0], row[5]
        cluster = cluster_code.split("_")[2]
        if int(cluster) > 0 and int(cluster) <= 50:
            query_images.append(image_name)
            query_cluster_codes.append(image_name)
            
query_dataset = ImageFolderDataset(query_folder, query_images, query_cluster_codes,
                                   transform = transforms.Compose([
                                        transforms.Resize(size),  
                                        transforms.ToTensor()
                                   ]))
batch_size = 8
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=True, **kwargs)

In [5]:
# Load Model
model_path = 'weights/onlinetriplet-b4-200405-hardest_30epochs.pth'
model = torch.load(model_path)
# feature_extractor = model.embedding_net
feature_extractor = model



In [6]:
N_DIMS = 1792
def extract_embeddings(dataloader, model):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), N_DIMS))
        labels = []
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.cuda()
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels += target
            k += len(images)
    return embeddings, labels

In [7]:
query_embedding, query_labels = extract_embeddings(query_loader, model)

In [39]:
def pdist_torch(emb1, emb2):
    m, n = emb1.shape[0], emb2.shape[0]
    emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n)
    emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t()
    dist_mtx = emb1_pow + emb2_pow
    dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t())
    dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt()
    return dist_mtx

def run_query_query(emb_query):
    #Calculate distance matrix between query images and gallery images
    dist_mtx = pdist_torch(emb_query, emb_query).cpu().detach().numpy()
    return dist_mtx

In [40]:
query_embedding_indices = np.argsort(np.asarray(query_labels))
query_embedding_sorted = np.asarray([query_embedding[i] for i in query_embedding_indices])
query_labels_sorted = np.asarray([query_labels[i] for i in query_embedding_indices])

query_tensor = torch.from_numpy(query_embedding_sorted)
if cuda:
    query_tensor = query_tensor.cuda()
    
dists = run_query_query(query_tensor)

In [57]:
threshold = 0.9
n_queries = len(query_labels_sorted)

mark = [False for i in range(n_queries)]

def DFS(u, component):
    mark[u] = True
    component.append(u)
    for v in range(n_queries):
        if not mark[v] and dists[u][v] < threshold:
            DFS(v, component)

components = []
for idx in range(n_queries):
    if not mark[idx]:
        components.append([])
        DFS(idx, components[-1])

for component in components:
    if len(component) > 1:
        print(component)

[0, 153]
[1, 60, 165]
[4, 133, 107]
[5, 14, 149]
[7, 54, 57, 162, 19, 44, 48, 82, 102, 174, 91, 35, 120, 68]
[10, 166, 173]
[11, 138]
[12, 38]
[16, 17]
[18, 67, 116, 156]
[20, 55, 72, 42, 134, 151, 158]
[22, 36]
[25, 80, 128]
[27, 137]
[28, 118, 88, 124, 152]
[29, 115, 161]
[31, 150, 135]
[34, 121, 159]
[37, 130]
[43, 47]
[45, 63, 85]
[51, 98]
[53, 164]
[56, 129, 141]
[59, 74, 89, 90, 171, 95, 147]
[62, 66]
[65, 86]
[69, 70, 160]
[78, 155, 113]
[87, 97, 140]
[93, 103, 163, 168]
[101, 106, 111, 148, 170]
[109, 117, 167]
[125, 132]
[139, 146]
[143, 169]


In [58]:
import os
from shutil import copyfile

output_path = '../query_groups'

os.mkdir(output_path)

for idx, component in enumerate(components):
    group_path = os.path.join(output_path, str(idx)) 
    os.mkdir(group_path)
    for vertical in component:
        copyfile(os.path.join(query_folder, query_labels_sorted[vertical]), 
                 os.path.join(group_path, query_labels_sorted[vertical]))
        