In [122]:
import torch
from sklearn.metrics.pairwise import cosine_similarity
import os
import pandas as pd
import numpy as np


In [77]:
def fusion(global_features,local_features,order_first="local"):
    """when you apply just the sequeeze() it removes all the ones
    #so (1,2048,1,1) after squeezing -> (2048), adding that extra dimension on the zero 
    #axis :- unsqueeze(0) (1,2048)"""
    
    if order_first=="local":
        if global_features.shape[0]==1:
            x = torch.cat([local_features,global_features],axis=1).squeeze().unsqueeze(0)
        else: #if we have batch size!=1, then we dont have to unsqueeze because it wont squeeze the batch size       
            x = torch.cat([local_features,global_features],axis=1).squeeze()
        return x
    else:#stack global first
        if global_features.shape[0]==1:
            x = torch.cat([global_features,local_features],axis=1).squeeze().unsqueeze(0)
        else:#if we have batch size!=1, then we dont have to unsqueeze     
            x = torch.cat([global_features,local_features],axis=1).squeeze()
        return x

In [11]:
global_matrix = torch.randn(20,2048,1,1)
local_matrix = torch.randn(20,2048,1,1)

In [12]:
fusion(global_matrix,local_matrix,order_first="global").size()

torch.Size([20, 4096])

In [197]:
df = pd.read_csv("data.csv")

In [55]:
df.head()

Unnamed: 0.1,Unnamed: 0,query_id,retrieved_ids,retrieved_indicies
0,0,x3vA7Bk0HNI6rGkDpDZQUQ,"['X9V1oGRaAEFjq5jufrklTQ', 'E7gcrCyitkguCnMzoE...","[3, 5130, 5131, 0, 7912, 5132, 8812, 9186, 1, ..."
1,1,U9Vj0IV4q1psciXpj51F_w,"['X9V1oGRaAEFjq5jufrklTQ', '22BOHMokEHyXf9LA8B...","[3, 4504, 1, 1815, 9186, 0, 2, 9181, 5131, 513..."
2,2,Eh1NwQjH4jbKcWqVJ4ZsJg,"['X9V1oGRaAEFjq5jufrklTQ', '_Eq8EgtwLGiMFc7VJd...","[3, 4, 2, 1815, 7604, 1, 0, 8810, 5, 9186, 6, ..."
3,3,1RKCGBAWsZbi5dj3vR2mlw,"['_Eq8EgtwLGiMFc7VJdb-YQ', 'X9V1oGRaAEFjq5jufr...","[4, 3, 5, 6, 0, 8815, 9186, 9181, 8798, 1, 513..."
4,4,LdiYwYkqgUfc1IYDu5ov9A,"['Z4MR4AHQufgsCwiBiqQ23A', 'eQ-8kVNfMZiexVcu_V...","[5, 6, 4, 8815, 5133, 7049, 5125, 8816, 8811, ..."


first CPH length: 6595 second
SF length: 4525

