In [1]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from torchvision import transforms

from trainer import fit
import numpy as np

cuda = torch.cuda.is_available()
print(cuda)

True


In [2]:
# # Set up data loaders
# import csv
# from datasets import ImageFolderDataset

# root_dir = '../AIC20_ReID/image_train'
# query_csv = 'metadata/reid_query_easy.csv'
# gallery_csv = 'metadata/reid_gallery_easy.csv'

# size = (224, 224)

# def get_images_labels(vehicle_csv):
#     image_names = []
#     labels = []
#     with open(vehicle_csv, 'r') as csv_file:
#         csv_reader = csv.reader(csv_file)
#         header = next(csv_reader)
#         for row in csv_reader:
#             image_name, vehicle_id = row
#             image_names.append(image_name)
#             labels.append(int(vehicle_id))
#     return image_names, labels

# query_image_names, query_labels = get_images_labels(query_csv) 
# gallery_image_names, gallery_labels = get_images_labels(gallery_csv) 

# query_dataset = ImageFolderDataset(root_dir, query_image_names, query_labels,
#                                        transform = transforms.Compose([
#                                         transforms.Resize(size),  
#                                         transforms.ToTensor()
#                                       ]))
# gallery_dataset = ImageFolderDataset(root_dir, gallery_image_names, gallery_labels,
#                                      transform = transforms.Compose([
#                                         transforms.Resize(size),
#                                         transforms.ToTensor()
#                                       ]))

# batch_size = 8
# kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
# query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=True, **kwargs)
# gallery_loader = torch.utils.data.DataLoader(gallery_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [2]:
# Set up data loaders
from datasets import ImageFolderDataset
import csv

query_folder = '../AIC20_ReID/image_query'
gallery_folder = '../AIC20_ReID/image_test'
track_txt = '../AIC20_ReID/test_track.txt'
query_csv = 'metadata/Label-Test-Query - Query.csv'
gallery_csv = 'metadata/Label-Test-Query - Test.csv'

dict_cluster_codes = {}
id_codes = 0

query_images = []
query_cluster_codes = []
with open(query_csv, 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    header = next(csv_reader)
    for row in csv_reader:
        image_name, cluster_code = row[0], row[5]
        cluster = cluster_code.split("_")[2]
        if int(cluster) > 0 and int(cluster) <= 50:
            query_images.append(image_name)
            if cluster_code not in dict_cluster_codes:
                dict_cluster_codes[cluster_code] = id_codes
                id_codes += 1
            query_cluster_codes.append(dict_cluster_codes[cluster_code])

gallery_images = []
gallery_cluster_codes = []
tracklet_lists = [[] for i in range(798)] # number of trackets
tracklet_id = 0
lines = [line.rstrip('\n') for line in open(track_txt, 'r')]
for line in lines:
    image_names = line.split(" ")[:-1]
    for image_name in image_names:
        tracklet_lists[tracklet_id].append(image_name)
    tracklet_id += 1
        
with open(gallery_csv, 'r') as csv_file:
    csv_reader = csv.reader(csv_file)
    header = next(csv_reader)
    for row in csv_reader:
        image_name, cluster_code = row[0], row[5]
        cluster = cluster_code.split("_")[2]
        if int(cluster) > 0 and int(cluster) <= 50 and "_" in image_name:
            tracklet_name = int(image_name.split("_")[0])
            for image_name in tracklet_lists[tracklet_name]:
                gallery_images.append(image_name)
                if cluster_code not in dict_cluster_codes:
                    dict_cluster_codes[cluster_code] = id_codes
                    id_codes += 1
                gallery_cluster_codes.append(dict_cluster_codes[cluster_code])

size = (224, 224)

query_dataset = ImageFolderDataset(query_folder, query_images, query_cluster_codes,
                                       transform = transforms.Compose([
                                        transforms.Resize(size),  
                                        transforms.ToTensor()
                                      ]))
gallery_dataset = ImageFolderDataset(gallery_folder, gallery_images, gallery_cluster_codes,
                                     transform = transforms.Compose([
                                        transforms.Resize(size),
                                        transforms.ToTensor()
                                      ]))

batch_size = 8
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=True, **kwargs)
gallery_loader = torch.utils.data.DataLoader(gallery_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [17]:
PATH = 'weights/onlinetriplet-b4-200406-hardest-30epochs-random_erasing.pth'
model = torch.load(PATH)
# feature_extractor = model.embedding_net
feature_extractor = model

In [4]:
N_DIMS = 1792

def extract_embeddings(dataloader, model):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), N_DIMS))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.cuda()
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels

In [18]:
query_embedding, query_labels = extract_embeddings(query_loader, model)
gallery_embedding, gallery_labels = extract_embeddings(gallery_loader, model)

In [15]:
print(query_embedding.shape)
print(gallery_embedding.shape)

(176, 1792)
(2739, 1792)


In [19]:
import reid_metrics
query_tensor = torch.from_numpy(query_embedding)
gallery_tensor = torch.from_numpy(gallery_embedding)
if cuda:
    query_tensor = query_tensor.cuda()
    gallery_tensor = gallery_tensor.cuda()
    
print(reid_metrics.reid_evaluate(query_tensor, gallery_tensor, query_labels, gallery_labels))

(0.513823585650262, array([0.54605263], dtype=float32))
