In [1]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable

from torchvision import transforms

from trainer import fit
import numpy as np

cuda = torch.cuda.is_available()

In [2]:
# Set up data loaders
from datasets import EvalDataset

root_dir = '/home/cuong/AIC20-Track2/AIC20_track2/AIC20_ReID/image_train'
query_csv = 'reid_query_easy.csv'
gallery_csv = 'reid_gallery_easy.csv'

size = (224, 224)

query_dataset = EvalDataset(root_dir, query_csv,
                                       transform = transforms.Compose([
                                        transforms.Resize(size),  
                                        transforms.ToTensor()
                                      ]))
gallery_dataset = EvalDataset(root_dir, gallery_csv,
                                     transform = transforms.Compose([
                                        transforms.Resize(size),
                                        transforms.ToTensor()
                                      ]))

batch_size = 8
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
query_loader = torch.utils.data.DataLoader(query_dataset, batch_size=batch_size, shuffle=True, **kwargs)
gallery_loader = torch.utils.data.DataLoader(gallery_dataset, batch_size=batch_size, shuffle=False, **kwargs)

In [3]:
PATH = 'triplet-b4-200403.pth'
model = torch.load(PATH)
feature_extractor = model.embedding_net

In [4]:
N_DIMS = 1792

def extract_embeddings(dataloader, model):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), N_DIMS))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            if cuda:
                images = images.cuda()
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels

In [5]:
query_embedding, query_labels = extract_embeddings(query_loader, model)
gallery_embedding, gallery_labels = extract_embeddings(gallery_loader, model)

In [6]:
print(gallery_embedding.shape)

(10775, 1792)


In [7]:
import reid_metrics

print(reid_metrics.reid_evaluate(torch.from_numpy(query_embedding), torch.from_numpy(gallery_embedding), query_labels, gallery_labels))

(0.5612963664556764, array([0.8603839], dtype=float32))
