In [1]:
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import random
import pdb
import itertools
import numpy as np
import pandas as pd
import datatable as dt
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn import metrics

import pickle
import h5py
import joblib

# import seaborn as sns
# import lightgbm as lgb
# from lightgbm import LGBMClassifier


In [None]:
!pip install latest_bert/dist/bert_src-1.0-py3-none-any.whl

In [2]:
from latest_bert.src.bert_src.modeling_bert import RiidModel
from latest_bert.src.bert_src.configuration_bert import BertConfig

In [3]:
if torch.cuda.is_available(): 
    device = torch.device("cuda:7")
else:
    device = torch.device("cpu")
print(device)

cuda:7


In [4]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [5]:
class Params:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
    def update(self, **kargs):
        self.__dict__.update(kargs)

In [6]:
path = './kaggle/input/riiid-test-answer-prediction'
train_file = f'{path}/train.csv'
train_dtypes = {'row_id': 'int64',
              'timestamp': 'int64',
              'user_id': 'int32',
              'content_id': 'int16',
              'content_type_id': 'int8',
              'task_container_id': 'int16',
              'user_answer': 'int8',
              'answered_correctly': 'int8',
              'prior_question_elapsed_time': 'float32', 
              'prior_question_had_explanation': 'boolean',
             }
test_file = f'{path}/example_test.csv'
test_sample = pd.read_csv(test_file)
questions = pd.read_csv(f'{path}/questions.csv')
lectures = pd.read_csv(f'{path}/lectures.csv')
print('Question shapes:', questions.shape)
print('Lecture shapes:', lectures.shape)

Question shapes: (13523, 5)
Lecture shapes: (418, 4)


In [None]:
colnames = ['row_id', 'timestamp', 'user_id', 'content_id', 'content_type_id', 'task_container_id', 'user_answer', 'answered_correctly', 'prior_question_elapsed_time', 'prior_question_had_explanation']
# chunks = pd.read_csv(train_file, chunksize=1e3, dtype=train_dtypes, header=None, names=colnames, index_col=False)
chunks = pd.read_csv(train_file, chunksize=1e5, dtype=train_dtypes)


In [None]:
question_tags = list(map(lambda x: map(lambda v: int(v) + 1, str(x).split()) if str(x).strip() != 'nan' else [0], questions.tags.values))
question_tags = list(set(itertools.chain(*question_tags)))

n_tags = len(question_tags)
n_parts = len(set(questions.part.unique())) + 1

print(f'n_tags {n_tags}, n_parts {n_parts}')

In [None]:
params_dict = {
    'is_update_train': True,
    'is_update_valid': False,
    'load_state': True,
    'use_buffer': False,
    'is_offline': False,
    'batch_norm': False,
    'is_test': False,
    'n_chunks': 200,
    'n_epoch': 10,
    'learning_rate': 5e-5,
    'batch_size': 64,
    'num_workers': 4,
    'cuda': torch.cuda.is_available(),
    'num_questions': questions.question_id.nunique(),
    'num_lectures': lectures.lecture_id.nunique(),
    'num_total_q_tags': n_tags,
    'max_len_tags': 6,
    'num_total_q_part': n_parts,
    'n_layers': 2,
    'dropout': 0.1,
    'hidden_size': 256,
    'max_position_embeddings': 256,
    'num_hidden_layers': 4,
    'num_attention_heads': 8,
    'intermediate_size': 1024,
    'max_task_size': 10002,
    'lag_size': 720,
    'ans_size': 3,
    'buffer_size_limit': 1e4,
    'max_seq_length': 100,
    'extra_dir':'./save/bert_save',
    'save_dir': './save/bert_save'
}
params = Params(**params_dict)
print(params.__dict__)

In [None]:
def split_data(train_part, n_tail=10):
    valid = train_part.groupby('user_id').tail(n_tail)
    train = train_part[~train_part.index.isin(valid.index)]
    return train, valid

In [None]:
test_sample['time_shift'] = test_sample['timestamp'].copy()
test_sample['time_shift'] = test_sample.groupby('user_id')['time_shift'].shift()
test_sample['lag_time'] = (test_sample['timestamp'] - test_sample['time_shift'] - test_sample['prior_question_elapsed_time']) // 100
test_sample['lag_time'] = test_sample['lag_time'].map(lambda x: x if (x != np.nan and x > 0) else 0)


