In [None]:
# https://www.kaggle.com/c/riiid-test-answer-prediction/discussion/210276
# https://medium.com/inside-machine-learning/what-is-a-transformer-d07dd1fbec04

# Multihead vs Transformer?
# This notebook seems to indicate a Transformer consists of the encoder and decoder blocks:
# https://www.kaggle.com/m10515009/saint-is-all-you-need-training-private-0-801


In [None]:
import gc
import pandas as pd
import numpy as np
import sklearn.metrics
import tqdm

import matplotlib.pyplot

import torch

settings = {}
settings['seq_len'] = 160
settings['n_content_id'] = 13525
settings['batch_size'] = 100
settings['embed_dim'] = 200
settings['n_train_rows'] = 5 * 1000000
settings['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dtype = {'timestamp':'int64', 
         'user_id':'int32' ,
         'content_id':'int16',
         'content_type_id':'int8',
         'answered_correctly':'int8'}

train_df = pd.read_csv('/kaggle/input/riiid-test-answer-prediction/train.csv'
                       ,usecols=[1, 2, 3, 4, 7]
                       ,dtype=dtype
                       ,nrows = settings['n_train_rows']
                      )

# Keep only questions
train_df = train_df[train_df.content_type_id == False]

# Arrange by timestamp
train_df = train_df.sort_values(['timestamp'], ascending=True).reset_index(drop = True)


# Group each user
train_group = train_df[['user_id', 'content_id', 'answered_correctly']]\
            .groupby('user_id')\
            .apply(lambda r: {'content_id' : r['content_id'].values
                             ,'answered_correctly' : r['answered_correctly'].values
                            })

del train_df
gc.collect()


# Make validation set
val_idx = np.random.choice(train_group.index, int(.1 * train_group.shape[0]), replace=False)
valid_group = train_group[val_idx].copy()
train_group.drop(valid_group.index, inplace=True)

In [None]:
class riiid_dataset(torch.utils.data.Dataset):
    
    def __init__(self, group, settings):
        super(riiid_dataset, self).__init__()
        self.seq_len = settings['seq_len']
        self.group = group
        
        # Take out people with only 1 interaction
        for user_id in self.group.index:
            if len(self.group[user_id]['content_id']) < 2:
                del self.group[user_id]
        
    def __len__(self):
        return(len(self.group))
    
    def __getitem__(self, index):
        # Get the relevant user row
        sample = self.group.iloc[index]
        
        # Get contents as np.int64s
        content_id = sample['content_id'].astype(np.int64)
        answered_correctly = sample['answered_correctly'].astype(np.int64)
        
        # Helper function to pad vector
        def pad(np_array, out_size=self.seq_len):
            n_pad = out_size - len(np_array)
            if n_pad > 0:
                np_array = np.concatenate((np.full(n_pad, 0).astype(np.int64), np_array))
            else:
                np_array = np_array[:out_size]
            return(np_array)
                
        content_id = pad(content_id)
        answered_correctly = pad(answered_correctly)
        prev_ac = pad(answered_correctly, self.seq_len + 1)
        prev_ac = prev_ac[:-1]
        
        # Return
        return({
            'content_id' : content_id
            ,'answered_correctly' : answered_correctly
            ,'prev_ac' : prev_ac
        })
    

train_dataset = riiid_dataset(group = train_group
                              ,settings = settings
                              )
train_dataloader = torch.utils.data.DataLoader(train_dataset
                                                ,batch_size = settings['batch_size']
                                                ,drop_last = True
                                                ,shuffle = True
                                                ,num_workers = 4
                                               )

valid_dataset = riiid_dataset(group = valid_group
                             ,settings = settings
                             )
valid_dataloader = torch.utils.data.DataLoader(valid_dataset
                                               ,batch_size = settings['batch_size']
                                               ,drop_last = True
                                              )

In [None]:
class encoder(torch.nn.Module):
    def __init__(self, settings):
        super(encoder, self).__init__()
        self.embed_dim = settings['embed_dim']
        self.n_content_id = settings['n_content_id']
        self.seq_len = settings['seq_len']
        self.device = settings['device']
        
        self.cid_embedding = torch.nn.Embedding(self.n_content_id, self.embed_dim)
        self.pos_embedding = torch.nn.Embedding(self.seq_len, self.embed_dim)
        self.multi_att = torch.nn.MultiheadAttention(embed_dim = self.embed_dim
                                                     ,num_heads = 8
                                                     ,dropout = 0.2)

        self.lin_1 = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.relu = torch.nn.ReLU()
        self.lin_2 = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.dropout = torch.nn.Dropout(0.2)
            
    def forward(self, batch):
        # Content embedding
        x = self.cid_embedding(batch['content_id'])
        
        # Position embedding
        pos_id = torch.arange(x.shape[1])[None, :].to(self.device)
        pos_x = self.pos_embedding(pos_id)
        
        # Add embeddings and permute
        x = x + pos_x
        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        
        # MultiHead Attention and permute back
        attn_mask = torch.from_numpy(np.triu(np.ones((self.seq_len,self.seq_len)), k=1)\
                                         .astype('bool')).to(self.device) # torch.triu does not have k argument
        attn_output, _ = self.multi_att(x, x, x, attn_mask = attn_mask)
        x = x + attn_output
        
        # Feed forward
        x = self.lin_1(x)
        x = self.relu(x)
        x = self.lin_2(x)
        x = self.dropout(x)

        # Return
        return(x)
        
