## Preliminaries

In [None]:
import os
import sys
import math
import pickle
import psutil
import random

import json
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd

import riiideducation

In [None]:
seed = 0
random.seed(seed)
torch.random.manual_seed(seed)

n_workers = os.cpu_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg_path = '/kaggle/input/riiid-mydata/cfg.json'
train_path = '/kaggle/input/riiid-mydata/train.pkl'
tag_path = '/kaggle/input/riiid-mydata/tags.csv'
states_path = '/kaggle/input/riiid-mydata/states.pickle'
model_path = '/kaggle/input/riiid-mydata/aPFA_08-12_18-18.pt'

B = 512
MAX_LEN = 128
SEQ_LEN = 128
MAX_LAG = 30 * 7 * 24 * 60
N_LAYERS = 0
N_HEADS = 1
D_MODEL = 256
IS_LITE = True

## aPFA Model

In [None]:
class FFN(nn.Module): 
    def __init__(self, d_model, dropout=0.0): 
        super().__init__()
        self.lr1 = nn.Linear(d_model, d_model)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x): 
        x = self.lr1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.lr2(x)
        return x


class AIKTMultiheadAttention(nn.Module): 
    def __init__(self, d_model, n_heads=8, d_qkv=None, use_proj=True, dropout=0.1): 
        super().__init__()
        if d_qkv is None or not use_proj: 
            assert d_model % n_heads == 0
            d_qkv = d_model // n_heads
        d_inner = d_qkv * n_heads
        self.scale = d_model ** (-0.5)
        if use_proj: 
            self.Q = nn.Linear(d_model, d_inner)
            self.K = nn.Linear(d_model, d_inner)
            self.V = nn.Linear(d_model, d_inner)
            self.out_proj = nn.Linear(d_inner, d_model)
        else: 
            self.Q = self.K = self.V = self.out_proj = lambda x: x
        # (len, B, d_inner) -> (len, B, n_heads, d_qkv)
        self.reshape_for_attn = lambda x: x.reshape(*x.shape[:-1], n_heads, d_qkv)
        # (len, B, n_heads, d_qkv) -> (len, B, d_inner)
        self.recover_from_attn = lambda x: x.reshape(*x.shape[:-2], d_inner)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attn_mask=None, rel_pos_embd=None): 
        # in-sample projection
        query, key, value = (
            self.reshape_for_attn(query), 
            self.reshape_for_attn(key), 
            self.reshape_for_attn(value)
        )
        # self-attention mechanism
        attn_scores = torch.einsum('ibnd,jbnd->ijbn', (query, key)) * self.scale
        if rel_pos_embd is not None: 
            attn_scores += rel_pos_embd
        if attn_mask is not None: 
            assert attn_mask.dtype == torch.bool, 'Only bool type is supported for masks.'
            assert attn_mask.ndim == 2, 'Only 2D attention mask is supported'
            assert attn_mask.shape == attn_scores.shape[:2], 'Incorrect mask shape: {}. Expect: {}'.format(attn_mask.shape, attn_scores.shape[:2])
            mask = torch.zeros_like(attn_mask, dtype=torch.float)
            mask.masked_fill_(attn_mask, float('-inf'))
            mask = mask.view(*mask.shape, 1, 1)
            attn_scores += mask
        attn_weights = torch.softmax(attn_scores, dim=1)
        attn_weights = self.dropout(attn_weights)
        out = torch.einsum('ijbn,jbnd->ibnd', (attn_weights, value))
        out = self.recover_from_attn(out)
        # output layer
        out = self.out_proj(out)
        attn_weights = attn_weights.mean(-1).permute(-1, 0, 1)
        return out, attn_weights