In [None]:
test_sample.tail(10)

In [None]:
class LectureData(Dataset):
    
    def __init__(self, params, question_df=None, lecture_df=None, is_train=True):
        # read init-data
        self.params = params
        self.is_train = is_train
        
        self.features = ['combined_id', 'content_type_id', 'task_container_id', 
                         'answered_correctly','prior_question_elapsed_time', 'part', 'tags', 'prior_question_elapsed_time']
        
        self.train_columns = ['user_id', 'content_id', 'content_type_id', 'task_container_id', 'timestamp',
                              'answered_correctly', 'prior_question_elapsed_time', 'row_id']
        
        self.prior_batch, self.current_batch, self.buffer_df, self.batch_user, self.question_df = None, None, None, None, None
        self.user_dict, self.question2idx, self.lecture2idx = None, None, None
        
        # Restore all data
        if params.load_state:
            self.load_state()
        else:
            self.init_info(question_df, lecture_df)

    def init_info(self, question_df, lecture_df):
        
        self.question_df = self.proc_questiondata(question_df)
        
        self.question_list = list(question_df['question_id'].unique())
        self.lecture_list = list(lecture_df['lecture_id'].unique())

        self.n_questions = len(self.question_list)
        self.n_lectures = len(self.lecture_list)
        
        self.question2idx = dict(zip(self.question_list, range(1, self.n_questions + 1))) # 0 for padding
        self.lecture2idx = dict(zip(self.lecture_list, range(self.n_questions + 1, self.n_questions + 1 + self.n_lectures)))
        
    def load_state(self):  
        
        print('Loading data state ...')
        f = h5py.File(os.path.join(self.params.extra_dir, 'data2idx.h5'), 'r')
        
        self.question2idx = f['question2idx'][:]
        self.n_questions = len(self.question2idx)
        self.question2idx = dict(zip(self.question2idx, range(1, self.n_questions + 1)))
        
        self.lecture2idx = f['lecture2idx'][:]
        self.n_lectures = len(self.lecture2idx)
        self.lecture2idx = dict(zip(self.lecture2idx, range(self.n_questions + 1, self.n_questions + 1 + self.n_lectures)))
        f.close()
        
        with open(os.path.join(self.params.extra_dir, 'user_dict.pickle'), 'rb') as handle:
            self.user_dict = pickle.load(handle)


    def proc_traindata(self, train_df):
        
        train_df['prior_question_elapsed_time'].fillna(0, inplace=True)
        
#         train_df['prior_question_had_explanation'].fillna(False, inplace=True)
#         train_df['prior_question_had_explanation'] = train_df['prior_question_had_explanation'].astype(int)        
#         train_df['prior_question_elapsed_time'] = train_df['prior_question_elapsed_time'].map(lambda x: np.log(x + 1.))

        train_df['combined_id'] = train_df.apply(lambda x: self.question2idx[x['content_id']], axis=1) 
#                                                  if x['content_type_id']==0 
#                                                  else self.lecture2idx[x['content_id']], axis=1)
#         train_df['task_container_id'].fillna(-1, inplace=True)
        train_df['task_container_id'] = train_df['task_container_id'].map(lambda x: x + 1)
#         pdb.set_trace()
        train_df['time_shift'] = train_df['timestamp'].copy()
        train_df['time_shift'] = train_df.groupby('user_id')['time_shift'].shift()
        train_df['lag_time'] = (train_df['timestamp'] - train_df['time_shift'] - train_df['prior_question_elapsed_time']) // 500
        train_df['lag_time'] = train_df['lag_time'].map(lambda x: x if (x != np.nan and x > 0 and x < self.params.lag_size) else 0).astype(int)

        train_df["prior_question_elapsed_time"] = train_df["prior_question_elapsed_time"] // 1000

        return train_df
        
    def proc_questiondata(self, question_df):
        
        def _convert_vector(q_ids):
            blank_vector = np.zeros(self.params.num_total_q_tags).astype(int)
            blank_vector[q_ids] = 1
            return blank_vector
        
        question_df.tags.fillna('-1', inplace=True)
        question_df['tags'] = question_df['tags'].map(lambda x: list(map(lambda s: int(s) + 1, str(x).split())))
        question_df['tags'] = question_df['tags'].map(_convert_vector)
