In [None]:
import gc
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

Thanks for the SAKT Implementation workbook:
https://www.kaggle.com/wangsg/a-self-attentive-model-for-knowledge-tracing/execution

In [None]:
train_df=pd.read_pickle('../input/riiid-trainpkl/riiid_train.pkl.gzip')
questions_df=pd.read_csv('../input/riiid-test-answer-prediction/questions.csv')
questions_df.rename(columns={'question_id': 'content_id'}, inplace=True)

In [None]:
train_df=train_df[train_df.content_type_id==False]
train_df=train_df[['row_id','user_id', 'content_id', 'answered_correctly']].merge(questions_df[['content_id', 'part']], how='left')
train_df=train_df.groupby('user_id').head(100)

print(train_df.shape)
train_df.head()

In [None]:
gc.collect()

In [None]:
train_df=train_df.groupby('user_id').apply(lambda row: (
        row.content_id.values, 
        row.part.values, 
        row.answered_correctly.values))

train_df.head()

# Filter the users with <=5 questions answered

In [None]:
train_df=train_df[train_df.apply(lambda x:x[0].shape[0])>5]
train_df.head()

# Parameters

In [None]:
n_questions=len(questions_df)
n_categories=1+questions_df.part.nunique()
n_responses=3

batch_size=2048


# Dataset

In [None]:
class SAINTDataset(torch.utils.data.Dataset):
    def __init__(self, df, max_seq=100):
        self.user_ids=[]
        self.df=df
        self.max_seq=max_seq
        for user_id in df.index.values:
            self.user_ids.append(user_id)
    
    def __len__(self):
        return len(self.user_ids)
    
    def __getitem__(self, idx):
        user_id=self.user_ids[idx]
        (q_, c_, r_)=self.df[user_id]
        seq_len=len(q_)
        
        q_=torch.as_tensor(q_, dtype=int)
        c_=torch.as_tensor(c_, dtype=int)
        r_=torch.as_tensor(r_, dtype=int)
        
        
        q=torch.zeros(self.max_seq, dtype=int)
        c=torch.zeros(self.max_seq, dtype=int)
        r=torch.zeros(self.max_seq, dtype=int)
        y=torch.zeros(self.max_seq, dtype=int)
        label_mask=torch.zeros(self.max_seq, dtype=bool)
        label_mask[: seq_len]=1
        
        
        r[0]=2 #2-for the start of the sequence
        if seq_len > self.max_seq:
            q[:]=q_[:self.max_seq]
            c[:]=c_[:self.max_seq]
            r[1:]=r_[:self.max_seq-1]
            y[:]=r_[:self.max_seq]
        elif seq_len <= self.max_seq:
            q[:seq_len]=q_
            c[:seq_len]=c_
            r[1:seq_len]=r_[:seq_len-1]
            y[:seq_len]=r_
        
        
        return (q, c, r, y, label_mask)

In [None]:
train, val=train_test_split(train_df, test_size=0.2)

train_dataset=SAINTDataset(train)
train_dataloader=torch.utils.data.DataLoader(train_dataset, 
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True
                                            )
val_dataset=SAINTDataset(val)
val_dataloader=torch.utils.data.DataLoader(val_dataset, 
                                           batch_size=batch_size,
                                           num_workers=4,
                                           shuffle=False,
                                           pin_memory=True
                                          )

# Model

