# Neural Inverted Index for Fast and Effective Information Retrieval

---

## 📚 Notebook Overview

This notebook explores a novel [information retrieval (IR)](https://en.wikipedia.org/wiki/Information_retrieval) framework that utilizes a **differentiable function** to generate a **sorted list of document identifiers** in response to a given **query**.

The approach is called **Differentiable Search Index (DSI)**, and was originally proposed in the paper [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf) by researchers at Google Research.

**DSI** aims at both encompassing all document's corpus information and executing retrieval within a single **Transformer language model**, instead of adopting the index-then-retrieve pipeline used in most modern IR sytems.

The notebook presents the implemented solution, a **Sequence to Sequence transformer** model `f` that, given a query `q` as input, returns a list of document IDs ranked by relevance to the query, and compares its performance with the traditional **TF-IDF** retrieval model, a **Word2Vec** model, and a **Siamese Network model with Triplet Loss**.

The proposed solution combines the **DSI** approach with the **Scheduled Sampling** technique for Transformers, inspired by the similar technique described in the paper [Scheduled Sampling for Transformers](https://arxiv.org/abs/1906.07651).

We evaluate the performance of the proposed models using the **Mean Average Precision (MAP)** and the **Recall at K** metrics computed on the **MS MARCO** dataset, and we compare the results with several baselines (**TF-IDF**, **Word2Vec**, **Siamese Network** and also other traditional **Transformer** approaches).

## 📝 Author

**Valerio Di Stefano** - _"Sapienza" University of Rome_
<br/>
Email: [distefano.1898728@studenti.uniroma1.it](mailto:distefano.1898728@studenti.uniroma1.it)

## 🔗 External Links

* **Main Paper**: [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf)

  _Authors_: Yi Tay, Vinh Q. Tran, Mostafa Dehghani, Jianmo Ni, Dara Bahri, Harsh Mehta, Zhen Qin, Kai Hui, Zhe Zhao, Jai Gupta, Tal Schuster, William W. Cohen, Donald Metzler

* **Relevant Paper**: [Understanding Differential Search Index for Text Retrieval](https://arxiv.org/abs/2305.02073)

  _Authors_: Xiaoyang Chen, Yanjiang Liu, Ben He, Le Sun, Yingfei Sun

* **Relevant Paper**: [Scheduled Sampling for Transformers](https://arxiv.org/abs/1906.07651)

	_Authors_: Tsvetomila Mihaylova, André F. T. Martins

* **Project Repository**: [GitHub Repository](https://github.com/valeriodiste/deep_learning_project)


---


## 📌 Table of Contents

To do...


---



<a id="1"></a>
# 🚀 Getting Started


First of all, we check if we are running the notebook on Google colab or locally, defining the `RUNNING_ON_COLAB` constant used throughout the notebook.

In [None]:
# Check if running on colab or locally
try:
	from google.colab import files
	RUNNING_IN_COLAB = True
	print("Running on Google Colab.")
except ModuleNotFoundError:
	RUNNING_IN_COLAB = False
	print("Running locally.")


<a id="1_1"></a>
## Collect Source Files


#### Clone Project's GitHub Repository

We **clone the project's repository** from GitHub to access the source files for datasets, models, evaluation and utilities.


In [None]:
# Clone the git repository of the project for the source files
!git clone https://github.com/valeriodiste/computer_vision_project_dev.git


#### Pull Latest Files Changes

We also **pull the latest changes** from the repository and store them in the `./computer_vision_project` directory.


In [None]:
# Change the working directory to the cloned repository
# TO DO: change the directory to the correct one
%cd /content/computer_vision_project_dev
# Pull the latest changes from the repository
!git pull origin main
# Change the working directory to the parent directory
%cd ..

<a id="1_2"></a>
## Install & Import Libraries

#### Install Libraries

We **install all the necessary libraries** for this notebook.

- **`pytorch-lightning`**: A **lightweight PyTorch wrapper** for simplifying PyTorch code.
- **`ir_datasets`**: A Python library for accessing **information retrieval datasets** (used to load the **"MS MARCO" dataset**).
- **`wandb`**: The python package for **Weights & Biases**, a tool for experiment tracking, dataset versioning, and project collaboration (used for **logging and visualization**).

In [None]:
# Install the required packages
%%capture
%pip install pytorch-lightning
%pip install pycocotools
%pip install wandb

#### Import Modules

We then **import the required modules**, including `PyTorch`, `PyTorch Lightning`, `IR Datasets` and `W&B`, plus other useful modules and libraries (`NLTK`, `Scikit Learn`, `Numpy`, `Pandas`, etc...).

In [None]:

# Import the standard libraries
import os
import json
import random
import logging
import math

# Import the PyTorch libraries and modules
import torch

# Import the PyTorch Lightning libraries and modules
import pytorch_lightning as pl

# Import the coco library
from pycocotools.coco import COCO

# Import the W&B (Weights & Biases) library
# import wandb
# from wandb.sdk import wandb_run
# from pytorch_lightning.loggers import WandbLogger

# Other libraries
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import cv2

# Import the tqdm library (for the progress bars)
if not RUNNING_IN_COLAB:
	from tqdm import tqdm
else:
	from tqdm.notebook import tqdm



We also import our own **custom modules** (cloned from the repository) containing Python classes for **datasets**, **models**, **evaluation**, and **utilities**.

In [None]:
# Import the custom modules
if not RUNNING_IN_COLAB:
	# We are running locally (not on Google Colab, import modules from the "src" directory in the current directory)
	from src.scripts import models, datasets, training, evaluation
	from src.scripts.utils import (
		print_json, RANDOM_SEED, print_model_evaluation_results, MODEL_CHECKPOINT_FILE
	)
else:
	# We are running on Google Colab (import modules from the pulled repository stored in the project's directory)
	from computer_vision_project_dev.src.scripts import models, datasets, training, evaluation
	from computer_vision_project_dev.src.scripts.utils import (
		print_json, RANDOM_SEED, print_model_evaluation_results, MODEL_CHECKPOINT_FILE
	)

<a id="1_3"></a>
## Configuration, Hyperparameters and Constants

#### Random Seed

We **seed the random number generators** for reproducibility.

In [None]:
# Set the random seeds for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED)

#### Device Configuration

We **set the device** to GPU if available, otherwise we use the CPU.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device.type}")

#### Database Constants

We **define the constants** used for the **database resources download** and the **dataset creation**.

In [None]:
# TO DO...

#### Models Hyperparameters

We then **define the constant** representing **hyperparameters** used for the **Word2Vec model**, the **Siamese Network model** and for the **Seq2Seq transformer model**.

In [None]:
# Define the max length of the tokenized queries and documents for the Transformer model (embeddings will be padded or truncated to this length)
TRANSFORMER_DOCUMENT_MAX_TOKENS = 64
TRANSFORMER_QUERY_MAX_TOKENS = 32

# Define the size of the embeddings for the Encoders of the Seq2Seq Transformer model
TRANSFORMER_EMBEDDINGS_SIZE = 128

#### Evauation Constants

We also define the constants used for the evaluation of the various models (i.e. to compute the **Mean Average Precision** and the **Recall at K**).

In [None]:
# Define the number of documents K to retrieve for each query and the number of queries N to calculate the mean average precision (MAP@K)
MAP_K = 10
MAP_N = 10

# Define the number of documents K to retrieve for each query to calculate the Recall@K metrics
RECALL_K = 1_000

# Whether to print the debug information during the MAP@K and Recall@K evaluation of the models
PRINT_EVALUATION_DEBUG = True

# Whether to evaluate the models (i.e. compute the MAP@K and Recall@K metrics for the trained models on the test datasets)
EVALUATE_MODELS = True

#### Other Constants

We ultimately define the constants used to determine where to save data and models and the flags to enable/disable database rebuild/refresh and the loading of models checkpoints.

In [None]:

# Define the data folder, onto which the documents and queries dictionaries will be saved
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the path to save models
MODELS_FOLDER = "src/models" if not RUNNING_IN_COLAB else "/content/models"

# Whether to load model checkpoints (if they were already saved locally) or not
LOAD_MODELS_CHECKPOINTS = True

#### Local Files Folder Creation

We create the folders to store the data dictionaries and the model's checkpoints.

In [None]:
# Create folders if they do not exist
if not os.path.exists(DATA_FOLDER):
	print(f"Creating the data folder at '{DATA_FOLDER}'...")
	os.makedirs(DATA_FOLDER)
if not os.path.exists(MODELS_FOLDER):
	print(f"Creating the models folder at '{MODELS_FOLDER}'...")
	os.makedirs(MODELS_FOLDER)

#### Weights & Biases Configuration

We set the **Weights & Biases** API key to log the experiments.

**⚠️ Note**: Copy and paste your own W&B API key into the `WANDB_API_KEY` constant to see logging results, or set the constant to an empty string to disable W&B logging (this won't plot training losses and accuracies over time).

In [None]:
# Define the WANDB_API_KEY (set to "" to disable W&B logging)
# NOTE: leaving the WANDB_API_KEY to a value of None will throw an error
WANDB_API_KEY = ""

We configure the **Weights & Biases** logger and API to track the experiments and the model's performances.

In [None]:
# Define the wandb logger, api object, entity name and project name
wandb_logger = None
wandb_api = None
wandb_entity = None
wandb_project = None
# Check if a W&B api key is provided
if WANDB_API_KEY == None:
	print("No W&B API key provided, please provide a valid key to use the W&B API or set the WANDB_API_KEY variable to an empty string to disable logging")
	raise ValueError("No W&B API key provided.")
elif WANDB_API_KEY != "":
	# Login to the W&B (Weights & Biases) API
	wandb.login(key=WANDB_API_KEY, relogin=True)
	# Minimize the logging from the W&B (Weights & Biases) library
	os.environ["WANDB_SILENT"] = "true"
	logging.getLogger("wandb").setLevel(logging.ERROR)
	# Initialize the W&B (Weights & Biases) loggger
	wandb_logger = WandbLogger(
		log_model="all", project="cv-dsi-project", name="- SEPARATOR -")
	# Initialize the W&B (Weights & Biases) API
	wandb_api = wandb.Api()
	# Get the W&B (Weights & Biases) entity name
	wandb_entity = wandb_logger.experiment.entity
	# Get the W&B (Weights & Biases) project name
	wandb_project = wandb_logger.experiment.project
	# Finish the "separator" experiment
	wandb_logger.experiment.finish(quiet=True)
	print("W&B API key provided, logging with W&B enabled.")
else:
	print("No W&B API key provided, logging with W&B disabled.")

<a id="1_4"></a>
## Download Data & Resources

#### Download Datasets

TO DO...

In [None]:

# Download the pytorch vision datasets
# Define the data folder for the datasets
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the datasets to download
# NOTE: use a dataset for image captioning (only one in pythorch vision is COCO) or image classification (e.g. CIFAR-10, CIFAR-100, etc.):
# - Image captioning:		https://pytorch.org/vision/main/datasets.html#image-captioning
# - Image classification:	https://pytorch.org/vision/main/datasets.html#image-classification
DATASET = "coco"  # "coco", "cifar10", "cifar100", etc...


# Example on the use of the coco dataset with python  (not using the pyrorch vision coco dataset)
# > https://www.kaggle.com/code/visheshvats/image-caption-generation-on-coco-dataset-ipynb
# Official pycocotools notebook demo file here:
# > https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoDemo.ipynb


'''
# Download the MS MARCO dataset (if needed and if the dictionaries need to be built/rebuilt)
if FORCE_DICTIONARIES_REBUILD or not os.path.exists(DATA_FOLDER + "/docs_dict.json") or not os.path.exists(DATA_FOLDER + "/queries_dict.json"):

	# Load the MS MARCO dataset
	dataset = None
	if USE_DOCUMENTS_DATASETS:
		# Load https://ir-datasets.com/msmarco-passage.html#msmarco-document/dev
		dataset = ir_datasets.load("msmarco-document/dev")
	else:
		# Load https://ir-datasets.com/msmarco-passage.html#msmarco-passage/dev/small
		dataset = ir_datasets.load("msmarco-passage/dev/small")

	# Triggers the download of the datasets (if not already downloaded)
	dataset.docs_iter().__next__()
	dataset.queries_iter().__next__()
	dataset.qrels_iter().__next__()
	dataset.scoreddocs_iter().__next__()

	# Print the dataset structure (i.e. the column names)
	print_metadata = False
	if print_metadata:
		print("Docs Metadata:")
		print_json(dataset.docs_metadata(), 2)
		print("Queries Metadata:")
		print_json(dataset.queries_metadata(), 2)
		print("Qrels Metadata:")
		print_json(dataset.qrels_metadata(), 2)
		print("Scored Docs Metadata:")
		print_json(dataset.scoreddocs_metadata(), 2)

	# Print some samples of the dataset
	print_database_samples = False
	if print_database_samples:
		# Print a sample document
		print("\nSample Document:")
		print("  <doc_id, url, title, body>"
			  if USE_DOCUMENTS_DATASETS
			  else "  <doc_id, text>")
		doc = dataset.docs_iter().__next__()
		print_json(doc, 2)
		# Print a sample query
		print("\nSample Query:")
		print("  <query_id, text>")
		query = dataset.queries_iter().__next__()
		print_json(query, 2)
		# Print a sample qrel
		#   NOTE: the "relevance" and "iteration" fields are always 1 and "0" respectively, for all the qrels (qrels only contain relevant pairs of <query_id, doc_id>)
		print("\nSample Qrel:")
		print("  <query_id, doc_id, relevance, iteration>")
		qrel = dataset.qrels_iter().__next__()
		print_json(qrel, 2)
		# Print a sample scored doc
		print("\nSample Scored Doc:")
		print("  <query_id, doc_id, score>")
		scored_doc = dataset.scoreddocs_iter().__next__()
		print_json(scored_doc, 2)
else:
	# Print a message indicating that the dictionaries already exist and will be loaded
	print("No need to download the MS MARCO dataset.")
	print("Documents and queries dictionaries already exist and will be loaded from the JSON files.")
'''
	

---



<a id="2"></a>
# 💾 Data Preparation

#### Data Variables and Constants

We define the `docs_dict` and `queries_dict` dictionaries used to store the documents and queries data.

The `docs_dict` dictionary contains, for each document ID, the documents' text and its Word2Vec embedding.

The `queries_dict` dictionary contains, for each query ID, the query's text, its Word2Vec embedding and also the list of document IDs for documents relevant to the query.

In [None]:
# Dictionaries to store the documents and queries (the main dataset)
docs_dict = {}
queries_dict = {}

# Auxiliary dictionary to map the column names to the corresponding index in the ir_datasets tuples
IR_DATASET_COLS = {
	"DOCS": {"id": 0, "url": 1, "title": 2, "body": 3} if USE_DOCUMENTS_DATASETS else {"id": 0, "text": 1},
	"QUERIES": {"id": 0, "text": 1},
	"QRELS": {"query_id": 0, "doc_id": 1},
	"SCORED_DOCS": {"query_id": 0, "doc_id": 1, "score": 2}
}

<a id="2_1"></a>
## Word2Vec Model Initialization

We initialize the `Word2Vec` model to compute the **vector embeddings** of the documents and queries.

This model is later **trained on the documents corpus** (using the `Gensim` library) to output vector embeddings of size `VECTOR_EMBEDDINGS_SIZE` for documents and queries.

This model is also used as a **baseline** for the evaluation of the final **Seq2Seq** transformer model (we compute the cosine similarity of the output embeddings between a query and the entire documents database to generate the top `K` most relevant documents).

In [None]:
# Initialize a Word2Vec model to encode the text
word2vec_model = models.Word2VecModel(
	embeddings_size=VECTOR_EMBEDDINGS_SIZE,
	words_window_size=10,
	min_word_frequency=0,
	learning_rate=0.025,
	max_epochs=20,
	save_path=MODELS_FOLDER + "/" +
	MODEL_CHECKPOINTS_FILES[MODEL_TYPES.WORD2VEC]
)


def load_or_train_word2vec_model(documents_corpus=None):
	'''
	Train the word2vec model if the checkpoint file does not exist, otherwise load the model from the checkpoint file

	If document_corpus is None or is an empty list, a new document corpus will be created using the documents in the dataset

	If a document_corpus is provided (as a list of list of strings representing the words of each document's text), it will be used to train the Word2Vec model
	'''
	loaded_checkpoint = False
	if LOAD_MODELS_CHECKPOINTS:
		loaded_checkpoint = word2vec_model.load()
	if not loaded_checkpoint:
		# Train the Word2Vec model on the documents corpus
		print("Training the Word2Vec model on the documents corpus...")
		if documents_corpus is None or len(documents_corpus) == 0:
			# Build the documents corpus
			documents_corpus = [get_preprocessed_text(
				docs_dict[doc_id]["text"]).split(" ") for doc_id in docs_dict]
		# Train the Word2Vec model
		word2vec_model.train(documents_corpus)
		print("Word2Vec model training completed.")
	else:
		print("Word2Vec model loaded from the checkpoint file.")

<a id="2_2"></a>

## Dictionaries Creation & Word2Vec Model Training

We build the **documents** and **queries** dictionaries if needed (or if the `FORCE_DICTIONARIES_REBUILD` flag is set to `True`).

Documents and queries **vector embeddings** are also created, using the `Word2Vec` model trained on the documents' corpus.

If the `REMAP_DOC_IDS` flag is set to `True`, document IDs are also **remapped to new IDs** to avoid gaps in the dictionary.

We ultimately **save the dictionaries** into local files in the `DATA_FOLDER` directory.

In [None]:
# Check if the dictionaries need to be built/rebuilt
if FORCE_DICTIONARIES_REBUILD or not os.path.exists(DATA_FOLDER + "/docs_dict.json") or not os.path.exists(DATA_FOLDER + "/queries_dict.json"):

	print("The documents and queries dictionaries files do not exist, creating them...")

	# Build a queries dictionary, containing the query_id as key, and as values both the query text and a list of associated relevant documents (as doc_id) taken from the scored documents (list of the 1000 relevant documents to the query)
	number_of_queries = MAX_DATASET_QUERIES if 0 < MAX_DATASET_QUERIES < dataset.queries_count() else dataset.queries_count()
	use_scored_docs_for_relevant_documents = True
	for query in tqdm(dataset.queries_iter(), "Building the queries dictionary", number_of_queries):
		if len(queries_dict) >= number_of_queries:
			break
		query_id = query[IR_DATASET_COLS["QUERIES"]["id"]]
		query_text = query[IR_DATASET_COLS["QUERIES"]["text"]]
		queries_dict[query_id] = {
			"text": query_text,
			"embedding": None,
			"relevant_docs": []
		}
	# Add the relevant documents to the queries dictionary
	doc_ids_with_rel = set()
	# First, add the relevant document(s) using the qrels (to ensure the most relevant documents are added first)
	for qrel in tqdm(dataset.qrels_iter(), "Adding relevant documents to queries (using qrels)", dataset.qrels_count()):
		query_id = qrel[IR_DATASET_COLS["QRELS"]["query_id"]]
		if query_id not in queries_dict:
			continue
		doc_id = qrel[IR_DATASET_COLS["QRELS"]["doc_id"]]
		if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY < 0 or len(queries_dict[query_id]["relevant_docs"]) < NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY:
			if doc_id not in queries_dict[query_id]["relevant_docs"]:
				queries_dict[query_id]["relevant_docs"].append(doc_id)
				doc_ids_with_rel.add(doc_id)
	# Then, add the relevant documents using the scoreddocs (if needed)
	# NOTE: the scoreddocs list contains 1000 relevant documents to the query, unordered and without an associated relevance score (these results are less precise than the qrels)
	if use_scored_docs_for_relevant_documents:
		for scored_doc in tqdm(dataset.scoreddocs_iter(), "Adding relevant documents to queries (using scoreddocs)", dataset.scoreddocs_count()):
			query_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["query_id"]]
			if query_id not in queries_dict:
				continue
			doc_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["doc_id"]]
			if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY < 0 or len(queries_dict[query_id]["relevant_docs"]) < NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY:
				if doc_id not in queries_dict[query_id]["relevant_docs"]:
					queries_dict[query_id]["relevant_docs"].append(doc_id)
					doc_ids_with_rel.add(doc_id)
		# Fix the missing relevant documents from the queries dictionary (if the relevant documents list was reduced)
		if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY > 0:
			# Re-add to the relevant documents list of each query all the removed documents that will be added to the documents dataset (i.e. in the doc_ids_with_rel set)
			for scored_doc in tqdm(dataset.scoreddocs_iter(), "Fixing missing relevant documents from queries dictionary", dataset.scoreddocs_count()):
				query_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["query_id"]]
				if query_id not in queries_dict:
					continue
				doc_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["doc_id"]]
				if doc_id in doc_ids_with_rel and doc_id not in queries_dict[query_id]["relevant_docs"]:
					queries_dict[query_id]["relevant_docs"].append(doc_id)
	print(
		f"Total number of documents relevant to at least one query: {len(doc_ids_with_rel)}")

	# Initialize the corpus of documents (to be used to train the Word2Vec model)
	documents_corpus = []

	# Build a documents dictionary, containing the doc_id as key, and the attribute "text" containing the document text
	documents_count = 0
	documents_id_remapping = {}
	documents_id_remapping_inverse = {}
	for doc in tqdm(dataset.docs_iter(), "Building the documents dictionary", dataset.docs_count()):
		# Add the document and its text to the documents dictionary
		doc_id = doc[IR_DATASET_COLS["DOCS"]["id"]]
		if doc_id not in doc_ids_with_rel:
			continue
		doc_text = ""
		if USE_DOCUMENTS_DATASETS:
			doc_text = doc[IR_DATASET_COLS["DOCS"]["title"]] + ".\n" + doc[IR_DATASET_COLS["DOCS"]["body"]]
		else:
			doc_text = doc[IR_DATASET_COLS["DOCS"]["text"]]
		docs_dict[doc_id] = {
			"text": doc_text,
			"embedding": None
		}
		# Compute the remapped doc_id (if needed)
		if REMAP_DOC_IDS:
			new_doc_id = str(documents_count)
			documents_id_remapping[doc_id] = new_doc_id
			documents_id_remapping_inverse[new_doc_id] = doc_id
		# Increment the documents count
		documents_count += 1
		# Add the document text to the corpus
		documents_corpus.append(get_preprocessed_text(doc_text).split(" "))

	# Load or train the Word2Vec model on the documents corpus
	load_or_train_word2vec_model(documents_corpus)

	# Scramble the documents and queries dictionaries
	if REMAP_DOC_IDS:
		print("Scrambling the documents and queries dictionaries...")
		docs_dict = dict(random.sample(docs_dict.items(), len(docs_dict)))
		queries_dict = dict(random.sample(queries_dict.items(), len(queries_dict)))
		print("Scrambled the documents and queries dictionaries.")

	# Iterate over documents in the dictionaries to compute the embeddings (and to eventually remap the doc_ids)
	new_docs_dict = {}
	for doc_id in tqdm(docs_dict, "Computing document embeddings" + (" and remapping doc_ids" if REMAP_DOC_IDS else "")):
		# Compute the embedding of the document text
		docs_dict[doc_id]["embedding"] = word2vec_model.get_embedding((docs_dict[doc_id]["text"]))
		# Remap the doc_id (if needed)
		if REMAP_DOC_IDS:
			new_docs_dict[documents_id_remapping[doc_id]] = {
				"text": docs_dict[doc_id]["text"],
				"embedding": docs_dict[doc_id]["embedding"]
			}
	if REMAP_DOC_IDS:
		docs_dict = new_docs_dict
	# Iterate over queries in the dictionary to compute the embeddings (and to eventually remap the relevant doc_ids)
	for query_id in tqdm(queries_dict, "Computing query embeddings" + (" and remapping relevant doc_ids" if REMAP_DOC_IDS else "")):
		# Compute the embedding of the query text
		queries_dict[query_id]["embedding"] = word2vec_model.get_embedding(queries_dict[query_id]["text"])
		# Remap the relevant documents (if needed)
		if REMAP_DOC_IDS:
			current_relevant_docs = queries_dict[query_id]["relevant_docs"]
			queries_dict[query_id]["relevant_docs"] = [
				documents_id_remapping[doc_id] for doc_id in current_relevant_docs]

	# Print the total number of documents and queries
	print(f"Total number of documents (in built dict): {len(docs_dict)}")
	print(f"Total number of queries (in built dict): {len(queries_dict)}")

	# Save the 2 dictionaries to 2 JSON files in the "data" directory
	print("Saving the documents and queries dictionaries to the JSON files...")
	with open(DATA_FOLDER + "/docs_dict.json", "w") as docs_dict_file:
		json.dump(docs_dict, docs_dict_file, indent=2)
	with open(DATA_FOLDER + "/queries_dict.json", "w") as queries_dict_file:
		json.dump(queries_dict, queries_dict_file, indent=2)
	print("Created the documents and queries dictionaries and saved them to the files.\n")

<a id="2_3"></a>

## Dictionaries Loading

We load the `documents` and `queries` dictionaries from the local files in the `DATA_FOLDER` dicectory and save them to the corresponding dictionary variables.

In [None]:
# Load the documents and queries dictionaries from the JSON files
print("Loading the documents and queries dictionaries from the files...")
with open(DATA_FOLDER + "/docs_dict.json", "r") as docs_dict_file:
	docs_dict = json.load(docs_dict_file)
print(f"  Loaded {len(docs_dict)} documents")
with open(DATA_FOLDER + "/queries_dict.json", "r") as queries_dict_file:
	queries_dict = json.load(queries_dict_file)
print(f"  Loaded {len(queries_dict)} queries")

---

<a id="6"></a>
# 🤖 Vision Transformer Model Training

In this section we implement 3 possible versions of a **Seq2Seq transformer model** to act as the final model for the evaluation of the **Differentiable Search Index** approach.

The **Seq2Seq transformer models** are trained to generate a **sorted list of document IDs** in response to a given **query**.

The 3 transformer models, described in detail in the sub-sections below, are:

1. **Seq2Seq Transformer Model using _teacher forcing_**: A Seq2Seq transformer model trained using only the **teacher forcing** technique, no auto-regressive decoding is used during training.

   At inference time, instead, the model uses an **auto-regressive decoding** technique to generate the sorted list of document IDs.

2. **Seq2Seq Transformer Model using _auto-regressive decoding_**: A Seq2Seq transformer model trained using only the **auto-regressive decoding** technique, no teacher forcing is used during training.

   This model also uses the same **auto-regressive decoding** technique at inference time to generate the sorted list of document IDs.

3. **Seq2Seq Transformer Model using _scheduled sampling_**: A Seq2Seq transformer model trained using the **scheduled sampling** technique, which consists of training the model using a mix of teacher forcing and auto-regressive decoding, by using tokens taken from either the ground truth or the model's own predictions during training, based on a probability defined by the `scheduled_sampling_decay` hyperparameter.

   Once again, the model uses only the **auto-regressive decoding** technique at inference time to generate the sorted list of document IDs.

In [None]:
# Compute the max length of the document IDS
if REMAP_DOC_IDS:
	# Doc IDs are remapped to a range [0, n_docs-1], so the max length depends on the number of documents
	doc_ids_max_length = int(math.floor(math.log10(len(docs_dict))) + 1)
else:
	# We calculate the max length of the doc IDs as the length of the longest doc ID
	doc_ids_max_length = max([len(doc_id) for doc_id in docs_dict])

# Number of output tokens for the encoded document IDs (the 10 digits [0-9] plus the special tokens, i.e. end of sequence, padding, start of sequence)
output_tokens = 10 + 3

#### Transformer Datasets Creation

We create the **datasets** to train and evaluate the **Seq2Seq transformer models** using the `Seq2SeqDataset` class.

Two different datasets are created:

- **Indexing Dataset**: A dataset to train the model for the **indexing task**, in which the model learns to generate document IDs starting from **documents' text embeddings** as source sequences.

   Items of the dataset have the form **`(encoded_document, encoded_doc_id)`** where `encoded_document` is the tokenized version of the document's text (i.e. a **vector of word token IDs** in the tokenizer's vocabulary), computed using a pretrained **BERT** model, and `encoded_doc_id` is a tokenized version of the document's ID in the documents dictionary, computed using an ad-hoc tokenizer which maps each digit of the document ID to an index (which is the same as the digit itself), and adds a special padding token, a special start-of-sequence token, and a special end-of-sequence token.