#         question_df['part'] = question_df['part'].map(lambda x: x - 1)
        
        return question_df
            
    def test_batch(self, batch_df):
        
        self.is_train = False
        
        # fetch prior labels
        gt_prior_batch = eval(batch_df.iloc[0]["prior_group_answers_correct"])
        
        # HERE stop updating first for 1st submission
        if self.current_batch is not None and len(gt_prior_batch) > 0:
            # save prior-batch with labels
            self.prior_batch = self.current_batch
            
            # Assign label to prev-batch
            self.prior_batch['answered_correctly'] = gt_prior_batch
            self.prior_batch = self.prior_batch[self.train_columns]
            
            # add to buffer-df
            if self.params.use_buffer:
                self.buffer_df = pd.concat([self.buffer_df, self.prior_batch], axis=0, ignore_index=True)
             
        else:
            self.prior_batch = batch_df
        
        self.current_batch = batch_df
        # create dummy labels "2", later will be replaced
        self.current_batch['answered_correctly'] = 2
        
        test_user_dict = self.proc_batch(self.current_batch)
        test_user_dict = self.update_newdata(test_user_dict) # append user history
        
        self.batch_user = list(test_user_dict.items())
        
        return len(self.batch_user)        
            
    def finetune_batch(self, is_finetune=False):
        
        self.is_train = True
        
        if self.buffer_df is not None and len(self.buffer_df) > params.buffer_size_limit:

            print('--> Dataset activated finetune buffer')
            buff_user_dict = self.proc_batch(self.buffer_df)
            
            # update user_dict later on
            buff_update = self.update_newdata(buff_user_dict)
            self.user_dict.update(buff_update)
            
            if is_finetune:
                # create data for continue training
                self.batch_user = list(buff_update.items())
            
            self.buffer_df = None
            
            # check user_dict size
            cur_size_in_gb = sys.getsizeof(self.user_dict) // (1024**3)
            if cur_size_in_gb > 12:
                print('--> Random removing user-entries')
                for i in range(int(len(self.user_dict) * 0.01)):
                    self.user_dict.pop(random.choice(self.user_dict.keys()))
                   
            return is_finetune
        
        return False
    
    def proc_batch(self, batch_df):
        
#         batch_df = batch_df.sort_values(by ='timestamp')
        batch_df = batch_df[batch_df.content_type_id==0]
        batch_df = self.proc_traindata(batch_df)
        
        batch_df = batch_df.merge(self.question_df, left_on='content_id', right_on='question_id', how='left')
        
#         batch_df['prev_answered_correctly'] = batch_df['answered_correctly'].shift().fillna(2)

        if not self.is_train:
            batch_df['position'] = range(batch_df.shape[0])
            user_group = batch_df.groupby('user_id').agg({
                'combined_id': list,
                'content_type_id': list,
                'task_container_id': list,
                'answered_correctly': list,
                'prior_question_elapsed_time': list,
                'position': list, # only for testing,
                'tags': list,
                'part': list,
                'lag_time': list,
                'row_id': 'count'
            })
            
        else:
            user_group = batch_df.groupby('user_id').agg({
                'combined_id': list,
                'content_type_id': list,
                'task_container_id': list,
                'answered_correctly': list,
                'prior_question_elapsed_time': list,
                'tags': list,
                'part': list,
                'lag_time': list,
                'row_id': 'count'
            })
            
