- Validate submission with TestIterator: https://www.kaggle.com/authman/lgbm-model-1?scriptVersionId=49016504
- History: Keep last 100 interactions per user in memory
- Build test sequence on-fly from current test_df and history for each user
- Pad on the fly... might have to cache # items to predict per user or simply look at the test_df...

# Setup

In [None]:
import riiideducation
env = riiideducation.make_env()
iter_test = env.iter_test()

In [None]:
import numpy as np
import pandas as pd
import gc, joblib, random

from tqdm.auto import tqdm
from pathlib import Path
from typing import List
from collections import defaultdict
from bitarray import bitarray
from time import time


import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy

TRAIN_DTYPES = {
    # 'row_id': np.uint32,
    'timestamp': np.uint64,
    'user_id': np.uint32,
    'content_id': np.uint16,
    'content_type_id': np.uint8,
    'task_container_id': np.uint16,
    'user_answer': np.int8,
    'answered_correctly': np.int8,
    'prior_question_elapsed_time': np.float32,
    'prior_question_had_explanation': 'boolean'
}

DATA_DIR = Path('../input/riiid-test-answer-prediction')
TRAIN_PATH = DATA_DIR / 'train.csv'
QUESTIONS_PATH = DATA_DIR / 'questions.csv'
LECTURES_PATH = DATA_DIR / 'lectures.csv'

DEVICE = 'cuda'
BATCH_SIZE = 512
WINDOW_SIZE = 128
PAD = 0
GREAT_SEED = 1337

def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = False
    
seed_everything(GREAT_SEED)
gc.enable()

In [None]:
# NOTE: Not sure if we're going to make use of this..

# Start col id labeling from this row below
COL_OFFSET_CUT = 1

dataset_cols = [
    # Leave user_id first, so we can truncate it
    'timestamp',
    'user_id',
    'task_container_id',
    
    #############
    # We start indexing from here, since the rest get cut off:
    'content_id',
    'part_id',
    'prior_question_elapsed_time',
    'prior_question_had_explanation',
    'incorrect_rank', # decoder feature that needs help..........
    'content_type_id',
    'bundle_id',
    
    # only these of the 4 have signal (0-index)
    'answer_ratio1',
    'answer_ratio2',
    'correct_streak_u',
    'incorrect_streak_u',
    'correct_streak_alltime_u',
    'incorrect_streak_alltime_u',
    
    # our session information + a lifetime field (elapsed time sum on platform—probably should be capped)
    'session_content_num_u',
    'session_duration_u',
    'session_ans_duration_u',
    'lifetime_ans_duration_u',
    
    # First block of continuous - timestamp intervals
    'lag_ts_u_recency',
    'last_ts_u_recency1',
    'last_ts_u_recency2',
    'last_ts_u_recency3',
    'last_correct_ts_u_recency',
    'last_incorrect_ts_u_recency',
    
    # Second block of continuous - how often user is right on global/part/session basis
    'correctness_u_recency',
    'part_correctness_u_recency',
    'session_correctness_u',
    
    # Have we been served this content before?
    'encountered',
    
    # Average score of most frequent asked questions
    'diagnostic_u_recency_1',
    'diagnostic_u_recency_2',
    'diagnostic_u_recency_3',
    'diagnostic_u_recency_4',
    'diagnostic_u_recency_5',
    'diagnostic_u_recency_6',

    'answered_correctly',
]

COL_USER_ID = dataset_cols.index('user_id') - COL_OFFSET_CUT
COL_TASK_CONTAINER_ID = dataset_cols.index('task_container_id') - COL_OFFSET_CUT
COL_CONTENT_ID = dataset_cols.index('content_id') - COL_OFFSET_CUT
COL_PART_ID = dataset_cols.index('part_id') - COL_OFFSET_CUT
COL_PRIOR_QUESTION_ELAPSED_TIME = dataset_cols.index('prior_question_elapsed_time') - COL_OFFSET_CUT
COL_PRIOR_QUESTION_EXPLANATION = dataset_cols.index('prior_question_had_explanation') - COL_OFFSET_CUT
COL_INCORRECT_RANK = dataset_cols.index('incorrect_rank') - COL_OFFSET_CUT
COL_CONTENT_TYPE_ID = dataset_cols.index('content_type_id') - COL_OFFSET_CUT
COL_BUNDLE_ID = dataset_cols.index('bundle_id') - COL_OFFSET_CUT

COL_ANSWER_RATIO1 = dataset_cols.index('answer_ratio1') - COL_OFFSET_CUT
COL_ANSWER_RATIO2 = dataset_cols.index('answer_ratio2') - COL_OFFSET_CUT
COL_CORRECT_STREAK_U = dataset_cols.index('correct_streak_u') - COL_OFFSET_CUT
COL_INCORRECT_STREAK_U = dataset_cols.index('incorrect_streak_u') - COL_OFFSET_CUT
COL_CORRECT_STREAK_ALLTIME_U = dataset_cols.index('correct_streak_alltime_u') - COL_OFFSET_CUT
COL_INCORRECT_STREAK_ALLTIME_U = dataset_cols.index('incorrect_streak_alltime_u') - COL_OFFSET_CUT
    
COL_SESSION_CONTENT_NUM_U = dataset_cols.index('session_content_num_u') - COL_OFFSET_CUT
COL_SESSION_DURATION_U = dataset_cols.index('session_duration_u') - COL_OFFSET_CUT
COL_SESSION_CORRECTNESS_U = dataset_cols.index('session_correctness_u') - COL_OFFSET_CUT
COL_SESSION_ANS_DURATION_U = dataset_cols.index('session_ans_duration_u') - COL_OFFSET_CUT
COL_LIFETIME_ANS_DURATION_U = dataset_cols.index('lifetime_ans_duration_u') - COL_OFFSET_CUT

COL_LAG_TS_RECENCY = dataset_cols.index('lag_ts_u_recency') - COL_OFFSET_CUT
COL_LAST_TS_RECENCY1 = dataset_cols.index('last_ts_u_recency1') - COL_OFFSET_CUT
COL_LAST_TS_RECENCY2 = dataset_cols.index('last_ts_u_recency2') - COL_OFFSET_CUT
COL_LAST_TS_RECENCY3 = dataset_cols.index('last_ts_u_recency3') - COL_OFFSET_CUT
COL_CORRECTNESS_U_RECENCY = dataset_cols.index('correctness_u_recency') - COL_OFFSET_CUT
COL_PART_CORRECTNESS_U_RECENCY = dataset_cols.index('part_correctness_u_recency') - COL_OFFSET_CUT
COL_LAST_CORRECT_TS_U_RECENCY = dataset_cols.index('last_correct_ts_u_recency') - COL_OFFSET_CUT
COL_LAST_INCORRECT_TS_U_RECENCY = dataset_cols.index('last_incorrect_ts_u_recency') - COL_OFFSET_CUT

COL_ENCOUNTERED = dataset_cols.index('encountered') - COL_OFFSET_CUT

COL_DIAGNOSTIC_RECENCY_1 = dataset_cols.index('diagnostic_u_recency_1') - COL_OFFSET_CUT
COL_DIAGNOSTIC_RECENCY_2 = dataset_cols.index('diagnostic_u_recency_2') - COL_OFFSET_CUT
COL_DIAGNOSTIC_RECENCY_3 = dataset_cols.index('diagnostic_u_recency_3') - COL_OFFSET_CUT
COL_DIAGNOSTIC_RECENCY_4 = dataset_cols.index('diagnostic_u_recency_4') - COL_OFFSET_CUT
COL_DIAGNOSTIC_RECENCY_5 = dataset_cols.index('diagnostic_u_recency_5') - COL_OFFSET_CUT
COL_DIAGNOSTIC_RECENCY_6 = dataset_cols.index('diagnostic_u_recency_6') - COL_OFFSET_CUT

COL_ANSWERED_CORRECTLY = dataset_cols.index('answered_correctly') - COL_OFFSET_CUT


# Model

In [None]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [None]:
def generate_mask(size, diagonal=1):        
    return torch.triu(torch.ones(size, size)==1, diagonal=diagonal)

def tasks_mask(tasks, seq_length, diagonal=1):
    future_mask = generate_mask(seq_length, diagonal=diagonal).to(tasks.device)
    container_mask= torch.ones((seq_length, seq_length)).to(tasks.device)
    container_mask=(container_mask*tasks.reshape(1,-1))==(container_mask*tasks.reshape(-1,1))
    future_mask=future_mask+container_mask
    future_mask = future_mask.fill_diagonal_(False)
    return future_mask

def tasks_3d_mask(tasks, seq_length=WINDOW_SIZE, nhead=1, diagonal=1):
    #https://www.kaggle.com/c/riiid-test-answer-prediction/discussion/206620
    mask_3d = [tasks_mask(t, seq_length, diagonal=diagonal) for t in tasks]
    mask_3d = torch.stack(mask_3d, dim=0)
    # Need BS*num_heads shape
    repeat_3d = [mask_3d for t in range(nhead)]
    repeat_3d = torch.cat(repeat_3d)
    return repeat_3d

