This notebook provides an in-depth comparison between classical keypoint-based descriptors and holistic deep learning-based descriptors for image retrieval tasks. The analysis aims to evaluate the trade-offs in performance, efficiency, and accuracy across these two types of feature extraction approaches.

Authors: Sanne Eeckhout & Kaj van Rijn

Course: Foundations of Image Retrieval

Date: November 2024

## Load Libraries

In [142]:
import json
import h5py
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from skimage import feature
from skimage.color import rgb2gray
import os
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
from sklearn.metrics import precision_recall_curve
import json
import torch
import torch.nn as nn
from torchvision import models, transforms
import pandas as pd
import plotly.express as px
import kaleido
import time


## Experimental Setup

In [88]:


# Function to load images based on paths and directory
def load_images_from_paths(image_paths, image_directory):
    images = []
    for filename in tqdm(image_paths, desc="Loading images"):
        img_path = os.path.join(image_directory, filename)
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if image is not None:
            images.append(image)
    return images

# Initialize keypoint detectors from both skimage and OpenCV
def initialize_detectors():
    # skimage descriptors
    sift = feature.SIFT()  # SIFT detector
    orb = feature.ORB(n_keypoints=50)  # ORB detector
    brief = feature.BRIEF()  # BRIEF descriptor
    # daisy = feature.daisy()  # DAISY descriptor

    # OpenCV descriptors
    akaze = cv2.AKAZE_create()  # AKAZE
    brisk = cv2.BRISK_create()  # BRISK
    # surf = cv2.xfeatures2d.SURF_create()  # SURF (requires opencv-contrib)
    return {
        # "SURF": surf,
        "ORB": orb,
        "SIFT": sift,
        # "DAISY": daisy,
        # "BRIEF": brief,
        "Harris + BRIEF": brief,
        "AKAZE": akaze,
        "BRISK": brisk,
        # 
    }

# Function to extract descriptors using skimage or OpenCV
def extract_descriptors(image, detector, method_name):
    if method_name in ["SIFT", "ORB", "DAISY"]:
        detector.detect_and_extract(image)
        keypoints = detector.keypoints
        descriptors = detector.descriptors
    elif method_name == "Harris + BRIEF":
        # Detect Harris corners for BRIEF
        keypoints = feature.corner_peaks(feature.corner_harris(image), min_distance=5)
        detector.extract(image, keypoints)
        keypoints, descriptors = keypoints, detector.descriptors
    else:  # OpenCV descriptors
        keypoints, descriptors = detector.detectAndCompute(image, None)
    return keypoints, descriptors

# Bag of words function
def bag_of_words(centroids, img_descriptors):
    bow_vector = np.zeros(len(centroids))
    for descriptor in img_descriptors:
        distances = np.linalg.norm(centroids - descriptor, axis=1)
        closest_centroid = np.argmin(distances)
        bow_vector[closest_centroid] += 1
    return bow_vector

# Retrieve images function
def retrieve_images(map_bow_vectors, query_bow):
    # similarities = np.dot(map_bow_vectors, query_bow)
    # return np.argsort(-similarities)  # Return indices sorted by similarity (descending)
    n_map_bow_vectors = map_bow_vectors.shape[0]
    bow_distances = np.zeros(n_map_bow_vectors)
    most_similar = None  # use this to 
    
    # Compute the Euclidean distance between query_bow and each map BoW vector
    for i in range(n_map_bow_vectors):
        bow_distances[i] = np.linalg.norm(map_bow_vectors[i] - query_bow)
    
    # Get indices of map images sorted by smallest distance (most similar first)
    most_similar = np.argsort(bow_distances)
    
    return most_similar




# Experiment function
# def run_experiment_with_dataset():
# Load images
with open("data02/database/database_lite.json", "r") as f:
    m_idx = json.load(f)
    m_imgs = np.array(m_idx["im_paths"])
    m_loc = np.array(m_idx["loc"])