class InducedMultiheadAttention(nn.Module): 
    def __init__(self, d_model, n_heads, d_qkv=None, use_proj=False, dropout=0.1): 
        super().__init__()
        if d_qkv is None or not use_proj: 
            assert d_model % n_heads == 0
            d_qkv = d_model // n_heads
        d_inner = d_qkv * n_heads
        self.scale = d_model ** (-0.5)
        self.Q = nn.Parameter(torch.randn(n_heads, d_qkv))
        if use_proj: 
            self.K = nn.Linear(d_model, d_inner)
            self.V = nn.Linear(d_model, d_inner)
            self.out_proj = nn.Linear(d_inner, d_model)
        else: 
            self.K = self.V = self.out_proj = lambda x: x
        # (len, B, d_inner) -> (len, B, n_heads, d_qkv)
        self.reshape_for_attn = lambda x: x.reshape(*x.shape[:-1], n_heads, d_qkv)
        # (len, B, n_heads, d_qkv) -> (len, B, d_inner)
        self.recover_from_attn = lambda x: x.reshape(*x.shape[:-2], d_inner)
        self.dropout = nn.Dropout(dropout)

    def forward(self, key, value, attn_mask=None, rel_pos_embd=None): 
        # in-sample projection
        key, value = (
            self.reshape_for_attn(self.K(key)), 
            self.reshape_for_attn(self.V(value))
        )
        # induced attention mechanism
        attn_scores = torch.einsum('nd,jbnd->jbn', (self.Q, key)) * self.scale
        attn_scores = attn_scores.unsqueeze(0).repeat(attn_scores.shape[0], 1, 1, 1)
        if rel_pos_embd is not None: 
            attn_scores += rel_pos_embd
        if attn_mask is not None: 
            assert attn_mask.dtype == torch.bool, 'Only bool type is supported for masks.'
            assert attn_mask.ndim == 2, 'Only 2D attention mask is supported'
            assert attn_mask.shape == attn_scores.shape[:2], 'Incorrect mask shape: {}. Expect: {}'.format(attn_mask.shape, attn_scores.shape[:2])
            mask = torch.zeros_like(attn_mask, dtype=torch.float)
            mask.masked_fill_(attn_mask, float('-inf'))
            mask = mask.view(*mask.shape, 1, 1)
            attn_scores += mask
        attn_weights = torch.softmax(attn_scores, dim=1)
        attn_weights = self.dropout(attn_weights)
        out = torch.einsum('ijbn,jbnd->ibnd', (attn_weights, value))
        out = self.recover_from_attn(out)
        # output layer
        out = self.out_proj(out)
        attn_weights = attn_weights.mean(-1).permute(-1, 0, 1)
        return out, attn_weights


