I need take the rest.
I will upload the training part and add comment later to explain the detail.

Thanks for very great notebook. The notebook got many idea from the following.
1. https://www.kaggle.com/manikanthr5/riiid-sakt-model-training-public
2. https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing
3. https://www.kaggle.com/leadbest/sakt-with-randomization-state-updates



In [None]:
import psutil
import joblib

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
"""
version 6 is add three lag feature and add zero padding .

feature list:

1.content_id
2.answered_correctly
3.part
4.prior_question_elapsed_time
5.time_lag1
6.time_lag2
7.time_lag3
8.prior_question_had_explanation


"""

In [None]:
## cv 0.7993

In [None]:
## epoch - 0 train_loss - 0.4103 train_auc - 0.7990 val_loss - 0.5121 val_auc - 0.7993 time=641.96s

In [None]:
MAX_SEQ = 100
n_part = 7
D_MODEL = 256
N_LAYER = 2
DROPOUT = 0.1

In [None]:
def feature_time_lag(df, time_dict):

    tt = np.zeros(len(df), dtype=np.int64)

    for ind, row in enumerate(df[['user_id','timestamp','task_container_id']].values):

        if row[0] in time_dict.keys():
            if row[2]-time_dict[row[0]][1] == 0:

                tt[ind] = time_dict[row[0]][2]

            else:
                t_last = time_dict[row[0]][0]
                task_ind_last = time_dict[row[0]][1]
                tt[ind] = row[1]-t_last
                time_dict[row[0]] = (row[1], row[2], tt[ind])
        else:
            # time_dict : timestamp, task_container_id, lag_time
            time_dict[row[0]] = (row[1], row[2], -1)
            tt[ind] =  0

    df["time_lag"] = tt
    return df


