# Import everything now...

より良くするなら...  
- lagtimeをtask_container_idごとに作り直す

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import copy
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import QuantileTransformer
import gc

# Configure constants

In [None]:
class config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MAX_SEQ = 100 
    EMBED_DIMS = 512 
    HEADS = 8
    NUM_ENCODER = NUM_DECODER = 2 # 4
    BATCH_SIZE = 256 # 64
    TRAIN_FILE = "../input/riiid-test-answer-prediction/train.csv"
    TOTAL_EXE = 13523
    TOTAL_CAT = 8 

# Dataset

In [None]:
class RiiidDataset(Dataset):
    def __init__(self,samples,max_seq,start_token=0): 
        super().__init__()
        self.samples = samples
        self.max_seq = max_seq
        self.start_token = start_token
        self.data = []
        for id in self.samples.index:
            exe_ids,answers,ela_time,lagtimes,parts = self.samples[id]
            if len(exe_ids)>max_seq:
                for l in range((len(exe_ids)+max_seq-1)//max_seq):
                    self.data.append((exe_ids[l:l+max_seq],answers[l:l+max_seq],ela_time[l:l+max_seq],lagtimes[l:l+max_seq],parts[l:l+max_seq]))
            elif len(exe_ids)<self.max_seq and len(exe_ids)>10:
                self.data.append((exe_ids,answers,ela_time,lagtimes,parts))
            else :
                continue
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        question_ids,answers,ela_time,l_time,part = self.data[idx]
        seq_len = len(question_ids)

        # 長さが足りない部分はpadding(0埋め)されている
        exe_ids = np.zeros(self.max_seq,dtype=int)
        ans = np.zeros(self.max_seq,dtype=int)
        elapsed_time = np.zeros(self.max_seq,dtype=int)
        lagtime = np.zeros(self.max_seq,dtype=int)
        exe_part = np.zeros(self.max_seq,dtype=int)
        if seq_len<self.max_seq:
            # 後ろ詰め
            exe_ids[-seq_len:] = question_ids
            ans[-seq_len:] = answers
            elapsed_time[-seq_len:] = ela_time
            lagtime[-seq_len:] = l_time
            exe_part[-seq_len:] = part
        else:
            exe_ids[:] = question_ids[-self.max_seq:]
            ans[:] = answers[-self.max_seq:]
            elapsed_time[:] = ela_time[-self.max_seq:]
            lagtime[:] = l_time[-self.max_seq:]
            exe_part[:] = part[-self.max_seq:]

        input_ela_time = np.zeros(self.max_seq,dtype=int)
        input_ela_time = np.insert(elapsed_time,0,self.start_token) # np.insert(配列、挿入位置、挿入値)
        input_ela_time = np.delete(input_ela_time,-1) # np.delete(配列、削除位置)
        
        input_lag_time = np.zeros(self.max_seq,dtype=int)
        input_lag_time = np.insert(lagtime,0,self.start_token)
        input_lag_time = np.delete(input_lag_time,-1)

        inputs = {
            "input_ids":exe_ids,
            "input_ela_time":input_ela_time.astype(np.int),
            "input_lag_time":input_lag_time.astype(np.int),
            "input_part":exe_part
        }
        answers = np.append([0],ans[:-1]) #start token
        assert ans.shape[0]==answers.shape[0] and answers.shape[0]==input_ela_time.shape[0] and answers.shape[0]==input_lag_time.shape[0], "both ans and label should be same len with start-token"
        return inputs,answers,ans


# SAINT+ model

In [None]:
class FFN(nn.Module):
    def __init__(self,in_feat):
        super(FFN,self).__init__()
        self.linear1 = nn.Linear(in_feat,in_feat)
        self.linear2 = nn.Linear(in_feat,in_feat)
        self.drop = nn.Dropout(0.2)
  
    def forward(self,x):
        out = F.relu(self.drop(self.linear1(x)))
        out = self.linear2(out)
        return out 


class EncoderEmbedding(nn.Module):
    """SAINTでは、エンコーダーでは演習情報しか利用しない"""
    def __init__(self,n_exercises,n_parts,n_dims,seq_len):
        super(EncoderEmbedding,self).__init__()
        self.n_dims = n_dims
        self.seq_len = seq_len
        # 演習情報についてembedding（演習ID、パートID、positional encoding）
        self.exercise_embed = nn.Embedding(n_exercises,n_dims)
        self.part_embed = nn.Embedding(n_parts,n_dims)
        self.position_embed = nn.Embedding(seq_len,n_dims)

    def forward(self,exercises,parts):
        e = self.exercise_embed(exercises)
        c = self.part_embed(parts)
        seq = torch.arange(self.seq_len,device=config.device).unsqueeze(0)
        p = self.position_embed(seq)
        return p + c + e

class DecoderEmbedding(nn.Module):
    """
    SAINTのデコーダーではレスポンス情報しか利用しない
    lagtime: categorical embedding,
    elapsed_time: continuas embedding
    """
    def __init__(self,n_responses,n_lags,n_dims,seq_len):
        super(DecoderEmbedding,self).__init__()
        self.n_dims = n_dims
        self.seq_len = seq_len
        # 正解、不正解の羅列が来るので、n_responsesは3(2種類+1)
        self.response_embed = nn.Embedding(n_responses,n_dims)
        self.ela_time_embed = nn.Linear(1,n_dims) # continuas embedding
        self.lagtime_embed = nn.Embedding(n_lags,n_dims)
        self.position_embed = nn.Embedding(seq_len,n_dims)

    def forward(self,responses,elaps_times,lagtimes):
        e = self.response_embed(responses)
        el = self.ela_time_embed(elaps_times)
        lagtimes = torch.true_divide(lagtimes, 1000)
        lagtimes = torch.round(lagtimes)
        lagtimes = torch.where(lagtimes.float() <= 300, lagtimes, torch.tensor(300.0).to(config.device)).long()
        la = self.lagtime_embed(lagtimes)
        seq = torch.arange(self.seq_len,device=config.device).unsqueeze(0)
        p = self.position_embed(seq)
        return p + e + el + la

# Final Model with Trainer

In [None]:
def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)

