# SAINT+ model

This kernel is a Pytorch implementation of:
https://arxiv.org/pdf/2010.12042v1.pdf

It includes:
* Features: only exercises and answers
* Transformer with both Encoder/Decoder.
* Both triangular mask (for lookahead) and padding mask applied
* BCEWithLogitsLoss

In [1]:
# General
import os, sys, random, gc, math, glob, time, pathlib
import numpy as np
import pandas as pd
import io, timeit, pickle, psutil
from tqdm.notebook import tqdm
from datetime import datetime, timedelta
import re, shutil
import psutil
import warnings

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import GridSearchCV, StratifiedKFold, TimeSeriesSplit, KFold, GroupKFold, ShuffleSplit
from sklearn import metrics
from collections import OrderedDict, defaultdict
import warnings

# Plotting
from matplotlib import cm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
DEFAULT_FIG_WIDTH = 20
sns.set_context("paper", font_scale=1.2) 

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
from torch.optim import Adam, SGD, AdamW

# Display
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 4000)
pd.options.display.float_format = '{:,.2f}'.format

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
# FOLDERS & FILES!
FOLDER_FEATHER = "/content/drive/My Drive/kaggle-riiid/feather-files"
PREPROCESS_FILE = "/content/drive/My\ Drive/Colab\ Notebooks/riiid-pytorch-preprocessing.ipynb"

In [4]:
FEATURES = ["content_id", "answered_correctly"]

In [5]:
print('Python     : ' + sys.version.split('\n')[0])
print('Numpy      : ' + np.__version__)
print('Pandas     : ' + pd.__version__)
print('PyTorch    : ' + torch.__version__)

Python     : 3.6.9 (default, Oct  8 2020, 12:12:24) 
Numpy      : 1.18.5
Pandas     : 1.1.5
PyTorch    : 1.7.0+cu101


In [6]:
def seed_everything(s):
    random.seed(s)
    os.environ['PYTHONHASHSEED'] = str(s)
    np.random.seed(s)
    # Torch
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

seed = 2020
seed_everything(seed)

In [7]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(DEVICE))

Running on device: cuda:0


In [8]:
class MaskedBCEWithLogitsLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_ce = nn.BCEWithLogitsLoss()

    def forward(self, input, target, mask=None):
        # Flatten to (BS*seq_len)
        input_ = input.reshape(-1)
        target_ = target.reshape(-1)
        
        if mask is not None:
            mask_ = mask.reshape(-1)
            input_ = input_[mask_ == False]
            target_ = target_[mask_ == False]

        return self.loss_ce(input_, target_)

In [9]:
class raw_conf:
    # Data
    seq_len = 100 
    embedding_dim = 256 # embed_dim must be divisible by num_heads
    exercices_id_size = 13523
    response_size = 2

    # Model
    nhead = 8
    num_encoder_layers = 2
    num_decoder_layers = 2
    dim_feedforward = 1028
    dropout = 0.1
    position_encoding_enabled = True

    # Loss function
    loss = MaskedBCEWithLogitsLoss()

    # Optimizer
    optimizer = "Noam" # "Adam"
    warm_up_step_count = 10000 # 4000 
    warm_up_scale = 1.5 # 1.5
    lr = 0.0003

    # Training params
    WORKERS = psutil.cpu_count()
    BATCH_SIZE = 128
    ITERATIONS_LOGS = 50
    EPOCHS = 100

    pin_memory = True

    if torch.cuda.is_available():
        map_location=lambda storage, loc: storage.cuda()
    else:
        map_location='cpu'

In [10]:
# Config file
conf = raw_conf()

# Prepare data

In [11]:
def to_series(df):
    df.set_index(["user_id"], inplace=True)
    df = df[FEATURES]
    return df

In [12]:
train = pd.read_feather(FOLDER_FEATHER + "/train.feather")

In [13]:
train = train[["user_id", "content_id", "content_type_id", "answered_correctly"]]
train = train.loc[train["content_type_id"]==False]
train.head()

Unnamed: 0,user_id,content_id,content_type_id,answered_correctly
0,115,5692,False,1
1,115,5716,False,1
2,115,128,False,1
3,115,7860,False,1
4,115,7922,False,1


In [14]:
users = train["user_id"].unique()

# Select 90% for train, 10% for validation
users_train, users_val = users[:int(len(users)*0.9)], users[int(len(users)*0.9):]

