# Sakt baseline model

This is the SAKT model baseline inference template following modifications from @mpware and @leadbest. Every irrelevant imports and lines are deleted to save memory (if CUDA is OOM your submission will fail).

I updated some utils functions to make the selection more automatic, if your model is named following this pattern:
> f'head_{n_head}_embed_{embed_dim}_seq_{max_seq}_auc_{val_auc}.pt' 

then simple plugging in your model file on Kaggle and you are good to go.

Before submission it also verifies the auc score on the cv3 file generated by @marisakamozz, which is an improvement over @its7171's great kernel.

Reference:
* https://www.kaggle.com/leadbest/sakt-with-randomization-state-updates
* https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing
* https://www.kaggle.com/leadbest/sakt-self-attentive-knowledge-tracing-submitter
* https://www.kaggle.com/its7171/cv-strategy
* https://www.kaggle.com/mpware/sakt-fork


In [None]:
import gc
import psutil
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

### Insert your model file below, it can be a kernel output or a dataset
change `VALID` to false if you want to skip the validation phase

In [None]:
# some global settings

random.seed(42)
MAX_SEQ = 150
WORKERS = 4
VAL_BATCH_SIZE = 2048
n_skill = 13523
VALID = True

# model_file = '../input/riiid-models/sakt_head_8_embed_128_seq_150_auc_0.7584.pt'
model_file = '../input/riiid-models/sakt_layer_1_head_8_embed_128_seq_150_auc_0.7605.pt'

## Load data and preprocess

In [None]:
%%time
TRAIN_DTYPES = {'timestamp':'int64', 
         'user_id':'int32' ,
         'content_id':'int16',
         'content_type_id':'int8',
         'answered_correctly':'int8'}
TRAIN_COLS = TRAIN_DTYPES.keys()

train_df = pd.read_csv('/kaggle/input/riiid-test-answer-prediction/train.csv', 
                       usecols=[1, 2, 3, 4, 7], dtype=TRAIN_DTYPES)
# train_df = pd.read_parquet('../input/cv-strategy-in-the-kaggle-environment/cv3_train.parquet')
train_df = train_df[TRAIN_COLS].astype(TRAIN_DTYPES)

train_df = train_df[train_df["content_type_id"] == False]

train_df = train_df.sort_values(['timestamp'], ascending=True).reset_index(drop = True)

In [None]:
group = train_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))

del train_df
gc.collect();

## Load model

Change: `num_heads` is now in the initialization.

In [None]:
class FFN(nn.Module):
    def __init__(self, state_size=200):
        super(FFN, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(state_size, state_size)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=MAX_SEQ, embed_dim=128, num_heads=8): 
        super(SAKTModel, self).__init__()
        self.n_skill = n_skill
        self.embed_dim = embed_dim

        self.embedding = nn.Embedding(2*n_skill+1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq-1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)

        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=0.2)

        self.dropout = nn.Dropout(0.2)
        self.layer_normal = nn.LayerNorm(embed_dim) 

        self.ffn = FFN(embed_dim)
        self.pred = nn.Linear(embed_dim, 1)
    
    def forward(self, x, question_ids):
        device = x.device        
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)

        pos_x = self.pos_embedding(pos_id)
        x = x + pos_x

        e = self.e_embedding(question_ids)

        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = e.permute(1, 0, 2)
        att_mask = future_mask(x.size(0)).to(device)
        att_output, att_weight = self.multi_att(e, x, x, attn_mask=att_mask)
        att_output = self.layer_normal(att_output + e)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.layer_normal(x+att_output)
        x = self.pred(x)

        return x.squeeze(-1), att_weight

In [None]:
def load_sakt_model(model_file, device='cuda'):
    # creating the model and load the weights
    configs = []
    model_file_lst = model_file.split('_')
    for c in ['head', 'embed', 'seq']:
        idx = model_file_lst.index(c) + 1
        configs.append(int(model_file_lst[idx]))

    # configs.append(int(model_file[model_file.rfind('head')+5]))
    # configs.append(int(model_file[model_file.rfind('embed')+6:model_file.rfind('embed')+9]))
    # configs.append(int(model_file[model_file.rfind('seq')+4:model_file.rfind('seq')+7]))
    conf_dict = dict(n_skill=n_skill,
                     num_heads=configs[0],
                     embed_dim=configs[1], 
                     max_seq=configs[2], 
                     )

    model = SAKTModel(**conf_dict)
        
    model = model.to(device)
    model.load_state_dict(torch.load(model_file, map_location=device))

    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_sakt_model(model_file, device=device)

model.to(device)
print(model)

## Validation

In [None]:
class SAKTDataset(Dataset):
    def __init__(self, group, n_skill, subset="train", max_seq=MAX_SEQ):
        super(SAKTDataset, self).__init__()
        self.max_seq = max_seq
        self.n_skill = n_skill # 13523
        self.samples = group
        self.subset = subset
        
        # self.user_ids = [x for x in group.index]
        self.user_ids = []
        for user_id in group.index:
            '''
            q: question_id
            qa: question answer correct or not
            '''
            q, qa = group[user_id] 
            if len(q) < 2: # 2 interactions minimum
                continue
            self.user_ids.append(user_id) # user_ids indexes

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index] # Pick a user
        q_, qa_ = self.samples[user_id] # Pick full sequence for user
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)

        if seq_len >= self.max_seq:
            if self.subset == "train":
