In [None]:
# Check if running on colab or locally
try:
	from google.colab import files
	RUNNING_IN_COLAB = True
	print("Running on Google Colab.")
except ModuleNotFoundError:
	RUNNING_IN_COLAB = False
	print("Running locally.")

In [None]:
# Clone the git repository of the project for the source files
!git clone https://github.com/valeriodiste/computer_vision_project_dev.git

In [None]:
# Change the working directory to the cloned repository
# TO DO: change the directory to the correct one
%cd /content/computer_vision_project_dev
# Pull the latest changes from the repository
!git pull origin main
# Change the working directory to the parent directory
%cd ..

In [None]:
# Install the required packages
# %%capture
%pip install pytorch-lightning
%pip install pycocotools
%pip install wandb

In [None]:
# Import the standard libraries
import os
import json
import random
import logging
import math

# Import the PyTorch libraries and modules
import torch

# Import the PyTorch Lightning libraries and modules
import pytorch_lightning as pl

# Import the coco library
from pycocotools.coco import COCO

# Import the W&B (Weights & Biases) library
# import wandb
# from wandb.sdk import wandb_run
# from pytorch_lightning.loggers import WandbLogger

# Other libraries
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import cv2
import base64

# Import the tqdm library (for the progress bars)
if not RUNNING_IN_COLAB:
	from tqdm import tqdm
else:
	from tqdm.notebook import tqdm

In [None]:
# Import the custom modules
if not RUNNING_IN_COLAB:
	# We are running locally (not on Google Colab, import modules from the "src" directory in the current directory)
	from src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore
else:
	# We are running on Google Colab (import modules from the pulled repository stored in the project's directory)
	from computer_vision_project_dev.src.scripts import models, datasets, training, evaluation, utils	# type: ignore
	from computer_vision_project_dev.src.scripts.utils import ( RANDOM_SEED, MODEL_CHECKPOINT_FILE )	# type: ignore

In [None]:
# Set the random seeds for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED)

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device.type}")

In [None]:
# Constant definitions

# ===== Training & Datasets constants =====================================================

# MS COCO dataset constants (use MS COCO 2014 dataset for image captioning)
COCO_DATA_YEAR = '2014'  	# '2014' or '2017'
COCO_DATA_TYPE = 'val'  # 'train' or 'val'
COCO_DATA_CAPTIONS_FILE = f"/annotations/captions_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the annotations file inside the DATA_FOLDER
CODO_DATA_INSTANCES_FILE = f"/annotations/instances_{COCO_DATA_TYPE}{COCO_DATA_YEAR}.json"	# Path of the instances file inside the DATA_FOLDER

# Size of the image patches
IMAGE_PATCH_SIZE = 16
# Number of image patches per dimension (i.e. both vertically and horizontally, since images have a square aspect ratio)
IMAGE_PATCHES_PER_DIMENSION = 10	# 3x3 patches, 48x48 pixels images

# Total number of images to consider in the dataset (will be split into training, validation and test sets)
NUMBER_OF_IMAGES_IN_DB = 10 	# Was 2000
# Minimum number of captions for an image
MIN_IMAGE_CAPTIONS = 5
# If not enough square images are found, also accept images that have this max aspect difference (they will be cropped to a square aspect ratio later)
MAX_ASPECT_RATIO_TOLERANCE = 0.1 	# Accept images that are 10% wider than they are tall (or vice versa)

# ===== Evaluation constants ==============================================================

# Define the number of images K to retrieve for each query and the number of queries N to calculate the mean average precision (MAP@K)
MAP_K = 10
MAP_N = 10

# Define the number of images K to retrieve for each query to calculate the Recall@K metrics
RECALL_K = 1_000

# Whether to print the debug information during the MAP@K and Recall@K evaluation of the models
PRINT_EVALUATION_DEBUG = True

# Whether to evaluate the models (i.e. compute the MAP@K and Recall@K metrics for the trained models on the test datasets)
EVALUATE_MODELS = True

# ===== MAIN CONSTANTS =====================================================================

# Define the data folder, onto which the various dictionaries, lists and other data will be saved
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the path to save models
MODELS_FOLDER = "src/models" if not RUNNING_IN_COLAB else "/content/models"

