# <a id='toc1_'></a>[Neural Inverted Index for Fast and Effective Information Retrieval](#toc0_)

---

## <a id='toc1_1_'></a>[📚 Notebook Overview](#toc0_)

This notebook explores a novel [information retrieval (IR)](https://en.wikipedia.org/wiki/Information_retrieval) framework that utilizes a **differentiable function** to generate a **sorted list of document identifiers** in response to a given **query**.

The approach is called **Differentiable Search Index (DSI)**, and was originally proposed in the paper [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf) by researchers at Google Research.

**DSI** aims at both encompassing all document's corpus information and executing retrieval within a single **Transformer language model**, instead of adopting the index-then-retrieve pipeline used in most modern IR sytems.

The notebook presents the implemented solution, a model **Sequence to Sequence transformer** model `f` that, given a query `q` as input, returns list of document IDs ranked by relevance to the query, and compares its performance with the traditional **TF-IDF** retrieval model, a **Word2Vec** model, and a **Siamese Network model with Triplet Loss**.

The proposed solution combines the **DSI** approach with the **Scheduled Sampling** technique for Transformers, inspired by the similar technique described in the paper [Scheduled Sampling for Transformers](https://arxiv.org/abs/1906.07651).

We evaluate the performance of the proposed models using the **Mean Average Precision (MAP)** and the **Recall at K** metrics computed on the **MS MARCO** dataset, and we compare the results with several baselines (**TF-IDF**, **Word2Vec**, **Siamese Network** and also other traditional **Transformer** approaches).

## <a id='toc1_2_'></a>[📝 Author](#toc0_)

**Valerio Di Stefano** - _"Sapienza" University of Rome_
<br/>
Email: [distefano.1898728@studenti.uniroma1.it](mailto:distefano.1898728@studenti.uniroma1.it)

## <a id='toc1_3_'></a>[🔗 External Links](#toc0_)

* **Main Paper**: [Transformer Memory as a Differentiable Search Index](https://arxiv.org/pdf/2202.06991.pdf)

  _Authors_: Yi Tay, Vinh Q. Tran, Mostafa Dehghani, Jianmo Ni, Dara Bahri, Harsh Mehta, Zhen Qin, Kai Hui, Zhe Zhao, Jai Gupta, Tal Schuster, William W. Cohen, Donald Metzler

* **Relevant Paper**: [Understanding Differential Search Index for Text Retrieval](https://arxiv.org/abs/2305.02073)

  _Authors_: Xiaoyang Chen, Yanjiang Liu, Ben He, Le Sun, Yingfei Sun

* **Relevant Paper**: [Scheduled Sampling for Transformers](https://arxiv.org/abs/1906.07651)

    _Authors_: Tsvetomila Mihaylova, André F. T. Martins

* **Project Repository**: [GitHub Repository](https://github.com/valeriodiste/deep_learning_project)


---


## <a id='toc1_4_'></a>[📌 Table of Contents](#toc0_)

**Table of contents**<a id='toc0_'></a>    

- [Neural Inverted Index for Fast and Effective Information Retrieval](#toc1_)    
  - [📚 Notebook Overview](#toc1_1_)    
  - [📝 Author](#toc1_2_)    
  - [🔗 External Links](#toc1_3_)    
  - [📌 Table of Contents](#toc1_4_)    

- [🚀 Getting Started](#toc2_)    
  - [Collect Source Files](#toc2_1_)    
  - [Install & Import Libraries](#toc2_2_)    
  - [Configuration, Hyperparameters and Constants](#toc2_3_)    
  - [Download Data & Resources](#toc2_4_)    

- [💾 Data Preparation](#toc3_)    
  - [Word2Vec Model Initialization](#toc3_2_)    
  - [Dictionaries Creation & Word2Vec Model Training](#toc3_3_)    
  - [Dictionaries Loading](#toc3_4_)    

- [🧾 TF-IDF Model](#toc4_)    

- [🔝 Word2Vec Model](#toc5_)    

- [👬 Siamese Network with Triplet Loss](#toc6_)    

- [🤖 Seq2Seq Transformer Model (DSI approach)](#toc7_)    
  - [Teacher Forcing Seq2Seq Transformer Model](#toc7_2_)    
  - [Autoregressive Seq2Seq Transformer Model](#toc7_3_)    
  - [Scheduled Sampling Seq2Seq Transformer Model](#toc7_4_)    

---



<a id="1"></a>
# <a id='toc2_'></a>[🚀 Getting Started](#toc0_)


First of all, we check if we are running the notebook on Google colab or locally, defining the `RUNNING_ON_COLAB` constant used throughout the notebook.

In [1]:
# Check if running on colab or locally
try:
    from google.colab import files
    RUNNING_IN_COLAB = True
    print("Running on Google Colab.")
except ModuleNotFoundError:
    RUNNING_IN_COLAB = False
    print("Running locally.")

Running locally.



<a id="1_1"></a>
## <a id='toc2_1_'></a>[Collect Source Files](#toc0_)


#### <a id='toc2_1_1_1_'></a>[Clone Project's GitHub Repository](#toc0_)

We **clone the project's repository** from GitHub to access the source files for datasets, models, evaluation and utilities.


In [2]:
%%script echo skipping
# Clone the git repository from "https://github.com/valeriodiste/deep_learning_project" (for the source files)
!git clone https://github.com/valeriodiste/deep_learning_project.git

Couldn't find program: 'echo'


#### <a id='toc2_1_1_2_'></a>[Pull Latest Files Changes](#toc0_)

We also **pull the latest changes** from the repository and store them in the `./deep_learning_project` directory.


In [3]:
%%script echo skipping
# Change the working directory to the cloned repository
%cd /content/deep_learning_project
# Pull the latest changes from the repository
!git pull origin main
# Change the working directory to the parent directory
%cd ..

Couldn't find program: 'echo'


<a id="1_2"></a>
## <a id='toc2_2_'></a>[Install & Import Libraries](#toc0_)

#### <a id='toc2_2_1_1_'></a>[Install Libraries](#toc0_)

We **install all the necessary libraries** for this notebook.

- **`pytorch-lightning`**: A **lightweight PyTorch wrapper** for simplifying PyTorch code.
- **`ir_datasets`**: A Python library for accessing **information retrieval datasets** (used to load the **"MS MARCO" dataset**).
- **`wandb`**: The python package for **Weights & Biases**, a tool for experiment tracking, dataset versioning, and project collaboration (used for **logging and visualization**).

In [4]:
%%script echo skipping
# Install the required packages
%%capture
%pip install pytorch-lightning
%pip install ir_datasets
%pip install wandb

Couldn't find program: 'echo'


#### <a id='toc2_2_1_2_'></a>[Import Modules](#toc0_)

We then **import the required modules**, including `PyTorch`, `PyTorch Lightning`, `IR Datasets` and `W&B`, plus other useful modules and libraries (`NLTK`, `Scikit Learn`, `Numpy`, `Pandas`, etc...).

In [5]:
# Import the standard libraries
import os
import json
import random
import logging
import math

# Import the PyTorch libraries and modules
import torch

# Import the PyTorch Lightning libraries and modules
import pytorch_lightning as pl

# Import the ir_datasets
import ir_datasets

# Import the W&B (Weights & Biases) library
import wandb
from wandb.sdk import wandb_run
from pytorch_lightning.loggers import WandbLogger

# Import the scikit-learn TF-IDF vectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

# Import the NLTK (Natural Language Toolkit) library
import nltk

# Import the tqdm library (for the progress bars)
if not RUNNING_IN_COLAB:
    from tqdm import tqdm
else:
    from tqdm.notebook import tqdm

  from .autonotebook import tqdm as notebook_tqdm


We also import our own **custom modules** (cloned from the repository) containing Python classes for **datasets**, **models**, **evaluation**, and **utilities**.

In [6]:
# Import the custom modules
if not RUNNING_IN_COLAB:
    # We are running locally (not on Google Colab, import modules from the "src" directory in the current directory)
    from src.scripts import models, datasets, training, evaluation
    from src.scripts.utils import (
        print_json, MODEL_TYPES, RANDOM_SEED, MODEL_CHECKPOINTS_FILES, get_preprocessed_text, print_model_evaluation_results
    )
else:
    # We are running on Google Colab (import modules from the pulled repository stored in the "deep_learning_project" directory)
    from deep_learning_project.src.scripts import models, datasets, training, evaluation
    from deep_learning_project.src.scripts.utils import (
        print_json, MODEL_TYPES, RANDOM_SEED, MODEL_CHECKPOINTS_FILES, get_preprocessed_text, print_model_evaluation_results
    )

<a id="1_3"></a>
## <a id='toc2_3_'></a>[Configuration, Hyperparameters and Constants](#toc0_)

#### <a id='toc2_3_1_1_'></a>[Random Seed](#toc0_)

We **seed the random number generators** for reproducibility.

In [7]:
# Set the random seeds for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED)

Seed set to 14


14

#### <a id='toc2_3_1_2_'></a>[Device Configuration](#toc0_)

We **set the device** to GPU if available, otherwise we use the CPU.

In [8]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device.type}")

Device: cpu


#### <a id='toc2_3_1_3_'></a>[Database Constants](#toc0_)

We **define the constants** used for the **database resources download** and the **dataset creation**.

In [9]:
# Define the max number of queries of the dataset (note that the MS MARCO dataset used contains 6980 queries, numbers higher than this will be ignored)
#   Set to -1 to use all the available queries in the MS MARCO dataset used
#   NOTE: this will also indirectly influence the number of documents in the final dataset, as only documents that are relevant to at least one of the selected queries will be kept
MAX_DATASET_QUERIES = 100

# Set the number of relevant documents associated to each query (when "scoreddocs" are used, a maximum value of 1_000 documents can be used)
#   Set to -1 to use all the available relevant documents for each query in the MS MARCO dataset used
#   NOTE: the actual number of relevant documents for some queries might be higher than this value, since the final documents dataset will include all documents
#       associated to at least one query, and some queries might be relevant to their own set of documents plus some documents relevant to other queries
#       (which will still be added to the list, thus exceeding the defined NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY in these cases)
NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY = 10

# Wheter to remap doc IDs to new IDs (starting from 0 up until the number of documents in the final documents dataset)
REMAP_DOC_IDS = True

# Defines wheter to use the MS MARCO documents dataset (very heavy) or the MS MARCO passages dataset (smaller and faster to download and process)
USE_DOCUMENTS_DATASETS = False

#### <a id='toc2_3_1_4_'></a>[Models Hyperparameters](#toc0_)

We then **define the constant** representing **hyperparameters** used for the **Word2Vec model**, the **Siamese Network model** and for the **Seq2Seq transformer model**.

In [10]:
# Define the max length of the embeddings for both queries and documents for the Word2Vec model (embeddings will be padded or truncated to this length)
VECTOR_EMBEDDINGS_SIZE = 128

# Define the size of the output vector embeddings of the Siamese network model
SIAMESE_EMBEDDINGS_SIZE = 64

# Define the max length of the embeddings for both queries and documents for the Transformer model (embeddings will be padded or truncated to this length)
TRANSFORMER_DOCUMENT_MAX_TOKENS = 128
TRANSFORMER_QUERY_MAX_TOKENS = 32

# Define the size of the embeddings for the Encoders of the Seq2Seq Transformer model
TRANSFORMER_EMBEDDINGS_SIZE = 64

#### <a id='toc2_3_1_5_'></a>[Evauation Constants](#toc0_)

We also define the constants used for the evaluation of the various models (i.e. to compute the **Mean Average Precision** and the **Recall at K**).

In [11]:
# Define the number of documents K to retrieve for each query and the number of queries N to calculate the mean average precision (MAP@K)
MAP_K = 10
MAP_N = 10

# Define the number of documents K to retrieve for each query to calculate the Recall@K metrics
RECALL_K = 1_000

# Whether to print the debug information during the MAP@K and Recall@K evaluation of the models
PRINT_EVALUATION_DEBUG = True

# Whether to evaluate the models (i.e. compute the MAP@K and Recall@K metrics for the trained models on the test datasets)
EVALUATE_MODELS = True

#### <a id='toc2_3_1_6_'></a>[Other Constants](#toc0_)

We ultimately define the constants used to determine where to save data and models and the flags to enable/disable database rebuild/refresh and the loading of models checkpoints.

In [12]:

# Define the data folder, onto which the documents and queries dictionaries will be saved
DATA_FOLDER = "src/data" if not RUNNING_IN_COLAB else "/content/data"

# Define the path to save models
MODELS_FOLDER = "src/models" if not RUNNING_IN_COLAB else "/content/models"

# Force the rebuild of the documents and queries dictionaries (to re-save them to the JSON files)
FORCE_DICTIONARIES_REBUILD = False

# Refreshes the embeddings of the documents and queries (if set to True, the embeddings will be recomputed and saved to the JSON files, used to change properties of the embeddings, e.g. the EMBEDDINGS_SIZE, without having to rebuild the dictionaries)
REFRESH_EMBEDDINGS = False

# Whether to load model checkpoints (if they were already saved locally) or not
LOAD_MODELS_CHECKPOINTS = False

#### <a id='toc2_3_1_7_'></a>[Local Files Folder Creation](#toc0_)

We create the folders to store the data dictionaries and the model's checkpoints.

In [13]:
# Create folders if they do not exist
if not os.path.exists(DATA_FOLDER):
    print(f"Creating the data folder at '{DATA_FOLDER}'...")
    os.makedirs(DATA_FOLDER)
if not os.path.exists(MODELS_FOLDER):
    print(f"Creating the models folder at '{MODELS_FOLDER}'...")
    os.makedirs(MODELS_FOLDER)

#### <a id='toc2_3_1_8_'></a>[Weights & Biases Configuration](#toc0_)

We set the **Weights & Biases** API key to log the experiments.

**⚠️ Note**: Copy and paste your own W&B API key into the `WANDB_API_KEY` constant to see logging results, or set the constant to an empty string to disable W&B logging (this won't plot training results).

In [14]:
# Define the WANDB_API_KEY (set to "" to disable W&B logging)
WANDB_API_KEY = "2ba6d81dbfe138d5c7fe13aeeeaac296cb88d274"

We configure the **Weights & Biases** logger and API to track the experiments and the model's performances.

In [15]:
# Define the wandb logger, api object, entity name and project name
wandb_logger = None
wandb_api = None
wandb_entity = None
wandb_project = None
# Check if a W&B api key is provided
if WANDB_API_KEY == None:
    print("No W&B API key provided, please provide a valid key to use the W&B API or set the WANDB_API_KEY variable to an empty string to disable logging")
    raise ValueError("No W&B API key provided.")
elif WANDB_API_KEY != "":
    # Login to the W&B (Weights & Biases) API
    wandb.login(key=WANDB_API_KEY, relogin=True)
    # Minimize the logging from the W&B (Weights & Biases) library
    os.environ["WANDB_SILENT"] = "true"
    logging.getLogger("wandb").setLevel(logging.ERROR)
    # Initialize the W&B (Weights & Biases) loggger
    wandb_logger = WandbLogger(
        log_model="all", project="dl-dsi-project", name="NONE")
    # Initialize the W&B (Weights & Biases) API
    wandb_api = wandb.Api()
    # Get the W&B (Weights & Biases) entity name
    wandb_entity = wandb_logger.experiment.entity
    # Get the W&B (Weights & Biases) project name
    wandb_project = wandb_logger.experiment.project
    print("W&B API key provided, logging with W&B enabled.")
else:
    print("No W&B API key provided, logging with W&B disabled.")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\valer\.netrc


wandb: Currently logged in as: valeriodstfn. Use `wandb login --relogin` to force relogin


W&B API key provided, logging with W&B enabled.


<a id="1_4"></a>
## <a id='toc2_4_'></a>[Download Data & Resources](#toc0_)

#### <a id='toc2_4_1_1_'></a>[Download NLTK Resources](#toc0_)

We download the needed NLTK resources for text preprocessing.

In [16]:
# Download the needed NLTK resources
nltk.download('punkt')
nltk.download('stopwords')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\valer\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\valer\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

#### <a id='toc2_4_1_2_'></a>[Download MS MARCO Dataset](#toc0_)

Download the **MS MARCO** dataset's resources for the `ir_dataset` module (if needed).

The `USE_DOCUMENTS_DATASETS` flag is used to determine whether to download the "documents" version of the dataset or its "passages" version.

In [17]:
# Download the MS MARCO dataset (if needed and if the dictionaries need to be built/rebuilt)
if FORCE_DICTIONARIES_REBUILD or not os.path.exists(DATA_FOLDER + "/docs_dict.json") or not os.path.exists(DATA_FOLDER + "/queries_dict.json"):

    # Load the MS MARCO dataset
    dataset = None
    if USE_DOCUMENTS_DATASETS:
        # Load https://ir-datasets.com/msmarco-passage.html#msmarco-document/dev
        dataset = ir_datasets.load("msmarco-document/dev")
    else:
        # Load https://ir-datasets.com/msmarco-passage.html#msmarco-passage/dev/small
        dataset = ir_datasets.load("msmarco-passage/dev/small")

    # Triggers the download of the datasets (if not already downloaded)
    dataset.docs_iter().__next__()
    dataset.queries_iter().__next__()
    dataset.qrels_iter().__next__()
    dataset.scoreddocs_iter().__next__()

    # Print the dataset structure (i.e. the column names)
    print_metadata = False
    if print_metadata:
        print("Docs Metadata:")
        print_json(dataset.docs_metadata(), 2)
        print("Queries Metadata:")
        print_json(dataset.queries_metadata(), 2)
        print("Qrels Metadata:")
        print_json(dataset.qrels_metadata(), 2)
        print("Scored Docs Metadata:")
        print_json(dataset.scoreddocs_metadata(), 2)

    # Print some samples of the dataset
    print_database_samples = False
    if print_database_samples:
        # Print a sample document
        print("\nSample Document:")
        print("  <doc_id, url, title, body>"
              if USE_DOCUMENTS_DATASETS
              else "  <doc_id, text>")
        doc = dataset.docs_iter().__next__()
        print_json(doc, 2)
        # Print a sample query
        print("\nSample Query:")
        print("  <query_id, text>")
        query = dataset.queries_iter().__next__()
        print_json(query, 2)
        # Print a sample qrel
        #   NOTE: the "relevance" and "iteration" fields are always 1 and "0" respectively, for all the qrels (qrels only contain relevant pairs of <query_id, doc_id>)
        print("\nSample Qrel:")
        print("  <query_id, doc_id, relevance, iteration>")
        qrel = dataset.qrels_iter().__next__()
        print_json(qrel, 2)
        # Print a sample scored doc
        print("\nSample Scored Doc:")
        print("  <query_id, doc_id, score>")
        scored_doc = dataset.scoreddocs_iter().__next__()
        print_json(scored_doc, 2)
else:
    # Print a message indicating that the dictionaries already exist and will be loaded
    print("No need to download the MS MARCO dataset.")
    print("Documents and queries dictionaries already exist and will be loaded from the JSON files.")

No need to download the MS MARCO dataset.
Documents and queries dictionaries already exist and will be loaded from the JSON files.


---



<a id="2"></a>
# <a id='toc3_'></a>[💾 Data Preparation](#toc0_)

#### <a id='toc3_1_1_1_'></a>[Data Variables and Constants](#toc0_)

We define the `docs_dict` and `queries_dict` dictionaries used to store the documents and queries data.

The `docs_dict` dictionary contains, for each document ID, the documents' text and its Word2Vec embedding.

The `queries_dict` dictionary contains, for each query ID, the query's text, its Word2Vec embedding and also the list of document IDs for documents relevant to the query.

In [18]:
# Dictionaries to store the documents and queries (the main dataset)
docs_dict = {}
queries_dict = {}

# Auxiliary dictionary to map the column names to the corresponding index in the ir_datasets tuples
IR_DATASET_COLS = {
    "DOCS": {"id": 0, "url": 1, "title": 2, "body": 3} if USE_DOCUMENTS_DATASETS else {"id": 0, "text": 1},
    "QUERIES": {"id": 0, "text": 1},
    "QRELS": {"query_id": 0, "doc_id": 1},
    "SCORED_DOCS": {"query_id": 0, "doc_id": 1, "score": 2}
}

<a id="2_1"></a>
## <a id='toc3_2_'></a>[Word2Vec Model Initialization](#toc0_)

We initialize the `Word2Vec` model to compute the **vector embeddings** of the documents and queries.

This model is later **trained on the documents corpus** (using the `Gensim` library) to output vector embeddings of size `VECTOR_EMBEDDINGS_SIZE` for documents and queries.

This model is also used as a **baseline** for the evaluation of the final **Seq2Seq** transformer model (we compute the cosine similarity of the output embeddings between a query and the entire documents database to generate the top `K` most relevant documents).

In [19]:
# Initialize a Word2Vec model to encode the text
word2vec_model = models.Word2VecModel(
    embeddings_size=VECTOR_EMBEDDINGS_SIZE,
    words_window_size=10,
    min_word_frequency=0,
    learning_rate=0.025,
    max_epochs=5,
    save_path=MODELS_FOLDER + "/" +
    MODEL_CHECKPOINTS_FILES[MODEL_TYPES.WORD2VEC]
)


def load_or_train_word2vec_model(documents_corpus=None):
    ''' 
    Train the word2vec model if the checkpoint file does not exist, otherwise load the model from the checkpoint file 

    If document_corpus is None or is an empty list, a new document corpus will be created using the documents in the dataset

    If a document_corpus is provided (as a list of list of strings representing the words of each document's text), it will be used to train the Word2Vec model
    '''
    loaded_checkpoint = False
    if LOAD_MODELS_CHECKPOINTS:
        loaded_checkpoint = word2vec_model.load()
    if not loaded_checkpoint:
        # Train the Word2Vec model on the documents corpus
        print("Training the Word2Vec model on the documents corpus...")
        if documents_corpus is None or len(documents_corpus) == 0:
            # Build the documents corpus
            documents_corpus = [get_preprocessed_text(
                docs_dict[doc_id]["text"]).split(" ") for doc_id in docs_dict]
        # Train the Word2Vec model
        word2vec_model.train(documents_corpus)
        print("Word2Vec model training completed.")
    else:
        print("Word2Vec model loaded from the checkpoint file.")

<a id="2_2"></a>

## <a id='toc3_3_'></a>[Dictionaries Creation & Word2Vec Model Training](#toc0_)

We build the **documents** and **queries** dictionaries if needed (or if the `FORCE_DICTIONARIES_REBUILD` flag is set to `True`).

Documents and queries **vector embeddings** are also created, using the `Word2Vec` model trained on the documents' corpus.

If the `REMAP_DOC_IDS` flag is set to `True`, document IDs are also **remapped to new IDs** to avoid gaps in the dictionary.

We ultimately **save the dictionaries** into local files in the `DATA_FOLDER` directory.

In [20]:
# Check if the dictionaries need to be built/rebuilt
if FORCE_DICTIONARIES_REBUILD or not os.path.exists(DATA_FOLDER + "/docs_dict.json") or not os.path.exists(DATA_FOLDER + "/queries_dict.json"):

    print("The documents and queries dictionaries files do not exist, creating them...")

    # Build a queries dictionary, containing the query_id as key, and as values both the query text and a list of associated relevant documents (as doc_id) taken from the scored documents (list of the 1000 relevant documents to the query)
    number_of_queries = MAX_DATASET_QUERIES \
        if 0 < MAX_DATASET_QUERIES < dataset.queries_count() \
        else dataset.queries_count()
    use_scored_docs_for_relevant_documents = True
    for query in tqdm(dataset.queries_iter(), "Building the queries dictionary", number_of_queries):
        if len(queries_dict) >= number_of_queries:
            break
        query_id = query[IR_DATASET_COLS["QUERIES"]["id"]]
        query_text = query[IR_DATASET_COLS["QUERIES"]["text"]]
        queries_dict[query_id] = {
            "text": query_text,
            "embedding": None,
            "relevant_docs": []
        }
    # Add the relevant documents to the queries dictionary
    doc_ids_with_rel = set()
    # First, add the relevant document(s) using the qrels (to ensure the most relevant documents are added first)
    for qrel in tqdm(dataset.qrels_iter(), "Adding relevant documents to queries (using qrels)", dataset.qrels_count()):
        query_id = qrel[IR_DATASET_COLS["QRELS"]["query_id"]]
        if query_id not in queries_dict:
            continue
        doc_id = qrel[IR_DATASET_COLS["QRELS"]["doc_id"]]
        if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY < 0 or len(queries_dict[query_id]["relevant_docs"]) < NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY:
            queries_dict[query_id]["relevant_docs"].append(doc_id)
            doc_ids_with_rel.add(doc_id)
    # Then, add the relevant documents using the scoreddocs (if needed)
    # NOTE: the scoreddocs list contains 1000 relevant documents to the query, unordered and without an associated relevance score (these results are less precise than the qrels)
    if use_scored_docs_for_relevant_documents:
        for scored_doc in tqdm(dataset.scoreddocs_iter(), "Adding relevant documents to queries (using scoreddocs)", dataset.scoreddocs_count()):
            query_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["query_id"]]
            if query_id not in queries_dict:
                continue
            doc_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["doc_id"]]
            if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY < 0 or len(queries_dict[query_id]["relevant_docs"]) < NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY:
                queries_dict[query_id]["relevant_docs"].append(doc_id)
                doc_ids_with_rel.add(doc_id)
        # Fix the missing relevant documents from the queries dictionary (if the relevant documents list was reduced)
        if NUMBER_OF_RELEVANT_DOCUMENTS_PER_QUERY > 0:
            # Re-add to the relevant documents list of each query all the removed documents that will be added to the documents dataset (i.e. in the doc_ids_with_rel set)
            for scored_doc in tqdm(dataset.scoreddocs_iter(), "Fixing missing relevant documents from queries dictionary", dataset.scoreddocs_count()):
                query_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["query_id"]]
                if query_id not in queries_dict:
                    continue
                doc_id = scored_doc[IR_DATASET_COLS["SCORED_DOCS"]["doc_id"]]
                if doc_id in doc_ids_with_rel and doc_id not in queries_dict[query_id]["relevant_docs"]:
                    queries_dict[query_id]["relevant_docs"].append(doc_id)
    print(
        f"Total number of documents relevant to at least one query: {len(doc_ids_with_rel)}")

    # Initialize the corpus of documents (to be used to train the Word2Vec model)
    documents_corpus = []

    # Build a documents dictionary, containing the doc_id as key, and the attribute "text" containing the document text
    documents_count = 0
    documents_id_remapping = {}
    documents_id_remapping_inverse = {}
    for doc in tqdm(dataset.docs_iter(), "Building the documents dictionary", dataset.docs_count()):
        # Add the document and its text to the documents dictionary
        doc_id = doc[IR_DATASET_COLS["DOCS"]["id"]]
        if doc_id not in doc_ids_with_rel:
            continue
        doc_text = ""
        if USE_DOCUMENTS_DATASETS:
            doc_text = doc[IR_DATASET_COLS["DOCS"]["title"]] + \
                ".\n" + doc[IR_DATASET_COLS["DOCS"]["body"]]
        else:
            doc_text = doc[IR_DATASET_COLS["DOCS"]["text"]]
        docs_dict[doc_id] = {
            "text": doc_text,
            "embedding": None
        }
        # Compute the remapped doc_id (if needed)
        if REMAP_DOC_IDS:
            new_doc_id = str(documents_count)
            documents_id_remapping[doc_id] = new_doc_id
            documents_id_remapping_inverse[new_doc_id] = doc_id
        # Increment the documents count
        documents_count += 1
        # Add the document text to the corpus
        documents_corpus.append(get_preprocessed_text(doc_text).split(" "))

    # Load or train the Word2Vec model on the documents corpus
    load_or_train_word2vec_model(documents_corpus)

    # Iterate over documents in the dictionaries to compute the embeddings (and to eventually remap the doc_ids)
    new_docs_dict = {}
    for doc_id in tqdm(docs_dict, "Computing document embeddings" + (" and remapping doc_ids" if REMAP_DOC_IDS else "")):
        # Compute the embedding of the document text
        docs_dict[doc_id]["embedding"] = \
            word2vec_model.get_embedding((docs_dict[doc_id]["text"]))
        # Remap the doc_id (if needed)
        if REMAP_DOC_IDS:
            new_docs_dict[documents_id_remapping[doc_id]] = {
                "text": docs_dict[doc_id]["text"],
                "embedding": docs_dict[doc_id]["embedding"]
            }
    if REMAP_DOC_IDS:
        docs_dict = new_docs_dict
    # Iterate over queries in the dictionary to compute the embeddings (and to eventually remap the relevant doc_ids)
    for query_id in tqdm(queries_dict, "Computing query embeddings" + (" and remapping relevant doc_ids" if REMAP_DOC_IDS else "")):
        # Compute the embedding of the query text
        queries_dict[query_id]["embedding"] = \
            word2vec_model.get_embedding(queries_dict[query_id]["text"])
        # Remap the relevant documents (if needed)
        if REMAP_DOC_IDS:
            current_relevant_docs = queries_dict[query_id]["relevant_docs"]
            queries_dict[query_id]["relevant_docs"] = [
                documents_id_remapping[doc_id] for doc_id in current_relevant_docs]

    # Print the total number of documents and queries
    print(f"Total number of documents (in built dict): {len(docs_dict)}")
    print(f"Total number of queries (in built dict): {len(queries_dict)}")

    # Save the 2 dictionaries to 2 JSON files in the "data" directory
    print("Saving the documents and queries dictionaries to the JSON files...")
    with open(DATA_FOLDER + "/docs_dict.json", "w") as docs_dict_file:
        json.dump(docs_dict, docs_dict_file, indent=2)
    with open(DATA_FOLDER + "/queries_dict.json", "w") as queries_dict_file:
        json.dump(queries_dict, queries_dict_file, indent=2)
    print("Created the documents and queries dictionaries and saved them to the files.\n")

<a id="2_3"></a>

## <a id='toc3_4_'></a>[Dictionaries Loading](#toc0_)

We load the `documents` and `queries` dictionaries from the local files in the `DATA_FOLDER` dicectory and save them to the corresponding dictionary variables.

In [21]:
# Load the documents and queries dictionaries from the JSON files
print("Loading the documents and queries dictionaries from the files...")
with open(DATA_FOLDER + "/docs_dict.json", "r") as docs_dict_file:
    docs_dict = json.load(docs_dict_file)
print(f"  Loaded {len(docs_dict)} documents")
with open(DATA_FOLDER + "/queries_dict.json", "r") as queries_dict_file:
    queries_dict = json.load(queries_dict_file)
print(f"  Loaded {len(queries_dict)} queries")

Loading the documents and queries dictionaries from the files...
  Loaded 979 documents
  Loaded 100 queries


---

<a id="3"></a>
# <a id='toc4_'></a>[🧾 TF-IDF Model](#toc0_)

#### <a id='toc4_1_1_1_'></a>[TF-IDF Model Initialization](#toc0_)

The first baseline used for the evaluation consists of a **TF-IDF** model (**no machine learning used**).

The model, a simple vectorizer built using the `Scikit Learn` library, computes the **TF-IDF scores** for each word and each document in the corpus and stores them in the `tf_idf_matrix` matrix.

In [22]:
# Document IDs
doc_ids = list(docs_dict.keys())
# Document texts
doc_texts = [docs_dict[doc_id]['text'].lower() for doc_id in doc_ids]
# Remove empty documents from the list (and their respective IDs)
doc_ids, doc_texts = zip(*[(doc_id, doc_text) for doc_id,
                           doc_text in zip(doc_ids, doc_texts) if len(doc_text) > 0])

# Get the TF-IDF vectorizer
tf_idf_vectorizer = TfidfVectorizer(stop_words='english')
# Fit the vectorizer on the document texts, computing the TF-IDF matrix (an [n_docs]x[vocab_size] matrix with the TF*IDF score value for each word in each document)
tf_idf_matrix = tf_idf_vectorizer.fit_transform(doc_texts)

#### <a id='toc4_1_1_2_'></a>[TF-IDF Model Evaluation](#toc0_)

We compute the **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** defined by `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **evaluate the TF-IDF model's performance**.

We then **print the results** of the evaluation of both metrics.

In [23]:
if EVALUATE_MODELS:
    print("Evaluating the TF-IDF model...")
    tf_idf_map_k = evaluation.compute_mean_average_precision_at_k(
        MODEL_TYPES.TF_IDF, queries_dict, docs_dict,
        k_documents=MAP_K, n_queries=MAP_N,
        print_debug=PRINT_EVALUATION_DEBUG,
        # Keyword arguments for the TF-IDF model
        vectorizer=tf_idf_vectorizer, tfidf_matrix=tf_idf_matrix)
    # Evaluate the TF-IDF model (compute the Recall@K)
    tf_idf_recall_k = evaluation.compute_recall_at_k(
        MODEL_TYPES.TF_IDF, queries_dict, docs_dict,
        k_documents=RECALL_K,
        print_debug=PRINT_EVALUATION_DEBUG,
        # Keyword arguments for the TF-IDF model
        vectorizer=tf_idf_vectorizer, tfidf_matrix=tf_idf_matrix)
    # Print the evaluation results
    print_model_evaluation_results(tf_idf_map_k, tf_idf_recall_k)

Evaluating the TF-IDF model...
Evaluating TF-IDF model to compute MAP@K...
Evaluating TF-IDF model to compute Recall@K...
MAP@10 for the TF-IDF model:
  > 0.93
  Computed on 10 queries
  Single queries precision:
    Query 873886: 0.9
    Query 1051285: 0.9
    Query 1051530: 1.0
    Query 1051755: 1.0
    Query 2798: 1.0
    Query 1051108: 0.9
    Query 1288: 0.7
    Query 1049955: 0.9
    Query 1091234: 1.0
    Query 1049894: 1.0
Recall@1000 for the TF-IDF model:
  > 1.0
  Computed for query 263670


---

<a id="4"></a>
# <a id='toc5_'></a>[🔝 Word2Vec Model](#toc0_)

#### <a id='toc5_1_1_1_'></a>[Word2Vec Model Evaluation](#toc0_)

We evaluate the `Word2Vec` model initialized in section "[Word2Vec Model Initialization](#toc3_2_)" and trained in section "[Dictionaries Creation & Word2Vec Model Training](#toc3_3_)" (to generate documents and queries embeddings) by computing the **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **evaluate the Word2Vec model's performance**.

Both metrics are computed by using the trained **Word2Vec model** to generate a vector embedding for the given queryes and for all the documents in the corpus, and then calculating the **cosine similarity** between the query embedding and the document embeddings to generate the top `K` most relevant documents.

We therefore employ an **index-then-retrieve** approach, which is significantly slower (in the document retrieval phase) than the approach taken for the final **Seq2Seq transformer** model.

If the model was not already trained (e.g. in case of documents and queries dictionaries being loaded from local files instead of being generated at runtime), we also **train the model** on the documents corpus.

We then **print the results** of the evaluation of both metrics.

In [24]:
# Check if the word2vec model needs to be trained
if not word2vec_model.get_is_trained():
    # Train the Word2Vec model on the documents corpus or load it from the checkpoint file
    load_or_train_word2vec_model()

# Use just the Word2Vec model (with which the embeddings were computed) to compute the similarity scores between the queries and the documents
if EVALUATE_MODELS:
    print("Computing the similarity scores between the queries and the documents using the Word2Vec model...")
    word2vec_map_k = evaluation.compute_mean_average_precision_at_k(
        MODEL_TYPES.WORD2VEC, queries_dict, docs_dict,
        k_documents=MAP_K, n_queries=MAP_N,
        print_debug=PRINT_EVALUATION_DEBUG)
    word2vec_recall_k = evaluation.compute_recall_at_k(
        MODEL_TYPES.WORD2VEC, queries_dict, docs_dict,
        k_documents=RECALL_K,
        print_debug=PRINT_EVALUATION_DEBUG)
    print_model_evaluation_results(word2vec_map_k, word2vec_recall_k)

Training the Word2Vec model on the documents corpus...
Word2Vec model training completed.
Computing the similarity scores between the queries and the documents using the Word2Vec model...
Evaluating Word2Vec model to compute MAP@K...


Computing relevance scores for MAP@K for query 1/10: 100%|██████████| 979/979 [00:00<00:00, 15606.91it/s]
Computing relevance scores for MAP@K for query 2/10: 100%|██████████| 979/979 [00:00<00:00, 17184.38it/s]
Computing relevance scores for MAP@K for query 3/10: 100%|██████████| 979/979 [00:00<00:00, 16869.16it/s]
Computing relevance scores for MAP@K for query 4/10: 100%|██████████| 979/979 [00:00<00:00, 18473.87it/s]
Computing relevance scores for MAP@K for query 5/10: 100%|██████████| 979/979 [00:00<00:00, 14259.26it/s]
Computing relevance scores for MAP@K for query 6/10: 100%|██████████| 979/979 [00:00<00:00, 17101.07it/s]
Computing relevance scores for MAP@K for query 7/10: 100%|██████████| 979/979 [00:00<00:00, 17196.03it/s]
Computing relevance scores for MAP@K for query 8/10: 100%|██████████| 979/979 [00:00<00:00, 17217.01it/s]
Computing relevance scores for MAP@K for query 9/10: 100%|██████████| 979/979 [00:00<00:00, 16681.73it/s]
Computing relevance scores for MAP@K for query

Evaluating Word2Vec model to compute Recall@K...


Computing relevance scores for Recall@K...: 100%|██████████| 979/979 [00:00<00:00, 16965.55it/s]

MAP@10 for the Word2Vec model:
  > 0.53
  Computed on 10 queries
  Single queries precision:
    Query 1038859: 0.7
    Query 524699: 0.6
    Query 1051422: 0.6
    Query 264410: 0.3
    Query 811852: 0.2
    Query 1050857: 0.2
    Query 789332: 0.5
    Query 788484: 1.0
    Query 1051886: 0.8
    Query 786918: 0.4
Recall@1000 for the Word2Vec model:
  > 1.0
  Computed for query 787784





---

<a id="5"></a>
# <a id='toc6_'></a>[👬 Siamese Network with Triplet Loss](#toc0_)

#### <a id='toc6_1_1_1_'></a>[Siamese Network Model Initialization](#toc0_)

We initialize a **Siamese Netork model with Triplet Loss** to act as a third baseline for the evaluation of the finial **Seq2Seq transformer** model.

The **hyperparameters** of the model are defined in the `siamese_network_args` dictionary.

In [25]:
# SiameseNetwork model's args
siamese_network_args = {
    "input_size": VECTOR_EMBEDDINGS_SIZE,
    "output_size": SIAMESE_EMBEDDINGS_SIZE,
    "learning_rate": 0.001,
    "margin": 1.0,
    "dropout": 0.0,
    "activation_function": "ReLU"
}

# Create the Siamese Network model with Triplet Loss
siamese_network_model = models.SiameseNetwork(**siamese_network_args)

#### <a id='toc6_1_1_2_'></a>[Siamese Network Dataset Creation](#toc0_)

We create a **dataset** to then train and evaluate the **Siamese Network model** using the `SiameseNetworkDataset` class.

The dataset consists of **triplets** of **anchor**, **positive** and **negative** samples, where:
- **anchor** is a query id;
- **positive** is a document ID of a document that is relevant to the corresponding query;
- **negative** is a document ID of a document that is **NOT** relevant to the corresponding query.

We also plot some **dataset triplet** examples to visualize the data.

In [26]:
# Create the dataset for the Siamese Network model
#   The dataset will be a list of triplets (anchor_query, positive_document, negative_document)
siamese_triplets_dataset = datasets.SiameseNetworkDataset(
    queries_dict, docs_dict,
    dataset_file_path=DATA_FOLDER + "/siamese_triplets_dataset.json",
    force_dataset_rebuild=FORCE_DICTIONARIES_REBUILD
)

# Print the number of triplets in the dataset
print(
    f"Number of [query, document+, document-] triplets in the dataset: {len(siamese_triplets_dataset.triplets)}")

# Print an example of a triplet
print_triplet_example = True
if print_triplet_example:
    print("Example of a triplet:")
    triplet_example = siamese_triplets_dataset.triplets[0]
    print("  [query, document+, document-]: ", triplet_example)
    # Print the text of the query, the positive document and the negative document
    print("  Query text: ", queries_dict[triplet_example[0]]["text"])
    print("  Positive document text: ",
          docs_dict[triplet_example[1]]["text"])
    print("  Negative document text: ",
          docs_dict[triplet_example[2]]["text"])

Loading the Siamese Network's triplets data from src/data/siamese_triplets_dataset.json...
Loaded 1001 triplets from src/data/siamese_triplets_dataset.json
Number of [query, document+, document-] triplets in the dataset: 1001
Example of a triplet:
  [query, document+, document-]:  ['1048585', '759', '860']
  Query text:  what is paula deen's brother
  Positive document text:  Brother monochrome laser printers go beyond just a low initial purchase price; they offer a low cost per page with features such as: Learn More About Brother Black and White Laser Printers. To learn more about Brother monochrome laser printers and how they can help you, click through the products above. 1  ENERGY STAR® Qualified: : Brother black and white laser printers enter an energy-saving sleep mode after being inactive for a certain period of time, which you can customize.
  Negative document text:  Definition of streaming. : relating to or being the transfer of data (such as audio or video material) in a con

#### <a id='toc6_1_1_3_'></a>[Siamese Network Model Training](#toc0_)

We train the **Siamese Network model** using the `train_siamese` function of the custom `training` module.

At the end of training, if a `WANDB_API_KEY` was provided (and thus the **Weights & Biases logger** was used), we also load the Weights & Biases dashboard to **plot the training results**.

If the model was **already trained** and **saved to a checkpoint file** in the `MODEL_CHECKPOINTS_FILES`, and if the `LOAD_MODELS_CHECKPOINTS` is set to `true`, we **load the model** from the checkpoint file instead of training it again.

In [27]:
# Model's checkpoint file path
model_checkpoint_file = MODELS_FOLDER + "/" + \
    MODEL_CHECKPOINTS_FILES[MODEL_TYPES.SIAMESE_NETWORK]

# Train or load the Siamese Network model
if LOAD_MODELS_CHECKPOINTS and os.path.exists(model_checkpoint_file):
    # Load the Siamese Network model from the checkpoint file
    print(
        "A checkpoint file for the Siamese Network model exists, loading the model...")
    siamese_network_model = models.SiameseNetwork.load_from_checkpoint(
        model_checkpoint_file, **siamese_network_args)
    print("Checkpoint for the Siamese Network model loaded.")
else:
    # Create a new logger for the Siamese Network model
    siamese_wandb_logger = None
    if wandb_api is not None:
        if wandb_logger is not None:
            wandb_logger.experiment.finish(quiet=True)
        siamese_wandb_logger = WandbLogger(
            log_model="all", project=wandb_project, name="Siamese Network")
    # Train the Siamese Network model
    siamese_training_infos = training.train_siamese(
        siamese_dataset=siamese_triplets_dataset,
        siamese_model=siamese_network_model,
        max_epochs=5,
        batch_size=512,
        split_ratio=0.8,
        logger=siamese_wandb_logger,
        save_path=model_checkpoint_file
    )
    # Show the W&B run's dashboard
    if wandb_api is not None:
        print("Training results for the Siamese Network model:")
        run_id = siamese_training_infos["run_id"]
        run_object: wandb_run.Run = wandb_api.run(
            f"{wandb_entity}/{wandb_project}/{run_id}")
        run_object.display(height=600)

checkpoint_folder: src/models/
checkpoint_name: siamese_network.ckpt


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name         | Type              | Params
---------------------------------------------------
0 | model        | Sequential        | 74.2 K
1 | triplet_loss | TripletMarginLoss | 0     
---------------------------------------------------
74.2 K    Trainable params
0         Non-trainable params
74.2 K    Total params
0.297     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (2) is smaller 

Epoch 0: 100%|██████████| 2/2 [00:00<00:00, 25.95it/s, v_num=zx4d]
Average training loss for epoch 0:  0.999636709690094
Average validation loss for epoch 0:  0.9991834163665771
Epoch 1: 100%|██████████| 2/2 [00:00<00:00, 26.65it/s, v_num=zx4d]
Average training loss for epoch 1:  0.9994148015975952
Average validation loss for epoch 1:  0.9986151456832886
Epoch 2: 100%|██████████| 2/2 [00:00<00:00, 32.71it/s, v_num=zx4d]
Average training loss for epoch 2:  0.9989949464797974
Average validation loss for epoch 2:  0.9978592991828918
Epoch 3: 100%|██████████| 2/2 [00:00<00:00, 29.59it/s, v_num=zx4d]
Average training loss for epoch 3:  0.9986307621002197
Average validation loss for epoch 3:  0.9968227744102478
Epoch 4: 100%|██████████| 2/2 [00:00<00:00, 21.44it/s, v_num=zx4d]
Average training loss for epoch 4:  0.997823178768158
Average validation loss for epoch 4:  0.9954085946083069
Epoch 4: 100%|██████████| 2/2 [00:00<00:00, 13.03it/s, v_num=zx4d]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 2/2 [00:00<00:00,  4.82it/s, v_num=zx4d]


Training results for the Siamese Network model:




#### <a id='toc6_1_1_4_'></a>[Siamese Network Model Evaluation](#toc0_)


We compute the **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **evaluate the Siamese Network model's performance**.

Once again, as for the Word2Vec model (see section "[Word2Vec Model Initialization](#toc3_3_)"), both metrics are computed by using the trained **Siamese Network model**, which in turn takes as input the vector embeddings for documents and queries computed using the **Word2Vec model**, to generate a vector embedding of size `SIAMESE_EMBEDDINGS_SIZE` for the given queryes and for all the documents in the corpus, and then calculating the **cosine similarity** between the query embedding and the document embeddings to generate the top `K` most relevant documents.

We therefore employ the same **index-then-retrieve** approach used for the Word2Vec model.

We then **print the results** of the evaluation of both metrics.

In [28]:
if EVALUATE_MODELS:
    print("Evaluating the Siamese Network model...")
    siamese_net_map_k = evaluation.compute_mean_average_precision_at_k(
        MODEL_TYPES.SIAMESE_NETWORK, queries_dict, docs_dict,
        k_documents=MAP_K, n_queries=MAP_N,
        print_debug=PRINT_EVALUATION_DEBUG,
        # Keyword arguments for the Siamese Network model
        model=siamese_network_model)
    # Evaluate the Siamese Network model (compute the Recall@K)
    siamese_net_recall_k = evaluation.compute_recall_at_k(
        MODEL_TYPES.SIAMESE_NETWORK, queries_dict, docs_dict,
        k_documents=RECALL_K,
        print_debug=PRINT_EVALUATION_DEBUG,
        # Keyword arguments for the Siamese Network model
        model=siamese_network_model)
    # Print the evaluation results
    print_model_evaluation_results(siamese_net_map_k, siamese_net_recall_k)

Evaluating the Siamese Network model...
Evaluating Siamese Network with Triplet Loss model to compute MAP@K...


Computing relevance scores for MAP@K for query 1/10: 100%|██████████| 979/979 [00:00<00:00, 1929.00it/s]


  Precision at 10 for query 1/10: 0.0


Computing relevance scores for MAP@K for query 2/10: 100%|██████████| 979/979 [00:00<00:00, 2000.35it/s]


  Precision at 10 for query 2/10: 0.0


Computing relevance scores for MAP@K for query 3/10: 100%|██████████| 979/979 [00:00<00:00, 2005.66it/s]


  Precision at 10 for query 3/10: 0.1


Computing relevance scores for MAP@K for query 4/10: 100%|██████████| 979/979 [00:00<00:00, 1790.71it/s]


  Precision at 10 for query 4/10: 0.0


Computing relevance scores for MAP@K for query 5/10: 100%|██████████| 979/979 [00:00<00:00, 1921.72it/s]


  Precision at 10 for query 5/10: 0.0


Computing relevance scores for MAP@K for query 6/10: 100%|██████████| 979/979 [00:00<00:00, 1803.02it/s]


  Precision at 10 for query 6/10: 0.1


Computing relevance scores for MAP@K for query 7/10: 100%|██████████| 979/979 [00:00<00:00, 1970.83it/s]


  Precision at 10 for query 7/10: 0.0


Computing relevance scores for MAP@K for query 8/10: 100%|██████████| 979/979 [00:00<00:00, 1698.51it/s]


  Precision at 10 for query 8/10: 0.0


Computing relevance scores for MAP@K for query 9/10: 100%|██████████| 979/979 [00:00<00:00, 1112.06it/s]


  Precision at 10 for query 9/10: 0.0


Computing relevance scores for MAP@K for query 10/10: 100%|██████████| 979/979 [00:00<00:00, 1811.84it/s]


  Precision at 10 for query 10/10: 0.1
Evaluating Siamese Network with Triplet Loss model to compute Recall@K...


Computing relevance scores for Recall@K...: 100%|██████████| 979/979 [00:00<00:00, 1666.16it/s]

MAP@10 for the Siamese Network with Triplet Loss model:
  > 0.030000000000000006
  Computed on 10 queries
  Single queries precision:
    Query 1049774: 0.0
    Query 525868: 0.0
    Query 526013: 0.1
    Query 787784: 0.0
    Query 263889: 0.0
    Query 1051339: 0.1
    Query 1051095: 0.0
    Query 524835: 0.0
    Query 524848: 0.0
    Query 1051211: 0.1
Recall@1000 for the Siamese Network with Triplet Loss model:
  > 1.0
  Computed for query 789292





---

<a id="6"></a>
# <a id='toc7_'></a>[🤖 Seq2Seq Transformer Model (DSI approach)](#toc0_)

In this section we implement 3 possible versions of a **Seq2Seq transformer model** to act as the final model for the evaluation of the **Differentiable Search Index** approach.

The **Seq2Seq transformer models** are trained to generate a **sorted list of document IDs** in response to a given **query**.

The 3 transformer models, described in detail in the sub-sections below, are:

1. **Seq2Seq Transformer Model using _teacher forcing_**: A Seq2Seq transformer model trained using only the **teacher forcing** technique, no auto-regressive decoding is used during training.

   At inference time, instead, the model uses an **auto-regressive decoding** technique to generate the sorted list of document IDs.

2. **Seq2Seq Transformer Model using _auto-regressive decoding_**: A Seq2Seq transformer model trained using only the **auto-regressive decoding** technique, no teacher forcing is used during training.

   This model also uses the same **auto-regressive decoding** technique at inference time to generate the sorted list of document IDs.

3. **Seq2Seq Transformer Model using _scheduled sampling_**: A Seq2Seq transformer model trained using the **scheduled sampling** technique, which consists of training the model using a mix of teacher forcing and auto-regressive decoding, by using tokens taken from either the ground truth or the model's own predictions during training, based on a probability defined by the `scheduled_sampling_decay` hyperparameter.

   Once again, the model uses only the **auto-regressive decoding** technique at inference time to generate the sorted list of document IDs.

In [29]:
# Compute the max length of the document IDS
if REMAP_DOC_IDS:
    # Doc IDs are remapped to a range [0, n_docs-1], so the max length depends on the number of documents
    doc_ids_max_length = int(math.floor(math.log10(len(docs_dict))) + 1)
else:
    # We calculate the max length of the doc IDs as the length of the longest doc ID
    doc_ids_max_length = max([len(doc_id) for doc_id in docs_dict])

# Number of output tokens for the encoded document IDs (the 10 digits [0-9] plus the special tokens, i.e. end of sequence, padding, start of sequence)
output_tokens = 10 + 3

#### <a id='toc7_1_1_1_'></a>[Transformer Datasets Creation](#toc0_)

We create the **datasets** to train and evaluate the **Seq2Seq transformer models** using the `Seq2SeqDataset` class.

Two different datasets are created:

- **Indexing Dataset**: A dataset to train the model for the **indexing task**, in which the model learns to generate document IDs starting from **documents' text embeddings** as source sequences.

   Items of the dataset have the form **`(encoded_document, encoded_doc_id)`** where `encoded_document` is the tokenized version of the document's text (i.e. a **vector of word token IDs** in the tokenizer's vocabulary), computed using a pretrained **BERT** model, and `encoded_doc_id` is a tokenized version of the document's ID in the documents dictionary, computed using an ad-hoc tokenizer which maps each digit of the document ID to an index (which is the same as the digit itself), and adds a special padding token, a special start-of-sequence token, and a special end-of-sequence token.

- **Retrieval Dataset**: A dataset to train the model for the **retrieval task**, in which the model learns to generate document IDs starting from **queries' text embeddings** as source sequences.

   Items of the dataset have the form **`(encoded_query, encoded_doc_id)`** where `encoded_query` is the tokenized version of the query's text, computed using the same **BERT** model, and `encoded_doc_id` is the tokenized version of the document's ID in the documents dictionary, computed using the same ad-hoc tokenizer used for the **Indexing Dataset**.

Both datasets are shared among the 3 transformer models for training and evaluation.

In [30]:
# Get the datasets for the transformer model (datasets are shared between the 3 transformer models)
transformer_indexing_dataset = datasets.TransformerIndexingDataset(
    documents=docs_dict,
    doc_id_max_length=doc_ids_max_length,
    doc_max_length=TRANSFORMER_DOCUMENT_MAX_TOKENS,
    dataset_file_path=DATA_FOLDER + "/transformer_indexing_dataset.json",
    force_dataset_rebuild=FORCE_DICTIONARIES_REBUILD)
transformer_retrieval_dataset = datasets.TransformerRetrievalDataset(
    documents=docs_dict, queries=queries_dict,
    doc_id_max_length=doc_ids_max_length,
    query_max_length=TRANSFORMER_QUERY_MAX_TOKENS,
    dataset_file_path=DATA_FOLDER + "/transformer_retrieval_dataset.json",
    force_dataset_rebuild=FORCE_DICTIONARIES_REBUILD)

Loading the Transformer Indexing Dataset from src/data/transformer_indexing_dataset.json...
Loaded 979 documents from src/data/transformer_indexing_dataset.json
Loading the Transformer Retrieval Dataset from src/data/transformer_retrieval_dataset.json...
Loaded 1006 encoded queries and document IDs from src/data/transformer_retrieval_dataset.json


We print some **examples** of the **Indexing Dataset** and the **Retrieval Dataset** to visualize the data.

In [31]:
# Print some examples of the Transformers datasets
print_dataset_examples = True
if print_dataset_examples:
    print("Example of a <encoded_doc, encoded_doc_id> pair:")
    encoded_doc, encoded_doc_id = transformer_indexing_dataset[random.randint(
        0, len(transformer_indexing_dataset) - 1)]
    print("  Encoded document:\n  ", encoded_doc)
    print("  Encoded document ID:\n  ", encoded_doc_id)
    print("Example of a <encoded_query, encoded_doc_id> pair:")
    encoded_query, encoded_relevant_doc_id = transformer_retrieval_dataset[random.randint(
        0, len(transformer_retrieval_dataset) - 1)]
    print("  Encoded query:\n  ", encoded_query)
    print("  Encoded relevant document ID:\n  ", encoded_relevant_doc_id)

Example of a <encoded_doc, encoded_doc_id> pair:
  Encoded document:
   tensor([  101, 16510,  4923,  3556,  3957, 18496,  2545,  2175,  1011,  3805,
         3198,  3477, 11780,  2769,  6165,  2920,  7511,  2482,  5414,  1529,
         2030,  2919,  4923,  3556,  2482,  5414,  2130,  2190,  2482,  5414,
         4923,  3556,  8146,  9699,  3037,  6165,  2210,  3020,  7957, 10940,
         2164, 14344,  2015,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,

#### <a id='toc7_1_1_2_'></a>[Transformer Models Initialization & Training](#toc0_)

We create an auxiliary function to **initialize, train and evaluate** the **3 Seq2Seq transformer models**.

The function considers the same constants and hyperparameters for all the Transformer model's different versions, with the exception of the parameters used for training.

Each of the 3 Transformer models uses a different training approach (as explained in the introduction of this section), and then plots the training metrics (loss, accuracy, etc...) using the **Weights & Biases logger** for both training phases.

Note that at inference time, in order to retrieve the top `K` most relevant documents, all the different Transformer models use the same **auto-regressive decoding** technique (generating document IDs' tokens one at a time, conditioning the generation of the next token on the previously generated tokens).

For each model, we first train for the **indexing task**, using the `TransformerIndexingDataset` dataset, then train for the **retrieval task**, using the `TransformerRetrievalDataset` dataset (defined above).

Before starting to train each model for the retrieval task, the retrieval dataset is split into a **training**, **validation** and **test** set: the latter is then used to evaluate the models' performance, thus for computing the **Mean Average Precision** and the **Recall at K**.

For **computing the evaluation metrics**, we use a similar approach to the one used for the **Word2Vec** and **Siamese Network** models, but with the difference than in this case, while **training the model requires longer** than the previously described models (used as baselines), the **retrieval phase is significantly faster**, as the model directly optputs document IDs relevant to the query given as input, thus not requiring to compute the cosine similarity between the query and the entire documents corpus to find the top `K` most relevant documents.

In [32]:

def train_and_evaluate_dsi_transformer(transformer_type):
    ''' Auxiliary function to train (or load checkpoints), show training results, and evaluate the transformer model of the given type '''

    # args to pass to the dsi transformer model
    use_scheduled_sampling_decay = \
        transformer_type == models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER
    dsi_transformer_args = {
        "tokens_in_vocabulary": transformer_indexing_dataset.tokenizer.vocab_size,
        "embeddings_size": TRANSFORMER_EMBEDDINGS_SIZE,
        "target_tokens": output_tokens,
        "transformer_heads": 4,
        "layers": 3,
        "dropout": 0.1,
        "learning_rate": 0.00075,
        "batch_size": 512,
        "transformer_type": transformer_type,
        "scheduled_sampling_decay": 0.015 if use_scheduled_sampling_decay else 0.0
    }

    # Initialize transformer model (using scheduled sampling)
    transformer_model = models.DSITransformer(
        **dsi_transformer_args)

    # Model's checkpoint path
    model_type_string = ""
    if transformer_type == models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER:
        model_type_string = "scheduled_sampling"
    elif transformer_type == models.DSITransformer.TRANSFORMER_TYPES.AUTOREGRESSIVE_TRANSFORMER:
        model_type_string = "autoregressive"
    elif transformer_type == models.DSITransformer.TRANSFORMER_TYPES.TEACHER_FORCINIG_TRANSFORMER:
        model_type_string = "teacher_forcing"
    else:
        raise ValueError(
            f"Invalid transformer type: {transformer_type}")
    model_checkpoint_file = MODELS_FOLDER + "/" + \
        model_type_string + "_" + \
        MODEL_CHECKPOINTS_FILES[MODEL_TYPES.DSI_TRANSFORMER]

    # Train the model or load its saved checkpoint
    transformer_retrieval_test_set = None
    transformer_retrieval_test_set_file = DATA_FOLDER + \
        f"/{model_type_string}_transformer_retrieval_test_set.json"
    if LOAD_MODELS_CHECKPOINTS and os.path.exists(model_checkpoint_file):
        # Load the saved models checkpoint
        print("A checkpoint for the model exist, loading the saved model checkpoint...")
        transformer_model = models.DSITransformer.load_from_checkpoint(
            model_checkpoint_file, **dsi_transformer_args)
        print("Model checkpoint loaded.")
        # Load the transformer retrieval test set from the JSON file
        print("Loading the transformer retrieval test set from the JSON file...")
        with open(transformer_retrieval_test_set_file, "r") as transformer_retrieval_test_set_file:
            transformer_retrieval_test_set = json.load(
                transformer_retrieval_test_set_file)
        print("Transformer retrieval test set loaded.")
    else:
        # Create 2 loggers for the transformer model (one for the indexing task and one for the retrieval task)
        transformer_loggers = None
        if wandb_api is not None:
            transformer_wandb_logger_indexing = WandbLogger(
                log_model="all", project=wandb_project, name=transformer_type + " (Indexing)")
            transformer_wandb_logger_retrieval = WandbLogger(
                log_model="all", project=wandb_project, name=transformer_type + " (Retrieval)")
            transformer_loggers = [transformer_wandb_logger_indexing,
                                   transformer_wandb_logger_retrieval]
        # Train the transformer model (with scheduled sampling) for the indexing task
        transformer_training_infos = training.train_transformer(
            transformer_indexing_dataset=transformer_indexing_dataset,
            transformer_retrieval_dataset=transformer_retrieval_dataset,
            transformer_model=transformer_model,
            max_epochs_list=[2, 2],
            batch_size=transformer_model.hparams.batch_size,
            indexing_split_ratios=(0.8, 0.2),
            retrieval_split_ratios=(0.75, 0.175, 0.075),
            logger=transformer_loggers,
            save_path=model_checkpoint_file
        )
        # Show the wandb training run's dashboard
        if wandb_api is not None:
            print(
                f"Indexing training results for the {transformer_type} model:")
            indexing_run_id = transformer_training_infos["run_ids"]["indexing"]
            indexing_run_object: wandb_run.Run = wandb_api.run(
                f"{wandb_entity}/{wandb_project}/{indexing_run_id}")
            indexing_run_object.display(height=1000)
            print(
                f"Retrieval training results for the {transformer_type} model:")
            retrieval_run_id = transformer_training_infos["run_ids"]["retrieval"]
            retrieval_run_object: wandb_run.Run = wandb_api.run(
                f"{wandb_entity}/{wandb_project}/{retrieval_run_id}")
            retrieval_run_object.display(height=1000)
        # Save the generated transformer retrieval test set to the JSON file
        print("Saving the transformer retrieval test set to the JSON file...")
        retrieval_test_dataset = transformer_training_infos["retrieval"]["test"]
        transformer_retrieval_test_set = {
            "encoded_queries": [],
            "encoded_doc_ids": []
        }
        retrieval_test_dataset_length = retrieval_test_dataset.__len__()
        for i in range(retrieval_test_dataset_length):
            encoded_query, doc_id = retrieval_test_dataset.__getitem__(i)
            transformer_retrieval_test_set["encoded_queries"].append(
                encoded_query.tolist())
            transformer_retrieval_test_set["encoded_doc_ids"].append(
                doc_id.tolist())
        with open(transformer_retrieval_test_set_file, "w") as transformer_retrieval_test_set_file:
            json.dump(transformer_retrieval_test_set,
                      transformer_retrieval_test_set_file)

    # Evaluate the transformer model (for the retrieval task)
    if EVALUATE_MODELS:
        transformer_retrieval_map_k = evaluation.compute_mean_average_precision_at_k(
            MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
            k_documents=MAP_K, n_queries=MAP_N,
            print_debug=PRINT_EVALUATION_DEBUG,
            # Keyword arguments for the Transformer model
            model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set)
        transformer_retrieval_recall_k = evaluation.compute_recall_at_k(
            MODEL_TYPES.DSI_TRANSFORMER, queries_dict, docs_dict,
            k_documents=RECALL_K,
            print_debug=PRINT_EVALUATION_DEBUG,
            # Keyword arguments for the Transformer model
            model=transformer_model, retrieval_dataset=transformer_retrieval_dataset, retrieval_test_set=transformer_retrieval_test_set)
        print_model_evaluation_results(transformer_retrieval_map_k,
                                       transformer_retrieval_recall_k)

    return transformer_model, transformer_retrieval_map_k, transformer_retrieval_recall_k

<a id="6_1"></a>

## <a id='toc7_2_'></a>[Teacher Forcing Seq2Seq Transformer Model](#toc0_)

The first version of the **Seq2Seq transformer model** is trained using only the **teacher forcing** technique, no auto-regressive decoding is used during training.

This means that, during training, the model is fed with the **ground truth** document IDs as target sequences, and is therefore trained to generate the correct document IDs given both the query and the ground truth document IDs as input.

At inference time, the model uses the usual **auto-regressive decoding** technique to generate token weights (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [33]:
# Train and evaluate the transformer model using only teacher forcing
teacher_forcing_transformer, teacher_forcing_transformer_map_k, teacher_forcing_transformer_recall_k = \
    train_and_evaluate_dsi_transformer(
        models.DSITransformer.TRANSFORMER_TYPES.TEACHER_FORCINIG_TRANSFORMER)

checkpoint_folder: src/models/
checkpoint_name: teacher_forcing_transformer.ckpt
Training the model for the indexing task...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs



  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | Transformer        | 201 K 
1 | get_input_embedding  | Embedding          | 2.0 M 
2 | get_target_embedding | Embedding          | 832   
3 | positional_encoder   | PositionalEncoding | 0     
4 | output_layer         | Linear             | 845   
5 | cross_entropy_loss   | CrossEntropyLoss   | 0     
------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.628     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in th

Epoch 0: 100%|██████████| 2/2 [00:05<00:00,  0.36it/s, v_num=u43n]
Average training loss for epoch 0:  2.5790724754333496
Average validation loss for epoch 0:  2.378133773803711
Average training accuracy for epoch 0:  0.12852619588375092
Average validation accuracy for epoch 0:  0.2568807303905487
Epoch 1:  50%|█████     | 1/2 [00:03<00:03,  0.29it/s, v_num=u43n]

Epoch 1: 100%|██████████| 2/2 [00:05<00:00,  0.37it/s, v_num=u43n]
Average training loss for epoch 1:  2.3914284706115723
Average validation loss for epoch 1:  2.3206422328948975
Average training accuracy for epoch 1:  0.25645875930786133
Average validation accuracy for epoch 1:  0.2568807303905487
Epoch 1: 100%|██████████| 2/2 [00:08<00:00,  0.23it/s, v_num=u43n]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:09<00:00,  0.22it/s, v_num=u43n]
Trained the model for the indexing task.


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Training the model for the retrieval task...



  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | Transformer        | 201 K 
1 | get_input_embedding  | Embedding          | 2.0 M 
2 | get_target_embedding | Embedding          | 832   
3 | positional_encoder   | PositionalEncoding | 0     
4 | output_layer         | Linear             | 845   
5 | cross_entropy_loss   | CrossEntropyLoss   | 0     
------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.628     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in th

Epoch 0: 100%|██████████| 2/2 [00:00<00:00,  2.47it/s, v_num=1jv8]
Average training loss for epoch 0:  2.3551197052001953
Average validation loss for epoch 0:  2.2969632148742676
Average training accuracy for epoch 0:  0.25449585914611816
Average validation accuracy for epoch 0:  0.2569343149662018
Epoch 1: 100%|██████████| 2/2 [00:00<00:00,  2.22it/s, v_num=1jv8]
Average training loss for epoch 1:  2.3072762489318848
Average validation loss for epoch 1:  2.2692878246307373
Average training accuracy for epoch 1:  0.25836020708084106
Average validation accuracy for epoch 1:  0.2569343149662018
Epoch 1: 100%|██████████| 2/2 [00:01<00:00,  1.54it/s, v_num=1jv8]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:01<00:00,  1.16it/s, v_num=1jv8]
Trained the model for the retrieval task.


Indexing training results for the Teacher Forcing Transformer model:




Retrieval training results for the Teacher Forcing Transformer model:


Saving the transformer retrieval test set to the JSON file...
Evaluating DSI Transformer model to compute MAP@K...
Top 10 (predicted) document IDs for query 1/10:
  ['45', '66', '242', '24', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['920', '105', '14', '204', '315', '411', '455', '499', '500', '609', '752']
  Precision at 10 for query 1/10: 0.0
Top 10 (predicted) document IDs for query 2/10:
  ['26', '4', '7', '0', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['906', '109', '119', '144', '20', '359', '528', '585', '665', '683']
  Precision at 10 for query 2/10: 0.0
Top 10 (predicted) document IDs for query 3/10:
  ['476', '4', '646', '7', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['952', '352', '632', '707', '762', '764', '771', '778', '781', '789']
  Precision at 10 for query 3/10: 0.0
Top 10 (predicted) document IDs for query 4/10:
  ['4', '44'

<a id="6_1"></a>

## <a id='toc7_3_'></a>[Autoregressive Seq2Seq Transformer Model](#toc0_)

The second version of the **Seq2Seq transformer model** is trained using an **auto-regressive decoding** technique, no teacher forcing is used during training.

This means that, during training, the model learns to generate the correct document IDs by relying only on its own predictions, and not on the ground truth document IDs.

The same **auto-regressive decoding** technique is also used at inferencing time to generate token weights (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [34]:
# Train and evaluate the transformer model using ony an autoregressive approach
autoregressive_transformer, autoregressive_transformer_map_k, autoregressive_transformer_recall_k = \
    train_and_evaluate_dsi_transformer(
        models.DSITransformer.TRANSFORMER_TYPES.AUTOREGRESSIVE_TRANSFORMER)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


checkpoint_folder: src/models/
checkpoint_name: autoregressive_transformer.ckpt
Training the model for the indexing task...



  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | Transformer        | 201 K 
1 | get_input_embedding  | Embedding          | 2.0 M 
2 | get_target_embedding | Embedding          | 832   
3 | positional_encoder   | PositionalEncoding | 0     
4 | output_layer         | Linear             | 845   
5 | cross_entropy_loss   | CrossEntropyLoss   | 0     
------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.628     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in th

Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] 

Epoch 0: 100%|██████████| 2/2 [01:11<00:00,  0.03it/s, v_num=s40d]
Average training loss for epoch 0:  2.6046524047851562
Average validation loss for epoch 0:  2.3555748462677
Average training accuracy for epoch 0:  0.1287824809551239
Average validation accuracy for epoch 0:  0.25789472460746765
Epoch 1: 100%|██████████| 2/2 [00:51<00:00,  0.04it/s, v_num=s40d]
Average training loss for epoch 1:  2.357470750808716
Average validation loss for epoch 1:  2.283738374710083
Average training accuracy for epoch 1:  0.25692886114120483
Average validation accuracy for epoch 1:  0.25789472460746765
Epoch 1: 100%|██████████| 2/2 [00:55<00:00,  0.04it/s, v_num=s40d]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:56<00:00,  0.04it/s, v_num=s40d]
Trained the model for the indexing task.


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Training the model for the retrieval task...



  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | Transformer        | 201 K 
1 | get_input_embedding  | Embedding          | 2.0 M 
2 | get_target_embedding | Embedding          | 832   
3 | positional_encoder   | PositionalEncoding | 0     
4 | output_layer         | Linear             | 845   
5 | cross_entropy_loss   | CrossEntropyLoss   | 0     
------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.628     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in th

Epoch 0: 100%|██████████| 2/2 [00:06<00:00,  0.32it/s, v_num=4zji]
Average training loss for epoch 0:  2.3218514919281006
Average validation loss for epoch 0:  2.238758087158203
Average training accuracy for epoch 0:  0.2557554841041565
Average validation accuracy for epoch 0:  0.25581395626068115
Epoch 1: 100%|██████████| 2/2 [00:05<00:00,  0.40it/s, v_num=4zji]
Average training loss for epoch 1:  2.2694759368896484
Average validation loss for epoch 1:  2.224186658859253
Average training accuracy for epoch 1:  0.26055797934532166
Average validation accuracy for epoch 1:  0.25581395626068115
Epoch 1: 100%|██████████| 2/2 [00:05<00:00,  0.36it/s, v_num=4zji]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:06<00:00,  0.33it/s, v_num=4zji]
Trained the model for the retrieval task.


Indexing training results for the Autoregressive Transformer model:




Retrieval training results for the Autoregressive Transformer model:


Saving the transformer retrieval test set to the JSON file...
Evaluating DSI Transformer model to compute MAP@K...
Top 10 (predicted) document IDs for query 1/10:
  ['55', '537', '63', '5', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['968', '969', '86', '216', '372', '519', '536', '653', '674', '749']
  Precision at 10 for query 1/10: 0.0
Top 10 (predicted) document IDs for query 2/10:
  ['8', '7', '3', '333', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['891', '892', '895', '893', '891', '894', '890']
  Precision at 10 for query 2/10: 0.0
Top 10 (predicted) document IDs for query 3/10:
  ['37', '978', '83', '8', '978', '978', '978', '978', '978', '978']
> Actual relevant document IDs for the query:
  ['921', '404', '480', '50', '598', '6', '664', '706', '723', '746']
  Precision at 10 for query 3/10: 0.0
Top 10 (predicted) document IDs for query 4/10:
  ['8', '77', '6', '3', '978', '978', '978

<a id="6_3"></a>

## <a id='toc7_4_'></a>[Scheduled Sampling Seq2Seq Transformer Model](#toc0_)

The final version of the **Seq2Seq transformer model** is trained using the **scheduled sampling** technique, which consists of training the model using a mix of teacher forcing and auto-regressive decoding, by using tokens taken from either the ground truth or the model's own predictions during training, based on a certain probability: this probability is initially set to 1.0 and then decays linearly over time, after each training epoch, by a factor defined by the `scheduled_sampling_decay` hyperparameter.

During training, at each new token generation, the model decides whether to use the ground truth token (teacher forcing) or the previously generated token (autoregression) as input for the next token generation, based on the current probability.

At inference time, only the **auto-regressive decoding** approach is used to generate token weights (instead of probabilities, as no softmax is applied to the output of the Transformer model) for all possible document IDs tokens.

We **train the model** and then, if a `WANDB_API_KEY` was provided, we also load the **Weights & Biases** dashboard to **plot the training results** for both the indexing and retrieval tasks (in this order).

After training, we then evaluate the Transformer model by computing the usual **Mean Average Precision** (over `MAP_N` queries, each considering a precision at **K** of `MAP_K`) and the **Recall at K** (with **K** defined by `RECALL_K`) to **teacher forcing Seq2Seq transformer model**.

We then **print the results** of the evaluation of both metrics.

In [35]:
# Train and evaluate the transformer model using scheduled sampling
scheduled_sampling_transformer, scheduled_sampling_transformer_map_k, scheduled_sampling_transformer_recall_k = \
    train_and_evaluate_dsi_transformer(
        models.DSITransformer.TRANSFORMER_TYPES.SCHEDULED_SAMPLING_TRANSFORMER)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


checkpoint_folder: src/models/
checkpoint_name: scheduled_sampling_transformer.ckpt
Training the model for the indexing task...



  | Name                 | Type               | Params
------------------------------------------------------------
0 | model                | Transformer        | 201 K 
1 | get_input_embedding  | Embedding          | 2.0 M 
2 | get_target_embedding | Embedding          | 832   
3 | positional_encoder   | PositionalEncoding | 0     
4 | output_layer         | Linear             | 845   
5 | cross_entropy_loss   | CrossEntropyLoss   | 0     
------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.628     Total estimated model params size (MB)
C:\Users\valer\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in th

Epoch 0: 100%|██████████| 2/2 [00:05<00:00,  0.37it/s, v_num=kjt0]
Scheduled sampling probability for epoch 0:  1.0
Average training loss for epoch 0:  2.6037509441375732
Average validation loss for epoch 0:  2.4007248878479004
Average training accuracy for epoch 0:  0.14360859990119934
Average validation accuracy for epoch 0:  0.25789472460746765
Epoch 1: 100%|██████████| 2/2 [00:06<00:00,  0.33it/s, v_num=kjt0]
Scheduled sampling probability for epoch 1:  0.985
Average training loss for epoch 1:  2.363131046295166
Average validation loss for epoch 1:  2.296553611755371
Average training accuracy for epoch 1:  0.25678831338882446
Average validation accuracy for epoch 1:  0.25789472460746765
Epoch 1: 100%|██████████| 2/2 [00:09<00:00,  0.21it/s, v_num=kjt0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:10<00:00,  0.20it/s, v_num=kjt0]
Trained the model for the indexing task.
