# SAINT
* With this notebook, we would like to send our special thanks to @Yih-Dar SHIEH for his/her dedicate explanations, we studied his/her notebook (https://www.kaggle.com/yihdarshieh/tpu-track-knowledge-states-of-1m-students) a lot, and gained many insights. Please upvote his/her notebook first;
* Beside, the model is mostly copied from https://github.com/seewoo5/KT with some personal adjustments in designing masks. The masking scheme is also inspired by @Yih-Dar SHIEH's notebook;
* Add content_difficulty and user_correctness, the content_difficulty is in the encoder part meanwhile the user_correctness is in the decoder part. They are both computed before splitting the data into training and validating sets;
* Continue to train the model, reset the learning rate to 1e-5 initially, the model was trained twice before. The first stage was trained in https://www.kaggle.com/shinomoriaoshi/riiid-saint-randomization-inference/output?scriptVersionId=51035364, the last two stages were trained on QBlocks platform. The learning rate scheduler is, the 30 first epochs with the initial learning rate of 1e-3, the next 30 epochs with the inital learning rate of 1e-5, the final 25 epochs with the intial learning rate of 1e-7. All there training stages use Noam scheduler. Best CV after the training process is 0.7825;
* There are a lot of missing parts due to the limited time, one of them is lag time, which I believe could boost the result quite a lot (according to what I observed from the thread of SAINT benchmark);
* Because we know the SAINT model has much more potential than what we got here with the LB score, the main reason we share our work is that people can read and discuss what we are missing such that they can boost the model performance, we welcome all discussions :D.

* First setting

In [None]:
r'''!pip install sklearn
!pip install datatable
!pip install seaborn
!pip install kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 /home/qblocks/.kaggle/kaggle.json
!kaggle datasets download -d shinomoriaoshi/riiid-chunking
!kaggle competitions download -c riiid-test-answer-prediction
!kaggle datasets download -d yihdarshieh/r3id-info-public
!unzip riiid-chunking.zip
!unzip riiid-test-answer-prediction.zip
!unzip r3id-info-public.zip'''

In [None]:
import gc, sys, os
import random, math
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import datatable as dt
import multiprocessing

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import seaborn as sns
sns.set()
DEFAULT_FIG_WIDTH = 20
sns.set_context("paper", font_scale = 1.2)

import warnings
warnings.filterwarnings('ignore')

In [None]:
print('Python     : ' + sys.version.split('\n')[0])
print('Numpy      : ' + np.__version__)
print('Pandas     : ' + pd.__version__)
print('PyTorch    : ' + torch.__version__)

In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # if IS_TPU == False else xm.xla_device()
print('Running on device: {}'.format(DEVICE))

In [None]:
def seed_everything(s):
    random.seed(s)
    os.environ['PYTHONHASHSEED'] = str(s)
    np.random.seed(s)
    # Torch
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

In [None]:
HOME =  "./"
DATA_HOME = r'../input/riiid-test-answer-prediction'
MODEL_NAME = "SAKT-v1"
MODEL_PATH = HOME + MODEL_NAME
STAGE = "stage1"
MODEL_BEST = 'model_best.pt'
FOLD = 1
MAX_SEQ = 256
    
CONTENT_TYPE_ID = "content_type_id"
CONTENT_ID = "content_id"
TARGET = "answered_correctly"
USER_ID = "user_id"
TASK_CONTAINER_ID = "task_container_id"
TIMESTAMP = "timestamp"
PART = "part"
ELAPSE = "prior_question_elapsed_time"
DIFF = 'content_difficulty'
CORR = "user_correctness"

## Load data

In [None]:
#%%time
USE_PICKLE = True

dtype = {
    USER_ID: 'int32', 
    CONTENT_ID: 'int16',
    CONTENT_TYPE_ID: 'bool',
    TARGET:'int8',
}
train_df = dt.fread(os.path.join(DATA_HOME, 'train.csv'), columns = set(dtype.keys())).to_pandas()
train_df = train_df[train_df[CONTENT_TYPE_ID] == False].reset_index(drop = True)
train_df.head()

# Overall user correctness
user_agg = train_df.groupby(USER_ID)[TARGET].agg(['sum', 'count']).astype('int16')
# Overall difficulty of questions
content_agg = train_df.groupby(CONTENT_ID)[TARGET].agg(['sum', 'count'])

del train_df; gc.collect()
    
if not USE_PICKLE:
    # How hard are questions in each content ID?
    train_df[DIFF] = train_df[CONTENT_ID].map(content_agg['sum'] / content_agg['count'])

    train_df.fillna(0, inplace = True)

    skills = train_df[CONTENT_ID].unique()
    n_skill = 13523 # len(skills)
    print("Number of skills", n_skill)

    elapse = train_df[ELAPSE].unique()
    n_elapse = 300 # max(elapse)
    print("Number of skills", n_elapse)
    
    batch_size = 100_000
    batches = []
    for i in range(int(train_df[USER_ID].nunique()//batch_size)+1):
        batch = [i * batch_size, min((i + 1) * batch_size, train_df[USER_ID].nunique())]
        batches.append(batch)


    picked_user = []
    for i, batch in enumerate(batches):
        u_s = train_df[USER_ID].unique()[batch[0]]
        u_e = train_df[USER_ID].unique()[batch[1]-1]
        picked_user.append([u_s, u_e])

    indx = []
    for i, batch in enumerate(picked_user):
        idx_s = list(train_df[train_df[USER_ID] == batch[0]].index)[0]
        idx_e = list(train_df[train_df[USER_ID] == batch[1]].index)[-1]
        indx.append([idx_s, idx_e])

    for i, batch in enumerate(indx):
        sub_df = train_df.loc[batch[0]:batch[1]][[USER_ID, CONTENT_ID, 
                                                  TASK_CONTAINER_ID, 
                                                  PART, DIFF, ELAPSE, 
                                                  CORR, TARGET]].groupby(USER_ID).apply(lambda r: (r[CONTENT_ID].values, 
                                                                                                   r[TASK_CONTAINER_ID].values, 
                                                                                                   r[PART].values, 
                                                                                                   r[DIFF].values, 
                                                                                                   r[ELAPSE].values, 
                                                                                                   r[CORR].values, 
                                                                                                   r[TARGET].values))
        sub_df.to_pickle(f'train_group_{i}.pkl')
        del sub_df

    del train_df; gc.collect()

In [None]:
train_group = []
for i in range(4):
    train_group.append(pd.read_pickle(os.path.join(r'../input/riiid-chunking', f'train_group_{i}.pkl')))
    
train_group = pd.concat(train_group)

In [None]:
questions_df = pd.read_csv(
    '../input/riiid-test-answer-prediction/questions.csv', 
    usecols = [0, 3], 
    dtype = {'question_id': 'int16', 'part': 'int8'}
)

# Training/Validation split strategy with r3id_info_public

* Define some special tokens

In [None]:
PAD_TOKEN = -1
START_TOKEN = -2
MASK_TOKEN = -3

special_token = [PAD_TOKEN, START_TOKEN, MASK_TOKEN]

PAD_ID = 0
START_ID = 1
MASK_ID = 2
RESPONSE_FALSE_ID = 3
RESPONSE_TRUE_ID = 4

* Mapping for old responses

In [None]:
responses_total_tokens = np.concatenate((special_token, [0, 1]))
questions_encoding_mapper = dict(zip(responses_total_tokens, np.arange(len(responses_total_tokens))))

In [None]:
# Import the pre-defined indexes
import json
with open(os.path.join(r'../input/r3id-info-public', 'train_valid_split_indices_fold_1.json')) as json_file:
    splitting_index = json.load(json_file)

* Add the START_TOKEN

In [None]:
def add_start_token(x):
    return (*x, np.append([START_TOKEN], x[-1][:-1]))
train_group = train_group.apply(add_start_token)
train_group

In [None]:
# Extracting the train_df
def train_extraction(user):
    if str(user) not in splitting_index.keys():
        return train_group[user]
    else:
        if splitting_index[str(user)] == 0:
            return np.nan
        else:
            return tuple([seq[:splitting_index[str(user)]] for seq in train_group[user]])
        
# Apply
sub_train_ = train_group.copy()
for i in tqdm(train_group.index):
    sub_train_[i] = train_extraction(i)
sub_train_.dropna(inplace = True)

In [None]:
sub_train_[115]

In [None]:
# Extracting the valid_df
def valid_extraction(user):
    if str(user) not in splitting_index.keys():
        return np.nan
    else:
        if splitting_index[str(user)] == 0:
            return train_group[user]
        else:
            return tuple([seq[splitting_index[str(user)]:] for seq in train_group[user]])

# Apply
valid_df = train_group.copy()
for i in tqdm(train_group.index):
    valid_df[i] = valid_extraction(i)
valid_df.dropna(inplace = True)

In [None]:
def valid_2_sequence(user, valid_seqs, idx, train_seqs = None):
    # Take the task_seq, determine the positional index and the number of blocks
    task_seq = valid_seqs[1]
    if train_seqs is not None:
        seq_len_train = len(train_seqs[0])
    pos_idx = np.cumsum([True] + [i != j for i, j in zip(task_seq, task_seq[1:])])
    num_blocks = pos_idx[-1]
    
    # For how many blocks, create that number of subsequences, each subsequence contains only one block
    user_val_seq = []
    idx_end = np.where(pos_idx == (idx + 1))[0][-1] + 1
    if train_seqs is not None:
        user_val_seq.append(tuple([np.concatenate((train_seq, valid_seq[:idx_end])) 
                                   for train_seq, valid_seq in zip(train_seqs, valid_seqs)]))
        old_res_mask = np.array([1] * len(train_seqs[0]) + [0] * len(valid_seqs[0][:idx_end]))
    else:
        user_val_seq.append(tuple([np.array(valid_seq[:idx_end]) for valid_seq in valid_seqs]))
        old_res_mask = np.zeros(len(valid_seqs[0][:idx_end]))
        
    return user_val_seq[0], old_res_mask

* Check

In [None]:
user_val_seq, old_res_mask = valid_2_sequence(115, valid_df[115], 19, train_seqs = sub_train_[115])

In [None]:
user_val_seq

In [None]:
old_res_mask

In [None]:
def num_subsequence_each_user(user, valid_seqs):
    # Take the task_seq, determine the positional index and the number of blocks
    task_seq = valid_seqs[1]
    pos_idx = np.cumsum([True] + [i != j for i, j in zip(task_seq, task_seq[1:])])
    num_blocks = pos_idx[-1]
    return num_blocks

* Check

In [None]:
num_subsequence_each_user(115, valid_df[115])

In [None]:
def data_len_calculator(valid_df):
    return np.cumsum([num_subsequence_each_user(i, valid_df[i]) for i in valid_df.index])

* Check

In [None]:
data_len_calculator(valid_df)

# Design some masking functions

In [None]:
def position(task, pad_include = True, bundle_ignore = True):
    if not bundle_ignore:
        if pad_include:
            num_padded = len(task[task == PAD_TOKEN])
            task = task[task != PAD_TOKEN]
        else:
            num_padded = 0
        # Position ids
        # It depends on the task container
        pos = np.cumsum([0] * num_padded + [True] + [i != j for i, j in zip(task, task[1:])])
    else:
        if pad_include:
            num_padded = len(task[task == PAD_TOKEN])
            task = task[task != PAD_TOKEN]
        else:
            num_padded = 0
        pos = np.array([0] * num_padded + list(range(1, len(task) + 1)))
    return pos

* Prediction mask

In [None]:
def prediction_mask(pos):
    # This return an array, where 0 shows the position that we don't predict and 1 is the position we do prediction
    pred_mask = 1 - np.sign(np.cumsum([False] + [i != j for i, j in zip(pos[::-1][1:], pos[::-1])]))[::-1]
    return pred_mask

* Check

In [None]:
task = user_val_seq[1]
print(f'Position ids: {position(task, pad_include = False, bundle_ignore = False)}')
print(f'Prediction_mask: {prediction_mask(position(task, pad_include = False, bundle_ignore = False))}')

print(f'Position ids with bundle ignore: {position(task, pad_include = False)}')
print(f'Prediction_mask with bundle ignore: {prediction_mask(position(task, pad_include = False))}')

In [None]:
task = np.concatenate(([-1] * 5, user_val_seq[1]))
print(f'Position ids: {position(task, pad_include = True, bundle_ignore = False)}')
print(f'Prediction_mask: {prediction_mask(position(task, pad_include = True, bundle_ignore = False))}')

print(f'Position ids with bundle ignore: {position(task, pad_include = True)}')
print(f'Prediction_mask with bundle ignore: {prediction_mask(position(task, pad_include = True))}')

* Old responses to index

In [None]:
def old_response_to_index(old_res, old_res_mask, use_mask = True):
    if use_mask:
        masked_seq = old_res * old_res_mask
    else:
        masked_seq = old_res
    # Start padding
    pad_mask = (masked_seq == PAD_TOKEN).astype(int) * PAD_ID
    start_mask = (masked_seq == START_TOKEN).astype(int) * START_ID
    false_mask = (masked_seq == 0).astype(int) * RESPONSE_FALSE_ID
    true_mask = (masked_seq == 1).astype(int) * RESPONSE_TRUE_ID
    if use_mask:
        return (pad_mask + start_mask + false_mask + true_mask) * old_res_mask + (1 - old_res_mask) * MASK_ID
    else:
        return pad_mask + start_mask + false_mask + true_mask

* Check: validation

In [None]:
sample_seq = user_val_seq[-1]
old_response_to_index(sample_seq, old_res_mask)

In [None]:
user_val_seq[-1]

In [None]:
# Check with train dataset
sample_seq = sub_train_[115][-1]
old_res_mask = np.ones(len(sub_train_[115][-1]))

old_response_to_index(sample_seq, old_res_mask)

# Datasets

* Training Dataset

In [None]:
class RIIID_Train_Dataset(Dataset):
    def __init__(self, group, max_seq = 100):
        self.group = group
        self.max_seq = max_seq
        
        # Discard users with too short sequences
        self.user_ids = []
        for i, user_id in enumerate(group.index):
            q = group[user_id][0]
            if len(q) < 2: # 10 interactions minimum
                continue
            self.user_ids.append(user_id)
            
    def __len__(self):
        return len(self.user_ids)
    
    def __getitem__(self, idx):
        # Pick user
        user_id = self.user_ids[idx]
        
        # Unpack sequences
        ques_, task_, part_, diff_, elapse_, corr_, target_, old_res_ = self.group[user_id]
        
        # Old responses mask
        old_res_mask_ = np.ones(len(old_res_))
        pred_mask_ = old_res_mask_
        old_res_ = old_response_to_index(old_res_, old_res_mask_)
        
        seq_len = len(ques_)
        
        # Position
        pos_ = np.arange(seq_len)
        
        # Create arrays to pad the sequences
        ques = np.zeros(self.max_seq, dtype = int) - 1
        task = np.zeros(self.max_seq, dtype = int) - 1
        part = np.zeros(self.max_seq, dtype = int)
        diff = np.zeros(self.max_seq, dtype = float)
        elapse = np.zeros(self.max_seq, dtype = float)
        corr = np.zeros(self.max_seq, dtype = float)
        target = np.zeros(self.max_seq, dtype = int)
        old_res = np.zeros(self.max_seq, dtype = int)
        pred_mask = np.zeros(self.max_seq, dtype = int)
        
        if seq_len >= self.max_seq:
            if seq_len > self.max_seq:
                # For the training set, if the sequences are longer than the max_seq, random sampling a sub-sequence
                start_index = np.random.randint(seq_len - self.max_seq)
                ques = ques_[start_index:(start_index + self.max_seq)]
                task = task_[start_index:(start_index + self.max_seq)]
                part = part_[start_index:(start_index + self.max_seq)]
                diff = diff_[start_index:(start_index + self.max_seq)]
                elapse = elapse_[start_index:(start_index + self.max_seq)]
                corr = corr_[start_index:(start_index + self.max_seq)]
                target = target_[start_index:(start_index + self.max_seq)]
                old_res = old_res_[start_index:(start_index + self.max_seq)]
                pred_mask = pred_mask_[start_index:(start_index + self.max_seq)]
            else:
                ques = ques_[-self.max_seq:]
                task = task_[-self.max_seq:]
                part = part_[-self.max_seq:]
                diff = diff_[-self.max_seq:]
                elapse = elapse_[-self.max_seq:]
                corr = corr_[-self.max_seq:]
                target = target_[-self.max_seq:]
                old_res = old_res_[-self.max_seq:]
                pred_mask = pred_mask_[-self.max_seq:]
        else:
            ques[-seq_len:] = ques_
            task[-seq_len:] = task_
            part[-seq_len:] = part_
            diff[-seq_len:] = diff_
            elapse[-seq_len:] = elapse_
            corr[-seq_len:] = corr_
            target[-seq_len:] = target_
            old_res[-seq_len:] = old_res_
            pred_mask[-seq_len:] = pred_mask_
        
        task_pos = position(task, bundle_ignore = False)
        
        e_pos = position(task)
        d_pos = np.concatenate(([0], e_pos[:-1]))
        
        # Return the sequences
        return {
            'user': user_id, 
            'content_id': torch.tensor(ques, dtype = torch.long) + 1, 
            'task_container_id': torch.tensor(task, dtype = torch.long) + 1, 
            'part_id': torch.tensor(part, dtype = torch.long), 
            'diff_id': torch.tensor(diff, dtype = torch.float32), 
            'prior_elapsed_time_id': torch.tensor(elapse, dtype = torch.float32), 
            'user_correctness_id': torch.tensor(corr, dtype = torch.float32), 
            'old_response_id': torch.tensor(old_res, dtype = torch.long), 
            'encoder_position_id': torch.tensor(e_pos, dtype = torch.long), 
            'decoder_position_id': torch.tensor(d_pos, dtype = torch.long), 
            'task_position_id': torch.tensor(task_pos, dtype = torch.long), 
            'prediction_mask': torch.tensor(pred_mask, dtype = torch.long), 
            'target':  torch.tensor(target, dtype = torch.float32)
        }

* Validation dataset

In [None]:
class RIIID_Valid_Dataset(Dataset):
    def __init__(self, valid, train, max_seq = 100):
        self.valid_df = valid
        self.train_df = train
        self.max_seq = max_seq
        self.users = list(self.valid_df.index)
        self.len_valid = np.array([0] + data_len_calculator(self.valid_df).tolist())
        
    def __len__(self):
        return self.len_valid[-1]
    
    def __getitem__(self, idx):
        # Idea:
        # For the index (idx), determine which user that index shows
        # E.g. The user number 115 has 20 subsequences, if the index is 7, then it will return 0 (the index of user 115)
        user_idx = np.argmax(self.len_valid[self.len_valid <= idx])
        user = self.users[user_idx]
        user_seq_valid = self.valid_df[user]
        if user in self.train_df.index:
            user_seq_train = self.train_df[user]
        else:
            user_seq_train = None
        # Extract the valid sequence
        
        begin_idx = self.len_valid[self.len_valid <= idx][-1]
        idx_in_user_batch = idx - begin_idx
        
        # Unpack sequences
        (ques_, task_, part_, diff_, elapse_, corr_, target_, old_res_), old_res_mask_ = valid_2_sequence(user, user_seq_valid, idx_in_user_batch, 
                                                                                                          train_seqs = user_seq_train)
        # Old response to index
        old_res_ = old_response_to_index(old_res_, old_res_mask_, use_mask = False)
        
        seq_len = len(ques_)
        
        # Position
        pos_ = np.arange(seq_len)
        
        # Create arrays to pad the sequences
        ques = np.zeros(self.max_seq, dtype = int) - 1
        task = np.zeros(self.max_seq, dtype = int) - 1
        part = np.zeros(self.max_seq, dtype = int)
        diff = np.zeros(self.max_seq, dtype = float)
        elapse = np.zeros(self.max_seq, dtype = float)
        corr = np.zeros(self.max_seq, dtype = float)
        target = np.zeros(self.max_seq, dtype = int)
        old_res = np.zeros(self.max_seq, dtype = int)
        
        if seq_len >= self.max_seq:
            ques = ques_[-self.max_seq:]
            task = task_[-self.max_seq:]
            part = part_[-self.max_seq:]
            diff = diff_[-self.max_seq:]
            elapse = elapse_[-self.max_seq:]
            corr = corr_[-self.max_seq:]
            target = target_[-self.max_seq:]
            old_res = old_res_[-self.max_seq:]
        else:
            ques[-seq_len:] = ques_
            task[-seq_len:] = task_
            part[-seq_len:] = part_
            diff[-seq_len:] = diff_
            elapse[-seq_len:] = elapse_
            corr[-seq_len:] = corr_
            target[-seq_len:] = target_
            old_res[-seq_len:] = old_res_
            
        # Prediction mask
        task_pos = position(task, pad_include = True, bundle_ignore = False)
        pred_mask = prediction_mask(task_pos)
        #pred_mask = (old_res == MASK_ID).astype(int)
        
        e_pos = position(task)
        d_pos = np.concatenate(([0], e_pos[:-1]))
        
        return {
            'user': user, 
            'content_id': torch.tensor(ques, dtype = torch.long) + 1, 
            'task_container_id': torch.tensor(task, dtype = torch.long) + 1, 
            'part_id': torch.tensor(part, dtype = torch.long), 
            'diff_id': torch.tensor(diff, dtype = torch.float32), 
            'prior_elapsed_time_id': torch.tensor(elapse, dtype = torch.float32), 
            'user_correctness_id': torch.tensor(corr, dtype = torch.float32), 
            'old_response_id': torch.tensor(old_res, dtype = torch.long), 
            'encoder_position_id': torch.tensor(e_pos, dtype = torch.long), 
            'decoder_position_id': torch.tensor(d_pos, dtype = torch.long), 
            'task_position_id': torch.tensor(task_pos, dtype = torch.long), 
            'prediction_mask': torch.tensor(pred_mask, dtype = torch.long), 
            'target':  torch.tensor(target, dtype = torch.float32)
        }

* Check

In [None]:
train_dataset = RIIID_Train_Dataset(sub_train_, 10)
train_dataloader = DataLoader(train_dataset, batch_size = 5, shuffle = False, num_workers = 4)

In [None]:
next(iter(train_dataloader))

In [None]:
valid_dataset = RIIID_Valid_Dataset(valid_df, sub_train_, max_seq = 10)
valid_dataloader = DataLoader(valid_dataset, batch_size = 5, shuffle = False, num_workers = 0)

In [None]:
next(iter(valid_dataloader))

# Model

* Model Utils

In [None]:
import copy

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''
    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps
        ])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr


class NoamOpt:
    "Optim wrapper that implements rate."

    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))

