# RIIID - SAKT Transformer Model Training

Public Leaderboard Score: **0.774**

#### If you like this kernel or forking this kernel, please consider upvoting this and the kernels I copied (acknowledgements) from. It helps them reach more people.

- **Inference Notebook**: https://www.kaggle.com/manikanthr5/riiid-sakt-model-inference-public
- **Pretrained Dataset**: https://www.kaggle.com/manikanthr5/riiid-sakt-model-dataset-public

**Possible Improvements**:
- All the data in this notebook is used for training, so create a train and valid dataset for cross validation. Note: For me this degraded my LB score.
- Some other text book ideas you could try:
 - Using Label Smoothing
 - Using Learning Rate Schedulers (ex: [check this kernel](https://www.kaggle.com/scaomath/riiid-sakt-train-with-a-warm-up-scheduler))
 - Increase the max sequence length and/or embedding dimension (ex: [check this kernel](https://www.kaggle.com/gilfernandes/riiid-self-attention-transformer))
 - Add more attention layers
 - Change the sampling strategy (most important)

In [1]:
import gc
import psutil
import joblib
import random
import time
from tqdm import tqdm

import numpy as np
import pandas as pd
import datatable as dt

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
seed_value = 42
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # gpu vars
# torch.backends.cudnn.deterministic = True  #needed
# torch.backends.cudnn.benchmark = False

In [3]:
TRAIN_SAMPLES = 320000
MAX_SEQ = 180
MIN_SAMPLES = 5
EMBED_DIM = 128
DROPOUT_RATE = 0.2
LEARNING_RATE = 2e-3
MAX_LEARNING_RATE = 2e-3
EPOCHS = 30
TRAIN_BATCH_SIZE = 64
ACCEPTED_USER_CONTENT_SIZE = 4

In [4]:
folder = "model/"  + time.strftime("%Y%m%d-%H%M%S")
os.makedirs(folder, exist_ok=True)

## Load data

In [5]:
%%time

dtypes = {'timestamp': 'int64', 'user_id': 'int32' ,'content_id': 'int16','content_type_id': 'int8','answered_correctly':'int8'}
train_df = dt.fread('./riiid-test-answer-prediction/train.csv', columns=set(dtypes.keys())).to_pandas()
for col, dtype in dtypes.items():
    train_df[col] = train_df[col].astype(dtype)
train_df = train_df[train_df.content_type_id == False]
train_df = train_df.sort_values(['timestamp'], ascending=True)
train_df.reset_index(drop=True, inplace=True)

CPU times: user 57.1 s, sys: 5.75 s, total: 1min 2s
Wall time: 22.6 s


## Preprocess

In [6]:
skills = train_df["content_id"].unique()
joblib.dump(skills, folder + "/skills.pkl.zip")
n_skill = len(skills)
print("number skills", len(skills))

number skills 13523


In [7]:
group = train_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))

joblib.dump(group, folder + "/group.pkl.zip")
del train_df
gc.collect()

0

In [8]:
train_indexes = list(group.index)[:TRAIN_SAMPLES]
valid_indexes = list(group.index)[TRAIN_SAMPLES:]
train_group = group[group.index.isin(train_indexes)]
valid_group = group[group.index.isin(valid_indexes)]
del group, train_indexes, valid_indexes
print(len(train_group), len(valid_group))

320000 73656


In [9]:
class SAKTDataset(Dataset):
    def __init__(self, group, n_skill, max_seq=100):
        super(SAKTDataset, self).__init__()
        self.samples, self.n_skill, self.max_seq = {}, n_skill, max_seq
        
        self.user_ids = []
        for i, user_id in enumerate(group.index):
            # if(i % 10000 == 0):
            #     print(f'Processed {i} users')
            content_id, answered_correctly = group[user_id]
            if len(content_id) >= ACCEPTED_USER_CONTENT_SIZE:
                if len(content_id) > self.max_seq:
                    total_questions = len(content_id)
                    last_pos = total_questions // self.max_seq
                    for seq in range(last_pos):
                        index = f"{user_id}_{seq}"
                        self.user_ids.append(index)
                        start = seq * self.max_seq
                        end = (seq + 1) * self.max_seq
                        self.samples[index] = (content_id[start:end], answered_correctly[start:end])
                    if len(content_id[end:]) >= ACCEPTED_USER_CONTENT_SIZE:
                        index = f"{user_id}_{last_pos + 1}"
                        self.user_ids.append(index)
                        self.samples[index] = (content_id[end:], answered_correctly[end:])
                else:
                    index = f'{user_id}'
                    self.user_ids.append(index)
                    self.samples[index] = (content_id, answered_correctly)
                
                
    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        content_id, answered_correctly = self.samples[user_id]
        seq_len = len(content_id)
        
        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        if seq_len >= self.max_seq:
            content_id_seq[:] = content_id[-self.max_seq:]
            answered_correctly_seq[:] = answered_correctly[-self.max_seq:]
        else:
            content_id_seq[-seq_len:] = content_id
            answered_correctly_seq[-seq_len:] = answered_correctly
            
        target_id = content_id_seq[1:]
        label = answered_correctly_seq[1:]
        
        x = content_id_seq[:-1].copy()
        x += (answered_correctly_seq[:-1] == 1) * self.n_skill
        
        return x, target_id, label

In [10]:
train_dataset = SAKTDataset(train_group, n_skill, max_seq=MAX_SEQ)
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8)
del train_group
valid_dataset = SAKTDataset(valid_group, n_skill, max_seq=MAX_SEQ)
valid_dataloader = DataLoader(valid_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=8)
del valid_group

## Define model

In [11]:
class FFN(nn.Module):
    def __init__(self, state_size=200, forward_expansion=1, bn_size=MAX_SEQ-1, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(bn_size)
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.bn(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = (np.triu(np.ones([seq_length, seq_length]), k = 1)).astype('bool')
    return torch.from_numpy(future_mask)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads=8, dropout=DROPOUT_RATE, forward_expansion=1):
        super(TransformerBlock, self).__init__()
        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=heads, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.layer_normal = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim, forward_expansion = forward_expansion, dropout=dropout)
        self.layer_normal_2 = nn.LayerNorm(embed_dim)
        

    def forward(self, value, key, query, att_mask):
        att_output, att_weight = self.multi_att(value, key, query, attn_mask=att_mask)
        att_output = self.dropout(self.layer_normal(att_output + value))
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]
        x = self.ffn(att_output)
        x = self.dropout(self.layer_normal_2(x + att_output))
        return x.squeeze(-1), att_weight
    
