In [None]:
import gc
import os
import time
import numpy as np
import pandas as pd
import riiideducation

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# Model

In [None]:
class FFN(nn.Module):
    def __init__(self, emb_dim):
        super(FFN, self).__init__()
        
        self.fc1=nn.Linear(emb_dim, emb_dim)
        self.relu1=nn.ReLU()
        self.dropout1=nn.Dropout(0.2)
        self.fc2=nn.Linear(emb_dim, emb_dim)
        self.dropout2=nn.Dropout(0.2)
        
    def forward(self, x):
        x=self.fc1(x)
        x=self.relu1(x)
        x=self.dropout1(x)
        
        x=self.fc2(x)
        x=self.dropout2(x)
        return x


class SAKT(nn.Module):
    def __init__(self, 
                 n_questions,
                 n_parts,
                 n_responses,
                 device='cpu',
                 emb_dim=128,
                 model_dim=128,
                 num_heads=8,
                 max_seq=100):
        
        super(SAKT, self).__init__()
        
        self.pos_idx=torch.arange(max_seq).to(device)
        self.n_questions=n_questions
        self.n_parts=n_parts
        self.n_responses=n_responses
        self.max_seq=max_seq
        self.device=device
        
        self.emb_dim=emb_dim
        self.model_dim=model_dim
        
        self.pos_embedding=nn.Embedding(max_seq, emb_dim)
        self.q_embedding=nn.Embedding(n_questions, emb_dim)
        self.p_embedding=nn.Embedding(n_parts+1, emb_dim)
        self.r_embedding=nn.Embedding(n_responses, emb_dim)
        
        
        self.multihead_attn=nn.MultiheadAttention(model_dim, num_heads=num_heads, dropout=0.2)
        self.layernorm1=nn.LayerNorm(model_dim)
        
        self.dropout1=nn.Dropout(0.2)

        self.ffn=FFN(model_dim)
        self.layernorm2=nn.LayerNorm(model_dim)
        
        self.dropout2=nn.Dropout(0.2)
        self.out = nn.Linear(model_dim, 1)
    
    def get_attention_mask(self, s):
        attn_mask=torch.tensor(np.triu(np.ones((s, s)), k=1).astype('bool'))
        attn_mask=attn_mask.to(self.device)
        return attn_mask
    
    def forward(self, q, p, r):
        pos_embedd=self.pos_embedding(self.pos_idx)
        q_embedd=self.q_embedding(q)
        p_embedd=self.p_embedding(p)
        r_embedd=self.r_embedding(r)
        
        query=q_embedd+p_embedd
        x=pos_embedd+q_embedd+p_embedd+r_embedd
        attn_mask=self.get_attention_mask(q.size(1))
        
        query=query.permute(1, 0, 2)
        x=x.permute(1, 0, 2)
        
        attn_output, attn_weights=self.multihead_attn(query, x, x, attn_mask=attn_mask)
        attn_output=self.layernorm1(query+attn_output)
        
        ffn_out=self.ffn(attn_output)
        y=self.layernorm2(attn_output+ffn_out)
        
        y=y.permute(1, 0, 2)
        yout=self.out(y).squeeze(-1)
        return yout

In [None]:
%%time
df=pd.read_pickle('../input/riiid-trainpkl/riiid_train.pkl.gzip')
df=df[df.content_type_id==False][['user_id', 'content_id', 'answered_correctly']].copy()
questions_df=pd.read_csv('../input/riiid-test-answer-prediction/questions.csv')


questions_df.rename(columns={'question_id': 'content_id'}, inplace=True)
df=df.merge(questions_df[['content_id', 'part']])

gc.collect()

In [None]:
n_questions=df.content_id.nunique()
n_parts=df.part.nunique()
n_responses=3

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
group=df.groupby('user_id').apply(lambda row: (row.content_id.values[-100:],
                                               row.part.values[-100:],
                                               row.answered_correctly.values[-100:]))

group.head()
del df
gc.collect()

In [None]:
model=SAKT(n_questions,n_parts,n_responses,device=device).to(device)
model.load_state_dict(torch.load('../input/sakt-saint-randomstate-v1/sakt.pth'))


