In [None]:
import os
import gc
import time
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from pytorch_lightning.metrics.functional.classification import auroc

## Preprocessing

In [None]:
def add_question_data(data):
    questions = pd.read_csv("../input/riiid-test-answer-prediction/questions.csv")
    q_part_ids_map = dict(zip(questions["question_id"], questions["part"]))
    data["part"] = data["content_id"].map(q_part_ids_map).astype(np.int64)
    return data


def preprocessing(data, n_sample=100_000_000):
    data = data.tail(n_sample)
    data = data[data["content_type_id"] == 0]
    data = data[["row_id", "user_id", "content_id", "answered_correctly"]]
    data = add_question_data(data)

    data = data.groupby("user_id").apply(
        lambda row: (
            row["content_id"].values,
            row["part"].values,
            row["answered_correctly"].values,
        )
    )
    # Drop <= 5 questions answered.
    data = data[data.apply(lambda x: x[0].shape[0]) > 5]
    return data

In [None]:
train_path = "../input/cv-strategy-in-the-kaggle-environment/cv5_train.parquet"
valid_path = "../input/cv-strategy-in-the-kaggle-environment/cv5_valid.parquet"
use_cols = ["row_id", "user_id", "content_id", "content_type_id", "answered_correctly"]

train = pd.read_parquet(train_path, columns=use_cols)
train = preprocessing(train, n_sample=50_000_000)

valid = pd.read_parquet(valid_path, columns=use_cols)
valid = preprocessing(valid, n_sample=5_000_000)

# train = pd.read_parquet(train_path, columns=use_cols)
# train = preprocessing(train, n_sample=500_000)

# valid = pd.read_parquet(valid_path, columns=use_cols)
# valid = preprocessing(valid, n_sample=100_000)

In [None]:
print("Train shape: ", train.shape[0])
display(train.head())

print("Train shape: ", valid.shape[0])
display(valid.head())

## Dataset

In [None]:
class SAINTDataset(torch.utils.data.Dataset):
    def __init__(self, df, max_seq=100):
        self.user_ids = []
        self.df = df
        self.max_seq = max_seq
        for user_id in df.index.values:
            self.user_ids.append(user_id)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, idx):
        user_id = self.user_ids[idx]
        (q_, c_, r_) = self.df[user_id]
        seq_len = len(q_)

        q_ = torch.as_tensor(q_, dtype=int)
        c_ = torch.as_tensor(c_, dtype=int)
        r_ = torch.as_tensor(r_, dtype=int)

        q = torch.zeros(self.max_seq, dtype=int)
        c = torch.zeros(self.max_seq, dtype=int)
        r = torch.zeros(self.max_seq, dtype=int)
        y = torch.zeros(self.max_seq, dtype=int)

        src_mask = torch.ones(self.max_seq, dtype=bool)
        label_mask = torch.ones(self.max_seq, dtype=bool)

        src_mask[-seq_len:] = False
        label_mask[-seq_len:] = False

        if seq_len > self.max_seq:
            q[:] = q_[: self.max_seq]
            c[:] = c_[: self.max_seq]
            r[1:] = r_[: self.max_seq - 1]
            y[:] = r_[: self.max_seq]
        elif seq_len <= self.max_seq:
            q[-seq_len:] = q_
            c[-seq_len:] = c_
            # 2-for the start of the sequence
            r[-seq_len:] = torch.cat((torch.tensor([2]), r_[: seq_len - 1]))
            y[-seq_len:] = r_

        return (q, c, r, y, src_mask, label_mask)

In [None]:
batch_size = 128
train_dataset = SAINTDataset(train)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
val_dataset = SAINTDataset(valid)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    pin_memory=True,
)

## Define Model (SAINT)

