# About this notebook

# Environment

In [None]:
import datetime

kernel_start = datetime.datetime.now()

print(f'kernel_start: {kernel_start}')

In [None]:
import os

IS_KAGGLE = os.path.isdir('/kaggle/input')


# It is unclear why the model is not running on GPU when using tensorflow 2.3.1.
# If GPU is enabled, it will avoid to upgrade to tensorflow 2.3.1 later.
gpu_info = !nvidia-smi
if 'command not found' in gpu_info[0]:
    USE_GPU = False
elif 'GPU  Name ' in str(gpu_info):
    USE_GPU = True
else:
    USE_GPU = False

print(f'IS_KAGGLE: {IS_KAGGLE}')        
print(f'USE_GPU: {USE_GPU}')

In [None]:
!pip uninstall -y cloud-tpu-client

if not IS_KAGGLE:

    !pip uninstall -y tensorflow
    !pip install --upgrade tensorflow==2.3.0
    !pip uninstall -y tensorflow-gcs-config
    !pip install --upgrade tensorflow-gcs-config==2.3.0
    

elif IS_KAGGLE and not USE_GPU:
    # Kaggle TPU
    
    !pip uninstall -y tensorflow
    !pip install --upgrade tensorflow==2.3.1

import tensorflow as tf
from tensorflow.python.tpu.client import client

print(f'TensorFlow: {tf.__version__}')
print(client.Client)

In [None]:
if IS_KAGGLE:
    BASE_DIR = '/kaggle/input'
else:
    BUCKET_DIR = 'gs://shieh-tpu/r3id'
    BASE_DIR = '/content/drive/My Drive/r3id'
    BASE_DIR_QUOTED = '"/content/drive/My Drive/r3id"'

if not IS_KAGGLE:

    # Access GCP Bucket
    from google.colab import auth
    auth.authenticate_user()
    project_id = 'shieh-tpu'
    !gcloud config set project {project_id}
    !gsutil ls {BUCKET_DIR}

    # Access Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    if not os.path.isdir(BASE_DIR):

        !mkdir {BASE_DIR_QUOTED}
        !gsutil -m cp -r 'gs://shieh-tpu/r3id' "/content/drive/My Drive"

!ls -l '{BASE_DIR}'

In [None]:
# Detect hardware, return appropriate distribution strategy

try:
    client.Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    if IS_KAGGLE:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    else:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR']) 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
# Use a newer version of huggingface's `tokenizers` (?? -> 0.9.2)    
# Use a newer version of huggingface's `transformers` (3.0.2 -> 3.4.0) 
!pip uninstall -y datatable
!pip uninstall -y tokenizers
!pip uninstall -y transformers

if IS_KAGGLE:
    !pip install '{BASE_DIR + '/r3id-packages/datatable-0.11.0-cp37-cp37m-manylinux2010_x86_64.whl'}'
    !pip install --upgrade '{BASE_DIR + '/r3id-packages/tokenizers-0.9.2-cp37-cp37m-manylinux1_x86_64.whl'}'

else:
    !pip install '{BASE_DIR + '/r3id-packages/datatable-0.11.0-cp36-cp36m-manylinux2010_x86_64.whl'}'
    !pip install --upgrade '{BASE_DIR + '/r3id-packages/tokenizers-0.9.2-cp36-cp36m-manylinux1_x86_64.whl'}'

!pip install --upgrade '{BASE_DIR + '/r3id-packages/transformers-3.4.0-py3-none-any.whl'}'

# Packages

## Install

## Import

In [None]:
import os
import psutil
import h5py
import numpy as np
import pandas as pd
import datatable as dt
import json
import pickle
from collections import defaultdict
import random
import math
import datetime
import collections
os.environ['TF_DETERMINISTIC_OPS'] = '1'
if IS_KAGGLE and USE_GPU:
    import tensorflow_probability as tfp
from copy import deepcopy
import gc
gc.enable()


import transformers
from transformers import PretrainedConfig, DistilBertConfig
from transformers.modeling_tf_distilbert import TFFFN, TFTransformer
from transformers.modeling_tf_distilbert import TFMultiHeadSelfAttention as HFTFMultiHeadSelfAttention
from transformers.modeling_tf_distilbert import TFTransformerBlock as HFTFTransformerBlock
from transformers.modeling_tf_distilbert import TFFFN as HFTFFFN

from transformers.modeling_tf_utils import (
    TFPreTrainedModel,
    keras_serializable,
    shape_list,
)
if IS_KAGGLE:
    import riiideducation
    from kaggle_datasets import KaggleDatasets

# TPU

# Data

In [None]:
train_dt_path_or_obj = f'{BASE_DIR}/r3id-trainjay/train.jay'

unique_user_id_train_path = f'{BASE_DIR}/r3id-info/unique_user_id_train.json'
unique_question_id_train_path = f'{BASE_DIR}/r3id-info/unique_question_id_train.json'
unique_lecture_id_train_path = f'{BASE_DIR}/r3id-info/unique_lecture_id_train.json'
user_id_to_row_id_train_path = f'{BASE_DIR}/r3id-info/user_id_to_row_id_train.json'

unique_user_id_splitted_train_path = f'{BASE_DIR}/r3id-info-shieh/unique_user_id_splitted_train.json'
unique_question_id_splitted_train_path = f'{BASE_DIR}/r3id-info-shieh/unique_question_id_splitted_train.json'
unique_lecture_id_splitted_train_path = f'{BASE_DIR}/r3id-info-shieh/unique_lecture_id_splitted_train.json'

n_contents_dict_path = f'{BASE_DIR}/r3id-info-shieh/n_contents_dict.json'

question_tags_info_path = f'{BASE_DIR}/r3id-info-shieh/question_tags_info.json'
lecture_tags_info_path = f'{BASE_DIR}/r3id-info-shieh/lecture_tags_info.json'

question_part_info_path = f'{BASE_DIR}/r3id-info-shieh/question_part_info.json'
lecture_part_info_path = f'{BASE_DIR}/r3id-info-shieh/lecture_part_info.json'

correct_answer_info_path = f'{BASE_DIR}/r3id-info-shieh/correct_answer_info.json'

question_history_path = f'{BASE_DIR}/r3id-info-shieh/question_history.json'
question_history_at_training_end_path = f'{BASE_DIR}/r3id-info-shieh/question_history_at_training_end.json'
single_question_history_at_training_end_optimized_path = f'{BASE_DIR}/r3id-info-shieh/single_question_history_at_training_end.json'

user_performance_hdf5_path = f'{BASE_DIR}/user-performance/user_performance.hdf5'

# ------------------------------------------------------------------------------------------
# this should be replaced with the corresponding files for validation
question_history_for_valid_path = f'{BASE_DIR}/r3id-info-shieh/question_history_fold_1.json'
question_history_at_training_end_for_valid_path = f'{BASE_DIR}/r3id-info-shieh/question_history_at_training_end_fold_1.json'
single_question_history_at_training_end_optimized_for_valid_path = f'{BASE_DIR}/r3id-info-shieh/single_question_history_at_training_end_fold_1.json'

# ------------------------------------------------------------------------------------------

valid_info_paths = [
    f'{BASE_DIR}/r3id-info-shieh/valid_info.json',
    f'{BASE_DIR}/r3id-info-shieh/valid_info_fold_1.json',
    f'{BASE_DIR}/r3id-info-shieh/valid_info_fold_2.json',
    f'{BASE_DIR}/r3id-info-shieh/valid_info_fold_3.json',
]

train_valid_split_indices_paths = [
        f'{BASE_DIR}/r3id-info-shieh/train_valid_split_indices.json',
        f'{BASE_DIR}/r3id-info-shieh/train_valid_split_indices_fold_1.json',
        f'{BASE_DIR}/r3id-info-shieh/train_valid_split_indices_fold_2.json',
        f'{BASE_DIR}/r3id-info-shieh/train_valid_split_indices_fold_3.json',
]

# train_tfrec_dir_local = f'{BASE_DIR}/ednet-tfrecords-sequential-more'
# valid_tfrec_dir_local = f'{BASE_DIR}/r3id-tfrecords-valid-more'

if tpu is None:
    
    TFREC_DIR = BASE_DIR
    train_tfrec_dir = f'{BASE_DIR}/ednet-tfrecords-sequential-more'
    valid_tfrec_dir = f'{BASE_DIR}/r3id-tfrecords-valid-more'

elif IS_KAGGLE:
    
    train_tfrec_dir = KaggleDatasets().get_gcs_path('ednet-tfrecords-sequential-more')
    valid_tfrec_dir = KaggleDatasets().get_gcs_path('r3id-tfrecords-valid-more')

    print(f'train_tfrec_dir from Kaggle = {train_tfrec_dir}\n')
    print(f'valid_tfrec_di from Kaggle = {valid_tfrec_dir}\n')

    # you can list the buckets
    !gsutil ls $train_tfrec_dir
    !gsutil ls $valid_tfrec_dir
        
else:
    
    TFREC_DIR = BUCKET_DIR
    train_tfrec_dir = f'{BUCKET_DIR}/ednet-tfrecords-sequential-more'
    valid_tfrec_dir = f'{BUCKET_DIR}/r3id-tfrecords-valid-more'

# Load the big file - train.dt - once
train_dt = dt.fread(train_dt_path_or_obj)

# Configuration

## Model / Training settings

In [None]:
MODEL_TYPE = 'ed'
MODEL_SIZE = 'b-2'
MODEL_DESC = 'raw'

ACTIVATION = 'relu'
USE_PRE_CLASSIFIER = True
USE_SOFTMAX = False
USE_USER_ANSWER = False
USE_USER_ANSWER_LOSS = False
USE_CORRECT_ANSWER_FOR_ENCODER = False
USE_CORRECT_ANSWER_FOR_DECODER = False
USE_ABS_POS = False
USE_TASK_CONTAINER_POS = False
SHARE_POS_EMBEDDING = True
USE_TAGS = False
USE_PART = False
USE_PRIOR_EXPLANATION = False
USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT = False
USE_LAG_TIME = False
USE_LAG_TIME_FOR_ENCODER = False
USE_USER_LEVEL_AGGREGATED_HISTORICAL_INFO = False
USE_PART_AGGREGATED_HISTORICAL_INFO = False
USE_CORRECT_ANSWER_AGGREGATED_HISTORICAL_INFO = False
USE_QUESTION_LEVEL_AGGREGATED_HISTORICAL_INFO = False
ALLOW_BUNDLE_ATTEN = False
GENERATIVE = False

VALID_FOLD = 0

WINDOW_SIZE = 128
LOSS_WEIGHT_WINDOW_SIZE = None
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
if not tpu:
    BATCH_SIZE = 8 * BATCH_SIZE
PRED_BATCH_SIZE = 256 * strategy.num_replicas_in_sync
if tpu:
    assert BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE
assert PRED_BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE
N_EPOCHS = 2
STEPS_PER_CALL = 1000
MAX_N_CONTENTS_PER_USER_FOR_SAMPLING_PROB = 512

LR = 1e-4
END_LR = 1e-4
WARMUP_STEPS = 5000

DETERMINISTIC = False

if DETERMINISTIC:
    
    SEED = 2021
    N_PARALLEL_READS = None  # 1
    N_PARALLEL_CALLS = None  # 1
    SHUFFLE_BUFFER_SIZE = 1

else:
    
    SEED = None
    N_PARALLEL_READS = 16
    N_PARALLEL_CALLS = tf.data.experimental.AUTOTUNE
    SHUFFLE_BUFFER_SIZE = 4096

MAX_TRAIN_ITER_STEPS = None
MAX_VALID_ITER_STEPS = None

PRINTING_STEPS = 1000

CKPT_DIR = f'{MODEL_TYPE}-{MODEL_SIZE}-{MODEL_DESC}/'

## Running mode

In [None]:
DTYPE = tf.float32
tf.keras.backend.set_floatx('float32')

TRAIN = True
VALID = False
PRED = False

RESUME_TRAINING = False

N_FILES = 6
SUBMISSION = False

if IS_KAGGLE:
    N_FILES = len(os.listdir('/kaggle/input/riiid-test-answer-prediction'))
    SUBMISSION = (N_FILES != 6)
else:
    PRED = False

if SUBMISSION:
    
    TRAIN = False
    VALID = False
    PRED = True

if not TRAIN:
    RESUME_TRAINING = False

DEBUG = True
PROBE = False

CKPT_TRAIN_PATH = None
CKPT_PRED_PATH = None
if CKPT_TRAIN_PATH is None:
    
    if IS_KAGGLE:
        CKPT_TRAIN_PATH = './'
    else:
        CKPT_TRAIN_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'
        
        if TRAIN and not RESUME_TRAINING:
            _state = !gsutil -q stat {CKPT_TRAIN_PATH}*; echo $?
            already_existed = 1 - int(_state[0])
            assert not already_existed       

if CKPT_PRED_PATH is None:
    
    if IS_KAGGLE:
        CKPT_PRED_PATH = f'{BASE_DIR}/{CKPT_DIR}'
    else:
        CKPT_PRED_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'

In [None]:
print(f'SUBMISSION: {SUBMISSION}')
print(f'TRAIN: {TRAIN}')
print(f'VALID: {VALID}')
print(f'PRED: {PRED}')

print('')

print(f'CKPT_TRAIN_PATH: {CKPT_TRAIN_PATH}')
print(f'CKPT_PRED_PATH: {CKPT_PRED_PATH}')

## Vocabulary settings

In [None]:
PAD_TOKEN = -2
START_TOKEN = -3
END_TOKEN = -4
MASK_TOKEN = -5

PAD_ID = 0
START_ID = 1
END_ID = 2
MASK_ID = 3

RESPONSE_LECTURE_TOKEN = -1
RESPONSE_FALSE_TOKEN = 0
RESPONSE_TRUE_TOKEN = 1

ANSWER_LECTURE_TOKEN = -1
ANSWER_0_TOKEN = 0
ANSWER_1_TOKEN = 1
ANSWER_2_TOKEN = 2
ANSWER_3_TOKEN = 3


RESPONSE_LECTURE_ID = 4
RESPONSE_FALSE_ID = 5
RESPONSE_TRUE_ID = 6

ANSWER_LECTURE_ID = 4
ANSWER_0_ID = 5
ANSWER_1_ID = 6
ANSWER_2_ID = 7
ANSWER_3_ID = 8

NON_TARGET_ID = -100

TAG_VOCAB_SIZE = 189  # Original: [0, 2, ..., 187]. Size is `189` because we will consider `-1` for special vocab.
PART_VOCAB_SIZE = 9  # Original: [1, 2, ..., 7]. We include `0` although this doesn't exist in the dataset. Size is `9` because we will consider `-1` for special vocab.
PRIOR_EXPLANATION_VOCAB_SIZE = 3  # Original: [0, 1]. Size is `3` because we will consider `-1` for special vocab.

# --------------------------------------------------

N_TAGS_PER_CONTENT = 6

# --------------------------------------------------
# To be removed once decided not to use abs_pos
MAX_TRAIN_HISTORY_LEN = 17917
# This is only an assumption (including `0` for padding)
MAX_HISTORY_LEN = 20480
# --------------------------------------------------

# Maybe to probe this info.
# Actually, it is `9999`
MAX_TASK_CONTAINER_ID = 10000

# --------------------------------------------------

N_AGGREGATED_QUESTION_SCALING_FACTOR = 1000

# Probe Info

In [None]:
MAX_PRED_TIME_QUESTION_BUNDLE_LEN = 10

# Train Manager

## Dataset

### Tables mapping question / lectures / tags ids (etc.) to encoder / decoder input ids

In [None]:
def convert_valid_info(valid_info):

    data = {}

    user_ids = list(valid_info.keys())

    for user_id in user_ids:

        user_valid_info = valid_info[user_id]

        user_valid_info['bundle_info']['block_starting_index_dict'] = {int(k): v for k, v in user_valid_info['bundle_info']['block_starting_index_dict'].items()}
        user_valid_info['bundle_info']['bundle_starting_index_dict'] = {int(k): v for k, v in user_valid_info['bundle_info']['bundle_starting_index_dict'].items()}
        data[int(user_id)] = user_valid_info

        del valid_info[user_id]

    return data

def load_data(path):

    if path.endswith('.json'):
        
        with open(path, 'r', encoding='UTF-8') as fp:
            data = json.load(fp)

        if path.endswith('valid_info.json') or 'valid_info_fold_' in path:
            data = convert_valid_info(data)

        elif path.endswith('n_contents_dict.json'):
            data = {int(k): v for k, v in data.items()}
            
        elif path.endswith('user_id_to_row_id_train.json') or 'train_valid_split_indices' in path:
            data = {int(k): v for k, v in data.items()}

        elif path.endswith('question_tags_info.json') or path.endswith('lecture_tags_info.json'):            
            data = {int(k): v for k, v in data.items()}            

        elif path.endswith('question_part_info.json') or path.endswith('lecture_part_info.json'):            
            data = {int(k): v for k, v in data.items()}

        elif path.endswith('correct_answer_info.json'):           
            data = {int(k): v for k, v in data.items()}
              
        elif path.endswith('question_history.json'):           
            pass      
                
        elif path.endswith('question_history_at_training_end.json'):           
            pass      
            
        elif path.endswith('single_question_history_at_training_end.json'):           
            pass             
              
        return data
        
    elif path.endswith('train.jay'):
        
        train_dt = dt.fread(train_dt_path_or_obj)
        
        return train_dt

    else:

        raise ValueError('The path provided is not the one needed to create an instance of `Valid_Manager`!')

In [None]:
questioin_df = pd.read_csv(f'{BASE_DIR}/riiid-test-answer-prediction/questions.csv')
lecture_df = pd.read_csv(f'{BASE_DIR}/riiid-test-answer-prediction/lectures.csv')

question_ids = questioin_df['question_id'].tolist()
lecture_ids = lecture_df['lecture_id'].tolist()

question_ids = sorted(question_ids)
lecture_ids = sorted(lecture_ids)

assert question_ids == sorted(question_ids)
assert lecture_ids == sorted(lecture_ids)

special_vocab = [PAD_TOKEN, START_TOKEN, END_TOKEN, MASK_TOKEN]
question_vocab = question_ids
lecture_vocab = lecture_ids

content_vocab = special_vocab + question_vocab + lecture_vocab

# Map question ids to encoder's input ids
input_ids = tf.range(len(special_vocab), len(special_vocab) + len(question_vocab))
initializer = tf.lookup.KeyValueTensorInitializer(tf.constant(question_ids, dtype=tf.int32), input_ids)
question_id_to_input_id_table = tf.lookup.StaticHashTable(
    initializer, default_value=tf.int32.limits[1], name=None
)

# Map lecture ids to encoder's input ids
input_ids = tf.range(len(special_vocab) + len(question_vocab), len(content_vocab))
initializer = tf.lookup.KeyValueTensorInitializer(tf.constant(lecture_ids, dtype=tf.int32), input_ids)
lecture_id_to_input_id_table = tf.lookup.StaticHashTable(
    initializer, default_value=tf.int32.limits[1], name=None
)

response_vocab = [PAD_TOKEN, START_TOKEN, END_TOKEN, MASK_TOKEN, RESPONSE_LECTURE_TOKEN, RESPONSE_TRUE_TOKEN, RESPONSE_FALSE_TOKEN]

answer_vocab = [
    PAD_TOKEN, START_TOKEN, END_TOKEN, MASK_TOKEN,
    ANSWER_LECTURE_TOKEN,
    ANSWER_0_TOKEN, ANSWER_1_TOKEN, ANSWER_2_TOKEN, ANSWER_3_TOKEN
]

CONTENT_VOCAB_SIZE = len(content_vocab)
RESPONSE_VOCAB_SIZE = len(response_vocab)
ANSWER_VOCAB_SIZE = len(answer_vocab)

question_tags_info = load_data(question_tags_info_path)
lecture_tags_info = load_data(lecture_tags_info_path)
correct_answer_info = load_data(correct_answer_info_path)

question_part_info = load_data(question_part_info_path)
lecture_part_info = load_data(lecture_part_info_path)

tags_database = []
part_database = []
correct_answer_table = []

for idx in range(len(content_vocab)):
    if idx < len(special_vocab):
        tags_database.append([-1] * N_TAGS_PER_CONTENT)
        part_database.append(-1)
        correct_answer_table.append(idx)
    elif idx < len(special_vocab) + len(question_vocab):
        index_in_questions_ids = idx - len(special_vocab)
        question_id = question_ids[index_in_questions_ids]
        tags_database.append(question_tags_info[question_id])
        part_database.append(question_part_info[question_id])
        correct_answer_table.append(correct_answer_info[question_id] + ANSWER_0_ID)
    elif idx < len(content_vocab):
        index_in_lecture_ids = idx - (len(special_vocab) + len(question_vocab))
        lecture_id = lecture_ids[index_in_lecture_ids]
        tags_database.append(lecture_tags_info[lecture_id])
        part_database.append(lecture_part_info[lecture_id])
        correct_answer_table.append(ANSWER_LECTURE_ID)

c_inputs_ids_to_tags = tf.constant(tags_database, dtype=tf.int32)
c_inputs_ids_to_part = tf.constant(part_database, dtype=tf.int32)
c_inputs_ids_to_correct_answer_id = tf.constant(correct_answer_table, dtype=tf.int32)

# --------------------------------------------------------------------------------------------------------------
# This is much faster!
question_id_to_input_id_dict = {
    k: v for k, v in zip(question_ids, range(len(special_vocab), len(special_vocab) + len(question_vocab)))
}
c_inputs_ids_to_part_dict = {k: v for k, v in enumerate(part_database)}
c_inputs_ids_to_correct_answer_id_dict = {k: v for k, v in enumerate(correct_answer_table)}

In [None]:
print(f'CONTENT_VOCAB_SIZE: {CONTENT_VOCAB_SIZE}')
print(f'RESPONSE_VOCAB_SIZE: {RESPONSE_VOCAB_SIZE}')
print(f'ANSWER_VOCAB_SIZE: {ANSWER_VOCAB_SIZE}')

### TFRecord files

In [None]:
if not IS_KAGGLE:

    train_tfrec_fns = [f'EdNet-user-history-{idx}.tfrecord' for idx in range(41)]
    # Sort the file - For verification purpose
    train_tfrec_fns = sorted(train_tfrec_fns, key=lambda x: int(x.replace('EdNet-user-history-', '').replace('.tfrecord', '')))

    if tpu is None or IS_KAGGLE:
        train_tfrec_paths = [os.path.join(train_tfrec_dir, fn) for fn in train_tfrec_fns]
    else:
        train_tfrec_paths = [train_tfrec_dir + f'/{fn}' for fn in train_tfrec_fns]

    train_tfrec_paths

In [None]:
if not IS_KAGGLE:

    valid_tfrec_fns = [
        'EdNet-user-history-with-valid-block-pos.tfrecord',
        # 'EdNet-user-history-with-valid-block-pos-fold-1.tfrecord',
        # 'EdNet-user-history-with-valid-block-pos-fold-2.tfrecord',
        # 'EdNet-user-history-with-valid-block-pos-fold-3.tfrecord',
    ]

    if tpu is None:
        valid_tfrec_paths = [os.path.join(valid_tfrec_dir, fn) for fn in valid_tfrec_fns]
    else:
        valid_tfrec_paths = [valid_tfrec_dir + f'/{fn}' for fn in valid_tfrec_fns]

    valid_tfrec_paths

### Load TFRecord files - tf.io.RaggedFeature

In [None]:
# --------------------------------------------------------------------------------
# For training dataset

train_raw_features = {
    'user_id': tf.io.FixedLenFeature([], dtype=tf.int64),
    'row_id': tf.io.RaggedFeature(value_key='row_id', dtype=tf.int64),
    'timestamp': tf.io.RaggedFeature(value_key='timestamp', dtype=tf.int64),
    'content_id': tf.io.RaggedFeature(value_key='content_id', dtype=tf.int64),
    'content_type_id': tf.io.RaggedFeature(value_key='content_type_id', dtype=tf.int64),
    'task_container_id': tf.io.RaggedFeature(value_key='task_container_id', dtype=tf.int64),
    'user_answer': tf.io.RaggedFeature(value_key='user_answer', dtype=tf.int64),
    'answered_correctly': tf.io.RaggedFeature(value_key='answered_correctly', dtype=tf.int64),
    'prior_question_elapsed_time': tf.io.RaggedFeature(value_key='prior_question_elapsed_time', dtype=tf.float32),
    'prior_question_had_explanation': tf.io.RaggedFeature(value_key='prior_question_had_explanation', dtype=tf.int64),
    # -------------------------------------------------------------------
    # extra information
    'n_prev_seen': tf.io.RaggedFeature(value_key='n_prev_seen', dtype=tf.int64),
    'n_prev_correctness': tf.io.RaggedFeature(value_key='n_prev_correctness', dtype=tf.int64),
    # -------------------------------------------------------------------

}


def parse_train_example(example):
    """Parse an example from the training tfrecord files.

    Add the following extra attributes:
 
        - seq_len: The length of the user interaction history for training, before the validation dataset being removed from it.
        - prev_seq_len: The length of interaction history in an example from which the current example is obtained. Here, it just
            equals to `seq_len` since it has no source.
        - start: The starting index in the interaction history of an example from which the current example is obtained. Here, it is `0`.
        - end: The ending index in the interaction history of an example from which the current example is obtained. Here, it is `seq_len - 1`.
        - pred_time_mask: A l-D `tf.Tensor` of `0` and `1`, indicating if a place is in the prediction time. Here, all of them are `0`.          

    """

    _parsed = tf.io.parse_single_example(example, train_raw_features)
    
    parsed = {}
    parsed['user_id'] = _parsed['user_id']
    parsed['seq_len'] = tf.reduce_sum(tf.ones_like(_parsed['row_id'], dtype=tf.int32))
    parsed['prev_seq_len'] = parsed['seq_len']
    parsed['start'] = tf.constant(0, dtype=tf.int32)
    parsed['end'] = parsed['seq_len'] - 1
    
    for k in train_raw_features:
        
        data = _parsed[k]
        if k not in ['row_id', 'user_id', 'timestamp', 'prior_question_elapsed_time']:
            data = tf.cast(data, dtype=tf.int32)
        elif k == 'prior_question_elapsed_time':
            data = tf.cast(data, dtype=DTYPE)
        if k != 'user_id':
            parsed[k] = data
            
    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    # This makes `parsed['shifted_answered_correctly']` has 1 more extra element at this moment.
    parsed['shifted_answered_correctly'] = tf.concat([[START_TOKEN], parsed['answered_correctly']], axis=0)

    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    # This makes `parsed['shifted_user_answer']` has 1 more extra element at this moment.
    parsed['shifted_user_answer'] = tf.concat([[START_TOKEN], parsed['user_answer']], axis=0)

    # This should be all `0`.
    pad_mask = tf.cast(parsed['row_id'] == PAD_TOKEN, dtype=tf.int32)

    # The tfrecord dataset is only used for training (excluding the part used for validation), so no place should be in the prediction time.
    pred_time_mask = tf.zeros_like(parsed['timestamp'], dtype=tf.int32)
    
    # This should be all `0`.
    pred_time_mask = pred_time_mask * (1 - pad_mask) + (PAD_TOKEN) * pad_mask
    
    parsed['pred_time_mask'] = pred_time_mask

    parsed['abs_pos'] = tf.range(parsed['seq_len'], dtype=tf.int32)
    parsed['shifted_abs_pos'] = tf.concat([[START_TOKEN], parsed['abs_pos'][:-1]], axis=0)

    # ----------------------------------------------------------------------------------------------------
    # To get the lag time

    diff_t = parsed['timestamp'] - tf.concat([[1], parsed['timestamp'][:-1]], axis=0)
    bundle_start_mask = tf.cast(diff_t != 0, dtype=tf.int32)
    # The jump values at bundle starting places
    jump_t = tf.boolean_mask(diff_t, bundle_start_mask > 0)
    # The task_container_id at bundle starting places
    task_container_id_at_jumps = tf.boolean_mask(parsed['task_container_id'], bundle_start_mask > 0)
    # Maps task_container_id to jump values
    task_containder_id_to_jump_t = tf.scatter_nd(
        indices=task_container_id_at_jumps[:, tf.newaxis], updates=jump_t, shape=tf.math.reduce_max(task_container_id_at_jumps + 1)[tf.newaxis], name=None
    )
    # `Our own version` of lag time.
    lag_time = tf.gather(params=task_containder_id_to_jump_t , indices=parsed['task_container_id'])
    # Make sure >= `0`
    lag_time = tf.math.maximum(tf.cast(0, dtype=tf.int64), lag_time)

    # set to 1 second if `lag_time == 0` but they are not in the same bunlde 
    _mask = tf.cast(tf.math.logical_and(lag_time == 0, parsed['timestamp'] > 0), dtype=tf.int64) 
    lag_time = lag_time * (1 - _mask) + 1000 * _mask 

    parsed['lag_time'] = lag_time
    
    return parsed


# --------------------------------------------------------------------------------
# For validation dataset

train_features_with_valid_info = {
    'user_id': tf.io.FixedLenFeature([], dtype=tf.int64),
    'row_id': tf.io.RaggedFeature(value_key='row_id', dtype=tf.int64),
    'timestamp': tf.io.RaggedFeature(value_key='timestamp', dtype=tf.int64),
    'content_id': tf.io.RaggedFeature(value_key='content_id', dtype=tf.int64),
    'content_type_id': tf.io.RaggedFeature(value_key='content_type_id', dtype=tf.int64),
    'task_container_id': tf.io.RaggedFeature(value_key='task_container_id', dtype=tf.int64),
    'user_answer': tf.io.RaggedFeature(value_key='user_answer', dtype=tf.int64),
    'answered_correctly': tf.io.RaggedFeature(value_key='answered_correctly', dtype=tf.int64),
    'prior_question_elapsed_time': tf.io.RaggedFeature(value_key='prior_question_elapsed_time', dtype=tf.float32),
    'prior_question_had_explanation': tf.io.RaggedFeature(value_key='prior_question_had_explanation', dtype=tf.int64),
    # -------------------------------------------------------------------
    # extra information
    'n_prev_seen': tf.io.RaggedFeature(value_key='n_prev_seen', dtype=tf.int64),
    'n_prev_correctness': tf.io.RaggedFeature(value_key='n_prev_correctness', dtype=tf.int64),
    # -------------------------------------------------------------------
    # validation information
    'n_valid_blocks': tf.io.FixedLenFeature([], dtype=tf.int64),
    'valid_blocks_start_pos': tf.io.RaggedFeature(value_key='valid_blocks_start_pos', dtype=tf.int64),
    'valid_blocks_end_pos': tf.io.RaggedFeature(value_key='valid_blocks_end_pos', dtype=tf.int64),
    # -------------------------------------------------------------------
}


def parse_train_example_with_valid_info(example):

    _parsed = tf.io.parse_single_example(example, train_features_with_valid_info)
    
    parsed = {}
    parsed['user_id'] = _parsed['user_id']
    parsed['seq_len'] = tf.reduce_sum(tf.ones_like(_parsed['row_id'], dtype=tf.int32))
    parsed['prev_seq_len'] = parsed['seq_len']
    parsed['start'] = tf.constant(0, dtype=tf.int32)
    parsed['end'] = parsed['seq_len'] - 1
    
    for k in train_features_with_valid_info:
        
        data = _parsed[k]
        if k not in ['row_id', 'user_id', 'timestamp', 'prior_question_elapsed_time']:
            data = tf.cast(data, dtype=tf.int32)
        elif k == 'prior_question_elapsed_time':
            data = tf.cast(data, dtype=DTYPE)
        if k != 'user_id':
            parsed[k] = data

    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    # This makes `parsed['shifted_answered_correctly']` has 1 more extra element at this moment.
    parsed['shifted_answered_correctly'] = tf.concat([[START_TOKEN], parsed['answered_correctly']], axis=0)           

    # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
    # This makes `parsed['shifted_user_answer']` has 1 more extra element at this moment.
    parsed['shifted_user_answer'] = tf.concat([[START_TOKEN], parsed['user_answer']], axis=0)

    # This should be all `0`.
    pad_mask = tf.cast(parsed['row_id'] == PAD_TOKEN, dtype=tf.int32)

    # The tfrecord dataset, in the raw format, contains training examples with validation information.
    # At this step, no place is consider to be in the prediction time yet.
    # This information will be updated in a further dataset transformation.
    pred_time_mask = tf.zeros_like(parsed['timestamp'], dtype=tf.int32)
    
    # This should be all `0`.
    pred_time_mask = pred_time_mask * (1 - pad_mask) + (PAD_TOKEN) * pad_mask
    
    parsed['pred_time_mask'] = pred_time_mask

    parsed['abs_pos'] = tf.range(parsed['seq_len'], dtype=tf.int32)
    parsed['shifted_abs_pos'] = tf.concat([[START_TOKEN], parsed['abs_pos'][:-1]], axis=0)
    
    # ----------------------------------------------------------------------------------------------------
    # To get the lag time
    # Block copied from `parse_train_example`.

    diff_t = parsed['timestamp'] - tf.concat([[1], parsed['timestamp'][:-1]], axis=0)
    bundle_start_mask = tf.cast(diff_t != 0, dtype=tf.int32)
    # The jump values at bundle starting places
    jump_t = tf.boolean_mask(diff_t, bundle_start_mask > 0)
    # The task_container_id at bundle starting places
    task_container_id_at_jumps = tf.boolean_mask(parsed['task_container_id'], bundle_start_mask > 0)
    # Maps task_container_id to jump values
    task_containder_id_to_jump_t = tf.scatter_nd(
        indices=task_container_id_at_jumps[:, tf.newaxis], updates=jump_t, shape=tf.math.reduce_max(task_container_id_at_jumps + 1)[tf.newaxis], name=None
    )
    # `Our own version` of lag time.
    lag_time = tf.gather(params=task_containder_id_to_jump_t , indices=parsed['task_container_id'])
    # Make sure >= `0`
    lag_time = tf.math.maximum(tf.cast(0, dtype=tf.int64), lag_time)

    # set to 1 second if `lag_time == 0` but they are not in the same bunlde 
    _mask = tf.cast(tf.math.logical_and(lag_time == 0, parsed['timestamp'] > 0), dtype=tf.int64) 
    lag_time = lag_time * (1 - _mask) + 1000 * _mask 

    parsed['lag_time'] = lag_time
    
    return parsed

In [None]:
if not IS_KAGGLE:

    train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=1)
    train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=1, deterministic=True)
    for x in train_raw_ds.take(1):
        print(x)

#### check

In [None]:
if not IS_KAGGLE:

    valid_raw_ds = tf.data.TFRecordDataset(valid_tfrec_paths, num_parallel_reads=1)
    valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=1, deterministic=True)
    for x in valid_raw_ds.take(1):
        # pass
        print(x)

### Split the dataset into training / validation parts

This is used to split the training and validation datasets.

In [None]:
def convert_split_index_dict(split_index_dict):

    user_ids = []
    split_indices = []
    for k, v in split_index_dict.items():
        user_ids.append(k)
        split_indices.append(v)

    user_id_tensor = tf.constant(user_ids, dtype=tf.int64)
    split_index_tensor = tf.constant(split_indices, dtype=tf.int32)

    initializer = tf.lookup.KeyValueTensorInitializer(user_id_tensor, split_index_tensor)

    split_index_table = tf.lookup.StaticHashTable(
        initializer, default_value=tf.int32.limits[1], name=None
    )

    return split_index_table


def split_train_example(raw_example, split_index_table):
    """Split an original train example to actual training part and validation part, and only return the training part.
    """
    
    user_id = raw_example['user_id']
    seq_len = raw_example['seq_len']
    
    split_index = split_index_table.lookup(user_id)
    tf.debugging.Assert(tf.reduce_all(split_index >= 0), [split_index])
    
    example = {}
    
    # `user_id` not showing in `split_index_table` - not used for validation.
    # `split_index` becomes `seq_len` - i.e. all the interactions belongs to training.
    if split_index == tf.int32.limits[1]:
        split_index = seq_len
    
    example['user_id'] = raw_example['user_id']
    example['seq_len'] = split_index
    example['prev_seq_len'] = raw_example['seq_len']
    example['start'] = tf.constant(0, dtype=tf.int32)
    example['end'] = split_index - 1    
        
    for k in raw_example:
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end']:
            example[k] = raw_example[k][0:split_index]
    
    return example