class APFAModel(nn.Module): 
    def __init__(self, n_exercises, max_lag, n_layers, d_model, n_heads=8, is_lite=False, dropout=0.1): 
        super().__init__()
        self.exercise_embd = nn.Embedding(n_exercises, d_model)
        self.correct_embd = nn.Embedding(2, d_model)
        self.max_lag = max_lag
        self.n_lag_buckets = 2 * math.ceil(math.log(max_lag))
        self.rel_lag_embd = nn.Embedding(self.n_lag_buckets, n_heads)

        self.is_lite = is_lite
        self.enc = nn.ModuleList([
            nn.ModuleList([
                AIKTMultiheadAttention(d_model, n_heads, use_proj=not is_lite, dropout=dropout), 
                None if is_lite else nn.LayerNorm(d_model), 
                None if is_lite else FFN(d_model, dropout=dropout), 
                None if is_lite else nn.LayerNorm(d_model)
            ]) for _ in range(n_layers)
        ])
        self.init_mem = nn.Parameter(torch.randn(d_model))
        self.dec = InducedMultiheadAttention(d_model, n_heads, use_proj=not is_lite, dropout=dropout)
        self.ln1 = None if is_lite else nn.LayerNorm(d_model)
        self.ffn = None if is_lite else FFN(d_model, dropout=dropout)
        self.ln2 = None if is_lite else nn.LayerNorm(d_model)

        self.predict = nn.Linear(d_model, n_exercises)
        self.predict.weight = self.exercise_embd.weight # tie weights
        self.dropout = nn.Dropout(dropout)

    def lag_to_bucket(self, lag_time): 
        n_exact = self.n_lag_buckets // 2
        acc_lag_time = torch.cumsum(lag_time, dim=-1).unsqueeze(-1)
        rel_lag_time = torch.clamp(
            acc_lag_time - acc_lag_time.transpose(-1, -2), min=0, max=self.max_lag
        )
        rel_lag_time = torch.cat(
            [rel_lag_time[:, :, :1], rel_lag_time[:, :, :-1]], dim=-1
        ) # right shift by 1 along the dimension of k_len
        buckets_for_long_lag = n_exact - 1 + torch.ceil(
            torch.log(rel_lag_time / n_exact) / math.log(self.max_lag / n_exact) * (self.n_lag_buckets - n_exact)
        )
        buckets = torch.where(rel_lag_time < n_exact, rel_lag_time, buckets_for_long_lag.long())
        return buckets.permute(1, 2, 0) # (q_len, k_len, B)

    def forward(self, e, c, lt, mem=None, attn_mask=None): 
        # encoder
        src = self.exercise_embd(e)
        src = src.transpose(0, 1) # (B, L, d) -> (L, B, d)
        enc_attn_weights = []
        for self_attn, ln1, ffn, ln2 in self.enc: 
            out, attn_weights = self_attn(src, src, src, attn_mask=attn_mask)
            if self.is_lite: 
                src = self.dropout(out)
            else: 
                src = ln1(src + self.dropout(out))
                out = ffn(src)
                src = ln2(src + self.dropout(out))
            enc_attn_weights.append(attn_weights)
        src = src.transpose(0, 1) # (L, B, d) -> (B, L, d)
        # decoder
        src = src + self.correct_embd(c)
        if mem is None: 
            mem = self.init_mem.view(1, 1, -1).repeat(src.shape[0], 1, 1) # (B, 1, d)
        src = torch.cat([mem, src[:, :-1, :]], dim=1)
        src = src.transpose(0, 1) # (B, L, d) -> (L, B, d)
        rel_pos_embd = self.rel_lag_embd(self.lag_to_bucket(lt))
        out, dec_attn_weights = self.dec(src, src, attn_mask=attn_mask, rel_pos_embd=rel_pos_embd)
        tgt = self.dropout(out)
        if not self.is_lite: 
            tgt = self.ln1(tgt)
            out = self.ffn(tgt)
            tgt = self.ln2(tgt + self.dropout(out))
        tgt = tgt.transpose(0, 1) # (L, B, d) -> (B, L, d)
        mem = tgt[:, -1:, :]
        # prediction
        weight = self.predict.weight[e]
        bias = self.predict.bias[e]
        logits = torch.einsum('bid,bid->bi', (tgt, weight)) + bias
        return logits, mem, (enc_attn_weights, dec_attn_weights)

## Testing Phase