#             print(f"---> Stats: avg_entries_in_batch {user_group['prior_question_elapsed_time'].mean()}")
#             user_group.drop(['prior_question_elapsed_time'], axis=1, inplace=True)
            
        batch_user_dict = user_group.to_dict('index')
        return batch_user_dict
    
    def set_batch(self, train_df):
        
        self.is_train = True
        
        train_user_dict = self.proc_batch(train_df)
        
        self.batch_user = list(train_user_dict.items())
        
        return len(self.batch_user), train_user_dict
    
    def val_batch(self, valid_df):
        
        self.is_train = False
        
        valid_user_dict = self.proc_batch(valid_df)
        
        valid_user_dict = self.update_newdata(valid_user_dict)
        self.user_dict.update(valid_user_dict)
        
        self.batch_user = list(valid_user_dict.items())
                
        return len(self.batch_user)
    
    def late_update(self, data_dict):
        
        data_update = self.update_newdata(data_dict)
        self.user_dict.update(data_update)
    
    def dummy_entry(self, user_id):
        entry = {k:[] for k in self.features}
        return entry
    
    def merge_train_valid(self, train_dict, valid_dict):
        
        merge_dict = dict()
        for uid, udata in valid_dict.items():
            if uid in train_dict:
                merge_dict[uid] = [train_dict[uid], udata]
            else:
                merge_dict[uid] = [self.dummy_entry(uid), udata]
        
        return list(merge_dict.items())
    
    def _agg_dict(self, org_dict, new_dict):

        for k in self.features:
            new_value = (org_dict[k].copy() + new_dict[k])[-self.params.max_seq_length:]
            new_dict[k] = new_value
            
        return new_dict
        
    def update_newdata(self, new_user_dict):
        
        if self.user_dict is None:
            print('First init user_dict')
            self.user_dict = new_user_dict
            return new_user_dict
        
        common_users = list(set(new_user_dict.keys()) & set(self.user_dict.keys()))        
        for uid in common_users:
            new_user_dict[uid] = self._agg_dict(self.user_dict[uid], new_user_dict[uid])
        
        return new_user_dict

    def padding(self, feature, padding_value=0, skip_first=False, first_value=0, max_len=64):
        
        if not skip_first:
            feature = feature[-(max_len-1):]
            return [first_value] + feature + [padding_value] * (max_len - 1 - len(feature)) 
        else:
            feature = feature[-max_len:]
            return feature + [padding_value] * (max_len - len(feature)) 
    
    def __len__(self):
        return len(self.batch_user) if self.batch_user is not None else 0
    
    def __getitem__(self, index):
        
        ins = self.batch_user[index][1]        
        target = np.array(self.padding(ins['answered_correctly'], max_len=self.params.max_seq_length)).astype(int)

#         curr_ans = np.array([self.padding(ins[0]['answered_correctly'], first_value=2, max_len=self.params.max_seq_length), 
#                              np.insert(target[:-1], 0, 2)]).astype(int)
        
        curr_ans = np.array([2] * self.params.max_seq_length).astype(int)
        
        enc_att_len = len(ins['answered_correctly'][-(self.params.max_seq_length-1):])
        enc_attention_mask = [0] + [1] * enc_att_len + [0] * (self.params.max_seq_length - enc_att_len - 1)
                
        attention_mask = np.array(enc_attention_mask).astype(int)
                
        content_type_id = np.array(self.padding(ins['content_type_id'], max_len=self.params.max_seq_length)).astype(int)
        
        combined_id = np.array(self.padding(ins['combined_id'], max_len=self.params.max_seq_length)).astype(int)
        
        task_container_id = np.array(self.padding(ins['task_container_id'], first_value=10001, max_len=self.params.max_seq_length)).astype(int)
        
        prior_time_id = np.array(self.padding(ins['prior_question_elapsed_time'], first_value=0, max_len=self.params.max_seq_length)).astype(int)
        lag_time_id = np.array(self.padding(ins['lag_time'], first_value=0, max_len=self.params.max_seq_length)).astype(int)
        
        question_part_id = np.array(self.padding(ins['part'], first_value=0, max_len=self.params.max_seq_length)).astype(int)
        tag_pad = [0]*self.params.num_total_q_tags
        question_tag_id = np.array(self.padding(ins['tags'], padding_value=tag_pad, first_value=tag_pad, max_len=self.params.max_seq_length)).astype(int)
        
        if self.is_train:
            return (torch.LongTensor(combined_id),
                    torch.LongTensor(attention_mask),
                    torch.LongTensor(content_type_id), 
                    torch.LongTensor(task_container_id),
                    torch.LongTensor(curr_ans),
                    torch.LongTensor(question_part_id),
                    torch.FloatTensor(question_tag_id),
                    torch.LongTensor(prior_time_id),
                    torch.LongTensor(lag_time_id),
                    torch.LongTensor(target)) 
        else:
            history_mask = [0] * (enc_att_len - ins['row_id']) + [1] * ins['row_id']
            history_mask = [0] + history_mask + [0] * (self.params.max_seq_length - len(history_mask) - 1)

            position = np.array(self.padding(ins['position'], skip_first=True, max_len=self.params.max_seq_length, padding_value=-1)).astype(int)
            
            return (torch.LongTensor(combined_id),
                    torch.LongTensor(attention_mask),
                    torch.LongTensor(content_type_id), 
                    torch.LongTensor(task_container_id),
                    torch.LongTensor(curr_ans),
                    torch.LongTensor(question_part_id),
                    torch.FloatTensor(question_tag_id),
                    torch.LongTensor(prior_time_id),
                    torch.LongTensor(history_mask),
                    torch.LongTensor(position),
                    torch.LongTensor(target)) 

