In [13]:
# Check if running on colab or locally
try:
	from google.colab import files
	RUNNING_IN_COLAB = True
	print("Running on Google Colab.")
except ModuleNotFoundError:
	RUNNING_IN_COLAB = False
	print("Running locally.")

Running locally.


In [14]:
%%script false
# Clone the git repository of the project for the source files
!git clone https://github.com/valeriodiste/computer_vision_project_dev.git

Couldn't find program: 'false'


In [15]:
%%script false
# Change the working directory to the cloned repository
# TO DO: change the directory to the correct one
%cd /content/computer_vision_project_dev
# Pull the latest changes from the repository
!git pull origin main
# Change the working directory to the parent directory
%cd ..

Couldn't find program: 'false'


In [16]:
%%script false
# Install the required packages
# %%capture
%pip install pytorch-lightning
%pip install pycocotools
%pip install wandb

Couldn't find program: 'false'


In [17]:
# Import the standard libraries
import os
import json
import random
import logging
import math

# Import the PyTorch libraries and modules
import torch

# Import the PyTorch Lightning libraries and modules
import pytorch_lightning as pl

# Import the coco library
from pycocotools.coco import COCO

# Import the W&B (Weights & Biases) library
# import wandb
# from wandb.sdk import wandb_run
# from pytorch_lightning.loggers import WandbLogger

# Other libraries
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import cv2
import base64

# Import the tqdm library (for the progress bars)
if not RUNNING_IN_COLAB:
	from tqdm import tqdm
else:
	from tqdm.notebook import tqdm

In [18]:
# Import the custom modules
if not RUNNING_IN_COLAB:
	# We are running locally (not on Google Colab, import modules from the "src" directory in the current directory)
	from src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore
else:
	# We are running on Google Colab (import modules from the pulled repository stored in the project's directory)
	from computer_vision_project_dev.src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from computer_vision_project_dev.src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore

AttributeError: module 'wandb.sdk' has no attribute 'lib'

In [None]:
# Set the random seeds for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED)

Seed set to 1


1

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device.type}")

Device: cpu


In [None]:
# Constant definitions

# ===== Training & Datasets constants =====================================================

# MS COCO dataset constants (use MS COCO 2014 dataset for image captioning)
COCO_DATA_YEAR = '2014'  	# '2014' or '2017'
COCO_DATA_TYPE = 'val'  # 'train' or 'val'
COCO_DATA_CAPTIONS_FILE = f"/annotations/captions_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the annotations file inside the DATA_FOLDER
CODO_DATA_INSTANCES_FILE = f"/annotations/instances_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the instances file inside the DATA_FOLDER

# Size of the image patches
IMAGE_PATCH_SIZE = 16
# Number of image patches per dimension (i.e. both vertically and horizontally, since images have a square aspect ratio)
IMAGE_PATCHES_PER_DIMENSION = 10	# 3x3 patches, 48x48 pixels images

# Total number of images to consider in the dataset (will be split into training, validation and test sets), set to -1 to use all the available images
NUMBER_OF_IMAGES_IN_DB = 100		# Was 1000
# Minimum number of captions for an image
MIN_IMAGE_CAPTIONS = 5
# If not enough square images are found, also accept images that have this max aspect difference (they will be cropped to a square aspect ratio later)
MAX_ASPECT_RATIO_TOLERANCE = 0.1 	# Accept images that are 10% wider than they are tall (or vice versa)
# Minimum number of images per class
MIN_IMAGES_PER_CLASS = 10	# Was 100

# Percentage of images, for each class, to use for the image retrieval dataset (the remaining images will be used for the indexing dataset, i.e. will be added in the images database)
IMAGE_RETRIEVAL_DB_PERCENTAGE = 0.8

# ===== Evaluation constants ==============================================================

# Define the number of images K to retrieve for each query and the number of queries N to calculate the mean average precision (MAP@K)
MAP_K = 10
MAP_N = 10

# Define the number of images K to retrieve for each query to calculate the Recall@K metrics
RECALL_K = 1_000

# Whether to print the debug information during the MAP@K and Recall@K evaluation of the models
PRINT_EVALUATION_DEBUG = True

# Whether to evaluate the models (i.e. compute the MAP@K and Recall@K metrics for the trained models on the test datasets)
EVALUATE_MODELS = True

# ===== MAIN CONSTANTS =====================================================================

# Whether to print examples of the images and captions during the dataset creation
PRINT_EXAMPLES = False

# Whether to load demo data from the "demo/" folder (set to False to build the dataset from the COCO dataset using the above constants or to load its existing version)
USE_DEMO_DATA = True

# Number of images in the demo dataset (if LOAD_DEMO_DATA is set to True)
DEMO_DATA_SIZE = 100	# Data will be found in the "demo/{DEMO_DATA_SIZE}/" folder

# Define the data folder, onto which the various dictionaries, lists and other data will be saved
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the path to save models
MODELS_FOLDER = "src/models" if not RUNNING_IN_COLAB else "/content/models"

# Folder containing the demo images and captions
DEMO_FOLDER = f"src/demo/{DEMO_DATA_SIZE}/" if not RUNNING_IN_COLAB else f"/content/computer_vision_project_dev/src/demo/{DEMO_DATA_SIZE}/"

# Force the creation of the "image_db" images list, the JSON files for the datasets, ecc...
FORCE_DICTIONARIES_CREATION = True		# Set to false to try to load the dictionaries from the DATA_FOLDER if they exist