class Encoder(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout=DROPOUT_RATE, forward_expansion=1, num_layers=1, heads = 8):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(2 * n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq - 1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)
        self.layers = nn.ModuleList([TransformerBlock(embed_dim, forward_expansion = forward_expansion) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids):
        device = x.device
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        x = self.dropout(x + pos_x)
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = self.e_embedding(question_ids)
        e = e.permute(1, 0, 2)
        for layer in self.layers:
            att_mask = future_mask(e.size(0)).to(device)
            x, att_weight = layer(e, x, x, att_mask=att_mask)
            x = x.permute(1, 0, 2)
        x = x.permute(1, 0, 2)
        return x, att_weight

class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128, dropout=DROPOUT_RATE, forward_expansion = 1, enc_layers=1, heads = 8):
        super(SAKTModel, self).__init__()
        self.encoder = Encoder(n_skill, max_seq, embed_dim, dropout, forward_expansion, num_layers=enc_layers)
        self.pred = nn.Linear(embed_dim, 1)
        
    def forward(self, x, question_ids):
        x, att_weight = self.encoder(x, question_ids)
        x = self.pred(x)
        return x.squeeze(-1), att_weight

In [12]:
def train_fn(model, dataloader, optimizer, scheduler, criterion, device="cpu"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    for item in dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        optimizer.zero_grad()
        output, _, = model(x, target_id)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss.append(loss.item())

        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [13]:
def valid_fn(model, dataloader, criterion, device="cpu"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    for item in dataloader:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        output, _, = model(x, target_id)
        loss = criterion(output, label)
        valid_loss.append(loss.item())

        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAKTModel(n_skill, max_seq=MAX_SEQ, embed_dim=EMBED_DIM, dropout=DROPOUT_RATE, enc_layers=1)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=MAX_LEARNING_RATE, steps_per_epoch=len(train_dataloader), epochs=EPOCHS
)

model.to(device)
criterion.to(device)

BCEWithLogitsLoss()

In [15]:
best_auc = 0
max_steps = 3
step = 0
for epoch in tqdm(range(EPOCHS)):
    loss, acc, auc = train_fn(model, train_dataloader, optimizer, scheduler, criterion, device)
    print("[epoch - {}/{}] [train: - {:.3f}] [acc - {:.4f}] [auc - {:.4f}]".format(epoch+1, EPOCHS, loss, acc, auc))
    loss, acc, auc = valid_fn(model, valid_dataloader, criterion, device)
    print("[epoch - {}/{}] [valid: - {:.3f}] [acc - {:.4f}] [auc - {:.4f}]\n".format(epoch+1, EPOCHS, loss, acc, auc))
    if auc > best_auc:
        best_auc = auc
        step = 0
        torch.save(model.state_dict(), folder + "/sakt_model.pt")
    else:
        step += 1
        if step >= max_steps:
            break
torch.save(model.state_dict(), folder + "/sakt_model_final.pt")

  0%|          | 0/30 [00:00<?, ?it/s][epoch - 1/30] [train: - 0.406] [acc - 0.6827] [auc - 0.6651]
  3%|▎         | 1/30 [04:14<2:03:00, 254.49s/it][epoch - 1/30] [valid: - 0.376] [acc - 0.7132] [auc - 0.7401]

[epoch - 2/30] [train: - 0.376] [acc - 0.7107] [auc - 0.7359]
  7%|▋         | 2/30 [08:16<1:57:00, 250.73s/it][epoch - 2/30] [valid: - 0.368] [acc - 0.7204] [auc - 0.7554]

[epoch - 3/30] [train: - 0.369] [acc - 0.7178] [auc - 0.7505]
 10%|█         | 3/30 [12:18<1:51:38, 248.11s/it][epoch - 3/30] [valid: - 0.367] [acc - 0.7222] [auc - 0.7590]

[epoch - 4/30] [train: - 0.366] [acc - 0.7207] [auc - 0.7556]
 13%|█▎        | 4/30 [16:20<1:46:46, 246.42s/it][epoch - 4/30] [valid: - 0.366] [acc - 0.7235] [auc - 0.7604]

[epoch - 5/30] [train: - 0.365] [acc - 0.7218] [auc - 0.7576]
 17%|█▋        | 5/30 [20:22<1:42:06, 245.07s/it][epoch - 5/30] [valid: - 0.365] [acc - 0.7240] [auc - 0.7609]

[epoch - 6/30] [train: - 0.364] [acc - 0.7227] [auc - 0.7589]
 20%|██        | 6/30 [24:24<1

In [16]:
del train_dataset, valid_dataset

## Test

In [None]:
class TestDataset(Dataset):
    def __init__(self, samples, test_df, skills, max_seq=MAX_SEQ):
        super(TestDataset, self).__init__()
        self.samples = samples
        self.user_ids = [x for x in test_df["user_id"].unique()]
        self.test_df = test_df
        self.skills = skills
        self.n_skill = len(skills)
        self.max_seq = max_seq

    def __len__(self):
        return self.test_df.shape[0]

    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]

        user_id = test_info["user_id"]
        target_id = test_info["content_id"]

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)

        if user_id in self.samples.index:
            q_, qa_ = self.samples[user_id]
            
            seq_len = len(q_)

            if seq_len >= self.max_seq:
                q = q_[-self.max_seq:]
                qa = qa_[-self.max_seq:]
            else:
                q[-seq_len:] = q_
                qa[-seq_len:] = qa_          
        
        x = np.zeros(self.max_seq-1, dtype=int)
        x = q[1:].copy()
        x += (qa[1:] == 1) * self.n_skill
        
        questions = np.append(q[2:], [target_id])
        
        return x, questions