# Force the creation of the "image_db" images list, the JSON files for the datasets, ecc...
FORCE_DICTIONARIES_CREATION = True		# Set to false to try to load the dictionaries from the DATA_FOLDER if they exist

# Whether to load model checkpoints (if they were already saved locally) or not
LOAD_MODELS_CHECKPOINTS = True

In [None]:
# Define the WANDB_API_KEY (set to "" to disable W&B logging)
# NOTE: leaving the WANDB_API_KEY to a value of None will throw an error
WANDB_API_KEY = ""

In [None]:
# Define the wandb logger, api object, entity name and project name
wandb_logger = None
wandb_api = None
wandb_entity = None
wandb_project = None
# Check if a W&B api key is provided
if WANDB_API_KEY == None:
	print("No W&B API key provided, please provide a valid key to use the W&B API or set the WANDB_API_KEY variable to an empty string to disable logging")
	raise ValueError("No W&B API key provided...")
elif WANDB_API_KEY != "":
	# Login to the W&B (Weights & Biases) API
	wandb.login(key=WANDB_API_KEY, relogin=True)
	# Minimize the logging from the W&B (Weights & Biases) library
	os.environ["WANDB_SILENT"] = "true"
	logging.getLogger("wandb").setLevel(logging.ERROR)
	# Initialize the W&B (Weights & Biases) loggger
	wandb_logger = WandbLogger(
		log_model="all", project="cv-dsi-project", name="- SEPARATOR -")
	# Initialize the W&B (Weights & Biases) API
	wandb_api = wandb.Api()
	# Get the W&B (Weights & Biases) entity name
	wandb_entity = wandb_logger.experiment.entity
	# Get the W&B (Weights & Biases) project name
	wandb_project = wandb_logger.experiment.project
	# Finish the "separator" experiment
	wandb_logger.experiment.finish(quiet=True)
	print("W&B API key provided, logging with W&B enabled.")
else:
	print("No W&B API key provided, logging with W&B disabled.")

In [None]:
# Create folders if they do not exist
if not os.path.exists(DATA_FOLDER):
	print(f"Creating the data folder at '{DATA_FOLDER}'...")
	os.makedirs(DATA_FOLDER)
if not os.path.exists(MODELS_FOLDER):
	print(f"Creating the models folder at '{MODELS_FOLDER}'...")
	os.makedirs(MODELS_FOLDER)

In [None]:
# Check if the annotation file for the COCO dataset exists, if it does not exist, download it
!cd {DATA_FOLDER} && wget -nc http://images.cocodataset.org/annotations/annotations_trainval{COCO_DATA_YEAR}.zip
!cd {DATA_FOLDER} && unzip -n annotations_trainval{COCO_DATA_YEAR}.zip

In [None]:
# Initialize the COCO api for captioning
coco_captions = COCO(f"{DATA_FOLDER}{COCO_DATA_CAPTIONS_FILE}")
# Initialize the COCO api for object detection
coco_instances = COCO(f"{DATA_FOLDER}{CODO_DATA_INSTANCES_FILE}")

# Show the COCO dataset info for the captioning task
print("\nCOCO captioning dataset infos:")
coco_captions.info()

# Show the information for the captioning task
print("\nCOCO captioning task infos:")
coco_caps = coco_captions.dataset['annotations']
print("Number of images: ", len(coco_captions.getImgIds()))
print("Number of captions: ", len(coco_caps))
print("Number of average captions per image: ", len(coco_caps) / len(coco_captions.getImgIds()))

# Show the COCO dataset info for the object detection task
print("\nCOCO object detection dataset infos:")
coco_instances.info()

# Show the information for the object detection task
print("\nCOCO object detection task infos:")
coco_objs = coco_instances.dataset['annotations']
print("Number of images: ", len(coco_instances.getImgIds()))
print("Number of objects: ", len(coco_objs))
print("Number of categories: ", len(coco_instances.cats))
print("Categories:")
utils.print_json(coco_instances.cats)


In [None]:
# Print some examples from the MS COCO dataset

