# Main contribution

This is based on many great kernels of SAKT model (see Refs).

The main change is inspired by the current 2nd place @mamasinkgs 's comment: https://www.kaggle.com/c/riiid-test-answer-prediction/discussion/204801#1128342

> I can't share detailed information, but I think no difficult technique is required to get to 0.810. What are truly necessary are stable validation, careful implementation for dataloader/model/training loop, and a stable system that avoids bug in inference.

So I dissect into the inference loop using @its7171 's `iter_env`'s simulation to check what happens. Since the test is given in a chronological order, there may only be a few interactions of a single user in one `test_df`, as such, we have to code the train loader to capture this. The illustration of this is in the inference part.

So the main change is the train loader:
- The train loader now uses the `timestamp` information to cut the sequences. There is a timestamp thresh to determine whether in the current seqence, the last and the second last entries in train could be in the same batch if they are from test.
- Additionally shifts the sequence as well so there are much much more sequences in 1 epoch. Using this train loader the model can reach > 0.78 CV in a few epochs.
- Unfortunately this training strategy does not fit Kaggle's memory, so I have to train the model locally.


Reference:
- https://www.kaggle.com/gilfernandes/riiid-self-attention-transformer
- https://www.kaggle.com/manikanthr5/riiid-sakt-model-training-public
- https://www.kaggle.com/leadbest/sakt-with-randomization-state-updates
- https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing

In [None]:
import gc
import random
from tqdm.notebook import tqdm
import numpy as np 
import pandas as pd 
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import psutil
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt

import pickle
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import datatable as dt

In [None]:
TRAIN = False
n_skill = 13523

NUM_SKILLS = 13523 # number of problems
MAX_SEQ = 180
ACCEPTED_USER_CONTENT_SIZE = 5
NUM_EMBED = 128
RECENT_SIZE = 10 # recent data in the training loop
TIMESTAMP_GAP = 20_000
NUM_HEADS = 8
BATCH_SIZE = 64
VAL_BATCH_SIZE = 512
TEST_SIZE = 0.05
DROPOUT = 0.1
SEED = 1127

# Data

In [None]:
%%time

train_dtypes = {
    'content_type_id': 'bool',
    'timestamp': 'int64',
    'user_id': 'int32', 
    'content_id': 'int16', 
    'answered_correctly': 'int8', 
    'prior_question_elapsed_time': 'float32', 
    'prior_question_had_explanation': 'bool'
}
target = 'answered_correctly'
train_df = dt.fread('../input/riiid-test-answer-prediction/train.csv', 
                   columns=set(train_dtypes.keys())).to_pandas().astype(train_dtypes)

# train_df = pd.read_parquet('../input/cv-strategy-in-the-kaggle-environment/cv3_train.parquet')
# train_df = train_df[train_dtypes.keys()]
# train_df = train_df.astype(train_dtypes)


train_df = train_df[train_df.content_type_id == False]
#arrange by timestamp
train_df = train_df.sort_values(['timestamp'], ascending=True).reset_index(drop = True)

# del train_df['timestamp']
del train_df['content_type_id']

### Pre-process
Adding the `timestamp`.

In [None]:
%%time
group = train_df[['user_id', 'content_id', 'answered_correctly','timestamp']]\
            .groupby('user_id')\
            .apply(lambda r: (r['content_id'].values, 
                              r['answered_correctly'].values,
                              r['timestamp'].values))

del train_df

In [None]:
# %%time
# with open('../input/riiid-sakt-baseline-user-group/group.pickle', 'rb') as f:
#     group = pickle.load(f)

In [None]:
group, val_group = train_test_split(group, test_size = TEST_SIZE, random_state=SEED)

## Main change: train loaders