class NoamOptimizer:
    def __init__(self, model, lr, model_size, warmup):
        self._adam = torch.optim.Adam(model.parameters(), lr=lr)
        self._opt = NoamOpt(
            model_size=model_size, factor=1, warmup=warmup, optimizer=self._adam)

    def step(self, loss):
        self._opt.zero_grad()
        loss.backward()
        self._opt.step()

# For Transformer-based models
def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

def get_subsequent_mask_3d(seq, only_before = True):
    batch_size, seq_len = seq.shape
    a = seq.unsqueeze(-1).expand(batch_size, seq_len, seq_len)
    b = seq.unsqueeze(1).expand(batch_size, seq_len, seq_len) + int(only_before)
    return (a >= b)

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''
    sz_b, len_s = seq.size()
    subsequent_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
    return subsequent_mask

def get_masks(seq, pad_idx, only_before = True):
    encoder_mask = (get_pad_mask(seq, pad_idx) & get_subsequent_mask_3d(seq, only_before = only_before))
    decoder_mask = (get_pad_mask(seq, pad_idx) & get_subsequent_mask(seq))
    encoder_decoder_mask = encoder_mask
    return encoder_mask, decoder_mask, encoder_decoder_mask

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

* Check

In [None]:
pos = torch.tensor([[0,0,0,1,2,2,2,3,4,5],
                    [0,0,0,0,1,2,3,3,4,4]])