- **Retrieval Dataset**: A dataset to train the model for the **retrieval task**, in which the model learns to generate document IDs starting from **queries' text embeddings** as source sequences.

   Items of the dataset have the form **`(encoded_query, encoded_doc_id)`** where `encoded_query` is the tokenized version of the query's text, computed using the same **BERT** model, and `encoded_doc_id` is the tokenized version of the document's ID in the documents dictionary, computed using the same ad-hoc tokenizer used for the **Indexing Dataset**.

Both datasets are shared among the 3 transformer models for training and evaluation.

In [None]:
# Get the datasets for the transformer model (datasets are shared between the 3 transformer models)
transformer_indexing_dataset = datasets.TransformerIndexingDataset(
	documents=docs_dict,
	doc_id_max_length=doc_ids_max_length,
	doc_max_length=TRANSFORMER_DOCUMENT_MAX_TOKENS,
	dataset_file_path=DATA_FOLDER + "/transformer_indexing_dataset.json",
	force_dataset_rebuild=FORCE_DICTIONARIES_REBUILD)
transformer_retrieval_dataset = datasets.TransformerRetrievalDataset(
	documents=docs_dict, queries=queries_dict,
	doc_id_max_length=doc_ids_max_length,
	query_max_length=TRANSFORMER_QUERY_MAX_TOKENS,
	dataset_file_path=DATA_FOLDER + "/transformer_retrieval_dataset.json",
	force_dataset_rebuild=FORCE_DICTIONARIES_REBUILD)