In [None]:
class FFN(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.fc1=nn.Linear(units, units)
        self.relu1=nn.ReLU()
        self.dropout1=nn.Dropout(0.2)
        self.fc2=nn.Linear(units, units)
        
    def forward(self, x):
        x=self.fc1(x)
        x=self.relu1(x)
        x=self.dropout1(x)
        x=self.fc2(x)
        return x

In [None]:
class SAINTDecoder(nn.Module):
    def __init__(self, 
                 n_responses,
                 device='cpu',
                 max_seq=100,
                 num_heads=4,
                 embedd_dim=128,
                 model_dim=128
                ):
        super().__init__()
        self.device=device
        self.r_embedding=nn.Embedding(n_responses, embedd_dim) #Adding the start token
        self.multihead_attn1=nn.MultiheadAttention(model_dim, num_heads, dropout=0.2)
        self.multihead_attn2=nn.MultiheadAttention(model_dim, num_heads, dropout=0.2)
        
        self.ffn1=FFN(model_dim)
        self.ffn2=FFN(model_dim)
        
        self.layer_norm11=nn.LayerNorm(model_dim)
        self.layer_norm12=nn.LayerNorm(model_dim)
        self.layer_norm21=nn.LayerNorm(model_dim)
        self.layer_norm22=nn.LayerNorm(model_dim)
        
        self.droupout=nn.Dropout(0.2)
    
    def model_stack_output(self, q, k, v, multihead_attn, ffn, layer_norm1, layer_norm2):
        seq_len=q.shape[0]
        attn_mask=torch.tensor(np.triu(np.ones((seq_len, seq_len)), k=1)).to(self.device)
        
        attn_out, attn_out_weights=multihead_attn(q, k, v, attn_mask=attn_mask)
        q=layer_norm1(attn_out+q)
        
        ffn_out=ffn(q)
        q=layer_norm2(ffn_out+q)
        return q
    
    def forward(self, pos_r, y_encoder, r):
        r_embedd=self.r_embedding(r)
        x=pos_r+r_embedd
        x=x.permute(1, 0, 2)
        y_encoder=y_encoder.permute(1, 0, 2)
        
        stack_out1=self.model_stack_output(x, x, x, self.multihead_attn1, self.ffn1, self.layer_norm11, self.layer_norm12)
        stack_out2=self.model_stack_output(stack_out1, y_encoder, y_encoder, self.multihead_attn2, self.ffn2, self.layer_norm21, self.layer_norm22)
        
        stack_out2=stack_out2.permute(1, 0, 2)
        out=self.droupout(stack_out2)
        return out

In [None]:
class SAINTEncoder(nn.Module):
    def __init__(self, 
                 n_questions, n_categories,
                 device='cpu',
                 max_seq=100,
                 num_heads=4,
                 embedd_dim=128,
                 model_dim=128
                ):
        
        super().__init__()
        self.device=device
        self.q_embedding=nn.Embedding(n_questions, embedd_dim)
        self.c_embedding=nn.Embedding(n_categories, embedd_dim)
        
        self.multihead_attn1=nn.MultiheadAttention(model_dim, num_heads, dropout=0.2)
        self.ffn1=FFN(model_dim)
        
        self.multihead_attn2=nn.MultiheadAttention(model_dim, num_heads, dropout=0.2)
        self.ffn2=FFN(model_dim)
        
        self.layer_norm11=nn.LayerNorm(model_dim) # layer norm for the attention_layer1
        self.layer_norm12=nn.LayerNorm(model_dim)
        
        self.layer_norm21=nn.LayerNorm(model_dim) # layer norm for attention_layer2
        self.layer_norm22=nn.LayerNorm(model_dim)
    
    def model_stack_output(self, x, multihead_attn, ffn, layer_norm1, layer_norm2):
        seq_len=x.shape[0]
        attn_mask=torch.tensor(np.triu(np.ones((seq_len, seq_len)), k=1)).to(self.device)
        attn_out, attn_out_weights=multihead_attn(x,x,x, attn_mask=attn_mask)
        x=layer_norm1(attn_out+x)
        
        ffn_out=ffn(x)
        x=layer_norm2(ffn_out+x)
        return x
        
        
    def forward(self, pos_q, q, c):
        q_embedd=self.q_embedding(q)
        c_embedd=self.c_embedding(c)
        x=q_embedd+c_embedd+pos_q
        
        
        x=x.permute(1, 0, 2) # [seq_len, batch_size, emb_dim]
        # Stack-1
        stack_out1=self.model_stack_output(x, 
                                           self.multihead_attn1, 
                                           self.ffn1,
                                           self.layer_norm11, self.layer_norm12)
        #Stack-2
        stack_out2=self.model_stack_output(stack_out1, 
                                           self.multihead_attn2, 
                                           self.ffn2,
                                           self.layer_norm21, self.layer_norm22)
        stack_out2=stack_out2.permute(1, 0, 2)
        return stack_out2

In [None]:
class SAINTModel(nn.Module):
    def __init__(self, 
                 n_questions, n_categories, n_responses,
                 device='cpu',
                 max_seq=100, 
                 embedd_dim=128, 
                 encoder_dim=128,  
                 decoder_dim=128,
                 num_heads=4):
        
        super().__init__()
        self.max_seq=max_seq
        self.device=device
        self.pos_embedding=nn.Embedding(max_seq, embedd_dim)
        self.encoder=SAINTEncoder(n_questions,
                                  n_categories,
                                  device=self.device,
                                  max_seq=max_seq, 
                                  num_heads=num_heads,
                                  model_dim=encoder_dim)
        
        
        
        self.decoder=SAINTDecoder(n_responses, 
                                  device=self.device,
                                  max_seq=max_seq,
                                  num_heads=num_heads,
                                  embedd_dim=embedd_dim,
                                  model_dim=decoder_dim
                                 )
        self.out=nn.Linear(decoder_dim, 1)
        
    def forward(self, q, c, r):
        pos_ids=torch.arange(self.max_seq).to(self.device)
        pos_x=self.pos_embedding(pos_ids)
        y_encoder=self.encoder(pos_x, q, c)
        y_decoder=self.decoder(pos_x, y_encoder, r)
        y_out=self.out(y_decoder)
        return y_out.squeeze(-1)

# Train Model

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
model=SAINTModel(n_questions, n_categories, n_responses, device=device)

optimizer=torch.optim.Adam(model.parameters())
criterion=nn.BCEWithLogitsLoss()


model.to(device)
criterion.to(device)

num_epochs=3

In [None]:
def train_epoch():
    train_loss=[]
    model.train()
    
    for (q, c, r, y, label_mask) in train_dataloader:
        q=q.to(device)
        c=c.to(device)
        r=r.to(device)
        y=y.to(device)
        label_mask=label_mask.to(device)
        
        optimizer.zero_grad()
        yout=model(q, c, r)
        y=torch.masked_select(y, label_mask)
        yout=torch.masked_select(yout, label_mask)
        
        yout=yout.float().to(device)
        y=y.float().to(device)
        
        loss_=criterion(yout, y)
        loss_.backward()
        optimizer.step()
        train_loss.append(loss_.item())
        
    return np.mean(train_loss)

def val_epoch():
    val_loss=[]
    model.eval()
    
    with torch.no_grad():
        for (q, c, r, y, label_mask) in train_dataloader:
            q=q.to(device)
            c=c.to(device)
            r=r.to(device)
            yout=model(q, c, r).float()

            y=torch.masked_select(y, label_mask)
            yout=torch.masked_select(yout, label_mask)

            yout=yout.float().to(device)
            y=y.float().to(device)
            
            loss_=criterion(yout, y)    
            val_loss.append(loss_.item())
            
    return np.mean(val_loss)

In [None]:
best_score=None
for i in range(num_epochs):
    epoch_start=time.time()
    
    train_loss=train_epoch()
    val_loss=val_epoch()
    
    epoch_end=time.time()
    print('Time To Run Epoch:{}'.format( (epoch_end - epoch_start)/60) )
    print("Epoch:{} | Train Loss: {:.4f} | Val Loss:{:.4f}".format(i, train_loss, val_loss))
    
    if (best_score is None) or  (best_score>val_loss) :
        best_score=val_loss
        torch.save(model.state_dict(), "saint{}.pth".format(i))
    gc.collect()