get_subsequent_mask_3d(pos)

In [None]:
get_pad_mask(pos, 0)

In [None]:
get_masks(pos, pad_idx = 0)

* Building-block layers

In [None]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model, bias=False), 4) # Q, K, V, last
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous() \
            .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))
    
class SAINTLayer_encoder(nn.Module):
    """
    Single Encoder block of SAINT
    """
    def __init__(self, hidden_dim, num_head, dropout):
        super().__init__()
        self._self_attn = MultiHeadedAttention(num_head, hidden_dim, dropout)
        self._ffn = PositionwiseFeedForward(hidden_dim, hidden_dim, dropout)
        self._layernorms = clones(nn.LayerNorm(hidden_dim, eps=1e-6), 2)
        self._dropout = nn.Dropout(dropout)

    def forward(self, src, mask = None):
        """
        query: question embeddings
        key: interaction embeddings
        """
        # self-attention block
        src2 = self._self_attn(query=src, key=src, value=src, mask=mask)
        src = src + self._dropout(src2)
        src = self._layernorms[0](src)
        src2 = self._ffn(src)
        src = src + self._dropout(src2)
        src = self._layernorms[1](src)
        return src
    
class SAINTLayer_decoder(nn.Module):
    """
    Single Encoder block of SAINT
    """
    def __init__(self, hidden_dim, num_head, dropout):
        super().__init__()
        self._self_attn_decoder = MultiHeadedAttention(num_head, hidden_dim, dropout)
        self._self_attn_encoder_decoder = MultiHeadedAttention(num_head, hidden_dim, dropout)
        self._ffn = PositionwiseFeedForward(hidden_dim, hidden_dim, dropout)
        self._layernorms = clones(nn.LayerNorm(hidden_dim, eps=1e-6), 3)
        self._dropout = nn.Dropout(dropout)

    def forward(self, tgt, memory, encoder_decoder_mask = None, decoder_mask = None):
        """
        query: question embeddings
        key: interaction embeddings
        """
        # self-attention block
        tgt2 = self._self_attn_decoder(query=tgt, key=tgt, value=tgt, mask=decoder_mask)
        tgt = tgt + self._dropout(tgt2)
        tgt = self._layernorms[0](tgt)
        tgt2 = self._self_attn_encoder_decoder(query=tgt, key=memory, value=memory, mask=encoder_decoder_mask)
        tgt = tgt + self._dropout(tgt2)
        tgt = self._layernorms[1](tgt)
        tgt2 = self._ffn(tgt)
        tgt = tgt + self._dropout(tgt2)
        tgt = self._layernorms[2](tgt)
        return tgt