In [None]:
import riiideducation

env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
group = joblib.load("group.pkl.zip")

In [None]:
model.eval()
prev_test_df = None

for (test_df, sample_prediction_df) in tqdm(iter_test):
    if (prev_test_df is not None) & (psutil.virtual_memory().percent < 90):
        prev_test_df['answered_correctly'] = eval(test_df['prior_group_answers_correct'].iloc[0])
        prev_test_df = prev_test_df[prev_test_df.content_type_id == False]
        
        prev_group = prev_test_df[['user_id', 'content_id', 'answered_correctly']].groupby('user_id').apply(lambda r: (
            r['content_id'].values,
            r['answered_correctly'].values))
        for prev_user_id in prev_group.index:
            if prev_user_id in group.index:
                group[prev_user_id] = (
                    np.append(group[prev_user_id][0], prev_group[prev_user_id][0])[-MAX_SEQ:], 
                    np.append(group[prev_user_id][1], prev_group[prev_user_id][1])[-MAX_SEQ:]
                )
 
            else:
                group[prev_user_id] = (
                    prev_group[prev_user_id][0], 
                    prev_group[prev_user_id][1]
                )

    prev_test_df = test_df.copy()
    
    test_df = test_df[test_df.content_type_id == False]
    test_dataset = TestDataset(group, test_df, skills)
    test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
    
    outs = []

    for item in tqdm(test_dataloader):
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()

        with torch.no_grad():
            output, att_weight = model(x, target_id)
        outs.extend(torch.sigmoid(output)[:, -1].view(-1).data.cpu().numpy())
        
    test_df['answered_correctly'] = outs
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])