class decoder(torch.nn.Module):
    def __init__(self, settings):
        super(decoder, self).__init__()
        self.embed_dim = settings['embed_dim']
        self.seq_len = settings['seq_len']
        self.device = settings['device']
        
        self.prev_ac_embedding = torch.nn.Embedding(10, self.embed_dim)
        self.pos_embedding = torch.nn.Embedding(self.seq_len, self.embed_dim)
        self.multi_att_1 = torch.nn.MultiheadAttention(embed_dim = self.embed_dim
                                                     ,num_heads = 8
                                                     ,dropout = 0.2)
        self.multi_att_2 = torch.nn.MultiheadAttention(embed_dim = self.embed_dim
                                                      ,num_heads = 8
                                                      ,dropout = 0.2)
        
        self.lin_1 = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.relu = torch.nn.ReLU()
        self.lin_2 = torch.nn.Linear(self.embed_dim, self.embed_dim)
        self.dropout = torch.nn.Dropout(0.2)        
        
    def forward(self, batch, x):
        # Previous answered_correctly embedding
        y = self.prev_ac_embedding(batch['prev_ac'])
        
        # Position embedding
        pos_id = torch.arange(y.shape[1])[None, :].to(self.device)
        pos_y = self.pos_embedding(pos_id)
        
        # Add embeddings and permute
        y = y + pos_y
        y = y.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        
        # MultiHead Attention 1
        attn_mask = torch.from_numpy(np.triu(np.ones((self.seq_len,self.seq_len)), k=1)\
                                         .astype('bool')).to(self.device) # torch.triu does not have k argument
        
        attn_output_1, _ = self.multi_att_1(y, y, y, attn_mask = attn_mask)
        y = y + attn_output_1
        
        # MultiHead Attention 2
        attn_output_2, _ = self.multi_att_2(y, x, x, attn_mask = attn_mask) # query, key, value
        y = y + attn_output_2
        
        # Permute back to [batch_size, seq_len, embed_dim]
        y = y.permute(1, 0, 2)
        
        # Feed forward
        y = self.lin_1(y)
        y = self.relu(y)
        y = self.lin_2(y)
        y = self.dropout(y)

        # Return
        return(y)
        
class riiid_model(torch.nn.Module):
    def __init__(self, settings):
        super(riiid_model, self).__init__()
        self.embed_dim = settings['embed_dim']
        self.seq_len = settings['seq_len']
        self.device = settings['seq_len']
        self.encoder = encoder(settings=settings)
        self.decoder = decoder(settings=settings)
        self.emb_to_seq = torch.nn.Linear(self.embed_dim, 1)
    
    def forward(self, batch):
        x = self.encoder(batch)
        y = self.decoder(batch, x)
        y = self.emb_to_seq(y)
        y = y[:,:,0]
        return(y)

        
# Setup model, optimizer and criterion
model = riiid_model(settings)
optimizer = torch.optim.Adam(model.parameters(), lr=.001)
criterion = torch.nn.BCEWithLogitsLoss()
all_auc = []

# Move model and criteriod to device
model.to(settings['device'])
criterion.to(settings['device'])


In [None]:
# Get content_ids and pad
for _ in range(5):
    tbar = tqdm.tqdm(train_dataloader)
    for batch in tbar:
        for k in batch.keys():
            batch[k] = batch[k].to(settings['device'])
        optimizer.zero_grad()
        pred = model(batch)
        loss = criterion(pred, batch['answered_correctly'].float())
        loss.backward()
        optimizer.step()
        
        # For now, do AUC on only the last prediction
        t = batch['answered_correctly'][:, -1:][:, -1].detach().to('cpu').numpy()
        p = pred[:, -1:][:, -1].detach().to('cpu').numpy()
        auc = sklearn.metrics.roc_auc_score(t, p)
        all_auc.append(auc)

In [None]:
print(np.array(all_auc[-200:]).mean())
matplotlib.pyplot.plot(all_auc)
matplotlib.pyplot.show()

In [None]:
# Validation
val_ac = np.array([])
val_pred = np.array([])

for batch in valid_dataloader:
    for k in batch.keys():
        batch[k] = batch[k].to(settings['device'])
    #optimizer.zero_grad()
    pred = model(batch)
    #loss = criterion(pred, batch['answered_correctly'].float())
    #loss.backward()
    #optimizer.step()

    # For now, do AUC on only the last prediction
    t = batch['answered_correctly'][:, -1:][:, -1].detach().to('cpu').numpy()
    p = pred[:, -1:][:, -1].detach().to('cpu').numpy()
    
    # Concatenate
    val_ac = np.concatenate((val_ac, t))
    val_pred = np.concatenate((val_pred, p))
        
sklearn.metrics.roc_auc_score(val_ac, val_pred)