In [None]:
class FFN(nn.Module):
    def __init__(self, state_size=200):
        super(FFN, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(state_size, state_size)
        self.dropout = nn.Dropout(DROPOUT)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class SAINTModel(nn.Module):
    def __init__(self, n_skill, n_part, max_seq=MAX_SEQ, embed_dim= 128, elapsed_time_cat_flag = True):
        super(SAINTModel, self).__init__()

        self.n_skill = n_skill
        self.embed_dim = embed_dim
        self.n_cat = n_part
        self.elapsed_time_cat_flag = elapsed_time_cat_flag

        self.e_embedding = nn.Embedding(self.n_skill+1, embed_dim) ## exercise
        self.c_embedding = nn.Embedding(self.n_cat+1, embed_dim) ## category
        self.pos_embedding = nn.Embedding(max_seq-1, embed_dim) ## position
        self.res_embedding = nn.Embedding(2+1, embed_dim) ## response


        if self.elapsed_time_cat_flag == True:
            self.elapsed_time_embedding = nn.Embedding(300+1, embed_dim) ## elapsed time (the maximum elasped time is 300)
            self.lag_embedding1 = nn.Embedding(300+1, embed_dim) ## lag time1 for 300 seconds
            self.lag_embedding2 = nn.Embedding(1440+1, embed_dim) ## lag time2 for 1440 minutes
            self.lag_embedding3 = nn.Embedding(365+1, embed_dim) ## lag time3 for 365 days

        else:
            self.elapsed_time_embedding = nn.Linear(1, embed_dim, bias=False) ## elapsed time
            self.lag_embedding = nn.Linear(1, embed_dim, bias=False) ## lag time


        self.exp_embedding = nn.Embedding(2+1, embed_dim) ## user had explain

        self.transformer = nn.Transformer(nhead=8, d_model = embed_dim, num_encoder_layers= N_LAYER, num_decoder_layers= N_LAYER, dropout = DROPOUT)

        self.dropout = nn.Dropout(DROPOUT)
        self.layer_normal = nn.LayerNorm(embed_dim) 
        self.ffn = FFN(embed_dim)
        self.pred = nn.Linear(embed_dim, 1)
    
    def forward(self, question, part, response, elapsed_time, lag_time, exp):

        device = question.device  

        ## embedding layer
        question = self.e_embedding(question)
        part = self.c_embedding(part)
        pos_id = torch.arange(question.size(1)).unsqueeze(0).to(device)
        pos_id = self.pos_embedding(pos_id)
        res = self.res_embedding(response)
        exp = self.exp_embedding(exp)

        if self.elapsed_time_cat_flag == True:

            ## feature engineering
            ## elasped time
            elapsed_time = torch.true_divide(elapsed_time, 1000)
            elapsed_time = torch.round(elapsed_time)
            elapsed_time = torch.where(elapsed_time.float() <= 300, elapsed_time, torch.tensor(300.0).to(device)).long()
            elapsed_time = self.elapsed_time_embedding(elapsed_time)

            ## lag_time1
            lag_time = torch.true_divide(lag_time, 1000)
            lag_time = torch.round(lag_time)
            lag_time1 = torch.where(lag_time.float() <= 300, lag_time, torch.tensor(300.0).to(device)).long()

            ## lag_time2
            lag_time = torch.true_divide(lag_time, 60)
            lag_time = torch.round(lag_time)
            lag_time2 = torch.where(lag_time.float() <= 1440, lag_time, torch.tensor(1440.0).to(device)).long()

            ## lag_time3
            lag_time = torch.true_divide(lag_time, 1440)
            lag_time = torch.round(lag_time)
            lag_time3 = torch.where(lag_time.float() <= 365, lag_time, torch.tensor(365.0).to(device)).long()

            ## lag time
            lag_time1 = self.lag_embedding1(lag_time1) 
            lag_time2 = self.lag_embedding2(lag_time2) 
            lag_time3 = self.lag_embedding3(lag_time3)

        else:

            elapsed_time = elapsed_time.view(-1,1)
            elapsed_time = self.elapsed_time_embedding(elapsed_time)
            elapsed_time = elapsed_time.view(-1, MAX_SEQ-1, self.embed_dim)

            lag_time = lag_time.view(-1,1)
            lag_time = self.lag_embedding(lag_time)
            lag_time = lag_time.view(-1, MAX_SEQ-1, self.embed_dim)

            # elapsed_time = elapsed_time.view(-1, MAX_SEQ-1, 1)  ## [batch, s_len] => [batch, s_len, 1]
            # elapsed_time = self.elapsed_time_embedding(elapsed_time)


        enc = question + part + pos_id + exp
        dec = pos_id + res + elapsed_time + lag_time1 + lag_time2 + lag_time3

        enc = enc.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        dec = dec.permute(1, 0, 2)
        mask = future_mask(enc.size(0)).to(device)

        att_output = self.transformer(enc, dec, src_mask=mask, tgt_mask=mask, memory_mask = mask)
        att_output = self.layer_normal(att_output)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.layer_normal(x + att_output)
        x = self.pred(x)

        return x.squeeze(-1)

## Load Pretrained Models

In [None]:
n_skill = 13523
group = joblib.load("../input/saint-plus-data-new/group_20210102.pkl.zip")

In [None]:
group

In [None]:
questions_df = pd.read_csv('/kaggle/input/riiid-test-answer-prediction/questions.csv')

In [None]:
time_dict = joblib.load("../input/saint-plus-data-new/time_dict.pkl.zip")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAINTModel(n_skill, n_part, embed_dim= D_MODEL)
try:
    model.load_state_dict(torch.load("../input/saint-plus-model/saint_plus_model_20210108_v1.pt"))
except:
    model.load_state_dict(torch.load("../input/saint-plus-model/saint_plus_model_20210108_v1.pt", map_location='cpu'))
model.to(device)
model.eval()

## Inference

In [None]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skills, max_seq=MAX_SEQ): 
        super(TestDataset, self).__init__()
        self.samples = samples
        self.user_ids = [x for x in test_df["user_id"].unique()]
        self.test_df = test_df
        self.n_skill = n_skills
        self.max_seq = max_seq

    def __len__(self):
        return self.test_df.shape[0]

    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]

        user_id = test_info["user_id"]
        target_id = test_info["content_id"]
        part = test_info["part"]
        pri_quest_elap = test_info["prior_question_elapsed_time"]
        time_lag = test_info["time_lag"]
        pri_quest_exp = test_info["prior_question_had_explanation"]
        
        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        res = np.zeros(self.max_seq, dtype=int)
        p = np.zeros(self.max_seq, dtype=int)
        pri_elap = np.zeros(self.max_seq, dtype=int)
        lag = np.zeros(self.max_seq, dtype=int)
        pri_exp = np.zeros(self.max_seq, dtype=int)

        if user_id in self.samples.index:
            q_, qa_, p_, pri_elap_, lag_, pri_exp_ = self.samples[user_id]
            
            seq_len = len(q_)
            
            ## for zero padding
            q_ = q_+1
            pri_exp_ = pri_exp_ + 1
            res_ = qa_ + 1
            

            if seq_len >= self.max_seq:
                q = q_[-self.max_seq:]
                qa = qa_[-self.max_seq:]
                res = res_[-self.max_seq:]
                p = p_[-self.max_seq:]
                pri_elap = pri_elap_[-self.max_seq:]
                lag = lag_[-self.max_seq:]
                pri_exp = pri_exp_[-self.max_seq:]
                
            else:
                q[-seq_len:] = q_
                qa[-seq_len:] = qa_
                res[-seq_len:] = res_
                p[-seq_len:] = p_
                pri_elap[-seq_len:] = pri_elap_
                lag[-seq_len:] = lag_
                pri_exp[-seq_len:] = pri_exp_
                
        
        exercise = np.append(q[2:], [target_id+1])
        part = np.append(p[2:], [part])
        elap = np.append(pri_elap[2:], [pri_quest_elap])
        lag = np.append(lag[2:], [time_lag])
        pri_exp = np.append(pri_exp[2:], [pri_quest_exp+1])

        response = res[1:]

        return  exercise, part, response, elap, lag, pri_exp

