## Preliminaries

In [None]:
import os
import sys
import math
import pickle
import psutil
import random

import json
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd

import riiideducation

In [None]:
seed = 0
random.seed(seed)
torch.random.manual_seed(seed)

n_workers = os.cpu_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg_path = '/kaggle/input/riiid-mydata/cfg.json'
train_path = '/kaggle/input/riiid-mydata/train.pkl'
tag_path = '/kaggle/input/riiid-mydata/tags.csv'
states_path = '/kaggle/input/riiid-mydata/states.pickle'
model_path = '/kaggle/input/riiid-mydata/SAKT_08-12_06-24.pt'

B = 512
MAX_LEN = 256
D_MODEL = 256

## SAKT Model

In [None]:
class FFN(nn.Module): 
    def __init__(self, d_model, dropout=0.0): 
        super().__init__()
        self.lr1 = nn.Linear(d_model, d_model)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x): 
        x = self.lr1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.lr2(x)
        return x


class SAKTModel(nn.Module): 
    def __init__(self, n_exercises, max_len=160, d_model=200, dropout=0.2): 
        super().__init__()
        self._n_exercises = n_exercises
        self.position_embd = nn.Embedding(max_len, d_model)
        self.interaction_embd = nn.Embedding(2 * n_exercises + 1, d_model)
        self.exercise_embd = nn.Embedding(n_exercises + 1, d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout)
        self.ffn = FFN(d_model, dropout=dropout)
        self.predict = nn.Linear(d_model, 1)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    @property
    def n_exercises(self): 
        return self._n_exercises

    def forward(self, x, e, attn_mask=None): 
        seq_len = x.shape[1]
        device = x.device
        pos = torch.arange(seq_len, device=device).unsqueeze(0)
        
        x = self.interaction_embd(x) + self.position_embd(pos)
        e = self.exercise_embd(e)

        # (B, L, d) -> (L, B, d)
        x = x.transpose(0, 1)
        e = e.transpose(0, 1)
        x, attn_weights = self.attn(e, x, x, attn_mask=attn_mask)
        x = self.norm1(x + e)

        # (L, B, d) => (B, L, d)
        x = x.transpose(0, 1)
        x = self.norm2(x + self.ffn(x))
        x = self.predict(x)
        return x.squeeze(-1), attn_weights

## Testing Phase