In [None]:
env = riiideducation.make_env()
iter_test = env.iter_test()


In [None]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_df, max_seq=100):
        self.test_df=test_df
        self.max_seq=max_seq
    def __len__(self):
        return len(self.test_df)
    def __getitem__(self, idx):
        row=self.test_df.iloc[idx]
        content_id=row.content_id
        part=row.part
        prev_seq=row.prev_seq
        
        (q_, p_, r_)=prev_seq
        seq_len=q_.size
        
        q_=torch.tensor(q_, dtype=int)
        p_=torch.tensor(p_, dtype=int)
        r_=torch.tensor(r_, dtype=int)
        
        q=torch.zeros(self.max_seq, dtype=int)
        p=torch.zeros(self.max_seq, dtype=int)
        r=torch.zeros(self.max_seq, dtype=int)
        
        label_mask=0
        
        if seq_len == 0:
            q[0]=content_id
            p[0]=part
            r[0]=2
            label_mask=0
        elif seq_len == self.max_seq:
            q[:-1]=q_[1:]
            p[:-1]=p_[1:]
            r[:]=r_[:]
            q[-1]=content_id
            p[-1]=part
            label_mask=seq_len-1
        else:
            q[:seq_len]=q_[:]
            p[:seq_len]=p_[:]
            r[1:seq_len+1]=r_[:]
            
            q[seq_len]=content_id
            p[seq_len]=part
            r[0]=2
            label_mask=seq_len
        return (q, p, r, label_mask)

In [None]:
def update_group(test_df, prev_test_df):
    if prev_test_df is None:
        return
    
    prev_answered_correctly=eval(test_df.prior_group_answers_correct.values[0])
    prev_test_df['answered_correctly']=prev_answered_correctly
    prev_test_df=prev_test_df[prev_test_df.content_type_id==0]
    
    
    prev_group=prev_test_df.groupby('user_id').apply(lambda row: (row.content_id.values[-100:],
                                                                  row.part.values[-100:],
                                                                  row.answered_correctly.values[-100:]))
    
    for user_id in prev_group.index.values:
        if user_id not in group.index:
            group[user_id]=prev_group[user_id]
        else:
            (prev_q, prev_p, prev_r)=prev_group[user_id]
            group[user_id]=(
                np.append(group[user_id][0], prev_q),
                np.append(group[user_id][1], prev_p),
                np.append(group[user_id][2], prev_r),
            )
            
        if len(group[user_id][0]) > 100:
            new_q=group[user_id][0][-100:]
            new_p=group[user_id][1][-100:]
            new_r=group[user_id][2][-100:]
            group[user_id]=(new_q, new_p, new_r)

In [None]:
prev_test_df=None

for (test_df, sample_prediction_df) in iter_test:
    test_df=test_df[['row_id', 'user_id', 'content_id', 
                     'content_type_id', 'prior_group_answers_correct']].merge(
        questions_df[['content_id', 'part']],how='left',on='content_id')
    test_df.part=test_df.part.fillna(5)
    
    update_group(test_df, prev_test_df)
    prev_test_df=test_df.copy()


    test_df=test_df[test_df.content_type_id==0]
    test_df['prev_seq']=test_df.user_id.apply(lambda user_id: group[user_id] if user_id in group else (np.array([]), np.array([]), np.array([])))
    test_dataset=TestDataset(test_df, max_seq=100)
    test_dataloader=torch.utils.data.DataLoader(test_dataset,
                                                batch_size=40960, shuffle=False, pin_memory=True, num_workers=4)


    model.eval()
    y_answered=[]
    
    with torch.no_grad():
        for (q, p, r, label_mask) in test_dataloader:
            q=q.to(device)
            p=p.to(device)
            r=r.to(device)
            y=model(q, p, r)
            
            y_answered.extend([torch.sigmoid(y[idx][label_id]).cpu().item() for idx, label_id in enumerate(label_mask)])
    test_df['answered_correctly']=y_answered
    env.predict(test_df[['row_id', 'answered_correctly']])
    del test_df
    gc.collect()