#                 if seq_len > self.max_seq:
                if random.random() > 0.1:
                    random_start_index = random.randint(0, seq_len - self.max_seq)
                    '''
                    Pick 100 questions, answers, prior question time, 
                    priori question explain from a random index
                    '''
                    end_index = random_start_index + self.max_seq
                    q[:] = q_[random_start_index:end_index] 
                    qa[:] = qa_[random_start_index:end_index] 
                else:
                    q[:] = q_[-self.max_seq:]
                    qa[:] = qa_[-self.max_seq:]
            else:
                q[:] = q_[-self.max_seq:] # Pick last 100 questions
                qa[:] = qa_[-self.max_seq:] # Pick last 100 answers
        else:
            if random.random()>0.1:
                seq_len = random.randint(2,seq_len)
                q[-seq_len:] = q_[:seq_len]
                qa[-seq_len:] = qa_[:seq_len]
            else:
                q[-seq_len:] = q_ # Pick last N question with zero padding
                qa[-seq_len:] = qa_ # Pick last N answers with zero padding
                
        target_id = q[1:] # Ignore first item 1 to 99
        label = qa[1:] # Ignore first item 1 to 99

        # x = np.zeros(self.max_seq-1, dtype=int)
        x = q[:-1].copy() # 0 to 98
        x += (qa[:-1] == 1) * self.n_skill # y = et + rt x E

        return x, target_id,  label
    
    
def valid_epoch(model, valid_iterator, criterion, device="cuda"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []
    len_dataset = len(valid_iterator)
    
    with tqdm(total=len_dataset) as pbar:
        for idx, item in enumerate(valid_iterator): 
            x = item[0].to(device).long()
            target_id = item[1].to(device).long()
            label = item[2].to(device).float()

            with torch.no_grad():
                output, _ = model(x, target_id)
            loss = criterion(output, label)
            valid_loss.append(loss.item())

            output = output[:, -1] # (BS, 1)
            output = torch.sigmoid(output)
            label = label[:, -1] 
            pred = (output >= 0.5).long()

            num_corrects += (pred == label).sum().item()
            num_total += len(label)

            labels.extend(label.view(-1).data.cpu().numpy())
            outs.extend(output.view(-1).data.cpu().numpy())
            
            pbar.set_description(f'val loss batch - {valid_loss[-1]:.4f}')
            pbar.update(1)

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc

In [None]:
if VALID:
    valid_df = pd.read_parquet('../input/cv-strategy-in-the-kaggle-environment/cv3_valid.parquet')
    valid_df = valid_df[TRAIN_COLS]

    valid_df = valid_df[valid_df["content_type_id"] == False]

    valid_group = valid_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))

    del valid_df
    gc.collect();

In [None]:
if VALID:
    valid_dataset = SAKTDataset(valid_group, n_skill, subset="valid")
    val_loader = DataLoader(valid_dataset, 
                            batch_size=VAL_BATCH_SIZE, 
                            shuffle=False)
    criterion = nn.BCEWithLogitsLoss()

    criterion.to(device)
    val_loss, val_acc, val_auc = valid_epoch(model, val_loader, criterion, device=device)
    print(f"Valid: loss - {val_loss:.2f} acc - {val_acc:.4f} auc - {val_auc:.4f}")

## Test

`skill` is not passed to this dataset anymore as only the number of embeddings are needed.

In [None]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skill, max_seq=MAX_SEQ): 
        super(TestDataset, self).__init__()
        self.samples = samples
        self.user_ids = [x for x in test_df["user_id"].unique()]
        self.test_df = test_df
        self.n_skill = n_skill
        self.max_seq = max_seq

    def __len__(self):
        return self.test_df.shape[0]

    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]

        user_id = test_info["user_id"]
        target_id = test_info["content_id"]

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)

        if user_id in self.samples.index:
            q_, qa_ = self.samples[user_id]
            
            seq_len = len(q_)

            if seq_len >= self.max_seq:
                q = q_[-self.max_seq:]
                qa = qa_[-self.max_seq:]
            else:
                q[-seq_len:] = q_
                qa[-seq_len:] = qa_          
        
        x = np.zeros(self.max_seq-1, dtype=int)
        x = q[1:].copy()
        x += (qa[1:] == 1) * self.n_skill
        
        questions = np.append(q[2:], [target_id])
        
        return x, questions

In [None]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
model.eval()

prev_test_df = None

for (test_df, sample_prediction_df) in iter_test:
    if (prev_test_df is not None) & (psutil.virtual_memory().percent<95):
        print(psutil.virtual_memory().percent)
        prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
        prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
        prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))
        for prev_user_id in prev_group.index:
            prev_group_content = prev_group[prev_user_id][0]
            prev_group_ac = prev_group[prev_user_id][1]
            if prev_user_id in group.index:
                group[prev_user_id] = (np.append(group[prev_user_id][0],prev_group_content), 
                                       np.append(group[prev_user_id][1],prev_group_ac))
 
            else:
                group[prev_user_id] = (prev_group_content,prev_group_ac)
            if len(group[prev_user_id][0])>MAX_SEQ:
                new_group_content = group[prev_user_id][0][-MAX_SEQ:]
                new_group_ac = group[prev_user_id][1][-MAX_SEQ:]
                group[prev_user_id] = (new_group_content,new_group_ac)

    prev_test_df = test_df.copy()
    
    test_df = test_df[test_df.content_type_id == False]
                
    test_dataset = TestDataset(group, test_df, n_skill)
    test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
    
    outs = []

    for item in test_dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()

        with torch.no_grad():
            output, _ = model(x, target_id)
            
        output = torch.sigmoid(output)
        output = output[:, -1]
        outs.extend(output.view(-1).data.cpu().numpy())
        
    test_df['answered_correctly'] =  outs
    
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])

## Sanity check
If the average of probability is not skewed toward 1, then probably your model is problematic.

In [None]:
sns.set()
sub = pd.read_csv('../working/submission.csv')
sub['answered_correctly'].hist();