In [None]:
# Main model for training 
class PlusSAINTModule(pl.LightningModule):
    def __init__(self):
        super(PlusSAINTModule,self).__init__()
        self.loss = nn.BCEWithLogitsLoss()
        self.transformer = nn.Transformer(nhead=config.HEADS, d_model = config.EMBED_DIMS, num_encoder_layers= config.NUM_ENCODER, num_decoder_layers= config.NUM_DECODER, dropout = 0.2)
        self.encoder_embedding = EncoderEmbedding(n_exercises=config.TOTAL_EXE,
                                                  n_parts=config.TOTAL_CAT,
                                                  n_dims=config.EMBED_DIMS,seq_len=config.MAX_SEQ)
        self.decoder_embedding = DecoderEmbedding(n_responses=3,
                                                  n_lags=301,
                                                  n_dims=config.EMBED_DIMS,
                                                  seq_len=config.MAX_SEQ)
        self.layer_normal = nn.LayerNorm(config.EMBED_DIMS)
        self.ffn = FFN(config.EMBED_DIMS)
        self.pred = nn.Linear(config.EMBED_DIMS, 1)

    def forward(self,x,y): 
        ids = x["input_ids"].long().to(config.device)
        part= x['input_part'].long().to(config.device)
        enc = self.encoder_embedding(exercises=ids,parts=part)

        ela_time= x['input_ela_time'].unsqueeze(-1).float().to(config.device)
        lag_time= x['input_lag_time'].long().to(config.device)
        responses=y.long().to(config.device)
        dec = self.decoder_embedding(responses=responses,elaps_times=ela_time,lagtimes=lag_time)

        enc = enc.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        dec = dec.permute(1, 0, 2)
        tmp = enc.size(0)
        mask = future_mask(tmp)
        mask = mask.to(config.device)
        att_output = self.transformer(enc, dec, src_mask=mask, tgt_mask=mask, memory_mask = mask)
        att_output = self.layer_normal(att_output)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.layer_normal(x + att_output)

        out = self.pred(x)
        return out.squeeze()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),0.0001)
  
    def training_step(self,batch,batch_ids):
        input,ans,labels = batch
        target_mask = (input["input_ids"]!=0)
        out = self(input,ans)
        loss = self.loss(out.view(-1).float(),labels.view(-1).float()) 
        out = torch.masked_select(out,target_mask)
        out = torch.sigmoid(out) 
        labels = torch.masked_select(labels,target_mask)    
        self.log("train_loss",loss,on_step=True,prog_bar=True)
        return {"loss":loss,"outs":out,"labels":labels}
  
    def validation_step(self,batch,batch_ids):
        input,ans,labels = batch
        target_mask = (input["input_ids"]!=0)
        out = self(input,ans)
        loss = self.loss(out.view(-1).float(),labels.view(-1).float())
        out = torch.masked_select(out,target_mask)
        out = torch.sigmoid(out) 
        labels = torch.masked_select(labels,target_mask) 
        self.log("val_loss",loss,on_step=True,prog_bar=True)
        output = {"outs":out,"labels":labels}
        return output
  
    def validation_epoch_end(self,validation_ouput): 
        out = torch.cat([i["outs"] for i in validation_ouput]).view(-1) 
        labels = torch.cat([i["labels"] for i in validation_ouput]).view(-1)
        auc = roc_auc_score(labels.cpu().detach().numpy(),out.cpu().detach().numpy())
        self.print("val auc",auc)