In [None]:
class SAKTDataset(Dataset):
    def __init__(self, group, n_skill, 
                        max_seq=MAX_SEQ, 
                        min_seq=ACCEPTED_USER_CONTENT_SIZE,
                        recent_seq=RECENT_SIZE,
                        gap=TIMESTAMP_GAP):
        super(SAKTDataset, self).__init__()
        self.samples = {}
        self.n_skill = n_skill
        self.max_seq = max_seq
        self.min_seq = min_seq
        self.recent_seq = recent_seq
        self.gap = gap
        
        self.user_ids = []
        for i, user_id in enumerate(group.index):
            if i%10000==0: print(f"{i} users processed")
            
            content_id, answered_correctly, timestamp = group[user_id]

            if len(content_id) >= self.min_seq:

                if len(content_id) > self.max_seq:
                    '''
                    Case 1: longer than max_seq, break into sub-seqs.
                    '''
                    total_questions = len(content_id)
                    num_seq_user = total_questions // self.max_seq
                    for seq in range(num_seq_user):
                        index = f"{user_id}_{seq}"
                        self.user_ids.append(index)
                        start = seq * self.max_seq
                        end = (seq + 1) * self.max_seq
                        '''
                        New contribution #2:
                        timestamp[end-1] the original last entry's timestamp
                        the difference with previous entry should be bigger than a threshold
                        otherwise they may be in the same test_df batch
                        '''
                        idx_same_bundle = []
                        while timestamp[end-1] - timestamp[end-2] < self.gap and end-start >= self.min_seq:
                            if timestamp[end-1] == timestamp[end-2]:
                                idx_same_bundle.append(end-1)
                            end -= 1     
                        self.samples[index] = (content_id[start:end], 
                                               answered_correctly[start:end],
                                               timestamp[start:end]
                                               )
                        
                        if idx_same_bundle and end-1-start >= self.min_seq: 
                            # seeing multiple questions at a time, this list is not empty
                            for j, idx in enumerate(idx_same_bundle):
                                index = f"{user_id}_{seq}_{j}"
                                self.user_ids.append(index)
                                self.samples[index] = (np.r_[content_id[start:end-1], content_id[idx]], 
                                                       np.r_[answered_correctly[start:end-1], answered_correctly[idx]],
                                                       np.r_[timestamp[start:end-1], timestamp[idx]]
                                                       )
                    '''
                    left-over sequence
                    '''
                    content_id_last = content_id[end:]
                    answered_correctly_last = answered_correctly[end:]
                    timestamp_last = timestamp[end:]
                    end = len(content_id_last)

                    if end >= self.min_seq and end <= self.max_seq:
                        idx_same_bundle = []
                        while timestamp_last[end-1] - timestamp_last[end-2] < self.gap and end >= self.min_seq:
                            if timestamp_last[end-1] == timestamp_last[end-2]:
                                idx_same_bundle.append(end-1)
                            end -= 1

                        index = f"{user_id}_{num_seq_user + 1}"
                        self.user_ids.append(index)
                        self.samples[index] = (content_id_last[:end], 
                                               answered_correctly_last[:end],
                                               timestamp_last[:end]
                                               )

                        if idx_same_bundle and end >= self.min_seq: 
                            # seeing multiple questions at a time, this list is not empty
                            for j, idx in enumerate(idx_same_bundle):
                                index = f"{user_id}_{num_seq_user + 1}_{j}"
                                self.user_ids.append(index)
                                self.samples[index] = (np.r_[content_id_last[:end-1], content_id_last[idx]], 
                                                       np.r_[answered_correctly_last[:end-1], answered_correctly_last[idx]],
                                                       np.r_[timestamp_last[:end-1], timestamp_last[idx]]
                                                       )
                else: # len(content_id) <= self.max_seq
                    '''
                    Case 2: shorter than max_seq, keep them all
                    '''
                    index = f'{user_id}'
                    end = len(timestamp)
                    idx_same_bundle = []
                    # last time stamp diff should be bigger than a threshold
                    while timestamp[end-1] - timestamp[end-2] < self.gap and end >= self.min_seq:
                        if timestamp[end-1] == timestamp[end-2]:
                                idx_same_bundle.append(end-1)
                        end -= 1
                    self.user_ids.append(index)
                    self.samples[index] = (content_id[:end], 
                                           answered_correctly[:end],
                                           timestamp[:end],
                                           )

                    if idx_same_bundle and end >= self.min_seq: 
                        # seeing multiple questions at a time, this list is not empty
                        for j, idx in enumerate(idx_same_bundle):
                            index = f"{user_id}_{j}"
                            self.user_ids.append(index)
                            self.samples[index] = (np.r_[content_id[:end-1], content_id[idx]], 
                                                   np.r_[answered_correctly[:end-1], answered_correctly[idx]],
                                                   np.r_[timestamp[:end-1], timestamp[idx]]
                                                   )
            '''
            New contribution #1
            Adding a shifted sequence, now train_loader has much more sequences per epoch
            '''
            if self.recent_seq is None: self.recent_seq = self.max_seq + 1
                
            if len(content_id) >= 2*self.recent_seq: #
                for i in range(1, self.recent_seq): # adding a shifted sequence
                    '''
                    Shifting cases:
                    generating much much more sequences by shifting the last few entries
                    '''
                    content_id_shift = content_id[:-i]
                    answered_correctly_shift = answered_correctly[:-i]
                    timestamp_shift = timestamp[:-i]
                    if len(content_id_shift) >= self.min_seq:
                        if len(content_id_shift) > self.max_seq:
                            '''
                            Case S 1: shifted seq greater than max_seq, break into pieces
                            '''
                            total_questions_2 = len(content_id_shift)
                            num_seq_user = total_questions_2 // self.max_seq

                            for seq in range(num_seq_user):

                                index = f"{user_id}_{seq}_{i}_s"
                                self.user_ids.append(index)
                                start = seq * self.max_seq
                                end = (seq + 1) * self.max_seq

                                idx_same_bundle = []
                                while timestamp_shift[end-1] - timestamp_shift[end-2] < self.gap and end-start >= self.min_seq:
                                    if timestamp_shift[end-1] == timestamp_shift[end-2]:
                                        idx_same_bundle.append(end-1)
                                    end -= 1

                                self.samples[index] = (content_id_shift[start:end], 
                                                       answered_correctly_shift[start:end],
                                                       timestamp_shift[start:end]
                                                       )

                                if idx_same_bundle and end-1-start >= self.min_seq: 
                                    # seeing multiple questions at a time, this list is not empty
                                    for j, idx in enumerate(idx_same_bundle):
                                        index = f"{user_id}_{seq}_{i}_s_{j}"
                                        self.user_ids.append(index)
                                        self.samples[index] = (np.r_[content_id_shift[start:end-1], content_id_shift[idx]], 
                                        np.r_[answered_correctly_shift[start:end-1], answered_correctly_shift[idx]],
                                        np.r_[timestamp_shift[start:end-1], timestamp_shift[idx]]
                                        )
                            '''
                            left-over sequence
                            '''
                            content_id_last = content_id_shift[end:]
                            answered_correctly_last = answered_correctly_shift[end:]
                            timestamp_last = timestamp_shift[end:]
                            end = len(content_id_last)              
                            if end >= self.min_seq and end <= self.max_seq:
                                idx_same_bundle = []
                                while timestamp_last[end-1] - timestamp_last[end-2] < self.gap and end >= self.min_seq:
                                    if timestamp_last[end-1] == timestamp_last[end-2]:
                                        idx_same_bundle.append(end-1)
                                    end -= 1

                                index = f"{user_id}_{num_seq_user + 1}_{i}_s"
                                self.user_ids.append(index)
                                self.samples[index] = (content_id_last[:end], 
                                                       answered_correctly_last[:end],
                                                       timestamp_last[:end]
                                                       )

                                if idx_same_bundle and end >= self.min_seq: 
                                    # seeing multiple questions at a time, this list is not empty
                                    for j, idx in enumerate(idx_same_bundle):
                                        index = f"{user_id}_{num_seq_user + 1}_{i}_s_{j}"
                                        self.user_ids.append(index)
                                        self.samples[index] = (np.r_[content_id_last[:end-1], content_id_last[idx]], 
                                        np.r_[answered_correctly_last[:end-1], answered_correctly_last[idx]],
                                        np.r_[timestamp_last[:end-1], timestamp_last[idx]]
                                        )
                        else: #len(content_id_shift) <= self.max_seq
                            '''
                            Case S 2: shifted seq less than or equal to max_seq
                            '''
                            index = f'{user_id}_{i}_s'
                            end = len(timestamp_shift)
                            idx_same_bundle = []
                            # last time stamp diff should be bigger than a threshold
                            while timestamp_shift[end-1] - timestamp_shift[end-2] < self.gap and end >= self.min_seq:
                                if timestamp_shift[end-1] == timestamp_shift[end-2]:
                                        idx_same_bundle.append(end-1)
                                end -= 1
                            self.user_ids.append(index)
                            self.samples[index] = (content_id_shift[:end], 
                                                   answered_correctly_shift[:end],
                                                   timestamp_shift[:end]
                                                   )

                            if idx_same_bundle and end >= self.min_seq: 
                                # seeing multiple questions at a time, this list is not empty
                                for j, idx in enumerate(idx_same_bundle):
                                    index = f"{user_id}_{i}_s_{j}"
                                    self.user_ids.append(index)
                                    self.samples[index] = (np.r_[content_id_shift[:end-1], content_id_shift[idx]], 
                                                           np.r_[answered_correctly_shift[:end-1], answered_correctly_shift[idx]],
                                                           np.r_[timestamp_shift[:end-1], timestamp_shift[idx]]
                                                           )

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        # content_id, answered_correctly = self.samples[user_id]
        content_id, answered_correctly, timestamp = self.samples[user_id]
        seq_len = len(content_id)
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        timestamp_seq = np.zeros(self.max_seq, dtype=int)

        if seq_len >= self.max_seq:
            content_id_seq[:] = content_id[-self.max_seq:]
            answered_correctly_seq[:] = answered_correctly[-self.max_seq:]
            timestamp_seq[:] = timestamp[-self.max_seq:]
        else:
            content_id_seq[-seq_len:] = content_id
            answered_correctly_seq[-seq_len:] = answered_correctly
            timestamp_seq[-seq_len:] = timestamp
            
        target_id = content_id_seq[1:] # question including the current one
        label = answered_correctly_seq[1:] # answers including the current
        timestamp = timestamp_seq[1:] # timestamp including the current

        x = content_id_seq[:-1].copy() # question till the previous one
        # encoded answers till the previous one as the past correctly answering seq
        x += (answered_correctly_seq[:-1] == 1) * self.n_skill
        
        return x, target_id, label, timestamp
    
    