# Whether to load model checkpoints (if they were already saved locally) or not
LOAD_MODELS_CHECKPOINTS = True

In [None]:
# Define the WANDB_API_KEY (set to "" to disable W&B logging)
# NOTE: leaving the WANDB_API_KEY to a value of None will throw an error
WANDB_API_KEY = ""

In [None]:
# Define the wandb logger, api object, entity name and project name
wandb_logger = None
wandb_api = None
wandb_entity = None
wandb_project = None
# Check if a W&B api key is provided
if WANDB_API_KEY == None:
	print("No W&B API key provided, please provide a valid key to use the W&B API or set the WANDB_API_KEY variable to an empty string to disable logging")
	raise ValueError("No W&B API key provided...")
elif WANDB_API_KEY != "":
	# Login to the W&B (Weights & Biases) API
	wandb.login(key=WANDB_API_KEY, relogin=True)
	# Minimize the logging from the W&B (Weights & Biases) library
	os.environ["WANDB_SILENT"] = "true"
	logging.getLogger("wandb").setLevel(logging.ERROR)
	# Initialize the W&B (Weights & Biases) loggger
	wandb_logger = WandbLogger(
		log_model="all", project="cv-dsi-project", name="- SEPARATOR -")
	# Initialize the W&B (Weights & Biases) API
	wandb_api = wandb.Api()
	# Get the W&B (Weights & Biases) entity name
	wandb_entity = wandb_logger.experiment.entity
	# Get the W&B (Weights & Biases) project name
	wandb_project = wandb_logger.experiment.project
	# Finish the "separator" experiment
	wandb_logger.experiment.finish(quiet=True)
	print("W&B API key provided, logging with W&B enabled.")
else:
	print("No W&B API key provided, logging with W&B disabled.")

No W&B API key provided, logging with W&B disabled.


In [None]:
# Create folders if they do not exist
if not os.path.exists(DATA_FOLDER):
	print(f"Creating the data folder at '{DATA_FOLDER}'...")
	os.makedirs(DATA_FOLDER)
if not os.path.exists(MODELS_FOLDER):
	print(f"Creating the models folder at '{MODELS_FOLDER}'...")
	os.makedirs(MODELS_FOLDER)

In [None]:
# Check if the annotation file for the COCO dataset exists, if it does not exist, download it
!cd {DATA_FOLDER} && wget -nc http://images.cocodataset.org/annotations/annotations_trainval{COCO_DATA_YEAR}.zip
!cd {DATA_FOLDER} && unzip -n annotations_trainval{COCO_DATA_YEAR}.zip

"wget" non � riconosciuto come comando interno o esterno,
 un programma eseguibile o un file batch.
"unzip" non � riconosciuto come comando interno o esterno,
 un programma eseguibile o un file batch.


In [None]:
# Build the dataset or load the demo dataset
coco_captions = None
coco_instances = None

if not USE_DEMO_DATA:

	# Initialize the COCO api for captioning
	coco_captions = COCO(f"{DATA_FOLDER}{COCO_DATA_CAPTIONS_FILE}")
	# Initialize the COCO api for object detection
	coco_instances = COCO(f"{DATA_FOLDER}{CODO_DATA_INSTANCES_FILE}")

	# Show the COCO dataset info for the captioning task
	print("\nCOCO captioning dataset infos:")
	coco_captions.info()

	# Show the information for the captioning task
	print("\nCOCO captioning task infos:")
	coco_caps = coco_captions.dataset['annotations']
	print("Number of images: ", len(coco_captions.getImgIds()))
	print("Number of captions: ", len(coco_caps))
	print("Number of average captions per image: ", len(coco_caps) / len(coco_captions.getImgIds()))

	# Show the COCO dataset info for the object detection task
	print("\nCOCO object detection dataset infos:")
	coco_instances.info()

	# Show the information for the object detection task
	print("\nCOCO object detection task infos:")
	coco_objs = coco_instances.dataset['annotations']
	print("Number of images: ", len(coco_instances.getImgIds()))
	print("Number of objects: ", len(coco_objs))
	print("Number of categories: ", len(coco_instances.cats))
	print("Categories:")
	utils.print_json(coco_instances.cats)


In [None]:
# Print some examples from the MS COCO dataset

