Inference notebook of [this kernel](https://www.kaggle.com/jihunlorenzopark/saint-with-tags-lsi).

In [None]:
# Flags for debug
TEST_MODE = False
MOCK_MODE = False
SKIP_ADD_FEATURE = False

LOCAL = False

SAINT_PICKLE_PATH = "../input/saint-final"
SAINT_MODEL_PATH = "../input/saint-final/saintv113.pth"

In [None]:
if MOCK_MODE:
    import pandas as pd
    from pathlib import Path
    import sqlite3
    import riiideducation
    from sklearn.metrics import roc_auc_score
    from tqdm.notebook import tqdm

In [None]:
import psutil
import joblib
import pandas as pd
import numpy as np
import gc
from sklearn.metrics import roc_auc_score
from collections import defaultdict
from tqdm import tqdm
import lightgbm as lgb
import matplotlib.pyplot as plt
import seaborn as sns

import pickle
import random
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

if LOCAL is False:
    import riiideducation

In [None]:
def load_group():
    group = None
    for i in range(10):
        with open(f"{SAINT_PICKLE_PATH}/{i}groupv1.pickle", "rb") as f:
            if group is None:
                group = pickle.load(f)
            else:
                group = pd.concat([group, pickle.load(f)])
        gc.collect()
    gc.collect()
    return group

In [None]:
group = load_group()

In [None]:
# Random seed
SEED = 123

# Function to seed everything
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
seed_everything(SEED)

# SAINT Model

In [None]:
MAX_SEQ = 100
n_skill = 13523
n_part = 7
n_et = 300
n_lt = 1441
n_lsi = 128
DROPOUT = 0.1
EMBED_SIZE = 256
BATCH_SIZE = 256

def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)