In [None]:
class EncoderEmbedding(nn.Module):
    def __init__(self, n_content, n_part, n_dims, seq_len, device):
        super(EncoderEmbedding, self).__init__()
        self.n_dims = n_dims
        self.seq_len = seq_len
        self.device = device

        self.position_embed = nn.Embedding(seq_len, n_dims)
        self.content_embed = nn.Embedding(n_content, n_dims)
        self.part_embed = nn.Embedding(n_part, n_dims)
        
        torch.nn.init.normal_(self.position_embed.weight, mean=0.0, std=0.01)
        torch.nn.init.normal_(self.content_embed.weight, mean=0.0, std=0.01)
        torch.nn.init.normal_(self.part_embed.weight, mean=0.0, std=0.01)

    def forward(self, content_id, part_id):
        seq = torch.arange(self.seq_len, device=self.device).unsqueeze(0)
        pos = self.position_embed(seq)

        content = self.content_embed(content_id)
        part = self.part_embed(part_id)
        return pos + content + part


class DecoderEmbedding(nn.Module):
    def __init__(self, n_response, n_dims, seq_len, device):
        super(DecoderEmbedding, self).__init__()
        self.n_dims = n_dims
        self.seq_len = seq_len
        self.device = device

        self.position_embed = nn.Embedding(seq_len, n_dims)
        self.response_embed = nn.Embedding(n_response, n_dims)
        
        torch.nn.init.normal_(self.position_embed.weight, mean=0.0, std=0.01)
        torch.nn.init.normal_(self.response_embed.weight, mean=0.0, std=0.01)

    def forward(self, response):
        seq = torch.arange(self.seq_len, device=self.device).unsqueeze(0)
        pos = self.position_embed(seq)

        res = self.response_embed(response)
        return pos + res


class SAINTModel(nn.Module):
    def __init__(
        self,
        n_questions,
        n_categories,
        n_responses,
        device="cpu",
        max_seq=100,
        d_model=128,
        encoder_dim=64,
        decoder_dim=64,
        num_heads=2,
    ):
        super().__init__()
        self.device = device
        self.mask = self.generate_square_subsequent_mask(max_seq)
        self.encoder_embedding = EncoderEmbedding(
            n_content=n_questions,
            n_part=n_categories,
            n_dims=d_model,
            seq_len=max_seq,
            device=device,
        )
        self.decoder_embedding = DecoderEmbedding(
            n_response=n_responses,
            n_dims=d_model,
            seq_len=max_seq,
            device=device,
        )

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=4,
            num_decoder_layers=4,
            dim_feedforward=1024,
            dropout=0.1,
            activation="relu",
        )
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(d_model, d_model),
            torch.nn.ReLU(),
            torch.nn.Linear(d_model, d_model),
            torch.nn.Dropout(p=0.1),
        )
        self.layer_norm = torch.nn.LayerNorm(d_model)
        self.fc1 = nn.Linear(in_features=d_model, out_features=1)

    def forward(self, q, c, r, src_pad_mask, tgt_pad_mask):
        mask = self.mask.to(q.device)
        enc = self.encoder_embedding(
            content_id=q,
            part_id=c,
        )
        enc = enc.transpose(0, 1)
        dec = self.decoder_embedding(
            response=r,
        )
        dec = dec.transpose(0, 1)
        x = self.transformer(
            enc,
            dec,
            src_mask=mask,
            tgt_mask=mask,
            # src_key_padding_mask=src_pad_mask,
            # tgt_key_padding_mask=tgt_pad_mask,
        )
        x = self.ffn(x) + x
        x = self.fc1(self.layer_norm(x))
        x = x.transpose(0, 1)
        return x.squeeze(-1)
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


## Train Model

In [None]:
# (q, c, r, y, src_mask, label_mask) = next(iter(train_dataloader))

In [None]:
# Define paramteres.
n_questions = 13523
n_categories = 8
n_responses = 3

# Train model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SAINTModel(n_questions, n_categories, n_responses, device=device)

optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

model.to(device)
criterion.to(device)