train_pd = train.loc[train["user_id"].isin(users_train)]
valid_pd = train.loc[train["user_id"].isin(users_val)]

In [15]:
train_pd = to_series(train_pd)
print(train_pd.shape)
train_pd.head(1)

(89336475, 2)


Unnamed: 0_level_0,content_id,answered_correctly
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1
115,5692,1


In [16]:
valid_pd = to_series(valid_pd)
print(valid_pd.shape)
valid_pd.head(1)

(9934825, 2)


Unnamed: 0_level_0,content_id,answered_correctly
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1933656070,6366,1


In [17]:
print(train_pd.index.nunique(), valid_pd.index.nunique())

354290 39366


# Optimizer

In [18]:
class NoamOpt:
    "Optim wrapper that implements rate."

    def __init__(self, model_size, factor, warmup, optimizer, warmup_scale=1.0):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.warmup_scale = warmup_scale
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5) / self.warmup_scale, step * self.warmup ** (-1.5)))

class NoamOptimizer:
    def __init__(self, model, lr, model_size, warmup, warmup_scale):
        self._adam = torch.optim.Adam(model.parameters(), lr=lr)
        self._opt = NoamOpt(model_size=model_size, factor=1, warmup=warmup, optimizer=self._adam, warmup_scale=warmup_scale)

    def zero_grad(self):
        self._opt.zero_grad()

    def step(self):
        self._opt.step()
    
    def get_last_lr(self):
        return self._opt._rate