In [198]:
def rank(local_features_cph_query,global_features_cph_query,
         local_features_sf_query,global_features_sf_query,
         local_features_cph_database,global_features_cph_database,
         local_features_sf_database,global_features_sf_database,df):
    
    CPH_LEN = 6595
    SF_LEN = 4525
    for row in range(df.shape[0]):    

        
        if row<CPH_LEN:
            #unsqueezing because when we access it we result in ([2048]), to make it ([1,2048])
            #same for global
            local_feature_query = local_features_cph_query[row].unsqueeze(0)
            global_feature_query = global_features_cph_query[row].unsqueeze(0)
            
            #applying the fusion (1,4096)
            query_fusion = fusion(global_feature_query,local_feature_query)
            #eval basically converts to the required datatype given the string format
            retrieved_indices = eval(df.iloc[row]["retrieved_indicies"])

            database_feature_list = None
            #fetching all the features first and concatenating them
            for database_id in retrieved_indices:
                #same reason as above(to why we unsqueezing)
                global_feature_database = local_features_cph_database[database_id].unsqueeze(0)
                
                local_feature_database = local_features_cph_database[database_id].unsqueeze(0)
                #(1,4096)
                combined_features = fusion(global_feature_database,local_feature_database)
                
                if database_feature_list is None:
                    database_feature_list = combined_features
                else:
                    database_feature_list = torch.cat([database_feature_list,combined_features])
            #so now we result the size of database_feature_list as -> (top_k,4096)
            #computing the similarity so we get (1,top_k) size. 1 because, 1 query image passed in
            similarity = cosine_similarity(query_fusion,database_feature_list)
            #getting the indices of the most similar and mapping to the retrieved list
            ranked_indices = [retrieved_indices[i] for i in np.argsort(similarity)[0]]
      
            df.loc[row,"re_ranked"] = str(ranked_indices)

        else: #For SF
            
            row = abs(row-CPH_LEN) 
            
            local_feature_query = local_features_sf_query[row].unsqueeze(0)
            global_feature_query = global_features_sf_query[row].unsqueeze(0)

            query_fusion = fusion(global_feature_query,local_feature_query)
        
            retrieved_indices = eval(df.iloc[CPH_LEN+row]["retrieved_indicies"])
            
            
            database_feature_list = None
            
            for database_id in retrieved_indices:
          
                global_feature_database = local_features_sf_database[database_id].unsqueeze(0)
                
                local_feature_database = local_features_sf_database[database_id].unsqueeze(0)
                combined_features = fusion(global_feature_database,local_feature_database)
                
                if database_feature_list is None:
                    database_feature_list = combined_features
                else:
                    database_feature_list = torch.cat([database_feature_list,combined_features])
            
            similarity = cosine_similarity(query_fusion,database_feature_list)
            ranked_indices = [retrieved_indices[i] for i in np.argsort(similarity)[0]]
            
            df.loc[CPH_LEN+row,"re_ranked"] = str(ranked_indices)
    
    return df

In [203]:
#simulation 


local_features_cph_query = torch.randn(6595,2048)
global_features_cph_query = torch.randn(6595,2048)

local_features_sf_query = torch.randn(4525,2048)
global_features_sf_query = torch.randn(4525,2048)

local_features_cph_database = torch.randn(20000,2048)
global_features_cph_database = torch.randn(20000,2048)

local_features_sf_database = torch.randn(20000,2048)
global_features_sf_database = torch.randn(20000,2048)

In [200]:
new_df = rank(local_features_cph_query, global_features_cph_query,
     local_features_sf_query,global_features_sf_query,
     local_features_cph_database,global_features_cph_database,
     local_features_sf_database,global_features_sf_database,df)

In [201]:
new_df.head()

Unnamed: 0.1,Unnamed: 0,query_id,retrieved_ids,retrieved_indicies,re_ranked
0,0,x3vA7Bk0HNI6rGkDpDZQUQ,"['X9V1oGRaAEFjq5jufrklTQ', 'E7gcrCyitkguCnMzoE...","[3, 5130, 5131, 0, 7912, 5132, 8812, 9186, 1, ...","[1760, 4505, 8250, 7896, 8251, 9186, 7530, 181..."
1,1,U9Vj0IV4q1psciXpj51F_w,"['X9V1oGRaAEFjq5jufrklTQ', '22BOHMokEHyXf9LA8B...","[3, 4504, 1, 1815, 9186, 0, 2, 9181, 5131, 513...","[4, 4504, 7608, 1816, 7896, 9521, 7912, 0, 697..."
2,2,Eh1NwQjH4jbKcWqVJ4ZsJg,"['X9V1oGRaAEFjq5jufrklTQ', '_Eq8EgtwLGiMFc7VJd...","[3, 4, 2, 1815, 7604, 1, 0, 8810, 5, 9186, 6, ...","[4, 1815, 7041, 4504, 1, 6, 5131, 2, 20, 9181,..."
3,3,1RKCGBAWsZbi5dj3vR2mlw,"['_Eq8EgtwLGiMFc7VJdb-YQ', 'X9V1oGRaAEFjq5jufr...","[4, 3, 5, 6, 0, 8815, 9186, 9181, 8798, 1, 513...","[530, 1815, 5, 1816, 3618, 4, 9186, 7913, 6206..."
4,4,LdiYwYkqgUfc1IYDu5ov9A,"['Z4MR4AHQufgsCwiBiqQ23A', 'eQ-8kVNfMZiexVcu_V...","[5, 6, 4, 8815, 5133, 7049, 5125, 8816, 8811, ...","[8816, 5125, 1227, 7050, 5135, 11926, 7049, 6,..."


In [202]:
new_df.isna().sum()

Unnamed: 0            0
query_id              0
retrieved_ids         0
retrieved_indicies    0
re_ranked             0
dtype: int64