In [None]:
import os
import time
import argparse
import random
import timm
import numpy as np
import matplotlib.pyplot as plt
import faiss
from PIL import Image
# from tqdm.notebook import tqdm
from tqdm import tqdm
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
# from model import *
# from utils import *
from metrics import *

## Hyper parameters

In [None]:
seed = 42

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
# os.environ['CUDA_VISIBLE_DEVICES']='0'
# device = "cpu" 
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 8
lr = 3e-4
gamma = 0.7
lmbd = 8
model_path = "./model/la-tf++.pth"
# data_dir = "/home/shubham/CVP/data/val"
data_dir = "/home/shubham/CVP/test"
visualization_path = "/home/shubham/CVP/Predictions"

## DataLoader

In [None]:
transform_query_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_gallery_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
    'query': transforms.Compose( transform_query_list ),
    'gallery': transforms.Compose(transform_gallery_list),
}

In [None]:
image_datasets = {}
image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                          data_transforms['query'])
image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                          data_transforms['gallery'])
query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size, shuffle=False )
gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size, shuffle=False)

class_names = image_datasets['query'].classes
print(len(class_names))

## Model

In [None]:
class LATransformer(nn.Module):
    def __init__(self, ViT, lmbd, num_classes=751, test=False):
        super(LATransformer, self).__init__()
        self.test = test
        self.class_num = num_classes # output number of classes
        
        # ViT model
        self.model = ViT
        self.model.head.requires_grad_ = False 
        self.cls_token = self.model.cls_token # 1, 1, 768
        self.pos_embed = self.model.pos_embed # 1, 197, 768

        # these are ViT model internal hyper-parameters (FIXED) 
        # self.num_blocks = 12 # number of sequential blocks in ViT
        
        # there are 196 patches in each image; thus, we split them into 14 x 14 grid
        self.num_rows = 14 
        self.num_cols = 14

        # Locally aware network
        self.avgpool = nn.AdaptiveAvgPool2d((self.num_rows,768))
        self.lmbd = lmbd

        if not self.test:
            # ensemble of classifiers
            # for i in range(self.num_rows):
            #     name = 'classifier'+str(i)
            #     setattr(self, name, FC_Classifier(input_dim=768, num_classes=self.class_num, droprate=0.5, num_bottleneck=256, return_features=False))
            name = 'classifier'+str(0)
            setattr(self, name, FC_Classifier(input_dim=768, num_classes=self.class_num, droprate=0.5, num_bottleneck=256, return_features=False))

    def forward(self, x):
        # x shape = 32, 3, 224, 224
        
        # Divide input image into patch embeddings and add position embeddings
        x = self.model.patch_embed(x) # 32, 196, 768
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 32, 1, 768
        x = torch.cat((cls_token, x), dim=1) # 32, 197, 768
        trnsfrmr_inp = self.model.pos_drop(x + self.pos_embed) # dropout with p = 0; idk!
        
        # Feed forward the x = (patch_embeddings+position_embeddings) through transformer blocks
        # for i in range(self.num_blocks):
        #    x = self.model.blocks[i](x)
        x = self.model.blocks(trnsfrmr_inp)
        x_trnsfrmr_encdd = self.model.norm(x) # layer normalization; shape = 32, 197, 768
        
        # extract the cls token
        cls_token_out = x_trnsfrmr_encdd[:, 0].unsqueeze(1)
        
        # Average pool
        Q = x_trnsfrmr_encdd[:, 1:]
        L = self.avgpool(Q) # 32, 14, 768
                
        if self.test:
            return L # moving this down the global-cls addition drops the testing score
        
        # Add global cls token to each local token 
        for i in range(self.num_rows):
            out = torch.mul(L[:, i, :], self.lmbd)
            L[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd)
            
        L, _ = torch.max(L, dim=1)
        
        # Locally aware network
        part = {}
        predict = {}
        # for i in range(self.num_rows):
        #     part[i] = L[:,i,:] # 32, 768
        #     name = 'classifier'+str(i)
        #     c = getattr(self, name)
        #     predict[i] = c(part[i]) # 32, 751
        name = 'classifier'+str(0)
        c = getattr(self, name)
        predict[0] = c(L) # 32, 751

        return predict

## Load Model

In [None]:
# Load ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base = vit_base.to(device)

# Create La-Transformer
model = LATransformer(vit_base, lmbd=lmbd, num_classes=123, test=True).to(device)

# Load LA-Transformer
model.load_state_dict(torch.load(model_path), strict=False)
model.eval()

##  Extract Features

In [None]:
def extract_feature(model, dataloaders):
    imgs = torch.FloatTensor()
    features = torch.FloatTensor()
    for data in tqdm(dataloaders):
        img, label = data
        
        img_copy = img.clone()
        imgs = torch.cat((imgs, img_copy), 0)
        
        img, label = img.to(device), label.to(device)

        output = model(img)
        features = torch.cat((features, output.detach().cpu()), 0)

    return features, imgs

In [None]:
# Extract Query Features
query_feature, query_imgs = extract_feature(model, query_loader)

# Extract Gallery Features
gallery_feature, gallery_imgs = extract_feature(model, gallery_loader)

In [None]:
# Retrieve labels
gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

