# RIIID - SAKT Transformer Model Inference

In [1]:
import psutil
import joblib

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()

In [3]:
TRAIN_SAMPLES = 320000
MAX_SEQ = 180
MIN_SAMPLES = 5
EMBED_DIM = 128
DROPOUT_RATE = 0.2
LEARNING_RATE = 2e-3
MAX_LEARNING_RATE = 2e-3
EPOCHS = 10
TRAIN_BATCH_SIZE = 64
ACCEPTED_USER_CONTENT_SIZE = 4

In [4]:
class FFN(nn.Module):
    def __init__(self, state_size=200, forward_expansion=1, bn_size=MAX_SEQ-1, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(bn_size)
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.bn(x)
        x = self.lr2(x)
        return self.dropout(x)


def future_mask(seq_length):
    future_mask = (np.triu(np.ones([seq_length, seq_length]), k = 1)).astype('bool')
    return torch.from_numpy(future_mask)


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads=8, dropout=DROPOUT_RATE, forward_expansion=1):
        super(TransformerBlock, self).__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_normal = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim, forward_expansion = forward_expansion, dropout=dropout)
        self.layer_normal_2 = nn.LayerNorm(embed_dim)
        

    def forward(self, value, key, query, att_mask):
        att_output, att_weight = self.multi_att(value, key, query, attn_mask=att_mask)
        att_output = self.dropout(self.layer_normal(att_output + value))
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        x = self.ffn(att_output)
        x = self.dropout(self.layer_normal_2(x + att_output))
        return x.squeeze(-1), att_weight
    

class Encoder(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, part_dim=7, dropout=DROPOUT_RATE, forward_expansion=1, num_layers=1, heads = 8):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(2 * n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq - 1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill + 1, embed_dim)
        self.part_embedding = nn.Embedding(n_skill + 1, part_dim)            # TOEIC has 7 parts
        self.e_part_embedding = nn.Linear(embed_dim + part_dim, embed_dim)

        self.layers = nn.ModuleList([TransformerBlock(embed_dim, forward_expansion = forward_expansion) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids, parts):
        device = x.device
        x = self.embedding(x)

        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        
        x = self.dropout(x + pos_x)
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]

        e = self.e_embedding(question_ids)
        e = e.permute(1, 0, 2)
        embedded_parts = self.part_embedding(parts)
        embedded_parts = embedded_parts.permute(1, 0, 2)

        # Project meta parts and content id to dim
        meta_question = torch.cat((e, embedded_parts), dim=2)
        meta_question = self.e_part_embedding(meta_question)

        for layer in self.layers:
            att_mask = future_mask(e.size(0)).to(device)
            x, att_weight = layer(e, x, x, att_mask=att_mask)
            x = x.permute(1, 0, 2)
        x = x.permute(1, 0, 2)
        return x, att_weight


class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, part_dim=7, dropout=DROPOUT_RATE, forward_expansion = 1, enc_layers=1, heads = 8):
        super(SAKTModel, self).__init__()
        self.encoder = Encoder(n_skill, max_seq, embed_dim, part_dim, dropout, forward_expansion, num_layers=enc_layers)
        self.pred = nn.Linear(embed_dim, 1)
        
    def forward(self, x, question_ids, parts):
        x, att_weight = self.encoder(x, question_ids, parts)
        x = self.pred(x)
        return x.squeeze(-1), att_weight

## Load Pretrained Models

In [5]:
skills = joblib.load("/kaggle/input/sakt-transformer/skills.pkl.zip")
n_skill = len(skills)
group = joblib.load("/kaggle/input/sakt-transformer/group.pkl.zip")

In [6]:
## Question csv
question_dtypes = { 
            'question_id': 'int16', 
            'part': 'int16'
         }
question_df = pd.read_csv('/kaggle/input/riiid-test-answer-prediction/questions.csv')[question_dtypes.keys()]
for col, dtype in question_dtypes.items():
    question_df[col] = question_df[col].astype(dtype)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAKTModel(n_skill, max_seq=MAX_SEQ, embed_dim=EMBED_DIM, dropout=DROPOUT_RATE)
try:
    model.load_state_dict(torch.load("/kaggle/input/sakt-transformer/sakt_model_final.pt"))
except:
    model.load_state_dict(torch.load("/kaggle/input/sakt-transformer/sakt_model_final.pt", map_location='cpu'))
model.to(device)
model.eval()