In [None]:
class Multi_linear_layer(nn.Module):
    def __init__(self,
                 n_layers,
                 input_size,
                 hidden_size,
                 output_size,
                 activation=None):
        super(Multi_linear_layer, self).__init__()
        self.linears = nn.ModuleList()
        self.linears.append(nn.Linear(input_size, hidden_size))
        for _ in range(1, n_layers - 1):
            self.linears.append(nn.Linear(hidden_size, hidden_size))
        self.linears.append(nn.Linear(hidden_size, output_size))
        self.activation = getattr(F, activation)

    def forward(self, x):
        for linear in self.linears[:-1]:
            x = self.activation(linear(x))
        linear = self.linears[-1]
        x = linear(x)
        return x

In [None]:
class RiidClassifier(nn.Module):
    def __init__(self, params):
        super(RiidClassifier, self).__init__()
        
        n_layers = params.n_layers
        hidden_size = params.hidden_size
        
        self.logits = nn.Linear(hidden_size, 2)
        
        enc_bert_config = BertConfig(num_hidden_layers=params.num_hidden_layers,
                                      max_position_embeddings=params.max_position_embeddings,
                                      num_attention_heads=params.num_attention_heads,
                                      intermediate_size=params.intermediate_size,
                                      question_size=params.num_questions,
                                      lecture_size=params.num_lectures,
                                      task_size=params.max_task_size,
                                      ans_size=params.ans_size,
                                      part_size=params.num_total_q_part,
                                      tag_size=params.num_total_q_tags,
                                      lag_size=params.lag_size,
                                      hidden_size=params.hidden_size,
                                      hidden_dropout_prob=params.dropout,
                                      output_attentions=False,
                                      gradient_checkpointing=False)
        
#         dec_bert_config = BertConfig(num_hidden_layers=params.num_hidden_layers,
#                                       max_position_embeddings=params.max_position_embeddings,
#                                       num_attention_heads=params.num_attention_heads,
#                                       intermediate_size=params.intermediate_size,
#                                       question_size=params.num_questions,
#                                       lecture_size=params.num_lectures,
#                                       task_size=params.max_task_size,
#                                       ans_size=params.ans_size,
#                                       part_size=params.num_total_q_part,
#                                       tag_size=params.num_total_q_tags,
#                                       hidden_size=params.hidden_size,
#                                       output_attentions=False,
#                                       is_decoder=True,
#                                       add_cross_attention=True)
        
        self.enc_model = RiidModel(enc_bert_config)
#         self.dec_model = RiidModel(dec_bert_config)
        
    def forward(self, x_enc):
        
        enc_hidden, pooled_output = self.enc_model(**x_enc, return_dict=False) # [b, len, hsize], [b, hsize]
#         pdb.set_trace()
#         x_dec.update({'encoder_hidden_states': enc_hidden,
#                       'encoder_attention_mask': x_enc['attention_mask']
#         })
#         dec_hidden, _ = self.dec_model(**x_dec, return_dict=False)
        
        cls_logits = self.logits(enc_hidden)
        cls_outputs = nn.Softmax(dim=-1)(cls_logits)
        
        return cls_outputs # [b, len, 2]
        

In [None]:
model = RiidClassifier(params)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
count_parameters(model)