# Print the first image object example
example_image_index = 0
print("\nImage object example: ")
image_example = coco_captions.loadImgs(coco_caps[example_image_index]['image_id'])[0]
utils.print_json(image_example, 2)

# Print the actual image file
print("\nActual image of the example (size: " + str(image_example['width']) + "x" + str(image_example['height']) + "):")
url = image_example['coco_url']
image = io.imread(url)
plt.axis('off')
plt.imshow(image)
plt.show()

# Downscale the image to the maximum allowed size in the model
image_max_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
# Crop the image to a square aspect ratio if it is not already square
downscaled_image = image
if image_example['width'] > image_example['height']:
	# Image is wider than tall, crop the sides
	crop_width = (image_example['width'] - image_example['height']) // 2
	downscaled_image = image[:, crop_width:crop_width+image_example['height']]
elif image_example['height'] > image_example['width']:
	# Image is taller than wide, crop the top and bottom
	crop_height = (image_example['height'] - image_example['width']) // 2
	downscaled_image = image[crop_height:crop_height+image_example['width'], :]
# Downscale the image to the maximum allowed size
downscaled_image = cv2.resize(downscaled_image, (image_max_size, image_max_size))
print("\nDownscaled & cropped image of the example (size: " + str(image_max_size) + "x" + str(image_max_size) + "):")
plt.axis('off')
plt.imshow(downscaled_image)
plt.show()

# Print the captions for the given image
print("\nCaption examples for the given image: ")
captions_for_image = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=image_example['id']))
for caption, i in zip(captions_for_image, range(len(captions_for_image))):
	print(str(i+1) + ") " + caption['caption'].strip())

# Print the captioning object example
print("\nFirst caption object example:")
utils.print_json(captions_for_image[0], 2)

# Print information about the object detection task for the given image
print("\nObject detection examples for the given image:")
# Get the object detection annotations for the given image
annotations_for_image = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=image_example['id']))
print("List of the " + str(len(annotations_for_image)) + " object detection annotations for the given image (obtained using the 'coco_instances.loadAnns(image_annotation_id)' function):")
for annotation, i in zip(annotations_for_image, range(len(annotations_for_image))):
	print("\n> Annotation " + str(i+1) + ":")
	# Truncate the "segmentation" field if it is too long
	truncation_length = 10
	for j in range(len(annotation['segmentation'])):
		if len(annotation['segmentation'][j]) > truncation_length:
			annotation['segmentation'][j] = annotation['segmentation'][j][:truncation_length] + ["..."] + ["[truncated to " + str(truncation_length) + " out of " + str(len(annotation['segmentation'][j])) + " elements]"]
		break
	# Print the annotation object
	utils.print_json(annotation, 2)

In [None]:
# Build a dataset of images for the training of the Vision Transformer model

