<center><img src="https://storage.googleapis.com/kaggle-competitions/kaggle/21651/logos/header.png?t=2020-09-09-03-03-31" width="1000"></center>
<br>
<center><h1>TPU - Track knowledge states of 1M+ students in the wild</h1></center>
<br>

### In this notebook, I demonstrate a tensorflow pipeline for the competition [Riiid! Answer Correctness Prediction](#https://www.kaggle.com/c/riiid-test-answer-prediction). Other than the common knowledge about using TPU (which is presented in my another Kaggle notebook [Detailed guide to custom training with TPUs](#https://www.kaggle.com/yihdarshieh/detailed-guide-to-custom-training-with-tpus)), you will see some particular techniques, including:

* Dealing with tf.RaggedTensor
* Advanced tf.data.Dataset manipulation
* Multi steps per TPU call
* Accumulate the predictions from the validation dataset with TPU

The model definition is a modified copy of the [distilbert.py](#https://github.com/huggingface/transformers/tree/master/src/transformers/models/distilbert) file from [Hugging Face's transformer library](#https://github.com/huggingface/transformers). If you find the code quality and object naming in this part is not perfect, it is due to my personal modification which doesn't refelect the high quality work done by Hugging Face great work.

Two special contributions about model definition in this notebook:

* implementation of decoder
* implementation of auto-regressive prediction generation (only works on CPU / GPU).

I tried to make this notebook working both on Kaggle and Google Colab. A few but minimal change is still required if you want to work on Google Colab.

# Table of Contents

0. [Tips and clarifications](#tips)
1. [Environment](#environment)
2. [Packages](#packages)
3. [TPU](#tpu)
4. [Data](#data)
5. [Configuration](#configuration)
  * [Model / Training](#model-training-settings)
  * [Running mode](#running-mode)
  * [Vocabulary](#vocabulary-settings)
6. [Train Manager](#train-manager)
  * [Dataset](#dataset)
    - [TFRecord files](#tfrecord-files)
    - [Load TFRecord files - tf.io.RaggedFeature](#load-tfrecord-files)
    - [Train / Valid split](#split)
    - [Transformation - from tf.RaggedTensor to tf.Tensor](#transformation)
  * [Model inputs](#inputs)
    - [Special masks](#special-masks)
    - [Input tensors](#input-tensors)
    - [For validation dataset](#for-valid)
  * [Model definition](#model-def)
  * [Training Manager](#training-manager)
7. [Train / Valid](#train-valid)
8. [Conclusion](#conclusion)  

# List of tips and clarifications<a id='tips'></a>

## TPU - Tips and remarks

1. Since tensorflow 2.3 works better with the code in this notebook, I need to upgrade to TF 2.3 manually. However, the TPU version was still in version 2.2, use the line below to select the corresponding TPU version.
    
    ```
       Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')
    ```
    
2. For loss calculation in training / validation, on each replica on TPU, we compute per-example loss, and divid the sum by the number of places where the target is not `NON_TARGET_ID` (which is `-100`) on the whole batch before distributing to each replica. This is because the gradients are synchronized over the replicas by summing them before the optimizer updates the model parameters. So the way we calculate the loss on each replica will give the final gradient corresponding to the average loss over the places with real targets in the whole batch distributed to each replica. See my another [Detecting contradiction and entailment in multilingual text using TPUs](#https://www.kaggle.com/yihdarshieh/masked-my-dear-watson-mlm-with-tpu#MLM-loss-calculation) for more details. The code looks like (where `train_batch` is a batch received by a replica, but `train_batch['nb_pred_places']` is the pre-computed number of places on the whole batch before being distributed to replicas):
    
    ```
        # Need to have the 1st dimension to make the losses not averaged
        losses = loss_obj(selected_targets[:, tf.newaxis], selected_logits[:, tf.newaxis])
        total_loss = tf.math.reduce_sum(losses)

        # `train_batch['nb_pred_places'][0]` is the total number of places used for calculating loss across replicas
        loss = total_loss / tf.cast(train_batch['nb_pred_places'][0], dtype=DTYPE)    
    ```
        
3. For training and validation, we want to perform multiple steps in a single TPU call to speed up the computation and avoid communication overhead between the local VM and XLA/TPU host. This is achieved by wrapping the dataset iteration in a `tf.range` for loop. See `dist_train_multi_steps` and ``dist_valid_multi_steps` in [Training Manager](#training-manager-main).

4. For validation, the last call to TPU might contain fewer steps. However, using TPU requires `tf.range` has a compile time known number of steps. Therefore the following 2 methods are provided, despite they are almost identical. The later is used for the last call to TPU.

    * dist_valid_multi_steps
    * dist_valid_multi_steps_last_call
    
5. During the validation, we want to collect the predictions and other information in order to save them to files to further investigation. However, the validation is run in graph model with TPU. In order to accumulate the restuls, we use `tf.TensorArray`. See the method `dist_valid_multi_steps` in [Training Manager](#training-manager-main).

6. For validation, ideally we don't want to drop the last batch which might contain fewer examples. The code in this notebook, while running with Kaggle TPU, gives error if we don't drop the last batch. While running on Colab with TPU, no error occurs when we keep the last batch. So the following condition is used to make this notebook runs on both platform
    
    ```
    valid_ds.apply(
        tf.data.experimental.dense_to_ragged_batch(
            batch_size=batch_size, drop_remainder=(IS_KAGGLE and tpu is not None)
        )
    ) 
    ```

7. While training on Kaggle with TPU, `tf.train.CheckpointManager` can save the checkpoints locally and gives `File system scheme '[local]' not implemented` error. The following code solves this issue

    ```
    ckpt_manager.checkpoint.save(
        file_prefix=ckpt_manager.directory + 'ckpt',
        options=tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
    )
    ```
    
8. Although I made (a lot of) effort to implement the auto-regression generation, unfortunately it can't run with TPU. It gives `XLA can't deduce compile time constant output shape for strided slice` error, which seems to me a limitation in XLA. However, it works while running on GPU. Therefore I need to comment out the block below in `valid_step` in [Training Manager](#training-manager-main), so XLA won't compile it at all. You can check in this version [GPU Running](#https://www.kaggle.com/yihdarshieh/r3id-tf-tpu/output?scriptVersionId=49027558) which shows it works with GPU.

    ```
        #             if generative == 1:

        #                 start_pos = self.train_config.window_size - MAX_PRED_TIME_QUESTION_BUNDLE_LEN
        #                 _, logits = predictor.generate(
        #                     valid_batch, start_pos=start_pos,
        #                     window_size=self.train_config.window_size,
        #                     dim=self.config.dim,
        #                     c_mask=c_mask, r_mask=r_mask, r_c_mask=r_c_mask, c_r_mask=c_r_mask
        #                 )

        #             else:
    ```

## Model / Input - Terminology and explanation

1. In `EdFormerConfig`, we have
    * model_type:
        - `'cr'`: encoder only
        - `'ed'` encoder decoder
    * generative (only applies to decoder during the inference time):
        - `True`: use auto-regressive generation, i.e. use the previous prediction result as input to the current prediction step. This is the usual way for decoder prediction.
        - `False`: pretend each of the questions in a bunlde (that we want to predict answer correctness) as a single question in a bundle, and only use the history before that bundle.

1. In the method `add_input_ids_and_targets` in [Input tensors](#input-tensors), we compute (quite a lot) tensors that are used in the model. Among them:
    * c_input_ids: The (encoded) ids for questions and lectures. This is used for the model encoder. Questions and lectures (along 4 special token) form a single vocabulary and their encoded ids (i.e. `c_input_ids`) is fed to a single embedding layer.
    * r_input_ids: The (encoded) answer correction. They won't be `0` and `1` anymore - the `0` was used to say `answered incorrectly`, but the `0` in the encoded ids means padding See [Vocabulary settings](#vocabulary-settings) for the mapping between tokens and their encoded ids. `r_input_ids` is used only if the model is of type `encoder only`, which is the case if we pass `model_type='cr'` to `EdFormerConfig.__init__()`.
    * d_input_ids: This is a shifted version of `r_input_ids` to indicate the ids in the previous place in a user interaction history. This is used only if the model is of type `encoder decoder`, which could be specified in `EdFormerConfig.__init__()` by passing `model_type='ed'`.
    
2. `r_input_ids` and `d_input_ids` will have different values, depending on if we are in the training or inference (including validation) time, because during inference, the answer correction information is not availabe (at the time we are predicting for a particular place). Basically, the values not available in inference time are replaced by `MASK_ID`.

3. For `encoder decoder` type, the `d_input_ids` will have different values during the inference time depending on if `EdFormerConfig.generative` is `True` or `False`:

    * if `generative` is `False`, we copy the value of `d_input_ids` at the 1st question in the question bundle being predicted and use it as the values of `d_input_ids` for all the remaining places in the same bunlde. This trick also applies to the positional information and attention masks - the objective is to pretend each question in the question bundle under prediction as the single question in the bundle.
    * if `generative` is `True`, we leave it as it is - which contains `MASK_ID` for all places in the question bundle to predict except the 1st question in that bundle.
    
4. Despite we have the `r_input_ids` and `d_input_ids`, the decoder self-attention mask and the decoder-to-encoder attention mask are called `r_mask` and `r_c_mask` rather than `d_mask` and `d_c_mask`.

5. Considering the amount of options provided in this notebook, it is not easy to provide all the information in detail - in particular being in the competition and would like to improve my scores. I will try to answer questions if you have any.

# Environment<a id='environment'></a>

To make this notebook both work on Kaggle and Google Colab.

In order to use this notebook on Google Colab, you have to specify your GCP bucket name and your project name. Also, copy the directories/files specifiedd in [Data section](#data) to your GCP bucket.

## Important for GPU

In [None]:
# It is unclear why the model is not running on GPU when using tensorflow 2.3.1.
# If GPU is enabled, it will avoid to upgrade to tensorflow 2.3.1 later.

gpu_info = !nvidia-smi
if 'command not found' in gpu_info[0]:
    USE_GPU = False
elif 'GPU  Name ' in str(gpu_info):
    USE_GPU = True
else:
    USE_GPU = False
        
print(f'USE_GPU: {USE_GPU}')

In [None]:
import os

IS_KAGGLE = os.path.isdir('/kaggle/input')

if IS_KAGGLE:
    BASE_DIR = '/kaggle/input'
else:
    BUCKET_DIR = 'gs://[YOUR_GCP_BUCKET_NAME]/r3id'
    BASE_DIR = '/content/drive/My Drive/r3id'
    BASE_DIR_QUOTED = '"/content/drive/My Drive/r3id"'
    
if not IS_KAGGLE:

    # Access GCP Bucket
    from google.colab import auth
    auth.authenticate_user()
    project_id = '[YOUR_GCP_PROJECT_NAME]'
    !gcloud config set project {project_id}
    !gsutil ls {BUCKET_DIR}

    # Access Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    if not os.path.isdir(BASE_DIR):

        !mkdir {BASE_DIR_QUOTED}
        !gsutil -m cp -r 'gs://[YOUR_GCP_BUCKET_NAME]/r3id' "/content/drive/My Drive"

!ls -l '{BASE_DIR}'

# Packages<a id='packages'></a>

## Install

In [None]:
# Use a newer version of huggingface's `tokenizers` (?? -> 0.9.2)    
# Use a newer version of huggingface's `transformers` (3.0.2 -> 3.4.0) 
!pip uninstall -y datatable
!pip uninstall -y tokenizers
!pip uninstall -y transformers

if IS_KAGGLE:
    !pip install '{BASE_DIR + '/r3id-packages/datatable-0.11.0-cp37-cp37m-manylinux2010_x86_64.whl'}'
    !pip install --upgrade '{BASE_DIR + '/r3id-packages/tokenizers-0.9.2-cp37-cp37m-manylinux1_x86_64.whl'}'
    if not USE_GPU:
        !pip uninstall -y tensorflow
        !pip install --upgrade tensorflow==2.3.1
        !pip install --upgrade tensorflow-probability
else:
    !pip install '{BASE_DIR + '/r3id-packages/datatable-0.11.0-cp36-cp36m-manylinux2010_x86_64.whl'}'
    !pip install --upgrade '{BASE_DIR + '/r3id-packages/tokenizers-0.9.2-cp36-cp36m-manylinux1_x86_64.whl'}'

!pip install --upgrade '{BASE_DIR + '/r3id-packages/transformers-3.4.0-py3-none-any.whl'}'
!pip install cloud-tpu-client

## Import

In [None]:
import os
import pandas as pd
import datatable as dt
import json
from collections import defaultdict
import random
import math
import datetime
import collections
import tensorflow as tf
os.environ['TF_DETERMINISTIC_OPS'] = '1'
import tensorflow_probability as tfp
from copy import deepcopy
import gc
gc.enable()
from cloud_tpu_client import Client

import transformers
from transformers import PretrainedConfig, DistilBertConfig
from transformers.modeling_tf_distilbert import TFFFN, TFTransformer
from transformers.modeling_tf_distilbert import TFMultiHeadSelfAttention as HFTFMultiHeadSelfAttention
from transformers.modeling_tf_distilbert import TFTransformerBlock as HFTFTransformerBlock
from transformers.modeling_tf_distilbert import TFFFN as HFTFFFN

from transformers.modeling_tf_utils import (
    TFPreTrainedModel,
    TFSharedEmbeddings,
    keras_serializable,
    shape_list,
)
if IS_KAGGLE:
    import riiideducation
    from kaggle_datasets import KaggleDatasets

# TPU<a id='tpu'></a>

Setup TPU

Since we upgrade to tensorflow 2.3 manually from the defalt version (2.2), we need to use the following code to configure TPU version.

    Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')
    


In [None]:
# Detect hardware, return appropriate distribution strategy

try:
    if IS_KAGGLE:
        Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Data<a id='data'></a>

Paths to data files used for this notebook.

In [None]:
n_contents_dict_path = f'{BASE_DIR}/r3id-info-public/n_contents_dict.json'

question_tags_info_path = f'{BASE_DIR}/r3id-info-public/question_tags_info.json'
lecture_tags_info_path = f'{BASE_DIR}/r3id-info-public/lecture_tags_info.json'

question_part_info_path = f'{BASE_DIR}/r3id-info-public/question_part_info.json'
lecture_part_info_path = f'{BASE_DIR}/r3id-info-public/lecture_part_info.json'

valid_info_path = f'{BASE_DIR}/r3id-info-public/valid_info_fold_1.json'
train_valid_split_indices_path = f'{BASE_DIR}/r3id-info-public/train_valid_split_indices_fold_1.json'

train_tfrec_dir_local = f'{BASE_DIR}/ednet-tfrecords-sequential'
valid_tfrec_dir_local = f'{BASE_DIR}/r3id-tfrecords-valid-public'

if tpu is None:
    
    train_tfrec_dir = f'{BASE_DIR}/ednet-tfrecords-sequential'
    valid_tfrec_dir = f'{BASE_DIR}/r3id-tfrecords-valid-public'

elif IS_KAGGLE:
    
    train_tfrec_dir = KaggleDatasets().get_gcs_path('ednet-tfrecords-sequential')
    valid_tfrec_dir = KaggleDatasets().get_gcs_path('r3id-tfrecords-valid-public')

    print(f'train_tfrec_dir from Kaggle = {train_tfrec_dir}\n')
    print(f'valid_tfrec_di from Kaggle = {valid_tfrec_dir}\n')

    # you can list the buckets
    !gsutil ls $train_tfrec_dir
    !gsutil ls $valid_tfrec_dir
        
else:
    
    train_tfrec_dir = f'{BUCKET_DIR}/ednet-tfrecords-sequential'
    valid_tfrec_dir = f'{BUCKET_DIR}/r3id-tfrecords-valid-public'

# Configuration<a id='configuration'></a>

## Model / Training settings<a id='model-training-settings'></a>

In [None]:
MODEL_TYPE = 'ed'
MODEL_SIZE = 'small'  # specify some model size parameters in the next cell
MODEL_DESC = 'edformer-tpu'

ACTIVATION = 'gelu'
USE_ABS_POS = False
SHARE_POS_EMBEDDING = True
USE_TAGS = True
USE_PART = True
USE_PRIOR_EXPLANATION = True
USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT = True
ALLOW_BUNDLE_ATTEN = False
GENERATIVE = True

WINDOW_SIZE = 96
LOSS_WEIGHT_WINDOW_SIZE = None
BATCH_SIZE = 64 * strategy.num_replicas_in_sync
PRED_BATCH_SIZE = 256 * strategy.num_replicas_in_sync

assert BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE
assert PRED_BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE

N_EPOCHS = 6
STEPS_PER_CALL = 1000

LR = 1e-3
END_LR = 5e-4
WARMUP_STEPS = 4000

DETERMINISTIC = False

if DETERMINISTIC:
    
    SEED = 2021
    N_PARALLEL_READS = None  # 1
    N_PARALLEL_CALLS = None  # 1
    SHUFFLE_BUFFER_SIZE = 1

else:
    
    SEED = None
    N_PARALLEL_READS = 16
    N_PARALLEL_CALLS = tf.data.experimental.AUTOTUNE
    SHUFFLE_BUFFER_SIZE = 4096

MAX_TRAIN_ITER_STEPS = None
MAX_VALID_ITER_STEPS = None

PRINTING_STEPS = 1000

CKPT_DIR = f'{MODEL_TYPE}-{MODEL_SIZE}-win-{WINDOW_SIZE}-bs-{BATCH_SIZE}-{MODEL_DESC}/'

#### Specify some model size parameters and assign a name to it.

In [None]:
if MODEL_SIZE == 'small':

    N_LAYERS = 2
    N_HEADS = 4
    DIM = 128
    HIDDEN_DIM = 4 * DIM

## Running mode<a id='running-mode'></a>

Determine we want to run traning, validataion or prediction for commit / submission to this competition.

By validation, it is actually a validation pipeline I wrote for simulatesimulating the competition submission.
That part of code if not presented here. For this notebook, only `TRAIN = True` will be used and it contains the
validation which is run in the tf.data.Dataset way - which is much faster.

In [None]:
DTYPE = tf.float32
tf.keras.backend.set_floatx('float32')

TRAIN = True
VALID = False
PRED = False

RESUME_TRAINING = False

N_FILES = 6
SUBMISSION = False

if IS_KAGGLE:
    N_FILES = len(os.listdir('/kaggle/input/riiid-test-answer-prediction'))
    SUBMISSION = (N_FILES != 6)
else:
    PRED = False

if SUBMISSION:
    
    TRAIN = False
    VALID = False
    PRED = True

if not TRAIN:
    RESUME_TRAINING = False

DEBUG = True
PROBE = False

CKPT_TRAIN_PATH = None
CKPT_PRED_PATH = None
if CKPT_TRAIN_PATH is None:
    
    if IS_KAGGLE:
        CKPT_TRAIN_PATH = './'
    else:
        CKPT_TRAIN_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'
        
        if TRAIN and not RESUME_TRAINING:
            _state = !gsutil -q stat {CKPT_TRAIN_PATH}*; echo $?
            already_existed = 1 - int(_state[0])
            assert not already_existed       

if CKPT_PRED_PATH is None:
    
    if IS_KAGGLE:
        CKPT_PRED_PATH = f'{BASE_DIR}/{CKPT_DIR}'
    else:
        CKPT_PRED_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'

In [None]:
print(f'SUBMISSION: {SUBMISSION}')
print(f'TRAIN: {TRAIN}')
print(f'VALID: {VALID}')
print(f'PRED: {PRED}')

print('')

print(f'CKPT_TRAIN_PATH: {CKPT_TRAIN_PATH}')
print(f'CKPT_PRED_PATH: {CKPT_PRED_PATH}')

## Probed Info<a id='probed-info'></a>

In [None]:
MAX_TRAIN_HISTORY_LEN = 17917
MAX_PRED_TIME_QUESTION_BUNDLE_LEN = 10

## Vocabulary settings<a id='vocabulary-settings'></a>

The notions of `token` and `id` are different here. Think `id` as encoding of a `token`. It might be confusing - for example, a question with `content_id` being `100`  (which is in the sense of token) will have an input id `104` to the model encoder, because we have 4 special tokens here. See [a later section](#token-id-mapping) for more details.

In [None]:
PAD_TOKEN = -2
START_TOKEN = -3
END_TOKEN = -4
MASK_TOKEN = -5

PAD_ID = 0
START_ID = 1
END_ID = 2
MASK_ID = 3

RESPONSE_LECTURE_TOKEN = -1
RESPONSE_FALSE_TOKEN = 0
RESPONSE_TRUE_TOKEN = 1

RESPONSE_LECTURE_ID = 4
RESPONSE_FALSE_ID = 5
RESPONSE_TRUE_ID = 6

NON_TARGET_ID = -100

TAG_VOCAB_SIZE = 189  # including extra `PAD_TOKEN`.
PART_VOCAB_SIZE = 9  # including extra `PAD_TOKEN`.
PRIOR_EXPLANATION_VOCAB_SIZE = 3  # including extra `PAD_TOKEN`.

# --------------------------------------------------

N_TAGS_PER_CONTENT = 6

# --------------------------------------------------

# This is only an assumption (including `0` for padding)
MAX_HISTORY_LEN = 20480
# --------------------------------------------------

# Train Manager<a id='train-manager'></a>

## Dataset<a id='dataset'></a>

### TFRecord files<a id='tfrecord-files'></a>

In [None]:
train_tfrec_fns = os.listdir(train_tfrec_dir_local)
# Sort the file - For verification purpose
train_tfrec_fns = sorted(train_tfrec_fns, key=lambda x: int(x.replace('EdNet-user-history-', '').replace('.tfrecord', '')))
    
if tpu is None:
    train_tfrec_paths = [os.path.join(train_tfrec_dir, fn) for fn in train_tfrec_fns]
else:
    train_tfrec_paths = [train_tfrec_dir + f'/{fn}' for fn in train_tfrec_fns]

train_tfrec_paths

In [None]:
valid_tfrec_fns = os.listdir(valid_tfrec_dir_local)

if tpu is None:
    valid_tfrec_paths = [os.path.join(valid_tfrec_dir, fn) for fn in valid_tfrec_fns]
else:
    valid_tfrec_paths = [valid_tfrec_dir + f'/{fn}' for fn in valid_tfrec_fns]

valid_tfrec_paths

### Load TFRecord files - tf.io.RaggedFeature<a id='load-tfrecord-files'></a>

We are dealing with `tf.io.RaggedFeature` here becausse users have different lengths of interaction history. When we use `tf.data.experimental.dense_to_ragged_batch` later to batch the examples, we will get `tf.RaggedTensor`.

We also inject some extra information, like
* the length of each sequence
* if an interaction is in the prediction time (for which we want to predict the answer correction)
* absolute position of an interaction in the history

The tfrecord file for the validation dataset actually contains the full history of a user (if it is selected to be in the validation dataset). But it has extra attributes, including
* the number of blocks in a user history selected to be used for validation
* the starting and ending (exclusive) position of each block selected

This extra information eables us to build the truncated user history for computing the validation score.

The notion of `block` is defined as:
    
    a sequence staring with a potential lecture (which must be the first element in the history, or its previous interaction is a question) and a subsequentail lectures (if any), followed by a unique question bundle (if any).
    
For a given user, at a time, only 1 block could in prediction time. 

In [None]:
# --------------------------------------------------------------------------------
# For training dataset

train_raw_features = {
    'user_id': tf.io.FixedLenFeature([], dtype=tf.int64),
    'row_id': tf.io.RaggedFeature(value_key='row_id', dtype=tf.int64),
    'timestamp': tf.io.RaggedFeature(value_key='timestamp', dtype=tf.int64),
    'content_id': tf.io.RaggedFeature(value_key='content_id', dtype=tf.int64),
    'content_type_id': tf.io.RaggedFeature(value_key='content_type_id', dtype=tf.int64),
    'task_container_id': tf.io.RaggedFeature(value_key='task_container_id', dtype=tf.int64),
    'user_answer': tf.io.RaggedFeature(value_key='user_answer', dtype=tf.int64),
    'answered_correctly': tf.io.RaggedFeature(value_key='answered_correctly', dtype=tf.int64),
    'prior_question_elapsed_time': tf.io.RaggedFeature(value_key='prior_question_elapsed_time', dtype=tf.float32),
    'prior_question_had_explanation': tf.io.RaggedFeature(value_key='prior_question_had_explanation', dtype=tf.int64),
}


def parse_train_example(example):
    """Parse an example from the training tfrecord files.

    Add the following extra attributes:
 
        - seq_len: The length of the user interaction history for training, before the validation dataset being removed from it.
        - prev_seq_len: The length of interaction history in an example from which the current example is obtained. Here, it just
            equals to `seq_len` since it has no source.
        - start: The starting index in the interaction history of an example from which the current example is obtained. Here, it is `0`.
        - end: The ending index in the interaction history of an example from which the current example is obtained. Here, it is `seq_len - 1`.
        - pred_time_mask: A l-D `tf.Tensor` of `0` and `1`, indicating if a place is in the prediction time. Here, all of them are `0`.          

    """

    _parsed = tf.io.parse_single_example(example, train_raw_features)
    
    parsed = {}
    parsed['user_id'] = _parsed['user_id']
    parsed['seq_len'] = tf.reduce_sum(tf.ones_like(_parsed['row_id'], dtype=tf.int32))
    parsed['prev_seq_len'] = parsed['seq_len']
    parsed['start'] = tf.constant(0, dtype=tf.int32)
    parsed['end'] = parsed['seq_len'] - 1
    
    for k in train_raw_features:
        
        data = _parsed[k]
        if k not in ['row_id', 'user_id', 'timestamp', 'prior_question_elapsed_time']:
            data = tf.cast(data, dtype=tf.int32)
        elif k == 'prior_question_elapsed_time':
            data = tf.cast(data, dtype=DTYPE)
        if k != 'user_id':
            parsed[k] = data
            
    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    parsed['shifted_answered_correctly'] = tf.concat([[START_TOKEN], parsed['answered_correctly']], axis=0)

    # This should be all `0`.
    pad_mask = tf.cast(parsed['row_id'] == PAD_TOKEN, dtype=tf.int32)

    # The tfrecord dataset is only used for training (excluding the part used for validation), so no place should be in the prediction time.
    pred_time_mask = tf.zeros_like(parsed['timestamp'], dtype=tf.int32)
    
    # This should be all `0`.
    pred_time_mask = pred_time_mask * (1 - pad_mask) + (PAD_TOKEN) * pad_mask
    
    parsed['pred_time_mask'] = pred_time_mask

    parsed['abs_pos'] = tf.range(parsed['seq_len'], dtype=tf.int32)
    parsed['shifted_abs_pos'] = tf.concat([[START_TOKEN], parsed['abs_pos'][:-1]], axis=0)
    
    return parsed


# --------------------------------------------------------------------------------
# For validation dataset

train_features_with_valid_info = {
    'user_id': tf.io.FixedLenFeature([], dtype=tf.int64),
    'row_id': tf.io.RaggedFeature(value_key='row_id', dtype=tf.int64),
    'timestamp': tf.io.RaggedFeature(value_key='timestamp', dtype=tf.int64),
    'content_id': tf.io.RaggedFeature(value_key='content_id', dtype=tf.int64),
    'content_type_id': tf.io.RaggedFeature(value_key='content_type_id', dtype=tf.int64),
    'task_container_id': tf.io.RaggedFeature(value_key='task_container_id', dtype=tf.int64),
    'user_answer': tf.io.RaggedFeature(value_key='user_answer', dtype=tf.int64),
    'answered_correctly': tf.io.RaggedFeature(value_key='answered_correctly', dtype=tf.int64),
    'prior_question_elapsed_time': tf.io.RaggedFeature(value_key='prior_question_elapsed_time', dtype=tf.float32),
    'prior_question_had_explanation': tf.io.RaggedFeature(value_key='prior_question_had_explanation', dtype=tf.int64),
    'n_valid_blocks': tf.io.FixedLenFeature([], dtype=tf.int64),
    'valid_blocks_start_pos': tf.io.RaggedFeature(value_key='valid_blocks_start_pos', dtype=tf.int64),
    'valid_blocks_end_pos': tf.io.RaggedFeature(value_key='valid_blocks_end_pos', dtype=tf.int64),
}


def parse_train_example_with_valid_info(example):

    _parsed = tf.io.parse_single_example(example, train_features_with_valid_info)
    
    parsed = {}
    parsed['user_id'] = _parsed['user_id']
    parsed['seq_len'] = tf.reduce_sum(tf.ones_like(_parsed['row_id'], dtype=tf.int32))
    parsed['prev_seq_len'] = parsed['seq_len']
    parsed['start'] = tf.constant(0, dtype=tf.int32)
    parsed['end'] = parsed['seq_len'] - 1
    
    for k in train_features_with_valid_info:
        
        data = _parsed[k]
        if k not in ['row_id', 'user_id', 'timestamp', 'prior_question_elapsed_time']:
            data = tf.cast(data, dtype=tf.int32)
        elif k == 'prior_question_elapsed_time':
            data = tf.cast(data, dtype=DTYPE)
        if k != 'user_id':
            parsed[k] = data

    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    parsed['shifted_answered_correctly'] = tf.concat([[START_TOKEN], parsed['answered_correctly']], axis=0)           

    # This should be all `0`.
    pad_mask = tf.cast(parsed['row_id'] == PAD_TOKEN, dtype=tf.int32)

    # The tfrecord dataset, in the raw format, contains training examples with validation information.
    # At this step, no place is consider to be in the prediction time yet.
    # This information will be updated in a further dataset transformation.
    pred_time_mask = tf.zeros_like(parsed['timestamp'], dtype=tf.int32)
    
    # This should be all `0`.
    pred_time_mask = pred_time_mask * (1 - pad_mask) + (PAD_TOKEN) * pad_mask
    
    parsed['pred_time_mask'] = pred_time_mask

    parsed['abs_pos'] = tf.range(parsed['seq_len'], dtype=tf.int32)
    parsed['shifted_abs_pos'] = tf.concat([[START_TOKEN], parsed['abs_pos'][:-1]], axis=0)
    
    return parsed

#### check

In [None]:
train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=1)
train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=1, deterministic=True)
for x in train_raw_ds.take(1):
    print(x)

In [None]:
valid_raw_ds = tf.data.TFRecordDataset(valid_tfrec_paths, num_parallel_reads=1)
valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=1, deterministic=True)
for x in valid_raw_ds.take(1):
    print(x)

### Split the dataset into training / validation parts<a id='split'></a>

This is used to remove the interactions used for validation from a user's full interaction history.

In [None]:
def convert_split_index_dict(split_index_dict):

    user_ids = []
    split_indices = []
    for k, v in split_index_dict.items():
        user_ids.append(k)
        split_indices.append(v)

    user_id_tensor = tf.constant(user_ids, dtype=tf.int64)
    split_index_tensor = tf.constant(split_indices, dtype=tf.int32)

    initializer = tf.lookup.KeyValueTensorInitializer(user_id_tensor, split_index_tensor)

    split_index_table = tf.lookup.StaticHashTable(
        initializer, default_value=tf.int32.limits[1], name=None
    )

    return split_index_table


def split_train_example(raw_example, split_index_table):
    """Split an original train example to actual training part and validation part, and only return the training part.
    """
    
    user_id = raw_example['user_id']
    seq_len = raw_example['seq_len']
    
    split_index = split_index_table.lookup(user_id)
    tf.debugging.Assert(tf.reduce_all(split_index >= 0), [split_index])
    
    example = {}
    
    # `user_id` not showing in `split_index_table` - not used for validation.
    # `split_index` becomes `seq_len` - i.e. all the interactions belongs to training.
    if split_index == tf.int32.limits[1]:
        split_index = seq_len
    
    example['user_id'] = raw_example['user_id']
    example['seq_len'] = split_index
    example['prev_seq_len'] = raw_example['seq_len']
    example['start'] = tf.constant(0, dtype=tf.int32)
    example['end'] = split_index - 1    
        
    for k in raw_example:
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end']:
            example[k] = raw_example[k][0:split_index]
    
    return example


def split_train_ds(train_raw_ds, split_indices, num_parallel_calls=None, deterministic=None):
    
    reduced_raw_ds = train_raw_ds.map(lambda example: split_train_example(example, split_indices), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return reduced_raw_ds

### Training dataset transformation - from tf.RaggedTensor to tf.Tensor<a id='transformation'></a>

These are helper functions used for sampling random subsequences (of a fixed length) of user interactions (in the splitted training dataset) for training.

In [None]:
def extract_subseqs(seqs, ending_indices, window_size, seq_len):

    """Let `seqs` be a `tf.RaggedTensor` be a tensor with rank = 2, where the 1st and 2nd dimensions are
       batch dimension and temporal and the unique ragged dimesion. Let `ending_indices` be a 1-D tensor
       with the same batch dimension as `seqs`. The condition `-1 <= ending_indices[i] < len(seqs[i])`
       must holds.

       For each example `seqs[i]`, we extract a partial sequence
       `seqs[ending_indices[i] - window_size : ending_indices[i]]`. If `len(seqs) < window_size`, the invalid
       indices will get `PAD_TOKEN` as values.
    
    Args:
        seqs: A `tf.RaggedTensor` tensor with rank = 2. The 1st dim is the batch dimension,
            and the 2nd dim is the temporal dimension.  The temporal dimension is the ragged rank,
            i.e. the unique dimension which is ragged.
        
        ending_indices: A 1-D `tf.int32` tensor with shape = [ragged_tensor.shape[0]]
        
        window_size: A scalar `tf.int32` tensor, which should be positive.
    """
    
    orig_window_size = window_size

    window_size += 1

    tf.debugging.Assert(tf.rank(seqs) >= 2, [tf.rank(seqs)])
        
    tf.debugging.Assert(tf.reduce_all(window_size > 0), [window_size])
    tf.debugging.Assert(tf.reduce_all(ending_indices >= -1), [ending_indices])
    
    batch_size = tf.reduce_sum(tf.ones_like(seq_len, dtype=tf.int32))
                        
    tf.debugging.Assert(tf.reduce_all((ending_indices < seq_len)), [ending_indices, seq_len])
    
    # shape = [batch_size, 1]
    _ending_indices = ending_indices[:, tf.newaxis]
    
    # add by 1 because we will add ...
    _ending_indices += 1

    # shape = [1, window_size]
    ranges = tf.range(window_size)[tf.newaxis, :]
    
    # shape = [batch_size, window_size]
    indices = _ending_indices - ranges
    
    # reverse the sequence dimension to get the correct temporal direction
    # shape = [batch_size, window_size]
    indices = tf.reverse(indices, axis=[1])
        
    # shape = [batch_size, window_size, 2]
    indices_to_seqs = tf.stack([tf.broadcast_to(tf.range(batch_size)[:, tf.newaxis], shape=[batch_size, window_size]), indices], axis=2)
    
    # Change negative indices to 0
    # shape = [batch_size, window_size, 2]
    indices_to_ext_seqs = tf.math.maximum(indices_to_seqs, 0)
    
    # Need to rework
    invalid_values = tf.cast(tf.ones(shape=seqs.shape[2:]) * PAD_TOKEN, seqs.dtype)
        
    # shape = [batch_size, 1, ...]
    invalid_values = tf.repeat([[invalid_values]], repeats=batch_size, axis=0)
        
    # Same shape as `ragged_tensor`, but each sequence has 1 more element inserted at the beginning
    extended_seqs = tf.concat([invalid_values, seqs], axis=1)
            
    # selected
    # shape = [batch_size, window_size, ...]
    sampled_subseqs = tf.gather_nd(extended_seqs, indices=indices_to_ext_seqs)

    # take the last `orig_window_size` part
    sampled_subseqs = sampled_subseqs[:, 1:]
    prev_last_element = sampled_subseqs[:, 0]
    
    return sampled_subseqs, prev_last_element


def random_ending_indices(seq_len, seed=None):

    batch_size = tf.reduce_sum(tf.ones_like(seq_len))
        
    max_seq_len = tf.cast(tf.math.reduce_max(seq_len), dtype=DTYPE)
        
    # shape = [batch_size]
    ending_indices = -1.0 + tf.random.uniform(minval=0.0, maxval=1.0, dtype=DTYPE, shape=[batch_size], seed=seed) * tf.cast(seq_len + 1, dtype=DTYPE)
    ending_indices = tf.cast(tf.math.floor(ending_indices), dtype=tf.int32)    

    # sanity check
    tf.debugging.Assert(tf.math.reduce_min(seq_len - ending_indices) >= 1, [seq_len, ending_indices])   
    tf.debugging.Assert(tf.math.reduce_min(ending_indices) >= -1, [seq_len, ending_indices])   
    
    return ending_indices
    
    
def random_subseqs(seqs, window_size, seq_len, seed=None):
    """Let `seqs` be a `tf.RaggedTensor` be a tensor with rank = 2, where the 1st and 2nd dimensions are
       batch dimension and temporal and the unique ragged dimesion.
       
       For each example `seqs[i]`, we extract a random partial sequence of length <= window_size
    
    Args:
        seqs: A `tf.RaggedTensor` tensor with rank = 2. The 1st dim is the batch dimension,
            and the 2nd dim is the temporal dimension.  The temporal dimension is the ragged rank,
            i.e. the unique dimension which is ragged.
        
        window_size: A scalar `tf.int32` tensor, which should be positive.
    """

    ending_indices = random_ending_indices(seq_len=seq_len, seed=seed)
    sampled_subseqs, _ = extract_subseqs(seqs, ending_indices, window_size, seq_len=seq_len)
    
    return sampled_subseqs, ending_indices

# --------------------------------------------------------------------------------
# For train / validation

def extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size):
    """Extract subsequences from a training batch (containing no validation part anymore) consisting of `tf.RaggedTensor` objects.
    The subsequences will have the same length `window_size` by padding from the beginning with `PAD_TOKEN`.
    """    

    tf.debugging.Assert(tf.reduce_all(raw_batch['start'] == tf.zeros_like(raw_batch['start'], dtype=tf.int32)), [raw_batch['start']])
    
    batch = {}
    batch['user_id'] = raw_batch['user_id']

    start = tf.math.maximum(ending_indices - window_size + 1, 0)
    end = ending_indices
    seq_len = (end - start) + 1
    
    batch['seq_len'] = seq_len
    batch['prev_seq_len'] = raw_batch['seq_len']    
    batch['start'] = start
    batch['end'] = end
    
    for k in raw_batch:
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end'] + ['n_valid_blocks', 'valid_blocks_start_pos', 'valid_blocks_end_pos', 'valid_block_pos', 'valid_block_idx', 'valid_start', 'valid_end']:
            
            subseqs, prev_last_element = extract_subseqs(seqs=raw_batch[k], ending_indices=ending_indices, window_size=window_size, seq_len=raw_batch['prev_seq_len'])
        
            batch[k] = subseqs

        elif k in ['n_valid_blocks', 'valid_block_idx', 'valid_start', 'valid_end']:
            batch[k] = raw_batch[k]

    return batch

# --------------------------------------------------------------------------------

def random_subseqs_from_raw_batch(raw_batch, window_size, only_last, seed=None):
        
    if only_last:
        ending_indices = raw_batch['seq_len'] - 1
    else:
        ending_indices = random_ending_indices(raw_batch['seq_len'], seed=seed)
        ending_indices = tf.math.maximum(ending_indices, window_size - 1)
        _mask = tf.cast(ending_indices >= raw_batch['seq_len'], tf.int32)
        ending_indices = ending_indices * (1 - _mask) + (raw_batch['seq_len'] - 1) * _mask

    batch = extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size)
    
    return batch, ending_indices


def random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size, only_last=False, seed=None, num_parallel_calls=None, deterministic=None):
    
    batched_ds = batched_raw_ds.map(lambda raw_batch: random_subseqs_from_raw_batch(raw_batch, window_size, only_last, seed), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return batched_ds

## Convert the dataset to model inputs<a id='inputs'></a>

#### helper functions

In [None]:
def convert_valid_info(valid_info):

    data = {}

    user_ids = list(valid_info.keys())

    for user_id in user_ids:

        user_valid_info = valid_info[user_id]

        user_valid_info['bundle_info']['block_starting_index_dict'] = {int(k): v for k, v in user_valid_info['bundle_info']['block_starting_index_dict'].items()}
        user_valid_info['bundle_info']['bundle_starting_index_dict'] = {int(k): v for k, v in user_valid_info['bundle_info']['bundle_starting_index_dict'].items()}
        data[int(user_id)] = user_valid_info

        del valid_info[user_id]

    return data

def load_data(path):

    if path.endswith('.json'):
        
        with open(path, 'r', encoding='UTF-8') as fp:
            data = json.load(fp)

        if path.endswith('valid_info_fold_1.json'):
            data = convert_valid_info(data)

        elif path.endswith('n_contents_dict.json'):
            data = {int(k): v for k, v in data.items()}
            
        elif path.endswith('train_valid_split_indices_fold_1.json'):            
            data = {int(k): v for k, v in data.items()}

        elif path.endswith('question_tags_info.json') or path.endswith('lecture_tags_info.json'):            
            data = {int(k): v for k, v in data.items()}            

        elif path.endswith('question_part_info.json') or path.endswith('lecture_part_info.json'):            
            data = {int(k): v for k, v in data.items()} 

        return data

### Build some table / information that are used for mapping ids in the dataset (consider them as tokens) to model input ids.<a id='token-id-mapping'></a>

We use `tf.lookup.StaticHashTable` to build the mappings (for questions / lectures) instead of Python dictionary, so we can perform the transformation with `tf.data.Dataset`.

For tags and parts respectively, we just use a single tensor to store the information, and use `tf.gather` to extract the information.

In [None]:
questioin_df = pd.read_csv(f'{BASE_DIR}/riiid-test-answer-prediction/questions.csv')
lecture_df = pd.read_csv(f'{BASE_DIR}/riiid-test-answer-prediction/lectures.csv')

question_ids = questioin_df['question_id'].tolist()
lecture_ids = lecture_df['lecture_id'].tolist()

question_ids = sorted(question_ids)
lecture_ids = sorted(lecture_ids)

assert question_ids == sorted(question_ids)
assert lecture_ids == sorted(lecture_ids)

special_vocab = [PAD_TOKEN, START_TOKEN, END_TOKEN, MASK_TOKEN]
question_vocab = question_ids
lecture_vocab = lecture_ids

content_vocab = special_vocab + question_vocab + lecture_vocab

# Map question ids to encoder's input ids
input_ids = tf.range(len(special_vocab), len(special_vocab) + len(question_vocab))
initializer = tf.lookup.KeyValueTensorInitializer(tf.constant(question_ids, dtype=tf.int32), input_ids)
question_id_to_input_id_table = tf.lookup.StaticHashTable(
    initializer, default_value=tf.int32.limits[1], name=None
)

# Map lecture ids to encoder's input ids
input_ids = tf.range(len(special_vocab) + len(question_vocab), len(content_vocab))
initializer = tf.lookup.KeyValueTensorInitializer(tf.constant(lecture_ids, dtype=tf.int32), input_ids)
lecture_id_to_input_id_table = tf.lookup.StaticHashTable(
    initializer, default_value=tf.int32.limits[1], name=None
)

response_vocab = [PAD_TOKEN, START_TOKEN, END_TOKEN, MASK_TOKEN, RESPONSE_LECTURE_TOKEN, RESPONSE_TRUE_ID, RESPONSE_FALSE_ID]

CONTENT_VOCAB_SIZE = len(content_vocab)
RESPONSE_VOCAB_SIZE = len(response_vocab)

question_tags_info = load_data(question_tags_info_path)
lecture_tags_info = load_data(lecture_tags_info_path)

question_part_info = load_data(question_part_info_path)
lecture_part_info = load_data(lecture_part_info_path)

tags_database = []
part_database = []

for idx in range(len(content_vocab)):
    if idx < len(special_vocab):
        tags_database.append([-1] * N_TAGS_PER_CONTENT)
        part_database.append(-1)
    elif idx < len(special_vocab) + len(question_vocab):
        index_in_questions_ids = idx - len(special_vocab)
        question_id = question_ids[index_in_questions_ids]
        tags_database.append(question_tags_info[question_id])
        part_database.append(question_part_info[question_id])
    elif idx < len(content_vocab):
        index_in_lecture_ids = idx - (len(special_vocab) + len(question_vocab))
        lecture_id = lecture_ids[index_in_lecture_ids]
        tags_database.append(lecture_tags_info[lecture_id])
        part_database.append(lecture_part_info[lecture_id])

# The tags and part here are not the ids for model inputs, but the ids defined in the questions.csv / lectures.csv
c_inputs_ids_to_tags = tf.constant(tags_database, dtype=tf.int32)
c_inputs_ids_to_part = tf.constant(part_database, dtype=tf.int32)

In [None]:
print(f'CONTENT_VOCAB_SIZE: {CONTENT_VOCAB_SIZE}')
print(f'RESPONSE_VOCAB_SIZE: {RESPONSE_VOCAB_SIZE}')

### Prepare encoder / decoder input ids and attention masks

### special masks<a id='special-masks'></a>

We provide the `causal attention mask` and `attention mask that can attend to the current and previous timestamp` - rembember that we can have a bunlde of questions that share the same timestamp.

In [None]:
def get_causal_attention_mask(nd, ns, dtype, only_before):
    """
    1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
    -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    
    # Remark: Think `nd` as the number of queries and `ns` as the number of keys.
    # In encoder-decoder case, the queries are the decoder features and the keys are the encoder features.
    
    i = tf.range(nd)[:, tf.newaxis]  # repeat along dim 1
    j = tf.range(ns) # repeat along dim 0 
    m = i >= (j - ns + nd) + tf.cast(only_before, dtype=tf.int32)
    
    return tf.cast(m, dtype)

def get_attention_mask_from_timestamp_batch(timestamp_tensors, dtype, only_before):
    """
    Args:
        timestamp_tensors: 2-D tf.int32 tensor, representing a batch of sequences of non-decreasing timestamps.
    
    Returns:
        attention_mask: 3-D tf.int32 tensor of shape = [batch_size, query_len, key_len], consisting of 0 and 1.
            Here `query_len` and `key_len` are actually `seq_len`. It should be reshpaed, when used to calculate 
            attention scores, to [batch_size, nb_attn_head, query_len, key_len].
    """
    
    t = timestamp_tensors
    
    batch_size = tf.math.reduce_sum(tf.ones_like(t[:, :1], dtype=tf.int32))
    seq_len = tf.math.reduce_sum(tf.ones_like(t[:1, :], dtype=tf.int32))
    
    x = tf.broadcast_to(t[:, :, tf.newaxis], shape=[batch_size, seq_len, seq_len]) # repeat along dim 2
    y = tf.broadcast_to(t[:, tf.newaxis, :], shape=[batch_size, seq_len, seq_len]) + tf.cast(only_before, dtype=tf.int64) # repeat along dim 1
    
    m =  x >= y
    
    return tf.cast(m, dtype)

##### check

In [None]:
print(get_causal_attention_mask(3, 3, tf.int32, only_before=tf.constant(False)))
print(get_causal_attention_mask(3, 3, tf.int32, only_before=tf.constant(True)))

timestamp_tensors = tf.constant([[0, 0, 1, 2, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 3]], dtype=tf.int64)

print(get_attention_mask_from_timestamp_batch(timestamp_tensors, tf.int32, only_before=tf.constant(False)))
print(get_attention_mask_from_timestamp_batch(timestamp_tensors, tf.int32, only_before=tf.constant(True)))

### Tensors required for model (input ids, targets, attention masks)<a id='input-tensors'></a>

To be continued ...

In [None]:
@tf.function
def get_attention_masks(batch, training, encoder_decoder, generative, allow_bundle_atten):

    pad_mask = tf.cast((batch['content_type_id'] == PAD_TOKEN), dtype=tf.int32)
    question_mask = tf.cast((batch['content_type_id'] == 0), dtype=tf.int32)

    # Replace `PAD_TOKEN` by `0`.
    pred_time_mask = batch['pred_time_mask'] * tf.cast(batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)

    content_ids = batch['content_id']

    # sanity check
    _pad_mask = tf.cast((batch['timestamp'] == PAD_TOKEN), dtype=tf.int32)
    tf.debugging.Assert(tf.reduce_all(pad_mask == _pad_mask), [pad_mask, _pad_mask])    

    # ----------------------------------------
    # General mask
    # `seq_len` below is actually the `window_size`.
    
    seq_len = tf.math.reduce_sum(tf.ones_like(content_ids[0, :], dtype=tf.int32))
    
    # Don't pay attention to [PAD]
    # shape = [batch_size, seq_len]
    non_pad_mask = 1 - pad_mask
    
    # Can only attent to the current and previous position
    # shape = [seq_len, seq_len]
    causal_attention_mask = get_causal_attention_mask(nd=seq_len, ns=seq_len, dtype=tf.int32, only_before=tf.constant(False))
   
    # Can only attent to the previous position    
    # shape = [seq_len, seq_len]
    causal_attention_mask_only_before = get_causal_attention_mask(nd=seq_len, ns=seq_len, dtype=tf.int32, only_before=tf.constant(True))

    # Can only attend to the current previous timestamps
    # shape = [batch_size, seq_len, seq_len]
    timestamp_attention_mask = get_attention_mask_from_timestamp_batch(batch['timestamp'], dtype=tf.int32, only_before=tf.constant(False))

    # Can only attend to previous timestamps
    # shape = [batch_size, seq_len, seq_len]
    timestamp_attention_mask_only_before = get_attention_mask_from_timestamp_batch(batch['timestamp'], dtype=tf.int32, only_before=tf.constant(True))
        
    # ----------------------------------------
    # content self attention mask
    # shape = [batch_size, seq_len, seq_len]

    if tf.cast(allow_bundle_atten, dtype=tf.bool):
        # can see the contents in the current and previous bundle, not including [PAD]
        c_mask = non_pad_mask[:, tf.newaxis, :] * timestamp_attention_mask
    else:
        # can see the current and previous contents, not including [PAD]
        c_mask = non_pad_mask[:, tf.newaxis, :] * causal_attention_mask[tf.newaxis, :, :]
    
    # ----------------------------------------
    # response self attention mask
    # shape = [batch_size, seq_len, seq_len]    
    
    # for `encoder-only` model: same as `c_mask`
    r_mask = c_mask

    if tf.cast(encoder_decoder, dtype=tf.bool):
        
        # for `encoder-decoder` model: should be causal masking
        d_mask = non_pad_mask[:, tf.newaxis, :] * causal_attention_mask[tf.newaxis, :, :]

        # if neither `training` nor `generative`
        if (1 - training) * (1 - generative) == 1:

            pred_time_question_mask = pred_time_mask * question_mask
            
            d_mask = d_mask * tf.math.maximum((1 - pred_time_question_mask)[:, tf.newaxis, :], tf.eye(seq_len, dtype=tf.int32)[tf.newaxis, :, :])
            r_mask = d_mask
            c_mask = c_mask * tf.math.maximum((1 - pred_time_question_mask)[:, tf.newaxis, :], tf.eye(seq_len, dtype=tf.int32)[tf.newaxis, :, :])

        else:
            r_mask = d_mask

    # ----------------------------------------
    # response to content to attention_mask
    # shape = [batch_size, seq_len, seq_len]
    
    r_c_mask = c_mask

    # ----------------------------------------
    # Used for encoder-only models.
    # Can see only the responses in the previous bundles, not including [PAD]

    # shape = [batch_size, seq_len, seq_len]
    c_r_mask = non_pad_mask[:, tf.newaxis, :] * timestamp_attention_mask_only_before
    
    return c_mask, r_mask, r_c_mask, c_r_mask


def add_input_ids_and_targets(batch, training, generative, use_abs_pos):
    """Add input ids and targets for training
    """
        
    content_ids = batch['content_id']
    content_type_ids = batch['content_type_id']
    answered_correctly = batch['answered_correctly']
    
    question_mask = tf.cast((content_type_ids == 0), dtype=tf.int32)
    lecture_mask = tf.cast((content_type_ids == 1), dtype=tf.int32)
    
    pad_mask = tf.cast((content_type_ids == PAD_TOKEN), dtype=tf.int32)
    _pad_mask = tf.cast((batch['timestamp'] == PAD_TOKEN), dtype=tf.int32)
    tf.debugging.Assert(tf.reduce_all(pad_mask == _pad_mask), [pad_mask, _pad_mask])
    
    # Replace `PAD_TOKEN` by `0`.
    pred_time_mask = batch['pred_time_mask'] * tf.cast(batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)
    pred_time_question_mask = pred_time_mask * question_mask
    
    # The number of questions in prediction time
    # shape = [batch_size]
    n_questions_in_pred_time = tf.math.reduce_sum(pred_time_question_mask, axis=1)
    # sanity check
    tf.debugging.Assert(tf.reduce_all(n_questions_in_pred_time >= 0), [n_questions_in_pred_time])      

    response_false_mask = tf.cast((answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    response_true_mask = tf.cast((answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)
    
    # ----------------------------------------

    _batch = {}
    for k in batch:
        _batch[k] = batch[k]

    if use_abs_pos:
        # Shifted by one, `PAD_TOKEN` --> `0`
        _batch['pos_ids'] = tf.math.maximum(0, _batch['abs_pos'] + 1)

        # Shifted by one, `PAD_TOKEN` --> `0`, `START_TOKEN` --> `0`.
        _batch['shifted_pos_ids'] = tf.math.maximum(0, _batch['shifted_abs_pos'] - 1)

    else:

        _batch['pos_ids'] = tf.math.cumsum(1 - pad_mask, axis=1)
        _batch['shifted_pos_ids'] = tf.math.maximum(0, _batch['pos_ids'] - 1)

    # ----------------------------------------
    # content_input_ids        
    
    content_input_ids = question_mask * question_id_to_input_id_table.lookup(content_ids) + \
        lecture_mask * lecture_id_to_input_id_table.lookup(content_ids) + \
        pad_mask * PAD_ID
    
    _batch['c_input_ids'] = content_input_ids
    
    # ----------------------------------------
    # response_input_ids - only used for encoder-only models
        
    # For `RESPONSE_LECTURE_ID`, we need to multiply by `(1 - pred_time_question_mask) * lecture_mask` instead of just `lecture_mask`.
    # Reason: we might have lectures occur during the prediction time. And these should be assigned to `RESPONSE_MASK_ID`.
    # If we only multiply by `lecture_mask`, we will get `RESPONSE_MASK_ID + RESPONSE_LECTURE_ID` which gives OOV error for embedding.
    
    response_input_ids_masked = pad_mask * PAD_ID + pred_time_question_mask * MASK_ID + lecture_mask * RESPONSE_LECTURE_ID + (1 - pred_time_question_mask) * response_false_mask * RESPONSE_FALSE_ID + (1 - pred_time_question_mask) * response_true_mask * RESPONSE_TRUE_ID
    _batch['r_input_ids'] = response_input_ids_masked

    # ----------------------------------------
    # d_input_ids

    shifted_answered_correctly = batch['shifted_answered_correctly']

    shifted_pad_mask = tf.cast((shifted_answered_correctly == PAD_TOKEN), dtype=tf.int32)
    shifted_start_mask = tf.cast((shifted_answered_correctly == START_TOKEN), dtype=tf.int32)

    shifted_masking_mask = tf.cast((shifted_answered_correctly == MASK_TOKEN), dtype=tf.int32)
    shifted_response_lecture_mask = tf.cast((shifted_answered_correctly == RESPONSE_LECTURE_TOKEN), dtype=tf.int32)
    shifted_response_false_mask = tf.cast((shifted_answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    shifted_response_true_mask = tf.cast((shifted_answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)

    decoder_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_masking_mask * MASK_ID + shifted_response_lecture_mask * RESPONSE_LECTURE_ID + shifted_response_false_mask * RESPONSE_FALSE_ID + shifted_response_true_mask * RESPONSE_TRUE_ID

    if not generative:

        pred_time_question_start_mask = tf.cast(tf.math.cumsum(pred_time_question_mask, axis=1) == 1, dtype=tf.int32)
        # shape = [batch_size]
        pred_time_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * decoder_input_ids, axis=1)

        # If at the starting of questions, we get `MASK_ID`, we change it to `RESPONSE_LECTURE_ID`
        prev_lecture_mask = tf.cast(pred_time_question_start_value == MASK_ID, tf.int32)            
        pred_time_question_start_value = RESPONSE_LECTURE_ID * prev_lecture_mask + pred_time_question_start_value * (1 - prev_lecture_mask) 

        # All places in prediction time share the values at the prediction question starting places.
        decoder_input_ids = decoder_input_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pred_time_question_start_value[:, tf.newaxis]

        # If there are remaining MASK_ID, its previous place is in prediction time, and we are sure it is a lecture at that moment.
        decoder_input_ids = decoder_input_ids * tf.cast(decoder_input_ids != MASK_ID, dtype=tf.int32) + RESPONSE_LECTURE_ID * tf.cast(decoder_input_ids == MASK_ID, dtype=tf.int32)

    _batch['d_input_ids'] = decoder_input_ids

    # ----------------------------------------
    # post processing `pos_ids` and `shifted_pos_ids`
    # Once the real decoder is implemented, we need to fix this.

    if not generative:

        pos_ids = _batch['pos_ids']
        shifted_pos_ids = _batch['shifted_pos_ids']

        pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * pos_ids, axis=1)
        shifted_pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * shifted_pos_ids, axis=1)

        pos_ids = pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pos_ids_question_start_value[:, tf.newaxis]
        shifted_pos_ids = shifted_pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * shifted_pos_ids_question_start_value[:, tf.newaxis]

        _batch['pos_ids'] = pos_ids
        _batch['shifted_pos_ids'] = shifted_pos_ids

    # ----------------------------------------
    # tags

    _batch['tag_ids'] = tf.gather(params=c_inputs_ids_to_tags, indices=_batch['c_input_ids']) + 1 
        
    # ----------------------------------------        
    # part

    _batch['part_ids'] = tf.gather(params=c_inputs_ids_to_part, indices=_batch['c_input_ids']) + 1

    # ----------------------------------------
    # prior explanation ids - just `prior_question_had_explanation` added by 1.
    # For PAD, we get `-1` but changed to `0`.
    # only used along with `d_input_ids`.

    _batch['prior_explanation_ids'] = tf.math.maximum(0, _batch['prior_question_had_explanation'] + 1)

    # ----------------------------------------
    # `prior_question_elapsed_time` in seconds

    _batch['prior_question_elapsed_time_input'] = tf.cast(_batch['prior_question_elapsed_time'], dtype=DTYPE) / 1000.0

    # ----------------------------------------
    # lag_time --> 


    # ----------------------------------------
    # target
    
    # `-2` means padding, `-1` means lecture
    answer_mask = tf.cast(batch['answered_correctly'] > -1, dtype=tf.int32)
    
    # negated values become `NON_TARGET_ID (-100)`
    _batch['target'] = batch['answered_correctly'] * answer_mask + (NON_TARGET_ID) * (1 - answer_mask)
    
    # ----------------------------------------
    # nb_pred_places

    targets = _batch['target']

    # `targets` are defined for all places (other than [PAD] and lectures).
    # However, during validation, unlike during training, we only focus on the places that are in prediction time (and being questions).
    # This should be used only in `train_step` and `valid_step`, `train` and `valid`, but not in `run_pred`.
    if training == 0:
        pred_time_mask = _batch['pred_time_mask']
        targets = targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)

    pred_mask = targets != NON_TARGET_ID
    nb_pred_places = tf.math.reduce_sum(tf.cast(pred_mask, dtype=tf.int32))

    # shape = [batch_size], but it is a constant
    _batch['nb_pred_places'] = nb_pred_places * tf.ones_like(_batch['user_id'], dtype=tf.int32)

    return _batch


def prepare_training_dataset(batched_ds, generative=False, use_abs_pos=False, num_parallel_calls=None, deterministic=None):
    # Add input ids and targets for training
    
    training = tf.constant(1, dtype=tf.int32)

    return batched_ds.map(lambda batch, _: add_input_ids_and_targets(batch, training, generative, use_abs_pos), num_parallel_calls=num_parallel_calls, deterministic=deterministic)

### For validation dataset<a id='for-valid'></a>

Previously, each example in the validation dataset contains the full interaction history of the corresponding user. However, each example for validation should contains only a single block (see [Load TFRecord files](#load-tfrecord-files) for the definition of `block`). Therefore, we perform the following dataset transformation: for each example

1. repeat it `n_valid_blocks` times
2. for each repeated instance, add an index `n` to indicate it is the `n-th` blocks appeared in the validation time
3. use the index calculated in 2. and the information stored in `'valid_blocks_start_pos'` and `'valid_blocks_end_pos'` to get the start / end indices in that user interaction history.
4. use the end indices obtained in 3. to remove the interactions after the current validation block
5. use the start / end indices to prepare the examples as preparing a training example
  * In particular, we modify the attribute `pred_time_mask` that indicates the places where interactions are in the validation.


In [None]:
def add_valid_block_info(valid_raw_example):
    """
        - Add `valid_block_pos`.
        - Repeat `n_valid_blocks` times, each with a index `valid_block_idx`.
        - Transform using `trans_2`.
    """
    
    example = {}
    for k in valid_raw_example:
        example[k] = valid_raw_example[k]
    
    valid_block_pos = tf.stack([valid_raw_example['valid_blocks_start_pos'], valid_raw_example['valid_blocks_end_pos']], axis=1)
    example['valid_block_pos'] = valid_block_pos
    
    n_valid_blocks = example['n_valid_blocks']
    
    ds_1 = tf.data.Dataset.from_tensors(example).repeat(tf.cast(n_valid_blocks, dtype=tf.int64))
    ds_2 = tf.data.Dataset.range(tf.cast(n_valid_blocks, dtype=tf.int64), output_type=tf.int32)
    ds = tf.data.Dataset.zip((ds_1, ds_2))
    
    ds = ds.map(lambda ex, valid_block_idx: add_valid_block_idx(ex, valid_block_idx))

    return ds


def add_valid_block_idx(valid_example, valid_block_idx):
    """Add the following information:
    
        - `valid_block_idx`: The index of a block in all the blocks in a user's interaction history that are used for validation.
        - `valid_start`: The starting indices of a valid block in a user's (before being splitted) training interaction history.
        - `valid_end`: The ending indices of the a valid block in a user's (before being splitted) training interaction history.
        
    Then call `remove_future_valid_blocks()` to remove validation blocks after the current one.
        
    """
    
    valid_example['valid_block_idx'] = valid_block_idx
    valid_example['valid_start'] = valid_example['valid_block_pos'][valid_block_idx][0]
    valid_example['valid_end'] = valid_example['valid_block_pos'][valid_block_idx][1]
    
    return valid_example


def remove_future_valid_blocks(valid_example):
    """Remove the validation blocks in the full interaction history of a user after the current validation block.
    """
        
    example = {}
    
    valid_start = valid_example['valid_start']
    valid_end = valid_example['valid_end']
    
    example['user_id'] = valid_example['user_id']
    example['seq_len'] = valid_end + 1
    example['prev_seq_len'] = valid_example['seq_len']
    example['start'] = tf.constant(0, dtype=tf.int32)
    example['end'] = valid_end
        
    for k in valid_example:
        
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end']:
            
            if k in ['n_valid_blocks', 'valid_blocks_start_pos', 'valid_blocks_end_pos', 'valid_block_pos', 'valid_block_idx', 'valid_start', 'valid_end']:
                # attributes with single value, or the values are not required to be removed
                example[k] = valid_example[k]
            else:
                # attributes with 
                example[k] = valid_example[k][0:valid_end + 1]
                
            if k == 'pred_time_mask':
                # Update `pred_time_mask` - assign `1` to the current validation places.
                
                n_valid_interactions = valid_end - valid_example['valid_start'] + 1
                example[k] = tf.concat([valid_example[k][:valid_start], tf.ones(shape=[n_valid_interactions], dtype=tf.int32)], axis=0)
    
    return example


def extract_ending_subseqs(raw_batch, window_size):
    
    ending_indices = raw_batch['seq_len'] - 1

    return extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size)
    

def prepare_validation_dataset(valid_raw_ds, batch_size=3, window_size=5, generative=False, use_abs_pos=False, seed=None, num_parallel_calls=None, deterministic=None):
    
    valid_ds = valid_raw_ds.flat_map(lambda valid_raw_example: add_valid_block_info(valid_raw_example))
    valid_ds = valid_ds.map(lambda example: remove_future_valid_blocks(example))
    
    # should be outside
    # batch examples with attributes having different lengths across examples - tf.RaggedTensor
    batched_valid_ds = valid_ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size, drop_remainder=(IS_KAGGLE and tpu is not None)))

    # batch - tf.Tensor: Extract subsequences from the ending of a fixed length    
    batched_valid_ds = batched_valid_ds.map(lambda raw_batch: extract_ending_subseqs(raw_batch, window_size), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    training = tf.constant(0, dtype=tf.int32)

    valid_ds = batched_valid_ds.map(lambda batch: add_input_ids_and_targets(batch, training, generative, use_abs_pos), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return valid_ds

### check

In [None]:
train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=1)
train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=1, deterministic=True)

# Get the splitted training dataset.
train_valid_split_indices = load_data(train_valid_split_indices_path)
train_valid_split_table = convert_split_index_dict(train_valid_split_indices)
reduced_raw_ds = split_train_ds(train_raw_ds, train_valid_split_table, num_parallel_calls=1, deterministic=True)

# batch examples with attributes having different lengths across examples - tf.RaggedTensor
batched_raw_ds = reduced_raw_ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=5))

# batch - tf.Tensor: Extract random subsequences of a fixed length
batched_ds = random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size=10, seed=1, num_parallel_calls=1, deterministic=True)

# Add input ids and targets for training
train_ds = prepare_training_dataset(batched_ds, use_abs_pos=False, num_parallel_calls=1, deterministic=True)

In [None]:
it = iter(train_ds.take(2))
next(it)

In [None]:
valid_raw_ds = tf.data.TFRecordDataset([valid_tfrec_paths], num_parallel_reads=1)
valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=1, deterministic=True)
valid_ds = prepare_validation_dataset(valid_raw_ds, batch_size=5, window_size=10, use_abs_pos=False)

In [None]:
it = iter(valid_ds.take(2))
next(it)

## Model definition<a id='model-def'></a>

### Layer initializer and some activation functions

Since we use the `Glorot uniform initializer` (also called `Xavier uniform initializer`) which are used in the papers [SAINT](#https://arxiv.org/pdf/2002.07033.pdf) and [SAINT+](#https://arxiv.org/pdf/2010.12042.pdf), we need to overwrite some layers' `__init__` methods, which usually contain activation functions.
We copy some of them from Hugging Face `transformer` library, [see this link](#https://github.com/huggingface/transformers/blob/master/src/transformers/activations_tf.py).

In [None]:
def get_initializer(seed):

    return tf.keras.initializers.GlorotUniform(seed=seed)


def gelu(x):
    """
    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
    initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
    https://arxiv.org/abs/1606.08415
    """
    x = tf.convert_to_tensor(x)
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.math.sqrt(2.0), dtype=x.dtype)))

    return x * cdf


def gelu_new(x):
    """
    Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
    Args:
        x: float Tensor to perform activation
    Returns:
        `x` with the GELU activation applied.
    """
    x = tf.convert_to_tensor(x)
    pi = tf.cast(math.pi, x.dtype)
    coeff = tf.cast(0.044715, x.dtype)
    cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))

    return x * cdf


def gelu_fast(x):
    x = tf.convert_to_tensor(x)
    coeff1 = tf.cast(7978845608, x.dtype)
    coeff2 = tf.cast(0.044715, x.dtype)

    return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))


ACT2FN = {
    "gelu": tf.keras.layers.Activation(gelu),
    "relu": tf.keras.activations.relu,
    "swish": tf.keras.activations.swish,
    "silu": tf.keras.activations.swish,
    "gelu_new": tf.keras.layers.Activation(gelu_new),
    "tanh": tf.keras.activations.tanh,
    "gelu_fast": tf.keras.layers.Activation(gelu_fast),
}


def get_tf_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))

### Model Architecture

I tried the implementations to support:

- encoder only archeticture
- encoder-decoder archeticture: without auto-regressive generation
- encoder-decoder archeticture: with auto-regressive generation

#### Building blocks

This cell contains the building blocks (layer definitions), including encoder / decoder. The final model used for training / inference is in the next cell.

This is a quite large cell and I hide it. You can click `[Code]` button to see the implementation if you are interested.
As mentioned above, I copied [distilbert.py](#https://github.com/huggingface/transformers/tree/master/src/transformers/models/distilbert) file from [Hugging Face's transformer library](#https://github.com/huggingface/transformers), and adding (a lot of) my personal modification.

To not make the notebook become even larger, and due to the time constraint, I removed the original doc for the methods, and added the new doc for only a few methods.

In [None]:
class EdFormerConfig(PretrainedConfig):

    def __init__(
        self,
        model_type,
        model_desc,
        model_size='none',
        content_vocab_size=CONTENT_VOCAB_SIZE,
        response_vocab_size=RESPONSE_VOCAB_SIZE,
        tag_vocab_size=TAG_VOCAB_SIZE,
        part_vocab_size=PART_VOCAB_SIZE,
        prior_explanation_vocab_size=PRIOR_EXPLANATION_VOCAB_SIZE,
        max_position_embeddings=MAX_HISTORY_LEN,
        sinusoidal_pos_embds=False,
        n_layers=4,
        n_heads=8,
        dim=512,
        hidden_dim=4 * 512,
        activation=ACTIVATION,        
        dropout=0.1,
        attention_dropout=0.1,
        seq2seq_dropout=0.1,
        initializer_range=0.02,
        seed=SEED,        
        pad_token_id=PAD_ID,
        use_abs_pos=USE_ABS_POS,
        share_position_embeddings=SHARE_POS_EMBEDDING,
        use_tags=USE_TAGS,
        use_part=USE_PART,
        use_prior_explanation=USE_PRIOR_EXPLANATION,
        use_prior_question_elapsed_time_input=USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT,
        allow_bundle_atten=ALLOW_BUNDLE_ATTEN,
        generative=GENERATIVE,
        **kwargs
    ):
        super().__init__(**kwargs, pad_token_id=pad_token_id)
        
        self.model_type = model_type
        self.model_size = model_size
        self.model_desc = model_desc

        self.content_vocab_size = content_vocab_size
        self.response_vocab_size = response_vocab_size

        self.tag_vocab_size = tag_vocab_size
        self.part_vocab_size = part_vocab_size
        self.prior_explanation_vocab_size = prior_explanation_vocab_size

        self.use_abs_pos = use_abs_pos
        self.share_position_embeddings = share_position_embeddings
        self.use_tags = use_tags
        self.use_part = use_part
        self.use_prior_explanation = use_prior_explanation
        self.allow_bundle_atten = allow_bundle_atten
        self.generative = generative
        self.use_prior_question_elapsed_time_input = use_prior_question_elapsed_time_input

        self.max_position_embeddings = max_position_embeddings
        self.sinusoidal_pos_embds = sinusoidal_pos_embds
        
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dim = dim
        self.hidden_dim = hidden_dim

        self.activation = activation

        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.seq2seq_dropout = seq2seq_dropout

        self.initializer_range = initializer_range
        self.seed = seed

    def toJSON(self):

        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)

    def vocab_size(self, input_name):

        if input_name == 'content':
            return self.content_vocab_size
        elif input_name == 'response':
            return self.response_vocab_size
        elif input_name == 'tag':
            return self.tag_vocab_size
        elif input_name == 'part':
            return self.part_vocab_size            
        elif input_name == 'prior_explanation':
            return self.prior_explanation_vocab_size
        else:
            raise ValueError('input name not used for model')

    @property
    def hidden_size(self):
        return self.dim

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers


class TFSharedEmbeddings(tf.keras.layers.Layer):

    def __init__(self, vocab_size, hidden_size, seed=None, **kwargs):
        
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.seed = seed

    def build(self, input_shape):

        self.weight = self.add_weight(
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(seed=self.seed)
        )
        super().build(input_shape)

    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:

        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):

        first_dims = shape_list(inputs)[:-1]
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


class TFEmbeddings(tf.keras.layers.Layer):

    def __init__(self, config, input_name, position_embeddings_layer=None, **kwargs):
        
        super().__init__(**kwargs)
        
        self.pad_id = config.pad_token_id
        self.input_name = input_name
        self.vocab_size = config.vocab_size(input_name)

        self.use_tags = config.use_tags
        self.use_part = config.use_part
        self.use_prior_explanation = config.use_prior_explanation
        self.use_prior_question_elapsed_time_input = config.use_prior_question_elapsed_time_input

        self.dim = config.dim
        self.seed = config.seed
        self.word_embeddings = TFSharedEmbeddings(
            self.vocab_size, config.dim, seed=config.seed, name="word_embeddings"
        )  # padding_idx=0)
        
        if position_embeddings_layer is None:

            self.position_embeddings = tf.keras.layers.Embedding(
                config.max_position_embeddings,
                config.dim,
                embeddings_initializer=get_initializer(config.seed),
                name="position_embeddings",
            )
        
        else:

            self.position_embeddings = position_embeddings_layer

        if self.input_name == 'content':

            if config.use_tags:

                self.tag_embeddings = tf.keras.layers.Embedding(
                    config.tag_vocab_size,
                    config.dim,
                    embeddings_initializer=get_initializer(config.seed),
                    name="tag_embeddings",
                )

            if config.use_part:

                self.part_embeddings = tf.keras.layers.Embedding(
                    config.part_vocab_size,
                    config.dim,
                    embeddings_initializer=get_initializer(config.seed),
                    name="part_embeddings",
                )

        elif self.input_name == 'response':

            if config.use_prior_explanation:

                self.prior_explanation_embeddings = tf.keras.layers.Embedding(
                    config.prior_explanation_vocab_size,
                    config.dim,
                    embeddings_initializer=get_initializer(config.seed),
                    name="prior_explanation_embeddings",
                )

            if config.use_prior_question_elapsed_time_input:

                self.prior_question_elapsed_time_embeddings = tf.keras.layers.Dense(
                    config.dim,
                    kernel_initializer=get_initializer(config.seed),
                    name="prior_question_elapsed_time_embeddings",
                )

                assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
                    config.activation
                )
                self.activation = get_tf_activation(config.activation)

        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.dropout)

    def build(self, input_shape):
        """Build shared word embedding layer """
        
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight", shape=[self.vocab_size, self.dim], initializer=get_initializer(self.seed)
            )
        
        super().build(input_shape)

    def call(self, input_ids=None, position_ids=None, tag_ids=None, part_ids=None, prior_explanation_ids=None, prior_question_elapsed_time_input=None, inputs_embeds=None, mode="embedding", training=False):
        
        if mode == "embedding":
            return self._embedding(input_ids, position_ids, tag_ids, part_ids, prior_explanation_ids, prior_question_elapsed_time_input, inputs_embeds, training=training)
        elif mode == "linear":
            return self._linear(input_ids)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids, position_ids, tag_ids, part_ids, prior_explanation_ids, prior_question_elapsed_time_input, inputs_embeds, training=False):
        """
        """
        
        assert not (input_ids is None and inputs_embeds is None)

        if input_ids is not None:
            seq_length = shape_list(input_ids)[1]
        else:
            seq_length = shape_list(inputs_embeds)[1]

        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)

        if position_ids is None:
            position_ids = tf.range(1, 1 + seq_length, dtype=tf.int32)[tf.newaxis, :]

        position_embeddings = tf.cast(
            self.position_embeddings(position_ids), inputs_embeds.dtype
        )  # (bs, max_seq_length, dim)

        tag_embeddings = tf.zeros_like(inputs_embeds, dtype=inputs_embeds.dtype)
        if self.use_tags and tag_ids is not None:
            tag_embeddings = tf.cast(
                self.tag_embeddings(tag_ids), inputs_embeds.dtype
            )  # (bs, seq_len, N_TAGS_PER_CONTENT, dim)
            
            # shape = (bs, seq_len, N_TAGS_PER_CONTENT)
            tag_mask = tf.cast(tag_ids != self.pad_id, dtype=tf.int32)

            tag_embeddings = tag_embeddings * tf.cast(tag_mask, dtype=inputs_embeds.dtype)[:, :, :, tf.newaxis]
            
            # shape = (bs, seq_len)
            nb_tags = tf.math.reduce_sum(tag_mask, axis=2)
            nb_tags = tf.cast(nb_tags, dtype=inputs_embeds.dtype)
            nb_tags = tf.math.maximum(nb_tags, tf.cast(1.0, dtype=inputs_embeds.dtype))

            tag_embeddings = tf.math.reduce_sum(tag_embeddings, axis=2) / nb_tags[:, :, tf.newaxis]

        part_embeddings = tf.zeros_like(inputs_embeds, dtype=inputs_embeds.dtype)
        if self.use_part and part_ids is not None:
            part_embeddings = tf.cast(
                self.part_embeddings(part_ids), inputs_embeds.dtype
            )  # (bs, seq_len, dim)

        prior_explanation_embeddings = tf.zeros_like(inputs_embeds, dtype=inputs_embeds.dtype)
        if self.use_prior_explanation and prior_explanation_ids is not None:
            prior_explanation_embeddings = tf.cast(
                self.prior_explanation_embeddings(prior_explanation_ids), inputs_embeds.dtype
            )  # (bs, seq_len, dim)

        prior_question_elapsed_time_embeddings = tf.zeros_like(inputs_embeds, dtype=inputs_embeds.dtype)
        if self.use_prior_question_elapsed_time_input and prior_question_elapsed_time_input is not None:
            prior_question_elapsed_time_embeddings = self.prior_question_elapsed_time_embeddings(prior_question_elapsed_time_input[:, :, tf.newaxis])
            prior_question_elapsed_time_embeddings = self.activation(prior_question_elapsed_time_embeddings)        

        embeddings = inputs_embeds + position_embeddings + tag_embeddings + part_embeddings + prior_explanation_embeddings + prior_question_elapsed_time_embeddings # (bs, max_seq_length, dim)
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings, training=training)  # (bs, max_seq_length, dim)
        
        return embeddings

    def _linear(self, inputs):
        """
        Computes logits by running inputs through a linear layer
        Args:
            inputs: A float32 tensor with shape [batch_size, length, hidden_size]
        Returns:
            float32 tensor with shape [batch_size, length, vocab_size].
        """
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]

        x = tf.reshape(inputs, [-1, self.dim])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])


class TFMultiHeadSelfAttention(HFTFMultiHeadSelfAttention):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.q_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="q_lin"
        )
        self.k_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="k_lin"
        )
        self.v_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="v_lin"
        )
        self.out_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="out_lin"
        )

    def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
        """
        Parameters:
            query: tf.Tensor(bs, query_length, dim)
            key: tf.Tensor(bs, key_length, dim)
            value: tf.Tensor(bs, key_length, dim)
            mask: tf.Tensor(bs, query_length / 1, key_length)
        Returns:
            weights: tf.Tensor(bs, n_heads, query_length, key_length) Attention weights context: tf.Tensor(bs,
            query_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = shape_list(query)
        k_length = shape_list(key)[1]
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()
        dim_per_head = tf.math.divide(self.dim, self.n_heads)
        dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
        
        def shape(x):
            """ separate heads """
            return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))

        def unshape(x):
            """ group heads """
            return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
        q = tf.cast(q, dtype=DTYPE)
        q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=DTYPE)))
        k = tf.cast(k, dtype=q.dtype)
        scores = tf.matmul(q, k, transpose_b=True)  # (bs, n_heads, q_length, k_length)
        mask = mask[:, tf.newaxis, :, :]  # (bs, 1, qlen / 1, klen) --> (bs, n_heads, qlen, klen)
        # scores.masked_fill_(mask, -float('inf'))            # (bs, n_heads, q_length, k_length)

        mask = tf.cast(mask, dtype=scores.dtype)
        scores = scores - 1e30 * (1.0 - mask)
        weights = tf.nn.softmax(scores, axis=-1)  # (bs, n_heads, qlen, klen)
        weights = self.dropout(weights, training=training)  # (bs, n_heads, qlen, klen)
        # This makes things more numerically stable.
        weights = weights * mask

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = tf.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)
    

class TFFFN(HFTFFFN):

    def __init__(self, config, **kwargs):
        
        super(TFFFN, self).__init__(**kwargs)

        self.lin1 = tf.keras.layers.Dense(
            config.hidden_dim,
            kernel_initializer=get_initializer(config.seed),
            name="lin1"
        )
        self.lin2 = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="lin2"
        )

        self.activation = get_tf_activation(config.activation)


class TFTransformerBlock(HFTFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="attention")
 
    
class TFContentBlock(TFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="c_attention")
        

class TFResponseBlock(TFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)
    
        self.attention = TFMultiHeadSelfAttention(config, name="r_attention")
        
        self.r_c_attentioin = TFMultiHeadSelfAttention(config, name="r_c_attentioin")
        self.r_c_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="r_c_attn_layer_norm")  
        
    
    def call(self, r, c_hidden, r_mask, r_c_mask, head_mask, output_attentions, training=False):  # removed: src_enc=None, src_len=None

        r_output = self.attention(r, r, r, r_mask, head_mask, output_attentions, training=training)
        if output_attentions:
            r_output, r_weights = r_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            # assert type(sa_output) == tuple
            r_output = r_output[0]
        r_output = self.sa_layer_norm(r_output + r)  # (bs, seq_length, dim)

        r_c_output = self.r_c_attentioin(
            query=r_output,
            key=c_hidden,
            value=c_hidden,
            mask=r_c_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            training=training
        )
        
        if output_attentions:
            r_c_output, r_c_weights = r_c_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            # assert type(sa_output) == tuple
            r_c_output = r_c_output[0]        
        r_c_output = self.r_c_attn_layer_norm(r_c_output + r_output)
        
        # Feed Forward Network
        ffn_output = self.ffn(r_c_output, training=training)  # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + r_c_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (r_weights, r_c_weights) + output
        return output
       
    def generate(self, current_pos, current_r, prev_c_hidden, prev_d_hidden, r_mask, r_c_mask, head_mask, output_attentions):

        r_mask = r_mask[:, current_pos:current_pos+1 ,:]
        r_c_mask = r_c_mask[:, current_pos:current_pos+1, :]

        d_hidden = tf.concat([prev_d_hidden[:, :current_pos, :], current_r, prev_d_hidden[:, current_pos+1:, :]], axis=1)
        r_output = self.attention(
            current_r, d_hidden, d_hidden,
            r_mask,
            head_mask,
            output_attentions,
            training=False
        )
        if output_attentions:
            r_output, r_weights = r_output  # (bs, 1, dim), (bs, n_heads, 1, seq_length)
        else:  
            r_output = r_output[0]
        r_output = self.sa_layer_norm(r_output + current_r)  # (bs, 1, dim)

        r_c_output = self.r_c_attentioin(
            query=r_output,
            key=prev_c_hidden,
            value=prev_c_hidden,
            mask=r_c_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            training=False
        )
        
        if output_attentions:
            r_c_output, r_c_weights = r_c_output  # (bs, 1, dim), (bs, n_heads, 1, seq_length)
        else:
            r_c_output = r_c_output[0]        
        r_c_output = self.r_c_attn_layer_norm(r_c_output + r_output)  # (bs, 1, dim)
        
        # Feed Forward Network
        ffn_output = self.ffn(r_c_output, training=False)  # (bs, 1, dim)
        ffn_output = self.output_layer_norm(ffn_output + r_c_output)  # (bs, 1, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (r_weights, r_c_weights) + output
        return output


class TFCRBlock(TFTransformerBlock):
               
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="cr_attention") 


class TFContentCoder(TFTransformer):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)
        
        self.layer = [TFContentBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
    
    
class TFResponseCoder(TFTransformer):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)        
        
        self.layer = [TFResponseBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]

    def call(self, r_embeds, c_hidden, r_mask, r_c_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):
        # docstyle-ignore
        """
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = r_embeds
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            layer_outputs = layer_module(hidden_state, c_hidden, r_mask, r_c_mask, head_mask[i], output_attentions, training=training)
            hidden_state = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 3
                r_attn = layer_outputs[0]
                r_c_attn = layer_outputs[1]
                all_attentions = all_attentions + ((r_attn, r_c_attn),)
            else:
                assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )        
        
    def generate(
        self,
        current_pos,
        current_d_embedding,
        prev_c_hidden,
        all_prev_d_hidden,
        r_mask,
        r_c_mask,
        head_mask,
        output_attentions
    ):

        all_hidden_states = ()
        all_attentions = () if output_attentions else None

        current_d_hidden = current_d_embedding
        for i, (layer_module, prev_d_hidden) in enumerate(zip(self.layer, all_prev_d_hidden[:-1])):

            all_hidden_states = all_hidden_states + (current_d_hidden,)

            layer_outputs = layer_module.generate(
                current_pos, current_d_hidden, prev_c_hidden, prev_d_hidden, r_mask, r_c_mask, head_mask[i], output_attentions
            )
            current_d_hidden = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 3
                r_attn = layer_outputs[0]
                r_c_attn = layer_outputs[1]
                all_attentions = all_attentions + ((r_attn, r_c_attn),)
            else:
                assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"

        # Add last layer
        all_hidden_states = all_hidden_states + (current_d_hidden,)

        return tuple(v for v in [current_d_hidden, all_hidden_states, all_attentions] if v is not None)        


class TFCRCoder(TFTransformer):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)
        
        self.layer = [TFCRBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
        
    def call(self, c_embeds, r_embeds, c_mask, r_mask, c_r_mask, r_c_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):

        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None        
        
        # shape = [batch_size, 2 * seq_len, hidden_dim]
        embeds = tf.concat([c_embeds, r_embeds], axis=1)

        _c_mask = tf.concat([c_mask, c_r_mask], axis=2)
        _r_mask = tf.concat([r_c_mask, r_mask], axis=2)
        
        attn_mask = tf.concat([_c_mask, _r_mask], axis=1)
        
        hidden_state = embeds
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)
                
            layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
            hidden_state = layer_outputs[-1]
            
            if output_attentions:
                assert len(layer_outputs) == 2
                attn = layer_outputs[0]
                all_attentions = all_attentions + (attn,)
            else:
                assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"            
            
        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)            
            
        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )  

        
class TFEdFormerEncoder(tf.keras.layers.Layer):
    
    def __init__(self, config, **kwargs):

        super().__init__(**kwargs)

        self.cr_encoder = TFCRCoder(config, name='cr_encoder')
    
    def call(
        self,
        c_embeds,
        r_embeds,
        d_embeds,
        c_mask,
        r_mask,      
        r_c_mask,
        c_r_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):

        # currently, only empty tuple is returned
        c_all_hidden_states = () if output_hidden_states else None
        c_all_attentions = () if output_attentions else None    
        r_all_hidden_states = () if output_hidden_states else None
        r_all_attentions = () if output_attentions else None 

        cr_outputs = self.cr_encoder(
            c_embeds, r_embeds, c_mask, r_mask, c_r_mask, r_c_mask, head_mask,
            output_attentions, output_hidden_states, return_dict, training=training
        )
        
        if not return_dict:
            cr_hidden = cr_outputs[0]
        else:
            cr_hidden = cr_outputs.last_hidden_state
        
        c_seq_len = tf.math.reduce_sum(tf.ones_like(c_embeds[0, :, 0], dtype=tf.int32))
        r_seq_len = tf.math.reduce_sum(tf.ones_like(r_embeds[0, :, 0], dtype=tf.int32))
        
        c_hidden = cr_hidden[:, 0:c_seq_len, :]
        r_hidden = cr_hidden[:, c_seq_len:(c_seq_len + r_seq_len), :]

        # Adding all hidden states and attentions
        if not return_dict:
            c_outputs = tuple(v for v in [c_hidden, c_all_hidden_states, c_all_attentions] if v is not None)
            r_outputs = tuple(v for v in [r_hidden, r_all_hidden_states, r_all_attentions] if v is not None)
        else:
            c_outputs =TFBaseModelOutput(
                last_hidden_state=c_hidden, hidden_states=c_all_hidden_states, attentions=c_all_attentions
            )         
            r_outputs =TFBaseModelOutput(
                last_hidden_state=r_hidden, hidden_states=r_all_hidden_states, attentions=r_all_attentions
            )

        hidden_states = c_hidden

        return (hidden_states, c_outputs, r_outputs)


def process_mask(mask, input_shpae):
    
    if mask is None:
        mask = tf.ones(input_shpae)
    mask = tf.cast(mask, dtype=tf.float32)            
    
    return mask


class TFEdFormerEncoderDecoder(tf.keras.layers.Layer):
    
    def __init__(self, config, **kwargs):

        super().__init__(**kwargs)

        self.encoder = TFContentCoder(config, name="encoder")  # Encoder
        self.decoder = TFResponseCoder(config, name="decoder")  # Decoder
    
    def call(
        self,
        c_embeds,
        r_embeds,
        d_embeds,
        c_mask,
        r_mask,       
        r_c_mask,
        c_r_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):

        c_outputs = self.encoder(
            c_embeds, c_mask, head_mask,
            output_attentions, output_hidden_states, return_dict, training
        )
        
        if not return_dict:
            c_hidden = c_outputs[0]
        else:
            c_hidden = c_outputs.last_hidden_state
        
        r_outputs = self.decoder(
            d_embeds, c_hidden, r_mask, r_c_mask, head_mask,
            output_attentions, output_hidden_states, return_dict, training
        )      

        if not return_dict:
            r_hidden = r_outputs[0]
        else:
            r_hidden = r_outputs.last_hidden_state

        hidden_states = r_hidden

        return (hidden_states, c_outputs, r_outputs)

    def generate(
        self,
        current_pos,
        current_d_embedding,
        prev_c_hidden,
        all_prev_d_hidden,
        r_mask,
        r_c_mask,
        head_mask,
        output_attentions
    ):

     return self.decoder.generate(current_pos, current_d_embedding, prev_c_hidden, all_prev_d_hidden, r_mask, r_c_mask, head_mask, output_attentions)


class TFEdFormerMainLayer(tf.keras.layers.Layer):
    
    config_class = EdFormerConfig

    def __init__(self, config, **kwargs):

        super().__init__(**kwargs)
        
        self.num_hidden_layers = config.num_hidden_layers
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.return_dict = config.use_return_dict
        
        self.use_tags = config.use_tags
        self.use_part = config.use_part
        self.use_prior_explanation = config.use_prior_explanation
        self.use_prior_question_elapsed_time_input = config.use_prior_question_elapsed_time_input             

        if config.share_position_embeddings:
            # All `TFEmbeddings` share a single `position_embeddings`.
            position_embeddings = tf.keras.layers.Embedding(
                config.max_position_embeddings,
                config.dim,
                embeddings_initializer=get_initializer(config.seed),
                name="position_embeddings",
        )
        else:
            # Each `TFEmbeddings` will have its own `position_embeddings`.
            position_embeddings = None
              
        self.content_embeddings = TFEmbeddings(
            config,
            input_name = 'content',
            position_embeddings_layer=position_embeddings,
            name="content_embeddings"
        )
        
        self.response_embeddings = TFEmbeddings(
            config,
            input_name = 'response',
            position_embeddings_layer=position_embeddings,
            name="response_embeddings"
        )
        
        if config.model_type == 'cr':
            self.coder = TFEdFormerEncoder(config, name='coder')
        elif config.model_type == 'ed':
            self.coder = TFEdFormerEncoderDecoder(config, name='coder')
                
    def call(
        self,
        inputs,
        c_mask=None,
        r_mask=None,
        r_c_mask=None,
        c_r_mask=None,
        head_mask=None,    
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,        
    ):

        c_input_ids = inputs.get('c_input_ids')
        r_input_ids = inputs.get('r_input_ids')
        d_input_ids = inputs.get('d_input_ids')

        batch_size = tf.math.reduce_sum(tf.ones_like(c_input_ids[:, 0], dtype=tf.int32))
        c_seq_len = tf.math.reduce_sum(tf.ones_like(c_input_ids[0, :], dtype=tf.int32))
        r_seq_len = tf.math.reduce_sum(tf.ones_like(r_input_ids[0, :], dtype=tf.int32))
    
        # The simplest way to compute positions.
        pos_ids = tf.range(c_seq_len, dtype=tf.int32) + 1
        shifted_pos_ids = tf.concat([[0], pos_ids[:-1]], axis=0)

        # positional information provided
        pos_ids = inputs.get('pos_ids', pos_ids)
        shifted_pos_ids = inputs.get('shifted_pos_ids', shifted_pos_ids)

        if self.use_tags:
            tag_ids = inputs.get('tag_ids')
        else:
            tag_ids = None
        
        if self.use_part:
            part_ids = inputs.get('part_ids')
        else:
            part_ids = None

        if self.use_prior_explanation:
            prior_explanation_ids = inputs.get('prior_explanation_ids')
        else:
            prior_explanation_ids = None

        if self.use_prior_question_elapsed_time_input:
            prior_question_elapsed_time_input = inputs.get('prior_question_elapsed_time_input')
        else:
            prior_question_elapsed_time_input = None

        c_mask = inputs.get('c_mask', c_mask)
        r_mask = inputs.get('r_mask', r_mask)
        r_c_mask = inputs.get('r_c_mask', r_c_mask)
        c_r_mask = inputs.get('c_r_mask', c_r_mask)                        
        head_mask = inputs.get('head_mask', head_mask)
        
        output_attentions = inputs.get('output_attentions', output_attentions)
        output_hidden_states = inputs.get('output_hidden_states', output_hidden_states)
        return_dict = inputs.get('return_dict', return_dict)
             
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.return_dict
    
        c_mask = process_mask(c_mask, [batch_size, c_seq_len, c_seq_len])
        r_mask = process_mask(r_mask, [batch_size, r_seq_len, r_seq_len])
        r_c_mask = process_mask(r_c_mask, [batch_size, r_seq_len, c_seq_len])
        c_r_mask = process_mask(c_r_mask, [batch_size, c_seq_len, r_seq_len])        
        
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            
        c_embedding_output = self.content_embeddings(input_ids=c_input_ids, position_ids=pos_ids, tag_ids=tag_ids, part_ids=part_ids, prior_explanation_ids=None, training=training)  # (bs, seq_length, dim)
        r_embedding_output = self.response_embeddings(input_ids=r_input_ids, position_ids=pos_ids, tag_ids=None, part_ids=None, prior_explanation_ids=None, training=training)  # (bs, seq_length, dim)
        d_embedding_output = self.response_embeddings(input_ids=d_input_ids, position_ids=shifted_pos_ids, tag_ids=None, part_ids=None, prior_explanation_ids=prior_explanation_ids, prior_question_elapsed_time_input=prior_question_elapsed_time_input, training=training)  # (bs, seq_length, dim)

        outputs = self.coder(
            c_embedding_output,
            r_embedding_output,
            d_embedding_output,
            c_mask,
            r_mask,
            r_c_mask,
            c_r_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
            return_dict,
            training=training           
        )
        
        return outputs

    def generate(
        self,
        inputs,
        current_pos,
        seq_len,
        updated_d_input_ids,  #  (bs, seq_len, dim)
        prev_c_hidden,
        all_prev_d_hidden,
        r_mask,
        r_c_mask,
        head_mask,
        output_attentions
    ):
        # The simplest way to compute positions.
        pos_ids = tf.range(seq_len, dtype=tf.int32) + 1
        shifted_pos_ids = tf.concat([[0], pos_ids[:-1]], axis=0)

        # positional information provided
        pos_ids = inputs.get('pos_ids', pos_ids)
        shifted_pos_ids = inputs.get('shifted_pos_ids', shifted_pos_ids)

        if self.use_prior_explanation:
            prior_explanation_ids = inputs.get('prior_explanation_ids')
        else:
            prior_explanation_ids = None

        if self.use_prior_question_elapsed_time_input:
            prior_question_elapsed_time_input = inputs.get('prior_question_elapsed_time_input')
        else:
            prior_question_elapsed_time_input = None

        current_d_embedding = self.response_embeddings(
            input_ids=updated_d_input_ids,
            position_ids=shifted_pos_ids,
            tag_ids=None,
            part_ids=None,
            prior_explanation_ids=prior_explanation_ids,
            prior_question_elapsed_time_input=prior_question_elapsed_time_input,
        )[:, current_pos:current_pos+1, :]

        outputs = self.coder.generate(
            current_pos,
            current_d_embedding,
            prev_c_hidden,
            all_prev_d_hidden,
            r_mask,
            r_c_mask,
            head_mask,
            output_attentions=output_attentions
        )
        
        return outputs    


class TFEdFormerPreTrainedModel(TFPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = EdFormerConfig
    base_model_prefix = "edformer"
    
    
class TFEdFormerModel(TFEdFormerPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        
        super().__init__(config, *inputs, **kwargs)
        
        self.edformer = TFEdFormerMainLayer(config, name="edformer")  # Embeddings

    def call(self, inputs, **kwargs):
        
        outputs = self.edformer(
            inputs, 
            **kwargs
        )
        
        return outputs

#### The model

The model for answer correctness prediction.

In particular, a `generate()` method is implemented for auto-regressive prediction generation. It uses some `generate()` methods in the build block layers defined in the previous code cell.

<center><a href="https://ibb.co/yf9jCt0"><img src="https://i.ibb.co/bW09VT1/Capture777.png" width="700" alt="Capture777" border="0" /></a></center>
<br>
<center><h5>auto-regressive prediction generation</h5></center>
<br>

* credit for the picture: [Mariya Yao: Novel Methods For Text Generation Using Adversarial Learning & Autoencoders](#https://www.topbots.com/ai-research-gan-vae-text-generation/)

In [None]:
class TFEdFormerAnswerPredictionModel(TFPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        
        super().__init__(config, *inputs, **kwargs)
    
        self.edformer = TFEdFormerMainLayer(config, name="edformer")  # Embeddings

        self.pre_classifier = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="pre_classifier",
        )       

        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = get_tf_activation(config.activation)

        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.seed), name="classifier"
        )
        self.dropout = tf.keras.layers.Dropout(config.seq2seq_dropout)
        
    def call(
        self,
        inputs=None,
        c_mask=None,
        r_mask=None,
        r_c_mask=None,
        c_r_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
    ):

        return_dict = return_dict if return_dict is not None else self.edformer.return_dict

        edformer_output = self.edformer(
            inputs,
            c_mask=c_mask,
            r_mask=r_mask,
            r_c_mask=r_c_mask,
            c_r_mask=c_r_mask,
            head_mask=head_mask,          
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        hidden_states, c_outputs, r_outputs = edformer_output

        o = self.pre_classifier(hidden_states)  # (bs, dim)
        o = self.activation(o)
        o = self.dropout(o, training=training)        
        
        logits = self.classifier(o)  # (bs, seq_len, dim)
        logits = logits[:, :, 0]

        return (logits, c_outputs, r_outputs)

    def generate(
        self,
        inputs,
        start_pos,
        window_size,
        dim,
        c_mask=None,
        r_mask=None,
        r_c_mask=None,
        c_r_mask=None,
        head_mask=None,
        output_attentions=None
    ):
        """
        Generate the predictions in a auto-regressive way.
        """

        question_mask = tf.cast((inputs['content_type_id'] == 0), dtype=tf.int32)
        pred_time_question_mask = inputs['pred_time_mask'] * question_mask
        pred_time_question_mask_shifted = tf.concat([tf.zeros_like(pred_time_question_mask[:, -1:], dtype=tf.int32), pred_time_question_mask[:, 0:-1]], axis=1)

        preds = tf.cast(inputs['target'], dtype=tf.float32)

        d_hidden, c_outputs, r_outputs = self.edformer(
            inputs,
            c_mask=c_mask,
            r_mask=r_mask,
            r_c_mask=r_c_mask,
            c_r_mask=c_r_mask,
            head_mask=head_mask,          
            output_attentions=output_attentions,
            output_hidden_states=True,
            training=False,
        )

        o = self.pre_classifier(d_hidden)  # (bs, seq_len, dim)
        o = self.activation(o)
        o = self.dropout(o, training=False)

        logits = self.classifier(o)[:, :, 0]  # (bs, seq_len)

        prev_c_hidden = c_outputs[0]
        all_prev_d_hidden = r_outputs[1]

        c_input_ids = inputs.get('c_input_ids')
        d_input_ids = inputs.get('d_input_ids')
        updated_d_input_ids = d_input_ids

        batch_size = tf.math.reduce_sum(tf.ones_like(c_input_ids[:, 0], dtype=tf.int32))
        c_seq_len = tf.math.reduce_sum(tf.ones_like(c_input_ids[0, :], dtype=tf.int32))
        d_seq_len = tf.math.reduce_sum(tf.ones_like(updated_d_input_ids[0, :], dtype=tf.int32))
  
        r_mask = process_mask(r_mask, [batch_size, d_seq_len, d_seq_len])
        r_c_mask = process_mask(r_c_mask, [batch_size, d_seq_len, c_seq_len])
 
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.edformer.num_hidden_layers     

        output_attentions = inputs.get('output_attentions', output_attentions)
        output_attentions = output_attentions if output_attentions is not None else self.edformer.output_attentions

        for current_pos in tf.range(start_pos, d_seq_len, dtype=tf.int32):

            _current_d_hidden, all_current_d_hidden = self.edformer.generate(
                inputs,
                current_pos,
                c_seq_len,
                updated_d_input_ids,
                prev_c_hidden,
                all_prev_d_hidden,
                r_mask,
                r_c_mask,
                head_mask,
                output_attentions=output_attentions
            )[0:2]

            all_prev_d_hidden = tuple([
                    tf.reshape(tf.concat(
                        [x[:, :current_pos, :], y, x[:, current_pos+1:d_seq_len, :]],
                        axis=1
                    ), shape=[batch_size, window_size, dim]) for x, y in zip(all_prev_d_hidden, all_current_d_hidden)
            ])

            # ------------------------------------------------------------------------------------------
            # The logit for the place at `current_pos`.

            o = self.pre_classifier(_current_d_hidden)  # (bs, 1, dim)
            o = self.activation(o)
            o = self.dropout(o, training=False)

            _logits = self.classifier(o)[:, :, 0]  # (bs, 1)
            _preds = tf.math.sigmoid(_logits)  # (bs, 1)

            logits = tf.concat([logits[:, :current_pos], _logits, logits[:, current_pos+1:d_seq_len]], axis=1)
            preds = tf.concat([preds[:, :current_pos], _preds, preds[:, current_pos+1:d_seq_len]], axis=1)

            logits = tf.reshape(logits, shape=[batch_size, window_size])
            preds = tf.reshape(preds, shape=[batch_size, window_size])

            # ------------------------------------------------------------------------------------------
            # update the d_input_ids at `current_pos + 1`

            predicted_answered_correctly = tf.cast(_preds >= 0.5, dtype=tf.int32)  # (bs, 1)

            if current_pos < d_seq_len - 1:
                
                mask = tf.cast(updated_d_input_ids != PAD_ID, dtype=tf.int32) * tf.cast(updated_d_input_ids != START_ID, dtype=tf.int32) * tf.cast(updated_d_input_ids != RESPONSE_LECTURE_ID, dtype=tf.int32)
                mask = mask * pred_time_question_mask_shifted
                mask = mask[:, current_pos + 1:current_pos + 2]

                predicted_next_response_id = (predicted_answered_correctly + RESPONSE_FALSE_ID) * mask + (1 - mask) * updated_d_input_ids[:, current_pos + 1:current_pos + 2]

                updated_d_input_ids = tf.concat([updated_d_input_ids[:, :current_pos+1], predicted_next_response_id[:, :1], updated_d_input_ids[:, current_pos+2:]], axis=1)

                updated_d_input_ids = tf.reshape(updated_d_input_ids, shape=[batch_size, window_size])
            # ------------------------------------------------------------------------------------------

        return preds, logits

### check

In [None]:
content_input_ids = tf.constant(1, shape=[3, 5])
response_input_ids = tf.constant(1, shape=[3, 5])
decoder_input_ids = tf.constant(1, shape=[3, 5])
tag_ids = tf.constant(0, shape=[3, 5, N_TAGS_PER_CONTENT])
part_ids = tf.constant(1, shape=[3, 5])
prior_explanation_ids = tf.constant(1, shape=[3, 5])
prior_question_elapsed_time_inputs = tf.constant(1.0, shape=[3, 5])

inputs = {
    'c_input_ids': content_input_ids,
    'r_input_ids': response_input_ids,
    'd_input_ids': decoder_input_ids,
    'tag_ids': tag_ids,
    'part_ids': part_ids,
    'prior_explanation_ids': prior_explanation_ids,
    'prior_question_elapsed_time_inputs': prior_question_elapsed_time_inputs
}

for model_type in ['cr', 'ed']:
    config = EdFormerConfig(model_type=model_type, model_desc='dummy', share_position_embeddings=True, use_tags=True, user_part=True, use_prior_explanation=True, use_prior_question_elapsed_time_input=True)
    predictor = TFEdFormerAnswerPredictionModel(config)
    logits, c_outputs, r_outputs = predictor(inputs=inputs, output_attentions=True, output_hidden_states=True, return_dict=False)
    print(logits)

## Training Manager<a id='training-manager'></a>

### Learing rate with warmup

Two learning rate schedules are provided: `Linear` and `Noam`, both with linear warmup. The `Linear` schedule has linear decay.

In [None]:
class LinearLR(tf.keras.optimizers.schedules.LearningRateSchedule):
    
    def __init__(self, total_steps, lr=1e-4, end_lr=1e-6, warmup_steps=WARMUP_STEPS):
        
        self.total_steps = tf.cast(total_steps, dtype=DTYPE)
        self.lr = lr
        self.end_lr = end_lr
        self.warmup_steps = tf.cast(warmup_steps, dtype=DTYPE)
        
    def __call__(self, step):
        
        is_warmup = tf.cast(step < self.warmup_steps, dtype=DTYPE)
        
        warmup_lr = is_warmup * self.lr * (step + 1) / self.warmup_steps 
        decay_lr = (1 - is_warmup) * (self.lr - (self.lr - self.end_lr) / tf.math.maximum(self.total_steps - self.warmup_steps, 1) * (step - self.warmup_steps + 1))
        decay_lr = (1 - is_warmup) * tf.math.maximum(self.end_lr, decay_lr)

        lr = warmup_lr + decay_lr

        return lr

class NoamLR(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, hidden_size, lr, end_lr, warmup_steps):

        self.lr = tf.cast(lr, DTYPE)
        self.end_lr = tf.cast(end_lr, DTYPE)
        self.warmup_steps = tf.cast(warmup_steps, DTYPE)
        self.hidden_size = tf.cast(hidden_size, DTYPE)

    def __call__(self, step):

        scaling = self.lr / (self.hidden_size**-0.5 * self.warmup_steps**-0.5)

        lr = scaling * self.hidden_size**-0.5 * tf.math.minimum((step + 1) * self.warmup_steps**-1.5, (step + 1)**-0.5)

        is_warmup = tf.cast(step < self.warmup_steps, dtype=DTYPE)

        lr = lr * is_warmup + (1.0 - is_warmup) * tf.math.maximum(self.end_lr, lr) 

        return lr

### Model input signatures

In [None]:
def get_input_signatures(batch_size, seq_len, valid=False):

    input_signatures = {
        'user_id': tf.TensorSpec(shape=[batch_size], dtype=tf.int64, name='user_id'),
        'seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='seq_len'),
        'prev_seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='prev_seq_len'),
        'start': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='start'),
        'end': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='end'),
        'row_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='row_id'),
        'timestamp': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='timestamp'),
        'content_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_id'),
        'content_type_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_type_id'),
        'task_container_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='task_container_id'),
        'user_answer': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='user_answer'),
        'answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='answered_correctly'),
        'shifted_answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_answered_correctly'),
        'prior_question_elapsed_time': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='prior_question_elapsed_time'),
        'prior_question_elapsed_time_input': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='prior_question_elapsed_time_input'),
        'prior_question_had_explanation': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='prior_question_had_explanation'),
        'pred_time_mask': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='pred_time_mask'),
        'abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='abs_pos'),
        'shifted_abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_abs_pos'),
        'pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='pos_ids'),
        'shifted_pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_pos_ids'),                                  
        'c_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='c_input_ids'),
        'r_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='r_input_ids'),
        'd_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='d_input_ids'),
        'tag_ids': tf.TensorSpec(shape=[batch_size, seq_len, N_TAGS_PER_CONTENT], dtype=tf.int32, name='tag_ids'),
        'part_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_ids'),                 
        'prior_explanation_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='prior_explanation_ids'),                                      
        'target': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='target'),
        'nb_pred_places': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='nb_pred_places')
    }

    if valid:

        input_signatures['n_valid_blocks'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='n_valid_blocks')
        input_signatures['valid_block_idx'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_block_idx')
        input_signatures['valid_start'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_start')
        input_signatures['valid_end'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_end')

    return [input_signatures]

### Training Manager<a id='training-manager-main'></a>

In [None]:
class TrainConfig:

    def __init__(
        self,
        ckpt_path=CKPT_TRAIN_PATH,
        window_size=WINDOW_SIZE,
        loss_weight_window_size=LOSS_WEIGHT_WINDOW_SIZE,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        pred_batch_size=PRED_BATCH_SIZE,
        shuffle_buf_size=SHUFFLE_BUFFER_SIZE,
        seed=SEED,
        deterministic=DETERMINISTIC,
        num_parallel_reads=N_PARALLEL_READS,
        num_parallel_calls=N_PARALLEL_CALLS,
        steps_per_call=STEPS_PER_CALL
    ):

        self.ckpt_path = ckpt_path
        self.window_size = window_size
        self.loss_weight_window_size = loss_weight_window_size
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.pred_batch_size = pred_batch_size
        self.shuffle_buf_size = shuffle_buf_size
        self.seed=seed
        self.deterministic = deterministic
        self.num_parallel_reads = num_parallel_reads
        self.num_parallel_calls = num_parallel_calls
        self.steps_per_call = steps_per_call

        n_contents_dict = load_data(n_contents_dict_path)
        n_training_examples = 0
        for k, v in n_contents_dict.items():
            n_training_examples += math.ceil(v * 1.0 / self.window_size)
        del n_contents_dict

        valid_info = load_data(valid_info_path)
        n_valid_examples = 0
        for k, v in valid_info.items():
            n_valid_examples += v['bundle_info']['n_blocks']
        del valid_info

        # This is slightly higher than the actual number of examples used for training.
        self.n_training_examples = n_training_examples
        n_training_steps_per_epoch = self.n_training_examples // self.batch_size
        self.n_training_calls_per_epoch = n_training_steps_per_epoch // self.steps_per_call
        self.n_training_steps_per_epoch = self.n_training_calls_per_epoch * self.steps_per_call
        self.n_training_steps = self.n_epochs * self.n_training_steps_per_epoch

        self.n_valid_examples = n_valid_examples
        if IS_KAGGLE and tpu is not None:
            self.n_valid_examples -= self.n_valid_examples % self.pred_batch_size
        self.n_valid_steps = self.n_valid_examples // self.pred_batch_size + int(self.n_valid_examples % self.pred_batch_size > 0)
        self.n_steps_in_last_valid_call = self.n_valid_steps % self.steps_per_call
        self.n_valid_calls = self.n_valid_steps // self.steps_per_call + int(self.n_steps_in_last_valid_call > 0)

        self.lr = None
        self.end_lr = None
        self.warmup_steps = None

    def toJSON(self):

        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)

class Train_Manager:
    
    def __init__(self, config, train_config):
        
        self.config = config
        self.train_config = train_config

        self.train_ds = strategy.experimental_distribute_dataset(self.get_train_ds())
        self.valid_ds = strategy.experimental_distribute_dataset(self.get_valid_ds())
                
    def toJSON(self):

        config = json.loads(self.config.toJSON())
        train_config = json.loads(self.train_config.toJSON())

        for k, v in train_config.items():
            config[k] = v

        return json.dumps(config, sort_keys=False, ensure_ascii=False, indent=4)

    def get_train_ds(self):

        train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=self.train_config.num_parallel_reads)
        train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)


        # Get the splitted training dataset.
        train_valid_split_indices = load_data(train_valid_split_indices_path)            
        train_valid_split_table = convert_split_index_dict(train_valid_split_indices)
        reduced_raw_ds = split_train_ds(train_raw_ds, train_valid_split_table, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)

        # repeat different times for each user, according to their history lengths.
        # This is used to avoid overfitting on users with much shorter history sequences.
        reduced_raw_ds_augmented = reduced_raw_ds.flat_map(
            lambda x: tf.data.Dataset.from_tensors(x).repeat(
                tf.cast(
                    tf.math.ceil(tf.cast(x['seq_len'], tf.float32) / self.train_config.window_size), dtype=tf.int64
                )
            )
        )

        # batch examples with attributes having different lengths across examples - tf.RaggedTensor
        batched_raw_ds = reduced_raw_ds_augmented.repeat().shuffle(self.train_config.shuffle_buf_size, seed=self.train_config.seed).apply(tf.data.experimental.dense_to_ragged_batch(self.train_config.batch_size))
        
        # batch - tf.Tensor: Extract random subsequences of a fixed length
        batched_ds = random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size=self.train_config.window_size, only_last=False, seed=self.train_config.seed, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)
        
        # Add input ids and targets for training        
        train_ds = prepare_training_dataset(batched_ds, generative=self.config.generative, use_abs_pos=self.config.use_abs_pos, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)
        
        train_ds = train_ds.prefetch(8)

        return train_ds

    def get_valid_ds(self):

        p = valid_tfrec_paths
        num_parallel_reads = self.train_config.num_parallel_reads
        num_parallel_calls = self.train_config.num_parallel_calls
        deterministic = self.train_config.deterministic

        valid_raw_ds = tf.data.TFRecordDataset(p, num_parallel_reads=num_parallel_reads)
        valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=num_parallel_calls, deterministic=deterministic)
        valid_ds = prepare_validation_dataset(valid_raw_ds, batch_size=self.train_config.pred_batch_size, window_size=self.train_config.window_size, generative=self.config.generative, use_abs_pos=self.config.use_abs_pos, num_parallel_calls=num_parallel_calls, deterministic=deterministic)
        
        valid_ds = valid_ds.prefetch(8)

        return valid_ds

    def get_train_objs(self, lr=LR, end_lr=END_LR, warmup_steps=WARMUP_STEPS):
        
        self.train_config.lr = lr
        self.train_config.end_lr = end_lr,
        self.train_config.warmup_steps = warmup_steps

        with strategy.scope():

            predictor = TFEdFormerAnswerPredictionModel(self.config)

            _lr = NoamLR(hidden_size=self.config.dim, lr=lr, end_lr=end_lr, warmup_steps=warmup_steps)

            # optimizer = tf.keras.optimizers.Adam(learning_rate=_lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
            optimizer = tf.keras.optimizers.Adam(
                _lr,
                beta_1=tf.Variable(0.9),
                beta_2=tf.Variable(0.999),
                epsilon=tf.Variable(1e-8)
            )
            optimizer.iterations  # this access will invoke optimizer._iterations method and create optimizer.iter attribute
            optimizer.decay = tf.Variable(0.0) # Adam.__init__ assumes ``decay`` is a float object, so this needs to be converted to tf.Variable **after** __init__ method.

            loss_obj = tf.keras.losses.BinaryCrossentropy(
                from_logits=True, label_smoothing=0, reduction=tf.keras.losses.Reduction.NONE,
                name='binary_crossentropy'
            )

            loss_metric = tf.keras.metrics.Sum()
            acc_metric = tf.keras.metrics.BinaryAccuracy()
            auc_metric = tf.keras.metrics.AUC(num_thresholds=2000)

            # --------------------------------------------------

            metrics = (loss_metric, acc_metric, auc_metric)
        
        return (predictor, optimizer, loss_obj, metrics)

    def train_valid(self, predictor, optimizer, ckpt_manager, loss_obj, metrics, last_epoch=0, valid_epochs=None, only_valid=False):

        batch_size = self.train_config.batch_size
        seq_len = self.train_config.window_size

        encoder_decoder = tf.constant(self.config.model_type=='ed', dtype=tf.int32)
        generative = tf.constant(self.config.generative, dtype=tf.int32)
        allow_bundle_atten = tf.constant(self.config.allow_bundle_atten, dtype=tf.int32)

        loss_metric, acc_metric, auc_metric = metrics

        @tf.function(
            input_signature=get_input_signatures(int(batch_size / strategy.num_replicas_in_sync), seq_len, valid=False)
        )
        def train_step(train_batch):

            training = tf.constant(1, dtype=tf.int32)
        
            c_mask, r_mask, r_c_mask, c_r_mask = get_attention_masks(train_batch, training, encoder_decoder, generative, allow_bundle_atten)

            targets = train_batch['target']
            pred_mask = targets != NON_TARGET_ID
            pred_indices = tf.where(pred_mask)

            selected_targets = tf.gather_nd(targets, pred_indices)
            
            # Compute the gradients for a list of variables.
            with tf.GradientTape() as tape:

                (logits, c_outputs, r_outputs) = predictor(
                    train_batch,
                    c_mask,
                    r_mask,
                    r_c_mask,
                    c_r_mask,
                    output_attentions=False, output_hidden_states=False, training=True
                )
                selected_logits = tf.gather_nd(logits, pred_indices)

                # Need to have the 1st dimension to make the losses not averaged
                losses = loss_obj(selected_targets[:, tf.newaxis], selected_logits[:, tf.newaxis])
                total_loss = tf.math.reduce_sum(losses)

                # `train_batch['nb_pred_places'][0]` is the total number of places used for calculating loss across replicas
                loss = total_loss / tf.cast(train_batch['nb_pred_places'][0], dtype=DTYPE)
                            
            grads = tape.gradient(loss, predictor.trainable_variables)

            # Ask the optimizer to apply the processed gradients.
            optimizer.apply_gradients(zip(grads, predictor.trainable_variables))
            
            preds = tf.math.sigmoid(logits)
            selected_preds = tf.math.sigmoid(selected_logits)
            
            loss_metric.update_state(total_loss)
            acc_metric.update_state(selected_targets[:, tf.newaxis], selected_preds[:, tf.newaxis])
            auc_metric.update_state(selected_targets, selected_preds)

        batch_size = None
        @tf.function(
            input_signature=get_input_signatures(batch_size, seq_len, valid=True)
        )
        def valid_step(valid_batch):

            training = tf.constant(0, dtype=tf.int32)

            c_mask, r_mask, r_c_mask, c_r_mask = get_attention_masks(valid_batch, training, encoder_decoder, generative, allow_bundle_atten)
            
            # This only works with CPU / GPU. On TPU, XLA can't compile the graph for this part.
#             if generative == 1:
            
#                start_pos = self.train_config.window_size - MAX_PRED_TIME_QUESTION_BUNDLE_LEN
#                _, logits = predictor.generate(
#                    valid_batch, start_pos=start_pos, window_size=self.train_config.window_size, dim=self.config.dim,
#                    c_mask=c_mask, r_mask=r_mask, r_c_mask=r_c_mask, c_r_mask=c_r_mask
#                )

#             else:

            logits, _, _ = predictor(
                valid_batch,
                c_mask,
                r_mask,
                r_c_mask,
                c_r_mask,
                output_attentions=False, output_hidden_states=False, training=False
            )
            
            # `targets` are defined for all places (other than [PAD] and lectures).
            # However, during validation, unlike during training, we only focus on the places that are in prediction time (and being questions).
            pred_time_mask = valid_batch['pred_time_mask']
            targets = valid_batch['target']

           # only select the last `MAX_PRED_TIME_QUESTION_BUNDLE_LEN`
            pred_time_mask = pred_time_mask[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            targets = targets[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            logits = logits[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            row_ids = valid_batch['row_id'][:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            user_ids = valid_batch['user_id'][:, tf.newaxis] * tf.ones_like(row_ids, dtype=tf.int64)

            targets = targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)

            preds = tf.math.sigmoid(logits)

            pred_mask = targets != NON_TARGET_ID
            pred_indices = tf.where(pred_mask)

            selected_targets = tf.gather_nd(targets, pred_indices)
            selected_logits = tf.gather_nd(logits, pred_indices)
            selected_preds = tf.math.sigmoid(selected_logits)

            # Need to have the 1st dimension to make the losses not averaged
            losses = loss_obj(selected_targets[:, tf.newaxis], selected_logits[:, tf.newaxis])
            total_loss = tf.math.reduce_sum(losses)

            # `train_batch['nb_pred_places'][0]` is the total number of places used for calculating loss
            loss = total_loss / tf.cast(tf.math.reduce_mean(valid_batch['nb_pred_places']), dtype=DTYPE)

            loss_metric.update_state(total_loss)
            acc_metric.update_state(selected_targets[:, tf.newaxis], selected_preds[:, tf.newaxis])
            auc_metric.update_state(selected_targets, selected_preds)

            return targets, preds, row_ids, user_ids, pred_mask

        @tf.function
        def dist_train_step(dist_train_batch):
            
            strategy.run(train_step, args=(dist_train_batch,))
            
        @tf.function
        def dist_valid_step(dist_valid_batch):
            
            targets, preds, row_ids, user_ids, pred_mask = strategy.run(valid_step, args=(dist_valid_batch,))
 
            return targets, preds, row_ids, user_ids, pred_mask

        @tf.function
        def dist_train_multi_steps(dist_train_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)
 
            for _ in tf.range(self.train_config.steps_per_call):

                dist_train_batch = next(dist_train_iter)
                dist_train_step(dist_train_batch)
                
                if type(dist_train_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_train_batch['nb_pred_places'])
                else:  # type(dist_train_batch['nb_pred_places']) == tf.python.distribute.values.PerReplica
                       # `dist_train_batch['nb_pred_places']` is PerPlica --> Use `.values` attribute to access tensors.
                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_train_batch['nb_pred_places'].values, axis=0))

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

            return total_nb_pred_places

        @tf.function
        def dist_valid_multi_steps(dist_valid_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)

            t_arr_1 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_2 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_3 = tf.TensorArray(tf.bool, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_4 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_5 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)

            for step_idx in tf.range(self.train_config.steps_per_call):

                dist_valid_batch = next(dist_valid_iter)
                _targets, _preds, _row_ids, _user_ids, _pred_mask = dist_valid_step(dist_valid_batch)
                
                if type(dist_valid_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_valid_batch['nb_pred_places'])
                else:

                    _targets = tf.concat(_targets.values, axis=0)
                    _preds = tf.concat(_preds.values, axis=0)
                    _pred_mask = tf.concat(_pred_mask.values, axis=0)

                    _row_ids = tf.concat(_row_ids.values, axis=0)
                    _user_ids = tf.concat(_user_ids.values, axis=0)

                    # `dist_valid_batch['nb_pred_places']` is PerPlica --> Use `.values` attribute to access tensors.
                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_valid_batch['nb_pred_places'].values, axis=0))

                t_arr_1 = t_arr_1.write(step_idx, _targets)
                t_arr_2 = t_arr_2.write(step_idx, _preds)
                t_arr_3 = t_arr_3.write(step_idx, _pred_mask)
                t_arr_4 = t_arr_4.write(step_idx, _row_ids)
                t_arr_5 = t_arr_5.write(step_idx, _user_ids)                    
                    
                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)
            
            targets = t_arr_1.concat()
            preds = t_arr_2.concat()
            pred_mask = t_arr_3.concat()
            row_ids = t_arr_4.concat()
            user_ids = t_arr_5.concat()

            return targets, preds, row_ids, user_ids, pred_mask, total_nb_pred_places

        @tf.function
        def dist_valid_multi_steps_last_call(dist_valid_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)

            t_arr_1 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_2 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_3 = tf.TensorArray(tf.bool, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_4 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_5 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)

            for step_idx in tf.range(self.train_config.n_steps_in_last_valid_call):

                dist_valid_batch = next(dist_valid_iter)
                _targets, _preds, _row_ids, _user_ids, _pred_mask = dist_valid_step(dist_valid_batch)
                
                if type(dist_valid_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_valid_batch['nb_pred_places'])
                else:

                    _targets = tf.concat(_targets.values, axis=0)
                    _preds = tf.concat(_preds.values, axis=0)
                    _pred_mask = tf.concat(_pred_mask.values, axis=0)

                    _row_ids = tf.concat(_row_ids.values, axis=0)
                    _user_ids = tf.concat(_user_ids.values, axis=0)

                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_valid_batch['nb_pred_places'].values, axis=0))

                t_arr_1 = t_arr_1.write(step_idx, _targets)
                t_arr_2 = t_arr_2.write(step_idx, _preds)
                t_arr_3 = t_arr_3.write(step_idx, _pred_mask)
                t_arr_4 = t_arr_4.write(step_idx, _row_ids)
                t_arr_5 = t_arr_5.write(step_idx, _user_ids)                    
                    
                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

            targets = t_arr_1.concat()
            preds = t_arr_2.concat()
            pred_mask = t_arr_3.concat()
            row_ids = t_arr_4.concat()
            user_ids = t_arr_5.concat()
            
            return targets, preds, row_ids, user_ids, pred_mask, total_nb_pred_places

        n_training_steps_per_epoch = self.train_config.n_training_steps_per_epoch
        n_training_steps = self.train_config.n_epochs * n_training_steps_per_epoch

        if MAX_TRAIN_ITER_STEPS is not None:
            n_training_steps_per_epoch = MAX_TRAIN_ITER_STEPS
            n_training_steps = self.train_config.n_epochs * n_training_steps_per_epoch
           
        self.train_config.n_training_steps_per_epoch = n_training_steps_per_epoch
        self.train_config.n_training_steps = n_training_steps   

        with open('train_config.json', 'w', encoding='UTF-8') as fp:
            json.dump(json.loads(self.toJSON()), fp, ensure_ascii=False, indent=4)
        if not IS_KAGGLE:
            !gsutil cp -r './train_config.json' '{self.train_config.ckpt_path}'
     
        n_epochs = self.train_config.n_epochs
        n_training_examples = self.train_config.n_training_examples

        print(f'n_epochs: {n_epochs}')
        print(f'n_training_examples: {n_training_examples}')
        print(f'n_training_steps_per_epoch: {n_training_steps_per_epoch}')
        print(f'n_training_steps: {n_training_steps}')
        print(f'n_valid_examples: {self.train_config.n_valid_examples}')
        print(f'n_valid_steps: {self.train_config.n_valid_steps}')        
        
        training_history = dict()

        def train(last_epoch=0, valid_epochs=None):

            start_epoch = last_epoch + 1

            end_epoch = self.train_config.n_epochs + 1
            train_ds = self.train_ds

            dist_train_iter = iter(train_ds)

            for epoch in range(start_epoch, end_epoch):

                training_history[epoch] = {
                    'loss': [],
                    'acc': [],
                    'auc': [],
                    'timing': []        
                }

                n_steps = 0
                n_preds = 0
                n_last_preds = 0

                total_nb_pred_places = 0.0

                start_epoch_t = datetime.datetime.now()

                for call_idx in range(self.train_config.n_training_calls_per_epoch):

                    start_t = datetime.datetime.now()

                    nb_pred_places = dist_train_multi_steps(dist_train_iter)                    
                    n_steps += self.train_config.steps_per_call

                    total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)
                    
                    avg_loss = (loss_metric.result() / total_nb_pred_places).numpy()
                    acc = acc_metric.result().numpy()
                    auc = auc_metric.result().numpy()

                    # --------------------------------------------------  

                    if (call_idx + 1) % max(1, self.train_config.n_training_calls_per_epoch // 100) == 0:

                        training_history[epoch]['loss'].append(float(avg_loss))
                        training_history[epoch]['acc'].append(float(acc))
                        training_history[epoch]['auc'].append(float(auc))

                    # --------------------------------------------------        

                    elapsed = (datetime.datetime.now() - start_t).total_seconds()

                    if (call_idx + 1) % max(1, self.train_config.n_training_calls_per_epoch // 10) == 0:

                        print(f'epoch: {epoch} - step: {n_steps}')
                        print(f'timing per step: {elapsed / self.train_config.steps_per_call}')
                        print(f'loss: {avg_loss}')
                        print(f'acc: {acc}')
                        print(f'auc: {auc}')
                        print(f'lr: {optimizer._decayed_lr(DTYPE)}')        
                        print('-' * 80)

                end_epoch_t = datetime.datetime.now()
                elapsed_epoch_t = (end_epoch_t - start_epoch_t).total_seconds()
                training_history[epoch]['timing'] = elapsed_epoch_t

                training_history[epoch]['train_loss'] = float(avg_loss)
                training_history[epoch]['train_acc'] = float(acc)
                training_history[epoch]['train_auc'] = float(auc)

                # In order to save the checkpoints.
                if IS_KAGGLE and tpu:
                    ckpt_manager.checkpoint.save(
                        file_prefix=ckpt_manager.directory + 'ckpt', options=tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
                    )
                else:
                    ckpt_manager.save()
                
                if not IS_KAGGLE and self.train_config.ckpt_path.startswith('gs://') and tpu is None:
                    !gsutil cp -r './checkpoint' '{self.train_config.ckpt_path}'
                    !gsutil cp -r './ckpt-{epoch}.index' '{self.train_config.ckpt_path}'
                    !gsutil cp -r './ckpt-{epoch}.data-00000-of-00001' '{self.train_config.ckpt_path}'
                    !rm -rf './ckpt-{epoch}.index'
                    !rm -rf './ckpt-{epoch}.data-00000-of-00001'   

                print(f'epoch: {epoch} - train')
                print(f'train loss: {avg_loss}')
                print(f'train acc: {acc}')
                print(f'train auc: {auc}')
                print('-' * 80)

                loss_metric.reset_states()
                acc_metric.reset_states()
                auc_metric.reset_states()

                # --------------------------------------------------
                # saving

                fn = f'training_history-{epoch}.json'
                with open(fn, 'w', encoding='UTF-8') as fp:
                    json.dump(training_history, fp, ensure_ascii=False, indent=4)

                if not IS_KAGGLE:
                    !gsutil cp -r './{fn}' '{self.train_config.ckpt_path}'
                    !rm -rf './{fn}'

                # --------------------------------------------------
                # validation

                if valid_epochs is None:
                    valid_epochs = [self.train_config.n_epochs]
                
                if epoch in valid_epochs:
                    valid(epoch)

        def valid(epoch):

            if epoch not in training_history:

                training_history[epoch] = {
                    'loss': [],
                    'acc': [],
                    'auc': [],
                    'timing': []        
                }                

            valid_user_ids = []
            valid_row_ids = []
            valid_targets = []
            valid_preds = []

            n_steps = 0
            total_nb_pred_places = 0.0

            dist_valid_iter = iter(self.valid_ds)

            start_valid = datetime.datetime.now()

            for call_idx in range(self.train_config.n_valid_calls):

                start_t = datetime.datetime.now()

                last_call = (self.train_config.n_steps_in_last_valid_call > 0) and (call_idx == self.train_config.n_valid_calls - 1)
                
                if last_call:
                    targets, preds, row_ids, user_ids, pred_mask, nb_pred_places = dist_valid_multi_steps_last_call(dist_valid_iter)
                    n_steps += self.train_config.n_steps_in_last_valid_call
                else:
                    targets, preds, row_ids, user_ids, pred_mask, nb_pred_places = dist_valid_multi_steps(dist_valid_iter)            
                    n_steps += self.train_config.steps_per_call
            
                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)
                
                pred_indices = tf.where(pred_mask)

                selected_targets = tf.gather_nd(targets, pred_indices)
                selected_preds = tf.gather_nd(preds, pred_indices)

                selected_row_ids = tf.gather_nd(row_ids, pred_indices)
                selected_user_ids = tf.gather_nd(user_ids, pred_indices)

                avg_loss = (loss_metric.result() / total_nb_pred_places).numpy()
                acc = acc_metric.result().numpy()
                auc = auc_metric.result().numpy()

                elapsed = (datetime.datetime.now() - start_t).total_seconds()

                if call_idx % max(1, self.train_config.n_valid_calls // 10) == 0:

                    print(f'epoch: {epoch} - valid step: {n_steps}')
                    print(f'valid timing per steps: {elapsed / self.train_config.steps_per_call}')
                    print(f'valid loss: {avg_loss}')
                    print(f'valid acc: {acc}')
                    print(f'vald auc: {auc}')       
                    print('-' * 80)
                
                valid_row_ids.extend(selected_row_ids.numpy().tolist())
                valid_user_ids.extend(selected_user_ids.numpy().tolist())
                valid_targets.extend(selected_targets.numpy().tolist())
                valid_preds.extend(selected_preds.numpy().tolist())

            end_valid = datetime.datetime.now()
            elapsed_valid = (end_valid - start_valid).total_seconds()
            training_history[epoch]['valid_timing'] = elapsed_valid

            training_history[epoch]['valid_loss'] = float(avg_loss)
            training_history[epoch]['valid_acc'] = float(acc)
            training_history[epoch]['valid_auc'] = float(auc)

            print(f'epoch: {epoch} - valid')
            print(f'valid loss: {avg_loss}')
            print(f'valid acc: {acc}')
            print(f'vald auc: {auc}')       
            print('=' * 80)

            loss_metric.reset_states()
            acc_metric.reset_states()
            auc_metric.reset_states()
            
            # --------------------------------------------------

            valid_submission = pd.DataFrame.from_dict(
                {
                    'row_ids': valid_row_ids,
                    'user_ids': valid_user_ids,
                    'targets': valid_targets,
                    'preds': valid_preds
                }
            )
            valid_submission.to_csv(f'valid_submission_epoch_{epoch}.csv', index=False)

            if not IS_KAGGLE:
                !gsutil cp -r './valid_submission_epoch_{epoch}.csv' '{self.train_config.ckpt_path}'                   
                !rm -rf './valid_submission_epoch_{epoch}.csv'

            # --------------------------------------------------

            fn = f'training_history-{epoch}.json'
            if only_valid:
                fn = f'training_history-{epoch}-only-valid.json'
            with open(fn, 'w', encoding='UTF-8') as fp:
                json.dump(training_history, fp, ensure_ascii=False, indent=4)

            if not IS_KAGGLE:
                !gsutil cp -r './{fn}' '{self.train_config.ckpt_path}'
                !rm -rf './{fn}'

            # --------------------------------------------------

        if only_valid:
            valid(last_epoch)
        else:
            train(last_epoch, valid_epochs)

# Train / Valid<a id='train-valid'></a>

## A method to make the model initialize its variables

In [None]:
 def run_dummy_inputs(predictor):

    @tf.function
    def foo(inputs):

        logits, _, _ = predictor(inputs=inputs)
        return logits

    with strategy.scope():

        content_input_ids = tf.constant(1, shape=[3, 5])
        response_input_ids = tf.constant(1, shape=[3, 5])
        d_input_ids = tf.constant(1, shape=[3, 5])
        tag_ids = tf.constant(0, shape=[3, 5, N_TAGS_PER_CONTENT])
        part_ids = tf.constant(1, shape=[3, 5])
        prior_explanation_ids = tf.constant(1, shape=[3, 5])
        prior_question_elapsed_time_input = tf.constant(1.0, shape=[3, 5])
    
        inputs = {
            'c_input_ids': content_input_ids,
            'r_input_ids': response_input_ids,
            'd_input_ids': decoder_input_ids,
            'tag_ids': tag_ids,
            'part_ids': part_ids,
            'prior_explanation_ids': prior_explanation_ids,
            'prior_question_elapsed_time_input': prior_question_elapsed_time_input        
        }

        logits = foo(inputs)

## A method to load the checkpoints

In [None]:
def load_ckpt(predictor, optimizer, ckpt_path, ckpt_no=-1):

    if ckpt_path.startswith('gs://') and tpu is None:

        assert ckpt_no >= 1

        f1 = f'{ckpt_path}checkpoint'
        f2 = f'{ckpt_path}ckpt-{ckpt_no}.index'
        f3 = f'{ckpt_path}ckpt-{ckpt_no}.data-00000-of-00001'

        !gsutil cp -r {f1} './'
        !gsutil cp -r {f2} './'
        !gsutil cp -r {f3} './'

        ckpt_path = './'

    with strategy.scope():

        # ----------------------------------------------------------------------
        # Init model's weight if necessary

        run_dummy_inputs(predictor)        

        # ----------------------------------------------------------------------

        if optimizer is None:
            optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=predictor)
        ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=ckpt_path, max_to_keep=None)

        last_ckpt_no = 0
        if ckpt_manager.latest_checkpoint:
            if len(ckpt_manager.latest_checkpoint.split('/ckpt-')) > 0:
                last_ckpt_no = int(ckpt_manager.latest_checkpoint.split('/ckpt-')[-1])
                
        if last_ckpt_no > 0:
            print(f'Latest checkpoint found. Model trained for {last_ckpt_no} epochs.')
        else:
            print("No checkpoint found.")

        if ckpt_no == -1:
            ckpt_no_to_load = last_ckpt_no
        else:
            ckpt_no_to_load = ckpt_no

        assert ckpt_no_to_load <= last_ckpt_no

        ckpt_file = os.path.join(ckpt_path, f'ckpt-{ckpt_no_to_load}')

        # ----------------------------------------------------------------------
        # Load ckpt

        print(f'try to load {ckpt_file}')
        if ckpt_no_to_load > 0:
            status = checkpoint.restore(ckpt_file)
            print(f'ckpt-{ckpt_no_to_load} is restored.')
            loaded_ckpt_no = ckpt_no_to_load
        else:
            print(f'no ckpt is found and restored.')
            loaded_ckpt_no = 0

    return ckpt_manager, loaded_ckpt_no

### Train / Valid / Pred configurations

In [None]:
max_position_embeddings = MAX_HISTORY_LEN
if not USE_ABS_POS:
    # `0` is used for padding, the real position is from `1` to `WINDOW_SIZE`.
    max_position_embeddings = WINDOW_SIZE + 1

config = EdFormerConfig(
    model_type=MODEL_TYPE,
    model_desc=MODEL_DESC,
    model_size=MODEL_SIZE,
    content_vocab_size=CONTENT_VOCAB_SIZE,
    response_vocab_size=RESPONSE_VOCAB_SIZE,
    tag_vocab_size=TAG_VOCAB_SIZE,
    part_vocab_size=PART_VOCAB_SIZE,
    prior_explanation_vocab_size=PRIOR_EXPLANATION_VOCAB_SIZE,
    use_prior_question_elapsed_time_input=USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT,
    max_position_embeddings=max_position_embeddings,
    sinusoidal_pos_embds=False,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    dim=DIM,
    hidden_dim=HIDDEN_DIM,
    activation=ACTIVATION,        
    dropout=0.1,
    attention_dropout=0.1,
    seq2seq_dropout=0.1,
    initializer_range=0.02,
    seed=SEED,
    pad_token_id=PAD_ID,
    use_abs_pos=USE_ABS_POS,
    share_position_embeddings=SHARE_POS_EMBEDDING,
    use_tags=USE_TAGS,
    use_part=USE_PART,
    use_prior_explanation=USE_PRIOR_EXPLANATION,
    allow_bundle_atten=ALLOW_BUNDLE_ATTEN,
    generative=GENERATIVE
)

train_config = TrainConfig(
    ckpt_path=CKPT_TRAIN_PATH,
    window_size=WINDOW_SIZE,
    loss_weight_window_size=LOSS_WEIGHT_WINDOW_SIZE,
    n_epochs=N_EPOCHS,
    shuffle_buf_size=SHUFFLE_BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    pred_batch_size=PRED_BATCH_SIZE,
    seed=SEED,
    deterministic=DETERMINISTIC,
    num_parallel_reads=N_PARALLEL_READS,
    num_parallel_calls=N_PARALLEL_CALLS,
    steps_per_call=STEPS_PER_CALL
)

### Now comes the training!

In [None]:
if TRAIN:
    
    only_valid = False
    ckpts = [0]
    valid_epochs = [3, 6]

    train_manager = Train_Manager(config, train_config)
    predictor, optimizer, loss_obj, metrics = train_manager.get_train_objs(lr=LR, end_lr=END_LR, warmup_steps=WARMUP_STEPS)
    
    # ----------------------------------------------------------------------------------------------------

    run_dummy_inputs(predictor)

    print('-' * 40)
    print('trainable variables:')
    for v in predictor.trainable_variables:
        print(v.name)
    print('-' * 40)

    for ckpt_no in ckpts:

        # ----------------------------------------------------------------------------------------------------

        if RESUME_TRAINING or only_valid:

            # reload ckpt
            ckpt_manager, loaded_ckpt_no = load_ckpt(predictor, optimizer, CKPT_TRAIN_PATH, ckpt_no=ckpt_no)

        else:

            with strategy.scope():

                loaded_ckpt_no = 0
                checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=predictor)
                ckpt_path = CKPT_TRAIN_PATH
                if ckpt_path.startswith('gs://') and tpu is None:
                    ckpt_path = './'
                ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=ckpt_path, max_to_keep=None)

        train_manager.train_valid(predictor, optimizer, ckpt_manager, loss_obj, metrics, last_epoch=loaded_ckpt_no, valid_epochs=valid_epochs, only_valid=only_valid)

In [None]:
CKPT_TRAIN_PATH

In [None]:
!ls -l

# Conclusion

1. We compared the usage of single/multiple steprs per TPU call. From [Single step per TPU call](#https://www.kaggle.com/yihdarshieh/r3id-tf-tpu/output?scriptVersionId=49071414) and [1000 steps per TPU call](#https://www.kaggle.com/yihdarshieh/r3id-tf-tpu/output?scriptVersionId=49075674), we see that:

    timing per step: (with batch sizes 128/256)
        * single step per TPU call:          
              training: 178 ms
            validation:  93 ms
        * 1000 steps per TPU call: 32 m
              training:  32 ms
            validation:  65 ms

It is evident that multiple steprs per TPU call speed up the training (5.5x faster) - and a bit faster for validation. For validation, since `PRED_BATCH_SIZE = 256` is twice large as 
`BATCH_SIZE = 128`, the timing for train/valid with `1000 steps per TPU call` is almost the same if they would have the same batch size. It is not clear for me why the difference for validatioin timining is less large than the difference for training - potentially due to the overhead coming from accumulating the predictions. For this version, we use larger batch size (512/2048) - this should reduce further the timing (once normalized to the same batch size during comparison).

2. We see how to save checkpoints locally on Kaggle via `tf.train.CheckpointManager`.

3. We learn how to accumulate the validation predictions while using TPU with multi-stpes by using `tf.TensorArray`.

4. An input pipeline with `tf.RaggedTensor` and a series of dataset transformations to obtain the training/validation dataset formats specific to this competition is demonstrated.

5. Several model options are provided, in particular:

   * encoder only
   * encoder decoder   
       - treat each question in a bundle during inference as a single question in the bundle
       - auto-regressive prediction generation
       
6. A notebook works both on Kaggle and Colab (with a minimal necessary change).

Hope you enjoy it. Good luck for the competition!