with open("data02/query/query_lite.json", "r") as f:
    q_idx = json.load(f)
    q_imgs = np.array(q_idx["im_paths"])
    q_loc = np.array(q_idx["loc"])

with h5py.File("data02/london_lite_gt.h5", "r") as f:
    fovs = f["fov"][:]
    sim = f["sim"][:].astype(np.uint8)

database_images = load_images_from_paths(m_imgs, "data02/")
query_images = load_images_from_paths(q_imgs, "data02/")

detectors = initialize_detectors()


Loading images: 100%|██████████| 1000/1000 [00:00<00:00, 1562.10it/s]
Loading images: 100%|██████████| 500/500 [00:00<00:00, 1554.60it/s]


In [60]:

# Efficiently load retrieved images once
def load_retrieved_images(file_path):
    with open(file_path, "r") as file:
        return [eval(line.strip()) for line in file]

# Precision at k
def precision_at_k(relevant, retrieved, k):
    relevant_retrieved = [1 if r in relevant else 0 for r in retrieved[:k]]
    return sum(relevant_retrieved) / k

# Mean average precision (mAP)
def mean_average_precision(relevant, retrieved, k_list, precision_scores):
    """
    Calculate mean average precision (mAP) and print precision at k for ks in k_list.

    Parameters:
        relevant (list): List of relevant images (ground truth).
        retrieved (list): List of retrieved images in ranked order.
        k_list (list of int): List of k values for precision calculation (e.g., [1, 5, 10]).

    Returns:
        float: Mean Average Precision (mAP).
    """
    precisions = []
    
    # Calculate precision for each rank and store in precisions if relevant
    for k in range(1, len(retrieved) + 1):
        if retrieved[k - 1] in relevant:
            precision_at_k_value = precision_at_k(relevant, retrieved, k)
            precisions.append(precision_at_k_value)
            
            if k in k_list:
                # print(f"Precision@{k}: {precision_at_k_value:.4f}")
                precision_scores[k].append(precision_at_k_value)
    
    return np.mean(precisions) if precisions else 0, precision_scores

def mAP(name_detector, k_list):
    """
    Calculate mean Average Precision (mAP) for different k values.
    
    Parameters:
        name_detector (str): The name of the detector.
        query_images (list): List of query images.
        sim (numpy array): Similarity matrix.
        k_list (list of int): List of k values for precision calculation (e.g., [5] or [5, 10]).
    
    Returns:
        dict: A dictionary with k values as keys and lists of precision scores for each query as values.
    """
    print(f"\nUsing {name_detector} detector...")

    # Load all retrieved images once
    retrieved_images_all = load_retrieved_images(f"data/results_{name_detector}.txt")

    # Initialize precision scores dictionary for each k value
    precision_scores = {k: [] for k in k_list}
    map_scores = []

    for query_idx in tqdm(range(len(query_images)), desc=f"Calculating Precisions"):
        img = query_images[query_idx]
        relevant_images = np.where(sim[query_idx, :] == 1)[0]  # Get relevant images for this query
        retrieved_images = retrieved_images_all[query_idx]  # Get retrieved images for this query

        # Calculate mean average precision (mAP) for the query
        map_score, precision_scores = mean_average_precision(relevant_images, retrieved_images, k_list, precision_scores)
        # print(map_score)

        map_scores.append(map_score)

        # Optionally calculate precision at k for each k in k_list
        # for k in k_list:
        #     precision_k = precision_at_k(relevant_images, retrieved_images, k)
        #     precision_scores[k].append(precision_k)

    with open(f"data/mapscores_{name_detector}.json", "w") as file:
        json.dump(map_scores, file)

    # Calculate the mean precision for each k and the mean map score
    mean_precision_scores = {k: np.mean(precision_scores[k]) for k in k_list}

    for k, mean_precision in mean_precision_scores.items():
        print(f"Precision@{k}: {mean_precision:.4f}")
    mean_map = np.mean(map_scores)
    
    print(f"\n{name_detector} Mean Average Precision (mAP) over all queries: {mean_map:.4f}")
    # print("mapscore", map_scores)
    return map_scores

