In [57]:
import os
import json
import numpy as np
import pandas as pd

from IPython.display import display

In [58]:
# PATHS
DATA_ROOT = "generalized_contrastive_loss"

MSLS_MODELS = os.path.join(DATA_ROOT, "Models", "MSLS")
RESULTS_ROOT = os.path.join(DATA_ROOT, "results", "MSLS", "val")
DISTANCE_PREDICTION_PATH = os.path.join(RESULTS_ROOT, "MSLS_resnext_GeM_480_GCL_predictions.txt.npy")
PREDICTION_PATH = os.path.join(RESULTS_ROOT, "MSLS_resnext_GeM_480_GCL_predictions.txt")
RESULTS_NPY_TEXT = "MSLS_resnext_GeM_480_GCL_{city}_{file_type}.npy"

DATASET_ROOT = os.path.join(DATA_ROOT, "msls")
DATASET_TEST = os.path.join(DATASET_ROOT, "test")
DATASET_VAL = os.path.join(DATASET_ROOT, "train_val")
DATASET_VAL_SF = os.path.join(DATASET_VAL, "sf")
DATASET_VAL_CPH = os.path.join(DATASET_VAL, "cph")

In [59]:
map_feature_data_cph = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="cph", file_type="mapfeats")))
print(f"Map Feature CPH: {map_feature_data_cph.shape}")
query_feature_data_cph = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="cph", file_type="queryfeats")))
print(f"Query Feature CPH: {query_feature_data_cph.shape}")
distances_cph = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="cph", file_type="distances")))
print(f"Distances CPH: {distances_cph.shape}")

Map Feature CPH: (12601, 2048)
Query Feature CPH: (6595, 2048)
Distances CPH: (6595, 12601)


In [60]:
map_feature_data_sf = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="sf", file_type="mapfeats")))
print(f"Map Feature SF: {map_feature_data_sf.shape}")
query_feature_data_sf = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="sf", file_type="queryfeats")))
print(f"Query Feature SF: {query_feature_data_sf.shape}")
distances_sf = np.load(os.path.join(RESULTS_ROOT, RESULTS_NPY_TEXT.format(city="sf", file_type="distances")))
print(f"Distances SF: {distances_sf.shape}")

Map Feature SF: (6315, 2048)
Query Feature SF: (4525, 2048)
Distances SF: (4525, 6315)


In [65]:
# read the predictions text file
df = pd.read_fwf(PREDICTION_PATH, header=None)
df['combined'] = df.drop(0, axis=1).values.tolist()
df = df[[0, 'combined']]
df.columns = ['query_id', 'retrieved_ids']
# create a new column for lists in the dataframe
df['retrieved_indicies'] = [[] for x in range(len(df))]
display(df)

Unnamed: 0,query_id,retrieved_ids,retrieved_indicies
0,x3vA7Bk0HNI6rGkDpDZQUQ,"[X9V1oGRaAEFjq5jufrklTQ, E7gcrCyitkguCnMzoEwm0...",[]
1,U9Vj0IV4q1psciXpj51F_w,"[X9V1oGRaAEFjq5jufrklTQ, HU9GEfLAB9pm5RmjW4MLh...",[]
2,Eh1NwQjH4jbKcWqVJ4ZsJg,"[X9V1oGRaAEFjq5jufrklTQ, _Eq8EgtwLGiMFc7VJdb-Y...",[]
3,1RKCGBAWsZbi5dj3vR2mlw,"[_Eq8EgtwLGiMFc7VJdb-YQ, X9V1oGRaAEFjq5jufrklT...",[]
4,LdiYwYkqgUfc1IYDu5ov9A,"[Z4MR4AHQufgsCwiBiqQ23A, _Eq8EgtwLGiMFc7VJdb-Y...",[]
...,...,...,...
11115,BdY6m-mEqcF_EAgWXhNhMQ,"[c4u-9p_hYbMqZMV9M9NJOw, HWs14OYXLRbvsVJ7sHJX0...",[]
11116,MISH40wohJPWQg2AQL91ww,"[c4u-9p_hYbMqZMV9M9NJOw, L8wuXgwUX0RV_CZ9cDni7...",[]
11117,8YnuOWWGLvmT1S405w3ydw,"[L8wuXgwUX0RV_CZ9cDni7A, oaavK1hzB6S7lK3d4zNpU...",[]
11118,yqa5OwamR2zj8KomdqC3FQ,"[qOFK3KN5VcWMzBde01gPiA, S-MaMTx8Gq1AW144nSerd...",[]