class FFN(nn.Module):
    def __init__(self, state_size = 200, forward_expansion = 1, bn_size=MAX_SEQ - 1):
        super(FFN, self).__init__()
        self.state_size = state_size
        
        self.lr1 = nn.Linear(state_size, forward_expansion * state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(forward_expansion * state_size, state_size)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.relu(self.lr1(x))
        x = self.lr2(x)
        return self.dropout(x)

class Encoder(nn.Module):
    def __init__(self, n_skill, n_pt=7, n_lsi=n_lsi, max_seq=100, embed_dim=128, dropout = DROPOUT, forward_expansion = 1, num_layers=1, heads = 8):
        super(Encoder, self).__init__()
        self.n_skill, self.embed_dim = n_skill, embed_dim
        self.embedding = nn.Embedding(n_skill + 1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.n_pt = n_pt
        self.pt_embedding = nn.Embedding(n_pt + 1, embed_dim)
        self.n_lsi = n_lsi
        self.lsi_embedding = nn.Embedding(n_lsi + 1, embed_dim)
        self.layer_normal = nn.LayerNorm(embed_dim) 
        
    def forward(self, x, question_ids, pt_x, lsi_x):
        device = x.device
        x = self.embedding(x)
        pt_x = self.pt_embedding(pt_x)
        lsi_x = self.lsi_embedding(lsi_x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        x = self.dropout(self.layer_normal(x + pos_x + pt_x + lsi_x))
        return x

    
class Decoder(nn.Module):
    def __init__(self, n_et=n_et, n_lt=n_lt, max_seq=100, embed_dim=128, dropout = DROPOUT, forward_expansion = 1, num_layers=1, heads = 8):
        super(Decoder, self).__init__()
        self.embed_dim = embed_dim
        self.pos_embedding = nn.Embedding(max_seq, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.n_response = 2
        self.n_et = n_et
        self.n_lt = n_lt
        self.response_embedding = nn.Embedding(self.n_response + 2, embed_dim)
        self.et_embedding = nn.Embedding(self.n_et + 2, embed_dim)
        self.lt_embedding = nn.Embedding(self.n_lt + 2, embed_dim)
        self.layer_normal = nn.LayerNorm(embed_dim) 
        
    def forward(self, c, et, lt):
        device = c.device
        c = self.response_embedding(c)
        pos_id = torch.arange(c.size(1)).unsqueeze(0).to(device)
        pos_x = self.pos_embedding(pos_id)
        et = self.et_embedding(et)
        lt = self.lt_embedding(lt)
        x = self.dropout(self.layer_normal(c + pos_x + et + lt))
        return x
    
class SAINTModel(nn.Module):
    def __init__(self, n_skill, n_pt=7, n_lsi=n_lsi, n_et=n_et, n_lt=n_lt, max_seq=100, embed_dim=128, dropout = DROPOUT, forward_expansion = 1, enc_layers=3, dec_layers=3, heads = 8):
        super(SAINTModel, self).__init__()
        self.encoder = Encoder(n_skill, n_pt, n_lsi, max_seq, embed_dim, dropout, forward_expansion, num_layers=enc_layers)
        self.decoder = Decoder(n_et, n_lt, max_seq, embed_dim, dropout, forward_expansion, num_layers=dec_layers)
        self.transformer = torch.nn.Transformer(embed_dim, heads, enc_layers, dec_layers, embed_dim*forward_expansion, dropout)
        
        self.ffn = FFN(embed_dim, forward_expansion = forward_expansion)
        self.pred = nn.Linear(embed_dim, 1)
        self.layer_normal = nn.LayerNorm(embed_dim) 
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, question_ids, pt_x, lsi_x, c, et, lt):
        ex = self.encoder(x, question_ids, pt_x, lsi_x)
        dx = self.decoder(c, et, lt)
        
        ex = ex.permute(1, 0, 2)
        dx = dx.permute(1, 0, 2)
        
        device = ex.device
        mask = future_mask(ex.size(0)).to(device)
        att_output = self.transformer(ex, dx, src_mask=mask, tgt_mask=mask)
        att_output = self.layer_normal(att_output)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.dropout(self.layer_normal(x + att_output))
        x = self.pred(x)
        
        return x.squeeze(-1)
    
class TestDataset(Dataset):
    def __init__(self, samples, test_df, n_skill=n_skill, n_et=n_et, n_lt=n_lt, n_pt=7,  max_seq=MAX_SEQ): 
        super(TestDataset, self).__init__()
        self.samples = samples
        self.user_ids = [x for x in test_df["user_id"].unique()]
        self.test_df = test_df
        self.n_skill = n_skill
        self.n_et = n_et
        self.n_lt = n_lt
        self.n_pt = n_pt
        self.max_seq = max_seq

    def __len__(self):
        return self.test_df.shape[0]

    def __getitem__(self, index):
        test_info = self.test_df.iloc[index]

        user_id = test_info["user_id"]
        target_id_new = test_info["content_id"]
        part_new = test_info["part"]
        lsi_topic_new = test_info["lsi_topic"]

        content_id_seq = np.zeros(self.max_seq, dtype=int)
        answered_correctly_seq = np.zeros(self.max_seq, dtype=int)
        c_seq = np.zeros(self.max_seq, dtype=int)
        lag_seq = np.zeros(self.max_seq, dtype=int)
        elapsed_time_seq = np.zeros(self.max_seq, dtype=int)
        part_seq = np.zeros(self.max_seq, dtype=int)
        lsi_topic_seq = np.zeros(self.max_seq, dtype=int)

        if user_id in self.samples.index:
            content_id, answered_correctly, lag, elapsed_time, part, lsi_topic = self.samples[user_id]
#             pd.DataFrame({
#                 "content_id": content_id,
#                 "answered_correctly": answered_correctly,
#                 "lag": lag,
#                 "elapsed_time": elapsed_time,
#                 "part": part,
#                 "lsi_topic": lsi_topic,
#             }).to_pickle("a.pkl")
            seq_len = len(content_id)

            if seq_len >= self.max_seq:
                content_id_seq[:] = content_id[-self.max_seq:]
                answered_correctly_seq[:] = answered_correctly[-self.max_seq:]
                c_seq[:] = answered_correctly[-self.max_seq:] + 1
                lag_seq[:] = lag[-self.max_seq:] + 1
                elapsed_time_seq[:] =  elapsed_time[-self.max_seq:] + 1
                part_seq[:] = part[-self.max_seq:]
                lsi_topic_seq[:] = lsi_topic[-self.max_seq:]
            else:
                content_id_seq[-seq_len:] = content_id
                c_seq[-seq_len:] = answered_correctly[:] + 1
                lag_seq[-seq_len:] = lag[:] + 1
                elapsed_time_seq[-seq_len:] = elapsed_time[:] + 1
                part_seq[-seq_len:] = part
                lsi_topic_seq[-seq_len:] = lsi_topic    
        
        target_id = content_id_seq
        label = answered_correctly_seq
        
        x = np.append(content_id_seq.copy()[1:], target_id_new)
        pt_x = np.append(part_seq.copy()[1:], part_new)
        lsi_x = np.append(lsi_topic_seq.copy()[1:], lsi_topic_new)
        c = c_seq.copy()
        et = elapsed_time_seq.copy()
        lt = lag_seq.copy()
        
#         pd.DataFrame({
#                 "x": x,
#                 "pt_x": pt_x,
#                 "lsi_x": lsi_x,
#                 "c": c,
#                 "et": et,
#                 "lt": lt,
#             }).to_pickle("b.pkl")
        
        return x, target_id, pt_x, lsi_x, c, et, lt

In [None]:
# Main changes are possibility of forward expansion and stacking of encoding layers
def create_model():
    return SAINTModel(n_skill, n_pt=7, n_lsi=n_lsi, n_et=n_et, n_lt=n_lt, max_seq=MAX_SEQ, embed_dim=EMBED_SIZE, forward_expansion=1, enc_layers=2, dec_layers=2, heads=8, dropout=0.1)

# Load SAINT Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

saint_model = create_model()
try:
    saint_model.load_state_dict(torch.load(SAINT_MODEL_PATH))
except:
    saint_model.load_state_dict(torch.load(SAINT_MODEL_PATH, map_location='cpu'))
saint_model.to(device)
saint_model.eval()

In [None]:
class PredictEnv:
    def __init__(self, folds_path, folds):
        self.conn = sqlite3.connect(':memory:')
        self.c = self.conn.cursor()
        self.setup_folds(folds_path, folds)

    def setup_folds(self, folds_path, folds):        
        self.c.executescript(f"""
            ATTACH DATABASE "{folds_path}" AS folds_db;

            DROP TABLE IF EXISTS b_records;

            CREATE TABLE b_records AS
            SELECT row_id, timestamp, user_id, content_id, content_type_id, task_container_id, prior_question_elapsed_time,
                prior_question_had_explanation, answered_correctly, user_answer
            FROM folds_db.train
            WHERE fold in ({(', ').join(list(map(str,folds)))})
            ORDER BY user_id, task_container_id, row_id;

            CREATE INDEX user_id_task_container_id_index ON b_records (user_id, task_container_id);

            DROP TABLE IF EXISTS b_users;

            CREATE TABLE b_users AS
            SELECT user_id, MIN(task_container_id) - 1 task_container_id_next, MAX(task_container_id) task_container_id_max
            FROM b_records
            GROUP BY user_id
                ORDER BY user_id, task_container_id_next;

            CREATE UNIQUE INDEX user_id_index ON b_users (user_id);

            ALTER TABLE b_users
                ADD COLUMN group_num INTEGER;

        """).fetchone()

        self.group_num = 0
        self.records_remaining = self.c.execute('SELECT COUNT(*) FROM b_records').fetchone()[0]
        self.df_users = pd.read_sql('SELECT * FROM b_users', self.conn)


    def iter_test(self):
        next_correct = '[]'
        next_responses = '[]'

        while self.records_remaining:
            self.c.execute(f"""
                INSERT INTO b_users (user_id)
                SELECT user_id
                FROM b_users
                WHERE task_container_id_next <= task_container_id_max
                LIMIT 1 + ABS(RANDOM() % 40) + ABS(RANDOM() % 1000) * (ABS(RANDOM() % 100) < 5)
                ON CONFLICT (user_id) DO UPDATE SET
                    task_container_id_next = task_container_id_next + 1,
                    group_num = {self.group_num};
            """).fetchone()
            
            self.conn.commit()

            df_b = pd.read_sql(f"""
                SELECT r.*
                FROM b_records r
                JOIN b_users u
                ON group_num = {self.group_num}
                    AND r.user_id = u.user_id
                    AND r.task_container_id = u.task_container_id_next
            """, self.conn)

            if len(df_b):
                df_b['group_num'] = self.group_num
                df_b['prior_group_answers_correct'] = None
                df_b.at[0, 'prior_group_answers_correct'] = next_correct

                df_b['prior_group_responses'] = None
                df_b.at[0, 'prior_group_responses'] = next_responses

                next_correct = f'[{(", ").join(df_b.answered_correctly.astype(str))}]'
                next_responses = f'[{(", ").join(df_b.user_answer.astype(str))}]'
                del df_b['answered_correctly']
                del df_b['user_answer']

                df_b = df_b.set_index('group_num')

                df_p = df_b[['row_id']].copy()
                df_p['answered_correctly'] = 0.5
                
                self.records_remaining -= len(df_b)

                yield df_b, df_p
            
            self.group_num += 1

    def predict(self, df_pred):
        if (df_pred.answered_correctly == -1).any():
            raise
        else:
            df_pred.reset_index().to_sql('predictions', self.conn, if_exists='append', index=False)

    def get_predictions(self):
        df_preds = pd.read_sql("""
            SELECT p.row_id, b.answered_correctly y_true, p.answered_correctly y_pred
            FROM predictions p
            JOIN b_records b
            ON p.row_id = b.row_id
        """, self.conn)

        self.score = roc_auc_score(df_preds.y_true, df_preds.y_pred)

        print(f'ROC AUC Score: {self.score:0.4f}')

        return df_preds

In [None]:
if MOCK_MODE:
    FOLDS = Path('../input/riiid-folds/riiid.db')
    env = PredictEnv(FOLDS, [0, 1])
    iter_test = env.iter_test()

else:
    env = riiideducation.make_env()
    iter_test = env.iter_test()
    set_predict = env.predict


if TEST_MODE and type(iter_test) != list:
    list_df = []
    for itr, (df_test, sample_prediction_df) in enumerate(iter_test):
        df_test.loc[:, 'answered_correctly'] = 0.5
        list_df.append((df_test, None))
        env.predict(df_test.loc[df_test['content_type_id'] == 0, ['row_id', 'answered_correctly']])
    iter_test = list_df
    print("TEST_MODE MODE ENABLED")
else:
    print("TEST_MODE MODE DISABLED")

In [None]:
TARGET = "answered_correctly"

In [None]:
QUESTION_FEATURES = ["part", "question_id", "lsi_topic"]
question_file = "../input/question-features-0102/questions.pickle"
questions_df = pd.read_pickle(question_file)[QUESTION_FEATURES]
questions_df["lsi_topic"] = questions_df["lsi_topic"].fillna(-1)
questions_df["lsi_topic"] = questions_df["lsi_topic"].map(dict(map(lambda x: (x[1],x[0]), enumerate(questions_df["lsi_topic"].fillna(-1).unique()))))
questions_df["lsi_topic"] = questions_df["lsi_topic"] + 1

In [None]:
import warnings

warnings.filterwarnings(action="ignore")

In [None]:
last_row_file = "../input/saint-final/last_row_states.pickle"
with open(last_row_file, "rb") as f:
    last_row_states = pickle.load(f)
    
# Using time series api that simulates production predictions
def inference(iter_test, TARGET, saint_model, questions_df):
    previous_test_df = None
    
    for (test_df, sample_prediction_df) in tqdm(iter_test):
        if previous_test_df is not None:
            previous_test_df[TARGET] = eval(test_df["prior_group_answers_correct"].iloc[0])
            previous_test_df = previous_test_df[previous_test_df.content_type_id == False]
            previous_test_df["user_to_remove"] = False
            previous_test_df["prior_question_elapsed_time"] = (previous_test_df["prior_question_elapsed_time"] // 100).clip(0,300)
            previous_test_df['current_container_size'] = previous_test_df[['user_id', 'task_container_id']].groupby(['user_id', 'task_container_id'])['task_container_id'].transform('size')
            
            if TEST_MODE: previous_test_df.to_csv("./previous_test_df1.csv")
            
            common_users = set(previous_test_df["user_id"].unique()).intersection(set(last_row_states.keys()))
            last_records = pd.DataFrame([
                {**last_row_states[user_id], **{'user_id': user_id}} for user_id in common_users
            ])
#             print(previous_test_df.info())
#             print(last_records)
            if len(last_records) != 0:
                previous_test_df = pd.concat([last_records, previous_test_df]).reset_index(drop=True)
                previous_test_df = previous_test_df.sort_values(['user_id','timestamp'], ascending=True).reset_index(drop = True)
#                 previous_test_df = previous_test_df.sort_values(['user_id','timestamp'], ascending=True).reset_index(drop = True)
            
            previous_test_df['last_timestamp'] = previous_test_df[['user_id', 'timestamp']].groupby(['user_id'])['timestamp'].shift(1, fill_value=0)
            previous_test_df['last_timestamp'] = previous_test_df[['user_id', 'task_container_id', 'last_timestamp']].groupby(['user_id', 'task_container_id'])['last_timestamp'].transform('first')
#             print(previous_test_df)
#             print(previous_test_df)
            previous_test_df['last_task_container_size'] = previous_test_df[['user_id', 'current_container_size']].groupby(['user_id'])['current_container_size'].shift(1, fill_value=0)
            previous_test_df['last_task_container_size'] = previous_test_df[['user_id', 'task_container_id', 'last_task_container_size']].groupby(['user_id', 'task_container_id'])['last_task_container_size'].transform('first')

            if TEST_MODE: previous_test_df.to_csv("./previous_test_df2.csv")
            previous_test_df['lag'] = previous_test_df['timestamp'] - previous_test_df['last_timestamp'] - (previous_test_df['prior_question_elapsed_time'] * previous_test_df['last_task_container_size'])
            previous_test_df["lag"] = (previous_test_df["lag"]//(100*60)).clip(0, 1440)
            if TEST_MODE: previous_test_df.to_csv("./previous_test_df3.csv")
#             print(previous_test_df)
            if TEST_MODE: previous_test_df.to_csv("./previous_test_df4.csv")
            previous_test_df = previous_test_df[previous_test_df["user_to_remove"] != True]
            if TEST_MODE: previous_test_df.to_csv("./previous_test_df5.csv")
#             print(previous_test_df)
            # Update SAINT
            prev_group = previous_test_df[[
                'user_id', 
                'content_id', 
                'answered_correctly', 
                'lag', 
                'prior_question_elapsed_time', 
                'part', 
                'lsi_topic'
            ]].groupby('user_id').apply(lambda r: (
                r['content_id'].values, 
                r['answered_correctly'].values, 
                r['lag'].values, 
                r['prior_question_elapsed_time'].values, 
                r['part'].values, 
                r['lsi_topic'].values))
            for prev_user_id in prev_group.index:
                if prev_user_id in group.index:
                    group[prev_user_id] = (
                        np.append(group[prev_user_id][0], prev_group[prev_user_id][0])[-MAX_SEQ:], 
                        np.append(group[prev_user_id][1], prev_group[prev_user_id][1])[-MAX_SEQ:],
                        np.append(group[prev_user_id][2], prev_group[prev_user_id][2])[-MAX_SEQ:], 
                        np.append(group[prev_user_id][3], prev_group[prev_user_id][3])[-MAX_SEQ:], 
                        np.append(group[prev_user_id][4], prev_group[prev_user_id][4])[-MAX_SEQ:],
                        np.append(group[prev_user_id][5], prev_group[prev_user_id][5])[-MAX_SEQ:],
                    )

                else:
                    group[prev_user_id] = (
                        prev_group[prev_user_id][0], 
                        prev_group[prev_user_id][1],
                        prev_group[prev_user_id][2], 
                        prev_group[prev_user_id][3],
                        prev_group[prev_user_id][4],
                        prev_group[prev_user_id][5],
                    )
            users_to_cache = previous_test_df.groupby("user_id").last()
            user_ids = users_to_cache.index
            timestamps = users_to_cache["timestamp"].values
            content_ids = users_to_cache["content_id"].values
            content_type_ids = users_to_cache["content_type_id"].values
            task_container_ids = users_to_cache["task_container_id"].values
            prior_question_elapsed_times = users_to_cache["prior_question_elapsed_time"].values
            prior_question_had_explanations = users_to_cache["prior_question_had_explanation"].values
            current_container_sizes = previous_test_df['current_container_size']
            
            for row in zip(
                user_ids,
                timestamps,
                content_ids,
                content_type_ids,
                task_container_ids,
                prior_question_elapsed_times,
                prior_question_had_explanations,
                current_container_sizes,
            ):
                user_id = row[0]
                timestamp = row[1]
                content_id = row[2]
                content_type_id = row[3]
                task_container_id = row[4]
                prior_question_elapsed_time = row[5]
                prior_question_had_explanation = row[6]
                current_container_size = row[7]
                row = {
                    "user_id": user_id,
                    "timestamp": timestamp,
                    "content_id": content_id,
                    "content_type_id": content_type_id,
                    "task_container_id": task_container_id,
                    "prior_question_elapsed_time": prior_question_elapsed_time,
                    "prior_question_had_explanation": prior_question_had_explanation,
                    "current_container_size": current_container_size,
                }
                row["user_to_remove"] = True
                last_row_states[row["user_id"]] = row

        test_df = pd.merge(test_df, questions_df[QUESTION_FEATURES], left_on = 'content_id', right_on = 'question_id', how = 'left')
        test_df['prior_question_had_explanation'] = test_df.prior_question_had_explanation.fillna(False).astype('int8')
        test_df['prior_question_elapsed_time'].fillna(0, inplace = True)
                
        previous_test_df = test_df.copy()
        test_df = test_df[test_df['content_type_id'] == 0].reset_index(drop = True)
        test_df[TARGET] = 0

        # SAINT inference
        test_df = test_df[test_df.content_type_id == False]
        test_dataset = TestDataset(group, test_df)
        test_dataloader = DataLoader(test_dataset, batch_size=51200, shuffle=False)
        
        outs = []
        for item in test_dataloader:
            x = item[0].to(device).long()
            target_id = item[1].to(device).long()
            pt_x = item[2].to(device).long()
            lsi_x = item[3].to(device).long()
            c = item[4].to(device).long()
            et = item[5].to(device).long()
            lt = item[6].to(device).long()
            if TEST_MODE:
                pd.DataFrame({
                    "x": x[0],
                    "target_id": target_id[0].detach().numpy(),
                    "pt_x": pt_x[0].detach().numpy(),
                    "lsi_x": lsi_x[0].detach().numpy(),
                    "c": c[0].detach().numpy(),
                    "et": et[0].detach().numpy(),
                    "lt": lt[0].detach().numpy(),
                }).to_csv("input.csv")
            
            with torch.no_grad():
                output = saint_model(x, target_id, pt_x, lsi_x, c, et, lt)
            outs.extend(torch.sigmoid(output)[:, -1].view(-1).data.cpu().numpy())

        pred = np.array(outs)

        test_df[TARGET] =  pred
        if not TEST_MODE and not MOCK_MODE:
            set_predict(test_df[['row_id', TARGET]])
        
    print('Job Done')

inference(iter_test, TARGET, saint_model, questions_df)