# Dataloader

In [None]:
def get_train_df():
    # leakの問題考えると、task_container_idで考えるべきだが実装だるくなったので考えない
    dtypes = {
        'timestamp': 'int64',
        'user_id': 'int32' ,
        'content_id': 'int16',
        'answered_correctly':'int8',
        'content_type_id':'int8',
        'prior_question_elapsed_time':'float32'
    }
    print("loading csv.....")
    train_df = pd.read_csv(config.TRAIN_FILE,usecols=[1,2,3,4,7,8],dtype=dtypes)
    
    print('concat part feature')
    part_df = pd.read_pickle('../input/riiid-train-part/train_df_part.pkl')
    train_df = pd.concat([train_df, part_df], axis=1)
    del part_df
    gc.collect()
    
    print("shape of dataframe :",train_df.shape)
    
    train_df.prior_question_elapsed_time.fillna(0,inplace=True)
    train_df.prior_question_elapsed_time /=3600
    train_df.prior_question_elapsed_time = train_df.prior_question_elapsed_time.astype(np.int)
    
    print("make feature: lagtime")
    train_df['lagtime'] = train_df.groupby(['user_id'])['timestamp'].shift()
    train_df['lagtime']=train_df['timestamp']-train_df['lagtime']
    train_df.loc[train_df.lagtime < 0] = 0 # なぜか負の値が入り込んでる（要調査）
    train_df.loc[train_df.lagtime > 300] = 300
    train_df['lagtime'].fillna(0, inplace=True)
    train_df.lagtime=train_df.lagtime.astype('int32')
    gc.collect()

    print('delete lecture rows')
    train_df = train_df[train_df.content_type_id==0].reset_index(drop=True)
    
    print('del feature: timestamp, content_type_id')
    train_df.drop(['timestamp', 'content_type_id'], axis=1, inplace=True)
    gc.collect()
    
    n_skills = train_df.content_id.nunique() 
    print("no. of skills :",n_skills)
    print("shape after exlusion:",train_df.shape)
    return train_df

train_df = get_train_df()

In [None]:
def get_group(train_df):
    print("Grouping users...") 
    group = train_df[["user_id","content_id","answered_correctly","prior_question_elapsed_time", "lagtime","part"]]\
                    .groupby("user_id")\
                    .apply(lambda r: (r.content_id.values,
                                      r.answered_correctly.values,
                                      r.prior_question_elapsed_time.values,
                                      r.lagtime.values,
                                      r.part.values))
    del train_df
    gc.collect()
    return group

group = get_group(train_df)

In [None]:
def get_dataloaders(group):
    print("splitting") 
    train,val = train_test_split(group,test_size=0.2) 
    print("train size: ",train.shape,"validation size: ",val.shape)
    train_dataset = RiiidDataset(train,max_seq = config.MAX_SEQ)
    val_dataset = RiiidDataset(val,max_seq = config.MAX_SEQ)
    train_loader = DataLoader(train_dataset,
                          batch_size=config.BATCH_SIZE,
                          num_workers=8,
                          shuffle=True) 
    val_loader = DataLoader(val_dataset,
                          batch_size=config.BATCH_SIZE,
                          num_workers=8,
                          shuffle=False)
    del train_dataset,val_dataset 
    gc.collect() 
    return train_loader, val_loader 
train_loader, val_loader = get_dataloaders(group) 

In [None]:
del train_df, group

In [None]:
next(iter(train_loader))

# Training

In [None]:
saint_plus = PlusSAINTModule()
trainer = pl.Trainer(gpus=-1,max_epochs=10,progress_bar_refresh_rate=21) 
trainer.fit(model=saint_plus,
            train_dataloader=train_loader, 
            val_dataloaders = [val_loader,]) 
trainer.save_checkpoint("model.pt") 