In [None]:
class BartSinusoidalPositionalEmbedding(nn.Embedding):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions: int, embedding_dim: int, padding_idx=None):
        super().__init__(num_positions+1, embedding_dim)
        self.weight = self._init_weight(self.weight)
        self.positions = None
        
    @staticmethod
    def _init_weight(out: nn.Parameter):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        n_pos, dim = out.shape
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        out.requires_grad = False  # set early to avoid an error in pytorch-1.8+
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        out[:, :sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        
        # Scale this bad boy
        out *= 0.1 # 10% of whatever it is
        
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, start:int=0):
        if self.positions is None:
            self.positions = torch.arange(
                0, WINDOW_SIZE+1,
                dtype=torch.long,
                device=self.weight.device
            )
            
        return super().forward(self.positions)[start:start+WINDOW_SIZE]

In [None]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    def forward(self, x):
        return self.lambd(x)

In [None]:
class ThinTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(ThinTransformerEncoderLayer, self).__init__()
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, src, src_mask = None, src_key_padding_mask = None):
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        
        # Mask Out:
        src[src_key_padding_mask.transpose(0,1)[:,:,None].expand(src.shape)] = 0
        return src

In [None]:
class SaintTransformerModel(nn.Module):
    def __init__(
        self,
        d_model:int=32,
        nhead:int=2,
        num_users:int=393657+1,
        num_layers:int=2,
        num_exercises:int=13941+2,         # NOTE: 1+total number of QUESTIONS+LECTURES + Mask at the end
        num_answer_streaks:int=8+1,
        num_content_types:int=2,
        num_parts:int=7+1,                 # TOTAL
        num_bundles:int=9765+2,
        num_prior_question_elapsed_time:int=301+2,
        num_explanable:int=2+2,
        num_max_seq_len:int=WINDOW_SIZE,   
        num_encountered:int=2+1,
        dropout:float=0.1,
        emb_dropout:float=0.2,
        initrange:float = 0.02
    ):
        super(SaintTransformerModel, self).__init__()

        self.d_model = d_model
        self.nhead = nhead
        self.emb_dropout = nn.Dropout(p=emb_dropout)
        
#         # These two have double dropout, due to their cardinality:
#         self.user_embeddings = nn.Sequential(
#             nn.Embedding(num_users, 16, padding_idx=0),
#             #nn.Linear(16, d_model, bias=False),
#             #nn.Dropout(p=emb_dropout)
#         )

        self.exercise_embeddings = nn.Embedding(num_exercises, d_model, padding_idx=0)
        self.bundle_embeddings = nn.Embedding(num_bundles, d_model, padding_idx=0)
        self.content_type_embeddings = nn.Embedding(num_content_types, d_model, padding_idx=0)
        self.sin_pos_embedder = BartSinusoidalPositionalEmbedding(num_positions=WINDOW_SIZE, embedding_dim=d_model, padding_idx=0) 
        self.part_embeddings = nn.Embedding(num_parts, d_model, padding_idx=0)
        self.prior_question_elapsed_time_embeddings = nn.Embedding(num_prior_question_elapsed_time, d_model, padding_idx=0) 
        self.prior_question_had_explanation_embeddings = nn.Embedding(num_explanable, d_model, padding_idx=0)
        self.correctness_embeddings = nn.Embedding(4+2, d_model, padding_idx=0)
        self.encountered_embeddings = nn.Embedding(num_encountered, d_model, padding_idx=0)

        self.cont_mlp = nn.Sequential(
            # LN ? BN
            nn.Linear(25, d_model//2),
            nn.Dropout(p=emb_dropout),
            nn.ReLU(inplace=True),
            nn.Linear(d_model//2, d_model, bias=False),
            nn.ReLU(inplace=True),
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model, nhead, dim_feedforward=d_model,#2048,
                dropout=dropout, activation='relu'
            ),
            num_layers
        )
        
        self.transformer_history = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model, nhead*2, dim_feedforward=d_model,#2048,
                dropout=dropout, activation='relu'
            ),
            1 # history only has 1 layer
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model, nhead, dim_feedforward=d_model,#2048,
                dropout=dropout, activation='relu'
            ),
            num_layers
        )
        
        self.decoder = nn.Linear(d_model,1)
        self.init_weights(initrange)

    def init_weights(self, initrange:float = 0.02) -> None:
        self.content_type_embeddings.weight.data.normal_(0, initrange)
        self.encountered_embeddings.weight.data.normal_(0, initrange)
        self.exercise_embeddings.weight.data.normal_(0, initrange)
        self.bundle_embeddings.weight.data.normal_(0, initrange)
        self.part_embeddings.weight.data.normal_(0, initrange)
        self.prior_question_elapsed_time_embeddings.weight.data.normal_(0, initrange)
        self.prior_question_had_explanation_embeddings.weight.data.normal_(0, initrange)
        self.correctness_embeddings.weight.data.normal_(0, initrange)
        #self.decoder.bias.data.zero_()
        self.decoder.weight.data.normal_(0, initrange)
        
    @torch.cuda.amp.autocast()
    def forward(
        self,
        user_id, content_id, content_type_id, part_id, encountered_id,
        prior_question_elapsed_time, prior_question_had_explanation,
        correctness_id, bundle_id, cont_streaks,
        continuous, session, diagnostic, padding_mask, task_container_id
    ):
        batch_size = content_id.shape[0]
        seq_len = content_id.shape[-1]
        
        future_mask = torch.ones([seq_len, seq_len], device=task_container_id.device, dtype=torch.uint8)
        future_mask = future_mask.triu_(1).view(seq_len, seq_len).bool()
        history_mask = tasks_3d_mask(task_container_id, seq_length=WINDOW_SIZE, nhead=self.nhead*2)
        
        ##########
        embedded_src = self.emb_dropout(
            self.content_type_embeddings(content_type_id)
            + self.part_embeddings(part_id)
            + self.bundle_embeddings(bundle_id)
            + self.exercise_embeddings(content_id)
            + self.sin_pos_embedder(start=1)
        ).transpose(0, 1) # (S, N, E)
        
        output_src = self.emb_dropout(
            self.prior_question_elapsed_time_embeddings(prior_question_elapsed_time)
            + self.prior_question_had_explanation_embeddings(prior_question_had_explanation)
            + self.encountered_embeddings(encountered_id)
            + self.cont_mlp(
                torch.cat([
                    cont_streaks,
                    diagnostic,
                    session,
                    continuous,
                ], dim=-1)
            )
            
            + self.sin_pos_embedder(start=0)
        ).transpose(0, 1) # (S, N, E)

        history_src = self.emb_dropout(
            self.correctness_embeddings(correctness_id)
            + self.sin_pos_embedder(start=1)
        ).transpose(0, 1)
        
        history = self.transformer_history(
            src=history_src,
            mask=history_mask,
            src_key_padding_mask=padding_mask,
        )
        with torch.no_grad():
            history[torch.isnan(history)] = 0

        memory = self.transformer_encoder(
            src=embedded_src,
            mask=future_mask,
            src_key_padding_mask=padding_mask,
        )

        output = self.transformer_decoder(
            tgt = output_src + history,
            memory = memory,
            tgt_mask = future_mask,
            memory_mask = future_mask,
            tgt_key_padding_mask = padding_mask,
            memory_key_padding_mask = padding_mask,
        )
        
        return self.decoder(output).transpose(1, 0)

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, networks):
        super(SiameseNetwork, self).__init__()
        self.networks = nn.ModuleList(networks)
        self.power = nn.Parameter(torch.FloatTensor([1]))
        
    @torch.cuda.amp.autocast()
    def forward(
        self,
        user_id, content_id, content_type_id, part_id, encountered_id,
        prior_question_elapsed_time, prior_question_had_explanation,
        correctness_id, bundle_id, cont_streaks,
        continuous, session, diagnostic, padding_mask, task_container_id
    ):
        futures = [
            torch.jit.fork(
                model,
                user_id, content_id, content_type_id, part_id, encountered_id,
                prior_question_elapsed_time, prior_question_had_explanation,
                correctness_id, bundle_id, cont_streaks,
                continuous, session, diagnostic, padding_mask, task_container_id
            ) for model in self.networks
        ]
        outputs = [torch.jit.wait(fut) for fut in futures]
        return (torch.stack(outputs) ** self.power).mean(dim=0)

# Classes