* Main model

In [None]:
class SAINT_plus(nn.Module):
    def __init__(self, question_num, task_num, max_seq = 100, d_model = 128, nhead = 8, dropout = 0.1, num_layers = 2):
        super().__init__()
        self.question_num = question_num
        self.question_num = task_num
        self.max_seq = max_seq
        self.d_model = d_model
        self.nhead = nhead
        self.dropout = dropout
        
        # Construct embedding layers
        # For categorical features
        self._positional_embedding = nn.Embedding(self.max_seq + 1, d_model)
        self._question_embedding = nn.Embedding(question_num + 1, d_model)
        self._task_embedding = nn.Embedding(task_num + 1, d_model)
        self._part_embedding = nn.Embedding(9, d_model)
        self._target_embedding = nn.Embedding(len(responses_total_tokens), d_model)
        
        # For continuous features
        self._diff_embedding = nn.Linear(1, d_model)
        self._elapse_embedding = nn.Linear(1, d_model)
        self._corr_embedding = nn.Linear(1, d_model)
        
        # Blocks
        self.encoder_layers = clones(SAINTLayer_encoder(d_model, nhead, dropout), num_layers)
        self.decoder_layers = clones(SAINTLayer_decoder(d_model, nhead, dropout), num_layers)

        # Prediction layer
        self._prediction = nn.Linear(d_model, 1)
        
    def forward(self, question, task, part, diff, elapse, corr, target, e_pos, d_pos, task_pos):
        device = question.device
        # Embed the continuous features first
        diff = self._diff_embedding(diff.unsqueeze(-1))
        elapse = self._elapse_embedding(elapse.unsqueeze(-1))
        corr = self._corr_embedding(corr.unsqueeze(-1))
        
        # Masks
        encoder_mask, decoder_mask, encoder_decoder_mask = get_masks(task_pos, 0, only_before = False)
        encoder_mask = encoder_mask.to(device)
        decoder_mask = decoder_mask.to(device)
        encoder_decoder_mask = encoder_decoder_mask.to(device)
        
        # Embed positions
        e_pos = self._positional_embedding(e_pos)
        d_pos = self._positional_embedding(d_pos)
        # Embed question
        question = self._question_embedding(question)
        # Embed task
        task = self._task_embedding(task)
        # Embed task
        part = self._part_embedding(part)
        # Embed task
        target = self._target_embedding(target)
        
        # Aggregate
        # Encoder is the information of questions themselves
        encoder = question + task + part + e_pos + diff
        # Encoder is the information of answers themselves
        decoder = target + elapse + d_pos + corr
        
        # Feed into the transformer
        # Encoder
        for layer in self.encoder_layers:
            encoder = layer(encoder, mask = encoder_mask)
        
        # Decoder
        for layer in self.decoder_layers:
            decoder = layer(decoder, encoder, encoder_decoder_mask = encoder_decoder_mask, 
                            decoder_mask = decoder_mask)
            
        output = self._prediction(decoder)
        return output.squeeze(-1)

