In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import gc
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import os
import copy

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

### Introduction and Credit

I had a lot of fun with this competition, and I'm disappointed it's over because I was making really good progress towards the end. My final submission got up to 0.798, but it finished scoring after the competition had ended. I'm sure I could have gotten this well above 0.8 if I had the time. There are still a number of ideas I didn't get the chance to implement. 

The following kernels and github were instrumental in building this model.

https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing

https://www.kaggle.com/mpware/sakt-fork

https://github.com/arshadshk/SAINT-pytorch

## Load data

The poorly named train-df-saint-not-binned contains a pre processed feather version of train_df. In that kernel, I calculate all the interesting features, then create artificial users for any user with more than seq_len interactions. So if a user had 160 interactions, I split this into two users: One with the final 100 interactions, and one with the first 60 interactions. This strategy seemed to work well when applied to my SAKT fork, so I kept it with my SAINT model. As a result of this, instead of 39000 users, the model trains on 120000 "users" or so. 

Given more time, I would have liked to experiment with other ways of creating sequences for the model.

In [None]:
%%time

train_df = pd.read_feather('../input/train-df-saint-not-binned-everything/train_not_binned')
train_df['prior_question_elapsed_time'].fillna(0, inplace=True)
train_df['prior_question_elapsed_time'] /= 1000
train_df['prior_question_elapsed_time'] = train_df['prior_question_elapsed_time'].round()
train_df['prior_question_elapsed_time'] = train_df['prior_question_elapsed_time'].astype('int16')

train_df.info()

## Preprocess

In [None]:
TARGET = 'answered_correctly'
MAX_SEQ = 100
# content_ids = train_df["content_id"].unique()
NUM_QUESTIONS = len(train_df["content_id"].unique()) + 1
NUM_USERS = len(train_df['user_id'].unique())
NUM_LAG1S = train_df['lag1'].max() + 1
NUM_LAG2S = train_df['lag2'].max() + 1
NUM_LAG3S = train_df['lag3'].max() + 1
ELAPSED_TIMES = train_df['prior_question_elapsed_time'].max() + 1
MODEL_BEST = 'model_best.pt'
BS = 1024

In [None]:
%%time

# Creates a Series with user_ids as indices and a tuple of all the content_ids and answered_correctlys as lists
def create_group(df):
    return df[['user_id', 'content_id', 'answered_correctly', 'lag1', 'lag2', 'lag3', 'part', 'prior_question_elapsed_time']].groupby('user_id').apply(lambda r: [
        r['content_id'].values,
        r['answered_correctly'].values,
        r['lag1'].values,
        r['lag2'].values,
        r['lag3'].values,
        r['part'].values,
        r['prior_question_elapsed_time'].values])

train_group = create_group(train_df)

del train_df
gc.collect()

In [None]:
valid_group = train_group.sample(frac=0.03)
train_group = train_group.drop(valid_group.index).reset_index(drop=True)
valid_group.reset_index(drop=True, inplace=True)
train_group.shape, valid_group.shape

In [None]:
class SAINTDataset(Dataset):
    def __init__(self, user_sequences, num_questions, subset='train', max_seq=100, min_seq=10):
        super(SAINTDataset, self).__init__()
        self.max_seq = max_seq
        self.num_questions = num_questions
        self.user_sequences = user_sequences
        self.subset = subset

        self.user_ids = []
        for user_id in user_sequences.index:
            q, _, _, _, _, _, _ = user_sequences[user_id]
            if len(q) < min_seq:
                continue
            self.user_ids.append(user_id)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        # question_id, answered_correctly, lag1, lag2, lag3, part, elapsed_time
        q_, qa_, l1_, l2_, l3_, p_, el_ = self.user_sequences[user_id]
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        l1 = np.zeros(self.max_seq, dtype=int)
        l2 = np.zeros(self.max_seq, dtype=int)
        l3 = np.zeros(self.max_seq, dtype=int)
        p = np.zeros(self.max_seq, dtype=int)
        el = np.zeros(self.max_seq, dtype=int)
        
#         # If there are more questions answered than max_seq, take the last max_seq sequences
        if seq_len >= self.max_seq:
            q[:] = q_[-self.max_seq:]
            qa[:] = qa_[-self.max_seq:]
            l1[:] = l1_[-self.max_seq:]
            l2[:] = l2_[-self.max_seq:]
            l3[:] = l3_[-self.max_seq:]
            p[:] = p_[-self.max_seq:]
            el[:] = el_[-self.max_seq:]
        # If not, map our user_sequences to the tail end of q and qa, the start will be padded with zeros
        else:
            q[-seq_len:] = q_
            qa[-seq_len:] = qa_
            l1[-seq_len:] = l1_
            l2[-seq_len:] = l2_
            l3[-seq_len:] = l3_
            el[-seq_len:] = el_
        
        r = np.zeros(self.max_seq, dtype=int)
        r[1:] = qa[:-1].copy()
        
        return q, r, qa, l1, l2, l3, p, el 