In [None]:
class UserDF():
    def pop_if_needed(self):
        # Ensure we're only keeping the final 100:
        if self.lensize[0] < 128:
            return True
        
        self.user_id.pop(0)
        self.task_container_id.pop(0)

        #############
        # We start indexing from here, since the rest get cut off:
        self.content_id.pop(0)
        self.part_id.pop(0)
        self.prior_question_elapsed_time.pop(0)
        self.prior_question_had_explanation.pop(0)
        self.incorrect_rank.pop(0) # decoder feature that needs help..........
        self.content_type_id.pop(0)
        self.bundle_id.pop(0)

        # only these of the 4 have signal (0-index)
        self.answer_ratio1.pop(0)
        self.answer_ratio2.pop(0)
        self.correct_streak_u.pop(0)
        self.incorrect_streak_u.pop(0)
        self.correct_streak_alltime_u.pop(0)
        self.incorrect_streak_alltime_u.pop(0)

        # our session information + a lifetime field (elapsed time sum on platform—probably should be capped)
        self.session_content_num_u.pop(0)
        self.session_duration_u.pop(0)
        self.session_ans_duration_u.pop(0)
        self.lifetime_ans_duration_u.pop(0)

        # First block of continuous - timestamp intervals
        self.lag_ts_u_recency.pop(0)
        self.last_ts_u_recency1.pop(0)
        self.last_ts_u_recency2.pop(0)
        self.last_ts_u_recency3.pop(0)
        self.last_correct_ts_u_recency.pop(0)
        self.last_incorrect_ts_u_recency.pop(0)

        # Second block of continuous - how often user is right on global/part/session basis
        self.correctness_u_recency.pop(0)
        self.part_correctness_u_recency.pop(0)
        self.session_correctness_u.pop(0)

        # Have we been served this content before?
        self.encountered.pop(0)

        # Average score of most frequent asked questions
        self.diagnostic_u_recency_1.pop(0)
        self.diagnostic_u_recency_2.pop(0)
        self.diagnostic_u_recency_3.pop(0)
        self.diagnostic_u_recency_4.pop(0)
        self.diagnostic_u_recency_5.pop(0)
        self.diagnostic_u_recency_6.pop(0)

        self.answered_correctly.pop(0)
        return False
            
    def __init__(self):
        # TODO: We may need to add some variables to keep track of how many predictions
        # And where in test_df they are
        self.lensize = [0]
        
        ##########
        self.user_id = []
        self.task_container_id = []

        #############
        # We start indexing from here, since the rest get cut off:
        self.content_id = []
        self.part_id = []
        self.prior_question_elapsed_time = []
        self.prior_question_had_explanation = []
        self.incorrect_rank = [] # decoder feature that needs help..........
        self.content_type_id = []
        self.bundle_id = []

        # only these of the 4 have signal (0-index)
        self.answer_ratio1 = []
        self.answer_ratio2 = []
        self.correct_streak_u = []
        self.incorrect_streak_u = []
        self.correct_streak_alltime_u = []
        self.incorrect_streak_alltime_u = []

        # our session information + a lifetime field (elapsed time sum on platform—probably should be capped)
        self.session_content_num_u = []
        self.session_duration_u = []
        self.session_ans_duration_u = []
        self.lifetime_ans_duration_u = []

        # First block of continuous - timestamp intervals
        self.lag_ts_u_recency = []
        self.last_ts_u_recency1 = []
        self.last_ts_u_recency2 = []
        self.last_ts_u_recency3 = []
        self.last_correct_ts_u_recency = []
        self.last_incorrect_ts_u_recency = []

        # Second block of continuous - how often user is right on global/part/session basis
        self.correctness_u_recency = []
        self.part_correctness_u_recency = []
        self.session_correctness_u = []

        # Have we been served this content before?
        self.encountered = []

        # Average score of most frequent asked questions
        self.diagnostic_u_recency_1 = []
        self.diagnostic_u_recency_2 = []
        self.diagnostic_u_recency_3 = []
        self.diagnostic_u_recency_4 = []
        self.diagnostic_u_recency_5 = []
        self.diagnostic_u_recency_6 = []

        self.answered_correctly = []

In [None]:
class UserFeatures():
    def __init__(self):
        # Everything needs to be a list so we can pass references..
        # Idx=0 current ts tcid, Idx=1 start idx of this TCID, Idx=2 counter of items in TCID
        # Idx=0 current ts tcid, Idx=1 counter, Idx=2 is [idx,idx,idx]
        self.cached_task_container_id_u = [np.nan]

        # These two features are only used to compute the lag feature.
        # Index: 0 = current, 1=previous
        self.task_container_num_questions_u = [0,0]  # Only updated on question_id's!
        # Index: 0 = ts, 1=per-bundle1, 1 = per-bundle2,
        self.last_question_ts_u_recency2 = [np.nan,np.nan,np.nan]

        # Idx0=ts, Idx1=CorrectTS, Idx2=CountTS, Idx3=CorrectBundleNoLeak, Idx4=CountBundleBundleNoLeak
        # Idx5=CumAnsTimeSessionNoLeak, Idx6=CumAnsTimePlatformNoLeak
        self.session_u = [np.nan, 2,3, 2,3, 0,0]

        # Index: 0=ts, 1 = per-bundle1, 2 = per-bundle2, 3 = per-bundle3, ...
        self.last_ts_u = [np.nan,np.nan,np.nan,np.nan]

        # Index: 0 = per-ts, 1 = per-bundle
        self.last_correct_ts_u = [0,0] #[np.nan,np.nan]

        # Index: 0,1 = per-ts, 2,3 = per-bundle. 0=correct qs, 1=total qs
        self.last_incorrect_ts_u = [0,0] #[np.nan,np.nan]

        # Index: 0,1 = per-ts, 2,3 = per-bundle. 0=correct qs, 1=total qs
        self.correctness_u = [2,3,2,3]  # 4right,2wrong=66.7%, mean target

        # 7*4, [7parts][0=correct_ts,1=count_ts,2=correct_bundle,3=count_bundle]
        self.part_correctness_u = 7 * [[2,3,2,3]]

        # diagnostic_content_id_order_map
        # diagnostic_content_id_mean_map
        # diagnostic_content_id_mean_map_values
        self.diagnostic_u_1 = [
            # Target Grouping rather than association grouping for now:
            #26 	6911 	0.818782
            #33 	7901 	0.824779

            # First two indices are 'update indicators', last two are default values that are set at end of each TCID
            np.nan, np.nan, 0.818782, 0.824779,

            # Last value is default mean
            0.8217805
        ]

        self.diagnostic_u_2 = [
            # 30 	7219 	0.594133
            # 25 	6910 	0.602537
            # 4 	2066 	0.614043
            # 21 	6879 	0.619330
            # 3 	2065 	0.621588
            # 1 	1279 	0.635821
            np.nan, np.nan, np.nan, np.nan, np.nan, np.nan,
            0.594133, 0.602537, 0.614043, 0.619330, 0.621588, 0.635821,

            0.6145753333333334
        ]

        self.diagnostic_u_3 = [
            # 12 	3365 	0.532004
            # 16 	4697 	0.545093
            # 5 	2594 	0.557091
            np.nan, np.nan, np.nan,
            0.532004, 0.545093, 0.557091,

            0.5447293333333334
        ]

        self.diagnostic_u_4 = [
            # 20 	6878 	0.430956
            # 27 	6912 	0.431397
            # 15 	4493 	0.440798
            # 7 	2596 	0.468018
            # 29 	7218 	0.485773
            np.nan, np.nan, np.nan, np.nan, np.nan,
            0.430956, 0.431397, 0.440798, 0.468018, 0.485773,

            0.4513884
        ]

        self.diagnostic_u_5 = [
            # 24 	6909 	0.390869
            # 32 	7877 	0.393220
            # 22 	6880 	0.404521
            # 28 	7217 	0.405293
            # 6 	2595 	0.419734
            np.nan, np.nan, np.nan, np.nan, np.nan,
            0.390869, 0.393220, 0.404521, 0.405293, 0.419734,
            0.4027274
        ]

        self.diagnostic_u_6 = [
            # 10 	2949 	0.211211
            # 9 	2948 	0.226543
            # 14 	4121 	0.230986
            # 23 	6881 	0.252001
            # 18 	6174 	0.255422
            # 31 	7220 	0.262637
            # 11 	3364 	0.268132
            np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan,
            0.211211, 0.226543, 0.230986, 0.252001, 0.255422, 0.262637, 0.268132,
            0.24384742857142858,
        ]
        
        #user_answer_streak_map
        self.answer_ratio = [
            # Idx0 = last user answer, Idx1 = streak count
            # We don't have to limit to bundles since we have access to this at read time
            1,1,  # UA=1 counter, UA=1 bundle 
            1,1,  # UA=2 counter, UA=2 bundle
            4,4,  # total q asked
        ]
        
        self.streak_u = [
            0,0,0, # correct ts, bundle, all-time bundle
            0,0,0, # incorrect ts, bundle, all-time bundle
        ]

        # We don't need to update @ Bundle because we'll only see a question once in a TCID
        self.content_encounter = bitarray(13550+425, endian='little')
        self.content_encounter.setall(0)
        
    def get_features(self):
        return (
            self.cached_task_container_id_u,
            self.task_container_num_questions_u,
            self.last_question_ts_u_recency2,
            self.session_u,
            self.last_ts_u,
            self.last_correct_ts_u,
            self.last_incorrect_ts_u,
            self.correctness_u,
            self.part_correctness_u,
            self.diagnostic_u_1,
            self.diagnostic_u_2,
            self.diagnostic_u_3,
            self.diagnostic_u_4,
            self.diagnostic_u_5,
            self.diagnostic_u_6,
            self.answer_ratio,
            self.streak_u,
            self.content_encounter,
        )

In [None]:
# For pickling
def newUserFeatures(): return UserFeatures()
def newUserDF(): return UserDF()

# Load up all mappings

In [None]:
(
    embedded_user_ids_map, mask_user_ids, part_ids_map, lecture_ids_map, bundle_id_map,
    question_incorrect_ranks,
    answer_ratio1_mean, answer_ratio1_std,
    answer_ratio2_mean, answer_ratio2_std,
    correct_streak_u_mean, correct_streak_u_std,
    incorrect_streak_u_mean, incorrect_streak_u_std,
    correct_streak_alltime_u_mean, correct_streak_alltime_u_std,
    incorrect_streak_alltime_u_mean, incorrect_streak_alltime_u_std,
    diagnostic_u_recency_1_mean, diagnostic_u_recency_1_std,
    diagnostic_u_recency_2_mean, diagnostic_u_recency_2_std,
    diagnostic_u_recency_3_mean, diagnostic_u_recency_3_std,
    diagnostic_u_recency_4_mean, diagnostic_u_recency_4_std,
    diagnostic_u_recency_5_mean, diagnostic_u_recency_5_std,
    diagnostic_u_recency_6_mean, diagnostic_u_recency_6_std,
    session_content_num_u_mean, session_content_num_u_std,
    session_duration_u_mean, session_duration_u_std,
    session_correctness_u_mean, session_correctness_u_std,
    session_ans_duration_u_mean, session_ans_duration_u_std,
    lifetime_ans_duration_u_mean, lifetime_ans_duration_u_std,
    correctness_u_recency_mean, correctness_u_recency_std,
    part_correctness_u_recency_mean, part_correctness_u_recency_std,
    lag_ts_u_recency_mean, lag_ts_u_recency_std,
    last_ts_u_recency1_mean, last_ts_u_recency1_std,
    last_ts_u_recency2_mean, last_ts_u_recency2_std,
    last_ts_u_recency3_mean, last_ts_u_recency3_std,
    last_incorrect_ts_u_recency_mean, last_incorrect_ts_u_recency_std,
    last_correct_ts_u_recency_mean, last_correct_ts_u_recency_std,
) = joblib.load('../input/uthman-riiid/mappings.jlib')