In [None]:
class KaggleOnlineDataset(Dataset): 
    def __init__(self, train_path, tag_path, states_path, n_exercises, cols, max_len): 
        super().__init__()
        self.df = pd.read_pickle(train_path)
        self.test_df = None
        tag_df = pd.read_csv(tag_path, usecols=['exercise_id', 'bundle_id', 'part', 'correct_rate', 'frequency'])
        assert np.all(tag_df['exercise_id'].values == np.arange(n_exercises))
        self.parts, self.correct_rate, self.frequency = tag_df[['part', 'correct_rate', 'frequency']].values.T
        self.lag_info = pickle.load(open(states_path, 'rb'))
        self.n_exercises = n_exercises
        self.cols = cols
        self.max_len = max_len

    def __len__(self): 
        assert self.test_df is not None, 'Please call update() first'
        return len(self.test_df)

    def __getitem__(self, idx): 
        new_observation = self.test_df.iloc[idx]
        # 'correct' is set to 0 temporarily
        user_id = new_observation['user_id']
        new_data = {col: np.array([new_observation.get(col, 0)]) for col in self.cols}
        # retrieve old observations
        if user_id in self.df.index: 
            old_items = self.df[user_id]
            old_len = min(len(old_items[0]), self.max_len - 1)
            data = {key: np.append(old_item[-old_len:], new_data[key]) for key, old_item in zip(self.cols, old_items)}
        else: 
            old_len = 0
            data = new_data
        seq_len = old_len + 1
        # retrieve addtional features
        data['part'] = self.parts[data['exercise_id']]
        data['correct_rate'] = self.correct_rate[data['exercise_id']]
        # pad to max_len and set dtype
        dtype_map = {key: int for key in self.cols + ['part']}
        dtype_map['correct_rate'] = float
        data = KaggleOnlineDataset._postpad_and_asdtype(data, self.max_len - seq_len, dtype_map)
        data['valid_len'] = np.array([seq_len], dtype=int)
        return data
    
    @staticmethod
    def _postpad_and_asdtype(data, pad, dtype_map): 
        return {
            key: np.pad(item, [[0, pad]]).astype(dtype_map[key]) for key, item in data.items()
        }
    
    def update(self, test_df): 
        if self.test_df is not None and psutil.virtual_memory().percent < 90: 
            # update df according to previous labels
            prev_df = self.test_df
            prev_df['correct'] = np.array(eval(test_df.iloc[0]['prior_group_answers_correct']))[self.was_exercise]
            user_df = prev_df.groupby('user_id').apply(lambda udf: tuple(udf[col].values for col in self.cols))
            for user_id, new_items in user_df.iteritems(): 
                if user_id in self.df.index: 
                    self.df[user_id] = tuple(map(
                        lambda old_item, new_item: np.append(old_item, new_item)[-min(self.max_len, len(old_item) + 1):], 
                        self.df[user_id], 
                        new_items
                    )) # truncate at max_len to prevent OOM
                else: 
                    self.df[user_id] = tuple(new_item for new_item in new_items) # create a new row
            # update correct rate
            # self._update_correct_rate(prev_df)
        # process test_df
        is_exercise = (test_df['content_type_id'] == 0)
        test_df = test_df[is_exercise]
        test_df = test_df.rename(columns={'content_id': 'exercise_id', 'prior_question_elapsed_time': 'prior_elapsed'})
        # compute lag and convert ms -> min
        test_df['prior_elapsed'] = test_df['prior_elapsed'].fillna(0).astype(int)
        lag = self._compute_new_lag(test_df)
        test_df['lag'] = np.where(
            np.logical_and(0 < lag, lag < 60 * 1000), 1, np.round(lag / (1000 * 60))
        ).astype(int)
        # as for prior_elapsed, convert ms -> s
        prior_elapsed = test_df['prior_elapsed'].values
        test_df['prior_elapsed'] = np.where(
            np.logical_and(0 < prior_elapsed, prior_elapsed < 1000), 1, np.round(prior_elapsed / 1000)
        ).astype(int)
        test_df.reset_index(drop=True, inplace=True)
        # save relevent information
        self.test_df = test_df
        self.was_exercise = is_exercise.values
        return test_df
        
    def _update_correct_rate(self, prev_df): 
        exercise_df = prev_df.groupby('exercise_id').aggregate({'correct': [sum, len]})['correct']
        n_correct = np.arange(self.n_exercises)
        np.put(n_correct, exercise_df.index, exercise_df['sum'].values)
        n_correct = n_correct + np.round(self.correct_rate * self.frequency)
        more_frequency = np.arange(self.n_exercises)
        np.put(more_frequency, exercise_df.index, exercise_df['len'].values)
        self.frequency += more_frequency
        correct_rate = n_correct / self.frequency
        self.correct_rate = np.where(np.isfinite(correct_rate), correct_rate, 0.5)
        
    def _compute_new_lag(self, df): 
        last_states, exercise_id_to_bundle, bundle_id_to_size = self.lag_info
        # compute_lag from the original implementation
        lag = np.zeros(len(df))
        for i, (user_id, curr_timestamp, curr_exercise_id, prior_elapsed) in enumerate(
            df[['user_id', 'timestamp', 'exercise_id', 'prior_elapsed']].values
        ): 
            curr_bundle_id = exercise_id_to_bundle[curr_exercise_id]
            last_state = last_states.get(user_id, None)
            if last_state is None: 
                last_states[user_id] = (curr_timestamp, curr_bundle_id)
                lag[i] = 0
            else: 
                last_timestamp, last_bundle_id = last_state
                if curr_bundle_id == last_bundle_id: 
                    # same bundle, do not update last_states
                    lag[i] = 0
                else: 
                    last_states[user_id] = (curr_timestamp, curr_bundle_id)
                    elapsed_offset = bundle_id_to_size[last_bundle_id] * prior_elapsed
                    lag[i] = curr_timestamp - last_timestamp - elapsed_offset
        lag = np.clip(lag, a_min=0, a_max=None)
        self.lag_info = (last_states, exercise_id_to_bundle, bundle_id_to_size)
        return lag

In [None]:
def truncate_and_prepare_masks(items, valid_len, need_pad_mask=True, need_attn_mask=True): 
    max_len = valid_len.max()
    device = max_len.device
    # truncate at the max_len for each sample
    out = [None if item is None else item[:, :max_len] for item in items]
    # pad to the same length for batch-ification
    pad_mask = torch.arange(max_len, device=device) >= valid_len if need_pad_mask else None
    # assume q_len = k_len in attention
    attn_mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).to(device, torch.bool) if need_attn_mask else None
    return out, pad_mask, attn_mask

In [None]:
cfg = json.load(open(cfg_path, 'r'))
model = SAKTModel(cfg['n_exercises'], max_len=MAX_LEN, d_model=D_MODEL).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
testset = KaggleOnlineDataset(train_path, tag_path, states_path, cfg['n_exercises'], cfg['cols'], MAX_LEN)

env = riiideducation.make_env()
iter_test = env.iter_test()
for test_df, _ in iter_test: 
    test_df = testset.update(test_df)
    testloader = DataLoader(testset, batch_size=B, shuffle=False, num_workers=n_workers, drop_last=False)
    outs = np.array([], dtype='float32')
    for data in testloader: 
        valid_len = data['valid_len'].to(device, torch.long)
        interactions = F.pad(data['exercise_id'][:, :-1], (1, 0)) + F.pad(data['correct'][:, :-1], (1, 0)) * cfg['n_exercises']
        inputs, _, attn_mask = truncate_and_prepare_masks(
            [
                interactions.to(device, torch.long), 
                data['exercise_id'].to(device, torch.long)
            ], 
            valid_len, 
            need_pad_mask=False
        )
        out, _ = model(*inputs, attn_mask=attn_mask)
        out = torch.gather(out, 1, valid_len - 1).squeeze(-1)
        outs = np.append(outs, torch.sigmoid(out).detach().cpu().numpy())
    test_df['answered_correctly'] = outs
    env.predict(test_df[['row_id', 'answered_correctly']])