In [None]:
def train_epoch():
    preds = []
    labels = []
    train_loss = []
    model.train()

    for i, (q, c, r, y, src_mask, label_mask) in enumerate(train_dataloader):
        q = q.to(device)
        c = c.to(device)
        r = r.to(device)
        y = y.to(device)
        src_mask = src_mask.to(device)
        label_mask = label_mask.to(device)

        optimizer.zero_grad()
        yout = model(q, c, r, src_mask, label_mask)

        yout = torch.masked_select(yout, torch.logical_not(label_mask))
        y = torch.masked_select(y, torch.logical_not(label_mask))

        yout = yout.float()
        y = y.float()
        
        preds.append(yout.clone().view(-1))
        labels.append(y.clone().view(-1))

        loss_ = criterion(yout, y)
        loss_.backward()
        optimizer.step()
        train_loss.append(loss_.item())
        
        torch.save(model.state_dict(), f"state_dict_{i}epoch.pth")
        
    pred = torch.cat(preds, dim=0)
    label = torch.cat(labels, dim=0)
    auc = auroc(pred, label)

    return np.mean(train_loss), auc

def val_epoch():
    preds = []
    labels = []
    val_loss = []
    
    best_epoch = np.argmin(train_loss)
    state_dict = torch.load(f"state_dict_{best_epoch}epoch.pth")
    model.load_state_dict(state_dict)
    model.eval()

    with torch.no_grad():
        for (q, c, r, y, src_mask, label_mask) in val_dataloader:
            q = q.to(device)
            c = c.to(device)
            r = r.to(device)
            y = y.to(device)
            src_mask = src_mask.to(device)
            label_mask = label_mask.to(device)
            yout = model(q, c, r, src_mask, label_mask)

            yout = torch.masked_select(yout, torch.logical_not(label_mask))
            y = torch.masked_select(y, torch.logical_not(label_mask))

            yout = yout.float()
            y = y.float()
            
            preds.append(yout.clone().view(-1))
            labels.append(y.clone().view(-1))

            loss_ = criterion(yout, y)
            val_loss.append(loss_.item())
            
    pred = torch.cat(preds, dim=0)
    label = torch.cat(labels, dim=0)
    auc = auroc(pred, label)

    return np.mean(val_loss), auc

In [None]:
num_epochs = 30
best_score = None
for i in range(num_epochs):
    epoch_start = time.time()

    train_loss, train_auc = train_epoch()
    val_loss, val_auc = val_epoch()

    epoch_end = time.time()
    # print("Time To Run Epoch:{}".format((epoch_end - epoch_start) / 60))
    print(f"Epoch:{i} | Train Loss:{train_loss:.6f} | Train AUC:{train_auc:.6f} | Val Loss:{val_loss:.6f} | Val ACU:{val_auc:.6f}")

    if (best_score is None) or (best_score > val_loss):
        best_score = val_loss
        best_epoch = i
        torch.save(model.state_dict(), "saint{}.pth".format(i))
    gc.collect()

## Evaluation

In [None]:
eval_path = "../input/cv-strategy-in-the-kaggle-environment/cv1_valid.parquet"
use_cols = ["row_id", "user_id", "content_id", "content_type_id", "answered_correctly"]

eval_data = pd.read_parquet(eval_path, columns=use_cols)
eval_data = preprocessing(eval_data)

In [None]:
print("Train shape: ", eval_data.shape[0])
display(eval_data.head())

In [None]:
batch_size = 128
eval_dataset = SAINTDataset(eval_data)
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=batch_size,
    num_workers=4,
    shuffle=False,
    pin_memory=True,
)

In [None]:
model = SAINTModel(n_questions, n_categories, n_responses, device=device)
model.load_state_dict(torch.load(f"saint{best_epoch}.pth"))
model.to(device)

preds = []
labels = []

model.eval()
with torch.no_grad():
    for (q, c, r, y, src_mask, label_mask) in eval_dataloader:
        q = q.to(device)
        c = c.to(device)
        r = r.to(device)
        y = y.to(device)
        src_mask = src_mask.to(device)
        label_mask = label_mask.to(device)
        yout = model(q, c, r, src_mask, label_mask)

        yout = torch.masked_select(yout, torch.logical_not(label_mask))
        y = torch.masked_select(y, torch.logical_not(label_mask))

        preds.append(yout.float().view(-1))
        labels.append(y.float().view(-1))
        
pred = torch.cat(preds, dim=0)
label = torch.cat(labels, dim=0)
auc = auroc(pred, label)

In [None]:
print(f"Evaluation AUC is {auc}")