# Model

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model) # torch.Size([max_len, d_model])
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # torch.Size([max_len, 1]) # 0,1,2,3,4,...max_len-1
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # torch.Size([d_model/2])
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # torch.Size([max_len, 1, d_model])
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [20]:
# Model
class RIIIDModel(nn.Module):
    def __init__(self, cfg, verbose=False):
        super().__init__()
        self.response_size = cfg.response_size
        self.seq_len = cfg.seq_len
        self.embedding_dim = cfg.embedding_dim

        # Position embedding
        self.pos_encoder1 = None
        self.pos_encoder2 = None

        # Exercices embeddings
        self.exercices_id_embedding = nn.Embedding(cfg.exercices_id_size, self.embedding_dim)

        # Response embeddings
        self.response_embedding = nn.Embedding(cfg.response_size + 2, self.embedding_dim) # +1 to include start token

        # Position encoder (relative or absolute position of the tokens in the sequence)
        if cfg.position_encoding_enabled is True:
            self.pos_encoder1 = PositionalEncoding(self.embedding_dim, cfg.dropout)
            self.pos_encoder2 = self.pos_encoder1 # PositionalEncoding(input_features_dim, cfg.dropout)

        # Transformer with default encoder/decoder        
        self.transformer = nn.Transformer(d_model=self.embedding_dim, 
                                          nhead=cfg.nhead, 
                                          num_encoder_layers=cfg.num_encoder_layers,
                                          num_decoder_layers=cfg.num_decoder_layers, 
                                          dim_feedforward=cfg.dim_feedforward, 
                                          dropout=cfg.dropout, 
                                          activation='relu', custom_encoder = None, custom_decoder = None)
        
        # Decoder
        self.fc = nn.Linear(self.embedding_dim, 1)

        self.init_weights() # Xavier

    def init_weights(self):
        # Xavier uniform initialization
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
   
    def generate_mask(self, size, diagonal=1):        
        return torch.triu(torch.ones(size, size)==1, diagonal=diagonal)
        # return self.generate_square_subsequent_mask(size)        

    def forward(self, data, src_mask=None, tgt_mask=None, mem_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Each input is (BS, seq_len)

        # Exercises
        x = data["content_id"].long()
        x = self.exercices_id_embedding(x) # (BS, seq_len, embedding_dim)
        x = self.pos_encoder1(x) # (BS, seq_len, embedding_dim)
        x = x.transpose(1,0) # (seq_len, BS, embedding_dim)


        # Response
        y = data["answered_correctly"].long() # BS x SEQ_LEN
        # y = torch.roll(y, shifts=(0, 1, 0), dims=(0, 1, 0)) # Shift right the sequence
        y = torch.roll(y, shifts=(0, 1), dims=(0, 1))
        y[:,0] = self.response_size + 1 # Start token (3)

        y = self.response_embedding(y) # (BS, seq_len, embedding_dim)
        y = self.pos_encoder2(y)
        y = y.transpose(1,0) # (seq_len, BS, embedding_dim)

        out_transformer = self.transformer(src=x, tgt=y, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=mem_mask, 
                                         src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) # (seq_len, BS, embedding_dim)
        out_transformer = out_transformer.transpose(1,0) # (BS, seq_len, embedding_dim)
     
        output = self.fc(out_transformer)
        output = output.squeeze(dim=2)
        return output

In [21]:
def build_model(cfg, device):
    model = RIIIDModel(cfg)
    model = model.to(device)
    
    optimizer = NoamOptimizer(model=model, lr=cfg.lr, model_size=cfg.embedding_dim, warmup=cfg.warm_up_step_count, warmup_scale=cfg.warm_up_scale)

    # Loss
    loss = cfg.loss
    loss = loss.to(device)

    return model, loss, optimizer

In [22]:
def accuracy(y_true, y_pred, mask=None):
    y_pred = np.where(y_pred > 0.5, 1, 0) # convert to binary

    # Flatten to (cumulated_BS * seq_len)
    if mask is not None:
        mask = mask.reshape(-1)
        y_true = y_true.reshape(-1)[mask == False]
        y_pred = y_pred.reshape(-1)[mask == False]        
        score = (y_true == y_pred).sum()
        score = score/len(y_pred)
    else:
        y_true = y_true.reshape(-1)
        y_pred = y_pred.reshape(-1)
        score = (y_true == y_pred).sum()
        score = score/len(y_pred)      
    return score


def auc(y_true, y_pred, mask=None):
    # Flatten to (cumulated_BS * seq_len)
    # print(y_true, y_pred, mask)
    if mask is not None:
        mask = mask.reshape(-1)
        y_true = y_true.reshape(-1)[mask == False]
        y_pred = y_pred.reshape(-1)[mask == False]        
        score = metrics.roc_auc_score(y_true, y_pred)
    else:
        y_true = y_true.reshape(-1)
        y_pred = y_pred.reshape(-1)
        score = metrics.roc_auc_score(y_true, y_pred)

    return score 

# Riiid dataset

In [23]:
# Dataset/DataLoader:
class RIIIDDataset(Dataset):
    def __init__(self, df, conf, subset="train"):
        super().__init__()
        self.df = df
        self.users = self.df.index.unique()
        self.conf = conf  
        self.subset = subset

        # Selection for each subset
        if subset == 'train':
            self.get_offset = np.random.randint
        elif subset == 'valid':
            self.get_offset = lambda x: 0
        elif subset == 'test':
            self.get_offset = lambda x: 0            
        else:
            raise RuntimeError("Unknown subset")
   
    def __len__(self):
        return len(self.users)
       
    def get_sample(self, row):        
        # build slice_pd with (features, seq_len)
        slice_pd = None
        mask = np.zeros((self.conf.seq_len), dtype=bool) # All False            

        series_len = len(row)
        start_index, stop_index = 0, series_len
        # build slice_pd with (features, seq_len)
        if series_len < self.conf.seq_len:
            # Extract chunk with valid data
            tmp = row.iloc[start_index:stop_index, :]
            # Not enough data => use pad_size
            pad_size = self.conf.seq_len-tmp.shape[0]

            # Pad with numpy - switching to right padding!
            right_pad = np.pad(tmp.values, ((0, pad_size), (0, 0)), 'constant', constant_values=2)
            slice_pd = pd.DataFrame(right_pad, columns=tmp.columns)
            
            # Mask     
            mask = np.concatenate([np.zeros((slice_pd.shape[0]-pad_size), dtype=bool), np.ones((pad_size), dtype=bool)])
        
        # Select another random slice
        elif series_len > self.conf.seq_len:
            random_start_index = self.get_offset(series_len - self.conf.seq_len)
            slice_pd = row.iloc[random_start_index:random_start_index + self.conf.seq_len, :]
        
        else: # exactly equal to series_len
            slice_pd = row.iloc[start_index:stop_index, :]

        sample = {}
        for input in FEATURES:
            if input in slice_pd.columns:
                # Pick feature!
                series_data = slice_pd[input].values
                
                if input == "answered_correctly": # TARGET!
                    labels = torch.from_numpy(series_data).long()
                    sample["labels"] = labels
                item = torch.from_numpy(series_data)
                sample[input] = item

        sample["mask"] = mask
        return sample

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # For train/valid we need series per user
        user_id = self.users[idx]
        row = self.df.loc[user_id:user_id]
        # print(row)
        sample =  self.get_sample(row)
        return sample

In [24]:
valid_dataset = RIIIDDataset(valid_pd, raw_conf, subset="valid")
a = iter(valid_dataset)

In [25]:
next(a)

{'answered_correctly': tensor([1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,
         1, 0, 1, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2], dtype=torch.int16),
 'content_id': tensor([ 6366,  3665,   296,  4002,  5355,  5613,  5987,  4476,  5589,  5374,
          9175,  5555,  5668,  3981, 11421,  5692,  4108,   689,   951,   733,
         10688,   788,   204,  1257,   294,   474,   277,  5613,   277,   951,
          5374,  3981,  5668,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,  

In [26]:
# Formatting for tqdm
def format_logs(logs):
    str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
    s = ', '.join(str_logs)
    return s

In [27]:
# Train loop
def train_loop_fn(batches, model, optimizer, criterion, device, verbose=True):
    model.train()
    count, running_loss = 0, 0.0
    all_probs, all_targets = None, None
    all_padding_mask = None
    src_mask = None

    scores_str, scores = {}, {}
    with tqdm(batches, file=sys.stdout, disable=not(verbose)) as iterator:
        for x, batch in enumerate(iterator, 1):
            try:
                for k, v in batch.items(): # data to GPU
                    batch[k] = v.to(device)
                
                # reset gradients!
                optimizer.zero_grad()

                # Get data and mask
                data, labels, padding_mask = batch, batch["labels"], batch["mask"]
                src_mask = model.generate_mask(conf.seq_len).to(device)
                
                # Predict output
                output = model(data, src_mask=src_mask, tgt_mask=src_mask, mem_mask=src_mask, 
                               src_key_padding_mask= padding_mask, tgt_key_padding_mask=padding_mask, memory_key_padding_mask=padding_mask) # forward pass
                
                
                # Loss steps
                loss = criterion(output, labels.float(), mask=padding_mask)
                loss.backward()
                optimizer.step()
                     
                if (conf.ITERATIONS_LOGS > 0) and (x % conf.ITERATIONS_LOGS == 0):
                    
                    # Running loss
                    loss_value = loss.item()
                    if ~np.isnan(loss_value): running_loss += loss_value
                    else: print("Warning: NaN loss") 

                    # Labels predictions
                    probs = torch.sigmoid(output).detach().cpu().numpy()
                    targets = labels.detach().cpu().numpy()
                    padding_mask = padding_mask.cpu().numpy()

                    # Concatenate for all batches
                    all_probs = np.concatenate([all_probs, probs], axis=0) if all_probs is not None else probs
                    all_targets = np.concatenate([all_targets, targets], axis=0) if all_targets is not None else targets 
                    all_padding_mask = np.concatenate([all_padding_mask, padding_mask], axis=0) if all_padding_mask is not None else padding_mask  
                    
                    count += 1

                    scores_str["train_auc"] = auc(all_targets, all_probs, mask = all_padding_mask)
                    scores_str["train_accuracy"] = accuracy(all_targets, all_probs, mask = all_padding_mask)
                    scores_str["train_loss"] = (running_loss / count)

                    iterator.set_postfix_str(format_logs(scores_str))

            except Exception as ex:
                print("Training batch error:", ex)
        
    return scores_str

In [28]:
# Valid loop
def valid_loop_fn(batches, model, criterion, device, verbose=True):
    model.eval()
    count, running_loss = 0, 0.0
    all_probs, all_targets = None, None
    all_padding_mask = None
    src_mask = None

    scores_str, scores = {}, {}
    with tqdm(batches, file=sys.stdout, disable=not(verbose)) as iterator:
        for x, batch in enumerate(iterator, 1):
            try:
                with torch.no_grad():
                    for k, v in batch.items(): # data to GPU
                        batch[k] = v.to(device)

                    # Get data and mask
                    data, labels, padding_mask = batch, batch["labels"], batch["mask"]
                    src_mask = model.generate_mask(conf.seq_len).to(device)

                    # Predict output
                    output = model(data, src_mask=src_mask, tgt_mask=src_mask, mem_mask=src_mask, 
                                  src_key_padding_mask= padding_mask, tgt_key_padding_mask=padding_mask, memory_key_padding_mask=padding_mask) # forward pass
                    
                    # Loss steps
                    padding_mask = padding_mask
                    loss = criterion(output, labels.float(), mask=padding_mask)  

                    if (conf.ITERATIONS_LOGS > 0) and (x % conf.ITERATIONS_LOGS == 0):
                        
                        # Running loss
                        loss_value = loss.item()
                        if ~np.isnan(loss_value): running_loss += loss_value
                        else: print("Warning: NaN loss") 

                        # Labels predictions
                        probs = torch.sigmoid(output).detach().cpu().numpy()
                        targets = labels.detach().cpu().numpy()
                        padding_mask = padding_mask.cpu().numpy()

                        # Concatenate for all batches
                        all_probs = np.concatenate([all_probs, probs], axis=0) if all_probs is not None else probs
                        all_targets = np.concatenate([all_targets, targets], axis=0) if all_targets is not None else targets 
                        all_padding_mask = np.concatenate([all_padding_mask, padding_mask], axis=0) if all_padding_mask is not None else padding_mask  
                        
                        count += 1

                        scores_str["valid_auc"] = auc(all_targets, all_probs, mask = all_padding_mask)
                        scores_str["valid_accuracy"] = accuracy(all_targets, all_probs, mask = all_padding_mask)
                        scores_str["valid_loss"] = (running_loss / count)

                        iterator.set_postfix_str(format_logs(scores_str))

            except Exception as ex:
                print("Valid batch error:", ex)
        
    return scores_str

In [29]:
def run_stage(X_train, X_valid, device):
    
    # Datasets
    train_dataset = RIIIDDataset(X_train, conf, subset="train")
    valid_dataset = RIIIDDataset(X_valid, conf, subset="valid")

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=conf.BATCH_SIZE, num_workers=conf.WORKERS, drop_last = False, pin_memory=conf.pin_memory, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=conf.BATCH_SIZE, shuffle=False, num_workers=conf.WORKERS, drop_last = False, pin_memory=conf.pin_memory)


    print("Device:", device, "workers:", conf.WORKERS,  "batch size:", conf.BATCH_SIZE,
          "train dataset:", len(train_dataset), "valid dataset:", len(valid_dataset))
    
    # Build model
    model, criterion, optimizer = build_model(conf, device)
    print(model)

    best_auc = 0
    for epoch in tqdm(range(1, conf.EPOCHS + 1)):
      
        # Train loop
        train_scores = train_loop_fn(train_loader, model, optimizer, criterion, device)

        # Validation loop
        valid_scores = valid_loop_fn(valid_loader, model, criterion, device)

        if valid_scores["valid_auc"] > best_auc:
            best_auc = valid_scores["valid_auc"]
            torch.save(model.state_dict(), "best_model.pth")

In [30]:
_ = run_stage(train_pd, valid_pd, DEVICE)

Device: cuda:0 workers: 4 batch size: 128 train dataset: 354290 valid dataset: 39366
RIIIDModel(
  (exercices_id_embedding): Embedding(13523, 256)
  (response_embedding): Embedding(4, 256)
  (pos_encoder1): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (pos_encoder2): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=1028, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1028, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2768.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=308.0), HTML(value='')))





## Add output to Drive

In [31]:
OUTPUT_DRIVE = "/content/drive/My Drive/kaggle-riiid/subs"

In [32]:
# # https://stackoverflow.com/questions/15034151/copy-directory-contents-into-a-directory-with-python
# from distutils.dir_util import copy_tree
# copy_tree(OUTPUT_FOLDER, os.path.join(OUTPUT_DRIVE, OUTPUT_FOLDER))
shutil.move("best_model.pth", OUTPUT_DRIVE)

'/content/drive/My Drive/kaggle-riiid/subs/best_model.pth'

# Old code

In [33]:
# x = x * math.sqrt(self.embedding_dim) # Counterbalance attention normalization that divide by math.sqrt(self.embedding_dim)
# y = y * math.sqrt(self.embedding_dim) # Counterbalance attention normalization that divide by math.sqrt(self.embedding_dim)

In [34]:
    # tensor([[0., -inf, -inf, -inf],
    #         [0., 0., -inf, -inf],
    #         [0., 0., 0., -inf],
    #         [0., 0., 0., 0.]])

In [35]:
# Figure!

# fig, ax = plt.subplots(1,2,figsize=(20,4))
# d = sns.distplot(train_pd.value_counts(["user_id"]).reset_index()[0], ax=ax[0], label="train", hist=False)
# d.set_xlabel("Interactions")
# d = sns.distplot(valid_pd.value_counts(["user_id"]).reset_index()[0], ax=ax[1], label="valid", hist=False)
# d.set_xlabel("Interactions")
# d = plt.legend()

In [36]:
# # Load train
# train_pd = pd.read_pickle(DATA_HOME + "cv%d_train.pickle" % FOLD)[COLS]
# print(train_pd.shape, train_pd[USER_ID].nunique())

In [37]:
# # train_outliers = drop_outliers(train_pd, MIN_INTERACTIONS_TRAIN, MAX_INTERACTIONS_TRAIN)
# # print("Dropping", len(train_outliers), "users from train with less than", MIN_INTERACTIONS_TRAIN, "interactions or more than", MAX_INTERACTIONS_TRAIN)
# # train_pd = train_pd[~train_pd[USER_ID].isin(train_outliers)].reset_index(drop=True)
# # print(train_pd.shape, train_pd[USER_ID].nunique())

# # Split Train dataset to fit memory (SUBSETS=None leads to OOM on Kaggle)
# if SUBSETS is not None:
#     sub_kf = GroupKFold(n_splits = SUBSETS)
#     for trn_idx, val_idx in sub_kf.split(train_pd, groups = train_pd[USER_ID]):
#         train_pd = train_pd.iloc[trn_idx].reset_index(drop=True)
#         #train_pd = train_pd.iloc[val_idx].reset_index(drop=True)
#         break
#     print("Train:", train_pd.shape)
#     del trn_idx, val_idx
#     gc.collect()

# # Prepare dataset
# train_pd = cleanup(train_pd, elapsed_time_seconds_mean=elapsed_time_seconds_mean, lag_minutes_mean=lag_minutes_mean)

In [38]:
# train_pd = to_series(train_pd)
# print(train_pd.shape)
# train_pd.head()

In [39]:
# # Load question data
# def load_questions():
#     # questions_pd = pd.read_csv(TRAIN_FILE_QUESTIONS)
#     questions_pd = pd.read_feather(TRAIN_FILE_QUESTIONS)
#     questions_pd[PART] = questions_pd[PART].astype(np.int8)
#     questions_pd[BUNDLE_ID] = questions_pd[BUNDLE_ID].astype(np.int32)
#     tag = questions_pd[TAGS].str.split(" ", n = 10, expand = True) 
    
#     tag.columns = ['tags1','tags2','tags3','tags4','tags5','tags6']
#     questions_pd =  pd.concat([questions_pd,tag], axis=1)
#     questions_pd['tags1'] = questions_pd['tags1'].astype(np.float32).astype('Int16')
#     questions_pd.rename(columns={"question_id": CONTENT_ID}, inplace=True)
    
#     for col in ["tags1", "tags2", "tags3", "tags4", "tags5", "tags6"]:
#         if col in questions_pd.columns:
#             questions_pd[col] = questions_pd[col].astype('category').cat.codes
#     questions_pd.drop(columns=[TAGS, CORRECT_ANSWER, BUNDLE_ID, "tags1", "tags2", "tags3", "tags4", "tags5", "tags6"], inplace=True)
#     return questions_pd

# questions_pd = load_questions()
# questions_pd.head()

In [40]:
# step = np.arange(1, 12000)
# lr = conf.embedding_dim**(-0.5) * np.minimum(step**(-0.5)/(conf.warm_up_scale), step*(conf.warm_up_step_count**(-1.5)))
# d = plt.plot(lr)
# d = plt.title(max(lr))

In [41]:
    # # Load weights
    # if (encoder_weights is not None) and (os.path.exists(encoder_weights)):
    #     print("Load weights before optimizer from: %s" % encoder_weights)
    #     model.load_state_dict(torch.load(encoder_weights, map_location=cfg.map_location), strict=False)
    #     if freeze_backbone is True:
    #         print("Freeze backbone")
    #         model.freeze_backbone()