In [66]:
# read the query json file for cph
with open(os.path.join(DATASET_VAL_CPH, "query.json"), "r") as f:
    # load the query json file
    query_cph = json.load(f)['im_paths']
    # cph length
    cph_len = len(query_cph)

# read the query json file for sf
with open(os.path.join(DATASET_VAL_SF, "query.json"), "r") as f:
    # load the query json file
    query_sf = json.load(f)['im_paths']
    # sf length
    sf_len = len(query_sf)

print(f'CPH length: {cph_len}')
print(f'SF length: {sf_len}')

# use the dataframe to get the index of the database images using the retrieved_ids for cph
with open(os.path.join(DATASET_VAL_CPH, "database.json"), "r") as f:
    # load the database json file
    database_cph = json.load(f)
    # get the image ids
    query_image_ids = [x.replace('.', '/').split('/')[4] for x in database_cph['im_paths']]
    
    # go through each row until cph length in the dataframe
    for index, row in df.iloc[:cph_len].iterrows():
        # get the retrieved ids
        retrieved_ids = row['retrieved_ids']
        # get indices of the query images which are in the retrieved_ids
        found_idx = [query_image_ids.index(x) for x in retrieved_ids]
        # add a new column to the dataframe with the indices
        row['retrieved_indicies'].extend(found_idx)

# use the dataframe to get the index of the database images using the retrieved_ids for sf
with open(os.path.join(DATASET_VAL_SF, "database.json"), "r") as f:
    # load the database json file
    database_sf = json.load(f)
    # get the image ids
    query_image_ids = [x.replace('.', '/').split('/')[4] for x in database_sf['im_paths']]
    
    # go through each row until cph length in the dataframe
    for index, row in df.iloc[cph_len:cph_len+sf_len].iterrows():
        # get the retrieved ids
        retrieved_ids = row['retrieved_ids']
        # get indices of the query images which are in the retrieved_ids
        found_idx = [query_image_ids.index(x) for x in retrieved_ids]
        # add a new column to the dataframe with the indices
        row['retrieved_indicies'].extend(found_idx)

display(df)


CPH length: 6595
SF length: 4525


Unnamed: 0,query_id,retrieved_ids,retrieved_indicies
0,x3vA7Bk0HNI6rGkDpDZQUQ,"[X9V1oGRaAEFjq5jufrklTQ, E7gcrCyitkguCnMzoEwm0...","[3, 5130, 5131, 0, 7912, 5132, 9186, 4505, 881..."
1,U9Vj0IV4q1psciXpj51F_w,"[X9V1oGRaAEFjq5jufrklTQ, HU9GEfLAB9pm5RmjW4MLh...","[3, 1, 4504, 1815, 9186, 0, 5131, 2, 5130, 918..."
2,Eh1NwQjH4jbKcWqVJ4ZsJg,"[X9V1oGRaAEFjq5jufrklTQ, _Eq8EgtwLGiMFc7VJdb-Y...","[3, 4, 2, 1815, 7604, 1, 0, 5, 8810, 6, 8815, ..."
3,1RKCGBAWsZbi5dj3vR2mlw,"[_Eq8EgtwLGiMFc7VJdb-YQ, X9V1oGRaAEFjq5jufrklT...","[4, 3, 6, 5, 0, 8815, 9186, 8798, 9181, 1828, ..."
4,LdiYwYkqgUfc1IYDu5ov9A,"[Z4MR4AHQufgsCwiBiqQ23A, _Eq8EgtwLGiMFc7VJdb-Y...","[5, 4, 6, 8815, 5133, 8814, 7049, 8816, 5125, ..."
...,...,...,...
11115,BdY6m-mEqcF_EAgWXhNhMQ,"[c4u-9p_hYbMqZMV9M9NJOw, HWs14OYXLRbvsVJ7sHJX0...","[6312, 6309, 6313, 6306, 6310, 6192, 6311, 540..."
11116,MISH40wohJPWQg2AQL91ww,"[c4u-9p_hYbMqZMV9M9NJOw, L8wuXgwUX0RV_CZ9cDni7...","[6312, 6313, 6180, 6192, 6193, 6170, 6309, 620..."
11117,8YnuOWWGLvmT1S405w3ydw,"[L8wuXgwUX0RV_CZ9cDni7A, oaavK1hzB6S7lK3d4zNpU...","[6313, 6180, 6170, 6171, 6193, 6201, 6312, 619..."
11118,yqa5OwamR2zj8KomdqC3FQ,"[qOFK3KN5VcWMzBde01gPiA, S-MaMTx8Gq1AW144nSerd...","[6314, 286, 285, 1194, 5161, 1191, 5170, 5182,..."