## Attention Weights Visualization

In [None]:
att_weight = att_weight.detach().cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
%matplotlib inline

In [None]:
try:
    attention_weights = att_weight[0]

    fig, ax = plt.subplots(figsize=(18, 16))
    mask = np.triu(np.ones_like(attention_weights, dtype=bool))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(attention_weights, mask=mask, cmap=cmap, vmax=.3, 
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Attention Weight")

    plt.show()
except:
    pass

In [None]:
try:
    attention_weights = att_weight[5]

    fig, ax = plt.subplots(figsize=(18, 16))
    mask = np.triu(np.ones_like(attention_weights, dtype=bool))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(attention_weights, mask=mask, cmap=cmap, vmax=.3, 
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Attention Weight")

    plt.show()
except:
    pass

In [None]:
try:
    attention_weights = att_weight[10]

    fig, ax = plt.subplots(figsize=(18, 16))
    mask = np.triu(np.ones_like(attention_weights, dtype=bool))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(attention_weights, mask=mask, cmap=cmap, vmax=.3, 
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Attention Weight")

    plt.show()
except:
    pass

In [None]:
try:
    attention_weights = att_weight[15]

    fig, ax = plt.subplots(figsize=(18, 16))
    mask = np.triu(np.ones_like(attention_weights, dtype=bool))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(attention_weights, mask=mask, cmap=cmap, vmax=.3, 
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Attention Weight")

    plt.show()
except:
    pass

In [None]:
try:
    attention_weights = att_weight[20]

    fig, ax = plt.subplots(figsize=(18, 16))
    mask = np.triu(np.ones_like(attention_weights, dtype=bool))

    # Generate a custom diverging colormap
    cmap = sns.diverging_palette(230, 20, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    sns.heatmap(attention_weights, mask=mask, cmap=cmap, vmax=.3, 
                square=True, linewidths=.5, cbar_kws={"shrink": .5})
    plt.title("Attention Weight")

    plt.show()
except:
    pass

---