In [None]:
class Trainer(object):
    
    def __init__(self, dataset, params):
        
        self.params = params
        self.dataset = dataset
        self.model = RiidClassifier(params)
        
        if params.cuda:
            print('Moving model to gpus ...')
            self.model.to(device)
            
        self.optimizer = optim.Adam(self.model.parameters(), lr=params.learning_rate)
        self.criteria = nn.CrossEntropyLoss(reduction='none').to(device)
                
    def get_dict(self, batch_data, is_train=True):
        
        field_names = ['input_ids', 'attention_mask', 'token_type_ids', 
                       'task_ids', 'curr_ans', 'part_ids', 'tag_ids', 
                       'prior_time_ids', 'lag_time_ids']
        enc_data = {}
        for i, n in enumerate(field_names):
            enc_entry = batch_data[i]
            enc_data.update({n: enc_entry})
#             dec_data.update({n: dec_entry.squeeze(1)})
            
        if is_train:
            placeholder = batch_data[-1]
            return enc_data, placeholder
        
        else:
            history_mask = batch_data[-3]
            placeholder = batch_data[-2]
            return enc_data, placeholder, history_mask
        
        
    def get_dataloader(self, batch_size, shuffle=True):
        
        data_loader = DataLoader(self.dataset, 
                                 batch_size=batch_size, 
                                 drop_last=False,
                                 num_workers=4,
                                 shuffle=shuffle)
        return data_loader

    def save_model(self, save_path):
        # save model as .pt or .pth file
        torch.save(self.model.state_dict(), f'{save_path}/model_latest.pth')
        print('Finished saving model!')
        
    def save_data(self, save_path):
        
        f = h5py.File(f'{save_path}/data2idx.h5', 'w')
        f.create_dataset('question2idx', data=list(self.dataset.question2idx.keys()))
        f.create_dataset('lecture2idx', data=list(self.dataset.lecture2idx.keys()))
        f.close()
        
        with open(f'{save_path}/user_dict.pickle', 'wb') as handle:
            pickle.dump(self.dataset.user_dict, handle, protocol=3)
            
        print('Finished saving data!')

    def load_model(self, model_path):
        
        if torch.cuda.is_available():
            checkpoint = torch.load(model_path, map_location='cuda:0')
        else:
            # this helps avoid errors when loading single-GPU-trained weights onto CPU-model
            checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
            
        self.model.load_state_dict(checkpoint)
        
    def infer(self, databatch, is_valid=False):
        
        self.model.eval()
        
        # All test-cases
        num_test = len(databatch[databatch['content_type_id'] == 0])
        if num_test == 0:
            return []
        if not is_valid:
            test_size = self.dataset.test_batch(databatch)
        else:
            test_size = self.dataset.val_batch(databatch)
#         print(f'[Infer] New cases: {test_size}')
        
        batch_size = min(test_size, self.params.batch_size)
        test_loader = DataLoader(self.dataset, 
                                 batch_size=batch_size, 
                                 drop_last=False,
                                 num_workers=4,
                                 shuffle=False)
#         test_outputs = []
        all_outputs = np.ones(num_test) * -1.
        for i, test_batch in enumerate(test_loader):
            input_batch = ()
            for feature in test_batch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
            
#             pdb.set_trace()
            enc_batch, position_batch, history_mask = self.get_dict(input_batch, is_train=False)

            output_batch = self.model(enc_batch)[:,:,1].detach().cpu().numpy().reshape(-1, 1)
            position_batch = position_batch.cpu().numpy().reshape(-1, 1)
            history_mask = history_mask.cpu().numpy().reshape(-1, 1)
            
            output_batch = output_batch[history_mask==1]
            position_batch = position_batch[position_batch!=-1]
#             pdb.set_trace()
            assert len(output_batch) == len(position_batch)
            
            all_outputs[position_batch] = output_batch
            
#             tmp = sorted(zip(output_batch, position_batch), key = lambda t: t[1])
#             sorted_output = list(zip(*tmp))[0]
#             test_outputs.append(sorted_output)
        
#         test_outputs = np.concatenate(test_outputs, axis=0)
#         assert num_test == len(test_outputs), f'{num_test} # {len(test_outputs)}'
        assert np.sum((all_outputs==-1).astype(int)) == 0
        
        return all_outputs
    
    def incre_update(self, stat_batch, train_batch, val_batch):
        
        self.train(train_batch)
        
        val_outputs = self.infer(val_batch, is_valid=True)
        val_target = val_batch[val_batch.content_type_id==0]['answered_correctly'].values
        
        assert len(val_outputs) == len(val_target)
        
        auc = metrics.roc_auc_score(val_target, val_outputs)
        print(f'>>> [VALID] AUC: {auc}')