In [None]:
class KaggleOnlineDataset(Dataset): 
    def __init__(self, train_path, tag_path, states_path, n_exercises, cols, max_len): 
        super().__init__()
        self.df = pd.read_pickle(train_path)
        self.test_df = None
        tag_df = pd.read_csv(tag_path, usecols=['exercise_id', 'bundle_id', 'part', 'correct_rate', 'frequency'])
        assert np.all(tag_df['exercise_id'].values == np.arange(n_exercises))
        self.parts, self.correct_rate, self.frequency = tag_df[['part', 'correct_rate', 'frequency']].values.T
        self.lag_info = pickle.load(open(states_path, 'rb'))
        self.n_exercises = n_exercises
        self.cols = cols
        self.max_len = max_len

    def __len__(self): 
        assert self.test_df is not None, 'Please call update() first'
        return len(self.test_df)

    def __getitem__(self, idx): 
        new_observation = self.test_df.iloc[idx]
        # 'correct' is set to 0 temporarily
        user_id = new_observation['user_id']
        new_data = {col: np.array([new_observation.get(col, 0)]) for col in self.cols}
        # retrieve old observations
        if user_id in self.df.index: 
            old_items = self.df[user_id]
            old_len = min(len(old_items[0]), self.max_len - 1)
            data = {key: np.append(old_item[-old_len:], new_data[key]) for key, old_item in zip(self.cols, old_items)}
        else: 
            old_len = 0
            data = new_data
        seq_len = old_len + 1
        # retrieve addtional features
        data['part'] = self.parts[data['exercise_id']]
        data['correct_rate'] = self.correct_rate[data['exercise_id']]
        # pad to max_len and set dtype
        dtype_map = {key: int for key in self.cols + ['part']}
        dtype_map['correct_rate'] = float
        data = KaggleOnlineDataset._postpad_and_asdtype(data, self.max_len - seq_len, dtype_map)
        data['valid_len'] = np.array([seq_len], dtype=int)
        return data
    
    @staticmethod
    def _postpad_and_asdtype(data, pad, dtype_map): 
        return {
            key: np.pad(item, [[0, pad]]).astype(dtype_map[key]) for key, item in data.items()
        }
    
    def update(self, test_df): 
        if self.test_df is not None and psutil.virtual_memory().percent < 90: 
            # update df according to previous labels
            prev_df = self.test_df
            prev_df['correct'] = np.array(eval(test_df.iloc[0]['prior_group_answers_correct']))[self.was_exercise]
            user_df = prev_df.groupby('user_id').apply(lambda udf: tuple(udf[col].values for col in self.cols))
            for user_id, new_items in user_df.iteritems(): 
                if user_id in self.df.index: 
                    self.df[user_id] = tuple(map(
                        lambda old_item, new_item: np.append(old_item, new_item)[-min(self.max_len, len(old_item) + 1):], 
                        self.df[user_id], 
                        new_items
                    )) # truncate at max_len to prevent OOM
                else: 
                    self.df[user_id] = tuple(new_item for new_item in new_items) # create a new row
            # update correct rate
            # self._update_correct_rate(prev_df)
        # process test_df
        is_exercise = (test_df['content_type_id'] == 0)
        test_df = test_df[is_exercise]
        test_df = test_df.rename(columns={'content_id': 'exercise_id', 'prior_question_elapsed_time': 'prior_elapsed'})
        # compute lag and convert ms -> min
        test_df['prior_elapsed'] = test_df['prior_elapsed'].fillna(0).astype(int)
        lag = self._compute_new_lag(test_df)
        test_df['lag'] = np.where(
            np.logical_and(0 < lag, lag < 60 * 1000), 1, np.round(lag / (1000 * 60))
        ).astype(int)
        # as for prior_elapsed, convert ms -> s
        prior_elapsed = test_df['prior_elapsed'].values
        test_df['prior_elapsed'] = np.where(
            np.logical_and(0 < prior_elapsed, prior_elapsed < 1000), 1, np.round(prior_elapsed / 1000)
        ).astype(int)
        test_df.reset_index(drop=True, inplace=True)
        # save relevent information
        self.test_df = test_df
        self.was_exercise = is_exercise.values
        return test_df
        
    def _update_correct_rate(self, prev_df): 
        exercise_df = prev_df.groupby('exercise_id').aggregate({'correct': [sum, len]})['correct']
        n_correct = np.arange(self.n_exercises)
        np.put(n_correct, exercise_df.index, exercise_df['sum'].values)
        n_correct = n_correct + np.round(self.correct_rate * self.frequency)
        more_frequency = np.arange(self.n_exercises)
        np.put(more_frequency, exercise_df.index, exercise_df['len'].values)
        self.frequency += more_frequency
        correct_rate = n_correct / self.frequency
        self.correct_rate = np.where(np.isfinite(correct_rate), correct_rate, 0.5)
        
    def _compute_new_lag(self, df): 
        last_states, exercise_id_to_bundle, bundle_id_to_size = self.lag_info
        # compute_lag from the original implementation
        lag = np.zeros(len(df))
        for i, (user_id, curr_timestamp, curr_exercise_id, prior_elapsed) in enumerate(
            df[['user_id', 'timestamp', 'exercise_id', 'prior_elapsed']].values
        ): 
            curr_bundle_id = exercise_id_to_bundle[curr_exercise_id]
            last_state = last_states.get(user_id, None)
            if last_state is None: 
                last_states[user_id] = (curr_timestamp, curr_bundle_id)
                lag[i] = 0
            else: 
                last_timestamp, last_bundle_id = last_state
                if curr_bundle_id == last_bundle_id: 
                    # same bundle, do not update last_states
                    lag[i] = 0
                else: 
                    last_states[user_id] = (curr_timestamp, curr_bundle_id)
                    elapsed_offset = bundle_id_to_size[last_bundle_id] * prior_elapsed
                    lag[i] = curr_timestamp - last_timestamp - elapsed_offset
        lag = np.clip(lag, a_min=0, a_max=None)
        self.lag_info = (last_states, exercise_id_to_bundle, bundle_id_to_size)
        return lag