# Build this out:
user_features = joblib.load('../input/uthman-riiid/user_features_full.jlib')

# Also load up cached 128 last records of users....
user_dfs = joblib.load('../input/uthman-riiid/user_dfs_full.jlib')

# NOTE: We aaren't embedding user_ids, even though I think it;d be a great idea...
# TODO: This needs to include validation data.....
# NOTE: User IDs already mapped.... but we add 1 so that we can use 0 for unknown users
embedded_user_ids_map = joblib.load('../input/uthman-riiid/mapped_embedded_user_ids.jlib')
embedded_user_ids_map = {i:v+1 for i,v in embedded_user_ids_map.items()}

question_incorrect_ranks = joblib.load(f'../input/uthman-riiid/question_incorrect_ranks.pkl')

# add_user_feats_without_update

In [None]:
def add_user_feats_without_update(df, user_features): 
    # Index: 0 = per-ts, 1 = per-bundle1, 2 = per-bundle2, 3 = per-bundle3 
    session_content_num_u = np.zeros(df.shape[0], dtype=np.float32)
    session_duration_u    = np.zeros(df.shape[0], dtype=np.float64)
    session_correctness_u = np.zeros(df.shape[0], dtype=np.float32)
    session_ans_duration_u= np.zeros(df.shape[0], dtype=np.float64)
    lifetime_ans_duration_u= np.zeros(df.shape[0], dtype=np.float64)
    lag_ts_u_recency   = np.zeros(df.shape[0], dtype=np.float64)
    last_ts_u_recency1 = np.zeros(df.shape[0], dtype=np.float64)
    last_ts_u_recency2 = np.zeros(df.shape[0], dtype=np.float64)
    last_ts_u_recency3 = np.zeros(df.shape[0], dtype=np.float64)
    last_incorrect_ts_u_recency = np.zeros(df.shape[0], dtype=np.float64)
    last_correct_ts_u_recency   = np.zeros(df.shape[0], dtype=np.float64)
    correctness_u_recency       = np.zeros(df.shape[0], dtype=np.float32)
    part_correctness_u_recency  = np.zeros(df.shape[0], dtype=np.float32)
    diagnostic_u_recency        = np.zeros((df.shape[0],6), dtype=np.float32)
    answer_ratio1               = np.zeros(df.shape[0], dtype=np.float32)
    answer_ratio2               = np.zeros(df.shape[0], dtype=np.float32)
    correct_streak_u            = np.zeros(df.shape[0], dtype=np.uint16)
    incorrect_streak_u          = np.zeros(df.shape[0], dtype=np.uint16)
    correct_streak_alltime_u    = np.zeros(df.shape[0], dtype=np.uint16)
    incorrect_streak_alltime_u  = np.zeros(df.shape[0], dtype=np.uint16)
    encountered                 = np.zeros(df.shape[0], dtype=np.uint8)
    
    # Ideally, we can pull less
    for idx, (_, user_id, task_container_id, content_type_id, part_id, timestamp, prior_question_elapsed_time_cont) in enumerate(tqdm(df[[
        'user_id', 'task_container_id', 'content_type_id', 'part_id', 'timestamp', 'prior_question_elapsed_time_cont'
    ]].itertuples())):
        # Adjustments
        part_id -= 1  # So we can index at 0
        (
            cached_task_container_id_u,
            task_container_num_questions_u,
            last_question_ts_u_recency2,
            session_u,
            last_ts_u,
            last_correct_ts_u,
            last_incorrect_ts_u,
            correctness_u,
            part_correctness_u,
            diagnostic_u_1,
            diagnostic_u_2,
            diagnostic_u_3,
            diagnostic_u_4,
            diagnostic_u_5,
            diagnostic_u_6,
            answer_ratio,
            streak_u,
            content_encounter,
        ) = user_features[user_id].get_features()
        part_correctness_u__part = part_correctness_u[part_id]
            
        # Step 1) Bundle Alignment operation - in kernel submission, this will run each time
        if cached_task_container_id_u[0] != task_container_id:
            # Initialize:
            cached_task_container_id_u[0] = task_container_id
            
            # Index: 0 = per-ts, 1 = per-bundle
            last_correct_ts_u[1]   = last_correct_ts_u[0]
            last_incorrect_ts_u[1] = last_incorrect_ts_u[0]

            # Index: 0,1 = per-ts, 2,3 = per-bundle. 0=correct qs, 1=total qs
            correctness_u[2] = correctness_u[0]
            correctness_u[3] = correctness_u[1]

            for part_id in range(7):
                # 7*4, [7parts][0=correct_ts,1=count_ts,2=correct_bundle,3=count_bundle]
                part_correctness_u[part_id][2] = part_correctness_u[part_id][0]
                part_correctness_u[part_id][3] = part_correctness_u[part_id][1]
            
            # Index: 0=current-ts, 1 = per-bundle1, 2 = per-bundle2, 3 = per-bundle3 
            # Order is important
            last_ts_u[3] = last_ts_u[2]
            last_ts_u[2] = last_ts_u[1]
            last_ts_u[1] = last_ts_u[0]
            last_ts_u[0] = timestamp
            
            # Streak
            streak_u[1] = streak_u[0]  # Update bundle
            streak_u[4] = streak_u[3]  # Update bundle
            if streak_u[2] < streak_u[1]: streak_u[2] = streak_u[1]  # Update all-time
            if streak_u[5] < streak_u[4]: streak_u[5] = streak_u[4]  # Update all-time
            
            # Answer Ratios
            answer_ratio[1] = answer_ratio[0]
            answer_ratio[3] = answer_ratio[2]
            answer_ratio[5] = answer_ratio[4]
            
            if task_container_id < 40:
                # Diagnostic Updates
                for i in range(2):
                    if np.isnan(diagnostic_u_1[i]): continue
                    diagnostic_u_1[i+2] = diagnostic_u_1[i]
                    diagnostic_u_1[i] = np.nan

                for i in range(6):
                    if np.isnan(diagnostic_u_2[i]): continue
                    diagnostic_u_2[i+6] = diagnostic_u_2[i]
                    diagnostic_u_2[i] = np.nan

                for i in range(3):
                    if np.isnan(diagnostic_u_3[i]): continue
                    diagnostic_u_3[i+3] = diagnostic_u_3[i]
                    diagnostic_u_3[i] = np.nan

                for i in range(5):
                    if not np.isnan(diagnostic_u_4[i]):
                        diagnostic_u_4[i+5] = diagnostic_u_4[i]
                        diagnostic_u_4[i] = np.nan
                    if not np.isnan(diagnostic_u_5[i]):
                        diagnostic_u_5[i+5] = diagnostic_u_5[i]
                        diagnostic_u_5[i] = np.nan

                for i in range(7):
                    if np.isnan(diagnostic_u_6[i]): continue
                    diagnostic_u_6[i+7] = diagnostic_u_6[i]
                    diagnostic_u_6[i] = np.nan
                    
                # The means:
                diagnostic_u_1[-1] = np.mean(diagnostic_u_1[2:-1])
                diagnostic_u_2[-1] = np.mean(diagnostic_u_2[6:-1])
                diagnostic_u_3[-1] = np.mean(diagnostic_u_3[3:-1])
                diagnostic_u_4[-1] = np.mean(diagnostic_u_4[5:-1])
                diagnostic_u_5[-1] = np.mean(diagnostic_u_5[5:-1])
                diagnostic_u_6[-1] = np.mean(diagnostic_u_6[7:-1])
                # End diagnostic features
            ###
            
            ###
            # Lag features
            if content_type_id==0:
                # We only reset for questions; lectures will ffill()
                task_container_num_questions_u[1] = task_container_num_questions_u[0]
                task_container_num_questions_u[0] = 0
                
                # Same claculation as last_ts_u_recency2
                last_question_ts_u_recency2[2] = last_question_ts_u_recency2[1]
                last_question_ts_u_recency2[1] = last_question_ts_u_recency2[0]
                last_question_ts_u_recency2[0] = timestamp

                if not np.isnan(prior_question_elapsed_time_cont):
                    session_u[5] += prior_question_elapsed_time_cont * task_container_num_questions_u[1]
                    session_u[6] += prior_question_elapsed_time_cont * task_container_num_questions_u[1]
            else:
                # Entering a lecture:
                diff = timestamp - last_question_ts_u_recency2[0]
                last_question_ts_u_recency2[2] += diff
                last_question_ts_u_recency2[1] += diff
                last_question_ts_u_recency2[0] = timestamp
                # End lag features
                
            # Copy the session stuff over
            session_u[3] = session_u[1]
            session_u[4] = session_u[2]
            
            ###
        
        # Bake values - note, none of these change per task_container_id,
        # even though we're recalculating them all every iteration......
        lag_ts_u_recency[idx]            = last_question_ts_u_recency2[1] - last_question_ts_u_recency2[2] - prior_question_elapsed_time_cont * task_container_num_questions_u[1]
        last_ts_u_recency1[idx]          = timestamp - last_ts_u[1]
        last_ts_u_recency2[idx]          = last_ts_u[1] - last_ts_u[2]
        last_ts_u_recency3[idx]          = last_ts_u[2] - last_ts_u[3]
        last_correct_ts_u_recency[idx]   = timestamp - last_correct_ts_u[1]
        last_incorrect_ts_u_recency[idx] = timestamp - last_incorrect_ts_u[1]
        correctness_u_recency[idx]       = correctness_u[2] / correctness_u[3]
        part_correctness_u_recency[idx]  = part_correctness_u__part[2] / part_correctness_u__part[3]
        session_content_num_u[idx]       = session_u[4] - 2 # we start at 6....
        session_duration_u[idx]          = 0 if np.isnan(session_u[0]) else timestamp - session_u[0]
        session_correctness_u[idx]       = session_u[3] / session_u[4]
        session_ans_duration_u[idx]      = session_u[5]
        lifetime_ans_duration_u[idx]     = session_u[6]
        diagnostic_u_recency[idx,0]      = diagnostic_u_1[-1]
        diagnostic_u_recency[idx,1]      = diagnostic_u_2[-1]
        diagnostic_u_recency[idx,2]      = diagnostic_u_3[-1]
        diagnostic_u_recency[idx,3]      = diagnostic_u_4[-1]
        diagnostic_u_recency[idx,4]      = diagnostic_u_5[-1]
        diagnostic_u_recency[idx,5]      = diagnostic_u_6[-1]
        answer_ratio1[idx]               = answer_ratio[1] / answer_ratio[5]
        answer_ratio2[idx]               = answer_ratio[3] / answer_ratio[5]
        correct_streak_u[idx]            = streak_u[1]
        incorrect_streak_u[idx]          = streak_u[4]
        correct_streak_alltime_u[idx]    = streak_u[2]
        incorrect_streak_alltime_u[idx]  = streak_u[5]
        

    user_feats_df = pd.DataFrame({
        'lag_ts_u_recency': np.log1p(lag_ts_u_recency.clip(0,86400000) / 1000),
        'last_ts_u_recency1': np.log1p(last_ts_u_recency1.clip(650,1209600000)),
        'last_ts_u_recency2': np.log1p(last_ts_u_recency2.clip(650,1209600000)),
        'last_ts_u_recency3': np.log1p(last_ts_u_recency3.clip(650,1209600000)),
        'last_correct_ts_u_recency': np.log1p(last_correct_ts_u_recency.clip(650,1209600000)),
        'last_incorrect_ts_u_recency': np.log1p(last_incorrect_ts_u_recency.clip(650,1209600000)),
        'correctness_u_recency': correctness_u_recency,
        'part_correctness_u_recency': part_correctness_u_recency,
        'session_content_num_u': np.log1p(session_content_num_u),
        'session_duration_u': np.log1p(session_duration_u.clip(4100,18000000)),
        'session_correctness_u': session_correctness_u,
        'session_ans_duration_u': np.log1p(session_ans_duration_u.clip(4100,18000000)),
        'lifetime_ans_duration_u': np.log1p(lifetime_ans_duration_u.clip(4100,None)),
        'diagnostic_u_recency_1': diagnostic_u_recency[:, 0],
        'diagnostic_u_recency_2': diagnostic_u_recency[:, 1],
        'diagnostic_u_recency_3': diagnostic_u_recency[:, 2],
        'diagnostic_u_recency_4': diagnostic_u_recency[:, 3],
        'diagnostic_u_recency_5': diagnostic_u_recency[:, 4],
        'diagnostic_u_recency_6': diagnostic_u_recency[:, 5],
        'answer_ratio1': answer_ratio1,
        'answer_ratio2': answer_ratio2,
        'correct_streak_u': np.log1p(correct_streak_u),
        'incorrect_streak_u': np.log1p(incorrect_streak_u),
        'correct_streak_alltime_u': np.log1p(correct_streak_alltime_u),
        'incorrect_streak_alltime_u': np.log1p(incorrect_streak_alltime_u),
        'encountered': encountered,
    })
    
    user_feats_df.lag_ts_u_recency = ((user_feats_df.lag_ts_u_recency   - lag_ts_u_recency_mean)  / lag_ts_u_recency_std).astype(np.float32) 
    user_feats_df.last_ts_u_recency1 = ((user_feats_df.last_ts_u_recency1   - last_ts_u_recency1_mean)  / last_ts_u_recency1_std).astype(np.float32) 
    user_feats_df.last_ts_u_recency2 = ((user_feats_df.last_ts_u_recency2   - last_ts_u_recency2_mean)  / last_ts_u_recency2_std).astype(np.float32) 
    user_feats_df.last_ts_u_recency3 = ((user_feats_df.last_ts_u_recency3   - last_ts_u_recency3_mean)  / last_ts_u_recency3_std).astype(np.float32) 
    user_feats_df.last_incorrect_ts_u_recency = ((user_feats_df.last_incorrect_ts_u_recency - last_incorrect_ts_u_recency_mean) / last_incorrect_ts_u_recency_std).astype(np.float32)
    user_feats_df.last_correct_ts_u_recency   = ((user_feats_df.last_correct_ts_u_recency   - last_correct_ts_u_recency_mean  ) / last_correct_ts_u_recency_std).astype(np.float32) 
    user_feats_df.correctness_u_recency       = ((user_feats_df.correctness_u_recency       - correctness_u_recency_mean      ) / correctness_u_recency_std).astype(np.float32) 
    user_feats_df.part_correctness_u_recency  = ((user_feats_df.part_correctness_u_recency  - part_correctness_u_recency_mean ) / part_correctness_u_recency_std).astype(np.float32)     
    user_feats_df.session_content_num_u  = ((user_feats_df.session_content_num_u   - session_content_num_u_mean )  / session_content_num_u_std).astype(np.float32) 
    user_feats_df.session_duration_u     = ((user_feats_df.session_duration_u      - session_duration_u_mean )     / session_duration_u_std).astype(np.float32) 
    user_feats_df.session_correctness_u  = ((user_feats_df.session_correctness_u   - session_correctness_u_mean )  / session_correctness_u_std).astype(np.float32) 
    user_feats_df.session_ans_duration_u = ((user_feats_df.session_ans_duration_u  - session_ans_duration_u_mean ) / session_ans_duration_u_std).astype(np.float32) 
    user_feats_df.lifetime_ans_duration_u= ((user_feats_df.lifetime_ans_duration_u - lifetime_ans_duration_u_mean) / lifetime_ans_duration_u_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_1 = ((user_feats_df.diagnostic_u_recency_1  - diagnostic_u_recency_1_mean ) / diagnostic_u_recency_1_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_2 = ((user_feats_df.diagnostic_u_recency_2  - diagnostic_u_recency_2_mean ) / diagnostic_u_recency_2_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_3 = ((user_feats_df.diagnostic_u_recency_3  - diagnostic_u_recency_3_mean ) / diagnostic_u_recency_3_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_4 = ((user_feats_df.diagnostic_u_recency_4  - diagnostic_u_recency_4_mean ) / diagnostic_u_recency_4_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_5 = ((user_feats_df.diagnostic_u_recency_5  - diagnostic_u_recency_5_mean ) / diagnostic_u_recency_5_std).astype(np.float32) 
    user_feats_df.diagnostic_u_recency_6 = ((user_feats_df.diagnostic_u_recency_6  - diagnostic_u_recency_6_mean ) / diagnostic_u_recency_6_std).astype(np.float32) 
    user_feats_df.answer_ratio1 = ((user_feats_df.answer_ratio1  - answer_ratio1_mean ) / answer_ratio1_std).astype(np.float32) 
    user_feats_df.answer_ratio2 = ((user_feats_df.answer_ratio2  - answer_ratio2_mean ) / answer_ratio2_std).astype(np.float32) 
    user_feats_df.correct_streak_u = ((user_feats_df.correct_streak_u  - correct_streak_u_mean ) / correct_streak_u_std).astype(np.float32) 
    user_feats_df.incorrect_streak_u = ((user_feats_df.incorrect_streak_u  - incorrect_streak_u_mean ) / incorrect_streak_u_std).astype(np.float32) 
    user_feats_df.correct_streak_alltime_u = ((user_feats_df.correct_streak_alltime_u  - correct_streak_alltime_u_mean ) / correct_streak_alltime_u_std).astype(np.float32) 
    user_feats_df.incorrect_streak_alltime_u = ((user_feats_df.incorrect_streak_alltime_u  - incorrect_streak_alltime_u_mean ) / incorrect_streak_alltime_u_std).astype(np.float32) 

    user_feats_df.loc[user_feats_df.lag_ts_u_recency.isna(), 'lag_ts_u_recency'] = 0
    user_feats_df.loc[user_feats_df.last_ts_u_recency1.isna(), 'last_ts_u_recency1'] = 0
    user_feats_df.loc[user_feats_df.last_ts_u_recency2.isna(), 'last_ts_u_recency2'] = 0
    user_feats_df.loc[user_feats_df.last_ts_u_recency3.isna(), 'last_ts_u_recency3'] = 0
    user_feats_df.loc[user_feats_df.last_incorrect_ts_u_recency.isna(), 'last_incorrect_ts_u_recency'] = 0
    user_feats_df.loc[user_feats_df.last_correct_ts_u_recency.isna(), 'last_correct_ts_u_recency'] = 0
    user_feats_df.loc[user_feats_df.correctness_u_recency.isna(), 'correctness_u_recency'] = 0.653417715747257 # mean answered correctly
    user_feats_df.loc[user_feats_df.part_correctness_u_recency.isna(), 'part_correctness_u_recency'] = 0.653417715747257 # mean answered correctly across dset 
    user_feats_df.loc[user_feats_df.session_content_num_u.isna(), 'session_content_num_u'] = 0
    user_feats_df.loc[user_feats_df.session_duration_u.isna(), 'session_duration_u'] = 0
    user_feats_df.loc[user_feats_df.session_correctness_u.isna(), 'session_correctness_u'] = 0.653417715747257
    user_feats_df.loc[user_feats_df.session_ans_duration_u.isna(), 'session_ans_duration_u'] = 0
    user_feats_df.loc[user_feats_df.lifetime_ans_duration_u.isna(), 'lifetime_ans_duration_u'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_1.isna(), 'diagnostic_u_recency_1'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_2.isna(), 'diagnostic_u_recency_2'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_3.isna(), 'diagnostic_u_recency_3'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_4.isna(), 'diagnostic_u_recency_4'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_5.isna(), 'diagnostic_u_recency_5'] = 0
    user_feats_df.loc[user_feats_df.diagnostic_u_recency_6.isna(), 'diagnostic_u_recency_6'] = 0
    user_feats_df.loc[user_feats_df.answer_ratio1.isna(), 'answer_ratio1'] = 0
    user_feats_df.loc[user_feats_df.answer_ratio2.isna(), 'answer_ratio2'] = 0
    user_feats_df.loc[user_feats_df.correct_streak_u.isna(), 'correct_streak_u'] = 0
    user_feats_df.loc[user_feats_df.incorrect_streak_u.isna(), 'incorrect_streak_u'] = 0
    user_feats_df.loc[user_feats_df.correct_streak_alltime_u.isna(), 'correct_streak_alltime_u'] = 0
    user_feats_df.loc[user_feats_df.incorrect_streak_alltime_u.isna(), 'incorrect_streak_alltime_u'] = 0

    # TODO: Let's do this outside?
    df = pd.concat([df, user_feats_df], axis=1)
    return df

# update_user_feats

In [None]:
def update_user_feats(df, user_features):
    # Ideally, we can pull less
    for idx, (_, user_id, content_id, task_container_id, content_type_id, part_id, timestamp, answered_correctly, user_answer) in enumerate(tqdm(df[[
        'user_id', 'content_id', 'task_container_id', 'content_type_id', 'part_id', 'timestamp', 'answered_correctly', 'user_answer'
    ]].itertuples())):
        # Adjustments
        part_id -= 1  # So we can index at 0
        (
            cached_task_container_id_u,
            task_container_num_questions_u,
            last_question_ts_u_recency2,
            session_u,
            last_ts_u,
            last_correct_ts_u,
            last_incorrect_ts_u,
            correctness_u,
            part_correctness_u,
            diagnostic_u_1,
            diagnostic_u_2,
            diagnostic_u_3,
            diagnostic_u_4,
            diagnostic_u_5,
            diagnostic_u_6,
            answer_ratio,
            streak_u,
            content_encounter,
        ) = user_features[user_id].get_features()
        part_correctness_u__part = part_correctness_u[part_id]
        
        # Step 0) Handle session stuff
        if np.isnan(last_ts_u[0]) or timestamp - last_ts_u[0] > 3600000 // 2:
            # Reset the session
            session_u[0] = timestamp
            session_u[1] = 2 # correct
            session_u[2] = 3 # count
            session_u[3] = 2 # correct
            session_u[4] = 3 # count
            session_u[5] = 0 # CumAnsTimeSessionNoLeak
        
        if content_encounter[content_id] == 1:
            encountered[idx] = 2
        else:
            # 0 for pad
            encountered[idx] = 1
            content_encounter[content_id] = 1
        
        if content_type_id==0:
            # Lag: Increase for each question we see
            task_container_num_questions_u[0] += 1
            answer_ratio[5] += 1
            
            # Answer Selection Ratio for questions only
            if user_answer==1:
                answer_ratio[0] += 1
            elif user_answer==2:
                answer_ratio[2] += 1

        # Content Counters
        correctness_u[1] += 1
        part_correctness_u__part[1] += 1
        session_u[2] += 1
        
        if answered_correctly==0:
            last_incorrect_ts_u[0] = timestamp
            streak_u[0] = 0  # reset correct counter
            streak_u[3] += 1 # increment incorrect counter
        else:
            last_correct_ts_u[0] = timestamp
            streak_u[0] += 1 # increment correct counter
            streak_u[3] = 0  # reset incorrect counter

            # Correct Content Counters
            correctness_u[0] += 1
            part_correctness_u__part[0] += 1
            session_u[1] += 1
            
        # Diagnostic Features
        if task_container_id < 40:
            if content_id == 6911:   diagnostic_u_1[0] = answered_correctly
            elif content_id == 7901: diagnostic_u_1[1] = answered_correctly
            elif content_id == 7219: diagnostic_u_2[0] = answered_correctly
            elif content_id == 6910: diagnostic_u_2[1] = answered_correctly
            elif content_id == 2066: diagnostic_u_2[2] = answered_correctly
            elif content_id == 6879: diagnostic_u_2[3] = answered_correctly
            elif content_id == 2065: diagnostic_u_2[4] = answered_correctly
            elif content_id == 1279: diagnostic_u_2[5] = answered_correctly
            elif content_id == 3365: diagnostic_u_3[0] = answered_correctly
            elif content_id == 4697: diagnostic_u_3[1] = answered_correctly
            elif content_id == 2594: diagnostic_u_3[2] = answered_correctly
            elif content_id == 6878: diagnostic_u_4[0] = answered_correctly
            elif content_id == 6912: diagnostic_u_4[1] = answered_correctly
            elif content_id == 4493: diagnostic_u_4[2] = answered_correctly
            elif content_id == 2596: diagnostic_u_4[3] = answered_correctly
            elif content_id == 7218: diagnostic_u_4[4] = answered_correctly
            elif content_id == 6909: diagnostic_u_5[0] = answered_correctly
            elif content_id == 7877: diagnostic_u_5[1] = answered_correctly
            elif content_id == 6880: diagnostic_u_5[2] = answered_correctly
            elif content_id == 7217: diagnostic_u_5[3] = answered_correctly
            elif content_id == 2595: diagnostic_u_5[4] = answered_correctly
            elif content_id == 2949: diagnostic_u_6[0] = answered_correctly
            elif content_id == 2948: diagnostic_u_6[1] = answered_correctly
            elif content_id == 4121: diagnostic_u_6[2] = answered_correctly
            elif content_id == 6881: diagnostic_u_6[3] = answered_correctly
            elif content_id == 6174: diagnostic_u_6[4] = answered_correctly
            elif content_id == 7220: diagnostic_u_6[5] = answered_correctly
            elif content_id == 3364: diagnostic_u_6[6] = answered_correctly

# update_user_df

In [None]:
def update_user_df(df, user_dfs):
    # We build the batch ourselves:
    _user_id, _task_container_id, _content_id, _part_id, _prior_question_elapsed_time,
    _prior_question_had_explanation, _incorrect_rank, _content_type_id, _bundle_id,
    _answer_ratio1, _answer_ratio2, _correct_streak_u, _incorrect_streak_u, _correct_streak_alltime_u,
    _incorrect_streak_alltime_u, _session_content_num_u, _session_duration_u, _session_ans_duration_u,
    _lifetime_ans_duration_u, _lag_ts_u_recency, _last_ts_u_recency1, _last_ts_u_recency2, _last_ts_u_recency3,
    _last_correct_ts_u_recency, _last_incorrect_ts_u_recency, _correctness_u_recency,
    _part_correctness_u_recency, _session_correctness_u, _encountered, _diagnostic_u_recency_1,
    _diagnostic_u_recency_2, _diagnostic_u_recency_3, _diagnostic_u_recency_4, _diagnostic_u_recency_5,
    _diagnostic_u_recency_6, _answered_correctly = [],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[],[]  
    
    for idx, (
        _, user_id, task_container_id, content_id, part_id, prior_question_elapsed_time,
        prior_question_had_explanation, incorrect_rank, content_type_id, bundle_id,
        answer_ratio1, answer_ratio2, correct_streak_u, incorrect_streak_u, correct_streak_alltime_u,
        incorrect_streak_alltime_u, session_content_num_u, session_duration_u, session_ans_duration_u,
        lifetime_ans_duration_u, lag_ts_u_recency, last_ts_u_recency1, last_ts_u_recency2, last_ts_u_recency3,
        last_correct_ts_u_recency, last_incorrect_ts_u_recency, correctness_u_recency,
        part_correctness_u_recency, session_correctness_u, encountered, diagnostic_u_recency_1,
        diagnostic_u_recency_2, diagnostic_u_recency_3, diagnostic_u_recency_4, diagnostic_u_recency_5,
        diagnostic_u_recency_6, answered_correctly
    ) in enumerate(tqdm(df[dataset_cols[COL_OFFSET_CUT:]].itertuples())):
        # TODO: We may have to cache some variables to record:
        # 1) How many items we're predicting in for this users TCID
        # 2) Where in test_df these values should map to (what line = idx??)
        
        # Find or create the user df:
        state = user_dfs[user_id]
        if state.pop_if_needed():
            state.lensize[0] += 1
            
        state.user_id.append(user_id)
        state.task_container_id.append(task_container_id)
        state.content_id.append(content_id)
        state.part_id.append(part_id)
        state.prior_question_elapsed_time.append(prior_question_elapsed_time)
        state.prior_question_had_explanation.append(prior_question_had_explanation)
        state.incorrect_rank.append(incorrect_rank) # decoder feature that needs help..........
        state.content_type_id.append(content_type_id)
        state.bundle_id.append(bundle_id)

        # only these of the 4 have signal (0-index)
        state.answer_ratio1.append(answer_ratio1)
        state.answer_ratio2.append(answer_ratio2)
        state.correct_streak_u.append(correct_streak_u)
        state.incorrect_streak_u.append(incorrect_streak_u)
        state.correct_streak_alltime_u.append(correct_streak_alltime_u)
        state.incorrect_streak_alltime_u.append(incorrect_streak_alltime_u)

        # our session information + a lifetime field (elapsed time sum on platform—probably should be capped)
        state.session_content_num_u.append(session_content_num_u)
        state.session_duration_u.append(session_duration_u)
        state.session_ans_duration_u.append(session_ans_duration_u)
        state.lifetime_ans_duration_u.append(lifetime_ans_duration_u)

        # First block of continuous - timestamp intervals
        state.lag_ts_u_recency.append(lag_ts_u_recency)
        state.last_ts_u_recency1.append(last_ts_u_recency1)
        state.last_ts_u_recency2.append(last_ts_u_recency2)
        state.last_ts_u_recency3.append(last_ts_u_recency3)
        state.last_correct_ts_u_recency.append(last_correct_ts_u_recency)
        state.last_incorrect_ts_u_recency.append(last_incorrect_ts_u_recency)

        # Second block of continuous - how often user is right on global/part/session basis
        state.correctness_u_recency.append(correctness_u_recency)
        state.part_correctness_u_recency.append(part_correctness_u_recency)
        state.session_correctness_u.append(session_correctness_u)

        # Have we been served this content before?
        state.encountered.append(encountered)

        # Average score of most frequent asked questions
        state.diagnostic_u_recency_1.append(diagnostic_u_recency_1)
        state.diagnostic_u_recency_2.append(diagnostic_u_recency_2)
        state.diagnostic_u_recency_3.append(diagnostic_u_recency_3)
        state.diagnostic_u_recency_4.append(diagnostic_u_recency_4)
        state.diagnostic_u_recency_5.append(diagnostic_u_recency_5)
        state.diagnostic_u_recency_6.append(diagnostic_u_recency_6)

        state.answered_correctly.append(answered_correctly)
        
        ################################################
        # Finally, add to our BATCH, which will be fed into a data loader to do padding
        _user_id.append(state.user_id)
        _task_container_id.append(state.task_container_id)
        _content_id.append(state.content_id)
        _part_id.append(state.part_id)
        _prior_question_elapsed_time.append(state.prior_question_elapsed_time)
        _prior_question_had_explanation.append(state.prior_question_had_explanation)
        _incorrect_rank.append(state.incorrect_rank)
        _content_type_id.append(state.content_type_id)
        _bundle_id.append(state.bundle_id)
        _answer_ratio1.append(state.answer_ratio1)
        _answer_ratio2.append(state.answer_ratio2)
        _correct_streak_u.append(state.correct_streak_u)
        _incorrect_streak_u.append(state.incorrect_streak_u)
        _correct_streak_alltime_u.append(state.correct_streak_alltime_u)
        _incorrect_streak_alltime_u.append(state.incorrect_streak_alltime_u)
        _session_content_num_u.append(state.session_content_num_u)
        _session_duration_u.append(state.session_duration_u)
        _session_ans_duration_u.append(state.session_ans_duration_u)
        _lifetime_ans_duration_u.append(state.lifetime_ans_duration_u)
        _lag_ts_u_recency.append(state.lag_ts_u_recency)
        _last_ts_u_recency1.append(state.last_ts_u_recency1)
        _last_ts_u_recency2.append(state.last_ts_u_recency2)
        _last_ts_u_recency3.append(state.last_ts_u_recency3)
        _last_correct_ts_u_recency.append(state.last_correct_ts_u_recency)
        _last_incorrect_ts_u_recency.append(state.last_incorrect_ts_u_recency)
        _correctness_u_recency.append(state.correctness_u_recency)
        _part_correctness_u_recency.append(state.part_correctness_u_recency)
        _session_correctness_u.append(state.session_correctness_u)
        _encountered.append(state.encountered)
        _diagnostic_u_recency_1.append(state.diagnostic_u_recency_1)
        _diagnostic_u_recency_2.append(state.diagnostic_u_recency_=2)
        _diagnostic_u_recency_3.append(state.diagnostic_u_recency_3)
        _diagnostic_u_recency_4.append(state.diagnostic_u_recency_4)
        _diagnostic_u_recency_5.append(state.diagnostic_u_recency_5)
        _diagnostic_u_recency_6.append(state.diagnostic_u_recency_6)
        _answered_correctly.append(state.answered_correctly)
    
    return (
        _user_id, _task_container_id, _content_id, _part_id, _prior_question_elapsed_time,
        _prior_question_had_explanation, _incorrect_rank, _content_type_id, _bundle_id,
        _answer_ratio1, _answer_ratio2, _correct_streak_u, _incorrect_streak_u, _correct_streak_alltime_u,
        _incorrect_streak_alltime_u, _session_content_num_u, _session_duration_u, _session_ans_duration_u,
        _lifetime_ans_duration_u, _lag_ts_u_recency, _last_ts_u_recency1, _last_ts_u_recency2, _last_ts_u_recency3,
        _last_correct_ts_u_recency, _last_incorrect_ts_u_recency, _correctness_u_recency,
        _part_correctness_u_recency, _session_correctness_u, _encountered, _diagnostic_u_recency_1,
        _diagnostic_u_recency_2, _diagnostic_u_recency_3, _diagnostic_u_recency_4, _diagnostic_u_recency_5,
        _diagnostic_u_recency_6, _answered_correctly
    )

# Dataset

In [None]:
class Riiid(torch.utils.data.Dataset):
    def __init__(self, test_df):
        global user_dfs
        
        (
            self.user_id, self.task_container_id, self.content_id, self.part_id, self.prior_question_elapsed_time,
            self.prior_question_had_explanation, self.incorrect_rank, self.content_type_id, self.bundle_id,
            self.answer_ratio1, self.answer_ratio2, self.correct_streak_u, self.incorrect_streak_u, self.correct_streak_alltime_u,
            self.incorrect_streak_alltime_u, self.session_content_num_u, self.session_duration_u, self.session_ans_duration_u,
            self.lifetime_ans_duration_u, self.lag_ts_u_recency, self.last_ts_u_recency1, self.last_ts_u_recency2, self.last_ts_u_recency3,
            self.last_correct_ts_u_recency, self.last_incorrect_ts_u_recency, self.correctness_u_recency,
            self.part_correctness_u_recency, self.session_correctness_u, self.encountered, self.diagnostic_u_recency_1,
            self.diagnostic_u_recency_2, self.diagnostic_u_recency_3, self.diagnostic_u_recency_4, self.diagnostic_u_recency_5,
            self.diagnostic_u_recency_6, self.answered_correctly
            
        ) = update_user_df(test_df, user_dfs)
        
    def __len__(self):
        return len(self.user_id)
    
    def __getitem__(self, idx):
        cont_feats = [[0,0, 0,0, 0,0, 0,0, 0,0]]
        sess_feats = [[0,0,0,0]]
        diag_feats = [[0,0,0,0,0,0]]
        strk_feats = [[0,0]]
    
        row_len = len(self.user_id[idx])        
        pad_len = WINDOW_SIZE - row_len
        padding = [PAD] * pad_len
        neg_pad = [-2] * pad_len
        
        return (
            # These are the features we need to build:
            # user_id, content_id, content_type_id, part_id, encountered_id,
            # prior_question_elapsed_time, prior_question_had_explanation,
            # correctness_id, bundle_id, cont_streaks,
            # continuous, session, diagnostic, padding_mask, task_container_id
            self.user_id[idx] + padding,
            np.array(self.content_id[idx] + padding).astype(int),
            np.array(self.content_type_id[idx] + padding).astype(int),
            np.array(self.part_id[idx] + padding).astype(int),
            np.array(self.encountered[idx] + padding).astype(int),
            np.array(self.prior_question_elapsed_time[idx] + padding).astype(int),
            np.array(self.prior_question_had_explanation[idx] + padding).astype(int),
            
            # TODO Check to see if we append something to this??? might be empty
            # NOTE: We don't get the actual incorrect rank until after our prediction
            # So at this point, incorrect rank will equal the previous Q's rank
            # We pad with a 1?? TODO Figure this out!!!!
            # I think we always want to use whatever data we have - but whatever data we have will
            # always be an item short. so the first item should be a [1] then we add...
            # Q. how do we build this feature? currently, we pretend its already there and dont
            # store it in our feature creation process.................. so we don't have this
            # data.
            np.array([1] + self.incorrect_rank[:-1] + padding).astype(int),
            
            np.array(self.bundle_id[idx] + padding).astype(int),

            np.vstack((
                # cont_streaks:
                self.answer_ratio1 + padding,
                self.answer_ratio2 + padding,
                self.correct_streak_u + padding,
                self.incorrect_streak_u + padding,
                self.correct_streak_alltime_u + padding,
                self.incorrect_streak_alltime_u + padding,
                
                # diagnostic:
                self.diagnostic_u_recency_1 + padding,
                self.diagnostic_u_recency_2 + padding,
                self.diagnostic_u_recency_3 + padding,
                self.diagnostic_u_recency_4 + padding,
                self.diagnostic_u_recency_5 + padding,
                self.diagnostic_u_recency_6 + padding,

                # session:
                self.session_content_num_u + padding,
                self.session_duration_u + padding,
                self.session_ans_duration_u + padding,
                self.lifetime_ans_duration_u + padding,
                
                # continuous
                self.lag_ts_u_recency + padding,
                self.last_ts_u_recency1 + padding,
                self.last_ts_u_recency2 + padding,
                self.last_ts_u_recency3 + padding,
                self.last_correct_ts_u_recency + padding,
                self.last_incorrect_ts_u_recency + padding,
                self.correctness_u_recency + padding,
                self.part_correctness_u_recency + padding,
                self.session_correctness_u + padding,
            )).astype(np.float32).T,
            
            # padding mask:
            np.array([0]*row_len + [1]*pad_len).astype(np.uint8),

            # task container id:
            np.array(self.task_container_id[idx] + padding).astype(int),
        )

In [None]:
def collate_fn(batch):
    (
        user_id, content_id, content_type_id, part_id, encountered_id,
        prior_question_elapsed_time, prior_question_had_explanation,
        correctness_id, bundle_id, fe_stuff, padding_mask, task_container_id
    ) = zip(*batch)
        
    user_id = torch.LongTensor(user_id)
    content_id = torch.LongTensor(content_id)
    content_type_id = torch.LongTensor(content_type_id)
    part_id = torch.LongTensor(part_id)
    encountered_id  = torch.LongTensor(encountered_id)
    prior_question_elapsed_time = torch.LongTensor(prior_question_elapsed_time)
    prior_question_had_explanation = torch.LongTensor(prior_question_had_explanation)
    correctness_id = torch.LongTensor(correctness_id)
    bundle_id = torch.LongTensor(bundle_id)
    
    fe_stuff = torch.FloatTensor(fe_stuff)
    padding_mask = torch.BoolTensor(padding_mask)
    task_container_id = torch.LongTensor(task_container_id)
    
    # remember the order
    return (
        user_id, content_id, content_type_id, part_id, encountered_id,
        prior_question_elapsed_time, prior_question_had_explanation,
        correctness_id, bundle_id, fe_stuff, padding_mask, task_container_id
    )

# Inference

In [None]:
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(DEVICE)

# creating the model
model = SaintTransformerModel(
    d_model=128,    # minimally
    nhead=4,        # leave it
    num_layers=2,   # maybe increase it 2-4
    dropout=0.1,   # can be increased....
    emb_dropout=0.1,
    initrange=0.02
)

#  model = SiameseNetwork([
#     SaintTransformerModel(
#         d_model=128,    # minimally
#         nhead=4,        # leave it
#         num_layers=2,   # maybe increase it 2-4
#         dropout=0.1,    # can be increased....
#         emb_dropout=0.15,
#         initrange=0.02
#     ),
#     SaintTransformerModel(
#         d_model=128,    # minimally
#         nhead=4,        # leave it
#         num_layers=2,   # maybe increase it 2-4
#         dropout=0.15,   # can be increased....
#         emb_dropout=0.1,
#         initrange=0.02
#     )
# ])
    
# Load up:

# TODO: This can be spruced up for siamese network:
state = torch.load(f'../input/uthman-riiid/best-loss-leanmodel_1.pth')
for key in list(state.keys()):
#     if 'user_embeddings' in key:
#         state.pop(key)
#         continue
    state[key[len('module.'):]] = state[key]
    del state[key]
print(model.load_state_dict(state, strict=False))


model.to(DEVICE)
model.eval()

In [None]:
previous_test_df = None
for counta, (test_df, sample_prediction_df) in enumerate(tqdm(iter_test)):
    if counta % 255 == 0:
        # Keep things fresh
        torch.cuda.empty_cache()
        gc.collect()
        
    if previous_test_df is not None:
        # Update answered_correctly and user_answer
        answered_correctly = np.array(eval(test_df["prior_group_answers_correct"].iloc[0]))
        user_answer = np.array(eval(test_df["prior_group_responses"].iloc[0]))
        user_answer[user_answer==-1] = 0               # Patch Lectures
        answered_correctly[answered_correctly==-1] = 1 # Patch Lectures
        previous_test_df['answered_correctly'] = answered_correctly
        previous_test_df['user_answer'] = user_answer
        update_user_feats(previous_test_df, user_features)
    
        # TODO: Update the DATASET with last 128 historical values for this user
    
    
    ############################################################################
    # Step 1) Make updates to the DF per our pre-processing and any joins we need:
    test_df.content_id += 1
    test_df.prior_question_had_explanation = (test_df.prior_question_had_explanation.astype(np.float16).fillna(-2) + 2).astype(np.uint8)
    mask_lectures = (test_df.content_type_id == 1)
    test_df.loc[mask_lectures, 'content_id'] = test_df[mask_lectures].content_id.map(lecture_ids_map)
    test_df.loc[mask_lectures, 'answered_correctly'] = 1
    test_df['part_id'] = test_df.content_id.map(part_ids_map).astype(np.uint8)
    test_df['bundle_id'] = test_df.content_id.map(bundle_id_map).fillna(0).astype(np.uint16)
    
    # TODO: To create incorrect rank, we also need the user_answer.....
    pd.merge(test_df, question_incorrect_ranks, how='left', on=['user_answer','content_id'], copy=False)
    test_df.loc[test_df.incorrect_rank.isna(), 'incorrect_rank'] = 2 # = "correct", for fillna
    test_df.incorrect_rank = test_df.incorrect_rank.astype(np.uint8)
    
    
    test_df['prior_question_elapsed_time_cont'] = test_df.prior_question_elapsed_time
    test_df.prior_question_elapsed_time //= 1000
    test_df.loc[test_df.prior_question_elapsed_time.isna(), 'prior_question_elapsed_time'] = -2
    test_df.prior_question_elapsed_time += 2
    # We are NOT mapping!!! We created a df that did NOT map since we dont embed users
    #df.user_id = df.user_id.map(embedded_user_ids_map).astype(np.uint32)
    
    # Step 2) Build out features and cache those that are necessary
    # Cache the df so that we can make updates after prediction
    test_df = add_user_feats_without_update(test_df, user_features)
    previous_test_df = test_df[
        # Minimal features used in `update_user_feats`
        ['user_id', 'content_id', 'task_container_id', 'content_type_id', 'part_id', 'timestamp', 'answered_correctly', 'user_answer']
    ].copy()

    # Step 3) Create a data loader! using test_df + user_dfs consisting of last 128 records for this user
    (
        _user_id, _task_container_id, _content_id, _part_id, _prior_question_elapsed_time,
        _prior_question_had_explanation, _incorrect_rank, _content_type_id, _bundle_id,
        _answer_ratio1, _answer_ratio2, _correct_streak_u, _incorrect_streak_u, _correct_streak_alltime_u,
        _incorrect_streak_alltime_u, _session_content_num_u, _session_duration_u, _session_ans_duration_u,
        _lifetime_ans_duration_u, _lag_ts_u_recency, _last_ts_u_recency1, _last_ts_u_recency2, _last_ts_u_recency3,
        _last_correct_ts_u_recency, _last_incorrect_ts_u_recency, _correctness_u_recency,
        _part_correctness_u_recency, _session_correctness_u, _encountered, _diagnostic_u_recency_1,
        _diagnostic_u_recency_2, _diagnostic_u_recency_3, _diagnostic_u_recency_4, _diagnostic_u_recency_5,
        _diagnostic_u_recency_6, _answered_correctly
    ) = update_user_df(test_df, user_dfs)
    

    # TODO: On the fly padding in the DataLoader
        
    # TODO: Predict using the model
    # TODO: Store predictions in the right location on test_df
    # And in the order they arrived; not in the order we created the dataset (by user, with this TCID last)
    test_df[TARGET] =  model.predict(test_df[FEATS])
    
    set_predict(
        # We only predict questions
        test_df.loc[
            test_df.content_type_id == 0,
            ['row_id', TARGET]
        ]
    )

# Gym

In [None]:
def save(path, model, optimizer, best_loss, epoch, step_scheduler=None, scheduler=None, lean=False):
    model.eval()
    
    if lean:
        torch.save(model.state_dict(), path)
        return

    params = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'epoch': epoch,
    }
    if scheduler is not None: params['scheduler_state_dict'] = scheduler.state_dict()
    if step_scheduler is not None: params['step_scheduler_state_dict'] = step_scheduler.state_dict()
    torch.save(params, path)

def load(path, model, optimizer, step_scheduler=None, scheduler=None):
    checkpoint = torch.load(path)
    
    if 'step_scheduler' in checkpoint:
        step_scheduler.load_state_dict(checkpoint['step_scheduler_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return (
        model, optimizer,
        checkpoint['best_loss'], checkpoint['epoch'],
        step_scheduler, scheduler
    )

# Fin~