We print some **examples** of the **Indexing Dataset** and the **Retrieval Dataset** to visualize the data.

In [None]:
# Print some examples of the Transformers datasets
print_dataset_examples = True
if print_dataset_examples:
	print("Example of a <encoded_doc, encoded_doc_id> pair:")
	encoded_doc, encoded_doc_id = transformer_indexing_dataset[random.randint(
		0, len(transformer_indexing_dataset) - 1)]
	print("  Encoded document:\n  ", encoded_doc)
	print("  Encoded document ID:\n  ", encoded_doc_id)
	print("Example of a <encoded_query, encoded_doc_id> pair:")
	encoded_query, encoded_relevant_doc_id = transformer_retrieval_dataset[random.randint(
		0, len(transformer_retrieval_dataset) - 1)]
	print("  Encoded query:\n  ", encoded_query)
	print("  Encoded relevant document ID:\n  ", encoded_relevant_doc_id)

#### Transformer Models Initialization & Training

We create an auxiliary function to **initialize, train and evaluate** the **3 Seq2Seq transformer models**.

The function considers the same constants and hyperparameters for all the Transformer model's different versions, with the exception of the parameters used for training.

Each of the 3 Transformer models uses a different training approach (as explained in the introduction of this section), and then plots the training metrics (loss, accuracy, etc...) using the **Weights & Biases logger** for both training phases.