In [None]:
prev_test_df = None

for (test_df, sample_prediction_df) in iter_test:
  
    if (prev_test_df is not None) & (psutil.virtual_memory().percent < 90):
        prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
        prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
        
        ## lag time
        prev_test_df = feature_time_lag(prev_test_df, time_dict)


        prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly', 'part', 'prior_question_elapsed_time', 'time_lag', 'prior_question_had_explanation']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values,
            r['part'].values,
            r['prior_question_elapsed_time'].values,
            r['time_lag'].values,
            r['prior_question_had_explanation'].values))
        
        for prev_user_id in prev_group.index:
            if prev_user_id in group.index:
                group[prev_user_id] = (
                    np.append(group[prev_user_id][0], prev_group[prev_user_id][0])[-MAX_SEQ:], 
                    np.append(group[prev_user_id][1], prev_group[prev_user_id][1])[-MAX_SEQ:],
                    np.append(group[prev_user_id][2], prev_group[prev_user_id][2])[-MAX_SEQ:],
                    np.append(group[prev_user_id][3], prev_group[prev_user_id][3])[-MAX_SEQ:],
                    np.append(group[prev_user_id][4], prev_group[prev_user_id][4])[-MAX_SEQ:],
                    np.append(group[prev_user_id][5], prev_group[prev_user_id][5])[-MAX_SEQ:]
                )
 
            else:
                group[prev_user_id] = (
                    prev_group[prev_user_id][0], 
                    prev_group[prev_user_id][1],
                    prev_group[prev_user_id][2],
                    prev_group[prev_user_id][3],
                    prev_group[prev_user_id][4],
                    prev_group[prev_user_id][5]
                )

            
    ## elapsed time
    test_df.prior_question_elapsed_time = test_df.prior_question_elapsed_time.fillna(0)
    
    ## prior_question_had_explanation
    test_df['prior_question_had_explanation'] = test_df['prior_question_had_explanation'].fillna(value = False).astype(int)
    
    test_df = test_df.merge(questions_df[["question_id","part"]], how = "left",left_on = 'content_id', right_on = 'question_id')  
              
    prev_test_df = test_df.copy()
            
    ## drop lecture
    test_df = test_df[test_df.content_type_id == False]
    
    
    ## lag time
    test_df = feature_time_lag(test_df, time_dict)
    
    
    test_dataset = TestDataset(group, test_df, n_skill)
    test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
    
    outs = []

    for item in test_dataloader:
        exercise = item[0].to(device).long()
        part = item[1].to(device).long()
        response = item[2].to(device).long()
        elapsed_time = item[3].to(device).long()
        lag_time = item[4].to(device).long()
        pri_exp = item[5].to(device).long()
        
        with torch.no_grad():
            output = model(exercise, part, response, elapsed_time, lag_time, pri_exp)
        outs.extend(torch.sigmoid(output)[:, -1].view(-1).data.cpu().numpy())
        
    test_df['answered_correctly'] = outs
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])