In [None]:
# gallery_path

In [None]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, label in img_path:
        cam_id = int(path.split("/")[-1].split("_")[0])
        # filename = os.path.basename(path)
        # camera = filename.split('_')[0]
        labels.append(int(label))
        camera_id.append(cam_id)
    return camera_id, labels

gallery_cam, gallery_label = get_id(gallery_path)
query_cam, query_label = get_id(query_path)

gallery_label = np.array(gallery_label)
query_label = np.array(query_label)

## Concat Averaged GELTs

In [None]:
concatenated_query_vectors = []
for query in tqdm(query_feature):
    fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)
    query_norm = query.div(fnorm.expand_as(query))
    concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752
#     concatenated_query_vectors.append(query.view((-1)))

concatenated_gallery_vectors = []
for gallery in tqdm(gallery_feature):
    fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)
    gallery_norm = gallery.div(fnorm.expand_as(gallery))
    concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752
#     concatenated_gallery_vectors.append(gallery.view((-1)))

## Calculate Similarity using FAISS

In [None]:
import faiss
import numpy as np

# index = faiss.IndexIDMap(faiss.IndexFlatIP(10752)) # inner product
# index.add_with_ids(np.array([t.numpy() for t in concatenated_gallery_vectors]),np.array(gallery_label))
index = faiss.IndexFlatIP(10752) # inner product
index.add(np.array([t.numpy() for t in concatenated_gallery_vectors]))

def search(query: str, k=1):
    encoded_query = query.unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    gallery_imgs_idxs = top_k[1][0].copy()
    top_k[1][0] = np.take(gallery_label, indices=top_k[1][0])
    return top_k, gallery_imgs_idxs
#     return top_k

In [None]:
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
t = transforms.Compose([transforms.ToPILImage(), 
                        transforms.Resize(size=(128,48))
                       ])

def visualize(query_img, gallery_imgs, gallery_idxs, save_path):
    plt.figure(figsize=(16.,6.))

    plt.subplot(1,11,1)
    img_tensor = query_img.clone()
    for i in range(3):
        img_tensor[i] = (img_tensor[i] * std[i]) + mean[i]
    x = t(img_tensor)
    x = np.array(x)
    plt.axis('off')
    plt.imshow(x)

    for j in range(10):
        img_tensor = gallery_imgs[gallery_idxs[j]].clone()
        for i in range(3):
            img_tensor[i] = (img_tensor[i] * std[i]) + mean[i]
        x = t(img_tensor)
        x = np.array(x)
        plt.subplot(1,11,j+2)
        plt.axis('off')
        plt.imshow(x)
        
    plt.savefig(save_path)
    plt.close()

In [None]:
! rm -r "/home/shubham/CVP/Predictions"
if not os.path.exists(visualization_path):
    os.mkdir(visualization_path)
    os.mkdir(os.path.join(visualization_path, "correct"))
    os.mkdir(os.path.join(visualization_path, "incorrect"))

In [None]:
## Evaluate 
rank1_score = 0
rank5_score = 0
ap = 0
count = 0
for query, label in zip(concatenated_query_vectors, query_label):
    query_img = query_imgs[count]
    
    count += 1
    label = label
    output, gallery_imgs_idxs = search(query, k=10)
    # output = search(query, k=10)
    
    r1 = rank1(label, output) 
    rank1_score += r1
    rank5_score += rank5(label, output) 
    ap += calc_map(label, output)
    
    if r1:
        # save_path = os.path.join(visualization_path, "correct")
        # save_path = os.path.join(save_path, str(count-1)+".png")
        # visualize(query_img, gallery_imgs, gallery_imgs_idxs, save_path)
        pass
    else:
        save_path = os.path.join(visualization_path, "incorrect")
        save_path = os.path.join(save_path, str(count-1)+".png")
        visualize(query_img, gallery_imgs, gallery_imgs_idxs, save_path)

print("Correct: {}, Total: {}, Incorrect: {}".format(rank1_score, count, count-rank1_score))
print("Rank1: %.3f, Rank5: %.3f, mAP: %.3f"%(rank1_score/len(query_feature), 
                                             rank5_score/len(query_feature), 
                                             ap/len(query_feature)))    

### Appendix

In [None]:
# query_img = query_imgs[1]
# output, gallery_imgs_idxs = search(concatenated_query_vectors[1], k=10)
# visualize(query_img, gallery_imgs, gallery_imgs_idxs, True)
# rank1(query_label[1], output) 
# rank5(query_label[1], output) 

In [None]:
# print(len(query_imgs), len(gallery_imgs))
# print(len(concatenated_query_vectors), len(concatenated_gallery_vectors))

In [None]:
# index = faiss.IndexFlatIP(10752)
# index.add(np.array([t.numpy() for t in concatenated_gallery_vectors]))

In [None]:
# index.ntotal

In [None]:
# encoded_query = concatenated_query_vectors[0].unsqueeze(dim=0).numpy()
# top_k = index.search(encoded_query, 10)

In [None]:
# top_k

In [None]:
# lbls = np.array(gallery_label)
# lbls.shape

In [None]:
# top_k[1][0].shape

In [None]:
# np.take(lbls, indices=top_k[1][0])

In [None]:
# gallery_imgs.shape