## Define model

In [None]:
class FFN(nn.Module):
    def __init__(self, dim=128):
        super().__init__()
        self.layer1 = nn.Linear(dim, dim)
        self.layer2 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.layer2(   self.relu(   self.layer1(x)))

    
def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class Encoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()
        self.seq_len = seq_len

        self.part_embed = nn.Embedding(10, embed_dim)
        
        self.e_embed = nn.Embedding(n_in, embed_dim)
        self.e_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.e_norm = nn.LayerNorm(embed_dim)
        
        self.e_multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=nheads, dropout=0.2)
        self.m_norm = nn.LayerNorm(embed_dim)
        self.ffn = FFN(embed_dim)
    
    def forward(self, e, p, first_block=True):
        
        if first_block:
            e = self.e_embed(e)
            p = self.part_embed(p)
            e = e + p
         
        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        e_pos = self.e_pos_embed(pos)
        e = e + e_pos
        e = self.e_norm(e)
        e = e.permute(1,0,2) #[bs, s_len, embed] => [s_len, bs, embed]     
        n = e.shape[0]
        
        att_mask = future_mask(n).to(device)
        att_out, _ = self.e_multi_att(e, e, e, attn_mask=att_mask)
        m = e + att_out
        m = m.permute(1,0,2)
        
        o = m + self.ffn(self.m_norm(m))
        
        return o
    
class Decoder(nn.Module):
    def __init__(self, n_in, seq_len=100, embed_dim=128, nheads=4):
        super().__init__()
        self.seq_len = seq_len
        
        self.r_embed = nn.Embedding(n_in, embed_dim)
        self.r_pos_embed = nn.Embedding(seq_len, embed_dim)
        self.r_norm = nn.LayerNorm(embed_dim)
        
        self.l1_embed = nn.Embedding(NUM_LAG1S, embed_dim)
        self.l2_embed = nn.Embedding(NUM_LAG2S, embed_dim)
        self.l3_embed = nn.Embedding(NUM_LAG3S, embed_dim)
        self.el_t_embed = nn.Embedding(ELAPSED_TIMES, embed_dim)
        
        self.r_multi_att1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.r_multi_att2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4, dropout=0.2)
        self.ffn = FFN(embed_dim)
        
        self.r_norm1 = nn.LayerNorm(embed_dim)
        self.r_norm2 = nn.LayerNorm(embed_dim)
        self.r_norm3 = nn.LayerNorm(embed_dim)

    
    def forward(self, r, o, l1, l2, l3, el, first_block=True):
        
        if first_block:
            r = self.r_embed(r)
            l1 = self.l1_embed(l1)
            l2 = self.l2_embed(l2)
            l3 = self.l3_embed(l3)
            el = self.el_t_embed(el)

            r = r + l1 + l2 + l3 + el
  
        pos = torch.arange(self.seq_len).unsqueeze(0).to(device)
        r_pos_embed = self.r_pos_embed(pos)
        r = r + r_pos_embed
        r = self.r_norm1(r) 
        r = r.permute(1,0,2)   
        n = r.shape[0]
   
        att_out1, _ = self.r_multi_att1(r, r, r, attn_mask=future_mask(n).to(device))
        m1 = r + att_out1

        o = o.permute(1,0,2)
        o = self.r_norm2(o)
        att_out2, _ = self.r_multi_att2(m1, o, o, attn_mask=future_mask(n).to(device))
        
        m2 = att_out2 + m1
        m2 = m2.permute(1,0,2)        
        m2 = self.r_norm3(m2)
        
        l = m2 + self.ffn(m2)
        
        return l

# This is an altered version from https://github.com/arshadshk/SAINT-pytorch
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class SAINT(nn.Module):
    def __init__(self, dim_model, num_en, num_de, heads_en, total_ex, total_in, heads_de, seq_len):
        super().__init__()
        
        self.num_en = num_en
        self.num_de = num_de

        self.encoder = get_clones( Encoder(n_in=total_ex, seq_len=seq_len, embed_dim=dim_model, nheads=heads_en) , num_en)
        self.decoder = get_clones( Decoder(n_in=total_in, seq_len=seq_len, embed_dim=dim_model, nheads=heads_de) , num_de)

        self.out = nn.Linear(in_features= dim_model , out_features=1)
    
    def forward(self, in_ex, in_in, l1, l2, l3, p, el):
        
        ## pass through each of the encoder blocks in sequence
        first_block = True
        for x in range(self.num_en):
            if x>=1:
                first_block = False
            in_ex = self.encoder[x](in_ex, p, first_block=first_block)
        
        ## pass through each decoder blocks in sequence
        first_block = True
        for x in range(self.num_de):
            if x>=1:
                first_block = False
            in_in = self.decoder[x]( in_in , in_ex, l1, l2, l3, el, first_block=first_block )

        ## Output layer
        in_in = torch.sigmoid( self.out( in_in ) )
        return in_in.squeeze(-1)