# Training

* Training utils

In [None]:
def train_epoch(model, train_iterator, criterion, optim, device = 'cpu'):
    model.train()
    
    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []
    masks = []

    tbar = tqdm(train_iterator)
    for item in tbar:
        x_ques = item['content_id'].to(device)
        x_task = item['task_container_id'].to(device)
        x_part = item['part_id'].to(device)
        x_diff = item['diff_id'].to(device)
        x_elapse = item['prior_elapsed_time_id'].to(device)
        x_corr = item['user_correctness_id'].to(device)
        x_ans = item['old_response_id'].to(device)
        x_e_pos = item['encoder_position_id'].to(device)
        x_d_pos = item['decoder_position_id'].to(device)
        x_task_pos = item['task_position_id'].to(device)
        mask = item['prediction_mask'].to(device)
        label = item['target'].to(device)
        
        # Define loss
        # criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
        
        with torch.set_grad_enabled(True):
            output = model(x_ques, x_task, x_part, x_diff, x_elapse, x_corr, x_ans, x_e_pos, x_d_pos, x_task_pos)
            # Choose output and label
            output = torch.masked_select(output, mask.bool())
            label = torch.masked_select(label, mask.bool())
            # Forward loss
            loss = criterion(output, label)# / torch.sum(mask)
            # Optimization
            optim.step(loss)
            
        train_loss.append(loss.item())
        
        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += mask.sum().item()

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(torch.sigmoid(output).view(-1).data.cpu().numpy())
        batch_auc = roc_auc_score(label.view(-1).data.cpu().numpy(), torch.sigmoid(output).view(-1).data.cpu().numpy())
                           
        tbar.set_description('loss - {:.4f} || auc - {:.4f}'.format(loss, batch_auc))
    
    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [None]:
def valid_epoch(model, valid_iterator, criterion, device = 'cpu'):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []
    masks = []
    
    tbar = tqdm(valid_iterator)
    for item in tbar:
        x_ques = item['content_id'].to(device)
        x_task = item['task_container_id'].to(device)
        x_part = item['part_id'].to(device)
        x_diff = item['diff_id'].to(device)
        x_elapse = item['prior_elapsed_time_id'].to(device)
        x_corr = item['user_correctness_id'].to(device)
        x_ans = item['old_response_id'].to(device)
        x_e_pos = item['encoder_position_id'].to(device)
        x_d_pos = item['decoder_position_id'].to(device)
        x_task_pos = item['task_position_id'].to(device)
        mask = item['prediction_mask'].to(device)
        label = item['target'].to(device)
        
        # Define loss
        #criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
        
        with torch.no_grad():
            output = model(x_ques, x_task, x_part, x_diff, x_elapse, x_corr, x_ans, x_e_pos, x_d_pos, x_task_pos)
        # Choose output and label
        output = torch.masked_select(output, mask.bool())
        label = torch.masked_select(label, mask.bool())
        # Forward loss
        loss = criterion(output, label)# / torch.sum(mask)
        
        valid_loss.append(loss.item())        
        
        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += mask.sum().item()

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(torch.sigmoid(output).view(-1).data.cpu().numpy())
        batch_auc = roc_auc_score(label.view(-1).data.cpu().numpy(), torch.sigmoid(output).view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f} || auc - {:.4f}'.format(loss, batch_auc))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)
    
    return loss, acc, auc

In [None]:
class train_config:
    # Training
    METRIC_ = 'max'
    EPOCHS = 30
    LR = 1e-5
    WARM_UP = 4000
    MODE = 'inference'
    WORKERS = multiprocessing.cpu_count()
    BATCH_SIZE = 128
    TRAIN_RATIO = 0.97
    CRITERION = nn.BCEWithLogitsLoss().to(DEVICE)
    # Model
    D_MODEL = 128
    N_HEADS = 8
    N_LAYERS = 4
    N_QUES = 13523
    DROP = 0.1
    
    if torch.cuda.is_available():
        MAP_LOCATION = lambda storage, loc: storage.cuda()
    else:
        MAP_LOCATION = 'cpu'