#         self.val(val_batch)
                    
    def finetune_batch(self):
        # update batch-data for finetuning
        is_finetune = self.dataset.finetune_batch(is_finetune=True)
        if is_finetune:
            batch_size = min(self.params.batch_size, len(self.dataset))
            
            # start finetune model
            self.model.train()
            train_loader = self.get_dataloader(batch_size)
            self.train_step(train_loader, print_step=100, msg='Finetune')

    def train_step(self, train_loader, print_step=200, msg='Train'):
        
        self.model.train()
        train_auc = AverageMeter()
        for i, databatch in enumerate(train_loader):
            
            # Move to device
            input_batch = ()
            for feature in databatch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
            
#             pdb.set_trace()
            # Convert to dict-type
            enc_batch, b_target = self.get_dict(input_batch)
            
            # FW model
            output_batch = self.model(enc_batch)    
            
            with torch.no_grad():
#                 pdb.set_trace()
                attention_mask = enc_batch['attention_mask']

#             assert output_batch.shape == b_target.shape
            
            loss = torch.sum(self.criteria(output_batch.reshape(-1, 2), b_target.reshape(-1)) * attention_mask.reshape(-1)) / torch.sum(attention_mask)
#             pdb.set_trace()
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                        
            if not self.params.is_test:
                if self.params.cuda:
                    b_target = b_target.cpu().numpy()
                    output_batch = output_batch.detach().cpu().numpy()
                else:
                    b_target = b_target.numpy()
                    output_batch = output_batch.detach().numpy()
                    
                output_batch, b_target = output_batch[:,:,1].reshape(-1), b_target.reshape(-1)
                attention_mask = attention_mask.cpu().numpy().reshape(-1)
                output_batch = output_batch[attention_mask==1]
                b_target = b_target[attention_mask==1]
                try:
                    acc = metrics.accuracy_score(b_target, (output_batch >= 0.5).astype(int))
                    auc = metrics.roc_auc_score(b_target, output_batch)
                    train_auc.update(auc, b_target.shape[0])
                    
                except:
                    auc, acc = None, None
                    pass

                if print_step > 0 and i % print_step == 0:
                    print(f'+++ [{msg}] Loss: {loss.item()} AUC: {auc} ACC: {acc}')
#                     pdb.set_trace()
                    
        if not self.params.is_test:
            print(f'>>> [{msg}] TOTAL AUC: {train_auc.avg}')
            
    def val_step(self, val_loader, print_step=300, msg='Valid'):
        
        self.model.eval()
        val_auc = AverageMeter()
        for i, databatch in enumerate(val_loader):
            
            # Move to device
            input_batch = ()
            for feature in databatch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
            
            # Convert to dict-type
            enc_batch, b_target = self.get_dict(input_batch)
            
#             pdb.set_trace()
            # FW model
            output_batch = self.model(enc_batch).squeeze(-1)
            
            with torch.no_grad():
                attention_mask = enc_batch['attention_mask']
                attention_mask = attention_mask.cpu().numpy().reshape(-1)
            
            if self.params.cuda:
                b_target = b_target.cpu().numpy()
                output_batch = output_batch.detach().cpu().numpy()
            else:
                b_target = b_target.numpy()
                output_batch = output_batch.detach().numpy()
                
            output_batch, b_target = output_batch[:,:,1].reshape(-1), b_target.reshape(-1)
            output_batch = output_batch[attention_mask==1]
            b_target = b_target[attention_mask==1]
            
            try:
                acc = metrics.accuracy_score(b_target, (output_batch >= 0.5).astype(int))
                auc = metrics.roc_auc_score(b_target, output_batch)
                val_auc.update(auc, b_target.shape[0])
            except:
                auc, acc = None, None
                pass
            
            if print_step > 0 and i % print_step == 0:
                print(f'- [{msg}] AUC: {auc} ACC: {acc}')
        
        print(f'>>> [{msg}] TOTAL AUC: {val_auc.avg}')
    
    def val(self, val_batch):
        
        data_size, val_update = self.dataset.val_batch(val_batch)
        print(f'[Valid] num_users: {data_size}')
        
        val_loader = self.get_dataloader(self.params.batch_size, shuffle=False)
        self.val_step(val_loader, print_step=100)
        
        if self.params.is_update_valid:
            print('--> push valid to storage')
            self.dataset.late_update(val_update)

    def train(self, data_batch):
        
        train_size, train_update = self.dataset.set_batch(data_batch)     
        print(f'[Train] Num-cases: {train_size}')