class SAKTValDataset(Dataset):
    '''
    Only for validation
    '''
    def __init__(self, group, n_skill, 
                        max_seq=MAX_SEQ, 
                        min_seq=ACCEPTED_USER_CONTENT_SIZE,
                        recent_seq=None):
        super(SAKTValDataset, self).__init__()
        self.samples = {}
        self.n_skill = n_skill
        self.max_seq = max_seq
        self.min_seq = min_seq
        self.recent_seq = recent_seq
        
        self.user_ids = []
        for i, user_id in enumerate(group.index):
            try:
                content_id, answered_correctly = group[user_id]
            except:
                content_id, answered_correctly,_ = group[user_id]
            if len(content_id) >= self.min_seq:
                if len(content_id) > self.max_seq:
                    total_questions = len(content_id)
                    last_pos = total_questions // self.max_seq
                    for seq in range(last_pos):
                        index = f"{user_id}_{seq}"
                        self.user_ids.append(index)
                        start = seq * self.max_seq
                        end = (seq + 1) * self.max_seq
                        self.samples[index] = (content_id[start:end], 
                                               answered_correctly[start:end])
                    if len(content_id[end:]) >= self.min_seq:
                        index = f"{user_id}_{last_pos + 1}"
                        self.user_ids.append(index)
                        self.samples[index] = (content_id[end:], 
                                               answered_correctly[end:])
                else:
                    index = f'{user_id}'
                    self.user_ids.append(index)
                    self.samples[index] = (content_id, answered_correctly)
            '''
            New: adding a shifted sequence
            '''
            if self.recent_seq is None: self.recent_seq = self.max_seq + 1
            if len(content_id) >= 2*self.recent_seq: #
                for i in range(1, self.recent_seq): # adding a shifted sequence
                    '''
                    generating much much more sequences by truncating
                    '''
                    content_id_truncated_end = content_id[:-i]
                    answered_correctly_truncated_end = answered_correctly[:-i]
                    if len(content_id_truncated_end) >= self.min_seq:
                        if len(content_id_truncated_end) > self.max_seq:
                            total_questions_2 = len(content_id_truncated_end)
                            last_pos = total_questions_2 // self.max_seq
                            for seq in range(last_pos):
                                index = f"{user_id}_{seq}_{i}_2"
                                self.user_ids.append(index)
                                start = seq * self.max_seq
                                end = (seq + 1) * self.max_seq
                                self.samples[index] = (content_id_truncated_end[start:end], 
                                                    answered_correctly_truncated_end[start:end])
                            if len(content_id_truncated_end[end:]) >= self.min_seq:
                                index = f"{user_id}_{last_pos + 1}_{i}_2"
                                self.user_ids.append(index)
                                self.samples[index] = (content_id_truncated_end[end:], 
                                                    answered_correctly_truncated_end[end:])
                        else:
                            index = f'{user_id}_{i}_2'
                            self.user_ids.append(index)
                            self.samples[index] = (content_id_truncated_end, 
                                                   answered_correctly_truncated_end)


    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        content_id, answered_correctly = self.samples[user_id]
        seq_len = len(content_id)
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)

        if seq_len >= self.max_seq:
            content_id_seq[:] = content_id[-self.max_seq:]
            answered_correctly_seq[:] = answered_correctly[-self.max_seq:]
        else:
            content_id_seq[-seq_len:] = content_id
            answered_correctly_seq[-seq_len:] = answered_correctly
            
        target_id = content_id_seq[1:] # question including the current one
        label = answered_correctly_seq[1:]
        
        x = content_id_seq[:-1].copy() # question till the previous one
        # encoded answers till the previous one
        x += (answered_correctly_seq[:-1] == 1) * self.n_skill
        
        return x, target_id, label