In [None]:
# train_iterator is our dataloader, criterion is nn.BCEWithLogitsLoss
def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(train_iterator)
    for item in tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        l1 = item[3].to(device).long()
        l2 = item[4].to(device).long()
        l3 = item[5].to(device).long()
        p = item[6].to(device).long()
        el = item[7].to(device).long()

        # Zero the gradients in the optimizer
        optim.zero_grad()
        # The results of one forward pass
        output = model(e, r, l1, l2, l3, p, el)
        # Calculate the loss
        loss = criterion(output, torch.sigmoid(label))
        # Calculate the gradients with respect to the loss
        loss.backward()
        # Adjust the parameters to minimize the loss based on these gradients
        optim.step()
        # Add our loss to the list of losses
        train_loss.append(loss.item())

        output = output[:, -1]
        label = label[:, -1] 
        pred = (output >= 0.5).long()
         
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f}'.format(loss))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(train_loss)

    return loss, acc, auc

In [None]:
# https://www.kaggle.com/mpware/sakt-fork
def valid_epoch(model, valid_iterator, criterion, device="cpu"):
    model.eval()

    valid_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    #tbar = tqdm(valid_iterator)
    for item in valid_iterator: # tbar:
        e = item[0].to(device).long()
        r = item[1].to(device).long()
        label = item[2].to(device).float()
        l1 = item[3].to(device).long()
        l2 = item[4].to(device).long()
        l3 = item[5].to(device).long()
        p = item[6].to(device).long()
        el = item[7].to(device).long()

        with torch.no_grad():
            output = model(e, r, l1, l2, l3, p, el)
        loss = criterion(output, torch.sigmoid(label))
        valid_loss.append(loss.item())

        output = output[:, -1] # (BS, 1)
        label = label[:, -1] 
        pred = (output >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.mean(valid_loss)

    return loss, acc, auc


In [None]:
gc.collect()
train_dataset = SAINTDataset(train_group, NUM_QUESTIONS, max_seq=MAX_SEQ)
train_dataloader = DataLoader(train_dataset, batch_size=BS, shuffle=True, num_workers=8)

valid_dataset = SAINTDataset(valid_group, NUM_QUESTIONS, max_seq=MAX_SEQ, subset='valid')
valid_dataloader = DataLoader(valid_dataset, batch_size=BS, shuffle=False, num_workers=8)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAINT(dim_model=128,
            num_en=2,
            num_de=2,
            heads_en=4,
            heads_de=4,
            total_ex=NUM_QUESTIONS, 
            total_in=2,
            seq_len=100
            )

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss() 

model.to(device)
criterion.to(device)

In [None]:
gc.collect()
epochs = 30
history = []
auc_max = -np.inf

for epoch in range(1, epochs+1):
    train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print(f'Epoch {epoch}, train_loss: {train_loss:5f}, train_acc: {train_acc:5f}, train_auc: {train_auc:5f}')
    valid_loss, valid_acc, valid_auc = valid_epoch(model, valid_dataloader, criterion, device)
    print(f'Epoch {epoch}, valid_loss: {valid_loss:5f}, valid_acc: {valid_acc:5f}, valid_auc: {valid_auc:5f}')
    
    lr = optimizer.param_groups[0]['lr']
    history.append({"epoch":epoch, "lr": lr, **{"train_auc": train_auc, "train_acc": train_acc}, **{"valid_auc": valid_auc, "valid_acc": valid_acc}})
    if valid_auc > auc_max:
        print("Epoch#%s, valid loss %.4f, Metric loss improved from %.4f to %.4f, saving model ..." % (epoch, valid_loss, auc_max, valid_auc))
        auc_max = valid_auc
        torch.save(model.state_dict(), MODEL_BEST)
    

### Conclusion

Although I'm happy with how the competition went, I still think I could have improved on my model quite a bit. 

I had issues using a continuous representation of the prior_question_elapsed_time feature, so in the end I just left the categorical version. The SAINT+ paper found that this feature worked marginally better with a continuous representation, so I would have liked to get that working.

I'm also sure that my lag features could have been engineered better. Since several questions were given together in a bundle, they had the same timestamps. So in using a simple lag such as lag = t<sub>n</sub> - t<sub>n-1</sub>, you get a lot of lag=0, which doesn't give much signal to the model. It also feeds the model sequences that can change, depending on how the group was formed. So I would have liked to experiment with this a bit.

Finally, there was still a gap between my LB score and my training scores. So I'm sure there were still some issues to resolve on that front.

If anyone has any comments or questions, I'd love to hear them. Thanks for looking at my kernel!