if not USE_DEMO_DATA and PRINT_EXAMPLES:
	# Print the first image object example
	example_image_index = 0
	print("\nImage object example: ")
	image_example = coco_captions.loadImgs(coco_caps[example_image_index]['image_id'])[0]
	utils.print_json(image_example, 2)

	# Print the actual image file
	print("\nActual image of the example (size: " + str(image_example['width']) + "x" + str(image_example['height']) + "):")
	url = image_example['coco_url']
	image = io.imread(url)
	plt.axis('off')
	plt.imshow(image)
	plt.show()

	# Downscale the image to the maximum allowed size in the model
	image_max_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
	# Crop the image to a square aspect ratio if it is not already square
	downscaled_image = image
	if image_example['width'] > image_example['height']:
		# Image is wider than tall, crop the sides
		crop_width = (image_example['width'] - image_example['height']) // 2
		downscaled_image = image[:, crop_width:crop_width+image_example['height']]
	elif image_example['height'] > image_example['width']:
		# Image is taller than wide, crop the top and bottom
		crop_height = (image_example['height'] - image_example['width']) // 2
		downscaled_image = image[crop_height:crop_height+image_example['width'], :]
	# Downscale the image to the maximum allowed size
	downscaled_image = cv2.resize(downscaled_image, (image_max_size, image_max_size))
	print("\nDownscaled & cropped image of the example (size: " + str(image_max_size) + "x" + str(image_max_size) + "):")
	plt.axis('off')
	plt.imshow(downscaled_image)
	plt.show()

	# Print the captions for the given image
	print("\nCaption examples for the given image: ")
	captions_for_image = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=image_example['id']))
	for caption, i in zip(captions_for_image, range(len(captions_for_image))):
		print(str(i+1) + ") " + caption['caption'].strip())

	# Print the captioning object example
	print("\nFirst caption object example:")
	utils.print_json(captions_for_image[0], 2)

	# Print information about the object detection task for the given image
	print("\nObject detection examples for the given image:")
	# Get the object detection annotations for the given image
	annotations_for_image = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=image_example['id']))
	print("List of the " + str(len(annotations_for_image)) + " object detection annotations for the given image (obtained using the 'coco_instances.loadAnns(image_annotation_id)' function):")
	for annotation, i in zip(annotations_for_image, range(len(annotations_for_image))):
		print("\n> Annotation " + str(i+1) + ":")
		# Print the annotation object
		utils.print_json(annotation, 2, truncate_large_lists=10)

In [None]:
# Build a dataset of images for the training of the Vision Transformer model