In [52]:
def load_map_scores(file_path):
    """Load map_scores from a file."""
    with open(file_path, "r") as file:
        line = file.readline().strip()
        # Convert the string representation of the list back to an actual list
        map_scores = eval(line)
    return map_scores

def load_map_scores_json(file_path):
    with open(file_path, "r") as file:
        map_scores = json.load(file)
    return map_scores

### Keypoint Descriptors

In [89]:
detectors=initialize_detectors()

# Iterate through each descriptor
for name, detector in detectors.items():
    print(f"\nUsing {name} detector...")

    descriptors_list = []
    average_precisions = []

    # Start tracking time for the descriptor extraction process
    start_time = time.time()

    for image in tqdm(database_images, desc=f"Extracting {name} descriptors"):
        _, descriptors = extract_descriptors(image, detector, name)
        if descriptors is not None:
            descriptors_list.append(descriptors)

    # Track time for descriptor extraction
    extraction_time = time.time() - start_time
    print(f"Time taken to extract descriptors for {name}: {extraction_time:.4f} seconds")

    all_retrieved = []

    if descriptors_list:
        # Perform k-means clustering to find the centroids
        all_descriptors = np.vstack(descriptors_list)
        kmeans = KMeans(n_clusters=32, random_state=42).fit(all_descriptors)
        centroids = kmeans.cluster_centers_
        map_bow_vectors = np.array([bag_of_words(centroids, desc) for desc in descriptors_list])

        # Start tracking time for image retrieval
        retrieval_start_time = time.time()

        for query_idx in tqdm(range(len(query_images)), desc=f"Retrieving images from Query List"):
            img = query_images[query_idx]
            _, query_descriptors = extract_descriptors(img, detector, name)

            if query_descriptors is not None:
                query_bow = bag_of_words(centroids, query_descriptors)
                relevant_images = np.where(sim[query_idx, :] == 1)[0]
                retrieved_images = retrieve_images(map_bow_vectors, query_bow)
                
                # Save the retrieved images for later evaluation
                with open(f"temp/results_{name}.txt", "a") as file:
                    file.write(f"{retrieved_images.tolist()}\n")
        
        # Track time for image retrieval
        retrieval_time = time.time() - retrieval_start_time
        print(f"Time taken for retrieval with {name}: {retrieval_time:.4f} seconds")

    # You can also calculate and print the average precision here if needed
    # print(f"Mean Average Precision (mAP) over all queries: {np.mean(average_precisions):.4f}") 



Using ORB detector...


Extracting ORB descriptors: 100%|██████████| 1000/1000 [02:16<00:00,  7.31it/s]


Time taken to extract descriptors for ORB: 136.8114 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [01:15<00:00,  6.60it/s]


Time taken for retrieval with ORB: 75.7888 seconds

Using SIFT detector...


Extracting SIFT descriptors: 100%|██████████| 1000/1000 [05:34<00:00,  2.99it/s]


Time taken to extract descriptors for SIFT: 334.9752 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [03:03<00:00,  2.73it/s]


Time taken for retrieval with SIFT: 183.1241 seconds

Using Harris + BRIEF detector...


Extracting Harris + BRIEF descriptors: 100%|██████████| 1000/1000 [00:41<00:00, 23.91it/s]


Time taken to extract descriptors for Harris + BRIEF: 41.8193 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:35<00:00, 14.27it/s]


Time taken for retrieval with Harris + BRIEF: 35.0477 seconds

Using AKAZE detector...


Extracting AKAZE descriptors: 100%|██████████| 1000/1000 [00:09<00:00, 108.97it/s]


Time taken to extract descriptors for AKAZE: 9.1785 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:07<00:00, 67.86it/s]


Time taken for retrieval with AKAZE: 7.3698 seconds

Using BRISK detector...


Extracting BRISK descriptors: 100%|██████████| 1000/1000 [00:41<00:00, 23.96it/s]


Time taken to extract descriptors for BRISK: 41.7359 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:28<00:00, 17.64it/s]

Time taken for retrieval with BRISK: 28.3386 seconds





In [68]:
for name, detector in detectors.items():
    scores = mAP(name, [1,5,10,20])


Using ORB detector...


Calculating Precisions: 100%|██████████| 500/500 [00:05<00:00, 91.04it/s] 


Precision@1: 1.0000
Precision@5: 0.2000
Precision@10: 0.1200
Precision@20: 0.0583

ORB Mean Average Precision (mAP) over all queries: 0.0170

Using SIFT detector...


Calculating Precisions: 100%|██████████| 500/500 [00:04<00:00, 102.77it/s]


Precision@1: 1.0000
Precision@5: 0.2800
Precision@10: 0.1750
Precision@20: 0.0889

SIFT Mean Average Precision (mAP) over all queries: 0.0485

Using Harris + BRIEF detector...


Calculating Precisions: 100%|██████████| 500/500 [00:04<00:00, 111.80it/s]


Precision@1: 1.0000
Precision@5: 0.2769
Precision@10: 0.1667
Precision@20: 0.0909

Harris + BRIEF Mean Average Precision (mAP) over all queries: 0.0407

Using AKAZE detector...


Calculating Precisions: 100%|██████████| 500/500 [00:05<00:00, 92.41it/s] 


Precision@1: 1.0000
Precision@5: 0.2000
Precision@10: 0.1200
Precision@20: 0.0938

AKAZE Mean Average Precision (mAP) over all queries: 0.0285

Using BRISK detector...


Calculating Precisions: 100%|██████████| 500/500 [00:05<00:00, 90.54it/s] 

Precision@1: 1.0000
Precision@5: 0.2571
Precision@10: 0.1857
Precision@20: 0.1000

BRISK Mean Average Precision (mAP) over all queries: 0.0308





### Holistic Descriptors

In [99]:
def initialize_cnn_model(model_name="resnet50"):
    if model_name == "ResNet50":
        model = models.resnet50(pretrained=True)
    elif model_name == "VGG16":
        model = models.vgg16(pretrained=True)
    elif model_name == "MobileNet_V2":
        model = models.mobilenet_v2(pretrained=True)
    elif model_name == "EfficientNet_B0":
        model = models.efficientnet_b0(pretrained=True)
    else:
        raise ValueError("Model name not recognized. Choose 'resnet50', 'vgg16', 'mobilenet_v2', or 'efficientnet_b0'.")

    # Remove the classification layer
    model = nn.Sequential(*list(model.children())[:-1])
    model.eval()
    return model

# Function to extract descriptors using a CNN model
def extract_cnn_descriptors(image, model, transform):
    image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image = cv2.resize(image, (224, 224))
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        features = model(image_tensor).flatten()
    return features.cpu().numpy()

# Retrieval and evaluation functions
def retrieve_images(map_vectors, query_vector):
    similarities = np.dot(map_vectors, query_vector)
    return np.argsort(-similarities)  # Descending order by similarity

# Run experiment
def run_cnn_experiment(database_images, query_images, model_name, sim):

    # Start tracking time for the descriptor extraction process
    start_time = time.time()

    # Initialize CNN model and transforms
    cnn_model = initialize_cnn_model(model_name=model_name)
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Extract CNN descriptors for database images
    cnn_features_list = []
    for image in tqdm(database_images, desc="Extracting CNN descriptors"):
        features = extract_cnn_descriptors(image, cnn_model, transform)
        cnn_features_list.append(features)

    # Track time for descriptor extraction
    extraction_time = time.time() - start_time
    print(f"Time taken to extract descriptors for {model_name}: {extraction_time:.4f} seconds")

    map_vectors = np.vstack(cnn_features_list)  # Combine features into a matrix
    average_precisions = []

    # File to store retrieved images for each query
    results_file_path = f"temp/results_{model_name}.txt"
    with open(results_file_path, "w") as file:
        # Start tracking time for image retrieval
        retrieval_start_time = time.time()

        for query_idx in tqdm(range(len(query_images)), desc=f"Retrieving images from Query List"):
            img = query_images[query_idx]
            query_features = extract_cnn_descriptors(img, cnn_model, transform)
            retrieved_images = retrieve_images(map_vectors, query_features)
            # relevant_images = np.where(sim[query_idx, :] == 1)[0]

            # Save retrieved images to file
            file.write(f"{retrieved_images.tolist()}\n")
        
        # Track time for image retrieval
        retrieval_time = time.time() - retrieval_start_time
        print(f"Time taken for retrieval with {model_name}: {retrieval_time:.4f} seconds")


#### ResNet50

In [97]:
run_cnn_experiment(database_images, query_images, model_name="ResNet50", sim=sim)

Extracting CNN descriptors: 100%|██████████| 1000/1000 [00:49<00:00, 20.26it/s]


Time taken to extract descriptors for ResNet50: 49.9926 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:24<00:00, 20.51it/s]

Time taken for retrieval with ResNet50: 24.3794 seconds





In [61]:
map_scores_resnet = mAP("ResNet50", [1,5,10,20])


Using ResNet50 detector...


Calculating Precisions: 100%|██████████| 500/500 [00:04<00:00, 107.09it/s]

Precision@1: 1.0000
Precision@5: 0.4167
Precision@10: 0.1500
Precision@20: 0.0875

ResNet50 Mean Average Precision (mAP) over all queries: 0.0402





#### MobileNet

In [98]:
run_cnn_experiment(database_images, query_images, model_name="MobileNet_V2", sim=sim)

Extracting CNN descriptors: 100%|██████████| 1000/1000 [00:50<00:00, 19.71it/s]


Time taken to extract descriptors for MobileNet_V2: 50.8576 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:27<00:00, 18.13it/s]

Time taken for retrieval with MobileNet_V2: 27.5862 seconds





In [62]:
map_scores_mobilenet = mAP("MobileNet_V2", [1,5,10,20])


Using MobileNet_V2 detector...


Calculating Precisions: 100%|██████████| 500/500 [00:03<00:00, 127.14it/s]

Precision@1: 1.0000
Precision@5: 0.4000
Precision@10: 0.2083
Precision@20: 0.1429

MobileNet_V2 Mean Average Precision (mAP) over all queries: 0.0875





#### EfficientNet

In [101]:
run_cnn_experiment(database_images, query_images, model_name="EfficientNet_B0", sim=sim)

Extracting CNN descriptors: 100%|██████████| 1000/1000 [03:26<00:00,  4.84it/s]


Time taken to extract descriptors for EfficientNet_B0: 206.9107 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [01:44<00:00,  4.78it/s]

Time taken for retrieval with EfficientNet_B0: 104.6878 seconds





In [63]:
map_scores_efficientnet = mAP("EfficientNet_B0", [1,5,10,20])


Using EfficientNet_B0 detector...


Calculating Precisions: 100%|██████████| 500/500 [00:03<00:00, 140.94it/s]

Precision@1: 1.0000
Precision@5: 0.3130
Precision@10: 0.2000
Precision@20: 0.1222

EfficientNet_B0 Mean Average Precision (mAP) over all queries: 0.0959





#### VGG16

In [102]:
run_cnn_experiment(database_images, query_images, model_name="VGG16", sim=sim)


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.

Extracting CNN descriptors: 100%|██████████| 1000/1000 [01:51<00:00,  9.00it/s]


Time taken to extract descriptors for VGG16: 112.6931 seconds


Retrieving images from Query List: 100%|██████████| 500/500 [00:54<00:00,  9.26it/s]

Time taken for retrieval with VGG16: 54.0054 seconds





In [64]:
map_scores_vgg = mAP("VGG16", [1,5,10,20])


Using VGG16 detector...


Calculating Precisions: 100%|██████████| 500/500 [00:03<00:00, 134.83it/s]

Precision@1: 1.0000
Precision@5: 0.5520
Precision@10: 0.3714
Precision@20: 0.1571

VGG16 Mean Average Precision (mAP) over all queries: 0.0938





## Visualisation

In [91]:
# Data for the descriptors and their performance metrics
data = {
    'Descriptor': [
        'ORB', 'SIFT', 'Harris+BRIEF', 'AKAZE', 'BRISK', 
        'ResNet50', 'VGG16', 'MobileNetV2', 'EfficientNetB0'
    ],
    'Precision@5': [0.2000, 0.2800, 0.2769, 0.2000, 0.2571, 0.4167, 0.5520, 0.4000, 0.3130],
    'Precision@10': [0.1200, 0.1750, 0.1667, 0.1200, 0.1857, 0.1500, 0.3714, 0.2083, 0.2000],
    'Precision@20': [0.0583, 0.0889, 0.0909, 0.0938, 0.1000, 0.0875, 0.1571, 0.1429, 0.1222],
    'mAP': [0.0170, 0.0485, 0.0407, 0.0285, 0.0308, 0.0402, 0.0938, 0.0875, 0.0959]
}

# Convert the data into a DataFrame
df = pd.DataFrame(data)

# Add a new column to classify descriptors as 'Classical' or 'Holistic'
df['Type'] = ['Classical', 'Classical', 'Classical', 'Classical', 'Classical', 
              'Holistic', 'Holistic', 'Holistic', 'Holistic']

# Melt the DataFrame to structure it for Plotly Express
df_melted = df.melt(id_vars=['Descriptor', 'Type'], var_name='Metric', value_name='Value')

# Create a grouped bar plot using Plotly Express
fig = px.bar(
    df_melted,
    x='Descriptor',
    y='Value',
    color='Metric',
    barmode='group',
    title='Performance Metrics for Descriptors',
    labels={'Value': 'Score', 'Descriptor': 'Descriptor'},
    text='Value',
    facet_col='Type',  # Separate Classical and Holistic descriptors
    # color_discrete_sequence=px.colors.qualitative.Set1  # Use distinct colors
)

# Update layout for better readability
fig.update_layout(
    template='plotly_white',
    yaxis=dict(title='Score', range=[0, 0.6]),
    legend_title='Metric',
    xaxis={'categoryorder': 'total descending'},
)

# Display the plot
fig.show()


In [141]:
# Data for the descriptors and their performance metrics
# Data for the descriptors
data_classical = {
    'Descriptor': ['ORB', 'SIFT', 'Harris+BRIEF', 'AKAZE', 'BRISK'],
    'Precision@5': [0.2000, 0.2800, 0.2769, 0.2000, 0.2571],
    'Precision@10': [0.1200, 0.1750, 0.1667, 0.1200, 0.1857],
    'Precision@20': [0.0583, 0.0889, 0.0909, 0.0938, 0.1000],
    'mAP': [0.0170, 0.0485, 0.0407, 0.0285, 0.0308]
}

data_holistic = {
    'Descriptor': ['ResNet50', 'VGG16', 'MobileNetV2', 'EfficientNetB0'],
    'Precision@5': [0.4167, 0.5520, 0.4000, 0.3130],
    'Precision@10': [0.1500, 0.3714, 0.2083, 0.2000],
    'Precision@20': [0.0875, 0.1571, 0.1429, 0.1222],
    'mAP': [0.0402, 0.0938, 0.0875, 0.0959]
}
data=data_holistic

# Convert the data into a DataFrame
df = pd.DataFrame(data)

# Add a new column to classify descriptors as 'Classical' or 'Holistic'
# df['Type'] = ['Classical', 'Classical', 'Classical', 'Classical', 'Classical', 
#               'Holistic', 'Holistic', 'Holistic', 'Holistic']

# Melt the DataFrame to structure it for Plotly Express
df_melted = df.melt(id_vars=['Descriptor'], var_name='Metric', value_name='Value')

# Create a grouped bar plot using Plotly Express
fig = px.bar(
    df_melted,
    x='Descriptor',
    y='Value',
    color='Metric',
    barmode='group',
    # title='Performance Metrics for Descriptors',
    labels={'Value': 'Score', 'Descriptor': 'Descriptor'},
    text='Value',
    # facet_col='Type',  # Separate Classical and Holistic descriptors
    # color_discrete_sequence=px.colors.qualitative.Set1  # Use distinct colors
)

# Update layout for better readability
fig.update_layout(
    template='plotly_white',
    yaxis=dict(title='Score', range=[0, 0.6]),
    legend_title='Metric',
    xaxis={'categoryorder': 'total descending'},
    width=800  # Set width of the figure to half
)

# Display the plot
fig.show()


In [132]:
# Data for the descriptors, their total time, and mAP
data = {
    'Descriptor': [
        'ORB', 'SIFT', 'Harris+BRIEF', 'AKAZE', 'BRISK',
        'ResNet50', 'VGG16', 'MobileNetV2', 'EfficientNetB0'
    ],
    'Total Time (s)': [
        212.6002, 518.0993, 76.867, 16.5483, 70.0745, 
        74.372, 166.6985, 78.4438, 311.5985
    ],
    'mAP': [
        0.0170, 0.0485, 0.0407, 0.0285, 0.0308, 
        0.0402, 0.0938, 0.0875, 0.0959
    ],
    'Type': [
        'Classical', 'Classical', 'Classical', 'Classical', 'Classical', 
        'Holistic', 'Holistic', 'Holistic', 'Holistic'
    ]
}

# Convert the data into a DataFrame
df = pd.DataFrame(data)

# Create a scatter plot using Plotly Express
fig = px.scatter(
    df,
    x='Total Time (s)',
    y='mAP',
    color='Type',  # Different colors for classical vs holistic
    # title='Total Time vs mAP for Descriptors',
    labels={'Total Time (s)': 'Total Time (s)', 'mAP': 'Mean Average Precision (mAP)', 'Descriptor': 'Descriptor'},
    hover_data=['Descriptor'],  # Show the descriptor name on hover
    text='Descriptor',  # Display the descriptor name at each point
    width=800  # Set width of the figure to half
)

# Adjust the text position for ResNet50
fig.update_traces(
    textposition='top right',  # Default position for other points
    textfont=dict(size=10)
)

# For ResNet50, set the text position to 'bottom left'
print(fig.data)
# fig.data[5].update(textposition='bottom left')  # ResNet50 is at index 5
fig.data[1].update(textposition='middle left')#, selector=dict(text="ResNet50"))

# Update layout for better readability
fig.update_layout(
    template='plotly_white',
    xaxis=dict(title='Total Time (s)'),
    yaxis=dict(title='Mean Average Precision (mAP)', range=[0, 0.1]),
    legend_title='Descriptor Type'
)

# Display the plot
fig.show()
# fig.write_image(f"documentation/time_vs_map.png", scale=2)


(Scatter({
    'customdata': array([['ORB'],
                         ['SIFT'],
                         ['Harris+BRIEF'],
                         ['AKAZE'],
                         ['BRISK']], dtype=object),
    'hovertemplate': ('Type=Classical<br>Total Time (' ... '{customdata[0]}<extra></extra>'),
    'legendgroup': 'Classical',
    'marker': {'color': '#636efa', 'symbol': 'circle'},
    'mode': 'markers+text',
    'name': 'Classical',
    'orientation': 'v',
    'showlegend': True,
    'text': array(['ORB', 'SIFT', 'Harris+BRIEF', 'AKAZE', 'BRISK'], dtype=object),
    'textfont': {'size': 10},
    'textposition': 'top right',
    'x': array([212.6002, 518.0993,  76.867 ,  16.5483,  70.0745]),
    'xaxis': 'x',
    'y': array([0.017 , 0.0485, 0.0407, 0.0285, 0.0308]),
    'yaxis': 'y'
}), Scatter({
    'customdata': array([['ResNet50'],
                         ['VGG16'],
                         ['MobileNetV2'],
                         ['EfficientNetB0']], dtype=object),
    'h