In [None]:
%%time
train_dataset = SAKTDataset(group, 
                            n_skill=NUM_SKILLS, 
                            max_seq=MAX_SEQ,
                            min_seq=ACCEPTED_USER_CONTENT_SIZE, 
                            recent_seq=None,
                            gap=TIMESTAMP_GAP)
train_dataloader = DataLoader(train_dataset, 
                        batch_size=BATCH_SIZE, 
                        num_workers=4,
                        shuffle=True, 
                        drop_last=True)

In [None]:
%%time
val_dataset = SAKTValDataset(val_group, n_skill=NUM_SKILLS, max_seq=MAX_SEQ)
val_dataloader = DataLoader(val_dataset, 
                        batch_size=VAL_BATCH_SIZE, 
                            num_workers=4,
                        shuffle=False)

In [None]:
print(len(train_dataloader), len(val_dataloader))

In [None]:
sample_batch = next(iter(train_dataloader))
sample_batch[0].shape, sample_batch[1].shape, sample_batch[2].shape

### Define model

In [None]:
class FFN(nn.Module):
    def __init__(self, state_size = MAX_SEQ, 
                    forward_expansion = 1, 
                    bn_size=MAX_SEQ - 1, 
                    dropout=DROPOUT):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(bn_size)
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.bn(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = (np.triu(np.ones([seq_length, seq_length]), k = 1)).astype('bool')
    return torch.from_numpy(future_mask)



class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, 
                    heads = NUM_HEADS, 
                    dropout = DROPOUT, 
                    forward_expansion = 1):
        super(TransformerBlock, self).__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, 
                        num_heads=heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_normal = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim, 
                    forward_expansion = forward_expansion, 
                    dropout=dropout)
        self.layer_normal_2 = nn.LayerNorm(embed_dim)
        

    def forward(self, value, key, query, att_mask):
        att_output, att_weight = self.multi_att(value, key, query, attn_mask=att_mask)
        att_output = self.dropout(self.layer_normal(att_output + value))
        att_output = att_output.permute(1, 0, 2) 
        # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        x = self.ffn(att_output)
        x = self.dropout(self.layer_normal_2(x + att_output))
        return x.squeeze(-1), att_weight
    