In [None]:
def truncate_and_prepare_masks(items, valid_len, need_pad_mask=True, need_attn_mask=True): 
    max_len = valid_len.max()
    device = max_len.device
    # truncate at the max_len for each sample
    out = [None if item is None else item[:, :max_len] for item in items]
    # pad to the same length for batch-ification
    pad_mask = torch.arange(max_len, device=device) >= valid_len if need_pad_mask else None
    # assume q_len = k_len in attention
    attn_mask = torch.triu(torch.ones(max_len, max_len), diagonal=1).to(device, torch.bool) if need_attn_mask else None
    return out, pad_mask, attn_mask

In [None]:
cfg = json.load(open(cfg_path, 'r'))
model = APFAModel(cfg['n_exercises'], MAX_LAG, N_LAYERS, D_MODEL, n_heads=N_HEADS, is_lite=IS_LITE).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
testset = KaggleOnlineDataset(train_path, tag_path, states_path, cfg['n_exercises'], cfg['cols'], MAX_LEN)

env = riiideducation.make_env()
iter_test = env.iter_test()
for test_df, _ in iter_test: 
    test_df = testset.update(test_df)
    testloader = DataLoader(testset, batch_size=B, shuffle=False, num_workers=n_workers, drop_last=False)
    outs = np.array([], dtype='float32')
    for data in testloader: 
        valid_len = data['valid_len'].to(device, torch.long)
        (*inputs, labels), _, attn_mask = truncate_and_prepare_masks(
            [
                data['exercise_id'].to(device, torch.long), 
                data['correct'].to(device, torch.long), 
                data['lag'].to(device, torch.long), 
                data['correct'].to(device, torch.float), 
            ], 
            valid_len, 
            need_pad_mask=False
        )
        
        max_len = valid_len.max().item()
        mem = None
        out = torch.empty_like(labels)
        for start in range(0, max_len, SEQ_LEN): 
            end = min(start + SEQ_LEN, max_len)
            inputs_i = tuple(item[:, start:end] for item in inputs)
            attn_mask_i = attn_mask[start:end, start:end]
            out_i, mem, _ = model(*inputs_i, mem=mem, attn_mask=attn_mask_i)
            out[:, start:end] = out_i.detach()
        out = torch.gather(out, 1, valid_len - 1).squeeze(-1)
        outs = np.append(outs, torch.sigmoid(out).detach().cpu().numpy())
    test_df['answered_correctly'] = outs
    env.predict(test_df[['row_id', 'answered_correctly']])