def split_train_ds(train_raw_ds, split_indices, num_parallel_calls=None, deterministic=None):
    
    reduced_raw_ds = train_raw_ds.map(lambda example: split_train_example(example, split_indices), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return reduced_raw_ds

### Training dataset transformation - from tf.RaggedTensor to tf.Tensor

We sample random subsequences of user interactions (in the splitted training dataset) for training.

In [None]:
def extract_subseqs(seqs, ending_indices, window_size, seq_len):

    """Let `seqs` be a `tf.RaggedTensor` be a tensor with rank = 2, where the 1st and 2nd dimensions are
       batch dimension and temporal and the unique ragged dimesion. Let `ending_indices` be a 1-D tensor
       with the same batch dimension as `seqs`. The condition `-1 <= ending_indices[i] < len(seqs[i])`
       must holds.

       For each example `seqs[i]`, we extract a partial sequence
       `seqs[ending_indices[i] - window_size : ending_indices[i]]`. If `len(seqs) < window_size`, the invalid
       indices will get `PAD_TOKEN` as values.
    
    Args:
        seqs: A `tf.RaggedTensor` tensor with rank = 2. The 1st dim is the batch dimension,
            and the 2nd dim is the temporal dimension.  The temporal dimension is the ragged rank,
            i.e. the unique dimension which is ragged.
        
        ending_indices: A 1-D `tf.int32` tensor with shape = [ragged_tensor.shape[0]]
        
        window_size: A scalar `tf.int32` tensor, which should be positive.
    """
    
    orig_window_size = window_size

    window_size += 1

    tf.debugging.Assert(tf.rank(seqs) >= 2, [tf.rank(seqs)])
        
    tf.debugging.Assert(tf.reduce_all(window_size > 0), [window_size])
    tf.debugging.Assert(tf.reduce_all(ending_indices >= -1), [ending_indices])
    
    batch_size = tf.reduce_sum(tf.ones_like(seq_len, dtype=tf.int32))
                        
    tf.debugging.Assert(tf.reduce_all((ending_indices < seq_len)), [ending_indices, seq_len])
    
    # shape = [batch_size, 1]
    _ending_indices = ending_indices[:, tf.newaxis]
    
    # add by 1 because we will add ...
    _ending_indices += 1

    # shape = [1, window_size]
    ranges = tf.range(window_size)[tf.newaxis, :]
    
    # shape = [batch_size, window_size]
    indices = _ending_indices - ranges
    
    # reverse the sequence dimension to get the correct temporal direction
    # shape = [batch_size, window_size]
    indices = tf.reverse(indices, axis=[1])
        
    # shape = [batch_size, window_size, 2]
    indices_to_seqs = tf.stack([tf.broadcast_to(tf.range(batch_size)[:, tf.newaxis], shape=[batch_size, window_size]), indices], axis=2)
    
    # Change negative indices to 0
    # shape = [batch_size, window_size, 2]
    indices_to_ext_seqs = tf.math.maximum(indices_to_seqs, 0)
    
    # Need to rework
    invalid_values = tf.cast(tf.ones(shape=seqs.shape[2:]) * PAD_TOKEN, seqs.dtype)
        
    # shape = [batch_size, 1, ...]
    invalid_values = tf.repeat([[invalid_values]], repeats=batch_size, axis=0)
        
    # Same shape as `ragged_tensor`, but each sequence has 1 more element inserted at the beginning
    extended_seqs = tf.concat([invalid_values, seqs], axis=1)
            
    # selected
    # shape = [batch_size, window_size, ...]
    sampled_subseqs = tf.gather_nd(extended_seqs, indices=indices_to_ext_seqs)

    # take the last `orig_window_size` part
    sampled_subseqs = sampled_subseqs[:, 1:]
    prev_last_element = sampled_subseqs[:, 0]
    
    return sampled_subseqs, prev_last_element


def random_ending_indices(seq_len, seed=None):

    batch_size = tf.reduce_sum(tf.ones_like(seq_len))
        
    max_seq_len = tf.cast(tf.math.reduce_max(seq_len), dtype=DTYPE)
        
    # shape = [batch_size]
    ending_indices = -1.0 + tf.random.uniform(minval=0.0, maxval=1.0, dtype=DTYPE, shape=[batch_size], seed=seed) * tf.cast(seq_len + 1, dtype=DTYPE)
    ending_indices = tf.cast(tf.math.floor(ending_indices), dtype=tf.int32)    

    # sanity check
    tf.debugging.Assert(tf.math.reduce_min(seq_len - ending_indices) >= 1, [seq_len, ending_indices])   
    tf.debugging.Assert(tf.math.reduce_min(ending_indices) >= -1, [seq_len, ending_indices])   
    
    return ending_indices
    
    
def random_subseqs(seqs, window_size, seq_len, seed=None):
    """Let `seqs` be a `tf.RaggedTensor` be a tensor with rank = 2, where the 1st and 2nd dimensions are
       batch dimension and temporal and the unique ragged dimesion.
       
       For each example `seqs[i]`, we extract a random partial sequence of length <= window_size
    
    Args:
        seqs: A `tf.RaggedTensor` tensor with rank = 2. The 1st dim is the batch dimension,
            and the 2nd dim is the temporal dimension.  The temporal dimension is the ragged rank,
            i.e. the unique dimension which is ragged.
        
        window_size: A scalar `tf.int32` tensor, which should be positive.
    """

    ending_indices = random_ending_indices(seq_len=seq_len, seed=seed)
    sampled_subseqs, _ = extract_subseqs(seqs, ending_indices, window_size, seq_len=seq_len)
    
    return sampled_subseqs, ending_indices

# --------------------------------------------------------------------------------
# For train / validation

def extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size):
    """Extract subsequences from 
        - either a training batch (containing no validation part anymore) consisting of `tf.RaggedTensor` objects.
        - or a validatioin batch (should be the ending part only)
    The subsequences will have the same length `window_size` by padding from the beginning with `PAD_TOKEN`.
    """    

    tf.debugging.Assert(tf.reduce_all(raw_batch['start'] == tf.zeros_like(raw_batch['start'], dtype=tf.int32)), [raw_batch['start']])
    
    batch = {}
    batch['user_id'] = raw_batch['user_id']

    start = tf.math.maximum(ending_indices - window_size + 1, 0)
    end = ending_indices
    seq_len = (end - start) + 1
    
    batch['seq_len'] = seq_len
    batch['prev_seq_len'] = raw_batch['seq_len']    
    batch['start'] = start
    batch['end'] = end
    
    for k in raw_batch:
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end'] + [
            'n_valid_blocks', 'valid_blocks_start_pos', 'valid_blocks_end_pos', 'valid_block_pos', 'valid_block_idx', 'valid_start', 'valid_end'
        ]:
        
            subseqs, prev_last_element = extract_subseqs(seqs=raw_batch[k], ending_indices=ending_indices, window_size=window_size, seq_len=raw_batch['prev_seq_len'])
            batch[k] = subseqs

        elif k in ['n_valid_blocks', 'valid_block_idx', 'valid_start', 'valid_end']:
            batch[k] = raw_batch[k]

    return batch

# --------------------------------------------------------------------------------

def random_subseqs_from_raw_batch(raw_batch, window_size, only_last, seed=None):
        
    if only_last:
        ending_indices = raw_batch['seq_len'] - 1
    else:
        ending_indices = random_ending_indices(raw_batch['seq_len'], seed=seed)
        ending_indices = tf.math.maximum(ending_indices, window_size - 1)
        _mask = tf.cast(ending_indices >= raw_batch['seq_len'], tf.int32)
        ending_indices = ending_indices * (1 - _mask) + (raw_batch['seq_len'] - 1) * _mask

    batch = extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size)
    
    return batch, ending_indices