class Encoder(nn.Module):
    def __init__(self, n_skill, max_seq=MAX_SEQ, 
                 embed_dim=NUM_EMBED, 
                 dropout = DROPOUT, 
                 forward_expansion = 1, 
                 num_layers=1, 
                 heads = NUM_HEADS):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(2 * n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq - 1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, heads=heads,
                forward_expansion = forward_expansion) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids):
        device = x.device
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        x = self.dropout(x + pos_x)
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = self.e_embedding(question_ids)
        e = e.permute(1, 0, 2)
        for layer in self.layers:
            att_mask = future_mask(e.size(0)).to(device)
            x, att_weight = layer(e, x, x, att_mask=att_mask)
            x = x.permute(1, 0, 2)
        x = x.permute(1, 0, 2)
        return x, att_weight

class SAKTModel(nn.Module):
    def __init__(self, 
                n_skill, 
                max_seq=MAX_SEQ, 
                embed_dim=NUM_EMBED, 
                dropout = DROPOUT, 
                forward_expansion = 1, 
                enc_layers=1, 
                heads = NUM_HEADS):
        super(SAKTModel, self).__init__()
        self.encoder = Encoder(n_skill, 
                               max_seq, 
                               embed_dim, 
                               dropout, 
                               forward_expansion, 
                               num_layers=enc_layers,
                               heads=heads)
        self.pred = nn.Linear(embed_dim, 1)
        
    def forward(self, x, question_ids):
        x, att_weight = self.encoder(x, question_ids)
        x = self.pred(x)
        return x.squeeze(-1), att_weight

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def get_num_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    n_params = sum([np.prod(p.size()) for p in model_parameters])
    return n_params