# Function that returns the list containing the images for the training of the Vision Transformer model
def get_images_db(number_of_images, process_images=True):
	'''
		Builds a list of images for the training of the Vision Transformer model.

		Parameters:
			number_of_images (int): The number of images to include in the dataset (search is stopped when the number of images is reached), use -1 to include all available images
			process_images (bool): Whether to process the images (i.e. retrieve actual image data, crop images and compupte their base64 encodings to add to the list)
	'''
	# Structure of the images
	images_list_object = {
		"image_id": "",			# ID of the image (as found in the COCO dataset)
		"image_url": "",		# URL of the image
		"image_width": 0,		# The original image width
		"image_height": 0,		# The original image height
		"image_captions": [],	# List of captions for the image
		"image_classes": [		# List of classes for the image (i.e. detected objects, in the order of area size)
			{
				"class_id": 0,		# ID of the class (as found in the COCO dataset)
				"class_name": "",	# Name of the class
				"class_area": 0,	# Sum of the area of each instance of the class in the image
				"class_count": 0	# Number of instances of the class in the image
			}
		],	
		"image_data": ""		# Base64 string of the image
	}
	# Get the image ids
	img_ids = coco_captions.getImgIds()
	# Get the images
	images = []
	
	# Function that returns a list of images with the given aspect ratio tolerance
	def select_images_list(image_aspect_ratio_tolerance):
		# Get the images
		for img_id in img_ids:
			# Get the image object
			img_obj = coco_captions.loadImgs(img_id)[0]
			# Check if the size of the image is square or within the aspect ratio tolerance
			image_aspect_ratio = img_obj['width'] / img_obj['height']
			if abs(image_aspect_ratio - 1) > image_aspect_ratio_tolerance:
				continue
			# Check if the image is already in the images list
			if any(img['image_id'] == img_obj['id'] for img in images):
				continue
			# Get the image url
			img_url = img_obj['coco_url']
			# Get the captions for the image
			img_captions = []
			captions = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=img_obj['id']))
			for caption in captions:
				caption_text = caption['caption'].strip()
				if len(caption_text) > 1:
					img_captions.append(caption_text)
			# Discard the image if the number of captions is less than the minimum
			if len(img_captions) < MIN_IMAGE_CAPTIONS:
				continue
			# Discard the image if it has no classes
			classes = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=img_obj['id']))
			if len(classes) == 0:
				continue
			# Create a classes object with the fields: "class_id", "class_name", "class_area"
			classes_obj = {}
			for class_obj in classes:
				class_id = class_obj['category_id']
				class_name = coco_instances.cats[class_id]['name']
				class_area = class_obj['area']
				if class_id not in classes_obj:
					classes_obj[class_id] = {
						"class_id": class_id,
						"class_name": class_name,
						"class_area": class_area,
						"class_count": 1,
					}
				else:
					classes_obj[class_id]['class_area'] += class_area
					classes_obj[class_id]['class_count'] += 1
			# Convert the classes object to a list
			classes_obj = list(classes_obj.values())
			# Sort the classes by area size
			classes_obj = sorted(classes_obj, key=lambda x: x['class_area'], reverse=True)
			# Add the image to the images list
			images_list_object = {
				"image_id": img_obj['id'],
				"image_url": img_url,
				"image_width": img_obj['width'],
				"image_height": img_obj['height'],
				"image_captions": img_captions,
				"image_classes": classes_obj,
				"image_data": None # Will be filled later
			}
			images.append(images_list_object)
			# Break if the number of images is reached
			if number_of_images >= 1 and len(images) >= number_of_images:
				break
		# Return the images list
		return images
	print("Selecting images with a square aspect ratio...")
	# Get the images that have a square aspect ratio first
	images = select_images_list(0)
	# Get the remaining images with the given aspect ratio tolerance
	if len(images) < number_of_images or number_of_images == -1:
		square_aspect_ratio_images = len(images)
		print("> Found " + str(square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio, looking for the remaining images...")
		print("Looking for remaining images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "% (either a " + str(1 + MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio or a " + str(1 - MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio)...")
		images = select_images_list(MAX_ASPECT_RATIO_TOLERANCE)
		non_square_aspect_ratio_images = len(images) - square_aspect_ratio_images
		# Print the number of images found
		print("> Found " + str(non_square_aspect_ratio_images) + (" / " + str(number_of_images) if number_of_images > 0 else "" )  + " more images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "%.")
	else:
		print("> Found " + str(len(images)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with a square aspect ratio.")
	# Print a message based on the number of images found
	if len(images) < number_of_images and number_of_images != -1:
		print("WARNING: Could not find enough images with the required aspect ratio tolerance, only " + str(len(images)) + " / " + str(number_of_images) + " images found.")
	else:
		print("DONE: Found all " + str(len(images)) + (" / " + str(number_of_images) if number_of_images > 0 else "" ) + " images with the required aspect ratio tolerance.")
	# Get all the image data
	if process_images:
		for img in tqdm(images, desc="Processing images data (computing BASE64 images encoding)..."):
			img["image_data"] = utils.get_image_data_as_base64(img['image_url'], IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION)
	# Return the images list
	return images

# List of image objects used for the training of the Vision Transformer model
images_db = []

if not USE_DEMO_DATA:
	# Check if the images list should be rebuilt or loaded
	images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
	if os.path.exists(images_db_file) and not FORCE_DICTIONARIES_CREATION:
		with open(images_db_file, 'r') as f:
			images_db = json.load(f)
		print("Loaded the images list from the file: ", images_db_file)
	else:
		# Initialize the images list
		images_db = get_images_db(-1, False)
		# Save the images list to a JSON file
		images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
		print("Saving the images list to the file: ", images_db_file)
		with open(images_db_file, 'w') as f:
			json.dump(images_db, f)

	# Print the final number of images in the dataset
	print("\nNumber of loaded images in the dataset: " + str(len(images_db)) + ("/" + str(NUMBER_OF_IMAGES_IN_DB) if NUMBER_OF_IMAGES_IN_DB != -1 else ""))
else:
	# Load the imaged_db.json file from the demo folder
	images_db_file = os.path.join(DEMO_FOLDER, "images_db.json")
	if os.path.exists(images_db_file):
		with open(images_db_file, 'r') as f:
			images_db = json.load(f)
		print("Loaded the images list from the file: ", images_db_file)
	else:
		print("ERROR: Could not load the demo images list from the file: ", images_db_file)
		raise FileNotFoundError("Could not load the demo images list from the file: " + images_db_file)

Loaded the images list from the file:  src/demo/100/images_db.json


In [None]:
# Create a "classes" dictionary with the classes found in the dataset and, for each of them, a list of the images in which they appear

# initialize the classes dictionary
classes = {}

if not USE_DEMO_DATA:
	# Function to get the classes dictionary from the images
	def get_classes_dict():
		# Initialize the classes list
		classes = {}
		# Get the classes from the images
		for i in tqdm(range(len(images_db)), desc="Processing images for classes..."):
			img = images_db[i]
			for class_obj in img['image_classes']:
				# Get the class id
				class_id = class_obj['class_id']
				# Add the class to the classes list if it does not exist
				if class_id not in classes.keys():
					classes[class_id] = []
				# Add the image index to the class list
				classes[class_id].append(i)
		print("Created the classes list from the images with " + str(len(classes)) + " classes.")
		# Discard the classes with less than the minimum number of images
		classes = {k: v for k, v in classes.items() if len(v) >= MIN_IMAGES_PER_CLASS}
		# Sort classes by the number of images
		classes = {k: v for k, v in sorted(classes.items(), key=lambda item: len(item[1]), reverse=True)}
		print("Discarded the classes with less than " + str(MIN_IMAGES_PER_CLASS) + " images: " + str(len(classes)) + " / " + str(len(classes.keys()) + len(classes)) + " classes remaining.")
		# Return the classes list
		return classes

	# Get the classes dictionary if it already exists, otherwise create it
	classes_file = os.path.join(DATA_FOLDER, "classes.json")
	if os.path.exists(classes_file) and not FORCE_DICTIONARIES_CREATION:
		with open(classes_file, 'r') as f:
			classes = json.load(f)
		if len(classes) > 0:
			print("Loaded the classes dictionary from the file: ", classes_file)
	else:
		print("Creating the classes dictionary from the images...")
		classes = get_classes_dict()
		# Save the classes dictionary to a JSON file
		print("Saving the classes dictionary to the file: ", classes_file)
		with open(classes_file, 'w') as f:
			json.dump(classes, f)
else:
	# Load the classes.json file from the demo folder
	classes_file = os.path.join(DEMO_FOLDER, "classes.json")
	if os.path.exists(classes_file):
		with open(classes_file, 'r') as f:
			classes = json.load(f)
		print("Loaded the classes dictionary from the file: ", classes_file)
	else:
		print("ERROR: Could not load the demo classes dictionary from the file: ", classes_file)
		raise FileNotFoundError("Could not load the demo classes dictionary from the file: " + classes_file)

Loaded the classes dictionary from the file:  src/demo/100/classes.json


In [None]:
# Update the images in the images DB to finally only include the images that have the classes in the classes list, with MIN_IMAGES_PER_CLASS images per class, and to populate the images list with the base64 encoding of the images

if not USE_DEMO_DATA:
	# Function to update the images list to only include the images that have the classes in the classes list
	def update_images_db_based_on_classes(max_images):
		# Number of classes to maintain the designated number of images
		classes_count = math.ceil(max_images / MIN_IMAGES_PER_CLASS)
		# Initialize the new images list
		new_images_db = []
		images_db_ids_map = {}
		# Get the classes to maintain the designated number of images
		classes_to_maintain = list(classes.keys())[:classes_count]
		# Create a new classes list with the classes found in the new images list
		new_classes = {}
		# Get the images to maintain the designated number of images
		for i in tqdm(range(len(images_db)), desc="Processing images for classes..."):
			img = images_db[i]
			# Check if the image has any of the classes to maintain
			if any(class_obj['class_id'] in classes_to_maintain for class_obj in img['image_classes']):
				new_images_db.append(img)
				images_db_ids_map[i] = len(new_images_db) - 1
			# Break if the number of images is reached
			if len(new_images_db) >= max_images:
				break
		# Get the classes from the classes to maintain
		for class_id in classes_to_maintain:
			# Remove any image index that is not in the new images list
			# new_classes[class_id] = [i for i in classes[class_id] if i < len(new_images_db)]
			new_classes[class_id] = [images_db_ids_map[i] for i in classes[class_id] if i in images_db_ids_map]
		# Sort the classes by the number of images
		new_classes = { k: v for k, v in sorted(new_classes.items(), key=lambda item: len(item[1]), reverse=True) }
		# Return the new images list
		return new_images_db, new_classes

	# Update the images list to only include the images that have the classes in the classes list
	max_images = NUMBER_OF_IMAGES_IN_DB if NUMBER_OF_IMAGES_IN_DB != -1 else len(images_db)
	print("\nUpdating the images list to only include the " + str(len(classes)) + " classes with at least " + str(MIN_IMAGES_PER_CLASS) + " images, not exceeding " + str(max_images) + " images...")
	images_db, classes = update_images_db_based_on_classes(max_images)
	print("DONE: Updated the images list, now containing " + str(len(images_db)) + " images.")
	print("> Final number of classes in the dataset: " + str(len(classes)))

	# Update the images list to include the base64 encoding of the images
	print("Computing the BASE64 images encoding for the images list...")
	for img in tqdm(images_db, desc="Processing images data (computing BASE64 images encoding)..."):
		img["image_data"] = utils.get_image_data_as_base64(img['image_url'], IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION)
	print("DONE: Computed the BASE64 images encoding for the images list.")

	# Save the updated images list to a JSON file
	images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
	print("Saving the updated images list to the file: ", images_db_file)
	with open(images_db_file, 'w') as f:
		json.dump(images_db, f)

# Print the final number of images in the dataset
print("\nNumber of loaded images in the dataset: " + str(len(images_db)))
print("Number of classes in the dataset: " + str(len(classes)))	



Number of loaded images in the dataset: 100
Number of classes in the dataset: 77


In [None]:
if PRINT_EXAMPLES:

	# Print the first image object example
	example_image_index = -1
	print("Image object example: ")
	utils.print_json(images_db[example_image_index], 2)

	# Print the actual image file
	image_b64_string = images_db[example_image_index]['image_data'] if images_db[example_image_index]['image_data'] != None else utils.get_image_data_as_base64(images_db[example_image_index]['image_url'], IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION)
	image = utils.get_image_from_b64_string(image_b64_string)
	print("\nActual image of the example (original size: " + str(images_db[example_image_index]['image_width']) + "x" + str(images_db[example_image_index]['image_height']) + " | downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	plt.axis('off')
	plt.imshow(image)
	plt.show()

	# Print how the Transformer model sees the image
	print("\nHow the Transformer model sees the image (downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
	# Divide the image into smaller images representing the patches
	image_patches = []
	for i in range(0, image.shape[0], IMAGE_PATCH_SIZE):
		for j in range(0, image.shape[1], IMAGE_PATCH_SIZE):
			image_patch = image[i:i+IMAGE_PATCH_SIZE, j:j+IMAGE_PATCH_SIZE]
			image_patches.append(image_patch)
	# Display the image patches
	fig, axs = plt.subplots(IMAGE_PATCHES_PER_DIMENSION, IMAGE_PATCHES_PER_DIMENSION, figsize=(10, 10))
	for i in range(IMAGE_PATCHES_PER_DIMENSION):
		for j in range(IMAGE_PATCHES_PER_DIMENSION):
			axs[i, j].imshow(image_patches[i*IMAGE_PATCHES_PER_DIMENSION+j])
			axs[i, j].axis('off')
	plt.show()

In [None]:
# Print the classes list
total_classes = -1 if USE_DEMO_DATA else len(coco_instances.cats)
print("\nClasses list sorted by number of images (" + str(len(classes)) + " classes out of " + (str(total_classes) if total_classes != -1 else "???") + " total MS COCO classes):")
utils.print_json(classes, 2, truncate_large_lists=10)


Classes list sorted by number of images (77 classes out of ??? total MS COCO classes):
    "1": [
      0,
      3,
      6,
      7,
      8,
      9,
      13,
      14,
      17,
      18,
      "...",
      "(Truncated to 10 out of 1109 elements)"
    ],
    "67": [
      19,
      28,
      29,
      45,
      47,
      55,
      56,
      58,
      59,
      66,
      "...",
      "(Truncated to 10 out of 337 elements)"
    ],
    "47": [
      16,
      23,
      28,
      46,
      55,
      56,
      66,
      67,
      68,
      69,
      "...",
      "(Truncated to 10 out of 235 elements)"
    ],
    "51": [
      11,
      28,
      29,
      55,
      67,
      68,
      69,
      100,
      101,
      102,
      "...",
      "(Truncated to 10 out of 215 elements)"
    ],
    "62": [
      16,
      28,
      29,
      40,
      45,
      53,
      103,
      104,
      106,
      124,
      "...",
      "(Truncated to 10 out of 195 elements)"
    ],
    "3": [
      0,
 

In [None]:
# Split the images list into a list for the indexing dataset (i.e. images in the database) and a list for the image retrieval dataset (i.e. similar images to retrieve images in the DB)

# List of images for the indexing and image retrieval datasets
images_db_indexing = []	# List of images for the indexing dataset
images_db_image_retrieval = {} # Dictionary containing image IDs of images NOT in the indexing dataset as keys and the list of similar images in the indexing dataset as values

if not USE_DEMO_DATA:
	# Create the indexing and image retrieval datasets from the images list
	for class_id in classes.keys():
		class_obj = classes[class_id]
		indexing_number = int(len(class_obj) * (1 - IMAGE_RETRIEVAL_DB_PERCENTAGE))
		similar_images = []
		remap_image_ids = {}
		for i in range(len(class_obj)):
			is_in_db = i < indexing_number
			img_id = class_obj[i]
			# Get the image object
			img = images_db[img_id]
			if is_in_db:
				# Add the image to the indexing dataset
				images_db_indexing.append(img)
				# Add the image to the similar images list
				similar_images.append(img_id)
				# Store the remapping of the image IDs
				remap_image_ids[img_id] = len(images_db_indexing) - 1
			else:
				# Add the image to the image retrieval dataset
				images_db_image_retrieval[img_id] = [ remap_image_ids[i] for i in similar_images ]
	# Save the indexing and image retrieval datasets to JSON files
	images_db_indexing_file = os.path.join(DATA_FOLDER, "images_db_indexing.json")
	print("Saving the images list for the indexing dataset to the file: ", images_db_indexing_file)
	with open(images_db_indexing_file, 'w') as f:
		json.dump(images_db_indexing, f)
	images_db_image_retrieval_file = os.path.join(DATA_FOLDER, "images_db_image_retrieval.json")
	print("Saving the images list for the image retrieval dataset to the file: ", images_db_image_retrieval_file)
	with open(images_db_image_retrieval_file, 'w') as f:
		json.dump(images_db_image_retrieval, f)
else:
	# Load the images_db_indexing.json and images_db_image_retrieval.json files from the demo folder
	images_db_indexing_file = os.path.join(DEMO_FOLDER, "images_db_indexing.json")
	images_db_image_retrieval_file = os.path.join(DEMO_FOLDER, "images_db_image_retrieval.json")
	if os.path.exists(images_db_indexing_file):
		with open(images_db_indexing_file, 'r') as f:
			images_db_indexing = json.load(f)
		print("Loaded the images list for the indexing dataset from the file: ", images_db_indexing_file)
	else:
		print("ERROR: Could not load the demo images list for the indexing dataset from the file: ", images_db_indexing_file)
		raise FileNotFoundError("Could not load the demo images list for the indexing dataset from the file: " + images_db_indexing_file)
	if os.path.exists(images_db_image_retrieval_file):
		with open(images_db_image_retrieval_file, 'r') as f:
			images_db_image_retrieval = json.load(f)
		print("Loaded the images list for the image retrieval dataset from the file: ", images_db_image_retrieval_file)
	else:
		print("ERROR: Could not load the demo images list for the image retrieval dataset from the file: ", images_db_image_retrieval_file)
		raise FileNotFoundError("Could not load the demo images list for the image retrieval dataset from the file: " + images_db_image_retrieval_file)

# Compute the max length of the image IDS (we consider the index of the image in the "images_db" as the image ID)
max_image_id_length = len(str(len(images_db_indexing)))
# Number of output tokens for the encoded image IDs (the 10 digits [0-9] plus the 3 special tokens, i.e. end of sequence, padding, start of sequence)
output_tokens = 10 + 3

# Print the final number of images in the datasets
print("\nNumber of images in the indexing dataset: " + str(len(images_db_indexing)))
print("Image IDs max length: " + str(max_image_id_length))


Loaded the images list for the indexing dataset from the file:  src/demo/100/images_db_indexing.json
Loaded the images list for the image retrieval dataset from the file:  src/demo/100/images_db_image_retrieval.json

Number of images in the indexing dataset: 30
Image IDs max length: 2


In [None]:
# Paths of the file in which the PyTorch datasets will be stored or from which they will be loaded
transformer_indexing_dataset_file = (DATA_FOLDER + "/" if not USE_DEMO_DATA else DEMO_FOLDER) + "transformer_indexing_dataset.json"
transformer_image_retrieval_dataset_file = (DATA_FOLDER + "/" if not USE_DEMO_DATA else DEMO_FOLDER) + "transformer_image_retrieval_dataset.json"

# Build the Transformer Indexing Database for training the vision transformer
transformer_indexing_dataset = datasets.TransformerIndexingDataset(
	images=images_db_indexing,
	patch_size=IMAGE_PATCH_SIZE,
	img_patches=IMAGE_PATCHES_PER_DIMENSION,
	img_id_max_length=max_image_id_length,
	dataset_file_path=transformer_indexing_dataset_file,
	force_dataset_rebuild=FORCE_DICTIONARIES_CREATION and not USE_DEMO_DATA
)

# Build the Transformer Image Retrieval Database for training the vision transformer
transformer_image_retrieval_dataset = datasets.TransformerImageRetrievalDataset(
	all_images=images_db,
	similar_images=images_db_image_retrieval,
	patch_size=IMAGE_PATCH_SIZE,
	img_patches=IMAGE_PATCHES_PER_DIMENSION,
	img_id_max_length=max_image_id_length,
	dataset_file_path=transformer_image_retrieval_dataset_file,
	force_dataset_rebuild=FORCE_DICTIONARIES_CREATION and not USE_DEMO_DATA
)

Loading the Vision Transformer Indexing Dataset from src/demo/100/transformer_indexing_dataset.json...
Loaded 30 images from src/demo/100/transformer_indexing_dataset.json
Loading the Vision Transformer Indexing Dataset from src/demo/100/transformer_image_retrieval_dataset.json...
Loaded 514 images from src/demo/100/transformer_image_retrieval_dataset.json


In [None]:
if PRINT_EXAMPLES or True:
	# Print the first example from the Transformer Indexing Dataset
	example_index = random.randint(0, len(transformer_indexing_dataset)-1)
	print("Example from the Transformer Indexing Dataset:")
	print("<encoded_image, encoded_image_id> tuple:")
	print(transformer_indexing_dataset[example_index])

	# Print the first example from the Transformer Image Retrieval Dataset
	example_index = random.randint(0, len(transformer_image_retrieval_dataset)-1)
	print("\nExample from the Transformer Image Retrieval Dataset:")
	print("<encoded_image, encoded_similar_image_id> tuple:")
	print(transformer_image_retrieval_dataset[example_index])

In [None]:
TRANSFORMER_EMBEDDINGS_SIZE = 128

TRANSFORMER_INDEXING_TRAINING_EPOCHS = 250
TRANSFORMER_RETRIEVAL_TRAINING_EPOCHS = 150

def train_and_evaluate_transformer():
	''' Auxiliary function to train (or load checkpoints), show training results, and evaluate the transformer model of the given type '''
	
	dsi_transformer_args = {
		# Dimensionality of the input feature vectors to the Transformer (i.e. the size of the embeddings)
		"embed_dim": TRANSFORMER_EMBEDDINGS_SIZE, 
		# Dimensionality of the hidden layer in the feed-forward networks within the Transformer
		"hidden_dim": 256, 
		# Number of channels of the input images (e.g. 3 for RGB, 1 for grayscale, ecc...)
		"num_channels": 3,	
		# Number of heads to use in the Multi-Head Attention block
		"num_heads": 4,	
		# Number of layers to use in the Transformer
		"num_layers": 3,
		# Size of each batch
		"batch_size": 32,
		# Number of classes to predict (in my case, since I give an image with, concatenated, the N digits of the image ID, the num_classes is the number of possible digits of the image IDs, hence 10+3, including the special tokens)
		"num_classes": output_tokens,
		# Size of the image patches
		"patch_size": IMAGE_PATCH_SIZE,
		# Maximum number of patches an image can have
		"num_patches": IMAGE_PATCHES_PER_DIMENSION * IMAGE_PATCHES_PER_DIMENSION,
		# Maximum length of the image IDs
		"img_id_max_length": max_image_id_length,
		# Special tokens for the image IDs
		"img_id_start_token": 10,
		"img_id_end_token": 12,
		"img_id_padding_token": 11,
		# Dropout to apply in the feed-forward network and on the input encoding
		"dropout": 0.2,
		# Learning rate for the optimizer
		"learning_rate": 0.001,
	}

	# Initialize transformer model (using scheduled sampling)
	transformer_model = models.DSI_VisionTransformer(**dsi_transformer_args)

	# Model's type string
	model_type_string = "DSI_VisionTransformer"

	# Model's checkpoint file
	model_checkpoint_file = MODELS_FOLDER + "/" + model_type_string + "_" + MODEL_CHECKPOINT_FILE

	# Train the model or load its saved checkpoint
	transformer_retrieval_test_set = None
	transformer_retrieval_test_set_file = DATA_FOLDER + f"/{model_type_string}_transformer_retrieval_test_set.json"
	if LOAD_MODELS_CHECKPOINTS and os.path.exists(model_checkpoint_file):
		# Load the saved models checkpoint
		print("A checkpoint for the model exist, loading the saved model checkpoint...")
		transformer_model = models.DSI_VisionTransformer.load_from_checkpoint(model_checkpoint_file, **dsi_transformer_args)
		print("Model checkpoint loaded.")
		# Load the transformer retrieval test set from the JSON file
		print("Loading the transformer retrieval test set from the JSON file...")
		with open(transformer_retrieval_test_set_file, "r") as transformer_retrieval_test_set_file:
			transformer_retrieval_test_set = json.load(transformer_retrieval_test_set_file)
		print("Transformer retrieval test set loaded.")
	else:
		# Create 2 loggers for the transformer model (one for the indexing task and one for the retrieval task)
		transformer_loggers = None
		if wandb_api is not None:
			transformer_wandb_logger_indexing = WandbLogger(log_model="all", project=wandb_project, name=model_type_string + " (Indexing)")
			transformer_wandb_logger_retrieval = WandbLogger(log_model="all", project=wandb_project, name=model_type_string + " (Retrieval)")
			transformer_loggers = [transformer_wandb_logger_indexing, transformer_wandb_logger_retrieval]
		# Train the transformer model (with scheduled sampling) for the indexing task
		transformer_training_infos = training.train_transformer(
			transformer_indexing_dataset=transformer_indexing_dataset,
			transformer_retrieval_dataset=transformer_image_retrieval_dataset,
			transformer_model=transformer_model,
			max_epochs_list=[TRANSFORMER_INDEXING_TRAINING_EPOCHS, TRANSFORMER_RETRIEVAL_TRAINING_EPOCHS],
			batch_size=transformer_model.hparams.batch_size,
			indexing_split_ratios=(1.0, 0.0),
			retrieval_split_ratios=(0.9, 0.05, 0.05),
			logger=transformer_loggers,
			save_path=model_checkpoint_file
		)
		# Show the wandb training run's dashboard
		if wandb_api is not None:
			indexing_run_id = transformer_training_infos["run_ids"]["indexing"]
			if indexing_run_id is not None:
				print(f"Indexing training results for the {model_type_string} model:")
				indexing_run_object: wandb_run.Run = wandb_api.run(f"{wandb_entity}/{wandb_project}/{indexing_run_id}")
				indexing_run_object.display(height=1000)
			retrieval_run_id = transformer_training_infos["run_ids"]["retrieval"]
			if retrieval_run_id is not None:
				print(f"Retrieval training results for the {model_type_string} model:")
				retrieval_run_object: wandb_run.Run = wandb_api.run(f"{wandb_entity}/{wandb_project}/{retrieval_run_id}")
				retrieval_run_object.display(height=1000)
		# Save the generated transformer retrieval test set to the JSON file
		print("Saving the transformer retrieval test set to the JSON file...")
		retrieval_test_dataset = transformer_training_infos["retrieval"]["test"]
		transformer_retrieval_test_set = {
			"encoded_queries": [],
			"encoded_doc_ids": []
		}
		retrieval_test_dataset_length = retrieval_test_dataset.__len__()
		for i in range(retrieval_test_dataset_length):
			encoded_query, doc_id = retrieval_test_dataset.__getitem__(i)
			transformer_retrieval_test_set["encoded_queries"].append(encoded_query.tolist())
			transformer_retrieval_test_set["encoded_doc_ids"].append(doc_id.tolist())
		with open(transformer_retrieval_test_set_file, "w") as transformer_retrieval_test_set_file:
			json.dump(transformer_retrieval_test_set, transformer_retrieval_test_set_file)

	'''
	# Evaluate the transformer model (for the retrieval task)
	if EVALUATE_MODELS:
		transformer_retrieval_map_k = evaluation.compute_mean_average_precision_at_k(
			MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
			k_documents=MAP_K, n_queries=MAP_N,
			print_debug=PRINT_EVALUATION_DEBUG,
			# Keyword arguments for the Transformer model
			model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
		)
		transformer_retrieval_recall_k = evaluation.compute_recall_at_k(
			MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
			k_documents=RECALL_K,
			print_debug=PRINT_EVALUATION_DEBUG,
			# Keyword arguments for the Transformer model
			model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set
		)
		print_model_evaluation_results(transformer_retrieval_map_k, transformer_retrieval_recall_k)
	''' 

	# return transformer_model, transformer_retrieval_map_k, transformer_retrieval_recall_k
	return transformer_model, 0, 0

In [None]:
# Train and evaluate the vision transformer model
teacher_forcing_transformer, teacher_forcing_transformer_map_k, teacher_forcing_transformer_recall_k = train_and_evaluate_transformer()

checkpoint_folder: src/models/
checkpoint_name: DSI_VisionTransformer_transformer.ckpt
Training the model for the indexing task...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type    | Params
----------------------------------
0 | model | DSI_ViT | 512 K 
----------------------------------
512 K     Trainable params
0         Non-trainable params
512 K     Total params
2.050     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 



B=30, C=3, H=160, W=160, N=2
input.shape: torch.Size([30, 3, 160, 160])
target.shape: torch.Size([30, 6])
output.shape: torch.Size([30, 1, 13])
generated_target.shape: torch.Size([30, 1])
input.shape (i=1):torch.Size([30, 3, 160, 160])
current_target.shape (i=1):torch.Size([30, 1])
ids.shape: torch.Size([30, 1])
M=1, N=2
imgs.shape (processed): torch.Size([30, 100, 128])
ids.shape (processed)): torch.Size([30, 1, 128])
x.shape (1): torch.Size([30, 101, 128])
x.shape (2): torch.Size([30, 102, 128])
padding_mask.shape: torch.Size([30, 102])
input.shape: torch.Size([102, 30, 128])
padding_mask.shape: torch.Size([30, 102])
attention_mask.shape: torch.Size([102, 102])
input.shape: torch.Size([102, 30, 128])
padding_mask.shape: torch.Size([30, 102])
attention_mask.shape: torch.Size([102, 102])
input.shape: torch.Size([102, 30, 128])
padding_mask.shape: torch.Size([30, 102])
attention_mask.shape: torch.Size([102, 102])
x.shape (3): torch.Size([30, 102, 128])
cls.shape: torch.Size([30, 128])
o

ValueError: Expected input batch_size (30) to match target batch_size (150).