SAKTModel(
  (encoder): Encoder(
    (embedding): Embedding(27047, 128)
    (pos_embedding): Embedding(179, 128)
    (e_embedding): Embedding(13524, 128)
    (part_embedding): Embedding(13524, 7)
    (e_part_embedding): Linear(in_features=135, out_features=128, bias=True)
    (layers): ModuleList(
      (0): TransformerBlock(
        (multi_att): MultiheadAttention(
          (out_proj): _LinearWithBias(in_features=128, out_features=128, bias=True)
        )
        (dropout): Dropout(p=0.2, inplace=False)
        (layer_normal): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (ffn): FFN(
          (lr1): Linear(in_features=128, out_features=128, bias=True)
          (relu): ReLU()
          (bn): BatchNorm1d(179, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lr2): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (layer_normal_2): LayerNorm((128,), eps=1e-05, elementwise_

## Inference

In [8]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, question_df, n_skill, max_seq=100):
        super(TestDataset, self).__init__()
#         self.samples, self.user_ids, self.test_df = samples, [x for x in test_df["user_id"].unique()], test_df
#         self.n_skill, self.max_seq = n_skill, max_seq
        
        self.samples = samples
        self.user_ids = [x for x in test_df["user_id"].unique()]
        self.test_df = test_df
        self.question_df = question_df
        self.n_skill = len(skills)
        self.max_seq = max_seq

    def __len__(self):
        return self.test_df.shape[0]
    
    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]
        
        user_id = test_info['user_id']
        target_id = test_info['content_id']
        target_part = self.question_df['part'][target_id]
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        parts_seq = np.zeros(self.max_seq, dtype=int)
        
        if user_id in self.samples.index:
            content_id, answered_correctly = self.samples[user_id]
            parts = self.question_df.loc[content_id, 'part'].to_list()
            seq_len = len(content_id)
            
            if seq_len >= self.max_seq:
                content_id_seq = content_id[-self.max_seq:]
                answered_correctly_seq = answered_correctly[-self.max_seq:]
                parts_seq = parts[-self.max_seq:]
            else:
                content_id_seq[-seq_len:] = content_id
                answered_correctly_seq[-seq_len:] = answered_correctly
                parts_seq[-seq_len:] = parts
                
        x = content_id_seq[1:].copy()
        x += (answered_correctly_seq[1:] == 1) * self.n_skill
        
        questions = np.append(content_id_seq[2:], [target_id])
        parts_id = np.append(parts_seq[2:], [target_part])
        
        return x, questions, parts_id

In [9]:
prev_test_df = None

for (test_df, sample_prediction_df) in tqdm(iter_test):
    
    if (prev_test_df is not None) & (psutil.virtual_memory().percent<90):
        print(psutil.virtual_memory().percent)
        
        prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
        prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
        prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))
        
        for prev_user_id in prev_group.index:
            prev_group_content = prev_group[prev_user_id][0]
            prev_group_answered_correctly = prev_group[prev_user_id][1]
            if prev_user_id in group.index:
                group[prev_user_id] = (np.append(group[prev_user_id][0], prev_group_content), 
                                       np.append(group[prev_user_id][1], prev_group_answered_correctly))
            else:
                group[prev_user_id] = (prev_group_content, prev_group_answered_correctly)
            
            if len(group[prev_user_id][0]) > MAX_SEQ:
                new_group_content = group[prev_user_id][0][-MAX_SEQ:]
                new_group_answered_correctly = group[prev_user_id][1][-MAX_SEQ:]
                group[prev_user_id] = (new_group_content, new_group_answered_correctly)
                
    prev_test_df = test_df.copy()
    test_df = test_df[test_df.content_type_id == False]
    
    test_dataset = TestDataset(group, test_df, question_df, n_skill, max_seq=MAX_SEQ)
    test_dataloader = DataLoader(test_dataset, batch_size=len(test_df), shuffle=False)
    
    item = next(iter(test_dataloader))
    x = item[0].to(device).long()
    target_id = item[1].to(device).long()
    parts = item[2].to(device).long()
    
    with torch.no_grad():
        output, _ = model(x, target_id, parts)
        
    output = torch.sigmoid(output)
    output = output[:, -1]
    test_df['answered_correctly'] = output.cpu().numpy()
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])

1it [00:00,  4.14it/s]

179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
14.7


2it [00:00,  4.20it/s]

179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
14.8
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179


4it [00:00,  5.14it/s]

15.0
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179
179 179 179


4it [00:01,  2.81it/s]


All the credits go to this popular notebook https://www.kaggle.com/leadbest/sakt-with-randomization-state-updates which is a modification of https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing. Please show some support to these original work kernels.

---