In [None]:
if train_config.MODE == 'train':
    # Pre-setting
    auc_max = -np.inf
    history = []
    es = 0
    
    # Define model
    model = SAINT_plus(train_config.N_QUES, 10000, max_seq = MAX_SEQ, d_model = train_config.D_MODEL, 
                       nhead = train_config.N_HEADS, dropout = train_config.DROP, num_layers = train_config.N_LAYERS).to(DEVICE)
    # Define optimizer and scheduler
    optimizer = NoamOptimizer(model, train_config.LR, train_config.D_MODEL, train_config.WARM_UP)
    
    # Define validation iterator, first, define the auxilary validation dataset
    valid_dataset = RIIID_Valid_Dataset(valid_df.sample(frac = 0.5, random_state = 2020, replace = False), sub_train_, max_seq = MAX_SEQ)
    valid_dataloader = DataLoader(valid_dataset, batch_size = train_config.BATCH_SIZE * 4, 
                                  shuffle = False, num_workers = train_config.WORKERS)
    
    for epoch in range(1, train_config.EPOCHS + 1):
        # Random sampling
        weights = sub_train_.apply(lambda x: max(min(500, len(x[0])), 5))
        train_df_ = sub_train_.sample(frac = train_config.TRAIN_RATIO, weights = weights, replace = True,
                                      random_state = epoch)
        train_dataset = RIIID_Train_Dataset(train_df_.reset_index(drop = True), MAX_SEQ)
        train_dataloader = DataLoader(train_dataset, batch_size = train_config.BATCH_SIZE, 
                                  shuffle = True, num_workers = train_config.WORKERS)
        # Training
        train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, train_config.CRITERION, optimizer, device = DEVICE)
        print("\nEpoch#{}, train_loss - {:.4f} acc - {:.4f} auc - {:.4f}".format(epoch, train_loss, train_acc, train_auc))
        # Validation
        valid_loss, valid_acc, valid_auc = valid_epoch(model, valid_dataloader, train_config.CRITERION, device = DEVICE)
        print("Epoch#{}, valid_loss - {:.4f} acc - {:.4f} auc - {:.4f}".format(epoch, valid_loss, valid_acc, valid_auc))
        
        # The current learning rate
        lr = optimizer._adam.param_groups[0]['lr']
        # Store training history
        history.append({"epoch": epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
        
        if valid_auc > auc_max:
            print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
            auc_max = valid_auc
            torch.save(model.state_dict(), MODEL_BEST)
            es = 0
        else:
            es += 1
        
        if es > 20:
            break
            
    if history:
        metric = 'auc'
        # Plot training history
        history_pd = pd.DataFrame(history[1:]).set_index("epoch")
        train_history_pd = history_pd[[c for c in history_pd.columns if "train_" in c]]
        valid_history_pd = history_pd[[c for c in history_pd.columns if "valid_" in c]]
        lr_history_pd = history_pd[[c for c in history_pd.columns if "lr" in c]]
        fig, ax = plt.subplots(1,2, figsize = (DEFAULT_FIG_WIDTH, 6))
        t_epoch = train_history_pd["train_%s" % metric].argmin() if train_config.METRIC_ == "min" else train_history_pd["train_%s" % metric].argmax()
        v_epoch = valid_history_pd["valid_%s" % metric].argmin() if train_config.METRIC_ == "min" else valid_history_pd["valid_%s" % metric].argmax()
        d = train_history_pd.plot(kind = "line", ax = ax[0], title = "Epoch: %d, Train: %.3f" % (t_epoch, train_history_pd.iloc[t_epoch,:]["train_%s" % metric]))
        d = lr_history_pd.plot(kind = "line", ax = ax[0], secondary_y = True)
        d = valid_history_pd.plot(kind = "line", ax = ax[1], title = "Epoch: %d, Valid: %.3f" % (v_epoch, valid_history_pd.iloc[v_epoch,:]["valid_%s" % metric]))
        d = lr_history_pd.plot(kind = "line", ax = ax[1], secondary_y = True)
        plt.savefig("train.png", bbox_inches = 'tight')
        plt.show()

# Inference

* Test dataset

In [None]:
class RIIID_Test_Dataset(Dataset):
    def __init__(self, group, test_df, max_seq = 100):
        self.group = group
        self.test_df = test_df
        self.max_seq = max_seq
        self.pos = np.cumsum([False] + [i != j for i, j in zip(test_df[USER_ID].values, test_df[USER_ID].values[1:])])
        
    def __len__(self):
        return len(np.unique(self.pos))
    
    def __getitem__(self, idx):
        # Determine bundle
        user_bundle = self.test_df.iloc[self.pos == idx]
        user_id = user_bundle.user_id.unique()[0]
        
        # Extract the sequences
        new_ques_ = user_bundle[CONTENT_ID].values
        new_task_ = user_bundle[TASK_CONTAINER_ID].values
        new_part_ = user_bundle[PART].values
        new_diff_ = user_bundle[DIFF].values
        new_elapse_ = user_bundle[ELAPSE].values
        new_corr_ = user_bundle[CORR].values
        
        # New sequence length
        new_seq_len = len(new_ques_)
        
        # Create arrays to pad the sequences
        ques = np.zeros(self.max_seq, dtype = int) - 1
        task = np.zeros(self.max_seq, dtype = int) - 1
        part = np.zeros(self.max_seq, dtype = int)
        diff = np.zeros(self.max_seq, dtype = float)
        elapse = np.zeros(self.max_seq, dtype = float)
        corr = np.zeros(self.max_seq, dtype = float)
        old_res = np.zeros(self.max_seq, dtype = int)
        pred_mask = np.zeros(self.max_seq, dtype = int)
        
        if user_id in self.group.index:
            # If the user is already defined in the training group
            # Unpack the old information
            ques_, task_, part_, diff_, elapse_, corr_, target_, old_res_ = self.group[user_id]
            
            # Append the new information with the old one
            ques_ = np.append(ques_, new_ques_)
            task_ = np.append(task_, new_task_)
            part_ = np.append(part_, new_part_)
            diff_ = np.append(diff_, new_diff_)
            elapse_ = np.append(elapse_, new_elapse_)
            corr_ = np.append(corr_, new_corr_)
            
            # For old responses, we copy the last information of the old information and distribute it to the entire bundle
            old_res_ = np.append(old_res_, target_[-1] * np.ones(new_seq_len))
            
            # Old responses mask
            old_res_mask_ = np.ones(len(old_res_))
            old_res_ = old_response_to_index(old_res_, old_res_mask_, use_mask = False)

            seq_len = len(ques_)
            
            if seq_len >= self.max_seq:
                ques = ques_[-self.max_seq:]
                task = task_[-self.max_seq:]
                part = part_[-self.max_seq:]
                diff = diff_[-self.max_seq:]
                elapse = elapse_[-self.max_seq:]
                corr = corr_[-self.max_seq:]
                old_res = old_res_[-self.max_seq:]
            else:
                ques[-seq_len:] = ques_
                task[-seq_len:] = task_
                part[-seq_len:] = part_
                diff[-seq_len:] = diff_
                elapse[-seq_len:] = elapse_
                corr[-seq_len:] = corr_
                old_res[-seq_len:] = old_res_
        else:
            # Else, if the user is new
            # Append the new information into the array of 0
            ques = np.append(ques, new_ques_)
            task = np.append(task, new_task_)
            part = np.append(part, new_part_)
            diff = np.append(diff, new_diff_)
            elapse = np.append(elapse, new_elapse_)
            corr = np.append(corr, new_corr_)
            if new_seq_len == 1:
                old_res = np.append(old_res, START_ID)
            else:
                old_res = np.concatenate((old_res, [START_ID], [MASK_ID] * (new_seq_len - 1)))
            
            # Slice to obtain the max_seq length
            ques = ques[-self.max_seq:]
            task = task[-self.max_seq:]
            part = part[-self.max_seq:]
            diff = diff[-self.max_seq:]
            elapse = elapse[-self.max_seq:]
            corr = corr[-self.max_seq:]
            old_res = old_res[-self.max_seq:]
            
        # Position
        task_pos = position(task, pad_include = True, bundle_ignore = False)
        pred_mask[-new_seq_len:] = np.ones(new_seq_len)
        
        e_pos = position(task)
        d_pos = np.concatenate(([0], e_pos[:-1]))
        
        return {
            'user': user_id, 
            'content_id': torch.tensor(ques, dtype = torch.long) + 1, 
            'task_container_id': torch.tensor(task, dtype = torch.long) + 1, 
            'part_id': torch.tensor(part, dtype = torch.long), 
            'diff_id': torch.tensor(diff, dtype = torch.float32), 
            'prior_elapsed_time_id': torch.tensor(elapse, dtype = torch.float32), 
            'user_correctness_id': torch.tensor(corr, dtype = torch.float32), 
            'old_response_id': torch.tensor(old_res, dtype = torch.long), 
            'encoder_position_id': torch.tensor(e_pos, dtype = torch.long), 
            'decoder_position_id': torch.tensor(d_pos, dtype = torch.long), 
            'task_position_id': torch.tensor(task_pos, dtype = torch.long), 
            'prediction_mask': torch.tensor(pred_mask, dtype = torch.long)
        }

# Load the model

In [None]:
model = SAINT_plus(train_config.N_QUES, 10000, max_seq = MAX_SEQ, d_model = train_config.D_MODEL, 
                   nhead = train_config.N_HEADS, dropout = train_config.DROP, num_layers = train_config.N_LAYERS).to(DEVICE)

if train_config.MODE == 'train':
    resume_path = MODEL_BEST
elif train_config.MODE == 'inference':
    resume_path = r'../input/riiid-saint-best-models/model_best_user_correctness_included_v7.pt'
    
model.load_state_dict(torch.load(resume_path, map_location = train_config.MAP_LOCATION))

model.to(DEVICE)
model.eval()

In [None]:
EMULATION = False

if EMULATION:
    target_val = pd.read_pickle('../input/riiid-cross-validation-files/cv1_valid.pickle')

In [None]:
class Iter_Valid(object):
    def __init__(self, df, max_user=1000):
        df = df.reset_index(drop=True)
        self.df = df
        self.user_answer = df['user_answer'].astype(str).values
        self.answered_correctly = df['answered_correctly'].astype(str).values
        df['prior_group_responses'] = "[]"
        df['prior_group_answers_correct'] = "[]"
        self.sample_df = df[df['content_type_id'] == 0][['row_id']]
        self.sample_df['answered_correctly'] = 0
        self.len = len(df)
        self.user_id = df.user_id.values
        self.task_container_id = df.task_container_id.values
        self.content_type_id = df.content_type_id.values
        self.max_user = max_user
        self.current = 0
        self.pre_user_answer_list = []
        self.pre_answered_correctly_list = []

    def __iter__(self):
        return self
    
    def fix_df(self, user_answer_list, answered_correctly_list, pre_start):
        df= self.df[pre_start:self.current].copy()
        sample_df = self.sample_df[pre_start:self.current].copy()
        df.loc[pre_start,'prior_group_responses'] = '[' + ",".join(self.pre_user_answer_list) + ']'
        df.loc[pre_start,'prior_group_answers_correct'] = '[' + ",".join(self.pre_answered_correctly_list) + ']'
        self.pre_user_answer_list = user_answer_list
        self.pre_answered_correctly_list = answered_correctly_list
        return df, sample_df

    def __next__(self):
        added_user = set()
        pre_start = self.current
        pre_added_user = -1
        pre_task_container_id = -1

        user_answer_list = []
        answered_correctly_list = []
        while self.current < self.len:
            crr_user_id = self.user_id[self.current]
            crr_task_container_id = self.task_container_id[self.current]
            crr_content_type_id = self.content_type_id[self.current]
            if crr_content_type_id == 1:
                # no more than one task_container_id of "questions" from any single user
                # so we only care for content_type_id == 0 to break loop
                user_answer_list.append(self.user_answer[self.current])
                answered_correctly_list.append(self.answered_correctly[self.current])
                self.current += 1
                continue
            if crr_user_id in added_user and ((crr_user_id != pre_added_user) or (crr_task_container_id != pre_task_container_id)):
                # known user(not prev user or differnt task container)
                return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
            if len(added_user) == self.max_user:
                if  crr_user_id == pre_added_user and crr_task_container_id == pre_task_container_id:
                    user_answer_list.append(self.user_answer[self.current])
                    answered_correctly_list.append(self.answered_correctly[self.current])
                    self.current += 1
                    continue
                else:
                    return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
            added_user.add(crr_user_id)
            pre_added_user = crr_user_id
            pre_task_container_id = crr_task_container_id
            user_answer_list.append(self.user_answer[self.current])
            answered_correctly_list.append(self.answered_correctly[self.current])
            self.current += 1
        if pre_start < self.current:
            return self.fix_df(user_answer_list, answered_correctly_list, pre_start)
        else:
            raise StopIteration()

In [None]:
if EMULATION:
    iter_test = Iter_Valid(target_val, max_user = 1000)
    predicted = []
    def set_predict(df):
        predicted.append(df)
else:
    import riiideducation
    env = riiideducation.make_env()
    iter_test = env.iter_test()
    set_predict = env.predict

test_train_group = train_group.copy()
prev_test_df = None

In [None]:
import psutil
from collections import defaultdict

user_sum_dict = user_agg['sum'].astype('int32').to_dict(defaultdict(int))
user_count_dict = user_agg['count'].astype('int32').to_dict(defaultdict(int))

content_sum_dict = content_agg['sum'].astype('int32').to_dict(defaultdict(int))
content_count_dict = content_agg['count'].astype('int32').to_dict(defaultdict(int))

In [None]:
for ii, (test_df, sample_prediction_df) in enumerate(tqdm(iter_test)):
    print('*' * 50)
    if (prev_test_df is not None) & (psutil.virtual_memory().percent < 90):
        print(psutil.virtual_memory().percent)
        prev_test_df[TARGET] = eval(test_df['prior_group_answers_correct'].iloc[0])
        prev_test_df = prev_test_df[prev_test_df.content_type_id == False].reset_index(drop = True)
        
        # Position
        pos = np.cumsum([False] + [i != j for i, j in zip(prev_test_df[USER_ID].values, prev_test_df[USER_ID].values[1:])])
        
        # Update content tables (table about questions and correct answers of questions)
        content_ids = prev_test_df[CONTENT_ID].values
        user_ids = prev_test_df[USER_ID].values
        targets = prev_test_df[TARGET].values
        
        for content_id, user_id, answered_correctly in zip(content_ids, user_ids, targets):
            user_sum_dict[user_id] += answered_correctly
            user_count_dict[user_id] += 1
            content_sum_dict[content_id] += answered_correctly
            content_count_dict[content_id] += 1
        
        # Because there will be cases that a user appears multiple times in the test_df, but not in a same bundle,
        # ...we can not use the groupby() method
        
        # Update train_group
        for prev_user_id in np.unique(pos):
            sub_user_df = prev_test_df[pos == prev_user_id]
            user = sub_user_df[USER_ID].unique()[0]
            
            prev_group_content = sub_user_df[CONTENT_ID].values
            prev_group_task = sub_user_df[TASK_CONTAINER_ID].values
            prev_group_part = sub_user_df[PART].values
            prev_group_diff = sub_user_df[DIFF].values
            prev_group_elapse = sub_user_df[ELAPSE].values
            prev_group_correct = sub_user_df[CORR].values
            prev_group_target = sub_user_df[TARGET].values
            
            if user in test_train_group.index:
                # Update old response if the user is already define in the test_train_group
                prev_group_old_res = np.concatenate(([test_train_group[user][6][-1]], prev_group_target))[:-1]
                test_train_group[user] = (np.append(test_train_group[user][0], prev_group_content), 
                                          np.append(test_train_group[user][1], prev_group_task), 
                                          np.append(test_train_group[user][2], prev_group_part), 
                                          np.append(test_train_group[user][3], prev_group_diff), 
                                          np.append(test_train_group[user][4], prev_group_elapse), 
                                          np.append(test_train_group[user][5], prev_group_correct), 
                                          np.append(test_train_group[user][6], prev_group_target), 
                                          np.append(test_train_group[user][7], prev_group_old_res))
            else:
                # If the user is new, create a new sequence
                prev_group_old_res = np.array([1] * len(prev_group_content))
                test_train_group[user] = (prev_group_content, prev_group_task, prev_group_part, 
                                          prev_group_diff, prev_group_elapse, prev_group_correct, 
                                          prev_group_target, prev_group_old_res)
            
            if len(test_train_group[user][0]) > MAX_SEQ:
                new_group_content = test_train_group[user][0][-MAX_SEQ:]
                new_group_task = test_train_group[user][1][-MAX_SEQ:]
                new_group_part = test_train_group[user][2][-MAX_SEQ:]
                new_group_diff = test_train_group[user][3][-MAX_SEQ:]
                new_group_elapse = test_train_group[user][4][-MAX_SEQ:]
                new_group_correct = test_train_group[user][5][-MAX_SEQ:]
                new_group_target = test_train_group[user][6][-MAX_SEQ:]
                new_group_old_res = test_train_group[user][7][-MAX_SEQ:]
                
                test_train_group[user] = (new_group_content, new_group_task, new_group_part, 
                                          new_group_diff, new_group_elapse, new_group_correct,
                                          new_group_target, new_group_old_res)
                
    # Merge with question data
    test_df = pd.merge(test_df, questions_df, left_on = CONTENT_ID, right_on = 'question_id', how = 'left')
    
    # Compute the content_difficulty
    user_sum = np.zeros(len(test_df), dtype = np.int16)
    user_count = np.zeros(len(test_df), dtype = np.int16)
    content_sum = np.zeros(len(test_df), dtype = np.int32)
    content_count = np.zeros(len(test_df), dtype = np.int32)

    for i, (user_id, content_id) in enumerate(zip(test_df[USER_ID].values, test_df[CONTENT_ID].values)):
        user_sum[i] = user_sum_dict[user_id]
        user_count[i] = user_count_dict[user_id]
        content_sum[i] = content_sum_dict[content_id]
        content_count[i] = content_count_dict[content_id]
    
    test_df[DIFF] = content_sum / content_count
    test_df[CORR] = user_sum / user_count
    test_df[ELAPSE] = test_df[ELAPSE] / 300000 * 100
    
    # Mapping the elapse time into categories
    test_df.fillna(0, inplace = True)
    
    # Copy the test_df to prev_test_df
    prev_test_df = test_df.copy()
    
    # Drop lecture observations
    test_df = test_df[test_df.content_type_id == False].reset_index(drop = True)
    
    test_dataset = RIIID_Test_Dataset(test_train_group, test_df)
    test_dataloader = DataLoader(test_dataset, batch_size = train_config.BATCH_SIZE, shuffle = False, drop_last = False)
    
    outs = []

    for item in test_dataloader:
        x_ques = item['content_id'].to(DEVICE)
        x_task = item['task_container_id'].to(DEVICE)
        x_part = item['part_id'].to(DEVICE)
        x_diff = item['diff_id'].to(DEVICE)
        x_elapse = item['prior_elapsed_time_id'].to(DEVICE)
        x_corr = item['user_correctness_id'].to(DEVICE)
        x_ans = item['old_response_id'].to(DEVICE)
        x_e_pos = item['encoder_position_id'].to(DEVICE)
        x_d_pos = item['decoder_position_id'].to(DEVICE)
        x_task_pos = item['task_position_id'].to(DEVICE)
        mask = item['prediction_mask'].to(DEVICE)

        with torch.no_grad():
            output = model(x_ques, x_task, x_part, x_diff, x_elapse, x_corr, x_ans, x_e_pos, x_d_pos, x_task_pos)
        output = torch.masked_select(output, mask.bool())    
        outs.extend(torch.sigmoid(output).view(-1).data.cpu().numpy())
    
    test_df['answered_correctly'] = outs
    set_predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])