def random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size, only_last=False, seed=None, num_parallel_calls=None, deterministic=None):
    
    batched_ds = batched_raw_ds.map(lambda raw_batch: random_subseqs_from_raw_batch(raw_batch, window_size, only_last, seed), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return batched_ds

## Convert the dataset to model inputs

### Tables mapping question / lectures / tags ids (etc.) to encoder / decoder input ids

### Prepare encoder / decoder input ids and attention masks

#### special masks

In [None]:
def get_causal_attention_mask(nd, ns, dtype, only_before):
    """
    1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
    -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    
    # Remark: Think `nd` as the number of queries and `ns` as the number of keys.
    # In encoder-decoder case, the queries are the decoder features and the keys are the encoder features.
    
    i = tf.range(nd)[:, tf.newaxis]  # repeat along dim 1
    j = tf.range(ns) # repeat along dim 0 
    m = i >= (j - ns + nd) + tf.cast(only_before, dtype=tf.int32)
    
    return tf.cast(m, dtype)

def get_attention_mask_from_timestamp_batch(timestamp_tensors, dtype, only_before):
    """
    Args:
        timestamp_tensors: 2-D tf.int32 tensor, representing a batch of sequences of non-decreasing timestamps.
    
    Returns:
        attention_mask: 3-D tf.int32 tensor of shape = [batch_size, query_len, key_len], consisting of 0 and 1.
            Here `query_len` and `key_len` are actually `seq_len`. It should be reshpaed, when used to calculate 
            attention scores, to [batch_size, nb_attn_head, query_len, key_len].
    """
    
    t = timestamp_tensors
    
    batch_size = tf.math.reduce_sum(tf.ones_like(t[:, :1], dtype=tf.int32))
    seq_len = tf.math.reduce_sum(tf.ones_like(t[:1, :], dtype=tf.int32))
    
    x = tf.broadcast_to(t[:, :, tf.newaxis], shape=[batch_size, seq_len, seq_len]) # repeat along dim 2
    y = tf.broadcast_to(t[:, tf.newaxis, :], shape=[batch_size, seq_len, seq_len]) + tf.cast(only_before, dtype=tf.int64) # repeat along dim 1
    
    m =  x >= y
    
    return tf.cast(m, dtype)

##### check

In [None]:
if not IS_KAGGLE:

    print(get_causal_attention_mask(3, 3, tf.int32, only_before=tf.constant(False)))
    print(get_causal_attention_mask(3, 3, tf.int32, only_before=tf.constant(True)))

    timestamp_tensors = tf.constant([[0, 0, 1, 2, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 3]], dtype=tf.int64)

    print(get_attention_mask_from_timestamp_batch(timestamp_tensors, tf.int32, only_before=tf.constant(False)))
    print(get_attention_mask_from_timestamp_batch(timestamp_tensors, tf.int32, only_before=tf.constant(True)))

#### Tensors required for model (input ids, targets, attention masks)

In [None]:
@tf.function
def get_attention_masks(batch, training, encoder_decoder, generative, allow_bundle_atten):

    pad_mask = tf.cast((batch['content_type_id'] == PAD_TOKEN), dtype=tf.int32)
    question_mask = tf.cast((batch['content_type_id'] == 0), dtype=tf.int32)

    # Replace `PAD_TOKEN` by `0`.
    pred_time_mask = batch['pred_time_mask'] * tf.cast(batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)

    content_ids = batch['content_id']

    # sanity check
    _pad_mask = tf.cast((batch['timestamp'] == PAD_TOKEN), dtype=tf.int32)
#     tf.debugging.Assert(tf.reduce_all(pad_mask == _pad_mask), [pad_mask, _pad_mask])    

    # ----------------------------------------
    # General mask
    # `seq_len` below is actually the `window_size`.
    
    seq_len = tf.math.reduce_sum(tf.ones_like(content_ids[0, :], dtype=tf.int32))
    
    # Don't pay attention to [PAD]
    # shape = [batch_size, seq_len]
    non_pad_mask = 1 - pad_mask
    
    # Can only attent to the current and previous position
    # shape = [seq_len, seq_len]
    causal_attention_mask = get_causal_attention_mask(nd=seq_len, ns=seq_len, dtype=tf.int32, only_before=tf.constant(False))
   
    # Can only attend to the current previous timestamps
    # shape = [batch_size, seq_len, seq_len]
    timestamp_attention_mask = get_attention_mask_from_timestamp_batch(batch['timestamp'], dtype=tf.int32, only_before=tf.constant(False))

    # Can only attend to previous timestamps
    # shape = [batch_size, seq_len, seq_len]
    timestamp_attention_mask_only_before = get_attention_mask_from_timestamp_batch(batch['timestamp'], dtype=tf.int32, only_before=tf.constant(True))
        
    # ----------------------------------------
    # content self attention mask
    # shape = [batch_size, seq_len, seq_len]

    if tf.cast(allow_bundle_atten, dtype=tf.bool):
        # can see the contents in the current and previous bundle, not including [PAD]
        c_mask = non_pad_mask[:, tf.newaxis, :] * timestamp_attention_mask
    else:
        # can see the current and previous contents, not including [PAD]
        c_mask = non_pad_mask[:, tf.newaxis, :] * causal_attention_mask[tf.newaxis, :, :]
    
    # ----------------------------------------
    # response self attention mask
    # shape = [batch_size, seq_len, seq_len]    
    
    # for `encoder-only` model: same as `c_mask`
    r_mask = c_mask

    if tf.cast(encoder_decoder, dtype=tf.bool):
        
        # for `encoder-decoder` model: should be causal masking
        d_mask = non_pad_mask[:, tf.newaxis, :] * causal_attention_mask[tf.newaxis, :, :]

        # if neither `training` nor `generative`
        if (1 - training) * (1 - generative) == 1:

            pred_time_question_mask = pred_time_mask * question_mask
            
            d_mask = d_mask * tf.math.maximum((1 - pred_time_question_mask)[:, tf.newaxis, :], tf.eye(seq_len, dtype=tf.int32)[tf.newaxis, :, :])
            r_mask = d_mask
            c_mask = c_mask * tf.math.maximum((1 - pred_time_question_mask)[:, tf.newaxis, :], tf.eye(seq_len, dtype=tf.int32)[tf.newaxis, :, :])

        else:
            r_mask = d_mask

    # ----------------------------------------
    # response to content to attention_mask
    # shape = [batch_size, seq_len, seq_len]
    
    r_c_mask = c_mask

    # ----------------------------------------
    # Used for encoder-only models.
    # Can see only the responses in the previous bundles, not including [PAD]

    # shape = [batch_size, seq_len, seq_len]
    c_r_mask = non_pad_mask[:, tf.newaxis, :] * timestamp_attention_mask_only_before
    
    return c_mask, r_mask, r_c_mask, c_r_mask


def add_input_ids_and_targets(batch, training, generative, use_abs_pos):
    """Add input ids and targets for training
    """
        
    content_ids = batch['content_id']
    content_type_ids = batch['content_type_id']
    answered_correctly = batch['answered_correctly']
    user_answer = batch['user_answer']
    
    question_mask = tf.cast((content_type_ids == 0), dtype=tf.int32)
    lecture_mask = tf.cast((content_type_ids == 1), dtype=tf.int32)
    
    pad_mask = tf.cast((content_type_ids == PAD_TOKEN), dtype=tf.int32)
    _pad_mask = tf.cast((batch['timestamp'] == PAD_TOKEN), dtype=tf.int32)
#     tf.debugging.Assert(tf.reduce_all(pad_mask == _pad_mask), [pad_mask, _pad_mask])
    
    # Replace `PAD_TOKEN` by `0`.
    pred_time_mask = batch['pred_time_mask'] * tf.cast(batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)
    pred_time_question_mask = pred_time_mask * question_mask
    
    # The number of questions in prediction time
    # shape = [batch_size]
    n_questions_in_pred_time = tf.math.reduce_sum(pred_time_question_mask, axis=1)
    # sanity check
#     tf.debugging.Assert(tf.reduce_all(n_questions_in_pred_time >= 0), [n_questions_in_pred_time])      

    response_false_mask = tf.cast((answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    response_true_mask = tf.cast((answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)
    
    # ----------------------------------------

    _batch = {}
    for k in batch:
        _batch[k] = batch[k]

    # Here, we use global variables, which is not good. But this is for fast experiments!!!
    # Shifted by one, `PAD_TOKEN` --> `0`
    _batch['abs_pos_ids'] = tf.cast(tf.math.minimum(tf.math.maximum(0, _batch['abs_pos'] + 1), MAX_HISTORY_LEN), dtype=DTYPE) / tf.cast(MAX_HISTORY_LEN, dtype=DTYPE) * WINDOW_SIZE
    # Shifted by one, `PAD_TOKEN` --> `0`, `START_TOKEN` --> `0`.
    _batch['shifted_abs_pos_ids'] = tf.cast(tf.math.minimum(tf.math.maximum(0, _batch['shifted_abs_pos'] + 1), MAX_HISTORY_LEN - 1), dtype=DTYPE) / tf.cast(MAX_HISTORY_LEN, dtype=DTYPE) * WINDOW_SIZE

    _batch['pos_ids'] = tf.math.cumsum(1 - pad_mask, axis=1)
    _batch['shifted_pos_ids'] = tf.math.maximum(0, _batch['pos_ids'] - 1)

    # ----------------------------------------
    # content_input_ids        
    
    content_input_ids = question_mask * question_id_to_input_id_table.lookup(content_ids) + \
        lecture_mask * lecture_id_to_input_id_table.lookup(content_ids) + \
        pad_mask * PAD_ID
    
    _batch['c_input_ids'] = content_input_ids
    
    # ----------------------------------------
    # response_input_ids - only used for encoder-only models
        
    # For `RESPONSE_LECTURE_ID`, we need to multiply by `(1 - pred_time_question_mask) * lecture_mask` instead of just `lecture_mask`.
    # Reason: we might have lectures occur during the prediction time. And these should be assigned to `RESPONSE_MASK_ID`.
    # If we only multiply by `lecture_mask`, we will get `RESPONSE_MASK_ID + RESPONSE_LECTURE_ID` which gives OOV error for embedding.
    
    response_input_ids_masked = pad_mask * PAD_ID + pred_time_question_mask * MASK_ID + lecture_mask * RESPONSE_LECTURE_ID + (1 - pred_time_question_mask) * response_false_mask * RESPONSE_FALSE_ID + (1 - pred_time_question_mask) * response_true_mask * RESPONSE_TRUE_ID
    _batch['r_input_ids'] = response_input_ids_masked

    # ----------------------------------------
    # d_input_ids

    shifted_answered_correctly = batch['shifted_answered_correctly']

    shifted_pad_mask = tf.cast((shifted_answered_correctly == PAD_TOKEN), dtype=tf.int32)
    shifted_start_mask = tf.cast((shifted_answered_correctly == START_TOKEN), dtype=tf.int32)
    shifted_masking_mask = tf.cast((shifted_answered_correctly == MASK_TOKEN), dtype=tf.int32)
    shifted_response_lecture_mask = tf.cast((shifted_answered_correctly == RESPONSE_LECTURE_TOKEN), dtype=tf.int32)
    shifted_response_false_mask = tf.cast((shifted_answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    shifted_response_true_mask = tf.cast((shifted_answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)

    ### Fixed
    #decoder_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_response_lecture_mask * RESPONSE_LECTURE_ID + shifted_response_false_mask * RESPONSE_FALSE_ID + shifted_response_true_mask * RESPONSE_TRUE_ID
    decoder_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_masking_mask * MASK_ID + shifted_response_lecture_mask * RESPONSE_LECTURE_ID + shifted_response_false_mask * RESPONSE_FALSE_ID + shifted_response_true_mask * RESPONSE_TRUE_ID

    # ----------------------------------------
    # d_ans_input_ids

    shifted_user_answer = batch['shifted_user_answer']
    shifted_answer_0_mask = tf.cast((shifted_user_answer == ANSWER_0_TOKEN), dtype=tf.int32)
    shifted_answer_1_mask = tf.cast((shifted_user_answer == ANSWER_1_TOKEN), dtype=tf.int32)
    shifted_answer_2_mask = tf.cast((shifted_user_answer == ANSWER_2_TOKEN), dtype=tf.int32)
    shifted_answer_3_mask = tf.cast((shifted_user_answer == ANSWER_3_TOKEN), dtype=tf.int32)

    decoder_answer_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_masking_mask * MASK_ID + \
        shifted_response_lecture_mask * ANSWER_LECTURE_ID + \
        shifted_answer_0_mask * ANSWER_0_ID + shifted_answer_1_mask * ANSWER_1_ID + \
        shifted_answer_2_mask * ANSWER_2_ID + shifted_answer_3_mask * ANSWER_3_ID

    # ----------------------------------------
    # post processing of decoder inputs

    if not generative:

        # This either contains exactly one place with `1`, or all places are `0`.
        pred_time_question_start_mask = tf.cast(tf.math.cumsum(pred_time_question_mask, axis=1) == 1, dtype=tf.int32)
        
        # shape = [batch_size]
        pred_time_question_start_value_response = tf.math.reduce_sum(pred_time_question_start_mask * decoder_input_ids, axis=1)
        pred_time_question_start_value_answer = tf.math.reduce_sum(pred_time_question_start_mask * decoder_answer_input_ids, axis=1)

        # If at the starting of questions, we get `MASK_ID`, we change it to `RESPONSE_LECTURE_ID` / `ANSWER_LECTURE_ID`
        prev_lecture_mask = tf.cast(pred_time_question_start_value_response == MASK_ID, tf.int32)            
        
        pred_time_question_start_value_response = RESPONSE_LECTURE_ID * prev_lecture_mask + pred_time_question_start_value_response * (1 - prev_lecture_mask)
        pred_time_question_start_value_answer = ANSWER_LECTURE_ID * prev_lecture_mask + pred_time_question_start_value_answer * (1 - prev_lecture_mask)

        # All places in prediction time share the values at the prediction question starting places.
        decoder_input_ids = decoder_input_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pred_time_question_start_value_response[:, tf.newaxis]
        decoder_answer_input_ids = decoder_answer_input_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pred_time_question_start_value_answer[:, tf.newaxis]

        # If there are remaining MASK_ID, its previous place is in prediction time, and we are sure it is a lecture at that moment.
        decoder_input_ids = decoder_input_ids * tf.cast(decoder_input_ids != MASK_ID, dtype=tf.int32) + RESPONSE_LECTURE_ID * tf.cast(decoder_input_ids == MASK_ID, dtype=tf.int32)
        decoder_answer_input_ids = decoder_answer_input_ids * tf.cast(decoder_answer_input_ids != MASK_ID, dtype=tf.int32) + ANSWER_LECTURE_ID * tf.cast(decoder_answer_input_ids == MASK_ID, dtype=tf.int32)

   # ----------------------------------------
   # add to `_batch`

    _batch['d_input_ids'] = decoder_input_ids
    _batch['d_ans_input_ids'] = decoder_answer_input_ids

    # ----------------------------------------
    # post processing `pos_ids` and `shifted_pos_ids`
    # Once the real decoder is implemented, we need to fix this.

    if not generative:

        pos_ids = _batch['pos_ids']
        shifted_pos_ids = _batch['shifted_pos_ids']
        abs_pos_ids = _batch['abs_pos_ids']
        shifted_abs_pos_ids = _batch['shifted_abs_pos_ids']

        pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * pos_ids, axis=1)
        shifted_pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * shifted_pos_ids, axis=1)

        pos_ids = pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pos_ids_question_start_value[:, tf.newaxis]
        shifted_pos_ids = shifted_pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * shifted_pos_ids_question_start_value[:, tf.newaxis]

        abs_pos_ids_question_start_value = tf.math.reduce_sum(tf.cast(pred_time_question_start_mask, dtype=DTYPE) * abs_pos_ids, axis=1)
        shifted_abs_pos_ids_question_start_value = tf.math.reduce_sum(tf.cast(pred_time_question_start_mask, dtype=DTYPE) * shifted_abs_pos_ids, axis=1)

        abs_pos_ids = abs_pos_ids * tf.cast(1 - pred_time_question_mask, dtype=DTYPE) + tf.cast(pred_time_question_mask, dtype=DTYPE) * abs_pos_ids_question_start_value[:, tf.newaxis]
        shifted_abs_pos_ids = shifted_abs_pos_ids * tf.cast(1 - pred_time_question_mask, dtype=DTYPE) + tf.cast(pred_time_question_mask, dtype=DTYPE) * shifted_abs_pos_ids_question_start_value[:, tf.newaxis]

        _batch['pos_ids'] = pos_ids
        _batch['shifted_pos_ids'] = shifted_pos_ids
        _batch['abs_pos_ids'] = abs_pos_ids
        _batch['shifted_abs_pos_ids'] = shifted_abs_pos_ids        

    # ----------------------------------------
    # tags

    # The original tags are added by `1`
    _batch['tag_ids'] = tf.gather(params=c_inputs_ids_to_tags, indices=_batch['c_input_ids']) + 1 
        
    # ----------------------------------------        
    # part

    # The original parts are added by `1`
    _batch['part_ids'] = tf.gather(params=c_inputs_ids_to_part, indices=_batch['c_input_ids']) + 1

    # ----------------------------------------
    # prior explanation ids - just `prior_question_had_explanation` added by 1.
    # For PAD, we get `-1` but changed to `0`.
    # only used along with `d_input_ids`.

    _batch['prior_explanation_ids'] = tf.math.maximum(0, _batch['prior_question_had_explanation'] + 1)

    # ----------------------------------------
    # `prior_question_elapsed_time` in seconds
    # scaled to [0, 1]

    _batch['prior_question_elapsed_time_input'] = tf.cast(_batch['prior_question_elapsed_time'], dtype=DTYPE) / 1000.0 / 300.0

    # ----------------------------------------
    # normalized `lag_time`

    lag_time = tf.cast(_batch['lag_time'], dtype=DTYPE)
    # nb. of hours
    lag_time = lag_time / 1000.0 / 3600.0
    # If `lag_time` > `72 hours` --> set it to `72 hours`.
    lag_time = tf.math.minimum(lag_time, 72.0)

    _batch['lag_time'] = lag_time

    # ----------------------------------------
    # use `task_container_ids` as positional information

    task_container_pos_ids = tf.cast(_batch['task_container_id'], dtype=DTYPE) / MAX_TASK_CONTAINER_ID * 10.0
    _batch['task_container_pos_ids'] = task_container_pos_ids

    # ----------------------------------------
    # answer correctness target
    
    # `-2` means padding, `-1` means lecture
    answer_mask = tf.cast(batch['answered_correctly'] > -1, dtype=tf.int32)
    
    # negated values become `NON_TARGET_ID (-100)`
    _batch['target'] = batch['answered_correctly'] * answer_mask + (NON_TARGET_ID) * (1 - answer_mask)

    # ----------------------------------------
    # answer target
        
    # negated values become `NON_TARGET_ID (-100)`
    _batch['answer_target'] = batch['user_answer'] * answer_mask + (NON_TARGET_ID) * (1 - answer_mask)
    
    # ----------------------------------------
    # correct_answer_id  
        
    # Unlike `tag_ids` or `part_ids`, we don't need to have `+ 1` because `c_inputs_ids_to_correct_answer_id` is built in a slightly different way.
    _batch['correct_answer_id'] = tf.gather(params=c_inputs_ids_to_correct_answer_id, indices=_batch['c_input_ids'])
    
    # ----------------------------------------
    # nb_pred_places

    targets = _batch['target']

    # `targets` are defined for all places (other than [PAD] and lectures).
    # However, during validation, unlike during training, we only focus on the places that are in prediction time (and being questions).
    # This should be used only in `train_step` and `valid_step`, `train` and `valid`, but not in `run_pred`.
    if training == 0:
        targets = targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)

    pred_mask = targets != NON_TARGET_ID
    nb_pred_places = tf.math.reduce_sum(tf.cast(pred_mask, dtype=tf.int32))

    # shape = [batch_size], but it is a constant
    _batch['nb_pred_places'] = nb_pred_places * tf.ones_like(_batch['user_id'], dtype=tf.int32)

    # ----------------------------------------------------------------------------------------------------
    # To process aggregated historical information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    scaling_factor = tf.constant(N_AGGREGATED_QUESTION_SCALING_FACTOR, dtype=tf.float32)
    
    _batch['n_questions_answered_scaled'] = tf.cast(_batch['n_questions_answered'], dtype=tf.float32) / scaling_factor
    _batch['n_lectures_watched_scaled'] = tf.cast(_batch['n_lectures_watched'], dtype=tf.float32) / scaling_factor    
    
    # set minimum to `1` to avoid division by `0` error.
    n_questions_answered = tf.cast(tf.math.maximum(_batch['n_questions_answered'], 1), dtype=tf.float32)
            
    _batch['answered_correctly_ratio'] = tf.cast(_batch['n_questions_answered_correctly'], dtype=tf.float32) / n_questions_answered
    
    # ----------------------------------------------------------------------------------------------------    
    # To process aggregated historical part information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    part_count_scaled = []
    part_correctness_ratio = []
    
    for part_idx in range(2, 9):
        
        key = f'part_{part_idx}_count'
        _part_count = tf.cast(_batch[key], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _part_count_scaled = _part_count / scaling_factor
        
        part_count_scaled.append(_part_count_scaled)
        
        key_2 = f'part_{part_idx}_correctness_count'
        _part_correctness_count = tf.cast(_batch[key_2], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _part_correctness_ratio = _part_correctness_count / tf.math.maximum(_part_count, 1.0)
        
        part_correctness_ratio.append(_part_correctness_ratio)
        
        # We don't want to keep these keys
        del _batch[key]
        del _batch[key_2]
        
    part_count_scaled = tf.stack(part_count_scaled, axis=-1)
    part_correctness_ratio = tf.stack(part_correctness_ratio, axis=-1)
    
    # shape = [batch_size, seq_len, PART_VOCAB_SIZE - 2]
    _batch['part_count_scaled'] = part_count_scaled
    _batch['part_correctness_ratio'] = part_correctness_ratio

    # ----------------------------------------------------------------------------------------------------  
    # To process `current` aggregated historical part information (i.e. not per part level)

    # shape = [batch_size, seq_len, PART_VOCAB_SIZE]
    _part_count_scaled = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            part_count_scaled
        ],
        axis=-1
    )

    # shape = [batch_size, seq_len, PART_VOCAB_SIZE]
    _part_correctness_ratio = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            part_correctness_ratio
        ],
        axis=-1
    )

    current_part_count_scaled = tf.gather(params=_part_count_scaled, indices=_batch['part_ids'], batch_dims=2)
    current_part_correctness_ratio = tf.gather(params=_part_correctness_ratio, indices=_batch['part_ids'], batch_dims=2)

    # shape = [batch_size, seq_len]
    _batch['current_part_count_scaled'] = current_part_count_scaled
    _batch['current_part_correctness_ratio'] = current_part_correctness_ratio
    
    # ----------------------------------------------------------------------------------------------------      
    # To process aggregated historical correct answer information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    correct_answer_count_scaled = []
    correct_answer_correctness_ratio = []
    
    for correct_answer_idx in range(ANSWER_0_ID, ANSWER_3_ID + 1):
        
        key = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_count'
        _correct_answer_count = tf.cast(_batch[key], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _correct_answer_count_scaled = _correct_answer_count / scaling_factor
        
        correct_answer_count_scaled.append(_correct_answer_count_scaled)
        
        key_2 = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_answered_correctly_count'
        _correct_answer_answered_correctly_count = tf.cast(_batch[key_2], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _correct_answer_answered_correctly_ratio = _correct_answer_answered_correctly_count / tf.math.maximum(_correct_answer_count, 1.0)
        
        correct_answer_correctness_ratio.append(_correct_answer_answered_correctly_ratio)
        
        # We don't want to keep these keys
        del _batch[key]
        del _batch[key_2]
        
    correct_answer_count_scaled = tf.stack(correct_answer_count_scaled, axis=-1)
    correct_answer_correctness_ratio = tf.stack(correct_answer_correctness_ratio, axis=-1)
    
    # shape = [batch_size, seq_len, 7]
    _batch['correct_answer_count_scaled'] = correct_answer_count_scaled
    _batch['correct_answer_correctness_ratio'] = correct_answer_correctness_ratio
    
    # ----------------------------------------------------------------------------------------------------     
    # To process `current` aggregated historical correct answer information (i.e. not per correct answer level)

    # shape = [batch_size, seq_len, ANSWER_VOCAB_SIZE - 3]
    _correct_answer_count_scaled = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,  # For lectures
            correct_answer_count_scaled
        ],
        axis=-1
    )

    # shape = [batch_size, seq_len, ANSWER_VOCAB_SIZE - 3]
    _correct_answer_correctness_ratio = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],  # For lectures
            correct_answer_correctness_ratio
        ],
        axis=-1
    )

    # `PAD_ID` gives index `0` and `ANSWER_LECTURE_ID` gives index `1`.
    # `ANSWER_START_ID`, `ANSWER_END_ID` and `ANSWER_MASK_ID` doesn't exist in `correct_answer_id`.
    _indices = tf.math.maximum(_batch['correct_answer_id'] - ANSWER_LECTURE_ID + 1, 0)

    current_correct_answer_count_scaled = tf.gather(params=_correct_answer_count_scaled, indices=_indices, batch_dims=2)
    current_correct_answer_correctness_ratio = tf.gather(params=_correct_answer_correctness_ratio, indices=_indices, batch_dims=2)

    # shape = [batch_size, seq_len]
    _batch['current_correct_answer_count_scaled'] = current_correct_answer_count_scaled
    _batch['current_correct_answer_correctness_ratio'] = current_correct_answer_correctness_ratio    

    # ----------------------------------------------------------------------------------------------------
        
    _batch['current_question_count_scaled'] = tf.cast(_batch['n_prev_seen'], dtype=tf.float32) / scaling_factor

    # set minimum to `1` to avoid division by `0` error.
    current_question_count = tf.cast(tf.math.maximum(_batch['n_prev_seen'], 1), dtype=tf.float32)
            
    _batch['current_question_correctness_ratio'] = tf.cast(_batch['n_prev_correctness'], dtype=tf.float32) / current_question_count

    # ----------------------------------------------------------------------------------------------------
    
    return _batch


def prepare_training_dataset(batched_ds, generative=False, use_abs_pos=False, num_parallel_calls=None, deterministic=None):
    # Add input ids and targets for training
    
    training = tf.constant(1, dtype=tf.int32)

    return batched_ds.map(lambda batch, _: add_input_ids_and_targets(batch, training, generative, use_abs_pos), num_parallel_calls=num_parallel_calls, deterministic=deterministic)

In [None]:
def get_input_signatures_2():
    
    batch_size = None
    seq_len = 128
    
    input_signatures = {
        'user_id': tf.TensorSpec(shape=[batch_size], dtype=tf.int64, name='user_id'),
        'seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='seq_len'),
        'prev_seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='prev_seq_len'),
        'start': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='start'),
        'end': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='end'),
        'row_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='row_id'),
        'timestamp': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='timestamp'),
        'content_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_id'),
        'content_type_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_type_id'),
        'task_container_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='task_container_id'),
        'user_answer': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='user_answer'),
        'shifted_user_answer': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_user_answer'),
        'answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='answered_correctly'),
        'shifted_answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_answered_correctly'),
        'prior_question_elapsed_time': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='prior_question_elapsed_time'),
        'prior_question_had_explanation': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='prior_question_had_explanation'),
        'pred_time_mask': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='pred_time_mask'),
        'abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='abs_pos'),
        'shifted_abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_abs_pos'),    
        'lag_time': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='lag_time'),
        'n_questions_answered': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_questions_answered'),  
        'n_lectures_watched': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_lectures_watched'),  
        'n_questions_answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_questions_answered_correctly'),  
        'n_prev_seen': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_prev_seen'),
        'n_prev_correctness': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_prev_correctness'),
        'part_2_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_2_count'),
        'part_3_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_3_count'),
        'part_4_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_4_count'),
        'part_5_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_5_count'),
        'part_6_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_6_count'),
        'part_7_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_7_count'),
        'part_8_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_8_count'),
        'part_2_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_2_correctness_count'),
        'part_3_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_3_correctness_count'),
        'part_4_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_4_correctness_count'),
        'part_5_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_5_correctness_count'),
        'part_6_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_6_correctness_count'),
        'part_7_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_7_correctness_count'),
        'part_8_correctness_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_8_correctness_count'),
        'correct_answer_0_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_0_count'),
        'correct_answer_1_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_1_count'),
        'correct_answer_2_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_2_count'),
        'correct_answer_3_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_3_count'),
        'correct_answer_0_answered_correctly_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_0_answered_correctly_count'),
        'correct_answer_1_answered_correctly_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_1_answered_correctly_count'),
        'correct_answer_2_answered_correctly_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_2_answered_correctly_count'),
        'correct_answer_3_answered_correctly_count': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_3_answered_correctly_count'),
    }
 
    return [input_signatures]

In [None]:
@tf.function(
    input_signature=get_input_signatures_2()
)
def add_input_ids_and_targets_2(batch):
    """Add input ids and targets for training
    """
        
    training = tf.constant(0, dtype=tf.int32)
    generative = tf.constant(False, dtype=tf.bool)
   
    content_ids = batch['content_id']
    content_type_ids = batch['content_type_id']
    answered_correctly = batch['answered_correctly']
    user_answer = batch['user_answer']
    
    question_mask = tf.cast((content_type_ids == 0), dtype=tf.int32)
    lecture_mask = tf.cast((content_type_ids == 1), dtype=tf.int32)
    
    pad_mask = tf.cast((content_type_ids == PAD_TOKEN), dtype=tf.int32)
    _pad_mask = tf.cast((batch['timestamp'] == PAD_TOKEN), dtype=tf.int32)
#     tf.debugging.Assert(tf.reduce_all(pad_mask == _pad_mask), [pad_mask, _pad_mask])
    
    # Replace `PAD_TOKEN` by `0`.
    pred_time_mask = batch['pred_time_mask'] * tf.cast(batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)
    pred_time_question_mask = pred_time_mask * question_mask
    
    # The number of questions in prediction time
    # shape = [batch_size]
    n_questions_in_pred_time = tf.math.reduce_sum(pred_time_question_mask, axis=1)
    # sanity check
#     tf.debugging.Assert(tf.reduce_all(n_questions_in_pred_time >= 0), [n_questions_in_pred_time])      

    response_false_mask = tf.cast((answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    response_true_mask = tf.cast((answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)
    
    # ----------------------------------------

    _batch = {}
    for k in batch:
        _batch[k] = batch[k]

    # Here, we use global variables, which is not good. But this is for fast experiments!!!
    # Shifted by one, `PAD_TOKEN` --> `0`
    _batch['abs_pos_ids'] = tf.cast(tf.math.minimum(tf.math.maximum(0, _batch['abs_pos'] + 1), MAX_HISTORY_LEN), dtype=DTYPE) / tf.cast(MAX_HISTORY_LEN, dtype=DTYPE) * WINDOW_SIZE
    # Shifted by one, `PAD_TOKEN` --> `0`, `START_TOKEN` --> `0`.
    _batch['shifted_abs_pos_ids'] = tf.cast(tf.math.minimum(tf.math.maximum(0, _batch['shifted_abs_pos'] + 1), MAX_HISTORY_LEN - 1), dtype=DTYPE) / tf.cast(MAX_HISTORY_LEN, dtype=DTYPE) * WINDOW_SIZE

    _batch['pos_ids'] = tf.math.cumsum(1 - pad_mask, axis=1)
    _batch['shifted_pos_ids'] = tf.math.maximum(0, _batch['pos_ids'] - 1)

    # ----------------------------------------
    # content_input_ids        
    
    content_input_ids = question_mask * question_id_to_input_id_table.lookup(content_ids) + \
        lecture_mask * lecture_id_to_input_id_table.lookup(content_ids) + \
        pad_mask * PAD_ID
    
    _batch['c_input_ids'] = content_input_ids
    
    # ----------------------------------------
    # response_input_ids - only used for encoder-only models
        
    # For `RESPONSE_LECTURE_ID`, we need to multiply by `(1 - pred_time_question_mask) * lecture_mask` instead of just `lecture_mask`.
    # Reason: we might have lectures occur during the prediction time. And these should be assigned to `RESPONSE_MASK_ID`.
    # If we only multiply by `lecture_mask`, we will get `RESPONSE_MASK_ID + RESPONSE_LECTURE_ID` which gives OOV error for embedding.
    
    response_input_ids_masked = pad_mask * PAD_ID + pred_time_question_mask * MASK_ID + lecture_mask * RESPONSE_LECTURE_ID + (1 - pred_time_question_mask) * response_false_mask * RESPONSE_FALSE_ID + (1 - pred_time_question_mask) * response_true_mask * RESPONSE_TRUE_ID
    _batch['r_input_ids'] = response_input_ids_masked

    # ----------------------------------------
    # d_input_ids

    shifted_answered_correctly = batch['shifted_answered_correctly']

    shifted_pad_mask = tf.cast((shifted_answered_correctly == PAD_TOKEN), dtype=tf.int32)
    shifted_start_mask = tf.cast((shifted_answered_correctly == START_TOKEN), dtype=tf.int32)
    shifted_masking_mask = tf.cast((shifted_answered_correctly == MASK_TOKEN), dtype=tf.int32)
    shifted_response_lecture_mask = tf.cast((shifted_answered_correctly == RESPONSE_LECTURE_TOKEN), dtype=tf.int32)
    shifted_response_false_mask = tf.cast((shifted_answered_correctly == RESPONSE_FALSE_TOKEN), dtype=tf.int32)
    shifted_response_true_mask = tf.cast((shifted_answered_correctly == RESPONSE_TRUE_TOKEN), dtype=tf.int32)

    ### Fixed
    #decoder_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_response_lecture_mask * RESPONSE_LECTURE_ID + shifted_response_false_mask * RESPONSE_FALSE_ID + shifted_response_true_mask * RESPONSE_TRUE_ID
    decoder_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_masking_mask * MASK_ID + shifted_response_lecture_mask * RESPONSE_LECTURE_ID + shifted_response_false_mask * RESPONSE_FALSE_ID + shifted_response_true_mask * RESPONSE_TRUE_ID

    # ----------------------------------------
    # d_ans_input_ids

    shifted_user_answer = batch['shifted_user_answer']
    shifted_answer_0_mask = tf.cast((shifted_user_answer == ANSWER_0_TOKEN), dtype=tf.int32)
    shifted_answer_1_mask = tf.cast((shifted_user_answer == ANSWER_1_TOKEN), dtype=tf.int32)
    shifted_answer_2_mask = tf.cast((shifted_user_answer == ANSWER_2_TOKEN), dtype=tf.int32)
    shifted_answer_3_mask = tf.cast((shifted_user_answer == ANSWER_3_TOKEN), dtype=tf.int32)

    decoder_answer_input_ids = shifted_pad_mask * PAD_ID + shifted_start_mask * START_ID + shifted_masking_mask * MASK_ID + \
        shifted_response_lecture_mask * ANSWER_LECTURE_ID + \
        shifted_answer_0_mask * ANSWER_0_ID + shifted_answer_1_mask * ANSWER_1_ID + \
        shifted_answer_2_mask * ANSWER_2_ID + shifted_answer_3_mask * ANSWER_3_ID

    # ----------------------------------------
    # post processing of decoder inputs

    ### if not generative:

    # This either contains exactly one place with `1`, or all places are `0`.
    pred_time_question_start_mask = tf.cast(tf.math.cumsum(pred_time_question_mask, axis=1) == 1, dtype=tf.int32)

    # shape = [batch_size]
    pred_time_question_start_value_response = tf.math.reduce_sum(pred_time_question_start_mask * decoder_input_ids, axis=1)
    pred_time_question_start_value_answer = tf.math.reduce_sum(pred_time_question_start_mask * decoder_answer_input_ids, axis=1)

    # If at the starting of questions, we get `MASK_ID`, we change it to `RESPONSE_LECTURE_ID` / `ANSWER_LECTURE_ID`
    prev_lecture_mask = tf.cast(pred_time_question_start_value_response == MASK_ID, tf.int32)            

    pred_time_question_start_value_response = RESPONSE_LECTURE_ID * prev_lecture_mask + pred_time_question_start_value_response * (1 - prev_lecture_mask)
    pred_time_question_start_value_answer = ANSWER_LECTURE_ID * prev_lecture_mask + pred_time_question_start_value_answer * (1 - prev_lecture_mask)

    # All places in prediction time share the values at the prediction question starting places.
    decoder_input_ids = decoder_input_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pred_time_question_start_value_response[:, tf.newaxis]
    decoder_answer_input_ids = decoder_answer_input_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pred_time_question_start_value_answer[:, tf.newaxis]

    # If there are remaining MASK_ID, its previous place is in prediction time, and we are sure it is a lecture at that moment.
    decoder_input_ids = decoder_input_ids * tf.cast(decoder_input_ids != MASK_ID, dtype=tf.int32) + RESPONSE_LECTURE_ID * tf.cast(decoder_input_ids == MASK_ID, dtype=tf.int32)
    decoder_answer_input_ids = decoder_answer_input_ids * tf.cast(decoder_answer_input_ids != MASK_ID, dtype=tf.int32) + ANSWER_LECTURE_ID * tf.cast(decoder_answer_input_ids == MASK_ID, dtype=tf.int32)

   # ----------------------------------------
   # add to `_batch`

    _batch['d_input_ids'] = decoder_input_ids
    _batch['d_ans_input_ids'] = decoder_answer_input_ids

    # ----------------------------------------
    # post processing `pos_ids` and `shifted_pos_ids`
    # Once the real decoder is implemented, we need to fix this.

    ### if not generative:

    pos_ids = _batch['pos_ids']
    shifted_pos_ids = _batch['shifted_pos_ids']
    abs_pos_ids = _batch['abs_pos_ids']
    shifted_abs_pos_ids = _batch['shifted_abs_pos_ids']

    pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * pos_ids, axis=1)
    shifted_pos_ids_question_start_value = tf.math.reduce_sum(pred_time_question_start_mask * shifted_pos_ids, axis=1)

    pos_ids = pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * pos_ids_question_start_value[:, tf.newaxis]
    shifted_pos_ids = shifted_pos_ids * (1 - pred_time_question_mask) + pred_time_question_mask * shifted_pos_ids_question_start_value[:, tf.newaxis]

    abs_pos_ids_question_start_value = tf.math.reduce_sum(tf.cast(pred_time_question_start_mask, dtype=DTYPE) * abs_pos_ids, axis=1)
    shifted_abs_pos_ids_question_start_value = tf.math.reduce_sum(tf.cast(pred_time_question_start_mask, dtype=DTYPE) * shifted_abs_pos_ids, axis=1)

    abs_pos_ids = abs_pos_ids * tf.cast(1 - pred_time_question_mask, dtype=DTYPE) + tf.cast(pred_time_question_mask, dtype=DTYPE) * abs_pos_ids_question_start_value[:, tf.newaxis]
    shifted_abs_pos_ids = shifted_abs_pos_ids * tf.cast(1 - pred_time_question_mask, dtype=DTYPE) + tf.cast(pred_time_question_mask, dtype=DTYPE) * shifted_abs_pos_ids_question_start_value[:, tf.newaxis]

    _batch['pos_ids'] = pos_ids
    _batch['shifted_pos_ids'] = shifted_pos_ids
    _batch['abs_pos_ids'] = abs_pos_ids
    _batch['shifted_abs_pos_ids'] = shifted_abs_pos_ids        

    # ----------------------------------------
    # tags

    # The original tags are added by `1`
    _batch['tag_ids'] = tf.gather(params=c_inputs_ids_to_tags, indices=_batch['c_input_ids']) + 1 
        
    # ----------------------------------------        
    # part

    # The original parts are added by `1`
    _batch['part_ids'] = tf.gather(params=c_inputs_ids_to_part, indices=_batch['c_input_ids']) + 1

    # ----------------------------------------
    # prior explanation ids - just `prior_question_had_explanation` added by 1.
    # For PAD, we get `-1` but changed to `0`.
    # only used along with `d_input_ids`.

    _batch['prior_explanation_ids'] = tf.math.maximum(0, _batch['prior_question_had_explanation'] + 1)

    # ----------------------------------------
    # `prior_question_elapsed_time` in seconds
    # scaled to [0, 1]

    _batch['prior_question_elapsed_time_input'] = tf.cast(_batch['prior_question_elapsed_time'], dtype=DTYPE) / 1000.0 / 300.0

    # ----------------------------------------
    # normalized `lag_time`

    lag_time = tf.cast(_batch['lag_time'], dtype=DTYPE)
    # nb. of hours
    lag_time = lag_time / 1000.0 / 3600.0
    # If `lag_time` > `72 hours` --> set it to `72 hours`.
    lag_time = tf.math.minimum(lag_time, 72.0)

    _batch['lag_time'] = lag_time

    # ----------------------------------------
    # use `task_container_ids` as positional information

    task_container_pos_ids = tf.cast(_batch['task_container_id'], dtype=DTYPE) / MAX_TASK_CONTAINER_ID * 10.0
    _batch['task_container_pos_ids'] = task_container_pos_ids

    # ----------------------------------------
    # answer correctness target
    
    # `-2` means padding, `-1` means lecture
    answer_mask = tf.cast(batch['answered_correctly'] > -1, dtype=tf.int32)
    
    # negated values become `NON_TARGET_ID (-100)`
    _batch['target'] = batch['answered_correctly'] * answer_mask + (NON_TARGET_ID) * (1 - answer_mask)

    # ----------------------------------------
    # answer target
        
    # negated values become `NON_TARGET_ID (-100)`
    _batch['answer_target'] = batch['user_answer'] * answer_mask + (NON_TARGET_ID) * (1 - answer_mask)
    
    # ----------------------------------------
    # correct_answer_id  
        
    # Unlike `tag_ids` or `part_ids`, we don't need to have `+ 1` because `c_inputs_ids_to_correct_answer_id` is built in a slightly different way.
    _batch['correct_answer_id'] = tf.gather(params=c_inputs_ids_to_correct_answer_id, indices=_batch['c_input_ids'])
    
    # ----------------------------------------
    # nb_pred_places

    targets = _batch['target']

    # `targets` are defined for all places (other than [PAD] and lectures).
    # However, during validation, unlike during training, we only focus on the places that are in prediction time (and being questions).
    # This should be used only in `train_step` and `valid_step`, `train` and `valid`, but not in `run_pred`.
    if training == 0:
        targets = targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)

    pred_mask = targets != NON_TARGET_ID
    nb_pred_places = tf.math.reduce_sum(tf.cast(pred_mask, dtype=tf.int32))

    # shape = [batch_size], but it is a constant
    _batch['nb_pred_places'] = nb_pred_places * tf.ones_like(_batch['user_id'], dtype=tf.int32)

    # ----------------------------------------------------------------------------------------------------
    # To process aggregated historical information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    scaling_factor = tf.constant(N_AGGREGATED_QUESTION_SCALING_FACTOR, dtype=tf.float32)
    
    _batch['n_questions_answered_scaled'] = tf.cast(_batch['n_questions_answered'], dtype=tf.float32) / scaling_factor
    _batch['n_lectures_watched_scaled'] = tf.cast(_batch['n_lectures_watched'], dtype=tf.float32) / scaling_factor    
    
    # set minimum to `1` to avoid division by `0` error.
    n_questions_answered = tf.cast(tf.math.maximum(_batch['n_questions_answered'], 1), dtype=tf.float32)
            
    _batch['answered_correctly_ratio'] = tf.cast(_batch['n_questions_answered_correctly'], dtype=tf.float32) / n_questions_answered
    
    # ----------------------------------------------------------------------------------------------------    
    # To process aggregated historical part information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    part_count_scaled = []
    part_correctness_ratio = []
    
    for part_idx in range(2, 9):
        
        key = f'part_{part_idx}_count'
        _part_count = tf.cast(_batch[key], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _part_count_scaled = _part_count / scaling_factor
        
        part_count_scaled.append(_part_count_scaled)
        
        key_2 = f'part_{part_idx}_correctness_count'
        _part_correctness_count = tf.cast(_batch[key_2], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _part_correctness_ratio = _part_correctness_count / tf.math.maximum(_part_count, 1.0)
        
        part_correctness_ratio.append(_part_correctness_ratio)
        
        # We don't want to keep these keys
        del _batch[key]
        del _batch[key_2]
        
    part_count_scaled = tf.stack(part_count_scaled, axis=-1)
    part_correctness_ratio = tf.stack(part_correctness_ratio, axis=-1)
    
    # shape = [batch_size, seq_len, PART_VOCAB_SIZE - 2]
    _batch['part_count_scaled'] = part_count_scaled
    _batch['part_correctness_ratio'] = part_correctness_ratio

    # ----------------------------------------------------------------------------------------------------  
    # To process `current` aggregated historical part information (i.e. not per part level)

    # shape = [batch_size, seq_len, PART_VOCAB_SIZE]
    _part_count_scaled = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            part_count_scaled
        ],
        axis=-1
    )

    # shape = [batch_size, seq_len, PART_VOCAB_SIZE]
    _part_correctness_ratio = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            part_correctness_ratio
        ],
        axis=-1
    )

    current_part_count_scaled = tf.gather(params=_part_count_scaled, indices=_batch['part_ids'], batch_dims=2)
    current_part_correctness_ratio = tf.gather(params=_part_correctness_ratio, indices=_batch['part_ids'], batch_dims=2)

    # shape = [batch_size, seq_len]
    _batch['current_part_count_scaled'] = current_part_count_scaled
    _batch['current_part_correctness_ratio'] = current_part_correctness_ratio
    
    # ----------------------------------------------------------------------------------------------------      
    # To process aggregated historical correct answer information
    # Be careful, the aggreated information doesn't contain the current place.
    # shape = [batch_size, seq_len]

    correct_answer_count_scaled = []
    correct_answer_correctness_ratio = []
    
    for correct_answer_idx in range(ANSWER_0_ID, ANSWER_3_ID + 1):
        
        key = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_count'
        _correct_answer_count = tf.cast(_batch[key], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _correct_answer_count_scaled = _correct_answer_count / scaling_factor
        
        correct_answer_count_scaled.append(_correct_answer_count_scaled)
        
        key_2 = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_answered_correctly_count'
        _correct_answer_answered_correctly_count = tf.cast(_batch[key_2], dtype=tf.float32)
        # shape = [batch_size, seq_len]
        _correct_answer_answered_correctly_ratio = _correct_answer_answered_correctly_count / tf.math.maximum(_correct_answer_count, 1.0)
        
        correct_answer_correctness_ratio.append(_correct_answer_answered_correctly_ratio)
        
        # We don't want to keep these keys
        del _batch[key]
        del _batch[key_2]
        
    correct_answer_count_scaled = tf.stack(correct_answer_count_scaled, axis=-1)
    correct_answer_correctness_ratio = tf.stack(correct_answer_correctness_ratio, axis=-1)
    
    # shape = [batch_size, seq_len, 7]
    _batch['correct_answer_count_scaled'] = correct_answer_count_scaled
    _batch['correct_answer_correctness_ratio'] = correct_answer_correctness_ratio
    
    # ----------------------------------------------------------------------------------------------------     
    # To process `current` aggregated historical correct answer information (i.e. not per correct answer level)

    # shape = [batch_size, seq_len, ANSWER_VOCAB_SIZE - 3]
    _correct_answer_count_scaled = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis] / scaling_factor,  # For lectures
            correct_answer_count_scaled
        ],
        axis=-1
    )

    # shape = [batch_size, seq_len, ANSWER_VOCAB_SIZE - 3]
    _correct_answer_correctness_ratio = tf.concat(
        [
            tf.constant(PAD_TOKEN, dtype=tf.float32) * tf.ones_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],
            tf.zeros_like(_batch['content_id'], dtype=tf.float32)[:, :, tf.newaxis],  # For lectures
            correct_answer_correctness_ratio
        ],
        axis=-1
    )

    # `PAD_ID` gives index `0` and `ANSWER_LECTURE_ID` gives index `1`.
    # `ANSWER_START_ID`, `ANSWER_END_ID` and `ANSWER_MASK_ID` doesn't exist in `correct_answer_id`.
    _indices = tf.math.maximum(_batch['correct_answer_id'] - ANSWER_LECTURE_ID + 1, 0)

    current_correct_answer_count_scaled = tf.gather(params=_correct_answer_count_scaled, indices=_indices, batch_dims=2)
    current_correct_answer_correctness_ratio = tf.gather(params=_correct_answer_correctness_ratio, indices=_indices, batch_dims=2)

    # shape = [batch_size, seq_len]
    _batch['current_correct_answer_count_scaled'] = current_correct_answer_count_scaled
    _batch['current_correct_answer_correctness_ratio'] = current_correct_answer_correctness_ratio    

    # ----------------------------------------------------------------------------------------------------
        
    _batch['current_question_count_scaled'] = tf.cast(_batch['n_prev_seen'], dtype=tf.float32) / scaling_factor

    # set minimum to `1` to avoid division by `0` error.
    current_question_count = tf.cast(tf.math.maximum(_batch['n_prev_seen'], 1), dtype=tf.float32)
            
    _batch['current_question_correctness_ratio'] = tf.cast(_batch['n_prev_correctness'], dtype=tf.float32) / current_question_count

    # ----------------------------------------------------------------------------------------------------
    
    return _batch

In [None]:
 def add_aggregated_historical_information(parsed):
    """
    For validation dataset, this must be performed after `remove_future_valid_blocks` is applied (so we have the actual `pred_time_mask`).
    For tranining dataset, the `pred_time_mask` is (and should be) all `0`, so every places are considered.
    """
 
    # ----------------------------------------------------------------------------------------------------
    # To get aggregated historical information
    # Be careful, the aggreated information shouldn't contain the current place.
    
    question_mask = tf.cast(parsed['content_type_id'] == 0, dtype=tf.int32)
    lecture_mask = tf.cast(parsed['content_type_id'] == 1, dtype=tf.int32)
    correction_mask = tf.cast(parsed['answered_correctly'] == 1, dtype=tf.int32)
    content_ids = parsed['content_id']
    non_pred_time_mask = 1 - parsed['pred_time_mask']
    non_pred_time_question_mask = question_mask * non_pred_time_mask
    non_pred_time_question_correction_mask = non_pred_time_question_mask * correction_mask
    
    # shape = [seq_len]
    parsed['n_questions_answered'] = tf.concat(
        [
            tf.constant([0], dtype=tf.int32),
            tf.math.cumsum(non_pred_time_question_mask[:-1], axis=-1)
        ],
        axis=-1
    )
    
    parsed['n_questions_answered_correctly'] = tf.concat(
        [
            tf.constant([0], dtype=tf.int32),
            tf.math.cumsum(non_pred_time_question_correction_mask[:-1], axis=-1)
        ],
        axis=-1
    )
    
    parsed['n_lectures_watched'] = tf.concat(
        [
            tf.constant([0], dtype=tf.int32),
            tf.math.cumsum(lecture_mask[:-1], axis=-1)
        ],
        axis=-1
    )
    
    # ----------------------------------------------------------------------------------------------------
    # Aggregated historical information for part
    
    content_input_ids = question_mask * question_id_to_input_id_table.lookup(content_ids) + lecture_mask * lecture_id_to_input_id_table.lookup(content_ids)
        
    # The original parts are added by `1`
    part_ids = tf.gather(params=c_inputs_ids_to_part, indices=content_input_ids) + 1
    
    for part_idx in range(2, 9):
        
        key = f'part_{part_idx}_count'
        
        part_mask = tf.cast(part_ids == part_idx, dtype=tf.int32)
        # Only count for questions that are not in `pred_time`.
        part_mask = part_mask * non_pred_time_question_mask
        part_correct_mask = part_mask * non_pred_time_question_correction_mask
        
        parsed[key] = tf.concat(
            [
                tf.constant([0], dtype=tf.int32),
                tf.math.cumsum(part_mask[:-1], axis=-1)
            ],
            axis=-1
        )
        
        key = f'part_{part_idx}_correctness_count'
        
        parsed[key] = tf.concat(
            [
                tf.constant([0], dtype=tf.int32),
                tf.math.cumsum(part_correct_mask[:-1], axis=-1)
            ],
            axis=-1
        )  
       
    # ----------------------------------------------------------------------------------------------------
    # Aggregated historical information for correct answer
        
    # There is an offset of `ANSWER_0_ID`.
    correct_answer_id = tf.gather(params=c_inputs_ids_to_correct_answer_id, indices=content_input_ids)
    
    for correct_answer_idx in range(ANSWER_0_ID, ANSWER_3_ID + 1):
        
        key = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_count'
        
        correct_answer_mask = tf.cast(correct_answer_id == correct_answer_idx, dtype=tf.int32)
        # Only count for questions that are not in `pred_time`.
        correct_answer_mask = correct_answer_mask * non_pred_time_question_mask        

        correct_answer_answered_correctly_mask = correct_answer_mask * non_pred_time_question_correction_mask
        
        parsed[key] = tf.concat(
            [
                tf.constant([0], dtype=tf.int32),
                tf.math.cumsum(correct_answer_mask[:-1], axis=-1)
            ],
            axis=-1
        )
        
        key = f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_answered_correctly_count'
        
        parsed[key] = tf.concat(
            [
                tf.constant([0], dtype=tf.int32),
                tf.math.cumsum(correct_answer_answered_correctly_mask[:-1], axis=-1)
            ],
            axis=-1
        )      

    return parsed

#### For validation dataset

In [None]:
def add_valid_block_info(valid_raw_example, num_parallel_calls=None, deterministic=None):
    """
        - Add `valid_block_pos`.
        - Repeat `n_valid_blocks` times, each with a index `valid_block_idx`.
        - Transform using `trans_2`.
    """
    
    example = {}
    for k in valid_raw_example:
        example[k] = valid_raw_example[k]
    
    valid_block_pos = tf.stack([valid_raw_example['valid_blocks_start_pos'], valid_raw_example['valid_blocks_end_pos']], axis=1)
    example['valid_block_pos'] = valid_block_pos
    
    n_valid_blocks = example['n_valid_blocks']
    
    ds_1 = tf.data.Dataset.from_tensors(example).repeat(tf.cast(n_valid_blocks, dtype=tf.int64))
    ds_2 = tf.data.Dataset.range(tf.cast(n_valid_blocks, dtype=tf.int64), output_type=tf.int32)
    ds = tf.data.Dataset.zip((ds_1, ds_2))
    
    ds = ds.map(lambda ex, valid_block_idx: add_valid_block_idx(ex, valid_block_idx), num_parallel_calls=num_parallel_calls, deterministic=deterministic)

    return ds


def add_valid_block_idx(valid_example, valid_block_idx):
    """Add the following information:
    
        - `valid_block_idx`: The index of a block in all the blocks in a user's interaction history that are used for validation.
        - `valid_start`: The starting indices of a valid block in a user's (before being splitted) training interaction history.
        - `valid_end`: The ending indices of the a valid block in a user's (before being splitted) training interaction history.
        
    Then call `remove_future_valid_blocks()` to remove validation blocks after the current one.
        
    """
    
    valid_example['valid_block_idx'] = valid_block_idx
    valid_example['valid_start'] = valid_example['valid_block_pos'][valid_block_idx][0]
    valid_example['valid_end'] = valid_example['valid_block_pos'][valid_block_idx][1]
    
    return valid_example


def remove_future_valid_blocks(valid_example):
    """Remove the validation blocks in the full interaction history of a user after the current validation block.
    """
        
    example = {}
    
    valid_start = valid_example['valid_start']
    valid_end = valid_example['valid_end']
    
    example['user_id'] = valid_example['user_id']
    example['seq_len'] = valid_end + 1
    example['prev_seq_len'] = valid_example['seq_len']
    example['start'] = tf.constant(0, dtype=tf.int32)
    example['end'] = valid_end
        
    for k in valid_example:
        
        if k not in ['user_id', 'seq_len', 'prev_seq_len', 'start', 'end']:
            
            if k in ['n_valid_blocks', 'valid_blocks_start_pos', 'valid_blocks_end_pos', 'valid_block_pos', 'valid_block_idx', 'valid_start', 'valid_end']:
                # attributes with single value, or the values are not required to be removed
                example[k] = valid_example[k]
            else:
                # attributes with 
                example[k] = valid_example[k][0:valid_end + 1]
                
            if k == 'pred_time_mask':
                # Update `pred_time_mask` - assign `1` to the current validation places.
                
                n_valid_interactions = valid_end - valid_example['valid_start'] + 1
                example[k] = tf.concat([valid_example[k][:valid_start], tf.ones(shape=[n_valid_interactions], dtype=tf.int32)], axis=0)
    
    return example


def extract_ending_subseqs(raw_batch, window_size):
    
    ending_indices = raw_batch['seq_len'] - 1

    return extract_subseqs_from_raw_batch(raw_batch, ending_indices, window_size)
    

def prepare_validation_dataset(valid_raw_ds, batch_size=3, window_size=5, generative=False, use_abs_pos=False, seed=None, num_parallel_calls=None, deterministic=None):
    
    valid_ds = valid_raw_ds.flat_map(lambda valid_raw_example: add_valid_block_info(valid_raw_example, num_parallel_calls=num_parallel_calls, deterministic=deterministic))
    valid_ds = valid_ds.map(lambda example: remove_future_valid_blocks(example), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    # This must be performed after `remove_future_valid_blocks`
    valid_ds = valid_ds.map(lambda example: add_aggregated_historical_information(example), num_parallel_calls=num_parallel_calls, deterministic=deterministic)   

    # should be outside
    # batch examples with attributes having different lengths across examples - tf.RaggedTensor    
    ### batched_valid_ds = valid_ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size, drop_remainder=(IS_KAGGLE and tpu is not None)))
    batched_valid_ds = valid_ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size, drop_remainder=(IS_KAGGLE and tpu is not None)))

    # batch - tf.Tensor: Extract subsequences from the ending of a fixed length    
    batched_valid_ds = batched_valid_ds.map(lambda raw_batch: extract_ending_subseqs(raw_batch, window_size), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    training = tf.constant(0, dtype=tf.int32)

    valid_ds = batched_valid_ds.map(lambda batch: add_input_ids_and_targets(batch, training, generative, use_abs_pos), num_parallel_calls=num_parallel_calls, deterministic=deterministic)
    
    return valid_ds

### check

In [None]:
if not IS_KAGGLE:

    train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=1)
    train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=1, deterministic=True)

    # add aggregated historical information
    train_raw_ds = train_raw_ds.map(lambda example: add_aggregated_historical_information(example))

    # Get the splitted training dataset.
    train_valid_split_indices = load_data(train_valid_split_indices_paths[0])
    train_valid_split_table = convert_split_index_dict(train_valid_split_indices)
    reduced_raw_ds = split_train_ds(train_raw_ds, train_valid_split_table, num_parallel_calls=1, deterministic=True)

    # batch examples with attributes having different lengths across examples - tf.RaggedTensor
    batched_raw_ds = reduced_raw_ds.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=5))

    # batch - tf.Tensor: Extract random subsequences of a fixed length
    batched_ds = random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size=10, seed=1, num_parallel_calls=1, deterministic=True)

    # Add input ids and targets for training
    train_ds = prepare_training_dataset(batched_ds, use_abs_pos=False, num_parallel_calls=1, deterministic=True)

In [None]:
if not IS_KAGGLE:

    it = iter(train_ds.take(2))
    next(it)

In [None]:
if not IS_KAGGLE:

    next(it)['pred_time_mask']

In [None]:
if not IS_KAGGLE:

    valid_raw_ds = tf.data.TFRecordDataset([valid_tfrec_paths], num_parallel_reads=1)
    valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=1, deterministic=True)

    for x in valid_raw_ds.take(1):
        # print(x)
        for k in x:
            print(f'{k} : shape = {x[k].shape}')
        print('--------')

    print('============================================')

    valid_ds = prepare_validation_dataset(valid_raw_ds, batch_size=5, window_size=10, use_abs_pos=False)

    for x in valid_ds.take(1):
        # print(x)
        for k in x:
            print(f'{k} : shape = {x[k].shape}')

        print('--------')

In [None]:
# it = iter(valid_ds.take(2))
# next(it)

## Model definition

In [None]:
def get_initializer(seed):

    return tf.keras.initializers.GlorotUniform(seed=seed)


def gelu(x):
    """
    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
    initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
    https://arxiv.org/abs/1606.08415
    """
    x = tf.convert_to_tensor(x)
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.math.sqrt(2.0), dtype=x.dtype)))

    return x * cdf


def gelu_new(x):
    """
    Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
    Args:
        x: float Tensor to perform activation
    Returns:
        `x` with the GELU activation applied.
    """
    x = tf.convert_to_tensor(x)
    pi = tf.cast(math.pi, x.dtype)
    coeff = tf.cast(0.044715, x.dtype)
    cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))

    return x * cdf


def gelu_fast(x):
    x = tf.convert_to_tensor(x)
    coeff1 = tf.cast(7978845608, x.dtype)
    coeff2 = tf.cast(0.044715, x.dtype)

    return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))


ACT2FN = {
    "gelu": tf.keras.layers.Activation(gelu),
    "relu": tf.keras.activations.relu,
    "swish": tf.keras.activations.swish,
    "silu": tf.keras.activations.swish,
    "gelu_new": tf.keras.layers.Activation(gelu_new),
    "tanh": tf.keras.activations.tanh,
    "gelu_fast": tf.keras.layers.Activation(gelu_fast),
}


def get_tf_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))

In [None]:
class EdFormerConfig(PretrainedConfig):

    def __init__(
        self,
        model_type,
        model_desc,
        model_size='none',
        content_vocab_size=CONTENT_VOCAB_SIZE,
        response_vocab_size=RESPONSE_VOCAB_SIZE,
        tag_vocab_size=TAG_VOCAB_SIZE,
        part_vocab_size=PART_VOCAB_SIZE,
        prior_explanation_vocab_size=PRIOR_EXPLANATION_VOCAB_SIZE,
        max_position_embeddings=WINDOW_SIZE + 1,
        sinusoidal_pos_embds=False,
        n_layers=4,
        n_heads=8,
        dim=512,
        hidden_dim=4 * 512,
        activation=ACTIVATION,        
        dropout=0.1,
        attention_dropout=0.1,
        seq2seq_dropout=0.1,
        initializer_range=0.02,
        seed=SEED,        
        pad_token_id=PAD_ID,
        use_user_answer=USE_USER_ANSWER,
        use_user_answer_loss=USE_USER_ANSWER_LOSS,
        use_correct_answer_for_encoder=USE_CORRECT_ANSWER_FOR_ENCODER,
        use_correct_answer_for_decoder=USE_CORRECT_ANSWER_FOR_DECODER,
        use_abs_pos=USE_ABS_POS,
        use_task_container_pos=USE_TASK_CONTAINER_POS,
        share_position_embeddings=SHARE_POS_EMBEDDING,
        use_tags=USE_TAGS,
        use_part=USE_PART,
        use_prior_explanation=USE_PRIOR_EXPLANATION,
        use_prior_question_elapsed_time_input=USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT,
        use_lag_time=USE_LAG_TIME,
        use_lag_time_for_encoder=USE_LAG_TIME_FOR_ENCODER,
        use_user_level_aggregated_historical_info=USE_USER_LEVEL_AGGREGATED_HISTORICAL_INFO,
        use_part_aggregated_historical_info=USE_PART_AGGREGATED_HISTORICAL_INFO,
        use_correct_answer_aggregated_historical_info=USE_CORRECT_ANSWER_AGGREGATED_HISTORICAL_INFO,
        use_question_level_aggregated_historical_info=USE_QUESTION_LEVEL_AGGREGATED_HISTORICAL_INFO,
        allow_bundle_atten=ALLOW_BUNDLE_ATTEN,
        generative=GENERATIVE,
        use_pre_classifier=USE_PRE_CLASSIFIER,
        use_softmax=USE_SOFTMAX,
        **kwargs
    ):
        super().__init__(**kwargs, pad_token_id=pad_token_id)
        
        self.model_type = model_type
        self.model_size = model_size
        self.model_desc = model_desc

        self.content_vocab_size = content_vocab_size
        self.response_vocab_size = response_vocab_size

        self.tag_vocab_size = tag_vocab_size
        self.part_vocab_size = part_vocab_size
        self.prior_explanation_vocab_size = prior_explanation_vocab_size

        self.use_user_answer = use_user_answer
        self.use_correct_answer_for_encoder = use_correct_answer_for_encoder
        self.use_correct_answer_for_decoder = use_correct_answer_for_decoder
        self.use_user_answer_loss = use_user_answer_loss
        self.use_abs_pos = use_abs_pos
        self.use_task_container_pos = use_task_container_pos
        self.share_position_embeddings = share_position_embeddings
        self.use_tags = use_tags
        self.use_part = use_part
        self.use_prior_explanation = use_prior_explanation
        self.allow_bundle_atten = allow_bundle_atten
        self.generative = generative
        self.use_prior_question_elapsed_time_input = use_prior_question_elapsed_time_input
        self.use_part_aggregated_historical_info = use_part_aggregated_historical_info
        self.use_correct_answer_aggregated_historical_info=use_correct_answer_aggregated_historical_info
        self.use_question_level_aggregated_historical_info=use_question_level_aggregated_historical_info
        self.use_lag_time = use_lag_time
        self.use_lag_time_for_encoder = use_lag_time_for_encoder
        self.use_user_level_aggregated_historical_info = use_user_level_aggregated_historical_info
        self.use_pre_classifier = use_pre_classifier
        self.use_softmax = use_softmax

        self.max_position_embeddings = max_position_embeddings
        self.sinusoidal_pos_embds = sinusoidal_pos_embds
        
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dim = dim
        self.hidden_dim = hidden_dim

        self.activation = activation

        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.seq2seq_dropout = seq2seq_dropout

        self.initializer_range = initializer_range
        self.seed = seed

    def toJSON(self):

        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)

    def vocab_size(self, input_name):

        if input_name == 'content':
            return self.content_vocab_size
        elif input_name == 'response':
            return self.response_vocab_size
        elif input_name == 'tag':
            return self.tag_vocab_size
        elif input_name == 'part':
            return self.part_vocab_size            
        elif input_name == 'prior_explanation':
            return self.prior_explanation_vocab_size
        else:
            raise ValueError('input name not used for model')

    @property
    def hidden_size(self):
        return self.dim

    @property
    def num_attention_heads(self):
        return self.n_heads

    @property
    def num_hidden_layers(self):
        return self.n_layers


class TFSharedEmbeddings(tf.keras.layers.Layer):

    def __init__(self, vocab_size, hidden_size, seed=None, **kwargs):
        
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.seed = seed

    def build(self, input_shape):

        self.weight = self.add_weight(
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(seed=self.seed)
        )
        super().build(input_shape)

    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:

        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, input_ids):
        """Applies embedding based on inputs tensor."""
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):

        first_dims = shape_list(inputs)[:-1]
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)

        return tf.reshape(logits, first_dims + [self.vocab_size])


class TFEmbeddings(tf.keras.layers.Layer):

    def __init__(
        self, config, input_name,
        position_embeddings_layer=None,       
        **kwargs
    ):
    
        super().__init__(**kwargs)
        
        self.pad_id = config.pad_token_id
        self.input_name = input_name
        self.vocab_size = config.vocab_size(input_name)

        self.use_user_answer = config.use_user_answer
        self.use_correct_answer_for_encoder = config.use_correct_answer_for_encoder
        self.use_correct_answer_for_decoder = config.use_correct_answer_for_decoder
        self.use_abs_pos = config.use_abs_pos
        self.use_task_container_pos = config.use_task_container_pos
        self.use_tags = config.use_tags
        self.use_part = config.use_part
        self.use_prior_explanation = config.use_prior_explanation
        self.use_prior_question_elapsed_time_input = config.use_prior_question_elapsed_time_input
        self.use_lag_time = config.use_lag_time
        self.use_lag_time_for_encoder = config.use_lag_time_for_encoder
        self.use_user_level_aggregated_historical_info = config.use_user_level_aggregated_historical_info
        self.use_part_aggregated_historical_info = config.use_part_aggregated_historical_info
        self.use_correct_answer_aggregated_historical_info = config.use_correct_answer_aggregated_historical_info
        self.use_question_level_aggregated_historical_info = config.use_question_level_aggregated_historical_info

        self.dim = config.dim
        self.seed = config.seed
        
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = get_tf_activation(config.activation)

        self.word_embeddings = TFSharedEmbeddings(
            self.vocab_size, config.dim, seed=config.seed, name="word_embeddings"
        )  # padding_idx=0)
        
        self.word_embeddings_2 = TFSharedEmbeddings(
            ANSWER_VOCAB_SIZE, config.dim, seed=config.seed, name="word_embeddings_2"
        )  # padding_idx=0)         
        
        if position_embeddings_layer is None:

            self.position_embeddings = tf.keras.layers.Embedding(
                config.max_position_embeddings,
                config.dim,
                embeddings_initializer=get_initializer(config.seed),
                name="position_embeddings",
            )
        
        else:

            self.position_embeddings = position_embeddings_layer
           
        self.correct_answer_embeddings = tf.keras.layers.Embedding(
            ANSWER_VOCAB_SIZE,
            config.dim,
            embeddings_initializer=get_initializer(config.seed),
            name="correct_answer_embeddings",
        )

        self.tag_embeddings = tf.keras.layers.Embedding(
            config.tag_vocab_size,
            config.dim,
            embeddings_initializer=get_initializer(config.seed),
            name="tag_embeddings",
        )

        self.part_embeddings = tf.keras.layers.Embedding(
            config.part_vocab_size,
            config.dim,
            embeddings_initializer=get_initializer(config.seed),
            name="part_embeddings",
        )

        self.prior_explanation_embeddings = tf.keras.layers.Embedding(
            config.prior_explanation_vocab_size,
            config.dim,
            embeddings_initializer=get_initializer(config.seed),
            name="prior_explanation_embeddings",
        )
        
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(0.05)

    def build(self, input_shape):
        """Build shared word embedding layer """
        
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight", shape=[self.vocab_size, self.dim], initializer=get_initializer(self.seed)
            )
        
        if self.input_name == 'response':
            
            with tf.name_scope("word_embeddings_2"):
                # Create and initialize weights. The random normal initializer was chosen
                # arbitrarily, and works well.
                self.word_embeddings_2 = self.add_weight(
                    "weight", shape=[ANSWER_VOCAB_SIZE, self.dim], initializer=get_initializer(self.seed)
                )

        super().build(input_shape)

    def call(
        self,
        input_ids=None, input_ids_2=None, position_ids=None,
        tag_ids=None, part_ids=None,
        prior_explanation_ids=None,
        correct_answer_ids=None,
        dense_embeddings_encoder=None,
        dense_embeddings_decoder=None,
        mode="embedding", training=False
    ):
       
        if mode == "embedding":
            return self._embedding(
                input_ids, input_ids_2, position_ids,
                tag_ids, part_ids,
                prior_explanation_ids,
                correct_answer_ids,
                dense_embeddings_encoder,
                dense_embeddings_decoder,
                training=training
            )
        elif mode == "linear":
            return self._linear(input_ids)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(
        self,
        input_ids, input_ids_2, position_ids,
        tag_ids, part_ids,
        prior_explanation_ids,
        correct_answer_ids,
        dense_embeddings_encoder,
        dense_embeddings_decoder,
        training=False
    ):
        
        seq_length = shape_list(input_ids)[1]

        inputs_embeds = tf.gather(self.word_embeddings, input_ids)

        position_embeddings = tf.cast(
            self.position_embeddings(position_ids), inputs_embeds.dtype
        )  # (bs, max_seq_length, dim)
        
        if self.input_name == 'content':
            
            tag_embeddings = tf.cast(
                self.tag_embeddings(tag_ids), inputs_embeds.dtype
            )  # (bs, seq_len, N_TAGS_PER_CONTENT, dim)
            
            # shape = (bs, seq_len, N_TAGS_PER_CONTENT)
            tag_mask = tf.cast(tag_ids != self.pad_id, dtype=tf.int32)

            tag_embeddings = tag_embeddings * tf.cast(tag_mask, dtype=inputs_embeds.dtype)[:, :, :, tf.newaxis]
            
            # shape = (bs, seq_len)
            nb_tags = tf.math.reduce_sum(tag_mask, axis=2)
            nb_tags = tf.cast(nb_tags, dtype=inputs_embeds.dtype)
            nb_tags = tf.math.maximum(nb_tags, tf.cast(1.0, dtype=inputs_embeds.dtype))

            tag_embeddings = tf.math.reduce_sum(tag_embeddings, axis=2) / nb_tags[:, :, tf.newaxis]

            part_embeddings = tf.cast(
                self.part_embeddings(part_ids), inputs_embeds.dtype
            )  # (bs, seq_len, dim)            
            
            correct_answer_embeddings = tf.cast(self.correct_answer_embeddings(correct_answer_ids), inputs_embeds.dtype)
                        
            # shape = [n_embeddings, batch_size, seq, dim]
            concated_embeddings = tf.concat(
                [
                    inputs_embeds[tf.newaxis, :, :, :],
                    position_embeddings[tf.newaxis, :, :, :],
                    tag_embeddings[tf.newaxis, :, :, :],
                    part_embeddings[tf.newaxis, :, :, :],
                    correct_answer_embeddings[tf.newaxis, :, :, :],
                    dense_embeddings_encoder
                ],
                axis=0
            )            
            
        if self.input_name == 'response':
            
            # use_user_answer
            inputs_embeds_2 = tf.gather(self.word_embeddings_2, input_ids_2)            
            
            prior_explanation_embeddings = tf.cast(
                self.prior_explanation_embeddings(prior_explanation_ids), inputs_embeds.dtype
            )  # (bs, seq_len, dim)
            
            # ----------------------------------------------------------------------------------------------------

            # shape = [n_embeddings, batch_size, seq, dim]
            concated_embeddings = tf.concat(
                [                    
                    inputs_embeds[tf.newaxis, :, :, :],
                    inputs_embeds_2[tf.newaxis, :, :, :],
                    position_embeddings[tf.newaxis, :, :, :],
                    prior_explanation_embeddings[tf.newaxis, :, :, :], 
                    dense_embeddings_decoder,
                ],
                axis=0
            )

        # shape = [batch_size, seq_len, dim]
        embeddings = tf.math.reduce_sum(concated_embeddings, axis=0)
        
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
        embeddings = self.dropout(embeddings, training=training)  # (bs, max_seq_length, dim)
        
        return embeddings
    
    
class TFMultiHeadSelfAttention(HFTFMultiHeadSelfAttention):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.q_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="q_lin"
        )
        self.k_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="k_lin"
        )
        self.v_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="v_lin"
        )
        self.out_lin = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="out_lin"
        )

    def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
        """
        Parameters:
            query: tf.Tensor(bs, query_length, dim)
            key: tf.Tensor(bs, key_length, dim)
            value: tf.Tensor(bs, key_length, dim)
            mask: tf.Tensor(bs, query_length / 1, key_length)
        Returns:
            weights: tf.Tensor(bs, n_heads, query_length, key_length) Attention weights context: tf.Tensor(bs,
            query_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = shape_list(query)
        k_length = shape_list(key)[1]
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
        # assert key.size() == value.size()
        dim_per_head = tf.math.divide(self.dim, self.n_heads)
        dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
        
        def shape(x):
            """ separate heads """
            return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))

        def unshape(x):
            """ group heads """
            return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
        q = tf.cast(q, dtype=DTYPE)
        q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=DTYPE)))
        k = tf.cast(k, dtype=q.dtype)
        scores = tf.matmul(q, k, transpose_b=True)  # (bs, n_heads, q_length, k_length)
        mask = mask[:, tf.newaxis, :, :]  # (bs, 1, qlen / 1, klen) --> (bs, n_heads, qlen, klen)
        # scores.masked_fill_(mask, -float('inf'))            # (bs, n_heads, q_length, k_length)

        mask = tf.cast(mask, dtype=scores.dtype)
        scores = scores - 1e30 * (1.0 - mask)
        weights = tf.nn.softmax(scores, axis=-1)  # (bs, n_heads, qlen, klen)
        weights = self.dropout(weights, training=training)  # (bs, n_heads, qlen, klen)
        # This makes things more numerically stable.
        weights = weights * mask

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = tf.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return (context, weights)
        else:
            return (context,)
    

class TFFFN(HFTFFFN):

    def __init__(self, config, **kwargs):
        
        super(TFFFN, self).__init__(**kwargs)

        self.lin1 = tf.keras.layers.Dense(
            config.hidden_dim,
            kernel_initializer=get_initializer(config.seed),
            name="lin1"
        )
        self.lin2 = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="lin2"
        )

        self.activation = get_tf_activation(config.activation)


class TFTransformerBlock(HFTFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="attention")
 
    
class TFContentBlock(TFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="c_attention")
        

class TFResponseBlock(TFTransformerBlock):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)
    
        self.attention = TFMultiHeadSelfAttention(config, name="r_attention")
        
        self.r_c_attentioin = TFMultiHeadSelfAttention(config, name="r_c_attentioin")
        self.r_c_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="r_c_attn_layer_norm")  
        
    
    def call(self, r, c_hidden, r_mask, r_c_mask, head_mask, output_attentions, training=False):  # removed: src_enc=None, src_len=None

        r_output = self.attention(r, r, r, r_mask, head_mask, output_attentions, training=training)
        if output_attentions:
            r_output, r_weights = r_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            # assert type(sa_output) == tuple
            r_output = r_output[0]
        r_output = self.sa_layer_norm(r_output + r)  # (bs, seq_length, dim)

        r_c_output = self.r_c_attentioin(
            query=r_output,
            key=c_hidden,
            value=c_hidden,
            mask=r_c_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            training=training
        )
        
        if output_attentions:
            r_c_output, r_c_weights = r_c_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            # assert type(sa_output) == tuple
            r_c_output = r_c_output[0]        
        r_c_output = self.r_c_attn_layer_norm(r_c_output + r_output)
        
        # Feed Forward Network
        ffn_output = self.ffn(r_c_output, training=training)  # (bs, seq_length, dim)
        ffn_output = self.output_layer_norm(ffn_output + r_c_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (r_weights, r_c_weights) + output
        return output
       

class TFCRBlock(TFTransformerBlock):
               
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)

        self.attention = TFMultiHeadSelfAttention(config, name="cr_attention") 


class TFContentCoder(TFTransformer):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)
        
        self.layer = [TFContentBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
    
    
class TFResponseCoder(TFTransformer):
    
    def __init__(self, config, **kwargs):
        
        super().__init__(config, **kwargs)        
        
        self.layer = [TFResponseBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]

    def call(self, r_embeds, c_hidden, r_mask, r_c_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):
        # docstyle-ignore
        """
        """
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = r_embeds
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            layer_outputs = layer_module(hidden_state, c_hidden, r_mask, r_c_mask, head_mask[i], output_attentions, training=training)
            hidden_state = layer_outputs[-1]

            if output_attentions:
                assert len(layer_outputs) == 3
                r_attn = layer_outputs[0]
                r_c_attn = layer_outputs[1]
                all_attentions = all_attentions + ((r_attn, r_c_attn),)
            else:
                assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )        
        
        
def process_mask(mask, input_shpae):
    
    if mask is None:
        mask = tf.ones(input_shpae)
    mask = tf.cast(mask, dtype=tf.float32)            
    
    return mask


class TFEdFormerEncoderDecoder(tf.keras.layers.Layer):
    
    def __init__(self, config, **kwargs):

        super().__init__(**kwargs)

        self.encoder = TFContentCoder(config, name="encoder")  # Encoder
        self.decoder = TFResponseCoder(config, name="decoder")  # Decoder
    
    def call(
        self,
        c_embeds,
        r_embeds,
        d_embeds,
        c_mask,
        r_mask,       
        r_c_mask,
        c_r_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):

        c_outputs = self.encoder(
            c_embeds, c_mask, head_mask,
            output_attentions, output_hidden_states, return_dict, training
        )
        
        if not return_dict:
            c_hidden = c_outputs[0]
        else:
            c_hidden = c_outputs.last_hidden_state
        
        r_outputs = self.decoder(
            d_embeds, c_hidden, r_mask, r_c_mask, head_mask,
            output_attentions, output_hidden_states, return_dict, training
        )      

        if not return_dict:
            r_hidden = r_outputs[0]
        else:
            r_hidden = r_outputs.last_hidden_state

        hidden_states = r_hidden

        return (hidden_states, c_outputs, r_outputs)
    

class TFEdFormerMainLayer(tf.keras.layers.Layer):
    
    config_class = EdFormerConfig

    def __init__(self, config, **kwargs):

        super().__init__(**kwargs)
        
        self.num_hidden_layers = config.num_hidden_layers
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.return_dict = config.use_return_dict
        
        self.use_user_answer = config.use_user_answer
        self.use_correct_answer_for_encoder = config.use_correct_answer_for_encoder
        self.use_correct_answer_for_decoder = config.use_correct_answer_for_decoder
        self.use_abs_pos = config.use_abs_pos
        self.use_task_container_pos = config.use_task_container_pos
        self.use_tags = config.use_tags
        self.use_part = config.use_part
        self.use_prior_explanation = config.use_prior_explanation
        self.use_prior_question_elapsed_time_input = config.use_prior_question_elapsed_time_input
        self.use_lag_time = config.use_lag_time
        self.use_lag_time_for_encoder = config.use_lag_time_for_encoder
        
        self.use_user_level_aggregated_historical_info = config.use_user_level_aggregated_historical_info
        self.use_part_aggregated_historical_info = config.use_part_aggregated_historical_info
        self.use_correct_answer_aggregated_historical_info = config.use_correct_answer_aggregated_historical_info
        self.use_question_level_aggregated_historical_info = config.use_question_level_aggregated_historical_info

        if config.share_position_embeddings:
            # All `TFEmbeddings` share a single `position_embeddings`.
            position_embeddings = tf.keras.layers.Embedding(
                config.max_position_embeddings,
                config.dim,
                embeddings_initializer=get_initializer(config.seed),
                name="position_embeddings",
            )
        else:
            position_embeddings = None
            
        self.prior_question_elapsed_time_embeddings = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="prior_question_elapsed_time_embeddings",
        )

        self.lag_time_embeddings = tf.keras.layers.Dense(
            config.dim,
            kernel_initializer=get_initializer(config.seed),
            name="lag_time_embeddings",
        )            
        
        self.key_names_for_encoder = {
            'n_questions_answered_scaled',
            'n_lectures_watched_scaled',
            'part_count_scaled',
            'current_part_count_scaled',
            'correct_answer_count_scaled',
            'current_correct_answer_count_scaled',
            'current_question_count_scaled',
        }
        
        self.key_name_need_newaxis = {
            'n_questions_answered_scaled',
            'n_lectures_watched_scaled',
            'current_part_count_scaled',
            'current_correct_answer_count_scaled',
            'current_question_count_scaled',
            'prior_question_elapsed_time_input',
            'lag_time',
            'answered_correctly_ratio',
            'current_part_correctness_ratio',
            'current_correct_answer_correctness_ratio',
            'current_question_correctness_ratio',            
        }    
        
        self.dense_layer_name_mapping = {
            'n_questions_answered_scaled': ['n_questions_answered_embeddings', 'n_questions_answered_layer_norm'],
            'n_lectures_watched_scaled': ['n_lectures_watched_embeddings', 'n_lectures_watched_layer_norm'],
            'part_count_scaled': ['part_count_embeddings', 'part_count_layer_norm'],
            'current_part_count_scaled': ['current_part_count_embeddings', 'current_part_count_layer_norm'],
            'correct_answer_count_scaled': ['correct_answer_count_embeddings', 'correct_answer_count_layer_norm'],
            'current_correct_answer_count_scaled': ['current_correct_answer_count_embeddings', 'current_correct_answer_count_layer_norm'],
            'current_question_count_scaled': ['current_question_count_embeddings', 'current_question_count_layer_norm'],
            'prior_question_elapsed_time_input': ['prior_question_elapsed_time_input_embeddings', 'prior_question_elapsed_time_input_layer_norm'],
            'lag_time': ['lag_time_embeddings', 'lag_time_layer_norm'],
            'answered_correctly_ratio': ['answered_correctly_ratio_embeddings', 'answered_correctly_ratio_layer_norm'],
            'part_correctness_ratio': ['part_correctness_ratio_embeddings', 'part_correctness_ratio_layer_norm'],
            'current_part_correctness_ratio': ['current_part_correctness_ratio_embeddings', 'current_part_correctness_ratio_layer_norm'],   
            'correct_answer_correctness_ratio': ['correct_answer_correctness_ratio_embeddings', 'correct_answer_correctness_ratio_layer_norm'], 
            'current_correct_answer_correctness_ratio': ['current_correct_answer_correctness_ratio_embeddings', 'current_correct_answer_correctness_ratio_layer_norm'],   
            'current_question_correctness_ratio': ['current_question_correctness_ratio_embeddings', 'current_question_correctness_ratio_layer_norm'],     
        }

        self.dense_layer_keys = [
            'n_questions_answered_scaled',
            'n_lectures_watched_scaled',
            'part_count_scaled',
            'current_part_count_scaled',
            'correct_answer_count_scaled',
            'current_correct_answer_count_scaled',
            'current_question_count_scaled',
            'prior_question_elapsed_time_input',
            'lag_time',
            'answered_correctly_ratio',
            'part_correctness_ratio',
            'current_part_correctness_ratio',
            'correct_answer_correctness_ratio',
            'current_correct_answer_correctness_ratio',
            'current_question_correctness_ratio'  
        ]      
        
        self.dense_layer_mapping = dict()        
        
        for key_name, layer_names in self.dense_layer_name_mapping.items():
        
            setattr(
                self,
                layer_names[0],
                tf.keras.layers.Dense(
                    config.dim,
                    kernel_initializer=get_initializer(config.seed),
                    name=layer_names[0],
                )    
            )
  
            setattr(
                self,
                layer_names[1],
                tf.keras.layers.LayerNormalization(epsilon=1e-12, name=layer_names[1])  
            )
    
            self.dense_layer_mapping[key_name] = [getattr(self, layer_names[0]), getattr(self, layer_names[1])]
    
        # --------------------------------------------------------------------------------
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
            config.activation
        )
        self.activation = get_tf_activation(config.activation)    
        self.dropout = tf.keras.layers.Dropout(0.05)
        # --------------------------------------------------------------------------------
        
        self.content_embeddings = TFEmbeddings(
            config,
            input_name = 'content',
            position_embeddings_layer=position_embeddings,   
            name="content_embeddings"
        )
        
        self.response_embeddings = TFEmbeddings(
            config,
            input_name = 'response',
            position_embeddings_layer=position_embeddings,  
            name="response_embeddings"
        )
        
        self.coder = TFEdFormerEncoderDecoder(config, name='coder')
                
    def call(
        self,
        inputs,
        c_mask=None,
        r_mask=None,
        r_c_mask=None,
        c_r_mask=None,
        head_mask=None,    
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):

        c_input_ids = inputs.get('c_input_ids')
        r_input_ids = inputs.get('r_input_ids')
        d_input_ids = inputs.get('d_input_ids')
        d_ans_input_ids = inputs.get('d_ans_input_ids')

        tag_ids = inputs.get('tag_ids')
        part_ids = inputs.get('part_ids')

        # --------------------------------------------------------------------------------

        batch_size = tf.math.reduce_sum(tf.ones_like(c_input_ids[:, 0], dtype=tf.int32))
        seq_len = tf.math.reduce_sum(tf.ones_like(c_input_ids[0, :], dtype=tf.int32))
    
        # positional information provided
        pos_ids = inputs['pos_ids']
        shifted_pos_ids = inputs['shifted_pos_ids']

        prior_explanation_ids = inputs.get('prior_explanation_ids')
        
        correct_answer_ids = inputs.get('correct_answer_id')

        c_mask = inputs.get('c_mask', c_mask)
        r_mask = inputs.get('r_mask', r_mask)
        r_c_mask = inputs.get('r_c_mask', r_c_mask)
        c_r_mask = inputs.get('c_r_mask', c_r_mask)                        
        
        output_attentions = inputs.get('output_attentions', output_attentions)
        output_hidden_states = inputs.get('output_hidden_states', output_hidden_states)
        return_dict = inputs.get('return_dict', return_dict)
             
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.return_dict
    
        c_mask = process_mask(c_mask, [batch_size, seq_len, seq_len])
        r_mask = process_mask(r_mask, [batch_size, seq_len, seq_len])
        r_c_mask = process_mask(r_c_mask, [batch_size, seq_len, seq_len])
        c_r_mask = process_mask(c_r_mask, [batch_size, seq_len, seq_len])        
        
        head_mask = [None] * self.num_hidden_layers
            
        dense_embeddings_encoder = []
        dense_embeddings_decoder = []
        for key_name, layers in self.dense_layer_mapping.items():
            
            dense_layer = layers[0]
            layer_norm = layers[1]
            
            layer_input = inputs[key_name]
            
            if key_name in self.key_name_need_newaxis:
                layer_input = layer_input[:, :, tf.newaxis]

            layer_output = dense_layer(layer_input)
            layer_output = layer_norm(
                self.dropout(
                    self.activation(
                        layer_output
                    ),
                    training=training
                ) + layer_output
            )
            
            if key_name in self.key_names_for_encoder:
                
                dense_embeddings_encoder.append(layer_output)

                # ------------------------------------------------------------
                # run again
                layer_output = dense_layer(layer_input)
                layer_output = layer_norm(
                    self.dropout(
                        self.activation(
                            layer_output
                        ),
                        training=training
                    ) + layer_output
                )
                dense_embeddings_decoder.append(layer_output)
                # ------------------------------------------------------------
            else:
                dense_embeddings_decoder.append(layer_output)
            
        dense_embeddings_encoder = tf.stack(dense_embeddings_encoder, axis=0)
        dense_embeddings_decoder= tf.stack(dense_embeddings_decoder, axis=0)
                
        # ----------------------------------------------------------------------------------------------------
        
        c_embedding_output = self.content_embeddings(
            input_ids=c_input_ids, input_ids_2=d_ans_input_ids, position_ids=pos_ids,
            tag_ids=tag_ids, part_ids=part_ids,
            prior_explanation_ids=prior_explanation_ids,
            correct_answer_ids=correct_answer_ids,
            dense_embeddings_encoder=dense_embeddings_encoder,
            dense_embeddings_decoder=dense_embeddings_decoder,
            training=training
        )  # (bs, seq_length, dim)
                
        d_embedding_output = self.response_embeddings(
            input_ids=d_input_ids, input_ids_2=d_ans_input_ids, position_ids=shifted_pos_ids,
            tag_ids=tag_ids, part_ids=part_ids,
            prior_explanation_ids=prior_explanation_ids,
            correct_answer_ids=correct_answer_ids,
            dense_embeddings_encoder=dense_embeddings_encoder,
            dense_embeddings_decoder=dense_embeddings_decoder,
            training=training
        )  # (bs, seq_length, dim)

        r_embedding_output = d_embedding_output

        outputs = self.coder(
            c_embedding_output,
            r_embedding_output,
            d_embedding_output,
            c_mask,
            r_mask,
            r_c_mask,
            c_r_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
            return_dict,
            training=training           
        )
        
        return outputs    
    
    
class TFEdFormerPreTrainedModel(TFPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = EdFormerConfig
    base_model_prefix = "edformer"
    
    
class TFEdFormerModel(TFEdFormerPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        
        super().__init__(config, *inputs, **kwargs)
        
        self.edformer = TFEdFormerMainLayer(config, name="edformer")  # Embeddings

    def call(self, inputs, **kwargs):
        
        outputs = self.edformer(
            inputs, 
            **kwargs
        )
        
        return outputs
    
    
class TFEdFormerAnswerPredictionModel(TFPreTrainedModel):
    
    def __init__(self, config, *inputs, **kwargs):
        
        super().__init__(config, *inputs, **kwargs)
     
        self.use_pre_classifier = config.use_pre_classifier

        self.edformer = TFEdFormerMainLayer(config, name="edformer")  # Embeddings

        if self.use_pre_classifier:

            self.pre_classifier = tf.keras.layers.Dense(
                config.dim,
                kernel_initializer=get_initializer(config.seed),
                name="pre_classifier",
            )       

            self.pre_classifier_2 = tf.keras.layers.Dense(
                config.dim,
                kernel_initializer=get_initializer(config.seed),
                name="pre_classifier_2",
            ) 

            assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
                config.activation
            )
            self.activation = get_tf_activation(config.activation)

            self.dropout = tf.keras.layers.Dropout(config.seq2seq_dropout)

        n_targets = 1
        if config.use_softmax:
            n_targets = 2

        self.classifier = tf.keras.layers.Dense(
            n_targets, kernel_initializer=get_initializer(config.seed), name="classifier"
        )

        self.classifier_2 = tf.keras.layers.Dense(
            4, kernel_initializer=get_initializer(config.seed), name="classifier_2"
        )

    def call(
        self,
        inputs=None,
        c_mask=None,
        r_mask=None,
        r_c_mask=None,
        c_r_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):

        return_dict = return_dict if return_dict is not None else self.edformer.return_dict

        edformer_output = self.edformer(
            inputs,
            c_mask=c_mask,
            r_mask=r_mask,
            r_c_mask=r_c_mask,
            c_r_mask=c_r_mask,
            head_mask=head_mask,          
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        hidden_states, c_outputs, r_outputs = edformer_output

        o = hidden_states
        o_2 = o

        if self.use_pre_classifier:

            o = self.pre_classifier(o)  # (bs, seq_len, dim)
            o = self.activation(o)
            o = self.dropout(o, training=training)        
        
            o_2 = self.pre_classifier_2(o_2)  # (bs, seq_len, dim)
            o_2 = self.activation(o_2)
            o_2 = self.dropout(o_2, training=training)       

        logits = self.classifier(o)  # (bs, seq_len, n_targets), `n_targets = 2` if using softmax. 
        answer_logits = self.classifier_2(o_2)  # (bs, seq_len, 4)

        return (logits, answer_logits, c_outputs, r_outputs)

### check

In [None]:
if not IS_KAGGLE or True:

    content_input_ids = tf.constant(1, shape=[3, 5])
    response_input_ids = tf.constant(1, shape=[3, 5])
    pos_ids = tf.constant(1, shape=[3, 5])
    shifted_pos_ids = tf.constant(1, shape=[3, 5])
    d_input_ids = tf.constant(1, shape=[3, 5])
    d_ans_input_ids = tf.constant(1, shape=[3, 5])
    tag_ids = tf.constant(0, shape=[3, 5, N_TAGS_PER_CONTENT])
    part_ids = tf.constant(1, shape=[3, 5])
    prior_explanation_ids = tf.constant(1, shape=[3, 5])
    prior_question_elapsed_time_input = tf.constant(1.0, shape=[3, 5])
    lag_time = tf.constant(1.0, shape=[3, 5])
    abs_pos_ids = tf.constant(1.0, shape=[3, 5])
    shifted_abs_pos_ids = tf.constant(1.0, shape=[3, 5])
    task_container_pos_ids = tf.constant(1.0, shape=[3, 5])
    correct_answer_id = tf.constant(1, shape=[3, 5])
    n_questions_answered_scaled = tf.constant(1.0, shape=[3, 5])
    n_lectures_watched_scaled = tf.constant(1.0, shape=[3, 5])
    answered_correctly_ratio = tf.constant(1.0, shape=[3, 5])
    part_correctness_ratio = tf.constant(1.0, shape=[3, 5, PART_VOCAB_SIZE - 2])
    part_count_scaled = tf.constant(1.0, shape=[3, 5, PART_VOCAB_SIZE - 2])
    correct_answer_count_scaled = tf.constant(1.0, shape=[3, 5, ANSWER_3_ID - ANSWER_0_ID + 1])
    correct_answer_correctness_ratio = tf.constant(1.0, shape=[3, 5, ANSWER_3_ID - ANSWER_0_ID + 1])
    current_part_count_scaled = tf.constant(1.0, shape=[3, 5])
    current_part_correctness_ratio = tf.constant(1.0, shape=[3, 5])
    current_correct_answer_count_scaled = tf.constant(1.0, shape=[3, 5])
    current_correct_answer_correctness_ratio = tf.constant(1.0, shape=[3, 5])
    current_question_count_scaled = tf.constant(1.0, shape=[3, 5])
    current_question_correctness_ratio = tf.constant(1.0, shape=[3, 5])
    
    inputs = {
        'c_input_ids': content_input_ids,
        'r_input_ids': response_input_ids,
        'd_input_ids': d_input_ids,
        'd_ans_input_ids': d_ans_input_ids,
        'pos_ids': pos_ids,
        'shifted_pos_ids': shifted_pos_ids,
        'tag_ids': tag_ids,
        'part_ids': part_ids,
        'prior_explanation_ids': prior_explanation_ids,
        'prior_question_elapsed_time_input': prior_question_elapsed_time_input,
        'lag_time': lag_time,
        'abs_pos_ids': abs_pos_ids,
        'shifted_abs_pos_ids': shifted_abs_pos_ids,
        'task_container_pos_ids': task_container_pos_ids,
        'correct_answer_id': correct_answer_id,
        'n_questions_answered_scaled': n_questions_answered_scaled,
        'n_lectures_watched_scaled': n_lectures_watched_scaled,
        'answered_correctly_ratio': answered_correctly_ratio,
        'part_correctness_ratio': part_correctness_ratio,
        'part_count_scaled': part_count_scaled,
        'correct_answer_count_scaled': correct_answer_count_scaled,
        'correct_answer_correctness_ratio': correct_answer_correctness_ratio,
        'current_part_count_scaled': current_part_count_scaled,
        'current_part_correctness_ratio': current_part_correctness_ratio,
        'current_correct_answer_count_scaled': current_correct_answer_count_scaled,
        'current_correct_answer_correctness_ratio': current_correct_answer_correctness_ratio,
        'current_question_count_scaled': current_question_count_scaled,
        'current_question_correctness_ratio': current_question_correctness_ratio,
    }

    model_type = 'ed'
    config = EdFormerConfig(
        model_type=model_type, model_desc='dummy',
        share_position_embeddings=True, use_tags=True, user_part=True,
        use_prior_explanation=True, use_prior_question_elapsed_time_input=True,
        use_lag_time=True,
        use_lag_time_for_encoder=True,
        use_user_answer=True
    )
    predictor = TFEdFormerAnswerPredictionModel(config)

    @tf.function
    def foo(inputs):

        logits, logits_2, c_outputs, r_outputs = predictor(inputs=inputs, output_attentions=True, output_hidden_states=True, return_dict=False)
        return logits

    logits = foo(inputs)
    print(logits)

## Learing rate with warmup

In [None]:
class LinearLR(tf.keras.optimizers.schedules.LearningRateSchedule):
    
    def __init__(self, total_steps, lr=1e-4, end_lr=1e-6, warmup_steps=WARMUP_STEPS):
        
        self.total_steps = tf.cast(total_steps, dtype=DTYPE)
        self.lr = lr
        self.end_lr = end_lr
        self.warmup_steps = tf.cast(warmup_steps, dtype=DTYPE)
        
    def __call__(self, step):
        
        is_warmup = tf.cast(step < self.warmup_steps, dtype=DTYPE)
        
        warmup_lr = is_warmup * self.lr * (step + 1) / self.warmup_steps

        # To avoid overfitting
        # _lr = self.lr
        _lr = tf.math.minimum(self.lr, 1 * self.end_lr)
        
        decay_lr = (1 - is_warmup) * (_lr - (_lr - self.end_lr) / tf.math.maximum(self.total_steps - self.warmup_steps, 1) * (step - self.warmup_steps + 1))
        decay_lr = (1 - is_warmup) * tf.math.maximum(self.end_lr, decay_lr)

        lr = warmup_lr + decay_lr

        return lr

class NoamLR(tf.keras.optimizers.schedules.LearningRateSchedule):

    def __init__(self, hidden_size, lr, end_lr, warmup_steps):

        self.lr = tf.cast(lr, DTYPE)
        self.end_lr = tf.cast(end_lr, DTYPE)
        self.warmup_steps = tf.cast(warmup_steps, DTYPE)
        self.hidden_size = tf.cast(hidden_size, DTYPE)

    def __call__(self, step):

        scaling = self.lr / (self.hidden_size**-0.5 * self.warmup_steps**-0.5)

        lr = scaling * self.hidden_size**-0.5 * tf.math.minimum((step + 1) * self.warmup_steps**-1.5, (step + 1)**-0.5)

        is_warmup = tf.cast(step < self.warmup_steps, dtype=DTYPE)

        lr = lr * is_warmup + (1.0 - is_warmup) * tf.math.maximum(self.end_lr, lr) 

        return lr

In [None]:
def get_input_signatures(batch_size, seq_len, valid=False):

    input_signatures = {
        'user_id': tf.TensorSpec(shape=[batch_size], dtype=tf.int64, name='user_id'),
        'seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='seq_len'),
        'prev_seq_len': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='prev_seq_len'),
        'start': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='start'),
        'end': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='end'),
        'row_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='row_id'),
        'timestamp': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int64, name='timestamp'),
        'content_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_id'),
        'content_type_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='content_type_id'),
        'task_container_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='task_container_id'),
        'user_answer': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='user_answer'),
        'shifted_user_answer': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_user_answer'),
        'answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='answered_correctly'),
        'shifted_answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_answered_correctly'),
        'prior_question_elapsed_time': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='prior_question_elapsed_time'),
        'prior_question_elapsed_time_input': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='prior_question_elapsed_time_input'),
        'prior_question_had_explanation': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='prior_question_had_explanation'),
        'pred_time_mask': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='pred_time_mask'),
        'abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='abs_pos'),
        'shifted_abs_pos': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_abs_pos'),
        'pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='pos_ids'),
        'shifted_pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='shifted_pos_ids'),
        'abs_pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='abs_pos_ids'),
        'shifted_abs_pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='shifted_abs_pos_ids'),
        'task_container_pos_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='task_container_pos_ids'),                                           
        'c_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='c_input_ids'),
        'r_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='r_input_ids'),
        'd_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='d_input_ids'),
        'd_ans_input_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='d_ans_input_ids'),
        'tag_ids': tf.TensorSpec(shape=[batch_size, seq_len, N_TAGS_PER_CONTENT], dtype=tf.int32, name='tag_ids'),
        'part_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='part_ids'),                 
        'prior_explanation_ids': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='prior_explanation_ids'),                                      
        'lag_time': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='lag_time'),
        'n_questions_answered': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_questions_answered'),  
        'n_lectures_watched': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_lectures_watched'),  
        'n_questions_answered_correctly': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_questions_answered_correctly'),  
        'n_questions_answered_scaled': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='n_questions_answered_scaled'),
        'n_lectures_watched_scaled': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='n_lectures_watched_scaled'),
        'answered_correctly_ratio': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='answered_correctly_ratio'),
        'part_count_scaled': tf.TensorSpec(shape=[batch_size, seq_len, PART_VOCAB_SIZE - 2], dtype=DTYPE, name='part_count_scaled'),
        'part_correctness_ratio': tf.TensorSpec(shape=[batch_size, seq_len, PART_VOCAB_SIZE - 2], dtype=DTYPE, name='part_correctness_ratio'),
        'current_part_count_scaled': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_correccurrent_part_count_scaledt_answer_count_scaled'),
        'current_part_correctness_ratio': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_part_correctness_ratio'),
        'correct_answer_count_scaled': tf.TensorSpec(shape=[batch_size, seq_len, ANSWER_3_ID - ANSWER_0_ID + 1], dtype=DTYPE, name='correct_answer_count_scaled'),
        'correct_answer_correctness_ratio': tf.TensorSpec(shape=[batch_size, seq_len, ANSWER_3_ID - ANSWER_0_ID + 1], dtype=DTYPE, name='correct_answer_correctness_ratio'),            
        'current_correct_answer_count_scaled': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_correct_answer_count_scaled'),
        'current_correct_answer_correctness_ratio': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_correct_answer_correctness_ratio'),
        'n_prev_seen': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_prev_seen'),
        'n_prev_correctness': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='n_prev_correctness'),
        'current_question_count_scaled': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_question_count_scaled'),
        'current_question_correctness_ratio': tf.TensorSpec(shape=[batch_size, seq_len], dtype=DTYPE, name='current_question_correctness_ratio'),
        'correct_answer_id': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='correct_answer_id'),       
        'target': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='target'),
        'answer_target': tf.TensorSpec(shape=[batch_size, seq_len], dtype=tf.int32, name='answer_target'),
        'nb_pred_places': tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='nb_pred_places'),
    }

    if valid:

        input_signatures['n_valid_blocks'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='n_valid_blocks')
        input_signatures['valid_block_idx'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_block_idx')
        input_signatures['valid_start'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_start')
        input_signatures['valid_end'] = tf.TensorSpec(shape=[batch_size], dtype=tf.int32, name='valid_end')
        # input_signatures['valid_blocks_start_pos'] =
        # input_signatures['valid_blocks_end_pos'] =
        # input_signatures['valid_block_pos'] =

    return [input_signatures]

## Train Manager

In [None]:
class TrainConfig:

    def __init__(
        self,
        ckpt_path=CKPT_TRAIN_PATH,
        window_size=WINDOW_SIZE,
        loss_weight_window_size=LOSS_WEIGHT_WINDOW_SIZE,
        n_epochs=N_EPOCHS,
        batch_size=BATCH_SIZE,
        pred_batch_size=PRED_BATCH_SIZE,
        shuffle_buf_size=SHUFFLE_BUFFER_SIZE,
        seed=SEED,
        deterministic=DETERMINISTIC,
        num_parallel_reads=N_PARALLEL_READS,
        num_parallel_calls=N_PARALLEL_CALLS,
        steps_per_call=STEPS_PER_CALL,
        max_n_contents_per_user_for_sampling_prob=MAX_N_CONTENTS_PER_USER_FOR_SAMPLING_PROB,
        valid_fold=VALID_FOLD
    ):

        self.ckpt_path = ckpt_path
        self.window_size = window_size
        self.loss_weight_window_size = loss_weight_window_size
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.pred_batch_size = pred_batch_size
        self.shuffle_buf_size = shuffle_buf_size
        self.seed=seed
        self.deterministic = deterministic
        self.num_parallel_reads = num_parallel_reads
        self.num_parallel_calls = num_parallel_calls
        self.steps_per_call = steps_per_call
        self.max_n_contents_per_user_for_sampling_prob = max_n_contents_per_user_for_sampling_prob
        self.valid_fold = valid_fold

        n_contents_dict = load_data(n_contents_dict_path)
        n_training_examples = 0      
        for k, v in n_contents_dict.items():
            n_training_examples += math.ceil(v * 1.0 / self.window_size)
            ### n_training_examples += 1
        del n_contents_dict

        if self.valid_fold not in [0, 1, 2, 3]:
            valid_info = {}
            n_valid_examples = 0
        else:
            valid_info = load_data(valid_info_paths[self.valid_fold])
            n_valid_examples = 0
            for k, v in valid_info.items():
                n_valid_examples += v['bundle_info']['n_blocks']
            del valid_info

        # This is slightly higher than the actual number of examples used for training.
        self.n_training_examples = n_training_examples
        # Disabled - Let's make each epoch longer
        n_training_steps_per_epoch = 1 * self.n_training_examples // self.batch_size
        self.n_training_calls_per_epoch = n_training_steps_per_epoch // self.steps_per_call
        self.n_training_steps_per_epoch = self.n_training_calls_per_epoch * self.steps_per_call
        self.n_training_steps = self.n_epochs * self.n_training_steps_per_epoch

        self.n_valid_examples = n_valid_examples
        if IS_KAGGLE and tpu is not None:
            self.n_valid_examples -= self.n_valid_examples % self.pred_batch_size
        self.n_valid_steps = self.n_valid_examples // self.pred_batch_size + int(self.n_valid_examples % self.pred_batch_size > 0)
        self.n_steps_in_last_valid_call = self.n_valid_steps % self.steps_per_call
        self.n_valid_calls = self.n_valid_steps // self.steps_per_call + int(self.n_steps_in_last_valid_call > 0)

        self.lr = None
        self.end_lr = None
        self.warmup_steps = None

    def toJSON(self):

        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)

class Train_Manager:
    
    def __init__(self, config, train_config):
        
        self.config = config
        self.train_config = train_config

        n_contents_dict = load_data(n_contents_dict_path)
        user_id_list = []
        n_contents_list = []
        
        if self.train_config.valid_fold in [0, 1, 2, 3]:
            # Get the splitted training dataset.
            train_valid_split_indices = load_data(train_valid_split_indices_paths[self.train_config.valid_fold])
            self.valid_user_ids = list(train_valid_split_indices.keys())
        else:
            train_valid_split_indices = {}
            self.valid_user_ids = []
        
        for user_id, n_contents_orig in n_contents_dict.items():
            
            if self.train_config.valid_fold in [0, 1, 2, 3] and user_id in train_valid_split_indices:
                split_index = train_valid_split_indices[user_id]
                
            else:
                split_index = n_contents_orig

            n_contents_in_train = split_index

            user_id_list.append(user_id)
            if n_contents_in_train <= 32:
                n_contents_in_train = 32
            n_contents_list.append(n_contents_in_train)

        user_id_tensor = tf.constant(user_id_list, dtype=tf.int64)
        n_contents_tensor = tf.constant(n_contents_list, dtype=tf.int32)
        n_contents_tensor = tf.math.minimum(n_contents_tensor, self.train_config.max_n_contents_per_user_for_sampling_prob)
        sample_prob_tensor = tf.cast(n_contents_tensor, dtype=tf.float32) / tf.cast(tf.math.reduce_sum(n_contents_tensor), dtype=tf.float32)    
        # Use `0.9` for some more randomness
        sample_prob_tensor = 0.9 * sample_prob_tensor / tf.math.reduce_max(sample_prob_tensor)

        del user_id_list
        del n_contents_list
        del n_contents_dict

        initializer = tf.lookup.KeyValueTensorInitializer(user_id_tensor, sample_prob_tensor)
        training_sample_prob_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)
        self.training_sample_prob_table = training_sample_prob_table

        self.train_ds = strategy.experimental_distribute_dataset(self.get_train_ds(from_valid=False))
        self.train_ds_from_valid = strategy.experimental_distribute_dataset(self.get_train_ds(from_valid=True))
        self.train_ds_from_valid_only_last = strategy.experimental_distribute_dataset(self.get_train_ds(from_valid=True, only_last=True))

        self.valid_ds = strategy.experimental_distribute_dataset(self.get_valid_ds(debug=False))
        
        self.valid_ds_debug = self.get_valid_ds(debug=True)

    def toJSON(self):

        config = json.loads(self.config.toJSON())
        train_config = json.loads(self.train_config.toJSON())

        for k, v in train_config.items():
            config[k] = v

        return json.dumps(config, sort_keys=False, ensure_ascii=False, indent=4)

    def get_train_ds(self, from_valid=False, only_last=False):

        only_last = False
        
        train_raw_ds = tf.data.TFRecordDataset(train_tfrec_paths, num_parallel_reads=self.train_config.num_parallel_reads)
        train_raw_ds = train_raw_ds.map(parse_train_example, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)

        # add aggregated historical information
        train_raw_ds = train_raw_ds.map(lambda example: add_aggregated_historical_information(example), num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)

        if self.train_config.valid_fold in [0, 1, 2, 3]:
            # Get the splitted training dataset.
            train_valid_split_indices = load_data(train_valid_split_indices_paths[self.train_config.valid_fold])            
            train_valid_split_table = convert_split_index_dict(train_valid_split_indices)
            reduced_raw_ds = split_train_ds(train_raw_ds, train_valid_split_table, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)
        else:
            reduced_raw_ds = train_raw_ds
            
        reduced_raw_ds_augmented = reduced_raw_ds.flat_map(
            lambda x: tf.data.Dataset.from_tensors(x).repeat(
                tf.cast(
                    tf.math.ceil(tf.cast(x['seq_len'], tf.float32) / self.train_config.window_size),
                    dtype=tf.int64
                )
            )
        )

        # batch examples with attributes having different lengths across examples - tf.RaggedTensor
        batched_raw_ds = reduced_raw_ds_augmented.repeat().shuffle(self.train_config.shuffle_buf_size, seed=self.train_config.seed).apply(tf.data.experimental.dense_to_ragged_batch(self.train_config.batch_size))
        
        # batch - tf.Tensor: Extract random subsequences of a fixed length
        batched_ds = random_subseqs_from_batched_raw_ds(batched_raw_ds, window_size=self.train_config.window_size, only_last=only_last, seed=self.train_config.seed, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)
        
        # Add input ids and targets for training        
        train_ds = prepare_training_dataset(batched_ds, generative=self.config.generative, use_abs_pos=self.config.use_abs_pos, num_parallel_calls=self.train_config.num_parallel_calls, deterministic=self.train_config.deterministic)
        
        train_ds = train_ds.prefetch(8)

        return train_ds

    def get_valid_ds(self, debug=False):

        p = valid_tfrec_paths
        num_parallel_reads = self.train_config.num_parallel_reads
        num_parallel_calls = self.train_config.num_parallel_calls
        deterministic = self.train_config.deterministic
        
        valid_user_id_tensor = tf.constant(self.valid_user_ids, dtype=tf.int64)
        repeat_tensor = tf.ones_like(valid_user_id_tensor, dtype=tf.int32)
        initializer = tf.lookup.KeyValueTensorInitializer(valid_user_id_tensor, repeat_tensor)
        valid_users_table = tf.lookup.StaticHashTable(
            initializer, default_value=tf.constant(0, dtype=tf.int32), name=None
        )        

        valid_raw_ds = tf.data.TFRecordDataset(p, num_parallel_reads=num_parallel_reads)
        valid_raw_ds = valid_raw_ds.map(parse_train_example_with_valid_info, num_parallel_calls=num_parallel_calls, deterministic=deterministic)
        # don't include examples in other folds
        valid_raw_ds.filter(lambda x: tf.math.equal(valid_users_table.lookup(x['user_id']), 1))
        valid_ds = prepare_validation_dataset(valid_raw_ds, batch_size=self.train_config.pred_batch_size, window_size=self.train_config.window_size, generative=self.config.generative, use_abs_pos=self.config.use_abs_pos, num_parallel_calls=num_parallel_calls, deterministic=deterministic)
        valid_ds = valid_ds.prefetch(8)

    def get_train_objs(self, lr=LR, end_lr=END_LR, warmup_steps=WARMUP_STEPS):
        
        self.train_config.lr = lr
        self.train_config.end_lr = end_lr,
        self.train_config.warmup_steps = warmup_steps

        with strategy.scope():

            predictor = TFEdFormerAnswerPredictionModel(self.config)

            ### _lr = NoamLR(hidden_size=self.config.dim, lr=lr, end_lr=end_lr, warmup_steps=warmup_steps)
            _lr = LinearLR(total_steps=self.train_config.n_training_steps, lr=lr, end_lr=end_lr, warmup_steps=warmup_steps)

            # optimizer = tf.keras.optimizers.Adam(learning_rate=_lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
            optimizer = tf.keras.optimizers.Adam(
                _lr,
                beta_1=tf.Variable(0.9),
                beta_2=tf.Variable(0.999),
                epsilon=tf.Variable(1e-8)
            )
            optimizer.iterations  # this access will invoke optimizer._iterations method and create optimizer.iter attribute
            optimizer.decay = tf.Variable(0.0) # Adam.__init__ assumes ``decay`` is a float object, so this needs to be converted to tf.Variable **after** __init__ method.

            if self.config.use_softmax:
                loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
                    from_logits=True, reduction=tf.keras.losses.Reduction.NONE,
                    name='sparse_categorical_crossentropy_for_correctness'
                )
            else:
                loss_obj = tf.keras.losses.BinaryCrossentropy(
                    from_logits=True, label_smoothing=0, reduction=tf.keras.losses.Reduction.NONE,
                    name='binary_crossentropy_for_correctness'
                )

            loss_obj_answer = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE,
                name='sparse_categorical_crossentropy_for_answer'
            )

            loss_metric = tf.keras.metrics.Sum()
            acc_metric = tf.keras.metrics.BinaryAccuracy()
            auc_metric = tf.keras.metrics.AUC(num_thresholds=2000)

            # --------------------------------------------------

            loss_metric_answer = tf.keras.metrics.Sum()
            acc_metric_answer = tf.keras.metrics.SparseCategoricalAccuracy()

            # --------------------------------------------------

            metrics = (loss_metric, acc_metric, auc_metric, loss_metric_answer, acc_metric_answer)
        
        return (predictor, optimizer,  loss_obj, loss_obj_answer, metrics)

    def train_valid(
            self, predictor, optimizer, ckpt_manager,
            loss_obj, loss_obj_answer, metrics,
            last_epoch=0,
            from_valid=False, valid_epochs=None, only_valid=False
        ):

        use_user_answer_loss = tf.cast(tf.constant(self.config.use_user_answer_loss), dtype=DTYPE)

        batch_size = self.train_config.batch_size
        seq_len = self.train_config.window_size

        encoder_decoder = tf.constant(self.config.model_type=='ed', dtype=tf.int32)
        generative = tf.constant(self.config.generative, dtype=tf.int32)
        allow_bundle_atten = tf.constant(self.config.allow_bundle_atten, dtype=tf.int32)

        loss_metric, acc_metric, auc_metric, loss_metric_answer, acc_metric_answer = metrics

        @tf.function(
            input_signature=get_input_signatures(int(batch_size / strategy.num_replicas_in_sync), seq_len, valid=False)
        )
        def train_step(train_batch):

            training = tf.constant(1, dtype=tf.int32)
        
            c_mask, r_mask, r_c_mask, c_r_mask = get_attention_masks(train_batch, training, encoder_decoder, generative, allow_bundle_atten)

            targets = train_batch['target']
            answer_targets = train_batch['answer_target']
            pred_mask = targets != NON_TARGET_ID
            pred_indices = tf.where(pred_mask)

            selected_targets = tf.gather_nd(targets, pred_indices)
            selected_answer_targets = tf.gather_nd(answer_targets, pred_indices)
            
            # Compute the gradients for a list of variables.
            with tf.GradientTape() as tape:

                # `logits`: shape = [batch_size, seq_len, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
                (logits, answer_logits, c_outputs, r_outputs) = predictor(
                    train_batch,
                    c_mask,
                    r_mask,
                    r_c_mask,
                    c_r_mask,
                    output_attentions=False, output_hidden_states=False, training=True
                )
                
                # shape = [n_selected_places, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
                selected_logits = tf.gather_nd(logits, pred_indices)
                
                # shape = [n_selected_places, 4]
                selected_answer_logits = tf.gather_nd(answer_logits, pred_indices)

                # From the doc, `tf.keras.losses.SparseCategoricalCrossentropy` should use `selected_targets`, but it seems ok to have the last dimension.
                # For `BinaryCrossentropy`, we need to use `selected_targets[:, tf.newaxis]` to have the 2nd dimension so the losses are not averaged.
                losses = loss_obj(selected_targets[:, tf.newaxis], selected_logits)
                
                answer_losses = loss_obj_answer(selected_answer_targets, selected_answer_logits)
                
                total_loss = tf.math.reduce_sum(losses)
                total_answer_loss = tf.math.reduce_sum(answer_losses)

                # `train_batch['nb_pred_places'][0]` is the total number of places used for calculating loss
                loss = (total_loss + total_answer_loss * use_user_answer_loss) / tf.cast(train_batch['nb_pred_places'][0], dtype=DTYPE)
                            
            grads = tape.gradient(loss, predictor.trainable_variables)

            # Process the gradients, for example cap them, etc.
            # capped_grads = [MyCapper(g) for g in grads]
            ### processed_grads = [process_gradient(g) for g in grads]
            ### processed_grads = grads

            # Ask the optimizer to apply the processed gradients.
            optimizer.apply_gradients(zip(grads, predictor.trainable_variables))
            
            if self.config.use_softmax:
                # shape = [batch_size, seq_len, 2]
                preds = tf.math.softmax(logits)
                # shape = [n_selected_places, 2]
                selected_preds = tf.math.softmax(selected_logits)
            else:
                # shape = [batch_size, seq_len, 1]
                preds = tf.math.sigmoid(logits)
                # shape = [n_selected_places, 1]
                selected_preds = tf.math.sigmoid(selected_logits)
            
            answer_preds = tf.math.softmax(answer_logits, axis=-1)
            selected_answer_preds = tf.math.softmax(selected_answer_logits, axis=-1)
            
            loss_metric.update_state(total_loss)
            
            # Use `selected_preds[:, -1:]` to get the probabilities for class `1`, with the 2nd dim
            acc_metric.update_state(selected_targets[:, tf.newaxis], selected_preds[:, -1:])
            # Use `selected_preds[:, -1]` to get the probabilities for class `1`, without the 2nd dim
            auc_metric.update_state(selected_targets, selected_preds[:, -1])

            loss_metric_answer.update_state(total_answer_loss)
            acc_metric_answer.update_state(selected_answer_targets, selected_answer_preds)          
            
        batch_size = None
        @tf.function(
            input_signature=get_input_signatures(batch_size, seq_len, valid=True)
        )
        def valid_step(valid_batch):

            training = tf.constant(0, dtype=tf.int32)

            c_mask, r_mask, r_c_mask, c_r_mask = get_attention_masks(valid_batch, training, encoder_decoder, generative, allow_bundle_atten)
            
            # if generative == 1:

            #     start_pos = self.train_config.window_size - MAX_PRED_TIME_QUESTION_BUNDLE_LEN
            #     _, logits = predictor.generate(
            #         valid_batch, start_pos=start_pos, window_size=self.train_config.window_size, dim=self.config.dim,
            #         c_mask=c_mask, r_mask=r_mask, r_c_mask=r_c_mask, c_r_mask=c_r_mask
            #     )

            # else:

            logits, answer_logits, _, _ = predictor(
                valid_batch,
                c_mask,
                r_mask,
                r_c_mask,
                c_r_mask,
                output_attentions=False, output_hidden_states=False, training=False
            )
            
            # `targets` are defined for all places (other than [PAD] and lectures).
            # However, during validation, unlike during training, we only focus on the places that are in prediction time (and being questions).
            pred_time_mask = valid_batch['pred_time_mask'] * tf.cast(valid_batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)
            targets = valid_batch['target']
            answer_targets = valid_batch['answer_target']

           # only select the last `MAX_PRED_TIME_QUESTION_BUNDLE_LEN`
            pred_time_mask = pred_time_mask[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            targets = targets[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            answer_targets = answer_targets[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            logits = logits[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:, :]
            answer_logits = answer_logits[:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:, :]
            row_ids = valid_batch['row_id'][:, -MAX_PRED_TIME_QUESTION_BUNDLE_LEN:]
            user_ids = valid_batch['user_id'][:, tf.newaxis] * tf.ones_like(row_ids, dtype=tf.int64)

            targets = targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)
            answer_targets = answer_targets * pred_time_mask + NON_TARGET_ID * (1 - pred_time_mask)

            pred_mask = targets != NON_TARGET_ID
            pred_indices = tf.where(pred_mask)

            selected_targets = tf.gather_nd(targets, pred_indices)
            selected_logits = tf.gather_nd(logits, pred_indices)

            if self.config.use_softmax:
                # shape = [batch_size, seq_len, 2]
                preds = tf.math.softmax(logits)
                # shape = [n_selected_places, 2]
                selected_preds = tf.math.softmax(selected_logits)
            else:
                # shape = [batch_size, seq_len, 1]
                preds = tf.math.sigmoid(logits)
                # shape = [n_selected_places, 1]
                selected_preds = tf.math.sigmoid(selected_logits)            

            answer_preds = tf.math.softmax(answer_logits, axis=-1)

            selected_answer_targets = tf.gather_nd(answer_targets, pred_indices)
            selected_answer_logits = tf.gather_nd(answer_logits, pred_indices)
            selected_answer_preds = tf.math.softmax(selected_answer_logits, axis=-1)

            # From the doc, `tf.keras.losses.SparseCategoricalCrossentropy` should use `selected_targets`, but it seems ok to have the last dimension.
            # For `BinaryCrossentropy`, we need to use `selected_targets[:, tf.newaxis]` to have the 2nd dimension so the losses are not averaged.
            losses = loss_obj(selected_targets[:, tf.newaxis], selected_logits)
            
            answer_losses = loss_obj_answer(selected_answer_targets, selected_answer_logits)

            total_loss = tf.math.reduce_sum(losses)
            total_answer_loss = tf.math.reduce_sum(answer_losses)

            # `train_batch['nb_pred_places'][0]` is the total number of places used for calculating loss
            loss = (total_loss + total_answer_loss * use_user_answer_loss) / tf.cast(tf.math.reduce_mean(valid_batch['nb_pred_places']), dtype=DTYPE)

            loss_metric.update_state(total_loss)
            
            # Use `selected_preds[:, -1:]` to get the probabilities for class `1`, with the 2nd dim
            acc_metric.update_state(selected_targets[:, tf.newaxis], selected_preds[:, -1:])
            # Use `selected_preds[:, -1]` to get the probabilities for class `1`, without the 2nd dim
            auc_metric.update_state(selected_targets, selected_preds[:, -1])

            loss_metric_answer.update_state(total_answer_loss)
            acc_metric_answer.update_state(selected_answer_targets, selected_answer_preds)   

            return targets, preds, row_ids, user_ids, pred_mask, answer_targets, answer_preds

        @tf.function
        def dist_train_step(dist_train_batch):
            
            strategy.run(train_step, args=(dist_train_batch,))

        @tf.function
        def dist_valid_step(dist_valid_batch):
            
            targets, preds, row_ids, user_ids, pred_mask, answer_targets, answer_preds = strategy.run(valid_step, args=(dist_valid_batch,))
 
            return targets, preds, row_ids, user_ids, pred_mask, answer_targets, answer_preds

        @tf.function
        def dist_train_multi_steps(dist_train_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)
 
            for _ in tf.range(self.train_config.steps_per_call):

                dist_train_batch = next(dist_train_iter)
                dist_train_step(dist_train_batch)
                
                if type(dist_train_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_train_batch['nb_pred_places'])
                else:  # type(dist_train_batch['nb_pred_places']) == tf.python.distribute.values.PerReplica
                       # `dist_train_batch['nb_pred_places']` is PerPlica --> Use `.values` attribute to access tensors.
                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_train_batch['nb_pred_places'].values, axis=0))

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

            return total_nb_pred_places

        @tf.function
        def dist_valid_multi_steps(dist_valid_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)

            t_arr_1 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_2 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_3 = tf.TensorArray(tf.bool, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_4 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_5 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)

            t_arr_6 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_7 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)

            for step_idx in tf.range(self.train_config.steps_per_call):

                dist_valid_batch = next(dist_valid_iter)
                _targets, _preds, _row_ids, _user_ids, _pred_mask, _answer_targets, _answer_preds = dist_valid_step(dist_valid_batch)
                
                if type(dist_valid_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_valid_batch['nb_pred_places'])
                else:

                    _targets = tf.concat(_targets.values, axis=0)
                    _preds = tf.concat(_preds.values, axis=0)
                    _pred_mask = tf.concat(_pred_mask.values, axis=0)

                    _row_ids = tf.concat(_row_ids.values, axis=0)
                    _user_ids = tf.concat(_user_ids.values, axis=0)

                    _answer_targets = tf.concat(_answer_targets.values, axis=0)
                    _answer_preds = tf.concat(_answer_preds.values, axis=0)                    

                    # `dist_valid_batch['nb_pred_places']` is PerPlica --> Use `.values` attribute to access tensors.
                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_valid_batch['nb_pred_places'].values, axis=0))

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)
            
                t_arr_1 = t_arr_1.write(step_idx, _targets)
                t_arr_2 = t_arr_2.write(step_idx, _preds)
                t_arr_3 = t_arr_3.write(step_idx, _pred_mask)
                t_arr_4 = t_arr_4.write(step_idx, _row_ids)
                t_arr_5 = t_arr_5.write(step_idx, _user_ids)
                t_arr_6 = t_arr_6.write(step_idx, _answer_targets)
                t_arr_7 = t_arr_7.write(step_idx, _answer_preds)

            targets = t_arr_1.concat()
            preds = t_arr_2.concat()
            pred_mask = t_arr_3.concat()
            row_ids = t_arr_4.concat()
            user_ids = t_arr_5.concat()
            answer_targets = t_arr_6.concat()
            answer_preds = t_arr_7.concat()

            return targets, preds, row_ids, user_ids, pred_mask, total_nb_pred_places, answer_targets, answer_preds

        @tf.function
        def dist_valid_multi_steps_last_call(dist_valid_iter):

            total_nb_pred_places = tf.constant(0.0, dtype=DTYPE)

            t_arr_1 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_2 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_3 = tf.TensorArray(tf.bool, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_4 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_5 = tf.TensorArray(tf.int64, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_6 = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False)
            t_arr_7 = tf.TensorArray(DTYPE, size=0, dynamic_size=True, clear_after_read=False)

            for step_idx in tf.range(self.train_config.n_steps_in_last_valid_call):

                dist_valid_batch = next(dist_valid_iter)
                _targets, _preds, _row_ids, _user_ids, _pred_mask, _answer_targets, _answer_preds = dist_valid_step(dist_valid_batch)
                
                if type(dist_valid_batch['nb_pred_places']) == tf.Tensor:
                    nb_pred_places = tf.math.reduce_mean(dist_valid_batch['nb_pred_places'])
                else:

                    _targets = tf.concat(_targets.values, axis=0)
                    _preds = tf.concat(_preds.values, axis=0)
                    _pred_mask = tf.concat(_pred_mask.values, axis=0)

                    _row_ids = tf.concat(_row_ids.values, axis=0)
                    _user_ids = tf.concat(_user_ids.values, axis=0)

                    _answer_targets = tf.concat(_answer_targets.values, axis=0)
                    _answer_preds = tf.concat(_answer_preds.values, axis=0)    

                    nb_pred_places = tf.math.reduce_mean(tf.concat(dist_valid_batch['nb_pred_places'].values, axis=0))

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

                t_arr_1 = t_arr_1.write(step_idx, _targets)
                t_arr_2 = t_arr_2.write(step_idx, _preds)
                t_arr_3 = t_arr_3.write(step_idx, _pred_mask)
                t_arr_4 = t_arr_4.write(step_idx, _row_ids)
                t_arr_5 = t_arr_5.write(step_idx, _user_ids)
                t_arr_6 = t_arr_6.write(step_idx, _answer_targets)
                t_arr_7 = t_arr_7.write(step_idx, _answer_preds)

            targets = t_arr_1.concat()
            preds = t_arr_2.concat()
            pred_mask = t_arr_3.concat()
            row_ids = t_arr_4.concat()
            user_ids = t_arr_5.concat()
            answer_targets = t_arr_6.concat()
            answer_preds = t_arr_7.concat()

            return targets, preds, row_ids, user_ids, pred_mask, total_nb_pred_places, answer_targets, answer_preds

        n_training_steps_per_epoch = self.train_config.n_training_steps_per_epoch
        n_training_steps = self.train_config.n_epochs * n_training_steps_per_epoch

        if MAX_TRAIN_ITER_STEPS is not None:
            n_training_steps_per_epoch = MAX_TRAIN_ITER_STEPS
            n_training_steps = self.train_config.n_epochs * n_training_steps_per_epoch
           
        self.train_config.n_training_steps_per_epoch = n_training_steps_per_epoch
        self.train_config.n_training_steps = n_training_steps   

        with open('train_config.json', 'w', encoding='UTF-8') as fp:
            json.dump(json.loads(self.toJSON()), fp, ensure_ascii=False, indent=4)
        if not IS_KAGGLE:
            !gsutil cp -r './train_config.json' '{self.train_config.ckpt_path}'
     
        n_epochs = self.train_config.n_epochs
        n_training_examples = self.train_config.n_training_examples

        # TODO: Not use hard coded numbers
        if from_valid:
            n_epochs = 3
            n_training_examples = int(2380588 / self.train_config.window_size)
            n_training_steps_per_epoch = int(n_training_examples / self.train_config.batch_size)
            n_training_steps = n_epochs * n_training_steps_per_epoch

        print(f'n_epochs: {n_epochs}')
        print(f'n_training_examples: {n_training_examples}')
        print(f'n_training_steps_per_epoch: {n_training_steps_per_epoch}')
        print(f'n_training_steps: {n_training_steps}')
        print(f'n_valid_examples: {self.train_config.n_valid_examples}')
        print(f'n_valid_steps: {self.train_config.n_valid_steps}')        
        
        training_history = dict()

        def train(last_epoch=0, from_valid=False, valid_epochs=None):

            start_epoch = last_epoch + 1

            end_epoch = self.train_config.n_epochs + 1
            train_ds = self.train_ds
            if from_valid:
                end_epoch = start_epoch + 3
                train_ds = self.train_ds_from_valid

            dist_train_iter = iter(train_ds)

            for epoch in range(start_epoch, end_epoch):

                training_history[epoch] = {
                    'loss': [],
                    'acc': [],
                    'auc': [],
                    'answer_loss': [],
                    'answer_acc': [],
                    'timing': []        
                }

                n_steps = 0
                n_preds = 0
                n_last_preds = 0

                total_nb_pred_places = 0.0

                start_epoch_t = datetime.datetime.now()

                for call_idx in range(self.train_config.n_training_calls_per_epoch):

                    start_t = datetime.datetime.now()

                    nb_pred_places = dist_train_multi_steps(dist_train_iter)                    
                    n_steps += self.train_config.steps_per_call

                    total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)
                    
                    avg_loss = (loss_metric.result() / total_nb_pred_places).numpy()
                    acc = acc_metric.result().numpy()
                    auc = auc_metric.result().numpy()

                    avg_answer_loss = (loss_metric_answer.result() / total_nb_pred_places).numpy()
                    answer_acc = acc_metric_answer.result().numpy()

                    # --------------------------------------------------  

                    if (call_idx + 1) % max(1, self.train_config.n_training_calls_per_epoch // 100) == 0:

                        training_history[epoch]['loss'].append(float(avg_loss))
                        training_history[epoch]['acc'].append(float(acc))
                        training_history[epoch]['auc'].append(float(auc))
                        training_history[epoch]['answer_loss'].append(float(avg_answer_loss))
                        training_history[epoch]['answer_acc'].append(float(answer_acc))

                    # --------------------------------------------------        

                    elapsed = (datetime.datetime.now() - start_t).total_seconds()

                    if (call_idx + 1) % max(1, self.train_config.n_training_calls_per_epoch // 10) == 0:

                        print(f'epoch: {epoch} - step: {n_steps}')
                        print(f'timing per step: {elapsed / self.train_config.steps_per_call}')
                        print(f'loss: {avg_loss}')
                        print(f'acc: {acc}')
                        print(f'auc: {auc}')
                        print(f'answer_loss: {avg_answer_loss}')
                        print(f'answer_acc: {answer_acc}')
                        print(f'lr: {optimizer._decayed_lr(DTYPE)}')        
                        print('-' * 80)

                end_epoch_t = datetime.datetime.now()
                elapsed_epoch_t = (end_epoch_t - start_epoch_t).total_seconds()
                training_history[epoch]['timing'] = elapsed_epoch_t

                training_history[epoch]['train_loss'] = float(avg_loss)
                training_history[epoch]['train_acc'] = float(acc)
                training_history[epoch]['train_auc'] = float(auc)
                training_history[epoch]['train_answer_loss'] = float(avg_answer_loss)
                training_history[epoch]['train_answer_acc'] = float(answer_acc)

                if IS_KAGGLE and tpu:
                    ckpt_manager.checkpoint.save(
                        file_prefix=ckpt_manager.directory + 'ckpt', options=tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
                    )
                else:
                    ckpt_manager.save()

                print(f'epoch: {epoch} - train')
                print(f'train loss: {avg_loss}')
                print(f'train acc: {acc}')
                print(f'train auc: {auc}')
                print(f'train answer loss: {avg_answer_loss}')
                print(f'train answer acc: {answer_acc}')
                print('-' * 80)

                loss_metric.reset_states()
                acc_metric.reset_states()
                auc_metric.reset_states()
                loss_metric_answer.reset_states()
                acc_metric_answer.reset_states()

                # --------------------------------------------------
                # saving

                fn = f'training_history-{epoch}.json'
                with open(fn, 'w', encoding='UTF-8') as fp:
                    json.dump(training_history, fp, ensure_ascii=False, indent=4)

                if not IS_KAGGLE:
                    !gsutil cp -r './{fn}' '{self.train_config.ckpt_path}'
                    !rm -rf './{fn}'

                # --------------------------------------------------
                # validation

                if valid_epochs is None:
                    valid_epochs = [self.train_config.n_epochs]
                
                if epoch in valid_epochs:
                    valid(epoch)

        def valid(epoch):

            if epoch not in training_history:

                training_history[epoch] = {
                    'loss': [],
                    'acc': [],
                    'auc': [],
                    'answer_loss': [],
                    'answer_acc': [],
                    'timing': []        
                }                

            valid_user_ids = []
            valid_row_ids = []
            valid_targets = []
            valid_preds = []
            valid_answer_targets = []
            valid_answer_preds = []

            n_steps = 0
            total_nb_pred_places = 0.0

            dist_valid_iter = iter(self.valid_ds)

            start_valid = datetime.datetime.now()

            n_valid_calls = self.train_config.n_valid_calls
            if self.train_config.n_steps_in_last_valid_call > 0:
                n_valid_calls -= 1

            for call_idx in range(n_valid_calls):

                start_t = datetime.datetime.now()

                targets, preds, row_ids, user_ids, pred_mask, nb_pred_places, answer_targets, answer_preds = dist_valid_multi_steps(dist_valid_iter)            
                n_steps += self.train_config.steps_per_call

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

                pred_indices = tf.where(pred_mask)

                selected_targets = tf.gather_nd(targets, pred_indices)
                
                # shape = [n_selected_places, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
                selected_preds = tf.gather_nd(preds, pred_indices)

                selected_answer_targets = tf.gather_nd(answer_targets, pred_indices)
                selected_answer_preds = tf.gather_nd(answer_preds, pred_indices)

                selected_row_ids = tf.gather_nd(row_ids, pred_indices)
                selected_user_ids = tf.gather_nd(user_ids, pred_indices)

                avg_loss = (loss_metric.result() / total_nb_pred_places).numpy()
                acc = acc_metric.result().numpy()
                auc = auc_metric.result().numpy()

                avg_answer_loss = (loss_metric_answer.result() / total_nb_pred_places).numpy()
                answer_acc = acc_metric_answer.result().numpy()

                elapsed = (datetime.datetime.now() - start_t).total_seconds()

                if call_idx % 1 == 0:

                    print(f'epoch: {epoch} - valid step: {n_steps}')
                    print(f'valid timing per steps: {elapsed / self.train_config.steps_per_call}')
                    print(f'valid loss: {avg_loss}')
                    print(f'valid acc: {acc}')
                    print(f'valid auc: {auc}')
                    print(f'valid answer loss: {avg_answer_loss}') 
                    print(f'valid answer acc: {answer_acc}')        
                    print('-' * 80)

                valid_row_ids.extend(selected_row_ids.numpy().tolist())
                valid_user_ids.extend(selected_user_ids.numpy().tolist())
                valid_targets.extend(selected_targets.numpy().tolist())
                
                # Use `selected_preds[:, -1]` to get the probabilities for class `1`, without the 2nd dim
                valid_preds.extend(selected_preds[:, -1].numpy().tolist())
                
                valid_answer_targets.extend(selected_answer_targets.numpy().tolist())
                valid_answer_preds.extend(selected_answer_preds.numpy().tolist())

            # ----------------------------------------------------------------------------------------------------
            # Last TPU call

            if self.train_config.n_steps_in_last_valid_call > 0:

                start_t = datetime.datetime.now()

                targets, preds, row_ids, user_ids, pred_mask, nb_pred_places, answer_targets, answer_preds = dist_valid_multi_steps_last_call(dist_valid_iter)
                n_steps += self.train_config.n_steps_in_last_valid_call

                total_nb_pred_places += tf.cast(nb_pred_places, dtype=DTYPE)

                pred_indices = tf.where(pred_mask)

                selected_targets = tf.gather_nd(targets, pred_indices)
                selected_preds = tf.gather_nd(preds, pred_indices)

                selected_answer_targets = tf.gather_nd(answer_targets, pred_indices)
                selected_answer_preds = tf.gather_nd(answer_preds, pred_indices)

                selected_row_ids = tf.gather_nd(row_ids, pred_indices)
                selected_user_ids = tf.gather_nd(user_ids, pred_indices)

                avg_loss = (loss_metric.result() / total_nb_pred_places).numpy()
                acc = acc_metric.result().numpy()
                auc = auc_metric.result().numpy()

                avg_answer_loss = (loss_metric_answer.result() / total_nb_pred_places).numpy()
                answer_acc = acc_metric_answer.result().numpy()

                elapsed = (datetime.datetime.now() - start_t).total_seconds()

                valid_row_ids.extend(selected_row_ids.numpy())
                valid_user_ids.extend(selected_user_ids.numpy())
                
                # Use `selected_preds[:, -1]` to get the probabilities for class `1`, without the 2nd dim
                valid_preds.extend(selected_preds[:, -1].numpy())

                valid_answer_targets.extend(selected_answer_targets.numpy())
                valid_answer_preds.extend(selected_answer_preds.numpy())

            # ----------------------------------------------------------------------------------------------------

            end_valid = datetime.datetime.now()
            elapsed_valid = (end_valid - start_valid).total_seconds()
            training_history[epoch]['valid_timing'] = elapsed_valid

            training_history[epoch]['valid_loss'] = float(avg_loss)
            training_history[epoch]['valid_acc'] = float(acc)
            training_history[epoch]['valid_auc'] = float(auc)
            training_history[epoch]['valid_answer_loss'] = float(avg_answer_loss)
            training_history[epoch]['valid_answer_acc'] = float(answer_acc)

            print(f'epoch: {epoch} - valid')
            print(f'valid loss: {avg_loss}')
            print(f'valid acc: {acc}')
            print(f'valid auc: {auc}')
            print(f'valid answer loss: {avg_answer_loss}') 
            print(f'valid answer acc: {answer_acc}')                     
            print('=' * 80)

            loss_metric.reset_states()
            acc_metric.reset_states()
            auc_metric.reset_states()
            loss_metric_answer.reset_states()
            acc_metric_answer.reset_states()
            
            # --------------------------------------------------

            valid_submission_dict = {
                'row_ids': valid_row_ids,
                'user_ids': valid_user_ids,
                'targets': valid_targets,
                'preds': valid_preds
            }

            if self.config.use_user_answer_loss:

                valid_answer_preds = np.array(valid_answer_preds)

                valid_submission_dict['answer_targets'] = valid_answer_targets
                valid_submission_dict['answer_preds_0'] = valid_answer_preds[:, 0]
                valid_submission_dict['answer_preds_1'] = valid_answer_preds[:, 1]
                valid_submission_dict['answer_preds_2'] = valid_answer_preds[:, 2]
                valid_submission_dict['answer_preds_3'] = valid_answer_preds[:, 3]


            # valid_submission = pd.DataFrame.from_dict(valid_submission_dict)
            # valid_submission.to_csv(f'valid_submission_epoch_{epoch}.csv', index=False)

            # if not IS_KAGGLE:
            #     !gsutil cp -r './valid_submission_epoch_{epoch}.csv' '{self.train_config.ckpt_path}'                   
            #     !rm -rf './valid_submission_epoch_{epoch}.csv'

            # --------------------------------------------------

            fn = f'training_history-{epoch}.json'
            if only_valid:
                fn = f'training_history-{epoch}-only-valid.json'
            with open(fn, 'w', encoding='UTF-8') as fp:
                json.dump(training_history, fp, ensure_ascii=False, indent=4)

            if not IS_KAGGLE:
                !gsutil cp -r './{fn}' '{self.train_config.ckpt_path}'
                !rm -rf './{fn}'

            # --------------------------------------------------

        if only_valid:
            valid(last_epoch)
        else:
            train(last_epoch, from_valid, valid_epochs)

#### check

In [None]:
if not IS_KAGGLE:
    
    config = EdFormerConfig(model_type='ed', model_desc='dummy')
    train_config = TrainConfig()
    train_manager = Train_Manager(config, train_config)

In [None]:
if not IS_KAGGLE:
    train_manager.train_ds

In [None]:
if not IS_KAGGLE:
    train_manager.valid_ds

In [None]:
if not IS_KAGGLE:
    train_manager.valid_ds_debug

# Iter Manager

## Utilitites

In [None]:
def convert_dt(_dt, pred=False, unique_question_id_train=None, unique_lecture_id_train=None):
    """
    Change column type, deal with NaN value. If it is a `datatable.Frame` from the test dataset,
    change it to a format suitable for prediction.

    Args:
        _dt: `datatable.Frame`, representing a block of the training dataset or a test batch given by `env.iter_test`.
    """

    _dt[dt.f.prior_question_elapsed_time] = dt.float32
    _dt[dt.f.prior_question_elapsed_time == None, 'prior_question_elapsed_time'] = -1.0
    _dt[dt.f.prior_question_had_explanation] = dt.int8
    _dt[dt.f.prior_question_had_explanation == None, 'prior_question_had_explanation'] = -1
    _dt[dt.f.content_type_id] = dt.int8

    if pred:

        _dt['answered_correctly'] = MASK_TOKEN
        _dt['user_answer'] = MASK_TOKEN

        try:
            del _dt['prior_group_answers_correct']
            del _dt['prior_group_responses']
            # del _dt['group_num']
        except:
            pass

    if DEBUG:
        
        if unique_question_id_train is not None:
            # All test questions must have been seen in training time.
            assert set(_dt[dt.f.content_type_id == 0, 'content_id'].to_list()[0]).issubset(unique_question_id_train)
        if unique_lecture_id_train is not None:    
            # All test lectures must have been seen in training time.
            assert set(_dt[dt.f.content_type_id == 1, 'content_id'].to_list()[0]).issubset(unique_lecture_id_train)
            
            
def convert_df_to_dt(df, test=False, unique_question_id_train=None, unique_lecture_id_train=None):
    """Convert a `pandas.DataFrame` to `datatable.Frame` with some extra processing.

    Args:
        df: `pandas.DataFrame`, representing a block of the training dataset or a test batch given by `env.iter_test()`.    
    """

    if DEBUG:

        _question_ids = set(df[df['content_type_id'] == 0]['content_id'].values.tolist())
        _lecture_ids = set(df[df['content_type_id'] == 1]['content_id'].values.tolist())

        if unique_question_id_train is not None:
            # All test questions must have been seen in training time.
            assert _question_ids.issubset(unique_question_id_train)

        if unique_lecture_id_train is not None:    
            # All test lectures must have been seen in training time.
            assert _lecture_ids.issubset(unique_lecture_id_train)

    _dt = dt.Frame(df.astype({"prior_question_had_explanation": float}))

    if test:

        prior_group_answers_correct = eval(_dt[0, 'prior_group_answers_correct'])
        prior_group_responses = eval(_dt[0, 'prior_group_responses'])

        if prior_group_answers_correct is None:
            prior_group_answers_correct = []

        if prior_group_responses is None:
            prior_group_responses = []

        if DEBUG:

            assert type(prior_group_answers_correct) == list
            assert type(prior_group_responses) == list

    convert_dt(_dt, pred=test, unique_question_id_train=unique_question_id_train, unique_lecture_id_train=unique_lecture_id_train)

    if test:
        return _dt, prior_group_answers_correct, prior_group_responses
    else:
        return _dt

## Pred Iter Manager

In [None]:
class Pred_Iter_Manager:
    
    def __init__(self, unique_question_id_train=None, unique_lecture_id_train=None):
                
        self.current_batch_no = 0
        
        self._unique_question_id_train = unique_question_id_train
        self._unique_lecture_id_train = unique_lecture_id_train
                
    def get_split_index(self, user_id):

        # `-1` means that `user_id` is not in `valid_info.json`.
        # For the test dataset, `valid_info.json` is not used, and any `user_id`, if it is in the original training dataset,
        # the model is trained for this user.
        return -1 
        
    def _step(self, n_users=None):
        
        raise NotImplementedError
        
        # return pred_df, pred_dt, prior_group_answers_correct, prior_group_responses, n_blocks_in_batch
        
    def iter_dataset(self):
        
        raise NotImplementedError

## Valid Iter Manager

### Function to sample sequence lengths

In [None]:
def sample_sequence_lengths(upper_bound, sample_size, expected_mean, seed):
    
    seq_lengths = tf.range(1, upper_bound+1)
    seq_len_mean = tf.math.reduce_mean(tf.cast(seq_lengths, dtype=DTYPE))
    
    N = sample_size
    mu = expected_mean
    sigma = mu / 2.5
    
    normal_dist = tfp.distributions.Normal(
        loc=mu, scale=sigma
    )

    # But the values might not in seq_lengths.
    # Extreme case: seq_lengths is a constant --> only 1 value is allowed ...
    seq_len_sample = normal_dist.sample(sample_shape=[N], seed=seed)
    
    avg_selected_seq_len = tf.math.reduce_mean(seq_len_sample)
    
    mask = tf.math.logical_and(2 * mu >= seq_len_sample, seq_len_sample >= 1)

    # About 95 % will be kept
    seq_len_sample = tf.gather_nd(seq_len_sample, indices=tf.where(mask))
    
    seq_len_sample = tf.cast(tf.math.round(seq_len_sample), dtype=tf.int32)
    
    uniform_dist = tfp.distributions.Uniform(
        low=0, high=len(seq_lengths)
    )

    # For the remaining 5 %, let's take the uniform distribution
    selected_indices = tf.cast(uniform_dist.sample(sample_shape=[N - seq_len_sample.shape[0]], seed=seed), dtype=tf.int32)

    seq_len_sample_2 = tf.gather(seq_lengths, indices=selected_indices)

    seq_len_sample_2 = seq_len_sample_2
    
    seq_len_sample = tf.concat([seq_len_sample, seq_len_sample_2], axis=0)
        
    return seq_len_sample

### Valid Iter Manager

In [None]:
class User_Interaction_Manager(collections.abc.Iterator):
    
    def __init__(self, user_id, user_valid_info, train_dt):
                
        self.user_id = user_id
        self.user_valid_info = user_valid_info
        
        self._n_remaining_blocks = self.n_blocks
        self._current_block_starting_index = 0
        
        self._train_dt = train_dt
        
    @property
    def n_blocks(self):
        
        return self.user_valid_info['bundle_info']['n_blocks']
    
    @property
    def n_remaining_blocks(self):
        
        return self._n_remaining_blocks
    
    def _get_next_block_starting_index(self):
        
        if self._current_block_starting_index not in self.user_valid_info['bundle_info']['block_starting_index_dict']:
            next_index = self._current_block_starting_index + 1
        else:
            next_index = self.user_valid_info['bundle_info']['block_starting_index_dict'][self._current_block_starting_index]
        
        if DEBUG:
            assert 0 <= next_index <= self.user_valid_info['seq_len']
        
        return next_index
        
    def _get_user_dt(self, next_block_starting_index):
        
        row_start = self.user_valid_info['row_start']
        
        block_row_start = row_start + self._current_block_starting_index
        block_row_end = row_start + next_block_starting_index - 1
        
        if DEBUG:
            assert block_row_end <= self.user_valid_info['row_end']
        
        user_dt = self._get_user_dt_from_row_indices(block_row_start, block_row_end)
        
        return user_dt
    
    def _get_user_dt_from_row_indices(self, row_start, row_end):
    
        user_dt = self._train_dt[row_start:(row_end + 1), :]
        
        return user_dt
    
    def has_next(self):
        
        return not (self._n_remaining_blocks <= 0 or self._current_block_starting_index >= self.user_valid_info['seq_len'])
    
    def __next__(self):
        
        if self._n_remaining_blocks <= 0:
            
            # raise ValueError('User has no more remaining block!')
            raise StopIteration
            
        if self._current_block_starting_index >= self.user_valid_info['seq_len']:

            # raise ValueError('Block starting index >= seq_len, cannot step anymore!')
            raise StopIteration
            
        next_block_starting_index = self._get_next_block_starting_index()

        user_dt = self._get_user_dt(next_block_starting_index)
        
        # reset the state
        self._n_remaining_blocks -= 1        
        self._current_block_starting_index = next_block_starting_index
        
        if DEBUG:
            
            # sanity check
            if self._n_remaining_blocks <= 0:            
                assert self._current_block_starting_index == self.user_valid_info['seq_len']
            if self._current_block_starting_index >= self.user_valid_info['seq_len']:
                assert self._n_remaining_blocks == 0

            assert user_dt.nrows > 0    
            
        return user_dt
        

class Valid_Iter_Manager(Pred_Iter_Manager):
    
    def __init__(self, valid_info, train_dt, unique_question_id_train=None, unique_lecture_id_train=None):

        super(Valid_Iter_Manager, self).__init__(unique_question_id_train, unique_lecture_id_train)
        
        self._valid_info = valid_info
        
        self._n_users = len(set(self._valid_info.keys()))

        _n_blocks = 0
        for user_id, user_valid_info in self._valid_info.items():
            _n_blocks += user_valid_info['bundle_info']['n_blocks']
        self._n_blocks = _n_blocks

        self._remaining_users = set(self._valid_info.keys())
        self._n_remaining_blocks = self._n_blocks
                
        self._current_answered_correctly = None
        self._current_user_answer = None
        
        self._user_manager_dict = dict()
        for user_id, user_valid_info in self._valid_info.items():
            self._user_manager_dict[user_id] = User_Interaction_Manager(user_id, user_valid_info, train_dt)

        self.train_dt = train_dt
        
        self._split_indices = {user_id: valid_info_for_user['split_index'] for user_id, valid_info_for_user in self._valid_info.items()}
        
    @classmethod
    def create(cls, valid_info_path, train_dt_path_or_obj, unique_question_id_train_path=None, unique_lecture_id_train_path=None):

        args = []
        if valid_info_path is None:
            valid_info = {}
        else:
            valid_info = load_data(valid_info_path)
        args.append(valid_info)
            
        for p in [unique_question_id_train_path, unique_lecture_id_train_path]:
            if p is not None:
                args.append(load_data(p))
                
        if type(train_dt_path_or_obj) == str:
            train_dt_path_or_obj = load_data(train_dt_path_or_obj)
            
        args = args[:1] + [train_dt_path_or_obj] + args[1:]
        
        return cls(*args)
                
    def get_split_index(self, user_id):
        
        # This means that no interaction for `user_id` is included in the validation dataset.
        if user_id not in self._split_indices:
            return -1
                    
        if DEBUG:
            # `self._split_indices[user_id]` could never be `seq_len`.
            assert 0 <= self._split_indices[user_id] < self._valid_info[user_id]['orig_seq_len']
        
        # If `self._split_indices[user_id]` is `0`, the user with `user_id` is a user unseen in the splitted training dataset.
        return self._split_indices[user_id]
        
    def iter_dataset(self, max_steps=None):
        
        print('entering iter_dataset()')

        expected_mean_blocks_per_batch = 30.0
        expected_n_batches = int(self._n_blocks / expected_mean_blocks_per_batch) + 1
        
        
        ### sampled_block_size = sample_sequence_lengths(upper_bound=40, sample_size=expected_n_batches, expected_mean=expected_mean_blocks_per_batch, seed=SEED)
        sampled_block_size = [30] * 100000
        
        ### sample_mean_blocks_per_batch = tf.math.reduce_mean(tf.cast(sampled_block_size, dtype=DTYPE))
        sample_mean_blocks_per_batch = 30.0
        
        ### sampled_block_size = tf.random.shuffle(sampled_block_size, seed=SEED)
        ### sampled_block_size = sampled_block_size.numpy().tolist()
        
        if DEBUG:
            assert len(sampled_block_size) >= expected_n_batches
            
        print(f'expected_mean_blocks_per_batch_in_pred: {expected_mean_blocks_per_batch}')
        print(f'sample_mean_blocks_per_batch_in_pred: {sample_mean_blocks_per_batch}')
        print(f'expected_n_batches_in_pred: {expected_n_batches}')
        
        total_blocks_yield = 0
        
        print('start looping ...')
        while True:
            
            if max_steps is not None and self.current_batch_no >= max_steps:
                break            
            
            try:
                
                n_blocks_to_select = 30.0
                if self.current_batch_no < len(sampled_block_size):
                    n_blocks_to_select = sampled_block_size[self.current_batch_no]    
                    
                valid_df, valid_dt, prior_group_answers_correct, prior_group_responses, n_blocks_in_batch = self._step(n_blocks_to_select)

                if DEBUG:
                    assert n_blocks_in_batch > 0
                
                total_blocks_yield += n_blocks_in_batch

                # sanity check
                if total_blocks_yield > self._n_blocks:
                    raise ValueError('total_blocks_yield > self._n_blocks! The code has some bugs to fix!')
                
                if DEBUG:
                    assert total_blocks_yield == self._n_blocks - self._n_remaining_blocks

                self.current_batch_no += 1

                if self.current_batch_no % PRINTING_STEPS == 0:

                    print(f'current_batch_no: {self.current_batch_no}')
                    print(f'n_blocks processed in current batch: {n_blocks_in_batch}')
                    print(f'n_blocks processed: {total_blocks_yield}')
                    print(f'averaged n_blocks processed per batch: {total_blocks_yield / self.current_batch_no}')
                    print(f'n_users_completed: {self._n_users - len(self._remaining_users)}')
                    
                yield valid_df, valid_dt, prior_group_answers_correct, prior_group_responses
    
            except StopIteration:
                break    
    
        # sanity check
        if DEBUG:
            assert len(self._user_manager_dict) == 0
            assert len(self._remaining_users) == 0
            assert self._n_remaining_blocks == 0    

        print('finished looping.')

    def _step(self, n_users):
        
        return self._blocks_from_random_users(n_users)
            
    def _blocks_from_random_users(self, n_users):
        
        if DEBUG:
            assert len(self._remaining_users) >= 0
        
        if len(self._remaining_users) == 0:
            raise StopIteration
        
        n_users_to_select = min(n_users, len(self._remaining_users))
                
        # randomly select `n_users`.
        selected_users = random.sample(self._remaining_users, k=n_users_to_select)
        
        if DEBUG:
            assert len(selected_users) == n_users_to_select
        
        # For each selected user, get its 1st remained block (in `dt` format) and reset its states.
        user_dts = []
        for user_id in selected_users:
            
            if self._user_manager_dict[user_id].has_next():

                user_dts.append(next(self._user_manager_dict[user_id]))

                if not self._user_manager_dict[user_id].has_next():
                    del self._user_manager_dict[user_id]
                    self._remaining_users.remove(user_id)
            
        # Combine dt Frames
        valid_dt = dt.rbind(user_dts)
         
        n_blocks_selected = n_users_to_select
        self._n_remaining_blocks -= n_blocks_selected
        
        # --------------------------------------------------------------------------------
        
        prior_group_answers_correct = self._current_answered_correctly
        prior_group_responses = self._current_user_answer

        if prior_group_answers_correct is None:
            prior_group_answers_correct = []
            
        if prior_group_responses is None:
            prior_group_responses = []

        if DEBUG:
            
            assert type(prior_group_answers_correct) == list
            assert type(prior_group_responses) == list
            
        # --------------------------------------------------------------------------------
        
        self._current_answered_correctly = valid_dt[:, 'answered_correctly'].to_list()[0]
        self._current_user_answer = valid_dt[:, 'user_answer'].to_list()[0]
        
        # --------------------------------------------------------------------------------        
        
        convert_dt(valid_dt, pred=True, unique_question_id_train=self._unique_question_id_train, unique_lecture_id_train=self._unique_lecture_id_train)
            
        valid_df = valid_dt.to_pandas()

        return valid_df, valid_dt, prior_group_answers_correct, prior_group_responses, n_blocks_selected

## Test Iter Manager

In [None]:
class Test_Iter_Manager(Pred_Iter_Manager):
    
    def __init__(self, unique_question_id_train=None, unique_lecture_id_train=None):
        
        super(Test_Iter_Manager, self).__init__(unique_question_id_train, unique_lecture_id_train)

        self.env = riiideducation.make_env()
        self._iter_test = iter(self.env.iter_test())
        
    @classmethod
    def create(cls, unique_question_id_train_path=None, unique_lecture_id_train_path=None):
        
        args = []
        for p in [unique_question_id_train_path, unique_lecture_id_train_path]:
            if p is not None:
                args.append(load_data(p))
        
        return cls(*args)
      
    def _has_next(self):
        
        return self._has_next
        
    def iter_dataset(self, max_steps=None):
        
        print('entering iter_dataset()')
        
        total_blocks_yield = 0
        
        print('start looping ...')
        while True:
            
            try:
                
                test_df, test_dt, prior_group_answers_correct, prior_group_responses, n_blocks_in_batch = self._step(n_users=None)
                assert n_blocks_in_batch > 0
                total_blocks_yield += n_blocks_in_batch
                self.current_batch_no += 1
                
                if self.current_batch_no % PRINTING_STEPS == 0:
                
                    print(f'current_batch_no: {self.current_batch_no}')
                    print(f'n_blocks processed in current batch: {n_blocks_in_batch}')
                    print(f'n_blocks processed: {total_blocks_yield}')
                    print(f'averaged n_blocks processed per batch: {total_blocks_yield / self.current_batch_no}')                    

                yield test_df, test_dt, prior_group_answers_correct, prior_group_responses
                
            except StopIteration:
                break
            
        print('finished looping.')
                
    def _step(self, n_users=None):
        
        test_df, _ = next(self._iter_test)
        
        # conversion
        test_dt, prior_group_answers_correct, prior_group_responses = convert_df_to_dt(
            test_df, test=True,
            unique_question_id_train=self._unique_question_id_train,
            unique_lecture_id_train=self._unique_lecture_id_train
        )        

        n_blocks = len(set(test_dt[:, 'user_id'].to_list()[0]))
        
        return test_df, test_dt, prior_group_answers_correct, prior_group_responses, n_blocks

# Prediction Manager

In [None]:
def get_user_dt(user_id, _dt, user_id_to_row_id, split_index=-1):
    """Get the partial `datatable.Frame` in `_dt` containing only `user_id`.
    
    Args:
        user_id: `int`. It must in `user_id_to_row_id`, which should be precomputed from `_dt`.
        _dt: `datatable.Frame`.
        user_id_to_row_id: `dict`. See the above markdown cell for the format. 
    """

#     assert user_id in user_id_to_row_id
    
    (row_id_start, row_id_end) = user_id_to_row_id[user_id]
#     assert split_index == -1 or (0 <= split_index <= (row_id_end - row_id_start) + 1)
    
    if split_index > -1:
        row_id_end = row_id_start + split_index - 1 
    
    user_dt = _dt[row_id_start:row_id_end + 1, :]

    return user_dt


attrs = [
    'user_id',
    'row_id',
    'timestamp',
    'content_id',
    'content_type_id',
    'task_container_id',
    'user_answer',
    'answered_correctly',
    'prior_question_elapsed_time',
    'prior_question_had_explanation'
]

extended_attrs = ['user_id'] + ['seq_len', 'prev_seq_len', 'start', 'end'] + \
    [k for k in attrs if k not in ['user_id']] + \
    ['pred_time_mask', 'abs_pos', 'shifted_abs_pos', 'shifted_answered_correctly', 'shifted_user_answer'] + \
    ['lag_time']
    
aggregated_attrs_less = [
    'n_questions_answered',
    'n_questions_answered_correctly',
    'n_lectures_watched'
]
for part_idx in range(2, 9):
        
    aggregated_attrs_less.append(f'part_{part_idx}_count')
    aggregated_attrs_less.append(f'part_{part_idx}_correctness_count')

for correct_answer_idx in range(ANSWER_0_ID, ANSWER_3_ID + 1):
        
	aggregated_attrs_less.append(f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_count')
	aggregated_attrs_less.append(f'correct_answer_{correct_answer_idx - ANSWER_0_ID}_answered_correctly_count')
    
# containing 'n_prev_seen' and 'n_prev_correctness'
aggregated_attrs = aggregated_attrs_less + ['n_prev_seen', 'n_prev_correctness']    
    
    
attr_dtypes = {}
for k in extended_attrs:
    
    if k in ['row_id', 'user_id', 'timestamp', 'lag_time']:
        attr_dtypes[k] = tf.int64
    elif k in ['prior_question_elapsed_time']:
        attr_dtypes[k] = DTYPE
    else:
        attr_dtypes[k] = tf.int32

for k in aggregated_attrs:
    attr_dtypes[k] = tf.int32

        
class User_Record:
    """
    Sequences (per attribute) of records for a single user.
    """

    def __init__(self, user_dt=None, user_id=None):
        """Exactly one argument should be `None`.
        
        Args:
            user_dt: A `datatable.Frame` object for a single user.
            user_id: int.
        """

        ### assert (user_dt is not None and user_id is None) or (user_id is not None and user_dt is None)
        
        user_ids = user_dt[:, 'user_id'].to_list()[0] if user_dt is not None else []
        
#         # single value - Each record is for a single user.
#         if user_dt is not None:
            
#             assert user_dt.nrows > 0
#             assert len(set(user_ids)) == 1
        
        self.user_id = user_ids[0] if user_dt is not None else user_id
    
        self.row_id = user_dt[:, 'row_id'].to_list()[0] if user_dt is not None else []
        self.timestamp = user_dt[:, 'timestamp'].to_list()[0] if user_dt is not None else []
        self.content_id = user_dt[:, 'content_id'].to_list()[0] if user_dt is not None else []
        self.content_type_id = user_dt[:, 'content_type_id'].to_list()[0] if user_dt is not None else []
        self.task_container_id = user_dt[:, 'task_container_id'].to_list()[0] if user_dt is not None else []
        self.user_answer = user_dt[:, 'user_answer'].to_list()[0] if user_dt is not None else []
        self.answered_correctly = user_dt[:, 'answered_correctly'].to_list()[0] if user_dt is not None else []
        self.prior_question_elapsed_time = user_dt[:, 'prior_question_elapsed_time'].to_list()[0] if user_dt is not None else []
        self.prior_question_had_explanation = user_dt[:, 'prior_question_had_explanation'].to_list()[0] if user_dt is not None else []
        
        # This informatioin should be set externally
        for aggregated_feature_idx, aggregated_feature_name in enumerate(aggregated_attrs):
            setattr(self, aggregated_feature_name, [])
        
#         if DEBUG:
#             # make sure the timestamp is always in order.
#             assert self.timestamp == sorted(self.timestamp)
            
    def extend(self, other):
        """
        Add the content of another recocrd `other` to the record `self`.
        """        
        
#         assert (self.user_id == other.user_id)
        
#         if DEBUG:
#             # The `timestamp` should be in order while adding new entries to existing record.
#             if len(self.timestamp) > 0:
#                 # Don't allow a non-empty record to be extended by an empty record. 
#                 assert len(other.timestamp) > 0
#                 assert self.timestamp[-1] <= other.timestamp[0]
                
        for k in attrs + aggregated_attrs:
            if k != 'user_id':
                getattr(self, k).extend(getattr(other, k))

    def update_answer_results(self, prior_correctnesses, prior_answers, prior_user_info_for_post_update, question_history_at_training_end):
        """
        Update the answers and their correctnesses in a record which was previously unknown in the last test batch.
        """
        
        """
        'user_id': user_id,
        'pred_record': _pred_record,
        'n_pred_time_steps': n_pred_time_steps,
        'aggregated_info_to_update': aggregated_info_to_update        
        """
        
        n_pred_time_steps = prior_user_info_for_post_update['n_pred_time_steps']
        aggregated_info_to_update = prior_user_info_for_post_update['aggregated_info_to_update']
        
#         assert n_pred_time_steps == len(prior_correctnesses)
#         assert n_pred_time_steps == len(prior_answers)
        
#         if DEBUG:
            
#             # sanity check
#             assert len(prior_correctnesses) == len(prior_answers)

#             assert len(self.answered_correctly) >= len(prior_correctnesses)
#             assert len(self.user_answer) >= len(prior_answers)

#             # the places to be updated should contain only `MASK_TOKEN` (i.e. unknown results)
#             assert set(self.answered_correctly[-len(prior_correctnesses):]) == {MASK_TOKEN}
#             assert set(self.user_answer[-len(prior_answers):]) == {MASK_TOKEN}
                
        self.answered_correctly = self.answered_correctly[:-len(prior_correctnesses)] + prior_correctnesses
        self.user_answer = self.user_answer[:-len(prior_answers)] + prior_answers
        
        # Update aggregated correction information
        for idx in range(n_pred_time_steps):
            
            index_in_pred_record = - (n_pred_time_steps - idx)
            
            # ------------------------------------------------------------------------
            # Update `n_pre_correctness` in `question_history_at_training_end`
            
            # At the current step, it is a question and answered correctly
            if prior_correctnesses[idx] == 1:
                
                # Get the current question id
                current_question_id_str = str(self.content_id[index_in_pred_record])
                # update `n_pre_correctness` in `question_history_at_training_end`.
                # The 'n_prev_seen' is already update in `question_history_at_training_end`.
                # `self.n_prev_seen` and `self.n_prev_correctness` will be updated later.
                question_history_at_training_end[str(self.user_id)][current_question_id_str][1] += 1
            # ------------------------------------------------------------------------
            # Update `aggregated_info_to_update`, which are used to update aggregated attributes in `self`.
                
            # No op to perform if we are at the 1st step in the current pred batch (of a particular user).
            if idx > 0:
                
                for k in aggregated_attrs_less:
                    aggregated_info_to_update[k][idx] = aggregated_info_to_update[k][idx-1]                 
                
                # We need to check the previous correctness
                prev_correctness = prior_correctnesses[idx - 1]
                # So  the previous step is a question and being answered correctly
                if prev_correctness == 1:
                    
                    aggregated_info_to_update['n_questions_answered_correctly'][idx] = aggregated_info_to_update['n_questions_answered_correctly'][idx-1] + 1
                    
                    prev_part_key = aggregated_info_to_update['prev_part_key'][idx]

                    prev_correct_answer_key = aggregated_info_to_update['prev_correct_answer_key'][idx]

#                     assert prev_part_key is not None
#                     assert prev_correct_answer_key is not None
                    
                    aggregated_info_to_update[prev_part_key][idx] = aggregated_info_to_update[prev_part_key][idx-1] + 1
                    aggregated_info_to_update[prev_correct_answer_key][idx] = aggregated_info_to_update[prev_correct_answer_key][idx-1] + 1
                    
        # ------------------------------------------------------------------------                    
                    
        for k in aggregated_attrs_less:
            if k == 'n_questions_answered_correctly' or 'correctness_count' in k or 'answered_correctly_count' in k:
                setattr(self, k, getattr(self, k)[:-n_pred_time_steps] + aggregated_info_to_update[k])
                    
#         for k in aggregated_attrs:
#             assert len(getattr(self, k)) == len(self)      
                    
        # ------------------------------------------------------------------------
                    
    def toJSON(self):
        
        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)
    
    def __str__(self):
        
        return self.toJSON()
    
    def __len__(self):
        
        return len(self.row_id)


class Record_Buffer:
    """
    A dictionary like buffer to store and manage (i.e. updating) records.
    """
    
    def __init__(self, record_dict=None):
        """
        `record_dict`: A `dict` mapping user ids (`str`) to their records (`User_Record`).
        """
        
        if record_dict is None:
            self.buffer = {}
        else:
            self.buffer = record_dict
           
    def __contains__(self, x):
        
        return x in self.buffer
        
    def __getitem__(self, x):
        
        if x not in self.buffer:
            raise KeyError(str(x))
        
        return self.buffer[x]
    
    def __len__(self):
        
        return len(self.buffer)
    
    def __del__(self):
        
        del self.buffer
                
    def update(self, record):
        """
        Add a record to the buffer. If its user_id already exists, find and update the existing record.
        """

        if record.user_id not in self.buffer:
            self.buffer[record.user_id] = User_Record(user_id=record.user_id)
            
        self.buffer[record.user_id].extend(record)

    def update_answer_results(self, user_id, prior_correctnesses, prior_answers, prior_user_info_for_post_update, question_history_at_training_end):
        """
        Update the answers and their correctnesses for a single user which was previously unknown in the last test batch.
        """

#         assert user_id in self.buffer
#         assert user_id == prior_user_info_for_post_update['user_id']

        record = self.buffer[user_id]
        record.update_answer_results(prior_correctnesses, prior_answers, prior_user_info_for_post_update, question_history_at_training_end)
                
    def toJSON(self):
        
        return json.dumps(self, default=lambda o: o.__dict__, sort_keys=False, indent=4)
    
    def __str__(self):
        
        return self.toJSON()
        
        
def combine_user_record(record_1, record_2, check=False):
    """
    Returns:
        A new record that combines the two `User_Record` objects `record_1` and  `record_2`.
        The arguments are not modified.
    """

#     assert record_1.user_id == record_2.user_id

#     if check:
#         # This check should be only used for validation!
#         # For submission, we don't know what the row ids would be!
#         if len(record_1.row_id) > 0 and len(record_2.row_id) > 0:
#             assert record_1.row_id[-1] == record_2.row_id[0] - 1   
    
    record = User_Record(user_id=record_1.user_id)

    record.extend(record_1)
    record.extend(record_2)

    return record    
    

class Pred_Manager:
    """
    See `update()` for the description.
    """
    
    def __init__(
        self, config, train_config, pred_iter_manager, train_dt, user_id_to_row_id_train,
        question_history_at_training_end, single_question_history_at_training_end_optimized,
        max_train_buffer_size=30000, probe=False
    ):
        
        self.config = config
        self.train_config = train_config
        
        self.pred_iter_manager = pred_iter_manager
            
        self.train_dt = train_dt
        self.user_id_to_row_id_train = user_id_to_row_id_train

        self.question_history_at_training_end = question_history_at_training_end
        self.single_question_history_at_training_end_optimized = single_question_history_at_training_end_optimized

        self.train_record_buffer = Record_Buffer()
        self.pred_record_buffer = Record_Buffer()
        
        self.current_batch_row_ids = None        
        self.current_batch_users = None
        
        # Used to avoid memory error - not sure if it is necessary.
        self.max_train_buffer_size = max_train_buffer_size

        self.users_in_pred = set()
        self.common_users_in_train_and_pred = set()
        self.new_users_in_pred = set()
        
        self.n_blocks_in_pred = 0
        self.n_blocks_from_common_users_in_train_and_pred = 0
        self.n_blocks_from_new_users_in_pred = 0
        
        self.avg_n_blocks_in_pred = 0
        self.avg_n_blocks_from_common_users_in_train_and_pred = 0
        self.avg_n_blocks_from_new_users_in_pred = 0
        
        # The user ids in the original (before splitted into training and validation datasets)
        self.users_in_train = set(user_id_to_row_id_train.keys())
        
        # The user ids in the actual training dataset (i.e. excluded the validation dataset)
        self.users_trained = set()
        
        for user_id in self.users_in_train:
            
            split_index = self.pred_iter_manager.get_split_index(user_id)
            
            if split_index == -1 or split_index > 0:
                self.users_trained.add(user_id)
        
        self.probe = probe
        
        self.user_performance_hdf5_fp = h5py.File(user_performance_hdf5_path, "r")
        self.user_performance_hdf5 = self.user_performance_hdf5_fp['user_performance']
        
    @classmethod
    def create(cls, config, train_config, pred_iter_manager_class,
               train_dt_path_or_obj,               
               user_id_to_row_id_train_path,
               unique_question_id_train_path,
               unique_lecture_id_train_path,
               question_history_at_training_end_path,
               single_question_history_at_training_end_optimized_path,
               valid_info_path=None,
               max_train_buffer_size=30000,
               probe=False,
               debug=False
        ):
        
        if pred_iter_manager_class == Valid_Iter_Manager:
            
            data_paths = [valid_info_path, train_dt_path_or_obj]
            
        elif pred_iter_manager_class == Test_Iter_Manager:
            data_paths = []
            
        assert type(unique_question_id_train_path) == str and type(unique_lecture_id_train_path) == str
        data_paths.extend([unique_question_id_train_path, unique_lecture_id_train_path])      
        pred_iter_manager = pred_iter_manager_class.create(*data_paths)
        
        user_id_to_row_id_train = load_data(user_id_to_row_id_train_path)

        
        start_memeory = psutil.virtual_memory().available / 1024.0 / 1024.0
        print(f'start_memeory before loading question_history_at_training_end.json: {start_memeory}')

        question_history_at_training_end = load_data(question_history_at_training_end_path)

        end_memeory = psutil.virtual_memory().available / 1024.0 / 1024.0
        print(f'end_memeory after loading question_history_at_training_end.json: {end_memeory}')

        used_memory = start_memeory - end_memeory
        print(f'used_memory for loading question_history_at_training_end.json: {used_memory}')
        print('----------------------------------------')
        
        start_memeory = psutil.virtual_memory().available / 1024.0 / 1024.0
        print(f'start_memeory before loading single_question_history_at_training_end_optimized.json: {start_memeory}')
        
        single_question_history_at_training_end_optimized = load_data(single_question_history_at_training_end_optimized_path)
        
        end_memeory = psutil.virtual_memory().available / 1024.0 / 1024.0
        print(f'end_memeory after loading single_question_history_at_training_end_optimized.json: {end_memeory}')

        used_memory = start_memeory - end_memeory
        print(f'used_memory for loading single_question_history_at_training_end_optimized.json: {used_memory}')
        print('----------------------------------------')
        
        if pred_iter_manager_class == Valid_Iter_Manager:
            train_dt = pred_iter_manager.train_dt
        elif type(train_dt_path_or_obj) != str:
            train_dt = train_dt_path_or_obj
        else:
            train_dt = load_data(train_dt_path_or_obj)

        args = [
                config, train_config, pred_iter_manager, train_dt, user_id_to_row_id_train,
                question_history_at_training_end, single_question_history_at_training_end_optimized,
                max_train_buffer_size, probe
        ]
        
        return cls(*args)
        
#     def reset_train_record_buffer(self):
#         """
#         """
        
#         del self.train_record_buffer
#         self.train_record_buffer = Record_Buffer()
        
    
    def reset_train_record_buffer(self):
        """
        """
        
        user_ids = list(self.train_record_buffer.buffer.keys())
        for user_id in user_ids:
            del self.train_record_buffer.buffer[user_id]
        
        del self.train_record_buffer

        self.train_record_buffer = Record_Buffer()    
    
    def update_batch_users(self, user_ids):
        """
        Store the user ids in the current test batch.
        """
        
        self.current_batch_users = user_ids

    def update_answer_results(self, prev_batch_users, prior_group_answers_correct, prior_group_responses, prior_info_for_post_update):
        """
        When we get a new test batch, we also get the `prior_group_answers_correct` and `prior_group_responses`.
        We use these information to update the `answered_correctly` and `user_anser` fields of the records of
        the users in the previous test batch.
        """
        
#         if DEBUG:
        
#             # sanity check: try to make sure the answer results are for `prev_batch_users` by verifying their lengths.
#             assert len(prev_batch_users) == len(prior_group_answers_correct)
#             assert len(prior_group_answers_correct) == len(prior_group_responses)
        
        d1 = defaultdict(list)
        d2 = defaultdict(list)
        
        for user_id, prior_ans_correct, prior_ans in zip(prev_batch_users, prior_group_answers_correct, prior_group_responses):
            d1[user_id].append(prior_ans_correct)
            d2[user_id].append(prior_ans)
            
        for user_id in d1:
                        
            self.pred_record_buffer.update_answer_results(
                user_id,
                d1[user_id], d2[user_id],
                prior_info_for_post_update[user_id],
                self.question_history_at_training_end
            )
                    
    def update_aggregated_info_buffer(self, user_id, pred_record, record_to_watch, index_to_watch, n_pred_time_steps, pred_time_step_index, buffer):
                
        user_id_str = str(user_id)
        
        # ====================================================================================================
        # This should be done for all `current` positions - dealing with `n_prev_seen` and `n_prev_correctness`
                
        current_index_in_pred_record = - (n_pred_time_steps - pred_time_step_index)
        
        n_prev_seen, n_prev_correctness = None, None
        
        # OP need to be done only if the current position is a question
        if pred_record.content_type_id[current_index_in_pred_record] == 0:
            
            # current is a question
            current_question_id = pred_record.content_id[current_index_in_pred_record]
            current_question_id_str = str(current_question_id)
            
            # -----------------------------------------------------------------------------------------------
            # During the prediction time, we make sure all the questions' info about `n_prev_seen` and `n_prev_correctness` are
            # added to `question_history_at_training_end`.
            if user_id_str not in self.question_history_at_training_end:
                self.question_history_at_training_end[user_id_str] = dict()
            if current_question_id_str not in self.question_history_at_training_end[user_id_str]:
                # Here we use `[None, None]` as a placeholder. The `None` values should be replaced by some integers
                # after the computation is done below.
                # The idea is to make sure we have the necessary keys - so we don't have `KeyError` error.
                self.question_history_at_training_end[user_id_str][current_question_id_str] = [None, None]
            else:
                # We have the values for the current question id, that we can put into the buffer directly.
                n_prev_seen, n_prev_correctness = self.question_history_at_training_end[user_id_str][current_question_id_str]
                    
            # We can't find information in `question_history_at_training_end`,
            # so check with `single_question_history_at_training_end_optimized`.
            if n_prev_seen is None:
                
                ### assert n_prev_correctness is None
                
                if user_id_str in self.single_question_history_at_training_end_optimized:
                    
                    # Change the lists to sets on the fly.
                    # We don't do this op during the loading, because it takes too many memory.
                    # We only do this op for the users with question seen during the prediction.
                    if type(self.single_question_history_at_training_end_optimized[user_id_str]['correct']) == list:
                        self.single_question_history_at_training_end_optimized[user_id_str]['correct'] = set(self.single_question_history_at_training_end_optimized[user_id_str]['correct'])
                    if type(self.single_question_history_at_training_end_optimized[user_id_str]['incorrect']) == list:
                        self.single_question_history_at_training_end_optimized[user_id_str]['incorrect'] = set(self.single_question_history_at_training_end_optimized[user_id_str]['incorrect'])
                    
                    # We remove the information from `single_question_history_at_training_end_optimized`,
                    # because we will put the information to `question_history_at_training_end`.
                    if current_question_id_str in self.single_question_history_at_training_end_optimized[user_id_str]['correct']:
                        n_prev_seen, n_prev_correctness = [1, 1]
                        self.single_question_history_at_training_end_optimized[user_id_str]['correct'].remove(current_question_id_str)
                    elif current_question_id_str in self.single_question_history_at_training_end_optimized[user_id_str]['incorrect']:
                        n_prev_seen, n_prev_correctness = [1, 0]
                        self.single_question_history_at_training_end_optimized[user_id_str]['incorrect'].remove(current_question_id_str)
                        
            if n_prev_seen is None:
                
                ### assert n_prev_correctness is None
                n_prev_seen, n_prev_correctness = [0, 0]
                
            # Unlike `buffer`, `question_history_at_training_end` should be updated only if the current position is a question.
            # Update `question_history_at_training_end` by combining `n_prev_seen` and the current question information.
            self.question_history_at_training_end[user_id_str][current_question_id_str][0] = n_prev_seen + 1
            # We don't know yet the correctness. This should be update after.
            # The current value is not the correct value (it is the value in the state one time step before) 
            self.question_history_at_training_end[user_id_str][current_question_id_str][1] =  n_prev_correctness
            
            # -----------------------------------------------------------------------------------------------
            
        # The current position is a lecture. The vaules should be all `0`.
        if n_prev_seen is None:

            ### assert n_prev_correctness is None
            n_prev_seen, n_prev_correctness = [0, 0]            
            
        # This is the value to be used.
        # We need to use `append` because the current position doesn't exist in these two lists yet.            
        # Be careful, unlike the update for `question_history_at_training_end` above with `n_prev_seen + 1`,
        # we need to use `n_prev_seen` here!
        buffer['n_prev_seen'].append(n_prev_seen)
        buffer['n_prev_correctness'].append(n_prev_correctness)            
            
        # ====================================================================================================
        # For attributes in `aggregated_attrs` other than `n_prev_seen` and `n_prev_correctness`.
        # We need to look the values in the previous step in the history.
            
        # ------------------------------------------------------------
        # No previous history, we still need to assign value of `0` and put them to the lists.
        
        if record_to_watch is None:
            # `n_prev_seen` and `n_prev_correctness` is done already.
            for k in aggregated_attrs_less:             
                buffer[k].append(0)
            for k in ['prev_part_key', 'prev_correct_answer_key']:
                # We assign `None` value.
                # In post-update, these values shoudn't be used, otherwise some logic is wrong, and we need to debug.
                buffer[k].append(None)
                            
            return
        # ------------------------------------------------------------
        # There is previous hisotry, either in training, or in previous predictions.
        
        # Here we just copy the values (of aggregated-attrs, other than `n_prev_seen` and `n_prev_correctness`)
        # from the previous step in the history. We need to update these values later by watching the 
        # values of non-aggregated-attrs from the previous step in the history.
        
        # i.e. we are at the first interaction (of a particular user) in the current pred batch.
        if pred_time_step_index == 0:
            # `n_prev_seen` and `n_prev_correctness` is done already.
            for k in aggregated_attrs_less:
                # Here, we can use `-1`
                # If `record_to_watch` is `train_record`, `-1` is of course fine.
                # However, if `record_to_watch` is `pred_record` - for attributes in `aggregated_attrs`, the lists haven't been extended and updated,
                # therefore, `-1` allows us to get the previous value.
                buffer[k].append(getattr(record_to_watch, k)[-1])
        # i.e. we are not at the first interaction (of a particular user) in the current pred batch.
        # `buffer[k]` should already have values, and we are safe to use `-1` as index.
        # (for `n_prev_seen` and `n_prev_correctness`, we sholud no longer operate on them), because they must have been done above.
        else:
            for k in aggregated_attrs_less:
                # Use the previous values in `buffer[k]` - and to update later.
                buffer[k].append(buffer[k][-1])
                
        # Just placeholders - if the previous step is a question
        for k in ['prev_part_key', 'prev_correct_answer_key']:
            buffer[k].append(None)
            
        # ------------------------------------------------------------
        # We use `index_to_watch` to watch the previous values (of non aggregated-attributes) in `record_to_watch`.
        # We shouldn't watch the `aggregated-attributes` by using `index_to_watch`, because they don't exist yet!
        
        content_id = record_to_watch.content_id[index_to_watch]
        content_type_id = record_to_watch.content_type_id[index_to_watch]
    
        # calculated only for question
        if content_type_id == 0:

            content_input_id = question_id_to_input_id_dict[content_id]
            # This is in [2, 8]
            part_id = c_inputs_ids_to_part_dict[content_input_id] + 1
            # This is in [0, 3]
            correct_answer = c_inputs_ids_to_correct_answer_id_dict[content_input_id] - ANSWER_0_ID

            # need to be a string
            question_id = str(content_id)
            
            answered_correctly = (record_to_watch.answered_correctly[index_to_watch] == 1)

            # We can use `-1` as index here, because we already extend the list with the value from the previous step.
            # If the previous step is not a question, the values remain `None`.
            buffer['prev_part_key'][-1] = f'part_{part_id}_correctness_count'
            buffer['prev_correct_answer_key'][-1] = f'correct_answer_{correct_answer}_answered_correctly_count'

        is_lecture = int(content_type_id == 1)

        # ------------------------------------------------------------
        # We only update the involved aggregated-attributes

        buffer['n_lectures_watched'][-1] += is_lecture        

        # i.e. the prev step is a question
        if is_lecture == 0:
            
            buffer['n_questions_answered'][-1] += 1

            k = f'part_{part_id}_count'
            buffer[k][-1] += 1
            k = f'correct_answer_{correct_answer}_count'
            buffer[k][-1] += 1
            
            # Update the correctness if `pred_time_step_index = 0` and the previous correctness is `True`.
            # This should only occur for `pred_time_step_index = 0`.
            # Because the correctness information in the current pred batch is not available, and they have `MASK_TOKEN` as values.
            # `pred_record` will have the values at the first (current) prediction time question's place.
            if answered_correctly is True:
                
                ### assert pred_time_step_index == 0
                
                for k in [
                    f'n_questions_answered_correctly',
                    f'part_{part_id}_correctness_count',
                    f'correct_answer_{correct_answer}_answered_correctly_count'
                ]:
                
                    buffer[k][-1] += 1
                    
    def get_aggregated_info_to_update(self, user_id, train_record, pred_record, n_pred_time_steps):
            
        if len(pred_record) > n_pred_time_steps:
            # i.e. we have already some prediction time history for that user
            record_to_watch = pred_record
            index_to_watch = - (n_pred_time_steps + 1)
        elif len(train_record) > 0:
            # i.e. the first pred batch, but the training history is not empty
            record_to_watch = train_record
            index_to_watch = -1
        else:
            # i.e. the first pred batch, but no history used for traininig
            record_to_watch = None
            index_to_watch = None
        
        aggregated_info_buffer = {
            # `aggregated_attrs` contains `n_prev_seen` and `n_prev_correctness`
            # we also add `prev_part_key` and `prev_correct_answer_key` so we don't need to reculate the part and correct answer
            # in post-updating.
            k: [] for k in aggregated_attrs + ['prev_part_key', 'prev_correct_answer_key']
        }
        
        for pred_time_step_index in range(n_pred_time_steps):
            if pred_time_step_index >= 1:
                record_to_watch = pred_record
                index_to_watch = - (n_pred_time_steps - pred_time_step_index) - 1
            self.update_aggregated_info_buffer(user_id, pred_record, record_to_watch, index_to_watch, n_pred_time_steps, pred_time_step_index, aggregated_info_buffer)
        
        return aggregated_info_buffer
        
    def update(self, pred_dt, prior_group_answers_correct, prior_group_responses, prior_info_for_post_update):
        """
        For a prediction batch `pred_df` (`pandas.DataFrame`) given by `env.iter_test`, this method performs:
            1. update the `answered_correctly` and `user_anser` information in the previous test batch
            2. update the user records (only in the test time) by appending the information in the current batch
            3. get the user records in the training time
            4. combine the user records in the training time and test time - so we have a full history and the current batch to predict
        """
            
        user_ids = pred_dt['user_id'].to_list()[0]
        prev_batch_users = self.current_batch_users
        
        if prev_batch_users is not None:
            ### assert prior_info_for_post_update is not None
            self.update_answer_results(prev_batch_users, prior_group_answers_correct, prior_group_responses, prior_info_for_post_update)
        self.update_batch_users(user_ids)
        
        if DEBUG:
            
            row_ids = pred_dt['row_id'].to_list()[0]
            prev_batch_row_ids = self.current_batch_row_ids
            
            if isinstance(self.pred_iter_manager, Test_Iter_Manager):
                debug_pred_batch_row_ids(row_ids, prev_batch_row_ids)
        
            self.current_batch_row_ids = row_ids
        
        # To `User_Record`. Here, each record contains exactly one user interaction.
        # But the same user might appear several times in the list.
        record_batch = [User_Record(user_dt=pred_dt[idx, :]) for idx in range(pred_dt.nrows)]
        
        if DEBUG:
            debug_record_batch(record_batch)
                        
        # update test buffer - add info about the new test batch
        # at this moment, for each attr in `aggregated_attrs`, each record is an empty list,
        # and the corresponding records in the buffer, for `aggregated_attrs`, are not updated
        for record in record_batch:
            self.pred_record_buffer.update(record)
        
        # ----------------------------------------------------
        
        self.users_in_pred.update(set(user_ids))

        n_new_users_in_batch = 0
        for x in set(user_ids):
            if x in self.users_trained:
                self.common_users_in_train_and_pred.add(x)
            else:
                self.new_users_in_pred.add(x)
                n_new_users_in_batch += 1
            
        # ----------------------------------------------------            
            
        nb_common_users = len(self.common_users_in_train_and_pred)
        nb_new_users = len(self.new_users_in_pred)        
        
        common_users_in_batch = set(user_ids).intersection(self.users_trained)
        new_users_in_batch = set(user_ids).difference(self.users_trained)
        
        self.n_blocks_in_pred += len(set(user_ids))
        self.n_blocks_from_common_users_in_train_and_pred += len(common_users_in_batch)
        self.n_blocks_from_new_users_in_pred += len(new_users_in_batch)
                
        self.avg_n_blocks_in_pred = self.n_blocks_in_pred / len(self.users_in_pred)
        
        if nb_common_users > 0:
            self.avg_n_blocks_from_common_users_in_train_and_pred = self.n_blocks_from_common_users_in_train_and_pred / nb_common_users       
        
        if nb_new_users > 0:
            self.avg_n_blocks_from_new_users_in_pred = self.n_blocks_from_new_users_in_pred / nb_new_users
      
        # ----------------------------------------------------    
    
#         if self.probe:
        
#             # ----------------------------------------------------
#             # Probe: The number of users in the test dataset.

#             # At most N_USERS users in the test dataset.       
#             # N_USERS = 20992  # False
#             N_USERS = 21056  # True

#             assert len(self.pred_record_buffer) <= N_USERS        

#             # ----------------------------------------------------
#             # Probe: The number of new users in the test dataset.

#             # N_NEW_USERS = 5904 # False
#             N_NEW_USERS = 5936 # True

#             assert nb_new_users <= N_NEW_USERS

        # ----------------------------------------------------
        
        # get the updated record from test buffer for each user in `user_ids`.
        # The same user might appear several times in the list, however, they get the same user history (test time) sequence.
        ### _pred_record_batch = [self.pred_record_buffer[x] for x in user_ids]
        
        # get the updated record from test buffer for each user in `user_ids` without duplication.
        # The same user can only appear at most 1 time in the list.
        _pred_record_batch = []
        n_pred_time_steps_dict = dict()
        
        for x in user_ids:
        
            if x not in n_pred_time_steps_dict:
                _pred_record_batch.append(self.pred_record_buffer[x])
                n_pred_time_steps_dict[x] = 1
            else:
                n_pred_time_steps_dict[x] += 1
                continue
                        
        user_ids_without_duplication = [x.user_id for x in _pred_record_batch]        
        n_pred_time_steps_batch = [n_pred_time_steps_dict[x.user_id] for x in _pred_record_batch]
                
        ### assert 0 <= min(n_pred_time_steps_batch) 
        ### assert max(n_pred_time_steps_batch) <= self.train_config.window_size
        
        # get the record from train_dt or train record buffer for each user in `user_ids`
        train_record_batch = [self.get_training_record(x) for x in user_ids_without_duplication]
    
        # obtain the full history (in training time + the previous batches in test time) and the current batch to predict
        check = isinstance(self.pred_iter_manager, Valid_Iter_Manager)    
    
        info_for_post_update = {}
        pred_record_batch = []
        for user_id, train_record, _pred_record, n_pred_time_steps in zip(
            user_ids_without_duplication,
            train_record_batch,
            _pred_record_batch,
            n_pred_time_steps_batch
        ):
                   
            # A dictionary of lists, key are in `aggregated_attrs` and values are all of length `n_pred_time_steps`
            aggregated_info_to_update = self.get_aggregated_info_to_update(user_id, train_record, _pred_record, n_pred_time_steps)
                            
            pred_record = combine_user_record(train_record, _pred_record, check=check)
            pred_record_batch.append(pred_record)
            
            info_for_post_update[user_id] = {
                'user_id': user_id,
                'pred_record': _pred_record,
                'n_pred_time_steps': n_pred_time_steps,
                'aggregated_info_to_update': aggregated_info_to_update
            }            
            
            for k in aggregated_attrs:
                
                ### assert len(getattr(_pred_record, k)) == len(_pred_record) - n_pred_time_steps
                
                # `pred_record` is used for model prediction
                if k in ['n_lectures_watched', 'n_prev_seen', 'n_prev_correctness']:
                    getattr(pred_record, k).extend(aggregated_info_to_update[k])
                else:
                    # only use the information before the current pred batch
                    getattr(pred_record, k).extend([aggregated_info_to_update[k][0]] * n_pred_time_steps)
                    
                # We need to update those correct information later!!!
                getattr(_pred_record, k).extend(aggregated_info_to_update[k])
                                     
#             for k in aggregated_attrs:

#                 assert len(getattr(_pred_record, k)) == len(_pred_record)   
#                 assert len(getattr(pred_record, k)) == len(pred_record)

#         if len(self.train_record_buffer) >= self.max_train_buffer_size:
#             self.reset_train_record_buffer()
            
        if len(self.train_record_buffer) >= 3000:
            self.reset_train_record_buffer()
            gc.collect()
    
        return pred_record_batch, n_pred_time_steps_batch, info_for_post_update

    def get_training_record(self, user_id):
        """
        Get the training time history of the user with `user_id`.
        """
               
        # Already cached the result in the buffer
        if user_id in self.train_record_buffer:
            return self.train_record_buffer[user_id]

        # unseen user in the test time
        if user_id not in self.user_id_to_row_id_train:
            
            assert isinstance(self.pred_iter_manager, Test_Iter_Manager)
        
        else:
            
            split_index = self.pred_iter_manager.get_split_index(user_id)        
            
            # user appeared in the original (unsplitted) training dataset
            
            user_dt = get_user_dt(user_id, self.train_dt, self.user_id_to_row_id_train, split_index)
                    
            # During the test time, if an user seen in training time, its whole training history should be used!
            # During the validatioin time, it is possible not to use any of its (original) training examples (i.e. new user for the validation).
            if isinstance(self.pred_iter_manager, Test_Iter_Manager):
                assert user_dt.nrows > 0
            
            if user_dt.nrows > 0:
                
                convert_dt(user_dt, pred=False)

                user_record = User_Record(user_dt=user_dt)
                self.train_record_buffer.update(user_record)
                
                # ----------------------------------------
                # From now, we should use `self.train_record_buffer[user_id]` (the one stored in the buffer), not `user_record`!
                
                # set the training user performance information
                row_ids = self.train_record_buffer[user_id].row_id
                ### assert len(row_ids) > 0
                row_start, row_end = row_ids[0], row_ids[-1]
                seq_len = row_end - row_start + 1
                ### assert len(row_ids) == seq_len
                
                # array of shape [row_end - row_start + 1, len(aggregated_attrs)]
                user_performance = self.user_performance_hdf5[row_start:row_end + 1, :]
                ### assert user_performance.shape == (seq_len, len(aggregated_attrs))

                for feature_idx, k in enumerate(aggregated_attrs):
                    
                    feature_performance = user_performance[:, feature_idx].tolist()
                    getattr(self.train_record_buffer[user_id], k).extend(feature_performance)
                    ### assert len(getattr(self.train_record_buffer[user_id], k)) == seq_len
                
                # ----------------------------------------
                
                return self.train_record_buffer[user_id]

        # Empty training (new users in prediction time)
        user_record = User_Record(user_id=user_id)
        self.train_record_buffer.update(user_record)
        
        return self.train_record_buffer[user_id]

    def convert_to_pred_batch(self, pred_record_batch, n_pred_time_steps_batch):

        batch_record = self.convert_record_batch(pred_record_batch, n_pred_time_steps_batch)
  
        # A single dict of `tf.RaggedTensor` objects                                                                
        batch = Pred_Manager.batch_record_to_tensor(batch_record)

        training = tf.constant(0, dtype=tf.int32)
        
        # A single dict of `tf.Tensor` objects
        ###pred_batch = add_input_ids_and_targets(batch, training, generative=self.config.generative, use_abs_pos=self.config.use_abs_pos)
                        
        ###return pred_batch
        return batch
            
    def record_to_dict(self, record, n_pred_time_steps):
                            
        record_as_dict = {}
        
        orig_seq_len = len(record.row_id)
        used_seq_len = orig_seq_len
        
        pad_seq = []
        if orig_seq_len < self.train_config.window_size:
            pad_seq = [PAD_TOKEN] * (self.train_config.window_size - orig_seq_len)
            
        if orig_seq_len > self.train_config.window_size:
            used_seq_len = self.train_config.window_size
        
        record_as_dict['user_id'] = record.user_id
    
        # adding information
        record_as_dict['seq_len'] = used_seq_len
        record_as_dict['prev_seq_len'] = orig_seq_len
        record_as_dict['start'] = max(orig_seq_len - self.train_config.window_size, 0)
        record_as_dict['end'] = orig_seq_len - 1
        
        # Be careful: we need this to compute `shifted_answered_correctly` correctly later.
        # Otherwise, we will get `START_TOKEN` before `PAD_TOKEN`.
        answered_correctly = getattr(record, 'answered_correctly')
        user_answer = getattr(record, 'user_answer')

        for k in attrs + aggregated_attrs:
            if k not in ['user_id']:
                record_as_dict[k] = (pad_seq + getattr(record, k))[- self.train_config.window_size:]
#                 assert len(record_as_dict[k]) == self.train_config.window_size
                    
        # We need to use `START_TOKEN` rather than `PAD_TOKEN` here.
        shifted_answered_correctly = [START_TOKEN] + answered_correctly[:-1]
        record_as_dict['shifted_answered_correctly'] = (pad_seq + shifted_answered_correctly)[- self.train_config.window_size:]
#         assert len(record_as_dict['shifted_answered_correctly']) == self.train_config.window_size

        shifted_user_answer = [START_TOKEN] + user_answer[:-1]
        record_as_dict['shifted_user_answer'] = (pad_seq + shifted_user_answer)[- self.train_config.window_size:]
#         assert len(record_as_dict['shifted_user_answer']) == self.train_config.window_size

#         assert self.train_config.window_size >= n_pred_time_steps
                
        pred_time_mask = [0] * orig_seq_len
        pred_time_mask = (pad_seq + pred_time_mask)[- self.train_config.window_size:]
        
        # Be careful: `n_pred_time_steps` might be `0`, and using `pred_time_mask[:- n_pred_time_steps]` gives wrong results.
        pred_time_mask = pred_time_mask[:self.train_config.window_size - n_pred_time_steps] + [1] * n_pred_time_steps
        
        record_as_dict['pred_time_mask'] = pred_time_mask
#         assert len(record_as_dict['pred_time_mask']) == self.train_config.window_size

        abs_pos = list(range(orig_seq_len))
        shifted_abs_pos = [START_TOKEN] + abs_pos[:-1]

        record_as_dict['abs_pos'] = (pad_seq + abs_pos)[- self.train_config.window_size:]
        record_as_dict['shifted_abs_pos'] = (pad_seq + shifted_abs_pos)[- self.train_config.window_size:]
        
#         assert len(record_as_dict['abs_pos']) == self.train_config.window_size
#         assert len(record_as_dict['shifted_abs_pos']) == self.train_config.window_size
        
        # --------------------------------------------------------------------------------
        # compute `lag time`
        
        timestamps = getattr(record, 'timestamp')
        task_container_ids = getattr(record, 'task_container_id')
        nb_actual_elts = min(self.train_config.window_size, len(timestamps))
        
        # try to get the (ending) subsequence of length `window_size + `MAX_PRED_TIME_QUESTION_BUNDLE_LEN`, but it might be shorter
        _start_index = max(0, len(timestamps) - (self.train_config.window_size + MAX_PRED_TIME_QUESTION_BUNDLE_LEN))
        # truncated: length = len(...) - `_start_index`
        timestamps = timestamps[_start_index:]
        task_container_ids = task_container_ids[_start_index:]
        
        # add leading `0`s if the lenght is not enough
        nb_elt_to_add = (self.train_config.window_size + MAX_PRED_TIME_QUESTION_BUNDLE_LEN) - len(timestamps)
        if nb_elt_to_add > 0:
            timestamps = [0] * nb_elt_to_add + timestamps
            task_container_ids = [PAD_TOKEN] * nb_elt_to_add + task_container_ids

#         assert len(timestamps) == self.train_config.window_size + MAX_PRED_TIME_QUESTION_BUNDLE_LEN
#         assert len(task_container_ids) == self.train_config.window_size + MAX_PRED_TIME_QUESTION_BUNDLE_LEN

        # loop
        lag_time = []
        current_lag_time = 0
        for _t, _prev_t, _bundle, _prev_bundle in zip(timestamps[1:], timestamps[:-1], task_container_ids[1:], task_container_ids[:-1]):
            _diff = _t - _prev_t
            if _diff == 0 and _bundle == _prev_bundle:
                # In the same bundle
                _diff = current_lag_time
            elif _diff == 0 and _bundle != _prev_bundle:
                # problematic data point - but we can use `0` and post-process it below
                pass
            else:
                # update the latest lag time
                # assert _diff > 0
                current_lag_time = _diff
            lag_time.append(_diff)
            
#         assert len(lag_time) == self.train_config.window_size + MAX_PRED_TIME_QUESTION_BUNDLE_LEN - 1
        
        # remove the leading part
        lag_time = lag_time[-nb_actual_elts:]
        timestamps = timestamps[1:][-nb_actual_elts:]
        
#         assert len(lag_time) == nb_actual_elts

        # set to 1 second if `lag_time == 0` but they are not in the same bunlde
        _mask = tf.cast(tf.math.logical_and(tf.constant(lag_time, dtype=tf.int64) == 0, tf.constant(timestamps, dtype=tf.int64) > 0), dtype=tf.int64) 
        lag_time = tf.constant(lag_time, dtype=tf.int64) * (1 - _mask) + tf.constant(1000, dtype=tf.int64) * _mask          
        lag_time = lag_time.numpy().tolist()
        
        lag_time = (pad_seq + lag_time)[- self.train_config.window_size:]
#         assert len(lag_time) == self.train_config.window_size
        
        record_as_dict['lag_time'] = lag_time
        
        # --------------------------------------------------------------------------------

        return record_as_dict
    
    def convert_record_batch(self, record_batch, n_pred_time_steps_batch):
        """A format change for a list of records.

        Args:
            record_batch: A list of `User_Record`.

        Returns:
            batch_record: A `dict` with keys in `attrs`. For each key `k`, the value
            is a list of sequences [x.k for x in record_batch]
        """

        # time consuming
        records_as_dicts = [Pred_Manager.record_to_dict(self, record, n_pred_time_steps) for record, n_pred_time_steps in zip(record_batch, n_pred_time_steps_batch)]
        batch_record = {k: [x[k] for x in records_as_dicts] for k in extended_attrs + aggregated_attrs}

        return batch_record


    @staticmethod
    def batch_record_to_tensor(batch_record):
        """Convert a `batch_record` obtained from `convert_record_batch` to `tf.Tensor`.
        """
        
        t = {
            k: tf.constant(batch_record[k], dtype=attr_dtypes[k]) for k in extended_attrs + aggregated_attrs
        }    
        
#         assert t['timestamp'].dtype == tf.int64
        
        return t

### Debug functions

In [None]:
def debug_pred_batch_row_ids(row_ids, prev_batch_row_ids):
    """Verify the properties of row ids in a single and across pred batch(es).
    
    Args:
        row_ids: The row ids in a batch during the pred time given by `env.iter_test()`
        prev_batch_row_ids: The row ids in the batch just before the batch of `row_ids` during the test time given by `env.iter_test()`
    """

    # sanity check
    # row ids must be distinct
    assert len(set(row_ids)) == len(row_ids)
    
    # row ids must be in sorted order in a batch
    assert row_ids == sorted(row_ids)
    
    # row ids must be in sorted order across all batch during the pred time.
    if prev_batch_row_ids is not None:
        assert row_ids[0] > prev_batch_row_ids[-1]


def debug_record_batch(record_batch):
    """Verify the properties of the question bundle for a single user in a pred batch.
    
    Args:
        record_batch: A list. Each element (`User_Record`) should contain only one record in a single timestamp.
    """

    _tmp = defaultdict(list)
    for record in record_batch:
        # each record here contains only one entry.
        assert len(record.row_id) == 1
        _tmp[record.user_id].append(record)

    for user_id, records in _tmp.items():

        task_container_ids_for_questions = [x.task_container_id[0] for x in records if x.content_type_id[0] == 0]

        # If there is any question for a user
        if len(task_container_ids_for_questions) > 0:

            # There must be exactly one question bundle.
            # This is `True`.
            assert len(set(task_container_ids_for_questions)) == 1

            row_ids_for_questions = [x.row_id[0] for x in records if x.content_type_id[0] == 0]
            
            # This is `False`: the question bundle must be in a consecutive block with continuous row ids.
            # assert row_ids_for_questions == list(range(row_ids_for_questions[0], row_ids_for_questions[-1] + 1))

            records_in_between = [x for x in records if row_ids_for_questions[0] <= x.row_id[0] <= row_ids_for_questions[-1]]
            row_ids_in_between = [x.row_id[0] for x in records_in_between]
            
            # This is `True`: the question bundle must be in a consecutive block (but the row ids may jump).
            assert row_ids_for_questions == row_ids_in_between

            # The question bundle must be at the end of the sequence (for a single user) in a test batch.
            assert row_ids_for_questions == [x.row_id[0] for x in records][-len(row_ids_for_questions):]

## Prediction routine

In [None]:
max_buffer_size = 20
buffer = []


def run_pred(pred_manager, predictor, max_steps=None, ckpt_no=None):
    
    batch_size = None
    seq_len = pred_manager.train_config.window_size

    encoder_decoder = tf.constant(pred_manager.config.model_type=='ed', dtype=tf.int32)
    training = tf.constant(0, dtype=tf.int32)
    generative = tf.constant(pred_manager.config.generative, dtype=tf.int32)
    allow_bundle_atten = tf.constant(pred_manager.config.allow_bundle_atten, dtype=tf.int32)
    
    # Only used for validation
    valid_user_ids = []
    valid_row_ids = []
    valid_targets = []
    valid_preds = []

    if pred_manager.config.use_softmax:
        loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE,
            name='sparse_categorical_crossentropy_for_correctness'
        )
    else:
        loss_obj = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, label_smoothing=0, reduction=tf.keras.losses.Reduction.NONE,
            name='binary_crossentropy_for_correctness'
        )

    loss_metric = tf.keras.metrics.Mean()    
    acc_metric = tf.keras.metrics.BinaryAccuracy()
    auc_metric = tf.keras.metrics.AUC(num_thresholds=2000)    
    
#     @tf.function(
#         input_signature=get_input_signatures(batch_size, seq_len, valid=False)
#     )


    @tf.function(
        input_signature=get_input_signatures_2()
    )
    def predict(batch):

        pred_batch = add_input_ids_and_targets_2(batch)
        
        c_mask, r_mask, r_c_mask, c_r_mask = get_attention_masks(pred_batch, training, encoder_decoder, generative, allow_bundle_atten)

        # This is the mask where only the interactions in the current pred batch will be `1`.
        # shape = [batch_size, window_size]
        pred_time_mask = pred_batch['pred_time_mask'] * tf.cast(pred_batch['pred_time_mask'] != PAD_TOKEN, dtype=tf.int32)

        logits, answer_logits, c_outputs, r_outputs = predictor(
            pred_batch, c_mask, r_mask, r_c_mask, c_r_mask,
            output_attentions=False, output_hidden_states=False, training=False
        )
            
        # preds = tf.math.sigmoid(logits)

        if pred_manager.config.use_softmax:
            # shape = [batch_size, seq_len, 2]
            preds = tf.math.softmax(logits)
        else:
            # shape = [batch_size, seq_len, 1]
            preds = tf.math.sigmoid(logits)

        # The places where in the current prediction batch.
        pred_indices = tf.where(pred_time_mask == 1)

        # shape = [n_selected_places, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
        selected_logits = tf.gather_nd(logits, pred_indices)
        # shape = [n_selected_places, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
        selected_preds = tf.gather_nd(preds, pred_indices)

        selected_row_ids = tf.gather_nd(pred_batch['row_id'], pred_indices)
        
        return logits, preds, selected_logits, selected_preds, selected_row_ids
    
    # ----------------------------------------------------
    
    # ----------------------------------------------------
    
    start_memeory_global = psutil.virtual_memory().available / 1024.0 / 1024.0
    start = datetime.datetime.now()
    start_global = start    
    
    pred_history = []

    n_pred_batches = 0
    n_interactions = 0
    
    prior_info_for_post_update = None
    
    if isinstance(pred_manager.pred_iter_manager, Test_Iter_Manager):
        if max_steps is not None:
            print(f'For `Test_Iter_Manager`, {max_steps} shoulde be `None`, and we set it for you here!')
            max_steps = None
        
    total = 0.0
    for pred_df, pred_dt, prior_group_answers_correct, prior_group_responses in pred_manager.pred_iter_manager.iter_dataset(max_steps=max_steps):

        n_pred_batches += 1
    
        try:
            
            # ---------------------------------------------------- 
            
            if DEBUG:
                
                pass
            
#                 _question_ids = set(pred_df[pred_df['content_type_id'] == 0]['content_id'].values.tolist())
#                 _lecture_ids = set(pred_df[pred_df['content_type_id'] == 1]['content_id'].values.tolist())              
            
#                 # All pred questions must have been seen in training time.
#                 assert _question_ids.issubset(pred_manager.pred_iter_manager._unique_question_id_train)

#                 # All pred lectures must have been seen in training time.
#                 assert _lecture_ids.issubset(pred_manager.pred_iter_manager._unique_lecture_id_train)

            pred_record_batch, n_pred_time_steps_batch, info_for_post_update = pred_manager.update(
                pred_dt, prior_group_answers_correct, prior_group_responses, prior_info_for_post_update
            )
            elapsed = (datetime.datetime.now() - start).total_seconds()

            # Update
            prior_info_for_post_update = info_for_post_update
            
            pred_batch = pred_manager.convert_to_pred_batch(pred_record_batch, n_pred_time_steps_batch)
                                    
            logits, preds, selected_logits, selected_preds, selected_row_ids = predict(pred_batch)
                
            # ----------------------------------------------------            
            # Compute metrics
            
            # shape = [n_selected_places]
            predictions = selected_preds[:, -1].numpy().tolist()
            
            # sanity check
            assert len(predictions) == len(pred_df['content_type_id'])
            
            user_ids = pred_manager.current_batch_users
            n_interactions += len(user_ids)

            if isinstance(pred_manager.pred_iter_manager, Valid_Iter_Manager):

                row_ids = selected_row_ids.numpy().tolist()
                
                # Be careful, the `targets` here is not the same as in `pred_batch['target']`.
                # In `pred_batch`, we will get `-100` (i.e. `NON_TARGET_ID`) for the places not to be predicted (i.e. padding/lectures),
                # while in `targets`, we will get `-1`.
                targets = pred_manager.pred_iter_manager._current_answered_correctly
                # user_answers = pred_manager.pred_iter_manager._current_user_answer

                # sanity check
                assert len(set([len(user_ids), len(row_ids), len(targets), len(predictions)])) == 1
                
                targets_t = tf.constant(targets, dtype=tf.int32)
                # Be careful: Here we use `-1` rather than `NON_TARGET_ID`.
                pred_mask = targets_t != -1
                pred_indices = tf.where(pred_mask)

                selected_targets_for_loss = tf.gather_nd(targets_t, pred_indices)
                
                # shape = [n_selected_places_for_loss, n_targets], `n_targets = 2` if `use_softmax`, otherwise `1`.
                selected_logits_for_loss = tf.gather_nd(selected_logits, pred_indices)
                selected_preds_for_loss = tf.gather_nd(selected_preds, pred_indices)
                
                # From the doc, `tf.keras.losses.SparseCategoricalCrossentropy` should use `selected_targets_for_loss`, but it seems ok to have the last dimension.
                # For `BinaryCrossentropy`, we need to use `selected_targets_for_loss[:, tf.newaxis]` to have the 2nd dimension so the losses are not averaged.                
                losses = loss_obj(selected_targets_for_loss[:, tf.newaxis], selected_logits_for_loss)
                
                loss_metric.update_state(losses)
                # Use `selected_preds_for_loss[:, -1:]` to get the probabilities for class `1`, with the 2nd dim
                acc_metric.update_state(selected_targets_for_loss[:, tf.newaxis], selected_preds_for_loss[:, -1:])
                # Use `selected_preds_for_loss[:, -1]` to get the probabilities for class `1`, without the 2nd dim
                auc_metric.update_state(selected_targets_for_loss, selected_preds_for_loss[:, -1])

                valid_user_ids.extend(user_ids)
                valid_row_ids.extend(row_ids)
                valid_targets.extend(targets)
                valid_preds.extend(predictions)       

            elif isinstance(pred_manager.pred_iter_manager, Test_Iter_Manager):

                pred_df['answered_correctly'] = predictions
                pred_manager.pred_iter_manager.env.predict(pred_df.loc[pred_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])

            else:

                 raise ValueError('`pred_manager.pred_iter_manager` must be an instance of either `Valid_Iter_Manager` or `Test_Iter_Manager`')
            
            # ----------------------------------------------------        
        
#             if len(buffer) < max_buffer_size:
#                 buffer.append({'pred_batch': pred_batch, 'logits': logits, 'preds': preds, 'selected_preds': selected_preds})
        
            # ----------------------------------------------------        
            
#             if pred_manager.probe:

#                 # ----------------------------------------------------
#                 # Probe: The number of pred batches

#                 # N_PRED_BATCHES = 57104 # False
#                 N_PRED_BATCHES = 57120 # True (something strange after API update)
#                 N_PRED_BATCHES = 60000 # True (need to check again)
                
#                 # Disable for now
#                 ### assert n_pred_batches <= N_PRED_BATCHES

#                 # ----------------------------------------------------
#                 # Probe: the number of blocks / the number of blocks given by new users

#                 if n_pred_batches > 57100:

#                     ### assert 64 <= pred_manager.avg_n_blocks_from_new_users # False
#                     assert 62 <= pred_manager.avg_n_blocks_from_new_users_in_pred # True

#                     ### assert 88 <= pred_manager.avg_n_blocks_from_common_users_in_train_and_pred # False
#                     assert 83 <= pred_manager.avg_n_blocks_from_common_users_in_train_and_pred # True
                    
            # ----------------------------------------------------

            if n_pred_batches % 100 == 0:
                
                end = datetime.datetime.now()
                elapsed = (end - start).total_seconds()
                elapsed_global = (end - start_global).total_seconds()
                avg_timing_per_batch = elapsed / PRINTING_STEPS
                avg_timing_per_batch_global = elapsed_global / n_pred_batches

                end_memeory = psutil.virtual_memory().available / 1024.0 / 1024.0
                used_memory_global = start_memeory_global - end_memeory                
                
                print(f'used memory: {used_memory_global}')
                
                print(f'pred_manager.n_blocks_in_pred: {pred_manager.n_blocks_in_pred}')
                print(f'pred_manager.n_blocks_from_common_users_in_train_and_pred: {pred_manager.n_blocks_from_common_users_in_train_and_pred}')
                print(f'pred_manager.n_blocks_from_new_users_in_pred: {pred_manager.n_blocks_from_new_users_in_pred}')

                print(f'pred_manager.avg_n_blocks_in_pred: {pred_manager.avg_n_blocks_in_pred}')
                print(f'pred_manager.avg_n_blocks_from_common_users_in_train_and_pred: {pred_manager.avg_n_blocks_from_common_users_in_train_and_pred}')
                print(f'pred_manager.avg_n_blocks_from_new_users_in_pred: {pred_manager.avg_n_blocks_from_new_users_in_pred}')
                
                print(f'number of user interactions processed in current batch: {len(user_ids)}')
                print(f'total number of user interactions processed: {n_interactions}')
                print(f'average number of user interactions per batch: {n_interactions / n_pred_batches}')
                
                print(f'avg_timing_per_batch: {avg_timing_per_batch}')
                print(f'avg_timing_per_batch_global: {avg_timing_per_batch_global}')
                
                if isinstance(pred_manager.pred_iter_manager, Valid_Iter_Manager):
                    
                    valid_loss = loss_metric.result().numpy()
                    valid_acc = acc_metric.result().numpy()
                    valid_auc = auc_metric.result().numpy()
                    
                    print(f'valid_loss: {valid_loss}')
                    print(f'valid_acc: {valid_acc}')
                    print(f'valid_auc: {valid_auc}')
                
                start = datetime.datetime.now()
                
                print('-' * 32)

            # ----------------------------------------------------        

            # Save a few buffer status to check things are expected.
            if n_pred_batches <= 4:
                pred_history.append(str(pred_manager.pred_record_buffer))

        except AssertionError as e:
            
            print('some assertions are wrong, breaking the loop')
            raise e

#     if DEBUG:
#         assert pred_manager.pred_iter_manager.current_batch_no == n_pred_batches
    
#     print(f'total n_pred_batches: {pred_manager.pred_iter_manager.current_batch_no}')

#     print(f'pred_manager.n_blocks_in_pred: {pred_manager.n_blocks_in_pred}')
#     print(f'pred_manager.n_blocks_from_common_users_in_train_and_pred: {pred_manager.n_blocks_from_common_users_in_train_and_pred}')
#     print(f'pred_manager.n_blocks_from_new_users_in_pred: {pred_manager.n_blocks_from_new_users_in_pred}')

#     print(f'pred_manager.avg_n_blocks_in_pred: {pred_manager.avg_n_blocks_in_pred}')
#     print(f'pred_manager.avg_n_blocks_from_common_users_in_train_and_pred: {pred_manager.avg_n_blocks_from_common_users_in_train_and_pred}')
#     print(f'pred_manager.avg_n_blocks_from_new_users_in_pred: {pred_manager.avg_n_blocks_from_new_users_in_pred}')

#     print(f'number of user interactions processed in current batch: {len(user_ids)}')
#     print(f'total number of user interactions processed: {n_interactions}')
#     print(f'average number of user interactions per batch: {n_interactions / n_pred_batches}')    
    
    if isinstance(pred_manager.pred_iter_manager, Valid_Iter_Manager):
        
        if MAX_VALID_ITER_STEPS is not None:
            assert not DEBUG
        else:
            print(f'pred_manager.pred_iter_manager: {pred_manager.pred_iter_manager._user_manager_dict}')
        
        end = datetime.datetime.now()
        elapsed_global = (end - start_global).total_seconds()
        avg_timing_per_batch_global = elapsed_global / n_pred_batches
        print(f'avg_timing_per_batch_global: {avg_timing_per_batch_global}')        
        
        valid_loss = float(loss_metric.result().numpy())
        valid_acc = float(acc_metric.result().numpy())
        valid_auc = float(auc_metric.result().numpy())

        print(f'valid_loss: {valid_loss}')
        print(f'valid_acc: {valid_acc}')
        print(f'valid_auc: {valid_auc}')
        
        valid_results = {
            'valid_loss': valid_loss,
            'valid_acc': valid_acc,
            'valid_auc': valid_auc
        }
        
        valid_submission = pd.DataFrame.from_dict(
            {
                'row_ids': valid_row_ids,
                'user_ids': valid_user_ids,
                'targets': valid_targets,
                'preds': valid_preds
            }
        )
        
        if ckpt_no is None:
            ckpt_no = 0
        
        valid_submission.to_csv(f'valid_submission_ckpt_{ckpt_no}.csv', index=False)
        with open(f'valid_results_ckpt_{ckpt_no}.json', 'w', encoding='UTF-8') as fp:
            json.dump(valid_results, fp, ensure_ascii=False, indent=4)

        if not IS_KAGGLE:
            !gsutil cp -r './valid_results_ckpt_{ckpt_no}.json' '{CKPT_PRED_PATH}'
            !gsutil cp -r './valid_submission_ckpt_{ckpt_no}.csv' '{CKPT_PRED_PATH}'

        del pred_manager.pred_iter_manager
        del pred_manager
        gc.collect()
            
    tf.keras.backend.clear_session()

In [None]:
 def run_dummy_inputs(predictor):

    @tf.function
    def foo(inputs):

        r = predictor(inputs=inputs)

        return r

    with strategy.scope():

        content_input_ids = tf.constant(1, shape=[3, 5])
        response_input_ids = tf.constant(1, shape=[3, 5])
        d_input_ids = tf.constant(1, shape=[3, 5])
        d_ans_input_ids = tf.constant(1, shape=[3, 5])
        pos_ids = tf.constant(1, shape=[3, 5])
        shifted_pos_ids = tf.constant(1, shape=[3, 5])
        tag_ids = tf.constant(0, shape=[3, 5, N_TAGS_PER_CONTENT])
        part_ids = tf.constant(1, shape=[3, 5])
        prior_explanation_ids = tf.constant(1, shape=[3, 5])
        prior_question_elapsed_time_input = tf.constant(1.0, shape=[3, 5])
        lag_time = tf.constant(1.0, shape=[3, 5])
        abs_pos_ids = tf.constant(1.0, shape=[3, 5])
        shifted_abs_pos_ids = tf.constant(1.0, shape=[3, 5])
        task_container_pos_ids = tf.constant(1.0, shape=[3, 5])
        correct_answer_id = tf.constant(1, shape=[3, 5])
        n_questions_answered_scaled = tf.constant(1.0, shape=[3, 5])
        n_lectures_watched_scaled = tf.constant(1.0, shape=[3, 5])
        answered_correctly_ratio = tf.constant(1.0, shape=[3, 5])
        part_correctness_ratio = tf.constant(1.0, shape=[3, 5, PART_VOCAB_SIZE - 2])
        part_count_scaled = tf.constant(1.0, shape=[3, 5, PART_VOCAB_SIZE - 2])
        correct_answer_count_scaled = tf.constant(1.0, shape=[3, 5, ANSWER_3_ID - ANSWER_0_ID + 1])
        correct_answer_correctness_ratio = tf.constant(1.0, shape=[3, 5, ANSWER_3_ID - ANSWER_0_ID + 1])
        current_part_count_scaled = tf.constant(1.0, shape=[3, 5])
        current_part_correctness_ratio = tf.constant(1.0, shape=[3, 5])
        current_correct_answer_count_scaled = tf.constant(1.0, shape=[3, 5])
        current_correct_answer_correctness_ratio = tf.constant(1.0, shape=[3, 5])
        current_question_count_scaled = tf.constant(1.0, shape=[3, 5])
        current_question_correctness_ratio = tf.constant(1.0, shape=[3, 5])

        inputs = {
            'c_input_ids': content_input_ids,
            'r_input_ids': response_input_ids,
            'd_input_ids': d_input_ids,
            'd_ans_input_ids': d_ans_input_ids,
            'pos_ids': pos_ids,
            'shifted_pos_ids': shifted_pos_ids,
            'tag_ids': tag_ids,
            'part_ids': part_ids,
            'prior_explanation_ids': prior_explanation_ids,
            'prior_question_elapsed_time_input': prior_question_elapsed_time_input,
            'lag_time': lag_time,
            'abs_pos_ids': abs_pos_ids,
            'shifted_abs_pos_ids': shifted_abs_pos_ids,
            'task_container_pos_ids': task_container_pos_ids,
            'correct_answer_id': correct_answer_id,
            'n_questions_answered_scaled': n_questions_answered_scaled,
            'n_lectures_watched_scaled': n_lectures_watched_scaled,
            'answered_correctly_ratio': answered_correctly_ratio,
            'part_correctness_ratio': part_correctness_ratio,
            'part_count_scaled': part_count_scaled,
            'correct_answer_count_scaled': correct_answer_count_scaled,
            'correct_answer_correctness_ratio': correct_answer_correctness_ratio,
            'current_part_count_scaled': current_part_count_scaled,
            'current_part_correctness_ratio': current_part_correctness_ratio,
            'current_correct_answer_count_scaled': current_correct_answer_count_scaled,
            'current_correct_answer_correctness_ratio': current_correct_answer_correctness_ratio,
            'current_question_count_scaled': current_question_count_scaled,
            'current_question_correctness_ratio': current_question_correctness_ratio,          
        }

        r1 = foo(inputs)
    
        return r1

In [None]:
def load_ckpt(predictor, optimizer, ckpt_path, ckpt_no=-1):

    with strategy.scope():
        
        # ----------------------------------------------------------------------
        # Init model's weight if necessary

        run_dummy_inputs(predictor)        

        # ----------------------------------------------------------------------

        if optimizer is None:
            optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=predictor)
        ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=ckpt_path, max_to_keep=None)

        # ----------------------------------------------------------------------

        last_ckpt_no = 0
        if ckpt_manager.latest_checkpoint:
            if len(ckpt_manager.latest_checkpoint.split('/ckpt-')) > 0:
                last_ckpt_no = int(ckpt_manager.latest_checkpoint.split('/ckpt-')[-1])
                
        if last_ckpt_no > 0:
            print(f'Latest checkpoint found. Model trained for {last_ckpt_no} epochs.')
        else:
            print("No checkpoint found.")

        if ckpt_no == -1:
            ckpt_no_to_load = last_ckpt_no
        else:
            ckpt_no_to_load = ckpt_no

        assert ckpt_no_to_load <= last_ckpt_no

        ckpt_file = os.path.join(ckpt_path, f'ckpt-{ckpt_no_to_load}')

        # Load ckpt

        print(f'try to load {ckpt_file}')
        if ckpt_no_to_load > 0:
            ### pass
            status = checkpoint.restore(ckpt_file)
            print(f'ckpt-{ckpt_no_to_load} is restored.')
            loaded_ckpt_no = ckpt_no_to_load
        else:
            print(f'no ckpt is found and restored.')
            loaded_ckpt_no = 0

        # ----------------------------------------------------------------------

    return ckpt_manager, loaded_ckpt_no

### Train / Valid / Pred configurations

In [None]:
# Only for a quick change, shouldn't use in real commit or submission

DUMMY = True

if SUBMISSION:
    DUMMY = False

In [None]:
# Only for a quick change, shouldn't use in real commit or submission

if DUMMY:

    # --------------------------------------------------

    MODEL_TYPE = 'ed'
    MODEL_SIZE = 'b-2'

    # Our base config
    # MODEL_DESC = 'lag-user-corr-ans-enc-loss'

    # Addon
    MODEL_DESC = 'master-train-all'

    ACTIVATION = 'gelu'
    USE_PRE_CLASSIFIER = False
    USE_SOFTMAX = False
    USE_USER_ANSWER = True
    USE_USER_ANSWER_LOSS = True
    USE_CORRECT_ANSWER_FOR_ENCODER = True
    USE_CORRECT_ANSWER_FOR_DECODER = False
    USE_ABS_POS = False
    USE_TASK_CONTAINER_POS = False
    SHARE_POS_EMBEDDING = True
    USE_TAGS = True
    USE_PART = True
    USE_PRIOR_EXPLANATION = True
    USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT = True
    USE_LAG_TIME = True
    USE_LAG_TIME_FOR_ENCODER=False
    USE_USER_LEVEL_AGGREGATED_HISTORICAL_INFO = True
    USE_PART_AGGREGATED_HISTORICAL_INFO = True
    USE_CORRECT_ANSWER_AGGREGATED_HISTORICAL_INFO = True
    USE_QUESTION_LEVEL_AGGREGATED_HISTORICAL_INFO = True
    ALLOW_BUNDLE_ATTEN = True
    GENERATIVE = False

    VALID_FOLD = 1
    
    WINDOW_SIZE = 128
    LOSS_WEIGHT_WINDOW_SIZE = None
    BATCH_SIZE = 16 * strategy.num_replicas_in_sync
    if not tpu:
        BATCH_SIZE = 8 * BATCH_SIZE
    PRED_BATCH_SIZE = 256 * strategy.num_replicas_in_sync
    if tpu:
        assert BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE
    assert PRED_BATCH_SIZE // strategy.num_replicas_in_sync != WINDOW_SIZE
    N_EPOCHS = 100
    STEPS_PER_CALL = 1000
    MAX_N_CONTENTS_PER_USER_FOR_SAMPLING_PROB = 512

    LR = 2e-4
    END_LR = 2e-5
    WARMUP_STEPS = 80000

    DETERMINISTIC = False

    if DETERMINISTIC:

        SEED = 2021
        N_PARALLEL_READS = None  # 1
        N_PARALLEL_CALLS = None  # 1
        SHUFFLE_BUFFER_SIZE = 1

    else:
        
        SEED = None
        N_PARALLEL_READS = 16
        N_PARALLEL_CALLS = tf.data.experimental.AUTOTUNE
        SHUFFLE_BUFFER_SIZE = 65536

    MAX_TRAIN_ITER_STEPS = None
    MAX_VALID_ITER_STEPS = None

    PRINTING_STEPS = 1000

    CKPT_DIR = f'{MODEL_TYPE}-{MODEL_SIZE}-{MODEL_DESC}/'

    # --------------------------------------------------

    TRAIN = False
    VALID = False
    PRED = True

    RESUME_TRAINING = True

    N_FILES = 6
    SUBMISSION = False

    if IS_KAGGLE:
        N_FILES = len(os.listdir('/kaggle/input/riiid-test-answer-prediction'))
        SUBMISSION = (N_FILES != 6)
    else:
        PRED = False

    if SUBMISSION:
        
        TRAIN = False
        VALID = False
        PRED = True

    if not TRAIN:
        RESUME_TRAINING = False

    DEBUG = False
    PROBE = False

    CKPT_TRAIN_PATH = None
    CKPT_PRED_PATH = None
    if CKPT_TRAIN_PATH is None:
        
        if IS_KAGGLE:
            CKPT_TRAIN_PATH = './'
        else:
            CKPT_TRAIN_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'
            
            if TRAIN and not RESUME_TRAINING:
                _state = !gsutil -q stat {CKPT_TRAIN_PATH}*; echo $?
                already_existed = 1 - int(_state[0])
                assert not already_existed       

    if CKPT_PRED_PATH is None:
        
        if IS_KAGGLE:
            CKPT_PRED_PATH = f'{BASE_DIR}/{CKPT_DIR}'
        else:
            CKPT_PRED_PATH = f'{BUCKET_DIR}/r3id-ckpts/{CKPT_DIR}'

In [None]:
FFN_FACTOR = 4

if MODEL_SIZE == 'baby':
    # Only for debug

    n_layers = 1
    n_heads = 2
    dim = 4
    hidden_dim = dim

elif MODEL_SIZE == 't-0':

    n_layers = 2
    n_heads = 4
    dim = 32
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 't-1':

    n_layers = 2
    n_heads = 4
    dim = 64
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 't-2':

    n_layers = 2
    n_heads = 4
    dim = 128
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 't-3':

    n_layers = 2
    n_heads = 4
    dim = 256
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 's-0':

    n_layers = 4
    n_heads = 4
    dim = 32
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 's-1':

    n_layers = 4
    n_heads = 4
    dim = 64
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 's-2':

    n_layers = 4
    n_heads = 4
    dim = 128
    hidden_dim = FFN_FACTOR * dim                                                

elif MODEL_SIZE == 's-3':

    n_layers = 4
    n_heads = 4
    dim = 256
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'b-0':

    n_layers = 4
    n_heads = 8
    dim = 64
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'b-1':

    n_layers = 4
    n_heads = 8
    dim = 128
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'b-2':

    n_layers = 4
    n_heads = 8
    dim = 256
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'b-3':

    n_layers = 4
    n_heads = 8
    dim = 512
    hidden_dim = FFN_FACTOR * dim // 2

elif MODEL_SIZE == 'b-4':

    n_layers = 4
    n_heads = 16
    dim = 512
    hidden_dim = FFN_FACTOR * dim  // 2

elif MODEL_SIZE == 'm-0':

    n_layers = 6
    n_heads = 8
    dim = 64
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'm-1':

    n_layers = 6
    n_heads = 8
    dim = 128
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'm-2':

    n_layers = 6
    n_heads = 8
    dim = 256
    hidden_dim = FFN_FACTOR * dim

elif MODEL_SIZE == 'm-3':

    n_layers = 6
    n_heads = 8
    dim = 512
    hidden_dim = FFN_FACTOR * dim // 2

elif MODEL_SIZE == 'm-4':

    n_layers = 6
    n_heads = 16
    dim = 512
    hidden_dim = FFN_FACTOR * dim // 2

# --------------------------------------------------

assert MODEL_TYPE is not None and MODEL_TYPE != ''
assert MODEL_DESC is not None and MODEL_DESC != ''

In [None]:
# `0` is used for padding, the real position is from `1` to `WINDOW_SIZE`.
max_position_embeddings = WINDOW_SIZE + 1

config = EdFormerConfig(
    model_type=MODEL_TYPE,
    model_desc=MODEL_DESC,
    model_size=MODEL_SIZE,
    content_vocab_size=CONTENT_VOCAB_SIZE,
    response_vocab_size=RESPONSE_VOCAB_SIZE,
    tag_vocab_size=TAG_VOCAB_SIZE,
    part_vocab_size=PART_VOCAB_SIZE,
    prior_explanation_vocab_size=PRIOR_EXPLANATION_VOCAB_SIZE,
    use_prior_question_elapsed_time_input=USE_PRIOR_QUESTION_ELAPSED_TIME_INPUT,
    max_position_embeddings=max_position_embeddings,
    sinusoidal_pos_embds=False,
    n_layers=n_layers,
    n_heads=n_heads,
    dim=dim,
    hidden_dim=hidden_dim,
    activation=ACTIVATION,        
    dropout=0.1,
    attention_dropout=0.1,
    seq2seq_dropout=0.1,
    initializer_range=0.02,
    seed=SEED,
    pad_token_id=PAD_ID,
    use_user_answer=USE_USER_ANSWER,
    use_user_answer_loss=USE_USER_ANSWER_LOSS,
    use_correct_answer_for_encoder=USE_CORRECT_ANSWER_FOR_ENCODER,
    use_correct_answer_for_decoder=USE_CORRECT_ANSWER_FOR_DECODER,    
    use_abs_pos=USE_ABS_POS,
    use_task_container_pos=USE_TASK_CONTAINER_POS,
    share_position_embeddings=SHARE_POS_EMBEDDING,
    use_tags=USE_TAGS,
    use_part=USE_PART,
    use_prior_explanation=USE_PRIOR_EXPLANATION,
    use_lag_time=USE_LAG_TIME,
    use_lag_time_for_encoder=USE_LAG_TIME_FOR_ENCODER,
    use_user_level_aggregated_historical_info=USE_USER_LEVEL_AGGREGATED_HISTORICAL_INFO,
    use_part_aggregated_historical_info=USE_PART_AGGREGATED_HISTORICAL_INFO,
    use_correct_answer_aggregated_historical_info=USE_CORRECT_ANSWER_AGGREGATED_HISTORICAL_INFO,
    use_question_level_aggregated_historical_info=USE_QUESTION_LEVEL_AGGREGATED_HISTORICAL_INFO,
    allow_bundle_atten=ALLOW_BUNDLE_ATTEN,
    generative=GENERATIVE,
    use_pre_classifier=USE_PRE_CLASSIFIER,
    use_softmax=USE_SOFTMAX,
)

train_config = TrainConfig(
    ckpt_path=CKPT_TRAIN_PATH,
    window_size=WINDOW_SIZE,
    loss_weight_window_size=LOSS_WEIGHT_WINDOW_SIZE,
    n_epochs=N_EPOCHS,
    shuffle_buf_size=SHUFFLE_BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    pred_batch_size=PRED_BATCH_SIZE,
    seed=SEED,
    deterministic=DETERMINISTIC,
    num_parallel_reads=N_PARALLEL_READS,
    num_parallel_calls=N_PARALLEL_CALLS,
    steps_per_call=STEPS_PER_CALL,
    max_n_contents_per_user_for_sampling_prob=MAX_N_CONTENTS_PER_USER_FOR_SAMPLING_PROB,
    valid_fold=VALID_FOLD
)

### check tensor input is good

In [None]:
if TRAIN:
    
    from_valid = False
    only_valid = False
    ckpts = [0]
    valid_epochs = [
        5, 10, 15, 20,
        21, 22, 23, 24, 25,
        26, 27, 28, 29, 30,
        31, 32, 33, 34, 35,
        36, 37, 38, 39, 40,
        41, 42, 43, 44, 45,
        46, 47, 48, 49, 50,
        51, 52, 53, 54, 55,
        56, 57, 58, 59, 60,               
    ]

    train_manager = Train_Manager(config, train_config)
    predictor, optimizer, loss_obj, loss_obj_answer, metrics = train_manager.get_train_objs(
        lr=LR, end_lr=END_LR, warmup_steps=WARMUP_STEPS
    )
    
    # ----------------------------------------------------------------------------------------------------

    run_dummy_inputs(predictor)

    print('-' * 40)
    print('trainable variables:')
    for v in predictor.trainable_variables:
        print(v.name)
    print('-' * 40)

    for ckpt_no in ckpts:

        # ----------------------------------------------------------------------------------------------------

        if RESUME_TRAINING or only_valid:

            # reload ckpt
            ckpt_manager, loaded_ckpt_no = load_ckpt(
                predictor, optimizer,
                CKPT_TRAIN_PATH, ckpt_no=ckpt_no
            )

        else:

            with strategy.scope():

                loaded_ckpt_no = 0

                checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=predictor)
                ckpt_path = CKPT_TRAIN_PATH
                ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=ckpt_path, max_to_keep=None)

        train_manager.train_valid(
            predictor, optimizer, ckpt_manager,
            loss_obj, loss_obj_answer, metrics,
            last_epoch=loaded_ckpt_no,
            from_valid=from_valid,
            valid_epochs=valid_epochs,
            only_valid=only_valid
        )

In [None]:
CKPT_TRAIN_PATH

In [None]:
!ls -l

In [None]:
if CKPT_TRAIN_PATH.startswith('gs://'):
    !gsutil ls '{CKPT_TRAIN_PATH}'
else:
    !ls -l '{CKPT_TRAIN_PATH}'

In [None]:
predictor = TFEdFormerAnswerPredictionModel(config)

if ((VALID or PRED) and not DUMMY) or SUBMISSION:
    assert CKPT_PRED_PATH is not None and CKPT_PRED_PATH != './'

In [None]:
CKPT_PRED_PATH

In [None]:
if VALID:
    
    valid_results = {'ckpt_no': [], 'valid_loss': [], 'valid_acc': [], 'valid_auc': []}

    ckpts_to_run = [60]
                    
    print(f'ckpts_to_run: {ckpts_to_run}')
        
    for ckpt_no in ckpts_to_run:
            
        # reload later ckpt
        if ckpt_no > 0:

            ckpt_manager, loaded_ckpt_no = load_ckpt(
                predictor=predictor, optimizer=None,
                ckpt_path=CKPT_PRED_PATH, ckpt_no=ckpt_no,
            )
            assert loaded_ckpt_no > 0
        
        pred_iter_manager_class = Valid_Iter_Manager

        pred_manager = Pred_Manager.create(
            config=config,
            train_config=train_config,
            pred_iter_manager_class=pred_iter_manager_class,
            train_dt_path_or_obj=train_dt,               
            user_id_to_row_id_train_path=user_id_to_row_id_train_path,
            unique_question_id_train_path=unique_question_id_splitted_train_path,
            unique_lecture_id_train_path=unique_lecture_id_splitted_train_path,
            # ------------------------------------------------------------------------------------------
            # this should be replaced with the corresponding files for validation
#             question_history_at_training_end_path=question_history_at_training_end_path,
#             single_question_history_at_training_end_optimized_path=single_question_history_at_training_end_optimized_path,
          
            question_history_at_training_end_path=question_history_at_training_end_for_valid_path,
            single_question_history_at_training_end_optimized_path=single_question_history_at_training_end_optimized_for_valid_path,

            # ------------------------------------------------------------------------------------------
            valid_info_path=(valid_info_paths[VALID_FOLD] if VALID_FOLD in [0, 1, 2, 3] else None),
            max_train_buffer_size=30000,
            probe=False,
            debug=False
        )

        run_pred(pred_manager, predictor, max_steps=MAX_VALID_ITER_STEPS, ckpt_no=ckpt_no)
        
        with open(f'valid_results_ckpt_{ckpt_no}.json', 'r', encoding='UTF-8') as fp:
            
            valid_results_ckpt = json.load(fp)
            
            valid_results['ckpt_no'].append(ckpt_no)
            valid_results['valid_loss'].append(valid_results_ckpt['valid_loss'])
            valid_results['valid_acc'].append(valid_results_ckpt['valid_acc'])
            valid_results['valid_auc'].append(valid_results_ckpt['valid_auc'])
            
        print('=' * 80)
        
    with open(f'valid_results.json', 'w', encoding='UTF-8') as fp:
        json.dump(valid_results, fp, ensure_ascii=False, indent=True)

    if not IS_KAGGLE:
        !gsutil cp -r './valid_results.json' '{CKPT_PRED_PATH}'

In [None]:
!ls -l

In [None]:
valid_df = None
if VALID:

    valid_df = pd.read_csv(f'valid_submission_ckpt_{ckpts_to_run[-1]}.csv')
    pd.set_option('display.max_rows', 200)

valid_df

In [None]:
valid_results = {}
if VALID:

    with open('valid_results.json', 'r', encoding='UTF-8') as fp:
        valid_results = json.load(fp)

valid_results

In [None]:
CKPT_PRED_PATH

In [None]:
pred_start = datetime.datetime.now()
elapsed = (pred_start - kernel_start).total_seconds()
print(f'time before prediction: {elapsed} seconds')

if PRED:
    
    ckpt_no = 60
    
    # reload ckpt
    ckpt_manager, loaded_ckpt_no = load_ckpt(
        predictor=predictor, optimizer=None,
        ckpt_path=CKPT_PRED_PATH, ckpt_no=ckpt_no,
    )
    
    if SUBMISSION:
        assert loaded_ckpt_no > 0
    
    pred_iter_manager_class = Test_Iter_Manager

    pred_manager = Pred_Manager.create(
        config=config,
        train_config=train_config,
        pred_iter_manager_class=pred_iter_manager_class,
        train_dt_path_or_obj=train_dt,  
        user_id_to_row_id_train_path=user_id_to_row_id_train_path,
        unique_question_id_train_path=unique_question_id_train_path,
        unique_lecture_id_train_path=unique_lecture_id_train_path,
        question_history_at_training_end_path=question_history_at_training_end_path,
        single_question_history_at_training_end_optimized_path=single_question_history_at_training_end_optimized_path,
        valid_info_path=None,   
        max_train_buffer_size=30000,
        probe=PROBE,
        debug=False        
    )

    run_pred(pred_manager, predictor, max_steps=None)

In [None]:
!ls -l