model = SAKTModel(n_skill, 
                  max_seq=MAX_SEQ, 
                  embed_dim=NUM_EMBED, 
                  heads=NUM_HEADS, 
                  dropout=DROPOUT)

n_params = get_num_params(model)
print(f"Current model has {n_params} parameters.")

In [None]:
model(sample_batch[0], sample_batch[1])[0].size()

### Training

In [None]:
MODEL_PATH = '/kaggle/working/sakt.pth'

In [None]:
def load_from_item(item):
    x = item[0].to(device).long()
    target_id = item[1].to(device).long()
    label = item[2].to(device).float()
    target_mask = (target_id != 0)
    return x, target_id, label, target_mask

def update_stats(train_loss, loss, 
                 output, label, num_corrects, num_total, 
                 labels, outs):
    train_loss.append(loss.item())
    pred = (torch.sigmoid(output) >= 0.5).long()
    num_corrects += (pred == label).sum().item()
    num_total += len(label)
    labels.extend(label.view(-1).data.cpu().numpy())
    outs.extend(output.view(-1).data.cpu().numpy())
    return num_corrects, num_total

def train_epoch(model, dataloader, optim, criterion, scheduler, device="cpu"):
    model.train()
    
    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []
    
    tbar = tqdm(dataloader)
    for k, item in enumerate(tbar):
        x, target_id, label, target_mask = load_from_item(item)
        
        optim.zero_grad()
        output, _ = model(x, target_id)
        
        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        
        loss = criterion(output, label)
        
        loss.backward()
        optim.step()
        scheduler.step()
        train_loss.append(loss.item())
        if k % 10 == 0:
            tbar.set_description('Train loss - {:.4f}'.format(np.mean(train_loss)))
            tbar.update(10)

def val_epoch(model, val_iterator, criterion, device="cpu"):
    model.eval()

    val_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(val_iterator)
    for item in tbar:
        x, target_id, label, target_mask = load_from_item(item)

        with torch.no_grad():
            output, atten_weight = model(x, target_id)
        
        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)

        loss = criterion(output, label) 
        
        num_corrects, num_total = update_stats(val_loss, 
                                               loss, 
                                               output, 
                                               label, 
                                               num_corrects, 
                                               num_total, 
                                               labels, 
                                               outs)

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.average(val_loss)

    return loss, acc, auc


In [None]:
def do_train(lr=1e-3, 
             epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                    max_lr=lr, 
                                                    steps_per_epoch=len(train_dataloader), 
                                                    epochs=epochs)
    model.to(device)
    criterion.to(device)
    best_auc = 0.0
    for epoch in range(epochs):
        train_epoch(model, train_dataloader, optimizer, criterion, scheduler, device)
        val_loss, val_acc, val_auc = val_epoch(model, val_dataloader, criterion, device)
        print(f"epoch - {epoch + 1} val_loss - {val_loss:.3f} acc - {val_acc:.3f} auc - {val_auc:.3f}")
        if best_auc < val_auc:
            print(f'epoch - {epoch + 1} best model with val auc: {val_auc}')
            best_auc = val_auc
        torch.save(model.state_dict(), MODEL_PATH)