Note that at inference time, in order to retrieve the top `K` most relevant documents, all the different Transformer models use the same **auto-regressive decoding** technique (generating document IDs' tokens one at a time, conditioning the generation of the next token on the previously generated tokens).

For each model, we first train for the **indexing task**, using the `TransformerIndexingDataset` dataset, then train for the **retrieval task**, using the `TransformerRetrievalDataset` dataset (defined above).

Before starting to train each model for the retrieval task, the retrieval dataset is split into a **training**, **validation** and **test** set: the latter is then used to evaluate the models' performance, thus for computing the **Mean Average Precision** and the **Recall at K**.

For **computing the evaluation metrics**, we use a similar approach to the one used for the **Word2Vec** and **Siamese Network** models, but with the difference than in this case, while **training the model requires longer** than the previously described models (used as baselines), the **retrieval phase is significantly faster**, as the model directly optputs document IDs relevant to the query given as input, thus not requiring to compute the cosine similarity between the query and the entire documents corpus to find the top `K` most relevant documents.

In [None]:
def train_and_evaluate_dsi_transformer(transformer_type):
	''' Auxiliary function to train (or load checkpoints), show training results, and evaluate the transformer model of the given type '''

	# args to pass to the dsi transformer model
	use_scheduled_sampling_decay = transformer_type == models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER
	dsi_transformer_args = {
		"tokens_in_vocabulary": transformer_indexing_dataset.tokenizer.vocab_size,
		"embeddings_size": TRANSFORMER_EMBEDDINGS_SIZE,
		"target_tokens": output_tokens,
		"transformer_heads": 4,
		"layers": 3,
		"dropout": 0.2,
		"learning_rate": 0.001,
		"batch_size": 512,
		"transformer_type": transformer_type,
		"scheduled_sampling_decay": 0.01 if use_scheduled_sampling_decay else 0.0
	}

	# Initialize transformer model (using scheduled sampling)
	transformer_model = models.DSITransformer(
		**dsi_transformer_args)

	# Model's checkpoint path
	model_type_string = ""
	if transformer_type == models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER:
		model_type_string = "scheduled_sampling"
	elif transformer_type == models.DSITransformer.TRANSFORMER_TYPES.AUTOREGRESSIVE_TRANSFORMER:
		model_type_string = "autoregressive"
	elif transformer_type == models.DSITransformer.TRANSFORMER_TYPES.TEACHER_FORCINIG_TRANSFORMER:
		model_type_string = "teacher_forcing"
	else:
		raise ValueError(
			f"Invalid transformer type: {transformer_type}")
	model_checkpoint_file = MODELS_FOLDER + "/" + model_type_string + "_" + MODEL_CHECKPOINT_FILE

	# Train the model or load its saved checkpoint
	transformer_retrieval_test_set = None
	transformer_retrieval_test_set_file = DATA_FOLDER + f"/{model_type_string}_transformer_retrieval_test_set.json"
	if LOAD_MODELS_CHECKPOINTS and os.path.exists(model_checkpoint_file):
		# Load the saved models checkpoint
		print("A checkpoint for the model exist, loading the saved model checkpoint...")
		transformer_model = models.DSITransformer.load_from_checkpoint(
			model_checkpoint_file, **dsi_transformer_args)
		print("Model checkpoint loaded.")
		# Load the transformer retrieval test set from the JSON file
		print("Loading the transformer retrieval test set from the JSON file...")
		with open(transformer_retrieval_test_set_file, "r") as transformer_retrieval_test_set_file:
			transformer_retrieval_test_set = json.load(
				transformer_retrieval_test_set_file)
		print("Transformer retrieval test set loaded.")
	else:
		# Create 2 loggers for the transformer model (one for the indexing task and one for the retrieval task)
		transformer_loggers = None
		if wandb_api is not None:
			transformer_wandb_logger_indexing = WandbLogger(
				log_model="all", project=wandb_project, name=transformer_type + " (Indexing)")
			transformer_wandb_logger_retrieval = WandbLogger(
				log_model="all", project=wandb_project, name=transformer_type + " (Retrieval)")
			transformer_loggers = [transformer_wandb_logger_indexing,
								   transformer_wandb_logger_retrieval]
		# Train the transformer model (with scheduled sampling) for the indexing task
		transformer_training_infos = training.train_transformer(
			transformer_indexing_dataset=transformer_indexing_dataset,
			transformer_retrieval_dataset=transformer_retrieval_dataset,
			transformer_model=transformer_model,
			max_epochs_list=[250, 150],
			batch_size=transformer_model.hparams.batch_size,
			indexing_split_ratios=(1.0, 0.0),
			retrieval_split_ratios=(0.9, 0.05, 0.05),
			logger=transformer_loggers,
			save_path=model_checkpoint_file
		)
		# Show the wandb training run's dashboard
		if wandb_api is not None:
			indexing_run_id = transformer_training_infos["run_ids"]["indexing"]
			if indexing_run_id is not None:
				print(f"Indexing training results for the {transformer_type} model:")
				indexing_run_object: wandb_run.Run = wandb_api.run(
					f"{wandb_entity}/{wandb_project}/{indexing_run_id}")
				indexing_run_object.display(height=1000)
			retrieval_run_id = transformer_training_infos["run_ids"]["retrieval"]
			if retrieval_run_id is not None:
				print(f"Retrieval training results for the {transformer_type} model:")
				retrieval_run_object: wandb_run.Run = wandb_api.run(
					f"{wandb_entity}/{wandb_project}/{retrieval_run_id}")
				retrieval_run_object.display(height=1000)
		# Save the generated transformer retrieval test set to the JSON file
		print("Saving the transformer retrieval test set to the JSON file...")
		retrieval_test_dataset = transformer_training_infos["retrieval"]["test"]
		transformer_retrieval_test_set = {
			"encoded_queries": [],
			"encoded_doc_ids": []
		}
		retrieval_test_dataset_length = retrieval_test_dataset.__len__()
		for i in range(retrieval_test_dataset_length):
			encoded_query, doc_id = retrieval_test_dataset.__getitem__(i)
			transformer_retrieval_test_set["encoded_queries"].append(
				encoded_query.tolist())
			transformer_retrieval_test_set["encoded_doc_ids"].append(
				doc_id.tolist())
		with open(transformer_retrieval_test_set_file, "w") as transformer_retrieval_test_set_file:
			json.dump(transformer_retrieval_test_set,
					  transformer_retrieval_test_set_file)

	# Evaluate the transformer model (for the retrieval task)
	if EVALUATE_MODELS:
		transformer_retrieval_map_k = evaluation.compute_mean_average_precision_at_k(
			MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
			k_documents=MAP_K, n_queries=MAP_N,
			print_debug=PRINT_EVALUATION_DEBUG,
			# Keyword arguments for the Transformer model
			model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set)
		transformer_retrieval_recall_k = evaluation.compute_recall_at_k(
			MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
			k_documents=RECALL_K,
			print_debug=PRINT_EVALUATION_DEBUG,
			# Keyword arguments for the Transformer model
			model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set)
		print_model_evaluation_results(transformer_retrieval_map_k,
									   transformer_retrieval_recall_k)

	return transformer_model, transformer_retrieval_map_k, transformer_retrieval_recall_k

<a id="6_1"></a>

## Teacher Forcing Seq2Seq Transformer Model

The first version of the **Seq2Seq transformer model** is trained using only the **teacher forcing** technique, no auto-regressive decoding is used during training.

This means that, during training, the model is fed with the **ground truth** document IDs as target sequences, and is therefore trained to generate the correct document IDs given both the query and the ground truth document IDs as input.

At inference time, the model uses the usual **auto-regressive decoding** technique to generate token logits (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [None]:
# Train and evaluate the transformer model using only teacher forcing
teacher_forcing_transformer, teacher_forcing_transformer_map_k, teacher_forcing_transformer_recall_k = \
	train_and_evaluate_dsi_transformer(models.DSITransformer.TRANSFORMER_TYPES.TEACHER_FORCINIG_TRANSFORMER)

<a id="6_1"></a>

## Autoregressive Seq2Seq Transformer Model

The second version of the **Seq2Seq transformer model** is trained using an **auto-regressive decoding** technique, no teacher forcing is used during training.

This means that, during training, the model learns to generate the correct document IDs by relying only on its own predictions, and not on the ground truth document IDs.

The same **auto-regressive decoding** technique is also used at inferencing time to generate token logits (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [None]:
# Train and evaluate the transformer model using ony an autoregressive approach
autoregressive_transformer, autoregressive_transformer_map_k, autoregressive_transformer_recall_k = \
	train_and_evaluate_dsi_transformer(models.DSITransformer.TRANSFORMER_TYPES.AUTOREGRESSIVE_TRANSFORMER)

<a id="6_3"></a>

## Scheduled Sampling Seq2Seq Transformer Model

The final version of the **Seq2Seq transformer model** is trained using the **scheduled sampling** technique, which consists of training the model using a mix of teacher forcing and auto-regressive decoding, by using tokens taken from either the ground truth or the model's own predictions during training, based on a certain probability: this probability is initially set to 1.0 and then decays linearly over time, after each training epoch, by a factor defined by the `scheduled_sampling_decay` hyperparameter.

During training, at each new token generation, the model decides whether to use the ground truth token (teacher forcing) or the previously generated token (autoregression) as input for the next token generation, based on the current probability.

At inference time, only the **auto-regressive decoding** approach is used to generate token logits (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [None]:
# Train and evaluate the transformer model using scheduled sampling
scheduled_sampling_transformer, scheduled_sampling_transformer_map_k, scheduled_sampling_transformer_recall_k = \
	train_and_evaluate_dsi_transformer(models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER)