# <a id='toc1_'></a>[Vision Transformer Memory as a Differentiable Search Index for Image Retrieval](#toc0_)

---

## <a id='toc1_1_'></a>[📚 Notebook Overview](#toc0_)

This notebook explores a novel [information retrieval (IR)](https://en.wikipedia.org/wiki/Information_retrieval) framework applied to image retrieval that utilizes a **differentiable function** to generate a **sorted list of image identifiers** in response to a given **image query**.

The approach is called **Differentiable Search Index (DSI)**, and was originally proposed in the paper [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf) by researchers at Google Research.

In its original formulation, **DSI** aims at both encompassing all document's corpus information and executing retrieval within a single **Transformer language model**, instead of adopting the index-then-retrieve pipeline used in most modern IR sytems.

The notebook presents the implemented DSI solution applied to an image retrieval task: a **Sequence to Sequence Vision Transformer** (ViT) model `f` that, given an image query `q` as input, returns a list of image IDs ranked by relevance to the given image query, and compares its performance with a traditional "index-then-retrieve" approach based on a **BoVW** baseline model.

We evaluate the performance of the proposed models using the **Indexing Accuracy**, **Mean Average Precision (MAP)** and **Recall at K** metrics computed on multiple variations of the **ImageNet** and the **MS COCO** datasets, and we compare the results obtained for multiple ViT variations and  configurations with the aforementioned **BoVW** baseline.

## <a id='toc1_2_'></a>[📝 Author](#toc0_)

**Valerio Di Stefano** - _"Sapienza" University of Rome_
<br/>
Email: [distefano.1898728@studenti.uniroma1.it](mailto:distefano.1898728@studenti.uniroma1.it)

## <a id='toc1_3_'></a>[🔗 External Links](#toc0_)

* **Main Related Work**: [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf)

  _Authors_: Yi Tay, Vinh Q. Tran, Mostafa Dehghani, Jianmo Ni, Dara Bahri, Harsh Mehta, Zhen Qin, Kai Hui, Zhe Zhao, Jai Gupta, Tal Schuster, William W. Cohen, Donald Metzler
  
* **Project Repository**: [GitHub Repository](https://github.com/valeriodiste/computer_vision_project)



---

## <a id='toc1_4_'></a>[📌 Table of Contents](#toc0_)

**Table of contents**<a id='toc0_'></a>    
- [📄 Vision Transformer Memory as a Differentiable Search Index for Image Retrieval](#toc1_)    
  - [📚 Notebook Overview](#toc1_1_)    
  - [📝 Author](#toc1_2_)    
  - [🔗 External Links](#toc1_3_)    
  - [📌 Table of Contents](#toc1_4_)    
- [🚀 Getting Started](#toc2_)    
  - [Collect Source Files](#toc2_1_)    
  - [Define Dataset](#toc2_2_)    
  - [Install & Import Libraries](#toc2_3_)    
  - [Configuration, Hyperparameters and Constants](#toc2_4_)    
- [💾 Data Preparation](#toc3_)    
  - [Download Data & Resources](#toc3_1_)    
  - [Database Creation](#toc3_2_)    
  - [Indexing & Retrieval Datasets](#toc3_3_)    
- [🛍️ Bag of Visual Words Model](#toc4_)    
    - [BoVW Model Initialization](#toc4_1_1_)    
- [🤖 Vision Transformer Model (DSI approach)](#toc5_)    
    - [ViT Model Initialization](#toc5_1_1_)    
- [📈 Evaluation](#toc6_)    
  - [BoVW Model Evaluation](#toc6_1_)    
  - [ViT Model Evaluation](#toc6_2_)    

---


# <a id='toc2_'></a>[🚀 Getting Started](#toc0_)


First of all, we check if we are running the notebook on Google colab or locally, defining the `RUNNING_ON_COLAB` constant used throughout the notebook.

In [None]:
# Check if running on colab or locally
try:
	from google.colab import files
	RUNNING_IN_COLAB = True
	print("Running on Google Colab.")
except ModuleNotFoundError:
	RUNNING_IN_COLAB = False
	print("Running locally.")


## <a id='toc2_1_'></a>[Collect Source Files](#toc0_)


#### <a id='toc2_1_1_1_'></a>[Clone Project's GitHub Repository](#toc0_)

We **clone the project's repository** from GitHub to access the source files for datasets, models, evaluation and utilities.


In [None]:
# Clone the git repository of the project for the source files
!git clone https://github.com/valeriodiste/computer_vision_project.git

#### <a id='toc2_1_1_2_'></a>[Pull Latest Files Changes](#toc0_)

We also **pull the latest changes** from the repository and store them in the `./computer_vision_project` directory.


In [None]:
# Change the working directory to the cloned repository
%cd /content/computer_vision_project
# Pull the latest changes from the repository
!git pull origin main
# Change the working directory to the parent directory
%cd ..

## <a id='toc2_2_'></a>[Define Dataset](#toc0_)

We **define the main dataset types** avilable:

- **`COCO`**: Use the **MS COCO** dataset for image segmentation and object detection tasks.
- **`IMAGENET_FULL`**: Use the full **ImageNet** dataset for image classification tasks.
- **`IMAGENET_REDUCED`**: Use a small subset of the **ImageNet** dataset for image classification tasks, with less images and fixed square aspect ratio images from 160x160 to 512x512 pixels.
- **`IMAGENET_TINY`**: Use the full **ImageNet** dataset with images cropped and resized to 64x64 pixels for image classification tasks.

In [None]:
# Define the possible datasets
class DatasetType:
	COCO = "coco"							# COCO dataset (using the COCO API for the image segmentation dataset)
	IMAGENET_FULL = "imagenet_full"			# Full ImageNet dataset (heavy dataset, requires a HuggingFace token, not recommended)
	IMAGENET_REDUCED = "imagenet_reduced"	# Reduced ImageNet dataset (less images, fixed square image sizes from 160x160 to 512x512, no token required)
	IMAGENET_TINY = "imagenet_tiny"			# Tiny ImageNet dataset (contains all ImageNet images as 64x64 images, no token required)

We **define the dataset to use** for the project.

In [None]:
# Define the dataset to use
DATASET = DatasetType.IMAGENET_REDUCED

## <a id='toc2_3_'></a>[Install & Import Libraries](#toc0_)

#### <a id='toc2_3_1_1_'></a>[Install Libraries](#toc0_)

We **install all the necessary libraries** for this notebook.

- **`pytorch-lightning`**: A **lightweight PyTorch wrapper** for simplifying PyTorch code.
- **`wandb`**: The python package for **Weights & Biases**, a tool for experiment tracking, dataset versioning, and project collaboration (used for **logging and visualization**).
- **`pycocotools`**: A Python API for the **MS COCO dataset**.
- **`datasets`**: A library for easily **loading and preprocessing datasets** from the Hugging Face Hub.

In [None]:
# Install the required packages
%%capture
%pip install pytorch-lightning
%pip install wandb

# Install the packages required for the dataset
if DATASET == DatasetType.COCO:
	%pip install pycocotools
else:
	%pip install datasets

#### <a id='toc2_3_1_2_'></a>[Import Modules](#toc0_)

We then **import the required modules**, including `PyTorch`, `PyTorch Lightning`, `pycocotools`/`datasets` and `W&B`, plus other useful modules and libraries (`Numpy`, `CV2`, etc...).

In [None]:
# Import the standard libraries
import os
import json
import random
import logging
import math

# Import the PyTorch and Lighning libraries and modules
import torch
import pytorch_lightning as pl

# Import the COCO or HuggingFace datasets module
if DATASET == DatasetType.COCO:
	from pycocotools.coco import COCO
else:
	import datasets as hf_datasets

# Import the W&B (Weights & Biases) library
import wandb
from wandb.sdk import wandb_run
from pytorch_lightning.loggers import WandbLogger

# Import other libraries
import numpy as np
import cv2
import matplotlib.pyplot as plt
from skimage import io
import base64

# Import the tqdm library (for the progress bars) 
if not RUNNING_IN_COLAB:
	from tqdm import tqdm
else:
	from tqdm.notebook import tqdm

We also import our own **custom modules** (cloned from the repository) containing Python classes for **datasets**, **models**, **evaluation**, and **utilities**.

In [None]:
# Import the custom modules
if not RUNNING_IN_COLAB:
	# We are running locally (not on Google Colab, import modules from the "src" directory in the current directory)
	from src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore
else:
	# We are running on Google Colab (import modules from the pulled repository stored in the project's directory)
	from computer_vision_project.src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from computer_vision_project.src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore

## <a id='toc2_4_'></a>[Configuration, Hyperparameters and Constants](#toc0_)

#### <a id='toc2_4_1_1_'></a>[Random Seed](#toc0_)

We **seed the random number generators** for reproducibility.

In [None]:
# Set the random seeds for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED)

#### <a id='toc2_4_1_2_'></a>[Device Configuration](#toc0_)

We **set the device** to GPU if available, otherwise we use the CPU.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device.type}")

#### <a id='toc2_4_1_3_'></a>[Database Constants](#toc0_)

We **define the constants** used for the **database resources download** and the **dataset creation**.

In [None]:
# Total number of images to consider in the dataset (will be split into training, validation and test sets), set to -1 to use all the available images
NUMBER_OF_IMAGES_IN_DB = 384

# Minimum number of images per class
# NOTE: For the "ImageNet Reduced" dataset, make sure to set MIN_IMAGES_PER_CLASS >= NUMBER_OF_IMAGES_IN_DB // 10
MIN_IMAGES_PER_CLASS = 64

# Percentage of images, for each class, to use for the image retrieval dataset (the remaining images will be used for the indexing dataset, i.e. will be added in the images database)
IMAGE_RETRIEVAL_DB_PERCENTAGE = 0.75

# Whether to shuffle the images before splitting them into training, validation and test sets (to make their IDs random)
SHUFFLE_DB_IMAGES = False

# MS COCO dataset constants
# NOTE: Use only if "DATASET" is set to "COCO"
COCO_DATA_YEAR = '2014'		# '2014' or '2017'
COCO_DATA_TYPE = 'val'		# 'train' or 'val'
COCO_DATA_CAPTIONS_FILE = f"/annotations/captions_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the annotations file inside the DATA_FOLDER
CODO_DATA_INSTANCES_FILE = f"/annotations/instances_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the instances file inside the DATA_FOLDER
MIN_COCO_IMAGE_CAPTIONS = 5		# Minimum number of captions for an image to be added in the dataset

# ImageNet dataset constants
# NOTE: Use only if "DATASET" is set to "imagenet", "imagenet_full" or "imagenet_reduced"
IMAGENET_DATA_FOLDER = "data/imagenet"	# Path of the ImageNet dataset folder inside the DATA_FOLDER
IMAGENET_DATA_CLASSES_FILE = f"{IMAGENET_DATA_FOLDER}/classes.py"	# Path of the classes file inside the IMAGENET_DATA_FOLDER

# Number of image patches per dimension (i.e. both vertically and horizontally, since images have a square aspect ratio)
# NOTE: Only works for COCO and full ImageNet datasets (not for ImageNet small or ImageNet reduced, which always have 64x64 and 160x160 images respectively, thus a number of patches per dimension which depends on the chosen PATCH_SIZE)
IMAGE_PATCHES_PER_DIMENSION = 16

# If not enough square images are found, also accept images that have this max aspect difference (they will be cropped to a square aspect ratio later)
# NOTE: Only works for COCO and full ImageNet datasets (not for ImageNet small or ImageNet reduced, which always have 64x64 and 160x160 images respectively, thus a fixed aspect ratio of 1)
MAX_ASPECT_RATIO_TOLERANCE = 0.1 	# Accept images that are 10% wider than they are tall (or vice versa)


#### <a id='toc2_4_1_4_'></a>[Models Hyperparameters](#toc0_)

We then **define the constant** representing **hyperparameters** used for the **Vision Transformer** model.

In [None]:
# Size of the image patches
IMAGE_PATCH_SIZE = 16

# Dimensionality of the feature vectors given as input to the Vision Transformer
TRANSFORMER_EMBEDDINGS_SIZE = 128

# Dimensionality of the hidden layers in the feed-forward networks within the Transformer
TRANSFORMER_FNN_HIDDEN_SIZE = 256

# Number of heads to use in the Multi-Head Attention block
# NOTE: must be a divisor of "TRANSFORMER_EMBEDDINGS_SIZE"
TRANSFORMER_NUM_HEADS = 4

# Number of layers to use in the Transformer (i.e. the number of Multi-Head Attention blocks and Feed-Forward networks)
TRANSFORMER_NUM_LAYERS = 6

# Number of epochs to train the Transformer model for the indexing and retrieval tasks
TRANSFORMER_INDEXING_TRAINING_EPOCHS = 750
TRANSFORMER_RETRIEVAL_TRAINING_EPOCHS = 75

# Batch size for training the Transformer model
TRANSFORMER_BATCH_SIZE = 32

# DROPOUT rate to apply in the feed-forward network and on the input encoding
TRANSFORMER_DROPOUT = 0.1

# Learning rate for the optimizer
TRANSFORMER_LEARNING_RATE = 0.0005

# Whether to use learned positional encodings in the Transformer model (instead of the standard sinusoidal positional encodings)
LEARN_POSITIONAL_ENCODINGS = False


#### <a id='toc2_4_1_5_'></a>[Evauation Constants](#toc0_)

We also define the constants used for the evaluation of the various models (i.e. to compute the **Indexing Accuracy**, **Mean Average Precision** and the **Recall at K**).

In [None]:
# Define the number of IDs to retrieve for each query image to calculate the IA@K results
IA_K = 5

# Define the number of images K to retrieve for each query image and the number of queries N to calculate the mean average precision (MAP@K)
MAP_K = 10
MAP_N = 10

# Define the number of images K to retrieve for each query image to calculate the Recall@K results
RECALL_K = 100
RECALL_N = 10

# Whether to evaluate the models (i.e. compute the IA@K, MAP@K and Recall@K metrics for the trained models on the test datasets)
EVALUATE_MODELS = True

# Whether to print the debug information during the IA@K, MAP@K and Recall@K evaluation of the models
PRINT_EVALUATION_DEBUG = False

#### Demo Data Constants

We define constants to use in case we want to test the model using demo data, which saves datasets and database creation time.

To use demo data, set the `USE_DEMO_DATA` constant to `True` and specify the `DEMO_DATA_SIZE` constant to the desired demo database size.

**❗NOTE**: Using demo data reuires said data to be already available in the "**./demo/{`DEMO_DATA_SIZE`}**" directory.<br/>
If no demo data is available, create the data by first running the Notebook with `USE_DEMO_DATA` set to `False` for the wanted demo data size, then copy the created datasets and database files inside the "**./demo/{`DEMO_DATA_SIZE`}**" directory and set `USE_DEMO_DATA` to `True` (for the defined `DEMO_DATA_SIZE`).

In [None]:
# NOTE: Requires a demo data folder named "demo/{DEMO_DATA_SIZE}/" to be created first in the project's directory, containing the demo data

# Whether to load demo data from the "demo/" folder (set to False to build the dataset from the COCO dataset using the above constants or to load its existing version)
USE_DEMO_DATA = False

# Number of images in the demo dataset (if LOAD_DEMO_DATA is set to True)
DEMO_DATA_SIZE = 100

# Folder containing the demo images and captions
DEMO_FOLDER = f"src/demo/{DEMO_DATA_SIZE}/" if not RUNNING_IN_COLAB else f"/content/computer_vision_project/src/demo/{DEMO_DATA_SIZE}/"

#### <a id='toc2_4_1_6_'></a>[Other Constants](#toc0_)

We ultimately define the constants used to determine whether to print examples from the various loaded datasets, where to save data and models, flags to enable/disable database rebuild/refresh and the loading of models checkpoints, ecc...

In [None]:

# Whether to print examples of the images and captions during the dataset creation
PRINT_EXAMPLES = True

# Define the data folder, onto which the various dictionaries, lists and other data will be saved
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the path to save models
MODELS_FOLDER = "src/models" if not RUNNING_IN_COLAB else "/content/models"

# Force the creation of the "image_db" images list, the JSON files for the datasets, ecc...
FORCE_DICTIONARIES_CREATION = False		# Set to false to try to load the dictionaries from the DATA_FOLDER if they exist

# Whether to load model checkpoints (if they were already saved locally) or not
LOAD_MODELS_CHECKPOINTS = True

# Whether to save pytorch-specific datasets to the DATA_FOLDER (may lead to slower performances, higher ram/disk usage, and running out of available memory/RAM while trying to save the datasets)
SAVE_PYTORCH_DATASETS = False

#### <a id='toc2_4_1_7_'></a>[HuggingFace Tokens Configuration](#toc0_)

We set the **HuggingFace** API token to use for loading datasets from the HuggingFace Hub.

**⚠️ Note**: Set the `HF_TOKEN` constant to contain your own HuggingFace API token only when using the **`IMAGENET_FULL`** dataset, as it requires authentication to download the dataset (unlike the **`IMAGENET_REDUCED`** and **`IMAGENET_TINY`** datasets, which can be downloaded without the need for an HuggingFace API token).

In [None]:
# Define the HuggingFace API Access Token (for the ImageNet full dataset)
# NOTE: Only needed in case "DATASET" is set to "imagenet_full"
HF_TOKEN = None

#### <a id='toc2_4_1_8_'></a>[Weights & Biases Configuration](#toc0_)

We set the **Weights & Biases** API key to log the experiments.

**⚠️ Note**: Copy and paste your own W&B API key into the `WANDB_API_KEY` constant to see logging results, or set the constant to an empty string to disable W&B logging (this won't plot training losses and accuracies over time).

In [None]:
# Define the WANDB_API_KEY (set to "" to disable W&B logging)
# NOTE: leaving the WANDB_API_KEY to a value of None will throw an error
WANDB_API_KEY = None

We configure the **Weights & Biases** logger and API to track the experiments and the model's performances.

In [None]:
# Define the project name to use for the W&B logging
WANDB_PROJECT_NAME = "vit-dsi-project"

# Define the wandb logger, api object, entity name and project name
wandb_logger = None
wandb_api = None
wandb_entity = None
wandb_project = None
# Check if a W&B api key is provided
if WANDB_API_KEY == None:
	print("No W&B API key provided, please provide a valid key to use the W&B API or set the WANDB_API_KEY variable to an empty string to disable logging")
	raise ValueError("No W&B API key provided...")
elif WANDB_API_KEY != "":
	# Login to the W&B (Weights & Biases) API
	wandb.login(key=WANDB_API_KEY, relogin=True)
	# Minimize the logging from the W&B (Weights & Biases) library
	os.environ["WANDB_SILENT"] = "true"
	logging.getLogger("wandb").setLevel(logging.ERROR)
	# Initialize the W&B (Weights & Biases) loggger
	wandb_logger = WandbLogger(log_model="all", project=WANDB_PROJECT_NAME, name="- SEPARATOR -")
	# Initialize the W&B (Weights & Biases) API
	wandb_api = wandb.Api()
	# Get the W&B (Weights & Biases) entity name
	wandb_entity = wandb_logger.experiment.entity
	# Get the W&B (Weights & Biases) project name
	wandb_project = wandb_logger.experiment.project
	# Finish the "separator" experiment
	wandb_logger.experiment.finish(quiet=True)
	print("W&B API key provided, logging with W&B enabled.")
else:
	print("No W&B API key provided, logging with W&B disabled.")

#### <a id='toc2_4_1_9_'></a>[Local Files Folder Creation](#toc0_)

We create the folders to store the data dictionaries and the model's checkpoints.

In [None]:
# Create folders if they do not exist
if not os.path.exists(DATA_FOLDER):
	print(f"Creating the data folder at '{DATA_FOLDER}'...")
	os.makedirs(DATA_FOLDER)
if not os.path.exists(MODELS_FOLDER):
	print(f"Creating the models folder at '{MODELS_FOLDER}'...")
	os.makedirs(MODELS_FOLDER)
if DATASET != DatasetType.COCO:
	if not os.path.exists(IMAGENET_DATA_FOLDER):
		print(f"Creating the ImageNet data folder at '{IMAGENET_DATA_FOLDER}'...")
		os.makedirs(IMAGENET_DATA_FOLDER)

---



# <a id='toc3_'></a>[💾 Data Preparation](#toc0_)

## <a id='toc3_1_'></a>[Download Data & Resources](#toc0_)

#### <a id='toc3_1_1_1_'></a>[Download Datasets Associated Files and Libraries](#toc0_)

Download all the files and resources needed for the project based on the chosen `DATASET` type.

In [None]:
# Download necessary files and libraries for the datasets
if DATASET == DatasetType.COCO:
	# Check if the annotation file for the COCO dataset exists, if it does not exist, download it
	!cd {DATA_FOLDER} && wget -nc http://images.cocodataset.org/annotations/annotations_trainval{COCO_DATA_YEAR}.zip
	!cd {DATA_FOLDER} && unzip -n annotations_trainval{COCO_DATA_YEAR}.zip
else:
	# Download the classes.py file from https://huggingface.co/datasets/zh-plus/tiny-imagenet/resolve/main/classes.py
	!cd {IMAGENET_DATA_FOLDER} && wget -nc https://huggingface.co/datasets/zh-plus/tiny-imagenet/resolve/main/classes.py
	from data.imagenet import classes as imagenet_classes

#### <a id='toc3_1_1_2_'></a>[Datasets Download](#toc0_)

We download the final **ImageNet** or **MS COCO** dataset based on the chosen `DATASET` type and store the corresponding data in the `coco_captions`, `coco_instances` and `imagenet_data` object variables (only if `USE_DEMO_DATA` is set to `False`).

We then print general information about the downloaded dataset.

In [None]:
# Define the variables to store the MS COCO or ImageNet datasets
coco_captions = None
coco_instances = None
imagenet_data = None
# Build the chosen datasets or load the demo dataset
if not USE_DEMO_DATA:
	if DATASET == DatasetType.COCO:
		# Initialize the COCO api for captioning
		coco_captions = COCO(f"{DATA_FOLDER}{COCO_DATA_CAPTIONS_FILE}")
		# Initialize the COCO api for object detection
		coco_instances = COCO(f"{DATA_FOLDER}{CODO_DATA_INSTANCES_FILE}")
		# Show the COCO dataset info for the captioning task
		print("\nCOCO captioning dataset infos:")
		coco_captions.info()
		# Show the information for the captioning task
		print("\nCOCO captioning task infos:")
		coco_caps = coco_captions.dataset['annotations']
		print("Number of images: ", len(coco_captions.getImgIds()))
		print("Number of captions: ", len(coco_caps))
		print("Number of average captions per image: ", len(coco_caps) / len(coco_captions.getImgIds()))
		# Show the COCO dataset info for the object detection task
		print("\nCOCO object detection dataset infos:")
		coco_instances.info()
		# Show the information for the object detection task
		print("\nCOCO object detection task infos:")
		coco_objs = coco_instances.dataset['annotations']
		print("Number of images: ", len(coco_instances.getImgIds()))
		print("Number of objects: ", len(coco_objs))
		print("Number of categories: ", len(coco_instances.cats))
		print("Categories:")
		utils.print_json(coco_instances.cats)
	elif DATASET == DatasetType.IMAGENET_TINY:
		# Get the ImageNet dataset using huggingface's tiny-imagenet-200 dataset at https://huggingface.co/datasets/zh-plus/tiny-imagenet
		imagenet_data = hf_datasets.load_dataset("zh-plus/tiny-imagenet")
		# Show the ImageNet dataset info
		print("\nImageNet Tiny dataset infos:")
		print(imagenet_data)
	elif DATASET == DatasetType.IMAGENET_FULL:
		# Print an error message if the HF_TOKEN is not set
		if HF_TOKEN == None or HF_TOKEN == "":
			print("Please set the \"HF_TOKEN\" constant to a valid HuggingFace API token to download the ImageNet full dataset or change dataset type using the \"DATASET\" constant.")
			raise ValueError("No HuggingFace API token provided...")
		# Get the ImageNet dataset using huggingface's imagenet-1k dataset at https://huggingface.co/datasets/ILSVRC/imagenet-1k
		imagenet_data = hf_datasets.load_dataset("ILSVRC/imagenet-1k", token=HF_TOKEN, trust_remote_code=True)
		# Show the ImageNet dataset info
		print("\nImageNet Full dataset infos:")
		print(imagenet_data)
	elif DATASET == DatasetType.IMAGENET_REDUCED:
		# Get the ImageNet dataset using huggingface's imagenette dataset at https://huggingface.co/datasets/frgfm/imagenette
		imagenet_data = hf_datasets.load_dataset("frgfm/imagenette", "160px")
		# Show the ImageNet dataset info
		print("\nImageNet Reduced dataset infos:")
		print(imagenet_data)

#### <a id='toc3_1_1_3_'></a>[Datasets Examples](#toc0_)

We print some examples of the images and their corresponding labels from the downloaded dataset (if the `PRINT_EXAMPLES` flag is set to `True`).

In [None]:
# Print some examples from the dataset
if not USE_DEMO_DATA and PRINT_EXAMPLES:

	if DATASET == DatasetType.COCO:
		# Print the first image object example
		example_image_index = 0
		print("\nImage object example: ")
		image_example = coco_captions.loadImgs(coco_caps[example_image_index]['image_id'])[0]
		utils.print_json(image_example, 2)
		# Print the actual image file
		print("\nActual image of the example (size: " + str(image_example['width']) + "x" + str(image_example['height']) + "):")
		url = image_example['coco_url']
		image = io.imread(url)
		plt.axis('off')
		plt.imshow(image)
		plt.show()
		# Downscale the image to the maximum allowed size in the model
		image_max_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
		# Crop the image to a square aspect ratio if it is not already square
		downscaled_image = image
		if image_example['width'] > image_example['height']:
			# Image is wider than tall, crop the sides
			crop_width = (image_example['width'] - image_example['height']) // 2
			downscaled_image = image[:, crop_width:crop_width+image_example['height']]
		elif image_example['height'] > image_example['width']:
			# Image is taller than wide, crop the top and bottom
			crop_height = (image_example['height'] - image_example['width']) // 2
			downscaled_image = image[crop_height:crop_height+image_example['width'], :]
		# Downscale the image to the maximum allowed size
		downscaled_image = cv2.resize(downscaled_image, (image_max_size, image_max_size))
		print("\nDownscaled & cropped image of the example (size: " + str(image_max_size) + "x" + str(image_max_size) + "):")
		plt.axis('off')
		plt.imshow(downscaled_image)
		plt.show()
		# Print the captions for the given image
		print("\nCaption examples for the given image: ")
		captions_for_image = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=image_example['id']))
		for caption, j in zip(captions_for_image, range(len(captions_for_image))):
			print(str(j+1) + ") " + caption['caption'].strip())
		# Print the captioning object example
		print("\nFirst caption object example:")
		utils.print_json(captions_for_image[0], 2)
		# Print information about the object detection task for the given image
		print("\nObject detection examples for the given image:")
		# Get the object detection annotations for the given image
		annotations_for_image = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=image_example['id']))
		print("List of the " + str(len(annotations_for_image)) + " object detection annotations for the given image (obtained using the 'coco_instances.loadAnns(image_annotation_id)' function):")
		for annotation, j in zip(annotations_for_image, range(len(annotations_for_image))):
			print("\n> Annotation " + str(j+1) + ":")
			# Print the annotation object
			utils.print_json(annotation, 2, truncate_large_lists=10)
	else:
		# Print the first image object example
		example_image_index = 0
		print("\nImage object example: ")
		image_example = imagenet_data['train'][example_image_index]
		print("{")
		print("  'image':'", image_example['image'])
		print("  'label':", image_example['label'])
		print("}")
		# Show the label name
		print("\nLabel of the example:")
		label = image_example['label']
		label_id = imagenet_data['train'].features['label'].int2str(label)
		print("Label index:", label)
		print("Label ID:", label_id)
		print("Label name:", get_imagenet_class_name(label_id))
		# Show the PIL image of the example
		print("\nActual image of the example:")
		image = image_example['image']
		plt.axis('off')
		plt.imshow(image)


## <a id='toc3_2_'></a>[Database Creation](#toc0_)

#### <a id='toc3_2_1_1_'></a>[Image Database Creation](#toc0_)

We create the images database from the downloaded **ImageNet** and **MS COCO** datasets, and store it in the `images` list variable, a list of objects representing images with the corresponding associated information (IDs, class labels, image data, captions, ecc...).

Alternatively, if the `USE_DEMO_DATA` flag is set to `True`, we load the **images demo data** from the `DEMO_FOLDER` directory.

In [None]:
# Build a dataset of images for the training of the Vision Transformer model

# List of image objects used for the training of the Vision Transformer model
images = []

# Flag used to detect whether the images_db list has been loaded or not
loaded_images_db = False

# Function that returns the list containing the images for the training of the Vision Transformer model
def get_coco_images_db(number_of_images, process_images=True):
	'''
		Builds a list of images for the training of the Vision Transformer model.

		Parameters:
			number_of_images (int): The number of images to include in the dataset (search is stopped when the number of images is reached), use -1 to include all available images
			process_images (bool): Whether to process the images (i.e. retrieve actual image data, crop images and compupte their base64 encodings to add to the list)
	'''
	# Structure of the images
	images_list_object = {
		"image_id": "",			# ID of the image (as found in the COCO dataset)
		"image_url": "",		# URL of the image
		"image_width": 0,		# The original image width
		"image_height": 0,		# The original image height
		"image_captions": [],	# List of captions for the image
		"image_classes": [		# List of classes for the image (i.e. detected objects, in the order of area size)
			{
				"class_id": 0,		# ID of the class (as found in the COCO dataset)
				"class_name": "",	# Name of the class
				"class_area": 0,	# Sum of the area of each instance of the class in the image
				"class_count": 0	# Number of instances of the class in the image
			}
		],	
		"image_data": ""		# Base64 string of the image
	}
	# Get the image ids
	img_ids = coco_captions.getImgIds()
	# Get the images
	images = []
	# Function that returns a list of images with the given aspect ratio tolerance
	def select_images_list(image_aspect_ratio_tolerance):
		# Get the images
		for img_id in img_ids:
			# Get the image object
			img_obj = coco_captions.loadImgs(img_id)[0]
			# Check if the size of the image is square or within the aspect ratio tolerance
			image_aspect_ratio = img_obj['width'] / img_obj['height']
			if abs(image_aspect_ratio - 1) > image_aspect_ratio_tolerance:
				continue
			# Check if the image is already in the images list
			if any(img['image_id'] == img_obj['id'] for img in images):
				continue
			# Get the image url
			img_url = img_obj['coco_url']
			# Get the captions for the image
			img_captions = []
			captions = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=img_obj['id']))
			for caption in captions:
				caption_text = caption['caption'].strip()
				if len(caption_text) > 1:
					img_captions.append(caption_text)
			# Discard the image if the number of captions is less than the minimum
			if len(img_captions) < MIN_COCO_IMAGE_CAPTIONS:
				continue
			# Discard the image if it has no classes
			classes = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=img_obj['id']))
			if len(classes) == 0:
				continue
			# Create a classes object with the fields: "class_id", "class_name", "class_area"
			classes_obj = {}
			for class_obj in classes:
				class_id = class_obj['category_id']
				class_name = coco_instances.cats[class_id]['name']
				class_area = class_obj['area']
				if class_id not in classes_obj:
					classes_obj[class_id] = {
						"class_id": class_id,
						"class_name": class_name,
						"class_area": class_area,
						"class_count": 1,
					}
				else:
					classes_obj[class_id]['class_area'] += class_area
					classes_obj[class_id]['class_count'] += 1
			# Convert the classes object to a list
			classes_obj = list(classes_obj.values())
			# Sort the classes by area size
			classes_obj = sorted(classes_obj, key=lambda x: x['class_area'], reverse=True)
			# Add the image to the images list
			images_list_object = {
				"image_id": img_obj['id'],
				"image_url": img_url,
				"image_width": img_obj['width'],
				"image_height": img_obj['height'],
				"image_captions": img_captions,
				"image_classes": classes_obj,
				"image_data": None # Will be filled later
			}
			images.append(images_list_object)
			# Break if the number of images is reached
			if number_of_images >= 1 and len(images) >= number_of_images:
				break
		# Return the images list
		return images
	print("Selecting images with a square aspect ratio...")
	# Get the images that have a square aspect ratio first
	images = select_images_list(0)
	# Get the remaining images with the given aspect ratio tolerance
	if len(images) < number_of_images or number_of_images == -1:
		square_aspect_ratio_images = len(images)
		print("> Found " + str(square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio, looking for the remaining images...")
		print("Looking for remaining images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "% (either a " + str(1 + MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio or a " + str(1 - MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio)...")
		images = select_images_list(MAX_ASPECT_RATIO_TOLERANCE)
		non_square_aspect_ratio_images = len(images) - square_aspect_ratio_images
		# Print the number of images found
		print("> Found " + str(non_square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" )  + " more images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "%.")
	else:
		print("> Found " + str(len(images)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio.")
	# Print a message based on the number of images found
	if len(images) < number_of_images and number_of_images != -1:
		print("WARNING: Could not find enough images with the required aspect ratio tolerance, only " + str(len(images)) + " / " + str(number_of_images) + " images found.")
	else:
		print("DONE: Found all " + str(len(images)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with the required aspect ratio tolerance.")
	# Get all the image data
	if process_images:
		for img in tqdm(images, desc="Processing images data (computing BASE64 images encoding)..."):
			img["image_data"] = utils.get_image_data_as_base64(img['image_url'], IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION)
	# Return the images list
	return images

# Function that returns the list containing the images for the training of the Vision Transformer model
def get_imagenet_tiny_or_reduced_images_db(number_of_images, process_images=True):
	'''
		Builds a list of images for the training of the Vision Transformer model.

		Parameters:
			number_of_images (int): The number of images to include in the dataset (search is stopped when the number of images is reached), use -1 to include all available images
			process_images (bool): Whether to process the images (i.e. retrieve actual image data, crop images and compupte their base64 encodings to add to the list)
	'''
	# Structure of the images
	images_list_object = {
		"image_id": "",			# ID of the image, i.e. index in the imaged_db list
		"image_label_name": "",	# Name of the label of the image
		"image_label_index": 0,	# Index of the label of the image (as found in huggingface's tiny imagenet dataset, i.e. the index of the class inside the dataset)
		"image_label_id": "",	# ID of the label of the image (i.e. string with the form "nX...X" with X being integerrs, which is meant to be an identifier for the corresponding class in the original ImageNet database)
		"image_data": ""		# Base64 string of the image
	}
	# Get the images
	images = []
	# Get the image ids
	img_ids = range(len(imagenet_data['train']))
	# Reference to the label ids
	label_ids = imagenet_data['train'].features['label']
	# Get the images, taking all of the images in the database and enriching them with the corresponding label names and ID
	for img_id in img_ids:
		# Get the image object
		img_obj = imagenet_data['train'][img_id]
		# Get the image PIL object
		img_pil_obj = img_obj['image']
		# Convert the PIL object to a base64 string
		img_data = None
		if process_images:
			image_size = -1
			if DATASET == DatasetType.IMAGENET_TINY:
				image_size = 64
			elif DATASET == DatasetType.IMAGENET_REDUCED:
				image_size = 160
			img_data = utils.get_image_data_as_base64(img_pil_obj, image_size)
		# Get the label of the image
		img_label = img_obj['label']
		# Get the label ID
		img_label_id = label_ids.int2str(img_label)
		# Get the label name
		img_label_name = get_imagenet_class_name(img_label_id)
		# Add the image to the images list
		images_list_object = {
			"image_id": img_id,
			"image_label_name": img_label_name,
			"image_label_index": img_label,
			"image_label_id": img_label_id,
			"image_data": img_data
		}
		images.append(images_list_object)
		# Break if the number of images is reached
		if number_of_images >= 1 and len(images) >= number_of_images:
			break
	# Return the images list
	return images

# Function that returns the list containing the images for the training of the Vision Transformer model
def get_imagenet_full_images_db(number_of_images, process_images=True):
	'''
		Builds a list of images for the training of the Vision Transformer model.

		Parameters:
			number_of_images (int): The number of images to include in the dataset (search is stopped when the number of images is reached), use -1 to include all available images
			process_images (bool): Whether to process the images (i.e. retrieve actual image data, crop images and compupte their base64 encodings to add to the list)
	'''
	# Structure of the images
	images_list_object = {
		"image_id": "",			# ID of the image, i.e. index in the imaged_db list
		"image_label_name": "",	# Name of the label of the image
		"image_label_index": 0,	# Index of the label of the image (as found in huggingface's tiny imagenet dataset, i.e. the index of the class inside the dataset)
		"image_label_id": "",	# ID of the label of the image (i.e. string with the form "nX...X" with X being integerrs, which is meant to be an identifier for the corresponding class in the original ImageNet database)
		"image_original_width": "",		# Original width of the image
		"image_original_height": "",	# Original height of the image
		"image_data": ""		# Base64 string of the image
	}

	# Get the images
	images_to_return = []
	# Get the image ids
	img_ids = range(len(imagenet_data['train']))
	# Reference to the label ids
	label_ids = imagenet_data['train'].features['label']
	# Function that returns a list of images with the given aspect ratio tolerance
	def select_images_list(image_aspect_ratio_tolerance):
		# Get the images
		for img_id in img_ids:
			# Get the image object
			img_obj = imagenet_data['train'][img_id]
			# Get the image PIL object
			img_pil_obj = img_obj['image']
			# Get the image width and height from the PIL object
			img_width, img_height = img_pil_obj.size
			# Check if the size of the image is square or within the aspect ratio tolerance
			image_aspect_ratio = img_width / img_height
			if abs(image_aspect_ratio - 1) > image_aspect_ratio_tolerance:
				continue
			# Check if the image is already in the images list
			if any(img['image_id'] == img_id for img in images_to_return):
				continue
			# Convert the PIL object to a base64 string
			img_data = None
			if process_images:
				img_data = utils.get_image_data_as_base64(img_pil_obj, IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION)
			# Get the label of the image
			img_label = img_obj['label']
			# Get the label ID
			img_label_id = label_ids.int2str(img_label)
			# Get the label name
			img_label_name = get_imagenet_class_name(img_label_id)
			# Add the image to the images list
			images_list_object = {
				"image_id": img_id,
				"image_label_name": img_label_name,
				"image_label_index": img_label,
				"image_label_id": img_label_id,
				"image_data": img_data
			}
			images_to_return.append(images_list_object)
			# Break if the number of images is reached
			if number_of_images >= 1 and len(images_to_return) >= number_of_images:
				break
		# Return the images list
		return images_to_return
	print("Selecting images with a square aspect ratio...")
	# Get the images that already have a square aspect ratio first
	images_to_return = select_images_list(0)
	# Get the remaining images with the given aspect ratio tolerance
	if len(images_to_return) < number_of_images or number_of_images == -1:
		square_aspect_ratio_images = len(images_to_return)
		print("> Found " + str(square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio, looking for the remaining images...")
		print("Looking for remaining images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "% (either a " + str(1 + MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio or a " + str(1 - MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio)...")
		images_to_return = select_images_list(MAX_ASPECT_RATIO_TOLERANCE)
		non_square_aspect_ratio_images = len(images_to_return) - square_aspect_ratio_images
		# Print the number of images found
		print("> Found " + str(non_square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" )  + " more images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "%.")
	else:
		print("> Found " + str(len(images_to_return)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio.")
	# Print a message based on the number of images found
	if len(images_to_return) < number_of_images and number_of_images != -1:
		print("WARNING: Could not find enough images with the required aspect ratio tolerance, only " + str(len(images_to_return)) + " / " + str(number_of_images) + " images found.")
	else:
		print("DONE: Found all " + str(len(images_to_return)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with the required aspect ratio tolerance.")
	# Return the images list
	return images_to_return

if not USE_DEMO_DATA:
	# Check if the images list should be rebuilt or loaded
	images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
	if os.path.exists(images_db_file) and not FORCE_DICTIONARIES_CREATION:
		with open(images_db_file, 'r') as f:
			images = json.load(f)
		loaded_images_db = True
		print("Loaded the images list from the file: ", images_db_file)
	else:
		if DATASET == DatasetType.COCO:
			# Initialize the images list
			images = get_coco_images_db(-1, False)
		elif DATASET == DatasetType.IMAGENET_TINY or DATASET == DatasetType.IMAGENET_REDUCED:
			# Initialize the images list
			images = get_imagenet_tiny_or_reduced_images_db(-1, False)
		elif DATASET == DatasetType.IMAGENET_FULL:
			# Initialize the images list
			images = get_imagenet_full_images_db(-1, False)
		# Save the images list to a JSON file
		images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
		print("Saving the images list to the file: ", images_db_file)
		with open(images_db_file, 'w') as f:
			json.dump(images, f)

	# Print the final number of images in the dataset
	print("\nNumber of loaded images in the dataset: " + str(len(images)) + ("/" + str(NUMBER_OF_IMAGES_IN_DB) if NUMBER_OF_IMAGES_IN_DB != -1 else ""))
	
else:
	# Load the imaged_db.json file from the demo folder
	images_db_file = os.path.join(DEMO_FOLDER, "images_db.json")
	if os.path.exists(images_db_file):
		with open(images_db_file, 'r') as f:
			images = json.load(f)
		print("Loaded the images list from the file: ", images_db_file)
	else:
		print("ERROR: Could not load the demo images list from the file: ", images_db_file)
		raise FileNotFoundError("Could not load the demo images list from the file: " + images_db_file)

#### <a id='toc3_2_1_2_'></a>[Classes Definition](#toc0_)

We populate the `classes` dictionary with the classes and their corresponding image IDs, considering images loaded from the **ImageNet** and **MS COCO** datasets and stored in the `images` list.

Alternatively, if the `USE_DEMO_DATA` flag is set to `True`, we load the **classes demo data** from the `DEMO_FOLDER` directory.

In [None]:
# Create a "classes" dictionary with the classes found in the dataset and, for each of them, a list of the images in which they appear

# Initialize the classes dictionary
classes = {}

# Flag used to detect whether the classes dictionary has been loaded or not
loaded_classes = False

if not USE_DEMO_DATA:
	# Function to get the classes dictionary from the images
	def get_classes_dict():
		# Initialize the classes list
		classes = {}
		if DATASET == DatasetType.COCO:
			# Get the classes from the images
			for i in tqdm(range(len(images)), desc="Processing images for classes..."):
				img = images[i]
				for class_obj in img['image_classes']:
					# Get the class id
					class_id = class_obj['class_id']
					# Add the class to the classes list if it does not exist
					if class_id not in classes.keys():
						classes[class_id] = []
					# Add the image index to the class list
					classes[class_id].append(i)
		else:
			# Get the classes from the images
			for i in tqdm(range(len(images)), desc="Processing images for classes..."):
				img = images[i]
				# Get the class id
				class_id = img['image_label_index']
				# Add the class to the classes list if it does not exist
				if class_id not in classes.keys():
					classes[class_id] = []
				# Add the image index to the class list
				classes[class_id].append(i)
		print("Created the classes list from the images with " + str(len(classes)) + " classes.")
		# Discard the classes with less than the minimum number of images
		classes = {k: v for k, v in classes.items() if len(v) >= MIN_IMAGES_PER_CLASS}
		# Sort classes by the number of images
		classes = {k: v for k, v in sorted(classes.items(), key=lambda item: len(item[1]), reverse=True)}
		print("Discarded the classes with less than " + str(MIN_IMAGES_PER_CLASS) + " images: " + str(len(classes)) + " / " + str(len(classes.keys()) + len(classes)) + " classes remaining.")
		# Return the classes list
		return classes

	# Get the classes dictionary if it already exists, otherwise create it
	classes_file = os.path.join(DATA_FOLDER, "classes.json")
	if os.path.exists(classes_file) and not FORCE_DICTIONARIES_CREATION:
		with open(classes_file, 'r') as f:
			classes = json.load(f)
		if len(classes) > 0:
			loaded_classes = True
			print("Loaded the classes dictionary from the file: ", classes_file)
	if not loaded_classes:
		print("Creating the classes dictionary from the images...")
		classes = get_classes_dict()
		# Save the classes dictionary to a JSON file
		print("Saving the classes dictionary to the file: ", classes_file)
		with open(classes_file, 'w') as f:
			json.dump(classes, f)
else:
	# Load the classes.json file from the demo folder
	classes_file = os.path.join(DEMO_FOLDER, "classes.json")
	if os.path.exists(classes_file):
		with open(classes_file, 'r') as f:
			classes = json.load(f)
		print("Loaded the classes dictionary from the file: ", classes_file)
	else:
		print("ERROR: Could not load the demo classes dictionary from the file: ", classes_file)
		raise FileNotFoundError("Could not load the demo classes dictionary from the file: " + classes_file)

We print the entire `classes` dictionary to show the classes and their corresponding image IDs (truncated to the first 10 images of each class).

In [None]:
# Print the classes dictionary
if PRINT_EXAMPLES:
	print("\nClasses dictionary:")
	utils.print_json(classes, 2, truncate_large_lists=10)

#### <a id='toc3_2_1_3_'></a>[Database Finalization](#toc0_)

We update the `images` list to only contain the designated `NUMBER_OF_IMAGES_IN_DB` images, making sure we include at least `MIN_IMAGES_PER_CLASS` images for each class.
If `SHUFFLE_IMAGES_DB` is set to `True`, we shuffle the images in the database to assign them random IDs.

We also update the `classes` dictionary to only include a number of classes compatible with the defined `NUMBER_OF_IMAGES_IN_DB` and `MIN_IMAGES_PER_CLASS` constants.

Ultimately, we compute the **BASE64** encoding of image data for each image in the `images` list to be stored as a string, locally, in the corresponding JSON files in the `DATA_FOLDER` directory.

We finally print information about the final number of images and classes included in the database.

In [None]:
# Update the images in the images DB to finally only include the images that have the classes in the classes list, with MIN_IMAGES_PER_CLASS images per class, and to populate the images list with the base64 encoding of the images

if not USE_DEMO_DATA and not loaded_images_db and not loaded_classes:

	# Function to update the images list to only include the images that have the classes in the classes list
	def update_images_db_based_on_classes(max_images):
		# Number of classes to maintain the designated number of images
		classes_count = math.ceil(max_images / MIN_IMAGES_PER_CLASS)
		# Initialize the new images list
		new_images_db = []
		images_db_ids_map = {}
		# Get the classes to maintain the designated number of images
		classes_to_maintain = list(classes.keys())[:classes_count]
		# Create a new classes list with the classes found in the new images list
		new_classes = {}
		# Get the images to maintain the designated number of images
		if DATASET == DatasetType.COCO:
			# Select the first "max_images" images that have the classes in the classes list
			for i in tqdm(range(len(images_to_return)), desc="Processing images for classes..."):
				img = images_to_return[i]
				# Check if the image has any of the classes to maintain
				mantain_image = any(class_obj['class_id'] in classes_to_maintain for class_obj in img['image_classes'])
				if mantain_image:
					new_images_db.append(img)
					images_db_ids_map[i] = len(new_images_db) - 1
				# Break if the number of images is reached
				if len(new_images_db) >= max_images:
					break
		else:
			# Select the first images of each class to maintain until we reach the maximum number of images "max_images"
			images_per_class = MIN_IMAGES_PER_CLASS
			for class_id in tqdm(classes_to_maintain, desc="Processing images for classes..."):
				# Get the images of the class
				images_of_class = classes[class_id]
				# Get the number of images to maintain for the class
				images_to_maintain = min(images_per_class, len(images_of_class))
				# Add the images to the new images list
				for i in range(images_to_maintain):
					img = images_to_return[images_of_class[i]]
					new_images_db.append(img)
					images_db_ids_map[images_of_class[i]] = len(new_images_db) - 1
				# Break if the number of images is reached
				if len(new_images_db) >= max_images:
					break
		# Get the classes from the classes to maintain
		for class_id in classes_to_maintain:
			# Remove any image index that is not in the new images list
			new_classes[class_id] = [images_db_ids_map[i] for i in classes[class_id] if i in images_db_ids_map]
		# Sort the classes by the number of images
		new_classes = { k: v for k, v in sorted(new_classes.items(), key=lambda item: len(item[1]), reverse=True) }
		# Shuffle the images in the database, if needed, and make the classes list consistent with the new images list "IDs", i.e. the index of the image in the new images list
		if SHUFFLE_DB_IMAGES:
			# Get a list of all IDs of images in the new images list, i.e. a list of indexes to then use for shuffling
			images_ids = list(range(len(new_images_db)))
			# Shuffle the images IDs
			random.shuffle(images_ids)
			# Shuffle the images list
			shuffled_images_db = [new_images_db[i] for i in images_ids]
			# Update the classes list to use the new image IDs
			for class_id in new_classes.keys():
				new_classes[class_id] = [images_ids.index(i) for i in new_classes[class_id]]
			# Update the new images list
			new_images_db = shuffled_images_db
		# Return the new images list and the new classes list
		return new_images_db, new_classes

	# Update the images list to only include the images that have the classes in the classes list
	if not loaded_images_db or not loaded_classes:
		# Update the images list to only include the images that have the classes in the classes list
		max_images = NUMBER_OF_IMAGES_IN_DB if NUMBER_OF_IMAGES_IN_DB != -1 else len(images)
		print("\nUpdating the images list to only include the " + str(len(classes)) + " classes with at least " + str(MIN_IMAGES_PER_CLASS) + " images, not exceeding " + str(max_images) + " images...")
		images_to_return, classes = update_images_db_based_on_classes(max_images)
		print("DONE: Updated the images list, now containing " + str(len(images_to_return)) + " images.")
		print("> Final number of classes in the dataset: " + str(len(classes)))
		# Update the images list to include the base64 encoding of the images
		print("Computing the BASE64 images encoding for the images list...")
		max_image_size = -1
		if DATASET == DatasetType.COCO:
			max_image_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
			for img in tqdm(images_to_return, desc="Processing images data (computing BASE64 images encoding)..."):
				img["image_data"] = utils.get_image_data_as_base64(img['image_url'], max_image_size)
		else:
			max_image_size = -1
			if DATASET == DatasetType.IMAGENET_TINY:
				max_image_size = 64
			elif DATASET == DatasetType.IMAGENET_FULL:
				max_image_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
			elif DATASET == DatasetType.IMAGENET_REDUCED:
				max_image_size = 160
			for img in tqdm(range(len(images_to_return)), desc="Processing images data (computing BASE64 images encoding)..."):
				if images_to_return[img]["image_data"] is not None:
					continue
				imagenet_image_id = images_to_return[img]["image_id"]
				pil_img_obj = imagenet_data['train'][imagenet_image_id]['image']
				images_to_return[img]["image_data"] = utils.get_image_data_as_base64(pil_img_obj, max_image_size)
		print("DONE: Computed the BASE64 images encoding for the images list.")
		# Save the updated images list to a JSON file
		images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
		print("Saving the updated images list to the file: ", images_db_file)
		with open(images_db_file, 'w') as f:
			json.dump(images_to_return, f)

actual_image_patches_per_dimension = -1
if DATASET == DatasetType.COCO or DATASET == DatasetType.IMAGENET_FULL:
	actual_image_patches_per_dimension = IMAGE_PATCHES_PER_DIMENSION
elif DATASET == DatasetType.IMAGENET_TINY:
	actual_image_patches_per_dimension = 64 // IMAGE_PATCH_SIZE
elif DATASET == DatasetType.IMAGENET_REDUCED:
	actual_image_patches_per_dimension = 160 // IMAGE_PATCH_SIZE

# Print the final number of images in the dataset
print("\nNumber of loaded images in the dataset: " + str(len(images_to_return)))
print("Number of classes in the dataset: " + str(len(classes)))	
print("\nImage patch size: " + str(IMAGE_PATCH_SIZE))
print("Number of image patches per dimension: " + str(actual_image_patches_per_dimension))


We print the final `classes` dictionary and `images` list (if the `PRINT_EXAMPLES` flag is set to `True`).

We also print an example image and its representation from the point of view of the **Vision Transformer** model (i.e. divided into **patches** of size `IMAGE_PATCH_SIZE`).

In [None]:
if PRINT_EXAMPLES:
	# Print the classes list
	if DATASET == DatasetType.COCO:
		total_classes = -1 if USE_DEMO_DATA else len(coco_instances.cats)
		print("\nClasses list sorted by number of images (" + str(len(classes)) + " classes out of " + (str(total_classes) if total_classes != -1 else "???") + " total MS COCO classes):")
		utils.print_json(classes, 2, truncate_large_lists=10)
	else:
		total_classes = -1 if USE_DEMO_DATA else imagenet_data['train'].features['label'].num_classes
		print("\nClasses list sorted by number of images (" + str(len(classes)) + " classes out of " + (str(total_classes) if total_classes != -1 else "???") + " total ImageNet classes):")
		utils.print_json(classes, 2, truncate_large_lists=10)
	# Print the final images list
	print("\nAll images list:")
	utils.print_json(images_to_return, 2, truncate_large_lists=5)
	# Print all the sinigle classes that appear in the list
	if DATASET == DatasetType.COCO:
		print("\nAll classes in DB list (" + str(len(classes)) + " classes):")
		actual_db_classes = list(set([class_obj['class_name'] for img in images_to_return for class_obj in img['image_classes']]))
		print(actual_db_classes)
	else:
		actual_db_classes = list(set([class_name for class_name in [image["image_label_name"] for image in images_to_return]]))
		print("\nAll classes in DB list (" + str(len(actual_db_classes)) + " classes out of " + str(len(classes)) + " expected):")
		print(actual_db_classes)
	# Print the first image object example
	example_image_index = -1
	print("Image object example: ")
	utils.print_json(images_to_return[example_image_index], 2)
	# Print the actual image file
	image_b64_string = None
	if images_to_return[example_image_index]['image_data'] != None:
		image_b64_string = images_to_return[example_image_index]['image_data']
	else:
		if DATASET == DatasetType.COCO:
			image_b64_string = utils.get_image_data_as_base64(images_to_return[example_image_index]['image_url'], actual_image_patches_per_dimension)
		else:
			image_b64_string = utils.get_image_data_as_base64(imagenet_data['train'][example_image_index]['image'], actual_image_patches_per_dimension)
	image = utils.get_image_from_b64_string(image_b64_string)
	if DATASET == DatasetType.COCO or DATASET == DatasetType.IMAGENET_FULL:
		image_width = images_to_return[example_image_index]['image_width']
		image_height = images_to_return[example_image_index]['image_height']
		print("\nActual image of the example (original size: " + str(image_width) + "x" + str(image_height) + " | downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	elif DATASET == DatasetType.IMAGENET_TINY:
		print("\nActual image of the example (original size: 64x64 | downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	elif DATASET == DatasetType.IMAGENET_REDUCED:
		print("\nActual image of the example (original size: 160x160 | downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	plt.axis('off')
	plt.imshow(image)
	plt.show()
	# Print how the Transformer model sees the image
	print("\nHow the Transformer model sees the image (downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	# Divide the image into smaller images representing the patches, then show the separated patches
	image_patches = []
	for i in range(actual_image_patches_per_dimension):
		for j in range(actual_image_patches_per_dimension):
			image_patch = image[i*IMAGE_PATCH_SIZE:(i+1)*IMAGE_PATCH_SIZE, j*IMAGE_PATCH_SIZE:(j+1)*IMAGE_PATCH_SIZE]
			image_patches.append(image_patch)
	# Show the image patches in a grid
	if actual_image_patches_per_dimension > 1:
		fig, axs = plt.subplots(actual_image_patches_per_dimension, actual_image_patches_per_dimension, figsize=(10, 10))
		for i in range(actual_image_patches_per_dimension):
			for j in range(actual_image_patches_per_dimension):
				axs[i, j].axis('off')
				axs[i, j].imshow(image_patches[i*actual_image_patches_per_dimension+j])
	else:
		plt.axis('off')
		plt.imshow(image_patches[0])
	plt.show()

## <a id='toc3_3_'></a>[Indexing & Retrieval Datasets](#toc0_)

#### <a id='toc3_3_1_1_'></a>[Training Datasets Creation](#toc0_)

We split the original `images` list into separate datasets for the **indexing** and **retrieval** tasks.

Two different datasets are created:

- **Indexing Dataset**: A dataset to train the ViT model for the **indexing task**, in which the model learns to generate image IDs starting from **images' patches** as source sequences.

   Items of the dataset have the form **`(db_image, image_id)`** where `db_image` is an image to consider being included in our indexed database, and `image_id` is the ID of the image in the indexed database (either "clustered" or "random" based on the `SHUFFLE_IMAGES_DB` flag).

   The indexed database images are stored in an **`images_db_indexing`** list variable, containing image objects corresponding to the images in the indexed database.

- **Retrieval Dataset**: A dataset to train the ViT model for the **retrieval task**, in which the model learns to generate image IDs starting from **image queries' patches** as source sequences.

   Items of the dataset have the form **`(query_image, image_id)`** where `query_image` is an image that is NOT included in our indexed database, and `image_id` is the ID of an image in the indexed database that is considered relevant to the query image.

   The query images are stored in an **`images_db_retrieval`** dictionary variable, where the keys are the image IDs of queries found in the `images` list, and the values are lists of relevant image IDs from the indexed database.

In [None]:
# Split the images list into a list for the indexing dataset (i.e. images in the database) and a list for the image retrieval dataset (i.e. similar images to retrieve images in the DB)

# List of images for the indexing and image retrieval datasets
images_db_indexing = []	# List of images for the indexing dataset
images_db_retrieval = {} # Dictionary containing image IDs of images NOT in the indexing dataset as keys and the list of similar images in the indexing dataset as values

# Dictionary for mapping image IDs in the images_db (i.e. indexes of its items) into the IDs of the images that are only in the indexing dataset (i.e. their indexes in the images_db_indexing list)
map_images_db_to_indexing_db = {}
# Dictionary for the reverse mapping (from the indexing dataset to the original dataset)
map_indexing_db_to_images_db = {}

# Check if the indexing and image retrieval datasets should be rebuilt or loaded
can_load_dictionaries = False
if not USE_DEMO_DATA:
	images_db_indexing_file = os.path.join(DATA_FOLDER, "images_db_indexing.json")
	images_db_image_retrieval_file = os.path.join(DATA_FOLDER, "images_db_image_retrieval.json")
	if os.path.exists(images_db_indexing_file) and os.path.exists(images_db_image_retrieval_file) and not FORCE_DICTIONARIES_CREATION:
		can_load_dictionaries = True

# Build or load the indexing and image retrieval datasets
if not USE_DEMO_DATA:
	if not can_load_dictionaries:
		# Create the indexing and image retrieval datasets from the images list
		for class_id_or_index in classes.keys():
			class_obj = classes[class_id_or_index]
			indexing_number = int(len(class_obj) * (1 - IMAGE_RETRIEVAL_DB_PERCENTAGE))
			predicted_similar_images = []
			for j in range(len(class_obj)):
				is_in_db = j < indexing_number
				img_id = class_obj[j]
				# Get the image object
				img = images_to_return[img_id]
				if is_in_db:
					# Add the image to the indexing dataset as an image object (containing various image infos)
					images_db_indexing.append(img)
					# Add the image to the similar images list
					predicted_similar_images.append(img_id)
					# Store the remapping of the image IDs for the image retrieval dataset (i.e. the index of the image in the indexing dataset)
					#	NOTE: this is done so that we can consider ONLY images in our indexing datasets as our final images in the dataset
					map_images_db_to_indexing_db[img_id] = len(images_db_indexing) - 1
				else:
					# Add the image to the image retrieval dataset (as the index of the image in the indexing dataset, not the original image ID)
					images_db_retrieval[img_id] = [ map_images_db_to_indexing_db[i] for i in predicted_similar_images ]
		# Store the reverse mapping of the image IDs for the image retrieval dataset
		map_indexing_db_to_images_db = { v: k for k, v in map_images_db_to_indexing_db.items() }
		# Save the indexing and image retrieval datasets to JSON files, along with both the remapped image IDs dictionaries
		images_db_indexing_file = os.path.join(DATA_FOLDER, "images_db_indexing.json")
		print("Saving the images list for the indexing dataset to the file: ", images_db_indexing_file)
		with open(images_db_indexing_file, 'w') as f:
			json.dump(images_db_indexing, f)
		images_db_image_retrieval_file = os.path.join(DATA_FOLDER, "images_db_image_retrieval.json")
		print("Saving the images list for the image retrieval dataset to the file: ", images_db_image_retrieval_file)
		with open(images_db_image_retrieval_file, 'w') as f:
			json.dump(images_db_retrieval, f)
		remapped_indexing_image_ids_file = os.path.join(DATA_FOLDER, "remapped_indexing_image_ids.json")
		print("Saving the remapped image IDs for the image retrieval dataset to the file: ", remapped_indexing_image_ids_file)
		with open(remapped_indexing_image_ids_file, 'w') as f:
			json.dump(map_images_db_to_indexing_db, f)
		remapped_indexing_image_ids_reverse_file = os.path.join(DATA_FOLDER, "remapped_indexing_image_ids_reverse.json")
		print("Saving the reverse remapped image IDs for the image retrieval dataset to the file: ", remapped_indexing_image_ids_reverse_file)
		with open(remapped_indexing_image_ids_reverse_file, 'w') as f:
			json.dump(map_indexing_db_to_images_db, f)
		print("DONE: Created the images list for the indexing and image retrieval datasets.")
	else:
		# Load the indexing and image retrieval datasets from the JSON files
		with open(images_db_indexing_file, 'r') as f:
			images_db_indexing = json.load(f)
		print("Loaded " + str(len(images_db_indexing)) + " images for the indexing dataset from the file: ", images_db_indexing_file)
		with open(images_db_image_retrieval_file, 'r') as f:
			images_db_retrieval = json.load(f)
		print("Loaded " + str(len(images_db_retrieval)) + " images for the image retrieval dataset from the file: ", images_db_image_retrieval_file)
		# Load the remapped image IDs for the image retrieval dataset
		remapped_indexing_image_ids_file = os.path.join(DATA_FOLDER, "remapped_indexing_image_ids.json")
		with open(remapped_indexing_image_ids_file, 'r') as f:
			map_images_db_to_indexing_db = json.load(f)
		print("Loaded the remapped image IDs for the image retrieval dataset from the file: ", remapped_indexing_image_ids_file)
		remapped_indexing_image_ids_reverse_file = os.path.join(DATA_FOLDER, "remapped_indexing_image_ids_reverse.json")
		with open(remapped_indexing_image_ids_reverse_file, 'r') as f:
			map_indexing_db_to_images_db = json.load(f)
		print("Loaded the reverse remapped image IDs for the image retrieval dataset from the file: ", remapped_indexing_image_ids_reverse_file)
else:
	# Load the images_db_indexing.json and images_db_image_retrieval.json files from the demo folder
	images_db_indexing_file = os.path.join(DEMO_FOLDER, "images_db_indexing.json")
	images_db_image_retrieval_file = os.path.join(DEMO_FOLDER, "images_db_image_retrieval.json")
	if os.path.exists(images_db_indexing_file):
		with open(images_db_indexing_file, 'r') as f:
			images_db_indexing = json.load(f)
		print("Loaded " + str(len(images_db_indexing)) + " images for the indexing dataset from the file: ", images_db_indexing_file)
	else:
		print("ERROR: Could not load the demo images list for the indexing dataset from the file: ", images_db_indexing_file)
		raise FileNotFoundError("Could not load the demo images list for the indexing dataset from the file: " + images_db_indexing_file)
	if os.path.exists(images_db_image_retrieval_file):
		with open(images_db_image_retrieval_file, 'r') as f:
			images_db_retrieval = json.load(f)
		print("Loaded " + str(len(images_db_retrieval)) + " images for the image retrieval dataset from the file: ", images_db_image_retrieval_file)
	else:
		print("ERROR: Could not load the demo images list for the image retrieval dataset from the file: ", images_db_image_retrieval_file)
		raise FileNotFoundError("Could not load the demo images list for the image retrieval dataset from the file: " + images_db_image_retrieval_file)

# Compute the max length of the image IDS (we consider the index of the image in the "images_db" as the image ID)
max_image_id_length = len(str(len(images_db_indexing)))
# Number of output tokens for the encoded image IDs (the 10 digits [0-9] plus the 3 special tokens, i.e. end of sequence, padding, start of sequence)
output_tokens = 10 + 3

# Print the final number of images in the datasets
print("\nNumber of images in the indexing dataset: " + str(len(images_db_indexing)))
print("Image IDs max length: " + str(max_image_id_length))


#### <a id='toc3_3_1_2_'></a>[Transformer Datasets Creation](#toc0_)

Starting from the `images_db_indexing` and `images_db_retrieval` datasets, we create the **PyTorch Datasets** to be used for training the **Vision Transformer** model:

- **`TransformerIndexingDataset`**: Items of the dataset have the form **`(db_image, image_id)`** where `db_image` is a 3D RGB tensor representing an image included in our indexed database, and `image_id` is the encoded representation of the ID of the image (starting from a "Begin of Sequence" token", followed by a sequnce of digits representing the image ID, padded to the given `max_image_id_length` using the corresponding "Padding" token, and ending with an "End of Sequence" token, with each token being represented as a vector).

- **`TransformerRetrievalDataset`**: Items of the dataset have the form **`(query_image, image_id)`** where `query_image` is a 3D RGB tensor representing an image that is NOT included in our indexed database, and `image_id` is the encoded representation of the ID of an image in the indexed database that is considered relevant to the query image (encoded with the same approach as before).

In [None]:
# Build the indexing and image retrieval datasets to be used for training the Vision Transformer model

# Paths of the file in which the PyTorch datasets will be stored or from which they will be loaded
transformer_indexing_dataset_file = None
transformer_image_retrieval_dataset_file = None
if SAVE_PYTORCH_DATASETS:
	transformer_indexing_dataset_file = (DATA_FOLDER + "/" if not USE_DEMO_DATA else DEMO_FOLDER) + "transformer_indexing_dataset.json"
	transformer_image_retrieval_dataset_file = (DATA_FOLDER + "/" if not USE_DEMO_DATA else DEMO_FOLDER) + "transformer_image_retrieval_dataset.json"

# Build the Transformer Indexing Database for training the vision transformer
transformer_indexing_dataset = datasets.TransformerIndexingDataset(
	images=images_db_indexing,
	patch_size=IMAGE_PATCH_SIZE,
	img_patches=actual_image_patches_per_dimension,
	img_id_max_length=max_image_id_length,
	dataset_file_path=transformer_indexing_dataset_file,
	force_dataset_rebuild=FORCE_DICTIONARIES_CREATION and not USE_DEMO_DATA
)

# Build the Transformer Image Retrieval Database for training the vision transformer
transformer_image_retrieval_dataset = datasets.TransformerRetrievalDataset(
	all_images=images_to_return,
	similar_images=images_db_retrieval,
	patch_size=IMAGE_PATCH_SIZE,
	img_patches=actual_image_patches_per_dimension,
	img_id_max_length=max_image_id_length,
	dataset_file_path=transformer_image_retrieval_dataset_file,
	force_dataset_rebuild=FORCE_DICTIONARIES_CREATION and not USE_DEMO_DATA,
	images_db_to_indexing_db=map_images_db_to_indexing_db,
	indexing_db_to_images_db=map_indexing_db_to_images_db
)

We print some **examples** of the **Indexing Dataset** and the **Retrieval Dataset** to visualize the data (if the `PRINT_EXAMPLES` flag is set to `True`).

In [None]:
if PRINT_EXAMPLES:
	# Print the first example from the Transformer Indexing Dataset
	example_index = random.randint(0, len(transformer_indexing_dataset)-1)
	print("Example from the Transformer Indexing Dataset:")
	print("<encoded_image, encoded_image_id> tuple:")
	print(transformer_indexing_dataset[example_index])
	# Print the first example from the Transformer Image Retrieval Dataset
	example_index = random.randint(0, len(transformer_image_retrieval_dataset)-1)
	print("\nExample from the Transformer Image Retrieval Dataset:")
	print("<encoded_image, encoded_similar_image_id> tuple:")
	print(transformer_image_retrieval_dataset[example_index])

# <a id='toc4_'></a>[🛍️ Bag of Visual Words Model](#toc0_)

### <a id='toc4_1_1_'></a>[BoVW Model Initialization](#toc0_)

We initialize our baseline **Bag of Visual Words** model, based on the **SIFT** detector and **K-Means** clustering algorithm.<br/>
Unlike the **Vision Transformer** model, the **BoVW** model does not require training (uses no machine learning), as it is based on a fixed pipeline of feature extraction, clustering, and histogram generation.

The steps to initialize the BoVW model are as follows:

1. **Feature Extraction**: We extract the **SIFT** features from all the images in the indexed database, represented as a list of **keypoints** with their corresponding vector **descriptors**;

2. **Clustering**: We cluster the extracted features using the **K-Means** algorithm (we use `K`=150) to create a **visual words vocabulary** of `K` visual words (K-Means centroids);

3. **Histogram Generation**: We generate the **histograms** for each image in the indexed database, representing the frequency of SIFT features assigned to each visual word in the vocabulary (BoVW representation of the indexed database);

4. **Database Indexing**: We store the BoVW histograms in a dictionary, where the keys are the image IDs and the values are the corresponding BoVW representations (histograms);

At **inference** time, given an image query `Q`:

1. We extract the **SIFT** features from the query image `Q` as a list of **keypoints** with their corresponding vector **descriptors**;

2. We assign each feature to the closest visual word in the vocabulary, generating a histogram representing the frequency of SIFT features assigned to each visual word (BoVW representation of the query image `Q`);

3. We iterate over all the BoVW representations of the indexed database images, computing the **cosine similarity** between the BoVW representation of the query image `Q` and each indexed database image (**index-then-retrieve** image retrieval pipeline);

4. We return the **sorted list of image IDs** based on the computed cosine similarities.

**NOTE:** The **BoVW** model uses grayscale images, so we convert the RGB images to grayscale before extracting the SIFT features.


In [None]:
# Create a BoVW model for the image retrieval dataset
print("\nCreating a BoVW model for the image retrieval dataset...")
BOVW_model = models.BoVW(
	all_images=images_to_return,
	indexed_images=images_db_indexing,
	kmeans_clusters=150
)
print("DONE: Created the BoVW model for the image retrieval dataset.")

# <a id='toc5_'></a>[🤖 Vision Transformer Model (DSI approach)](#toc0_)

### <a id='toc5_1_1_'></a>[ViT Model Initialization](#toc0_)

We initialize the **Vision Transformer** model for the **Differentiable Search Index (DSI)** approach with the given hyperparameters and configuration (set by the various constants defined in the **constants** section of this Notebook).

The model is based on the **Sequence to Sequence** architecture, where the **encoder** processes the image patches and the **decoder** generates the image IDs.

Images are divided into **patches** of size `IMAGE_PATCH_SIZE`, each represented as a **3D RGB tensor**.

Image IDs are represented as a sequence of digits, starting from a **"Begin of Sequence"** token, followed by the sequence of digits representing the image ID, padded to the given `max_image_id_length` using the corresponding **"Padding"** token, and ending with an **"End of Sequence"** token: each of these tokens is represented as a vector.

The model is trained using:
- The **TransformerIndexingDataset** to first learn to output image IDs given indexed images as input using "**Masked Attention**";

- The **TransformerRetrievalDataset** to learn to output image IDs given query images as input using the knowledge acquired during the indexing task.

At **inference** time, given an image query `Q`, the model uses an "auto-regressive" approach to generate the image IDs based on the patches of the query image `Q` to which the **Begin of Sequence** token is appended, digit by digit, returning the **logits** (non-normalized probabilities) for each of the 10 possible digits at each step.

If a `WANDB_API_KEY` is provided, the model also logs the training and validation losses and accuracies to the **Weights & Biases** dashboard and shows the graphs of the training and validation losses and accuracies over time at the end of the training process, for both the indexing and retrieval tasks.

In [None]:
def train_transformer_model():
	''' Auxiliary function to train (or load checkpoints), show training results, and evaluate the transformer model of the given type '''
	
	dsi_transformer_args = {
		# Dimensionality of the input feature vectors to the Transformer (i.e. the size of the embeddings)
		"embed_dim": TRANSFORMER_EMBEDDINGS_SIZE, 
		# Dimensionality of the hidden layer in the feed-forward networks within the Transformer
		"hidden_dim": TRANSFORMER_FNN_HIDDEN_SIZE, 
		# Number of channels of the input images (e.g. 3 for RGB, 1 for grayscale, ecc...)
		"num_channels": 3,	
		# Number of heads to use in the Multi-Head Attention block (should be a divisor of the embed_dim)
		"num_heads": TRANSFORMER_NUM_HEADS,	
		# Number of layers to use in the Transformer
		"num_layers": TRANSFORMER_NUM_LAYERS,
		# Size of each batch
		"batch_size": TRANSFORMER_BATCH_SIZE,
		# Number of classes to predict (in my case, since I give an image with, concatenated, the N digits of the image ID, the num_classes is the number of possible digits of the image IDs, hence 10+3, including the special tokens)
		"num_classes": output_tokens,
		# Size of the image patches
		"patch_size": IMAGE_PATCH_SIZE,
		# Maximum number of patches an image can have
		"num_patches": actual_image_patches_per_dimension * actual_image_patches_per_dimension,
		# Maximum length of the image IDs
		"img_id_max_length": max_image_id_length,
		# Special tokens for the image IDs
		"img_id_start_token": 10,
		"img_id_end_token": 12,
		"img_id_padding_token": 11,
		# learn positional encodings
		"learn_positional_encodings": LEARN_POSITIONAL_ENCODINGS,
		# Dropout to apply in the feed-forward network and on the input encoding
		"dropout": TRANSFORMER_DROPOUT,
		# Learning rate for the optimizer
		"learning_rate": TRANSFORMER_LEARNING_RATE,
		# Other parameters (useful for keeping track of the model's training variations in the logging files and in the W&B dashboard)
		"dataset_size": NUMBER_OF_IMAGES_IN_DB,
		"images_per_class": MIN_IMAGES_PER_CLASS,
		"retrieval_db_percentage": IMAGE_RETRIEVAL_DB_PERCENTAGE,
		"shuffle_db_images": SHUFFLE_DB_IMAGES,
		"patches_per_dimension": actual_image_patches_per_dimension,
		"max_aspect_ratio_tolerance": MAX_ASPECT_RATIO_TOLERANCE
	}

	# Initialize transformer model
	transformer_model = models.DSI_VisionTransformer(**dsi_transformer_args)
	transformer_model.to(device)

	# Model's checkpoint file
	model_checkpoint_file = MODELS_FOLDER + "/" + "ViT_" + MODEL_CHECKPOINT_FILE

	# Train the model or load its saved checkpoint
	transformer_retrieval_test_set = None
	transformer_retrieval_test_set_file = DATA_FOLDER + f"/ViT_transformer_retrieval_test_set.json"
	if LOAD_MODELS_CHECKPOINTS and os.path.exists(model_checkpoint_file):
		# Load the saved models checkpoint
		print("A checkpoint for the model exist, loading the saved model checkpoint...")
		transformer_model = models.DSI_VisionTransformer.load_from_checkpoint(model_checkpoint_file, **dsi_transformer_args)
		print("Model checkpoint loaded.")
		# Load the transformer retrieval test set from the JSON file
		print("Loading the transformer retrieval test set from the JSON file...")
		with open(transformer_retrieval_test_set_file, "r") as transformer_retrieval_test_set_file:
			transformer_retrieval_test_set = json.load(transformer_retrieval_test_set_file)
		print("Transformer retrieval test set loaded.")
	else:
		# Create 2 loggers for the transformer model (one for the indexing task and one for the retrieval task)
		transformer_loggers = None
		if wandb_api is not None:
			transformer_wandb_logger_indexing = WandbLogger(log_model="all", project=wandb_project, name="ViT (Indexing)")
			transformer_wandb_logger_retrieval = WandbLogger(log_model="all", project=wandb_project, name="ViT (Retrieval)")
			transformer_loggers = [transformer_wandb_logger_indexing, transformer_wandb_logger_retrieval]
		# Train the transformer model for the indexing task
		transformer_training_infos = training.train_transformer(
			transformer_indexing_dataset=transformer_indexing_dataset,
			transformer_retrieval_dataset=transformer_image_retrieval_dataset,
			transformer_model=transformer_model,
			max_epochs_list=[TRANSFORMER_INDEXING_TRAINING_EPOCHS, TRANSFORMER_RETRIEVAL_TRAINING_EPOCHS],
			batch_size=transformer_model.hparams.batch_size,
			indexing_split_ratios=(1.0, 0.0),
			retrieval_split_ratios=(0.9, 0.05, 0.05),
			logger=transformer_loggers,
			save_path=model_checkpoint_file
		)
		# Show the wandb training run's dashboard
		if wandb_api is not None:
			indexing_run_id = transformer_training_infos["run_ids"]["indexing"]
			if indexing_run_id is not None:
				print(f"Indexing training results for the ViT model:")
				indexing_run_object: wandb_run.Run = wandb_api.run(f"{wandb_entity}/{wandb_project}/{indexing_run_id}")
				indexing_run_object.display(height=1000)
			retrieval_run_id = transformer_training_infos["run_ids"]["retrieval"]
			if retrieval_run_id is not None:
				print(f"Retrieval training results for the ViT model:")
				retrieval_run_object: wandb_run.Run = wandb_api.run(f"{wandb_entity}/{wandb_project}/{retrieval_run_id}")
				retrieval_run_object.display(height=1000)
		# Save the generated transformer retrieval test set to the JSON file
		print("Saving the transformer retrieval test set to the JSON file...")
		retrieval_test_dataset = transformer_training_infos["retrieval"]["test"]
		transformer_retrieval_test_set = {
			"images": [],
			"relevant_ids": []
		}
		retrieval_test_dataset_length = len(retrieval_test_dataset)
		for i in range(retrieval_test_dataset_length):
			encoded_img, img_id = retrieval_test_dataset[i]
			transformer_retrieval_test_set["images"].append(encoded_img.tolist())
			transformer_retrieval_test_set["relevant_ids"].append(img_id.tolist())
		with open(transformer_retrieval_test_set_file, "w") as transformer_retrieval_test_set_file:
			json.dump(transformer_retrieval_test_set, transformer_retrieval_test_set_file)

	return transformer_model, transformer_retrieval_test_set

# Train the vision transformer model
vision_transformer, transformer_retrieval_test_set = train_transformer_model()

# <a id='toc6_'></a>[📈 Evaluation](#toc0_)

## <a id='toc6_1_'></a>[BoVW Model Evaluation](#toc0_)

We evaluate the baseline **Bag of Visual Words** model by computing the **Indexing Accuracy**, **Mean Average Precision (MAP)** and **Recall at K** metrics on the images of our test dataset.

We then print the **evaluation results** for the BoVW model.

**❗Note**: The baseline BoVW model evaluation is performed using the **index-then-retrieve** pipeline, and for this reason, the **Indexing Accuracy** of the model is not computed, since the BoVW model does not generate image IDs directly, but rather retrieves images based on the cosine similarity between the BoVW representation of the query image and the BoVW representations of the indexed database images, providing a **sorted list of image IDs** based on the computed cosine similarities.


In [None]:
# Evaluate the BoVW model and print its evaluation results
if EVALUATE_MODELS:
	print("\nEvaluating the BoVW model...")
	bovw_indexing_accuracy = None	# BoVW does not provide DB indexing
	bovw_retrieval_map_k = evaluation.compute_mean_average_precision_at_k(
		images_to_return, classes,
		k_results=MAP_K, n_images=MAP_N,
		print_debug=PRINT_EVALUATION_DEBUG,
		model=BOVW_model, retrieval_dataset=transformer_image_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
	)
	bovw_retrieval_recall_k = evaluation.compute_recall_at_k(
		images_to_return, classes,
		k_results=RECALL_K, n_images=RECALL_N,
		print_debug=PRINT_EVALUATION_DEBUG,
		model=BOVW_model, retrieval_dataset=transformer_image_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
	)
	utils.print_model_evaluation_results(bovw_indexing_accuracy, bovw_retrieval_map_k, bovw_retrieval_recall_k)

We print some **examples** of the **retrieved images** for a given query image using the BoVW model (if the `PRINT_EXAMPLES` flag is set to `True`).

In [None]:
# Test the BoVW model by providing, as input, some random images from the image retrieval dataset
num_tests = 3
num_results = 5
if PRINT_EXAMPLES:
	# Get the images to test the BoVW model
	images_to_test = []
	for i in range(num_tests):
		# Get a random image from the image retrieval dataset
		random_image_id = random.choice(list(images_db_retrieval.keys()))
		# Get the image object
		random_image = images_to_return[random_image_id]
		# Add the image to the list of images to test
		images_to_test.append({
			"id": random_image_id,
			"img_obj": random_image
		})
	# Test the BoVW model
	for i in range(len(images_to_test)):
		# Get the image to test
		test_img_infos = images_to_test[i]
		test_img_id = test_img_infos["id"]
		test_img_obj = test_img_infos["img_obj"]
		# Actual similar images
		actual_similar_images = images_db_retrieval[test_img_id]
		# Get the predictiono from the BoVW model
		predicted_similar_images = BOVW_model.get_similar_images(test_img_obj, num_results)
		# Print information about the results
		print("\nImage " + str(i+1) + " to test (ID: " + str(test_img_id) + ") | Class: \"" + test_img_obj['image_label_name'] + "\"")
		compact_print = True
		if compact_print:
			print("> Actual similar images:\n  ",end="")
			print([img_obj["image_id"] for img_obj in [images_db_indexing[img_id] for img_id in actual_similar_images]])
			print("> Predicted similar images (BoVW model)\n  ",end="")
			print([img_obj["image_id"] for img_obj in predicted_similar_images])
		else:
			print("> Actual similar images:")
			for j in range(len(actual_similar_images)):
				similar_image_obj = images_db_indexing[actual_similar_images[j]]
				image_index = map_indexing_db_to_images_db[actual_similar_images[j]]
				print("  - Image #", j, " (ID: ", image_index, ") | Class: \"", similar_image_obj['image_label_name'],"\"", sep="")
			print("> Predicted similar images (BoVW model):")
			for j in range(len(predicted_similar_images)):
				similar_image_obj = predicted_similar_images[j]
				print("  - Image #", j, " (ImageNet ID: ", similar_image_obj["image_id"], ") | Class: \"", similar_image_obj['label'],"\"", sep="")
		print("> Correct predictions: ", len(set(actual_similar_images).intersection(set([img["image_id"] for img in predicted_similar_images]))), " / ", num_results, sep="")

## <a id='toc6_2_'></a>[ViT Model Evaluation](#toc0_)

We evaluate the **Vision Transformer** model for the **Differentiable Search Index (DSI)** approach by computing the **Indexing Accuracy**, **Mean Average Precision (MAP)** and **Recall at K** metrics on the images of our test dataset.

We then print the **evaluation results** for the ViT model.


In [None]:
# Evaluate the transformer model and print its evaluation results
if EVALUATE_MODELS:
	vision_transformer.eval()
	vision_transformer.to(device)
	transformer_indexing_accuracy = evaluation.compute_indexing_accuracy(
		transformer_indexing_dataset, transformer_image_retrieval_dataset,
		model=vision_transformer, 
		k_results=IA_K,
		print_debug=PRINT_EVALUATION_DEBUG
	)
	transformer_retrieval_map_k = evaluation.compute_mean_average_precision_at_k(
		images_to_return, classes,
		k_results=MAP_K, n_images=MAP_N,
		print_debug=PRINT_EVALUATION_DEBUG,
		model=vision_transformer, retrieval_dataset=transformer_image_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
	)
	transformer_retrieval_recall_k = evaluation.compute_recall_at_k(
		images_to_return, classes,
		k_results=RECALL_K, n_images=RECALL_N,
		print_debug=PRINT_EVALUATION_DEBUG,
		model=vision_transformer, retrieval_dataset=transformer_image_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
	)
	# Print the evaluation results of the Transformer model
	utils.print_model_evaluation_results(transformer_indexing_accuracy, transformer_retrieval_map_k, transformer_retrieval_recall_k)

We print some **examples** of the **retrieved images** for a given query image using the ViT model (if the `PRINT_EXAMPLES` flag is set to `True`).

In [None]:
# Test the trained traansformer model by generating the top k image IDs for an image in the retrieval dataset
if vision_transformer is not None:
	# Total images to test
	tests = 3
	# Set the model to evaluation mode
	vision_transformer.eval()
	vision_transformer.to(device)
	# Test the model
	if tests >= 1:
		for i in range(tests):
			# Get the image to test (as a tensor of the encoded image)
			test_img_retrieval_db_index = i * (len(transformer_image_retrieval_dataset) // tests)
			test_image = transformer_image_retrieval_dataset[test_img_retrieval_db_index][0]
			test_img_db_index = transformer_image_retrieval_dataset.original_ids[test_img_retrieval_db_index]
			print("\nImage ID to test: ", test_img_retrieval_db_index , " (Original ID: ", test_img_db_index, ")",sep="")
			utils.print_json(images_to_return[test_img_db_index])
			# relevant_image_id = [transformer_image_retrieval_dataset.decode_image_id(encoded_id) for encoded_id in transformer_image_retrieval_dataset[image_id_to_test][1]]
			similar_image_obj_id = -1
			if isinstance(list(transformer_image_retrieval_dataset.similar_images.keys())[0], int):
				similar_image_obj_id = test_img_db_index
			elif isinstance(list(transformer_image_retrieval_dataset.similar_images.keys())[0], str):
				similar_image_obj_id = str(test_img_db_index)
			relevant_image_indexing_ids = transformer_image_retrieval_dataset.similar_images[similar_image_obj_id]
			relevant_image_db_ids = [transformer_image_retrieval_dataset.indexing_db_to_images_db[i] for i in relevant_image_indexing_ids]
			relevant_image_ids_as_int = [int(i) if isinstance(i, str) and i.isdigit() else i for i in relevant_image_db_ids]
			# Set the number of results to generate K
			results_to_generate = 5
			# Get the top k image IDs for the first image in the retrieval dataset
			top_k_image_ids = vision_transformer.generate_top_k_image_ids(test_image, results_to_generate, transformer_image_retrieval_dataset)
			predicted_remapped_image_ids = [transformer_image_retrieval_dataset.indexing_db_to_images_db[int(i)] if int(i) in transformer_image_retrieval_dataset.indexing_db_to_images_db.keys() else str(i) for i in top_k_image_ids]
			predicted_image_ids_as_int = [int(i) if isinstance(i, str) and i.isdigit() else i for i in predicted_remapped_image_ids]
			# Display the top k image IDs and the actual relevant image IDs
			print("> Top 5 image IDs for image #" + str(test_img_db_index) + ":")
			print(predicted_remapped_image_ids)
			print("> Actual relevant image IDs (" +  str(len(relevant_image_db_ids)) + "):")
			print(relevant_image_db_ids)
			# Display the infos of relevant images
			print("> Predicted related images:")
			for j in predicted_image_ids_as_int:
				if isinstance(j, int):
					print("  - Image #", j, " (ImageNet ID: ", images_to_return[j]["image_id"], ") | Class: \"", images_to_return[j]['image_label_name'],"\"", sep="")
				else:
					print("  - Image \"", j, "\" (original image index) | NOT in DB", sep="")
			print("> Actual related images:")
			for j in relevant_image_ids_as_int:
				print("  - Image #", j, " (ImageNet ID: ", images_to_return[j]["image_id"], ") | Class: \"", images_to_return[j]['image_label_name'],"\"", sep="")
			# Visualize the image to test and the top k images
			test_image = utils.get_image_from_b64_string(images_to_return[test_img_db_index]['image_data'])
			top_k_images = []
			for j in predicted_remapped_image_ids:
				# Convert the index into an integer, if it is a number strung
				if isinstance(j, str):
					j = int(j) if j.isdigit() else j
				predicted_img_data = images_to_return[j]['image_data']
				if isinstance(j, int) and predicted_img_data is not None:
					top_k_images.append(utils.get_image_from_b64_string(predicted_img_data))
				else:
					top_k_images.append(j)
			num_of_correct_predictions = len(set(predicted_image_ids_as_int).intersection(relevant_image_ids_as_int))
			print("> Correct predicted image IDs: ", num_of_correct_predictions, " / ", results_to_generate, sep="")
			# Display the image to test and the top k images
			showActualImages = True
			if showActualImages:
				fig, axs = plt.subplots(1, results_to_generate+1, figsize=(15, 5))
				axs[0].imshow(test_image)
				axs[0].axis('off')
				axs[0].set_title("Image to Test")
				for k in range(results_to_generate):
					if isinstance(top_k_images[k], np.ndarray):
						axs[k+1].imshow(top_k_images[k])
						axs[k+1].axis('off')
						axs[k+1].set_title("Predicted image #" + str(k+1))
					else:
						axs[k+1].axis('off')
						axs[k+1].set_title("Image #" + str(top_k_images[k]) + "\nNOT in DB")
				plt.show()
			# Display all the actual relevant images from the original imagenet database, in a single row
			fig, axs = plt.subplots(1, len(relevant_image_ids_as_int), figsize=(15, 5))
			for j in range(len(relevant_image_ids_as_int)):
				imagenet_image_id = images_to_return[relevant_image_ids_as_int[j]]["image_id"]
				pil_image = imagenet_data['train'][imagenet_image_id]['image']
				axs[j].axis('off')
				axs[j].imshow(pil_image)
				axs[j].set_title("Related #" + str(j+1))
			plt.show()