In [None]:
if TRAIN:

    LR = 1e-3
    EPOCHS = 3
    do_train(lr=LR, epochs=EPOCHS)

    LR = 1e-4
    EPOCHS = 2
    do_train(lr=LR, epochs=EPOCHS)

# Inference

In [None]:
model = SAKTModel(n_skill, 
                  max_seq=MAX_SEQ, 
                  embed_dim=NUM_EMBED, 
                  heads=NUM_HEADS, 
                  dropout=DROPOUT)
MODEL_PATH = '../input/riiid-models/sakt_seq_180_auc_0.7836.pt'
model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
model.to(device)
model.eval();

In [None]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skill, max_seq=100):
        super(TestDataset, self).__init__()
        self.samples, self.user_ids, self.test_df = samples, [x for x in test_df["user_id"].unique()], test_df
        self.n_skill, self.max_seq = n_skill, max_seq

    def __len__(self):
        return self.test_df.shape[0]
    
    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]
        
        user_id = test_info['user_id']
        target_id = test_info['content_id']
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        
        if user_id in self.samples.index:
            content_id, answered_correctly = self.samples[user_id]
            
            seq_len = len(content_id)
            
            if seq_len >= self.max_seq:
                content_id_seq = content_id[-self.max_seq:]
                answered_correctly_seq = answered_correctly[-self.max_seq:]
            else:
                content_id_seq[-seq_len:] = content_id
                answered_correctly_seq[-seq_len:] = answered_correctly
                
        x = content_id_seq[1:].copy()
        x += (answered_correctly_seq[1:] == 1) * self.n_skill
        
        questions = np.append(content_id_seq[2:], [target_id])
        
        return x, questions

In [None]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()

prev_test_df = None

In the example test there are 4 batches of test_df, aside from the placeholder user `275030867`, there are 3 other common users

In [None]:
batch_1_user_id = [ 275030867,  554169193, 1720860329,  288641214, 1728340777,
       1364159702, 1521618396, 1317245193, 1700555100,  998511398,
       1422853669, 1096784725,  385471210, 1202386221, 2018567473]
batch_2_user_id = [ 275030867, 1233875513,  891955351, 1981166446, 1637273633,
       2030979309,  319060572,  288641214,   98059812,  674533997,
        555691277, 1317245193, 1202386221,  775113212, 1219481379]
batch_3_user_id = [ 275030867, 1521618396, 1148874033,  554169193, 1281335472,
        998511398, 2002570769,  706626847, 1422853669, 1357500007,
       2018567473, 1720860329,  674533997, 1202386221,  891955351,
       1317245193,  385471210,  555691277,  288641214, 1364159702,
       1599808246,   98059812, 1728340777]
batch_4_user_id = [ 275030867, 1305988022, 1310228392, 1637273633,  674533997,
       2093197291, 1202386221, 1468996389,  555691277, 1838324752,
       2103436554,  311890082, 1817433235,  998511398, 1422853669,
        554169193, 1317245193, 1900527744,    7792299,  288641214,
       2018567473]

In [None]:
set(batch_1_user_id).intersection(batch_2_user_id).intersection(batch_3_user_id).intersection(batch_4_user_id)

## First iteration

Let us look at the user `288641214`

In [None]:
test_df, sample_prediction_df =next(iter_test) 
    
if prev_test_df is not None:
    prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
    prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
    prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']]\
    .groupby('user_id').apply(lambda r: (
        r['content_id'].values,
        r['answered_correctly'].values))
    for prev_user_id in prev_group.index:
        prev_group_content = prev_group[prev_user_id][0]
        prev_group_answered_correctly = prev_group[prev_user_id][1]
        if prev_user_id in group.index:
            group[prev_user_id] = (np.append(group[prev_user_id][0], prev_group_content), 
                                   np.append(group[prev_user_id][1], prev_group_answered_correctly))
        else:
            group[prev_user_id] = (prev_group_content, prev_group_answered_correctly)

        if len(group[prev_user_id][0]) > MAX_SEQ:
            new_group_content = group[prev_user_id][0][-MAX_SEQ:]
            new_group_answered_correctly = group[prev_user_id][1][-MAX_SEQ:]
            group[prev_user_id] = (new_group_content, new_group_answered_correctly)