#         pdb.set_trace()
        
        train_loader = self.get_dataloader(self.params.batch_size)
        for epoch in range(self.params.n_epoch):
            print(f'Epoch: {epoch}')
            self.train_step(train_loader, print_step=10)    
            
        if self.params.is_update_train:
            print('--> push train to storage')
            self.dataset.late_update(train_update)
            

# Training all batches

In [None]:
def gen_mini_chunks(data_chunk, csize=int(1e4)):
    data_chunk = data_chunk.sort_values(by ='timestamp')
    for i in range(0, int(data_chunk.shape[0]), csize):
        yield data_chunk.iloc[i:i+csize]

In [None]:
def gen_test(test_df):
    indice = list(test_df.loc[~test_df['prior_group_answers_correct'].isnull()].index) 
    for i in range(len(indice)-1):
        yield test_df.iloc[indice[i]:indice[i+1]]
        
    yield test_df.iloc[indice[i+1]:]

In [None]:
torch.cuda.empty_cache()
params.load_state = False
params.n_epoch = 1
params.n_chunks = 1000
# save_part = None
if params.load_state:
    print(f'LOADING ALL DATASET ...')
    mydata = LectureData(params, questions.copy(), lectures.copy()) 
    trainer = Trainer(mydata, params)
        
    print('LOADING PRETRAINED MODEL ...')
    trainer.load_model(f'{params.save_dir}/model_latest.pth')
    
else:
    print(f'STARTING TRAINING FROM SCRATCH ...')
    mydata = None
    trainer = None
    
    if trainer is None:
        mydata = LectureData(params, questions.copy(), lectures.copy())  
        trainer = Trainer(mydata, params)
        
start = time.time()
for n, train_part in enumerate(chunks):
    print(f'\n***Training chunk-{n}:')
    
#     mini_chunks = gen_mini_chunks(train_part)
#     for train_part in mini_chunks:
            
    rest_part, valid_part = split_data(train_part, n_tail=6) # 6
    stat_part, train_part = split_data(rest_part, n_tail=100) # 32: goood
    
    valid_part = valid_part.sort_values(by ='timestamp')
    train_part = train_part.sort_values(by ='timestamp')

    print(f'\nStat: {stat_part.shape} / Train: {train_part.shape} / Valid: {valid_part.shape}')
    trainer.incre_update(stat_part, train_part, valid_part)

    if True:
        test_generator = gen_test(test_sample)
        pred = None
        for test_part in test_generator: 
            try:
                prior_correct = eval(test_part['prior_group_answers_correct'].iloc[0])
                prior_correct = [a for a in prior_correct if a != -1]
                if pred is not None:
                    assert len(prior_correct) == len(pred)
                    auc = metrics.roc_auc_score(prior_correct, pred)
                    print(f'>>> [TESTING] AUC: {auc}')
            except:
                prior_correct = []
            s1 = time.time()
            pred = trainer.infer(test_part)
#             trainer.finetune_batch()
            print(f'Testing batch: {time.time() - s1}')

    if n >= params.n_chunks:
        break

    print(f'Time elapsed: {time.time() - start}')

print(f'Saving all model + data ...')
trainer.save_model(params.save_dir)
trainer.save_data(params.save_dir)

print(f'Training finished in {time.time() - start} seconds')
    

In [None]:
import riiideducation
# You can only call make_env() once, so don't lose it!
env = riiideducation.make_env()

# You can only iterate through a result from `env.iter_test()` once
# so be careful not to lose it once you start iterating.
iter_test = env.iter_test()

In [None]:
################################
# Submission
################################

print(f'Start testing ....')
for (test_df, sample_prediction_df) in iter_test:
    # do prediction
    pred = trainer.infer(test_df)

    # fill all value first
    test_df['answered_correctly'] = 0.5
    
    # Only fill question-type
    test_df.loc[test_df['content_type_id'] == 0, 'answered_correctly'] = pred

    # submit prediction    
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])
    
    trainer.finetune_batch()