# Function that returns the list containing the images for the training of the Vision Transformer model
def get_images_db(number_of_images):
	# Structure of the images
	images_list_object = {
		"image_id": "",			# ID of the image (as found in the COCO dataset)
		"image_url": "",		# URL of the image
		"image_width": 0,		# The original image width
		"image_height": 0,		# The original image height
		"image_captions": [],	# List of captions for the image
		"image_classes": [		# List of classes for the image (i.e. detected objects, in the order of area size)
			{
				"class_id": 0,		# ID of the class (as found in the COCO dataset)
				"class_name": "",	# Name of the class
				"class_area": 0,		# Area of the class in the image
				"class_bounding_box": [0, 0, 0, 0]	# Bounding box of the class in the format: "[x, y, width, height]" (normalized to the image size)
			}
		],	
		"image_data": ""		# Base64 string of the image
	}
	# Get the image ids
	img_ids = coco_captions.getImgIds()
	# Randomly shuffle the image ids
	# random.shuffle(img_ids)
	# Get the images
	images = []
	# Function that returns an image's base64 string from the image url
	def get_image_data(image_url):
		# Load the image
		image = io.imread(image_url)
		# Downscale the image to the maximum allowed size in the model
		image_max_size = IMAGE_PATCH_SIZE * IMAGE_PATCHES_PER_DIMENSION
		# Crop the image to a square aspect ratio if it is not already square
		downscaled_image = image
		if image.shape[1] > image.shape[0]:
			# Image is wider than tall, crop the sides
			crop_width = (image.shape[1] - image.shape[0]) // 2
			downscaled_image = image[:, crop_width:crop_width+image.shape[0]]
		elif image.shape[0] > image.shape[1]:
			# Image is taller than wide, crop the top and bottom
			crop_height = (image.shape[0] - image.shape[1]) // 2
			downscaled_image = image[crop_height:crop_height+image.shape[1], :]
		# Downscale the image to the maximum allowed size
		downscaled_image = cv2.resize(downscaled_image, (image_max_size, image_max_size))
		# Convert the image to a base64 string
		image_base64 = base64.b64encode(cv2.imencode('.jpg', downscaled_image)[1]).decode()
		# Return the base64 string of the image
		return image_base64
	# Function that returns a list of images with the given aspect ratio tolerance
	def select_images_list(image_aspect_ratio_tolerance):
		# Get the images
		for img_id in img_ids:
			# Get the image object
			img_obj = coco_captions.loadImgs(img_id)[0]
			# Check if the size of the image is square or within the aspect ratio tolerance
			image_aspect_ratio = img_obj['width'] / img_obj['height']
			if abs(image_aspect_ratio - 1) > image_aspect_ratio_tolerance:
				continue
			# Check if the image is already in the images list
			if any(img['image_id'] == img_obj['id'] for img in images):
				continue
			# Get the image url
			img_url = img_obj['coco_url']
			# Get the captions for the image
			img_captions = []
			captions = coco_captions.loadAnns(coco_captions.getAnnIds(imgIds=img_obj['id']))
			for caption in captions:
				caption_text = caption['caption'].strip()
				if len(caption_text) > 1:
					img_captions.append(caption_text)
			# Discard the image if the number of captions is less than the minimum
			if len(img_captions) < MIN_IMAGE_CAPTIONS:
				continue
			# Discard the image if it has no classes
			classes = coco_instances.loadAnns(coco_instances.getAnnIds(imgIds=img_obj['id']))
			if len(classes) == 0:
				continue
			# Create a classes object with the fields: "class_id", "class_name", "class_area"
			classes_obj = []
			for class_obj in classes:
				classes_obj.append({
					"class_id": class_obj['category_id'],
					"class_name": coco_instances.cats[class_obj['category_id']]['name'],
					"class_area": class_obj['area'],
					"class_bounding_box": class_obj['bbox']
				})
			# Sort the classes by area size
			classes_obj = sorted(classes_obj, key=lambda x: x['class_area'], reverse=True)
			# Add the image to the images list
			images_list_object = {
				"image_id": img_obj['id'],
				"image_url": img_url,
				"image_width": img_obj['width'],
				"image_height": img_obj['height'],
				"image_captions": img_captions,
				"image_classes": classes_obj,
				"image_data": None # Will be filled later
			}
			images.append(images_list_object)
			# Break if the number of images is reached
			if len(images) >= number_of_images:
				break
		# Return the images list
		return images
	print("Selecting images with a square aspect ratio...")
	# Get the images that have a square aspect ratio first
	images = select_images_list(0)
	# Get the remaining images with the given aspect ratio tolerance
	if len(images) < number_of_images:
		square_aspect_ratio_images = len(images)
		print("> Found " + str(square_aspect_ratio_images) + " / " + str(number_of_images) + " images with a square aspect ratio, looking for the remaining images...")
		print("Looking for remaining images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "% (either a " + str(1 + MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio or a " + str(1 - MAX_ASPECT_RATIO_TOLERANCE) + " aspect ratio)...")
		images = select_images_list(MAX_ASPECT_RATIO_TOLERANCE)
		non_square_aspect_ratio_images = len(images) - square_aspect_ratio_images
		# Print the number of images found
		print("> Found " + str(non_square_aspect_ratio_images) + " / " + str(number_of_images) + " more images with an aspect ratio within a tolerance of " + str(round(MAX_ASPECT_RATIO_TOLERANCE*100)) + "%.")
	else:
		print("> Found " + str(number_of_images) + " / " + str(number_of_images) + " images with a square aspect ratio.")
	# Print a message based on the number of images found
	if len(images) < number_of_images:
		print("WARNING: Could not find enough images with the required aspect ratio tolerance, only " + str(len(images)) + " / " + str(number_of_images) + " images found.")
	else:
		print("DONE: Found all " + str(number_of_images) + " / " + str(number_of_images) + " images with the required aspect ratio tolerance.")
	# Get all the image data
	for img in tqdm(images, desc="Processing images data (computing BASE64 images encoding)..."):
		img["image_data"] = get_image_data(img['image_url'])
	# Return the images list
	return images

# List of image objects used for the training of the Vision Transformer model
images_db = []

# Check if the images list should be rebuilt or loaded
create_images_db = True
images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
if os.path.exists(images_db_file) and not FORCE_DICTIONARIES_CREATION:
	with open(images_db_file, 'r') as f:
		images_db = json.load(f)
	if len(images_db) == NUMBER_OF_IMAGES_IN_DB:
		create_images_db = False
		print("Loaded the images list from the file: ", images_db_file)
if create_images_db or FORCE_DICTIONARIES_CREATION:
	# Initialize the images list
	images_db = get_images_db(NUMBER_OF_IMAGES_IN_DB)
	# Save the images list to a JSON file
	images_db_file = os.path.join(DATA_FOLDER, "images_db.json")
	print("Saving the images list to the file: ", images_db_file)
	with open(images_db_file, 'w') as f:
		json.dump(images_db, f)

# Print the final number of images in the dataset
print("\nNumber of loaded images in the dataset: " + str(len(images_db)) + "/" + str(NUMBER_OF_IMAGES_IN_DB))

In [None]:
# Print the first image object example
example_image_index = -1
print("Image object example: ")
utils.print_json(images_db[example_image_index], 2)

# Print the actual image file
image = utils.get_image_from_db_object(images_db[example_image_index])
print("\nActual image of the example (original size: " + str(images_db[example_image_index]['image_width']) + "x" + str(images_db[example_image_index]['image_height']) + " | downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
plt.axis('off')
plt.imshow(image)
plt.show()

# Print how the Transformer model sees the image
print("\nHow the Transformer model sees the image (downsampled size: " + str(image.shape[1]) + "x" + str(image.shape[0]) + "):")
# Divide the image into smaller images representing the patches
image_patches = []
for i in range(0, image.shape[0], IMAGE_PATCH_SIZE):
	for j in range(0, image.shape[1], IMAGE_PATCH_SIZE):
		image_patch = image[i:i+IMAGE_PATCH_SIZE, j:j+IMAGE_PATCH_SIZE]
		image_patches.append(image_patch)
# Display the image patches
fig, axs = plt.subplots(IMAGE_PATCHES_PER_DIMENSION, IMAGE_PATCHES_PER_DIMENSION, figsize=(10, 10))
for i in range(IMAGE_PATCHES_PER_DIMENSION):
	for j in range(IMAGE_PATCHES_PER_DIMENSION):
		axs[i, j].imshow(image_patches[i*IMAGE_PATCHES_PER_DIMENSION+j])
		axs[i, j].axis('off')
plt.show()

In [None]:
# Compute the max length of the image IDS (we consider the index of the image in the "images_db" as the image ID)
max_image_id_length = len(images_db)

# Number of output tokens for the encoded image IDs (the 10 digits [0-9] plus the 3 special tokens, i.e. end of sequence, padding, start of sequence)
output_tokens = 10 + 3

# Build the Transformer Indexing Database for training the vision transformer
transformer_indexing_dataset = datasets.TransformerIndexingDataset(
	images=images_db,
	patch_size=IMAGE_PATCH_SIZE,
	img_patches=IMAGE_PATCHES_PER_DIMENSION,
	img_id_max_length=max_image_id_length,
	dataset_file_path=os.path.join(DATA_FOLDER, "transformer_indexing_dataset.json"),
	force_dataset_rebuild=FORCE_DICTIONARIES_CREATION
)

# Print the first example from the Transformer Indexing Dataset
example_index = 0
print("Example from the Transformer Indexing Dataset:")
print("<encoded_image, encoded_image_id> tuple:")
print(transformer_indexing_dataset[example_index])