prev_test_df = test_df.copy()
test_df = test_df[test_df.content_type_id == False]
test_df['answered_correctly'] = 0.5
env.predict(test_df.loc[test_df['content_type_id'] == 0, 
                        ['row_id', 'answered_correctly']])

The last 10 question sequence this user sees is `[12098,   901,  1240,   873, 12075,   592,   276,   429, 11912, 13262]`

In [None]:
test_df[test_df.user_id==288641214]

## second iteration
The `timestamp` from the first `test_df` is 62798072960. In the second iteration, this user sequence is updated in the last bit. The `timestamp` is 62798100988, and there is a lag of 28028. So if the model wants to thrive, the train loader needs to able to generate this type of sequence. A new question is appended to the end: `[  901,  1240,   873, 12075,   592,   276,   429, 11912, 13262,  5418]`

In [None]:
test_df, sample_prediction_df =next(iter_test) 
    
if prev_test_df is not None:
    prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
    prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
    prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']]\
    .groupby('user_id').apply(lambda r: (
        r['content_id'].values,
        r['answered_correctly'].values))
    for prev_user_id in prev_group.index:
        prev_group_content = prev_group[prev_user_id][0]
        prev_group_answered_correctly = prev_group[prev_user_id][1]
        if prev_user_id in group.index:
            group[prev_user_id] = (np.append(group[prev_user_id][0], prev_group_content), 
                                   np.append(group[prev_user_id][1], prev_group_answered_correctly))
        else:
            group[prev_user_id] = (prev_group_content, prev_group_answered_correctly)

        if len(group[prev_user_id][0]) > MAX_SEQ:
            new_group_content = group[prev_user_id][0][-MAX_SEQ:]
            new_group_answered_correctly = group[prev_user_id][1][-MAX_SEQ:]
            group[prev_user_id] = (new_group_content, new_group_answered_correctly)

prev_test_df = test_df.copy()
test_df = test_df[test_df.content_type_id == False]
test_df['answered_correctly'] = 0.5
env.predict(test_df.loc[test_df['content_type_id'] == 0, 
                        ['row_id', 'answered_correctly']])

In [None]:
prev_group[288641214] # prev group is updated to have the correct answer despite the first prediction was wrong.

In [None]:
test_df[test_df.user_id==288641214]

## Third batch
Again the `timestamp` is now 34626 greater than the second `test_df` for this user. The question sequence is updated to `[ 1240,   873, 12075,   592,   276,   429, 11912, 13262,  5418,  5620]`, and in the fourth batch the question sequence is updated to `[  873, 12075,   592,   276,   429, 11912, 13262,  5418,  5620,  9077]`. We can generate more of these sequences by shifting the sequence a tiny bit as in the train loader I wrote.

In prediction, only the last entry is used, the loss function can have a decaying weight to weigh the later prediction more, however, I did not get any CV increase using this method.

In [None]:
test_df, sample_prediction_df =next(iter_test) 
    
if prev_test_df is not None:
    prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
    prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
    prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']]\
    .groupby('user_id').apply(lambda r: (
        r['content_id'].values,
        r['answered_correctly'].values))
    for prev_user_id in prev_group.index:
        prev_group_content = prev_group[prev_user_id][0]
        prev_group_answered_correctly = prev_group[prev_user_id][1]
        if prev_user_id in group.index:
            group[prev_user_id] = (np.append(group[prev_user_id][0], prev_group_content), 
                                   np.append(group[prev_user_id][1], prev_group_answered_correctly))
        else:
            group[prev_user_id] = (prev_group_content, prev_group_answered_correctly)

        if len(group[prev_user_id][0]) > MAX_SEQ:
            new_group_content = group[prev_user_id][0][-MAX_SEQ:]
            new_group_answered_correctly = group[prev_user_id][1][-MAX_SEQ:]
            group[prev_user_id] = (new_group_content, new_group_answered_correctly)

prev_test_df = test_df.copy()
test_df = test_df[test_df.content_type_id == False]
test_df['answered_correctly'] = 0.5
env.predict(test_df.loc[test_df['content_type_id'] == 0, 
                        ['row_id', 'answered_correctly']])

In [None]:
test_df[test_df.user_id==288641214]