In [None]:
import numpy as np
import pandas as pd
import gc
import riiideducation
from collections import defaultdict
from tqdm.notebook import tqdm
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

random.seed(34)

In [None]:
%%time
# Read questions
q_lookup_df = pd.read_pickle("../input/riiid-fe-part-i/q_lookup_df.pickle")
#
q_lookup_df.head()

In [None]:
SEQ_LENGTH = 100
seq_prep = SEQ_LENGTH - 1
n_question = q_lookup_df.content_id.max()
n_bundle = q_lookup_df.bundle_id.max() + 1
n_qet = q_lookup_df.q_encoded_tags.max() + 1
n_part = q_lookup_df.part.nunique()
n_q_attempt = 5 # max value
lag_time_bins = 143
elapsed_time_bins = 27
#
n_answer = 2
print("lag_time_bins, elapsed_time_bins", lag_time_bins, elapsed_time_bins)
print("number of questions, bundles, encoded_tags, parts, n_q_attempt, n_answer", n_question, n_bundle, n_qet, n_part, n_q_attempt, n_answer)

In [None]:
class FCN(nn.Module):
    def __init__(self, embed_dim):
        super(FCN, self).__init__()
        ###################################### Layers ######################################
        self.linear1 = nn.Linear(embed_dim, embed_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return self.dropout(x)

In [None]:
class SAINTEncoder(nn.Module):
    def __init__(self, seq_len, embed_dim, num_heads, dropout, device="cpu"):
        super(SAINTEncoder, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.device=device
        ###################################### Layers ######################################
        self.multi_att1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        #
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        #
        self.fcn = FCN(embed_dim)
        
    def forward(self, X, attn_mask):
        # Permute operation must before and for multi head attention
        X = X.permute(1, 0, 2) # [bs, s_len, embed] => [s_len, bs, embed]
        norm_X = self.layer_norm1(X)
        #
        att_output, _ = self.multi_att1(norm_X, norm_X, norm_X, attn_mask=attn_mask)
        #
        M = (att_output + X).permute(1, 0, 2) # [s_len, bs, embed] => [bs, s_len, embed]
        # Norm 2
        norm_M = self.layer_norm2(M)
        #
        fcn_output = self.fcn(norm_M)
        #
        return fcn_output + M

In [None]:
class SAINTDecoder(nn.Module):
    def __init__(self, seq_len, embed_dim, num_heads, dropout, device="cpu"):
        super(SAINTDecoder, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.device=device
        ###################################### Layers ######################################
        self.multi_att1 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        self.multi_att2 = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        #
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)
        self.layer_norm4 = nn.LayerNorm(embed_dim)
        #
        self.fcn = FCN(embed_dim)
        
    def forward(self, X, encoder_output, attn_mask):
        #
        encoder_output = encoder_output.permute(1, 0, 2) # [bs, s_len, embed] => [s_len, bs, embed]
        X = X.permute(1, 0, 2) # [bs, s_len, embed] => [s_len, bs, embed]
        norm_X = self.layer_norm1(X)
        #
        att_output, _ = self.multi_att1(norm_X, norm_X, norm_X, attn_mask=attn_mask)
        #
        M1 = att_output + X
        # Norm 2, 3
        norm_M1 = self.layer_norm2(M1)
        norm_encoder_output = self.layer_norm3(encoder_output)
        #
        att_output, _ = self.multi_att2(norm_M1, norm_encoder_output, norm_encoder_output, attn_mask=attn_mask)
        #
        M2 = (att_output + M1).permute(1, 0, 2) # [s_len, bs, embed] => [bs, s_len, embed]
        # Norm 4
        norm_M2 = self.layer_norm4(M2)
        # Final output
        fcn_output = self.fcn(norm_M2)
        #
        return fcn_output + M2

In [None]:
class SAINTModel(nn.Module):
    def __init__(self, seq_len=SEQ_LENGTH, n_question=n_question, n_answer=n_answer, 
                 n_bundle=n_bundle, n_part=n_part, n_q_attempt=n_q_attempt, n_qet=n_qet,
                 lag_time_bins=lag_time_bins, elapsed_time_bins=elapsed_time_bins,
                 embed_dim=256, n_head=8, dropout=0.1, device="cpu"):
        #
        super(SAINTModel, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.device = device
        self.n_head = n_head
        #
        self.pos_embedding = nn.Embedding(seq_len + 1, embed_dim)
        # known
        self.question_embedding = nn.Embedding(n_question + 1, embed_dim)
        self.part_embedding = nn.Embedding(n_part + 1, embed_dim)
        self.bundle_embedding = nn.Embedding(n_bundle + 1, embed_dim)
        self.q_attempt_embedding = nn.Embedding(n_q_attempt + 1, embed_dim)
        self.lag_embedding = nn.Embedding(lag_time_bins + 1, embed_dim)
        self.qet_embedding = nn.Embedding(n_qet + 1, embed_dim)
        self.h_mean_embedding = nn.Linear(1, embed_dim, bias=False)
        # future
        # 0, 1 values, 2 for padding and mask
        self.answer_embedding = nn.Embedding(n_answer + 1, embed_dim)
        self.elapsed_embedding = nn.Embedding(elapsed_time_bins + 1, embed_dim)        
        #
        ###################################### Layers ######################################
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        # Encoders
        self.encoder1 = SAINTEncoder(seq_len, embed_dim, num_heads=self.n_head, dropout=dropout, device=device)
        self.encoder2 = SAINTEncoder(seq_len, embed_dim, num_heads=self.n_head, dropout=dropout, device=device)
        # Decoders
        self.decoder1 = SAINTDecoder(seq_len, embed_dim, num_heads=self.n_head, dropout=dropout, device=device)
        self.decoder2 = SAINTDecoder(seq_len, embed_dim, num_heads=self.n_head, dropout=dropout, device=device)
        #
        self.pred = nn.Linear(embed_dim, 1)
    
    def forward(self, x):
        task_ids = x[0]
        # Longs
        question_ids = x[1]
        bundle_ids = x[2]
        parts = x[3]
        q_attempts = x[4]
        lag_times = x[5]
        qets = x[6]
        elapsed_times = x[7]
        answers = x[8]
        # Floats
        h_means = x[9]
        #
        pos1 = torch.arange(1, self.seq_len+1, device=self.device).unsqueeze(0)
        ab_pos1_em = self.pos_embedding(pos1)
        #
        q_em = self.question_embedding(question_ids)
        bundle_em = self.bundle_embedding(bundle_ids)
        part_em = self.part_embedding(parts)
        q_att_em = self.q_attempt_embedding(q_attempts)
        lag_em = self.lag_embedding(lag_times)
        qet_em = self.qet_embedding(qets)
        h_mean_em = self.h_mean_embedding(h_means.view(-1, self.seq_len, 1))        
        known_feats = ab_pos1_em + q_em + bundle_em + part_em +\
                        q_att_em + lag_em + qet_em + h_mean_em
        #
        pos2 = torch.arange(0, self.seq_len, device=self.device).unsqueeze(0)
        ab_pos2_em = self.pos_embedding(pos2)
        #
        answer_bundle_ids = torch.roll(bundle_ids.detach().clone(), 1, dims=1)
        answer_bundle_ids[:, 0] = 0
        answer_bundle_em = self.bundle_embedding(answer_bundle_ids)
        aq_em = self.answer_embedding(answers)
        elapsed_em = self.elapsed_embedding(elapsed_times)
        future_feats = ab_pos2_em + aq_em + elapsed_em + answer_bundle_em
        #
        attn_mask = torch.triu(torch.ones((self.seq_len, self.seq_len), device=self.device), diagonal=1).bool()
        #
        encoder_output = self.encoder1(known_feats, attn_mask)
        encoder_output = self.encoder2(encoder_output, attn_mask)
        #
        output = self.decoder1(future_feats, encoder_output, attn_mask)
        output = self.decoder2(output, encoder_output, attn_mask)
        #
        output = self.layer_norm1(output)
        output = self.pred(output)
        return output.squeeze(-1)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAINTModel(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
criterion = nn.BCEWithLogitsLoss()

model.to(device)
criterion.to(device)

In [None]:
################################ INIT ################################
tc = torch.zeros((3, SEQ_LENGTH), device=device)
q = torch.zeros((3, SEQ_LENGTH), device=device)
b = torch.zeros((3, SEQ_LENGTH), device=device)
p = torch.zeros((3, SEQ_LENGTH), device=device)
qatt = torch.zeros((3, SEQ_LENGTH), device=device)
lag = torch.zeros((3, SEQ_LENGTH), device=device)
qet = torch.zeros((3, SEQ_LENGTH), device=device)
elapsed = torch.zeros((3, SEQ_LENGTH), device=device)
aq = torch.zeros((3, SEQ_LENGTH), device=device)
h_mean = torch.zeros((3, SEQ_LENGTH), device=device)
output = model([
    # Longs
    tc.long(), q.long(), b.long(), p.long(), 
    qatt.long(), lag.long(), qet.long(),
    elapsed.long(), aq.long(),
    # Floats
    h_mean.float(),
])
output.shape

In [None]:
def train_epoch(model, train_iterator, optim, criterion, device="cpu"):
    model.train()
    #
    total_loss = []
    labels = []
    preds = []
    result = 0
    total = 0
    #
    progress = tqdm(train_iterator)
    for item in progress:
        inputs = [inp.to(device).long() for inp in item[0][:-1]] +\
                    [inp.to(device).float() for inp in item[0][-1:]]
        label = item[1].to(device).float()
        mask = item[2].to(device).bool()
        #
        optim.zero_grad()
        output = model(inputs)
        #
        loss = criterion(output[mask], label[mask])
        loss.backward()
        optim.step()
        total_loss.append(loss.detach().item())
        #
        output = output[:, -1]
        label = label[:, -1]
        #
        pred = (torch.sigmoid(output) >= 0.5).long()
        #
        result += (pred == label).sum().item()
        total += len(label)
        #
        labels.extend(label.detach().cpu().numpy())
        preds.extend(output.detach().cpu().numpy())
        #
        progress.set_description('loss - {:.4f}, acc - {:.4f}'.format(loss, result / total))
    #
    acc = result / total
    auc = roc_auc_score(labels, preds)
    loss = np.mean(total_loss)
    #
    return loss, acc, auc

In [None]:
def valid_epoch(model, val_dataset, device="cpu"):
    model.eval()
    #
    batch_size = 1024
    features, org_labels = val_dataset
    total_epoch = int(np.ceil(len(features[0]) / batch_size))
    #
    labels = []
    preds = []
    result = 0
    total = 0
    #
    progress = tqdm(range(total_epoch))
    for i in progress:
        s = i * batch_size
        e = s + batch_size
        #
        inputs = [torch.from_numpy(inp[s:e]).to(device).long() for inp in features[:-1]] +\
                    [torch.from_numpy(inp[s:e]).to(device).float() for inp in features[-1:]]
        label = torch.from_numpy(org_labels[s:e]).to(device).float()
        #
        with torch.no_grad():
            output = model(inputs)
        #
        output = torch.sigmoid(output[:, -1])
        #
        pred = (output >= 0.5).long()
        #
        result += (pred == label).sum().item()
        total += len(label)
        #
        labels.extend(label.detach().cpu().numpy())
        preds.extend(output.detach().cpu().numpy())
        #
        progress.set_description('acc - {:.4f}'.format(result / total))
    #
    acc = result / total
    auc = roc_auc_score(labels, preds)
    return acc, auc

In [None]:
def save_model(path, model, optimizer, scheduler, epoch, loss):
    torch.save({
        'epoch': epoch + 1, # beacuse last epoch already completed
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
    }, path)

In [None]:
"""
epochs = 25
for epoch in range(epochs):
    loss, acc, auc = train_epoch(model, dataloader, optimizer, criterion, device)
    print("{}/{} train_loss - {:.3f} train_acc - {:.4f} train_auc - {:.4f}".format(epoch+1, epochs, loss, acc, auc))
    acc, auc = valid_epoch(model, val_dataset, device)
    print("{}/{} val_acc - {:.4f} val_auc - {:.4f}".format(epoch+1, epochs, acc, auc))
    if epoch >= 15:
        save_model("SAINT_model_{}.pt".format(epoch), model, optimizer, scheduler, epoch, loss)
    scheduler.step()
    print("lr:", scheduler.get_last_lr())
"""