In [1]:
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import random
import pdb
import itertools
import numpy as np
import pandas as pd
import datatable as dt
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn import metrics

import pickle
import h5py
import joblib

# import seaborn as sns
# import lightgbm as lgb
# from lightgbm import LGBMClassifier


In [2]:
# !pip install latest_bert/dist/bert_src-1.0-py3-none-any.whl

In [3]:
from latest_bert.src.bert_src.modeling_bert import RiidModel
from latest_bert.src.bert_src.configuration_bert import BertConfig

In [4]:
if torch.cuda.is_available(): 
    device = torch.device("cuda:7")
else:
    device = torch.device("cpu")
print(device)

cuda:7


In [5]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [6]:
class Params:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
    def update(self, **kargs):
        self.__dict__.update(kargs)

In [7]:
path = './kaggle/input/riiid-test-answer-prediction'
train_file = f'{path}/train.csv'
train_dtypes = {'row_id': 'int64',
              'timestamp': 'int64',
              'user_id': 'int32',
              'content_id': 'int16',
              'content_type_id': 'int8',
              'task_container_id': 'int16',
              'user_answer': 'int8',
              'answered_correctly': 'int8',
              'prior_question_elapsed_time': 'float32', 
              'prior_question_had_explanation': 'boolean',
             }
test_file = f'{path}/example_test.csv'
test_sample = pd.read_csv(test_file)
questions = pd.read_csv(f'{path}/questions.csv')
lectures = pd.read_csv(f'{path}/lectures.csv')
print('Question shapes:', questions.shape)
print('Lecture shapes:', lectures.shape)

Question shapes: (13523, 5)
Lecture shapes: (418, 4)


In [8]:
colnames = ['row_id', 'timestamp', 'user_id', 'content_id', 'content_type_id', 'task_container_id', 'user_answer', 'answered_correctly', 'prior_question_elapsed_time', 'prior_question_had_explanation']
# chunks = pd.read_csv(train_file, chunksize=1e3, dtype=train_dtypes, header=None, names=colnames, index_col=False)
chunks = pd.read_csv(train_file, chunksize=1e5, dtype=train_dtypes)

In [9]:
question_tags = list(map(lambda x: map(lambda v: int(v) + 1, str(x).split()) if str(x).strip() != 'nan' else [0], questions.tags.values))
question_tags = list(set(itertools.chain(*question_tags)))

n_tags = len(question_tags)
n_parts = len(set(questions.part.unique()))

print(f'n_tags {n_tags}, n_parts {n_parts}')

n_tags 189, n_parts 7


In [10]:
params_dict = {
    'load_state': False,
    'use_buffer': True,
    'is_offline': False,
    'batch_norm': False,
    'is_test': False,
    'n_chunks': 300,
    'n_epoch': 10,
    'learning_rate': 1e-4,
    'batch_size': 64,
    'num_workers': 4,
    'cuda': torch.cuda.is_available(),
    'num_questions': questions.question_id.nunique(),
    'num_lectures': lectures.lecture_id.nunique(),
    'num_total_q_tags': n_tags,
    'num_total_q_part': n_parts,
    'n_layers': 2,
    'dropout': 0.1,
    'hidden_size': 256,
    'max_position_embeddings': 64,
    'num_hidden_layers': 4,
    'num_attention_heads': 4,
    'intermediate_size': 512,
    'max_task_size': 10002,
    'ans_size': 4,
    'buffer_size_limit': 1e4,
    'max_seq_length': 32,
    'extra_dir':'/kaggle/input/bert-src',
    'save_dir': './save/bert_save'
}
params = Params(**params_dict)
print(params.__dict__)

{'load_state': False, 'use_buffer': True, 'is_offline': False, 'batch_norm': False, 'is_test': False, 'n_chunks': 300, 'n_epoch': 10, 'learning_rate': 0.0001, 'batch_size': 64, 'num_workers': 4, 'cuda': True, 'num_questions': 13523, 'num_lectures': 418, 'num_total_q_tags': 189, 'num_total_q_part': 7, 'n_layers': 2, 'dropout': 0.1, 'hidden_size': 256, 'max_position_embeddings': 64, 'num_hidden_layers': 4, 'num_attention_heads': 4, 'intermediate_size': 512, 'max_task_size': 10002, 'ans_size': 4, 'buffer_size_limit': 10000.0, 'max_seq_length': 32, 'extra_dir': '/kaggle/input/bert-src', 'save_dir': './save/bert_save'}


In [11]:
def split_data(train_part, n_tail=10):
    valid = train_part.groupby('user_id').tail(n_tail)
    train = train_part[~train_part.index.isin(valid.index)]
    return train, valid

In [12]:
class LectureData(Dataset):
    
    def __init__(self, params, question_df=None, lecture_df=None, is_train=True):
        # read init-data
        self.params = params
        self.is_train = is_train
        
        self.features = ['combined_id', 'content_type_id', 'task_container_id', 
                         'answered_correctly','prior_question_elapsed_time']
        
        self.train_columns = ['user_id', 'content_id', 'content_type_id', 'task_container_id',
                              'answered_correctly', 'prior_question_elapsed_time', 'prior_question_had_explanation']
        
        self.prior_batch, self.current_batch, self.buffer_df, self.batch_user = None, None, None, None
        self.user_dict, self.question2idx, self.lecture2idx = None, None, None
        
        # Restore all data
        if params.load_state:
            self.load_state()
        else:
            self.init_info(question_df, lecture_df)

    def init_info(self, question_df, lecture_df):
        
        self.question_list = list(question_df['question_id'].unique())
        self.lecture_list = list(lecture_df['lecture_id'].unique())

        self.n_questions = len(self.question_list)
        self.n_lectures = len(self.lecture_list)
        
        self.question2idx = dict(zip(self.question_list, range(1, self.n_questions + 1))) # 0 for padding
        self.lecture2idx = dict(zip(self.lecture_list, range(self.n_questions + 1, self.n_questions + 1 + self.n_lectures)))
        
    def load_state(self):  
        
        f = h5py.File(os.path.join(self.params.extra_dir, 'data2idx.h5'), 'r')
        
        self.question2idx = f['question2idx'][:]
        self.n_questions = len(self.question2idx)
        self.question2idx = dict(zip(self.question2idx, range(1, self.n_questions + 1)))
        
        self.lecture2idx = f['lecture2idx'][:]
        self.n_lectures = len(self.lecture2idx)
        self.lecture2idx = dict(zip(self.lecture2idx, range(self.n_questions + 1, self.n_questions + 1 + self.n_lectures)))
        f.close()
        
        with open(os.path.join(self.params.extra_dir, 'user_dict.pickle'), 'rb') as handle:
            self.user_dict = pickle.load(handle)


    def proc_traindata(self, train_df):
        
#         train_df['prior_question_elapsed_time'].fillna(0., inplace=True)
#         train_df['prior_question_had_explanation'].fillna(False, inplace=True)
#         train_df['prior_question_had_explanation'] = train_df['prior_question_had_explanation'].astype(int)        
#         train_df['prior_question_elapsed_time'] = train_df['prior_question_elapsed_time'].map(lambda x: np.log(x + 1.))

        train_df['combined_id'] = train_df.apply(lambda x: self.question2idx[x['content_id']], axis=1) 
#                                                  if x['content_type_id']==0 
#                                                  else self.lecture2idx[x['content_id']], axis=1)
#         train_df['task_container_id'].fillna(-1, inplace=True)
        train_df['task_container_id'] = train_df['task_container_id'].map(lambda x: x + 1)

        return train_df
        
            
    def test_batch(self, batch_df):
        
        self.is_train = False
        
        # fetch prior labels
        gt_prior_batch = eval(batch_df.iloc[0]["prior_group_answers_correct"])
        
        # HERE stop updating first for 1st submission
        if self.current_batch is not None and len(gt_prior_batch) > 0:
            # save prior-batch with labels
            self.prior_batch = self.current_batch
            
            # Assign label to prev-batch
            self.prior_batch['answered_correctly'] = gt_prior_batch
            self.prior_batch = self.prior_batch[self.train_columns]
            
            # add to buffer-df
            if self.params.use_buffer:
                self.buffer_df = pd.concat([self.buffer_df, self.prior_batch], axis=0, ignore_index=True)
             
        else:
            self.prior_batch = batch_df
        
        self.current_batch = batch_df
        # create dummy labels "2", later will be replaced
        self.current_batch['answered_correctly'] = 2
        
        test_user_dict = self.proc_batch(self.current_batch)
        self.batch_user = self.merge_train_valid(self.user_dict, test_user_dict)
        
        return len(self.batch_user)        
            
    def finetune_batch(self, is_finetune=False):
        
        self.is_train = True
        
        if self.buffer_df is not None and len(self.buffer_df) > params.buffer_size_limit:

            print('--> Dataset activated finetune buffer')
            buff_user_dict = self.proc_batch(self.buffer_df)
            
            if is_finetune:
                # create data for continue training
                self.batch_user = self.merge_train_valid(self.user_dict, buff_user_dict)
            
            # update user_dict later on
            buff_update = self.update_newdata(buff_user_dict)
            self.user_dict.update(buff_update)
            
            self.buffer_df = None
            
            # check user_dict size
            cur_size_in_gb = sys.getsizeof(self.user_dict) // (1024**3)
            if cur_size_in_gb > 12:
                print('--> Random removing user-entries')
                for i in range(int(len(self.user_dict) * 0.01)):
                    self.user_dict.pop(random.choice(self.user_dict.keys()))
                   
            return is_finetune
        
        return False
    
    def proc_batch(self, batch_df):
        
        batch_df = batch_df[batch_df.content_type_id==0]
        batch_df = self.proc_traindata(batch_df)
#         batch_df['prev_answered_correctly'] = batch_df['answered_correctly'].shift().fillna(2)

        if not self.is_train:
            batch_df['position'] = range(batch_df.shape[0])
            user_group = batch_df.groupby('user_id').agg({
                'combined_id': lambda x: list(x),
                'content_type_id': lambda x: list(x),
                'task_container_id': lambda x: list(x),
                'answered_correctly': lambda x: list(x),
#                 'prior_question_elapsed_time': lambda x: list(x),
                'position': lambda x: list(x) # only for testing
            })
            
        else:
            user_group = batch_df.groupby('user_id').agg({
                'combined_id': lambda x: list(x),
                'content_type_id': lambda x: list(x),
                'task_container_id': lambda x: list(x),
                'answered_correctly': lambda x: list(x),
#                 'prior_question_elapsed_time': lambda x: list(x),
            })
            
        batch_user_dict = user_group.to_dict('index')
        return batch_user_dict
    
    def set_batch(self, train_df, valid_df, is_update_valid=True):
        
        self.is_train = True
        
        train_user_dict = self.proc_batch(train_df)
        valid_user_dict = self.proc_batch(valid_df)
        # set data for training
        train_user_dict = self.update_newdata(train_user_dict)
        
        # update storage data
        self.user_dict.update(train_user_dict)
        if is_update_valid:
            valid_update = self.update_newdata(valid_user_dict)
            self.user_dict.update(valid_update)  
        
        self.batch_user = self.merge_train_valid(train_user_dict, valid_user_dict)
        
        return len(self.batch_user)
    
    def dummy_entry(self, user_id):
        entry = {k:[] for k in self.features}
        return entry
    
    def merge_train_valid(self, train_dict, valid_dict):
        
        merge_dict = dict()
        for uid, udata in valid_dict.items():
            if uid in train_dict:
                merge_dict[uid] = [train_dict[uid], udata]
            else:
                merge_dict[uid] = [self.dummy_entry(uid), udata]
        
        return list(merge_dict.items())
    
    def _agg_dict(self, org_dict, new_dict):

        for k, v in new_dict.items():
            new_value = (org_dict[k] + v)[-self.params.max_seq_length:]
            new_dict[k] = new_value
            
        return new_dict
        
    def update_newdata(self, new_user_dict):
        
        if self.user_dict is None:
            self.user_dict = new_user_dict
            return new_user_dict
        
        common_users = list(set(new_user_dict.keys()) & set(self.user_dict.keys()))        
        for uid in common_users:
            new_user_dict[uid] = self._agg_dict(self.user_dict[uid], new_user_dict[uid])
        
        return new_user_dict

    def padding(self, feature, padding_value=0, skip_first=False, first_value=0, max_len=64):
        
        feature = feature[-(max_len-1):]
        if not skip_first:
            return [first_value] + feature + [padding_value] * (max_len - 1 - len(feature)) 
        else:
            return feature + [padding_value] * (max_len - len(feature)) 
    
    def __len__(self):
        return len(self.batch_user) if self.batch_user is not None else 0
    
    def __getitem__(self, index):
        
        ins = self.batch_user[index][1]        
        target = np.array(self.padding(ins[1]['answered_correctly'], skip_first=True, max_len=self.params.max_seq_length)).astype(int)

        curr_ans = np.array([self.padding(ins[0]['answered_correctly'], first_value=3, max_len=self.params.max_seq_length), 
                             np.insert(target[:-1], 0, 2)]).astype(int)
        
        enc_att_len = len(ins[0]['answered_correctly'][-(self.params.max_seq_length-1):]) + 1
        enc_attention_mask = [1] * enc_att_len + [0] * (self.params.max_seq_length - enc_att_len)
        
        dec_att_len = len(ins[1]['answered_correctly'][-self.params.max_seq_length:])
        dec_attention_mask = [1] * dec_att_len + [0] * (self.params.max_seq_length - dec_att_len)
        
        attention_mask = np.array([enc_attention_mask, dec_attention_mask]).astype(int)
        
        content_type_id = np.array([self.padding(ins[0]['content_type_id'], max_len=self.params.max_seq_length), 
                                    self.padding(ins[1]['content_type_id'], skip_first=True, max_len=self.params.max_seq_length)]).astype(int)
        
        combined_id = np.array([self.padding(ins[0]['combined_id'], first_value=13940, max_len=self.params.max_seq_length), 
                                self.padding(ins[1]['combined_id'], skip_first=True, max_len=self.params.max_seq_length)]).astype(int)
        
        task_container_id = np.array([self.padding(ins[0]['task_container_id'], first_value=10001, max_len=self.params.max_seq_length), 
                                      self.padding(ins[1]['task_container_id'], skip_first=True, max_len=self.params.max_seq_length)]).astype(int)
        
        if self.is_train:
            return (torch.LongTensor(combined_id),
                    torch.LongTensor(attention_mask),
                    torch.LongTensor(content_type_id), 
                    torch.LongTensor(task_container_id),
                    torch.LongTensor(curr_ans),
                    torch.FloatTensor(target)) 
        else:
            position = np.array(self.padding(ins[1]['position'], skip_first=True, max_len=self.params.max_seq_length, padding_value=-1)).astype(int)
            
            return (torch.LongTensor(combined_id),
                    torch.LongTensor(attention_mask),
                    torch.LongTensor(content_type_id), 
                    torch.LongTensor(task_container_id),
                    torch.LongTensor(curr_ans),
                    torch.LongTensor(position),
                    torch.FloatTensor(target)) 

In [13]:
class Multi_linear_layer(nn.Module):
    def __init__(self,
                 n_layers,
                 input_size,
                 hidden_size,
                 output_size,
                 activation=None):
        super(Multi_linear_layer, self).__init__()
        self.linears = nn.ModuleList()
        self.linears.append(nn.Linear(input_size, hidden_size))
        for _ in range(1, n_layers - 1):
            self.linears.append(nn.Linear(hidden_size, hidden_size))
        self.linears.append(nn.Linear(hidden_size, output_size))
        self.activation = getattr(F, activation)

    def forward(self, x):
        for linear in self.linears[:-1]:
            x = self.activation(linear(x))
        linear = self.linears[-1]
        x = linear(x)
        return x

In [14]:
class RiidClassifier(nn.Module):
    def __init__(self, params):
        super(RiidClassifier, self).__init__()
        
        n_layers = params.n_layers
        hidden_size = params.hidden_size
        
        self.logits = Multi_linear_layer(n_layers, hidden_size, hidden_size,
                                           1, activation='relu')
        
        enc_bert_config = BertConfig(num_hidden_layers=params.num_hidden_layers,
                                      max_position_embeddings=params.max_position_embeddings,
                                      num_attention_heads=params.num_attention_heads,
                                      intermediate_size=params.intermediate_size,
                                      question_size=params.num_questions,
                                      lecture_size=params.num_lectures,
                                      task_size=params.max_task_size,
                                      ans_size=params.ans_size,
                                      hidden_size=params.hidden_size,
                                      output_attentions=False)
        
        dec_bert_config = BertConfig(num_hidden_layers=params.num_hidden_layers,
                                      max_position_embeddings=params.max_position_embeddings,
                                      num_attention_heads=params.num_attention_heads,
                                      intermediate_size=params.intermediate_size,
                                      question_size=params.num_questions,
                                      lecture_size=params.num_lectures,
                                      task_size=params.max_task_size,
                                      ans_size=params.ans_size,
                                      hidden_size=params.hidden_size,
                                      output_attentions=False,
                                      is_decoder=True,
                                      add_cross_attention=True)
        
        self.enc_model = RiidModel(enc_bert_config)
        self.dec_model = RiidModel(dec_bert_config)
        
    def forward(self, x_enc, x_dec):
        
        enc_hidden, pooled_output = self.enc_model(**x_enc, return_dict=False) # [b, len, hsize], [b, hsize]
#         pdb.set_trace()
        x_dec.update({'encoder_hidden_states': enc_hidden,
                      'encoder_attention_mask': x_enc['attention_mask']
        })
        dec_hidden, _ = self.dec_model(**x_dec, return_dict=False)
        
        cls_logits = self.logits(dec_hidden)
        cls_outputs = torch.sigmoid(cls_logits)
        
        return cls_outputs # [b, len, 2]
        

In [15]:
class Trainer(object):
    
    def __init__(self, dataset, params):
        
        self.params = params
        self.dataset = dataset
        self.model = RiidClassifier(params)
        
        if params.cuda:
            print('Moving model to gpus ...')
            self.model.to(device)
            
        self.optimizer = optim.Adam(self.model.parameters(), lr=params.learning_rate)
        self.criteria = nn.BCELoss(reduction='none')
                
    def get_dict(self, batch_data, is_train=True):
        
        field_names = ['input_ids', 'attention_mask', 'token_type_ids', 'task_ids', 'curr_ans']
        
        enc_data, dec_data = {}, {}
        for i, n in enumerate(field_names):
            enc_entry, dec_entry = batch_data[i].split(1, dim=1)
            enc_data.update({n: enc_entry.squeeze(1)})
            dec_data.update({n: dec_entry.squeeze(1)})
            
        if is_train:
#             target = batch_data[-1].split(1, dim=1)[1].squeeze(1) # fetch valid side only
            placeholder = batch_data[-1]
        else:
            placeholder = batch_data[-2]

        return enc_data, dec_data, placeholder
        
    def get_dataloader(self, batch_size, shuffle=True):
        
        data_loader = DataLoader(self.dataset, 
                                 batch_size=batch_size, 
                                 drop_last=False,
                                 num_workers=4,
                                 shuffle=shuffle)
        return data_loader

    def save_model(self, save_path):
        # save model as .pt or .pth file
        torch.save(self.model.state_dict(), f'{save_path}/model_latest.pth')
        print('Finished saving model!')
        
    def save_data(self, save_path):
        
        f = h5py.File(f'{save_path}/data2idx.h5', 'w')
        f.create_dataset('question2idx', data=list(self.dataset.question2idx.keys()))
        f.create_dataset('lecture2idx', data=list(self.dataset.lecture2idx.keys()))
        f.close()
        
        with open(f'{save_path}/user_dict.pickle', 'wb') as handle:
            pickle.dump(self.dataset.user_dict, handle, protocol=3)
            
        print('Finished saving data!')

    def load_model(self, model_path):
        
        if torch.cuda.is_available():
            checkpoint = torch.load(model_path, map_location='cuda:0')
        else:
            # this helps avoid errors when loading single-GPU-trained weights onto CPU-model
            checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
            
        self.model.load_state_dict(checkpoint)
        
    def infer(self, databatch):
        
        self.model.eval()
        
        # All test-cases
        num_test = len(databatch[databatch['content_type_id'] == 0])
        if num_test == 0:
            return []

        test_size = self.dataset.test_batch(databatch)
#         print(f'[Infer] New cases: {test_size}')
        
        batch_size = min(test_size, self.params.batch_size)
        test_loader = DataLoader(self.dataset, 
                                 batch_size=batch_size, 
                                 drop_last=False,
                                 num_workers=4,
                                 shuffle=False)
        test_outputs = []
        for i, test_batch in enumerate(test_loader):
            input_batch = ()
            for feature in test_batch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
                
            enc_batch, dec_batch, position_batch = self.get_dict(input_batch, is_train=False)

            output_batch = self.model(enc_batch, dec_batch).squeeze(-1).detach().cpu().numpy().reshape(-1, 1)
            
            attention_mask = dec_batch['attention_mask'].cpu().numpy().reshape(-1, 1)
            
            output_batch = output_batch[attention_mask==1]
            position_batch = position_batch.cpu().numpy().reshape(-1, 1)
            position_batch = position_batch[attention_mask==1]
            
            tmp = sorted(zip(output_batch, position_batch), key = lambda t: t[1])
            sorted_output = list(zip(*tmp))[0]
            test_outputs.append(sorted_output)
        
        test_outputs = np.concatenate(test_outputs, axis=0)
#         assert num_test == len(test_outputs), f'{num_test} # {len(test_outputs)}'
        
        return test_outputs
    
    def incre_update(self, stat_batch, train_batch, val_batch):
        
        self.train(stat_batch, train_batch)
        self.val(train_batch, val_batch)
                    
    def finetune_batch(self):
        # update batch-data for finetuning
        is_finetune = self.dataset.finetune_batch(is_finetune=False)
        if is_finetune:
            batch_size = min(self.params.batch_size, len(self.dataset))
            
            # start finetune model
            self.model.train()
            train_loader = self.get_dataloader(batch_size)
            self.train_step(train_loader, print_step=100, msg='Finetune')

    def train_step(self, train_loader, print_step=200, msg='Train'):
        
        self.model.train()
        train_auc = AverageMeter()
        for i, databatch in enumerate(train_loader):
            
            # Move to device
            input_batch = ()
            for feature in databatch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
        
            # Convert to dict-type
            enc_batch, dec_batch, b_target = self.get_dict(input_batch)
            
            # FW model
            output_batch = self.model(enc_batch, dec_batch).squeeze(-1)    
            
            with torch.no_grad():
                attention_mask = dec_batch['attention_mask']

#             assert output_batch.shape == b_target.shape
            
            loss = torch.sum(self.criteria(output_batch, b_target) * attention_mask) / torch.sum(attention_mask)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                        
            if not self.params.is_test:
                if self.params.cuda:
                    b_target = b_target.cpu().numpy()
                    output_batch = output_batch.detach().cpu().numpy()
                else:
                    b_target = b_target.numpy()
                    output_batch = output_batch.detach().numpy()
                    
                output_batch, b_target = output_batch.reshape(-1, 1), b_target.reshape(-1, 1)
                attention_mask = attention_mask.cpu().numpy().reshape(-1, 1)
                output_batch = output_batch[attention_mask==1]
                b_target = b_target[attention_mask==1]
                try:
                    acc = metrics.accuracy_score(b_target, (output_batch >= 0.5).astype(int))
                    auc = metrics.roc_auc_score(b_target, output_batch)
                    train_auc.update(auc, b_target.shape[0])
                    
                except:
                    auc, acc = None, None
                    pass

                if print_step > 0 and i % print_step == 0:
                    print(f'+++ [{msg}] Loss: {loss.item()} AUC: {auc} ACC: {acc}')
                    
        if not self.params.is_test:
            print(f'>>> [{msg}] TOTAL AUC: {train_auc.avg}')
            
    def val_step(self, val_loader, print_step=300, msg='Valid'):
        
        self.model.eval()
        val_auc = AverageMeter()
        for i, databatch in enumerate(val_loader):
            
            # Move to device
            input_batch = ()
            for feature in databatch:
                input_batch += (feature.to(device) if self.params.cuda else feature, )
            
            # Convert to dict-type
            enc_batch, dec_batch, b_target = self.get_dict(input_batch)
            
            # FW model
            output_batch = self.model(enc_batch, dec_batch).squeeze(-1)
            
            with torch.no_grad():
                attention_mask = dec_batch['attention_mask']
                attention_mask = attention_mask.cpu().numpy().reshape(-1, 1)
            
            if self.params.cuda:
                b_target = b_target.cpu().numpy()
                output_batch = output_batch.detach().cpu().numpy()
            else:
                b_target = b_target.numpy()
                output_batch = output_batch.detach().numpy()
                
            output_batch, b_target = output_batch.reshape(-1, 1), b_target.reshape(-1, 1)
            output_batch = output_batch[attention_mask==1]
            b_target = b_target[attention_mask==1]
            
            try:
                acc = metrics.accuracy_score(b_target, (output_batch >= 0.5).astype(int))
                auc = metrics.roc_auc_score(b_target, output_batch)
                val_auc.update(auc, b_target.shape[0])
            except:
                auc, acc = None, None
                pass
            
            if print_step > 0 and i % print_step == 0:
                print(f'- [{msg}] AUC: {auc} ACC: {acc}')
        
        print(f'>>> [{msg}] TOTAL AUC: {val_auc.avg}')
    
    def val(self, stat_batch, val_batch):
        
        data_size = self.dataset.set_batch(stat_batch, val_batch)
        print(f'[Valid] num_users: {data_size}')
        
        val_loader = self.get_dataloader(self.params.batch_size, shuffle=False)
        self.val_step(val_loader, print_step=100)

    def train(self, stat_batch, data_batch):
        
        train_size = self.dataset.set_batch(stat_batch, data_batch)     
        print(f'[Train] Num-cases: {train_size}')
#         pdb.set_trace()
        
        train_loader = self.get_dataloader(self.params.batch_size)
        for epoch in range(self.params.n_epoch):
            print(f'Epoch: {epoch}')
            self.train_step(train_loader, print_step=100)    
            

# Training all batches

In [16]:
def gen_mini_chunks(data_chunk, csize=int(1e4)):
    data_chunk = data_chunk.sort_values(by ='timestamp')
    for i in range(0, int(data_chunk.shape[0]), csize):
        yield data_chunk.iloc[i:i+csize]

In [17]:
def gen_test(test_df):
    indice = list(test_df.loc[~test_df['prior_group_answers_correct'].isnull()].index) 
    for i in range(len(indice)-1):
        yield test_df.iloc[indice[i]:indice[i+1]]
        
    yield test_df.iloc[indice[i+1]:]

In [18]:
torch.cuda.empty_cache()
params.load_state = False
params.n_epoch = 2
params.n_chunks = 300
# save_part = None
if params.load_state:
    print(f'LOADING ALL DATASET ...')
    mydata = LectureData(params, questions.copy(), lectures.copy()) 
    trainer = Trainer(mydata, params)
        
    print('LOADING PRETRAINED MODEL ...')
    trainer.load_model(f'{params.extra_dir}/model_latest.pth')
    
else:
    print(f'STARTING TRAINING FROM SCRATCH ...')
    mydata = None
    trainer = None
    
    if trainer is None:
        mydata = LectureData(params, questions.copy(), lectures.copy())  
        trainer = Trainer(mydata, params)
        
    start = time.time()
    for n, train_part in enumerate(chunks):
        print(f'\n***Training chunk-{n}:')
        rest_part, valid_part = split_data(train_part, n_tail=6)
        stat_part, train_part = split_data(rest_part, n_tail=32) # 12: goood

        print(f'Stat: {stat_part.shape} / Train: {train_part.shape} / Valid: {valid_part.shape}')
        trainer.incre_update(stat_part, train_part, valid_part)

        if True:
            test_generator = gen_test(test_sample)
            pred = None
            for test_part in test_generator: 
                try:
                    prior_correct = eval(test_part['prior_group_answers_correct'].iloc[0])
                    prior_correct = [a for a in prior_correct if a != -1]
                    if pred is not None:
                        assert len(prior_correct) == len(pred)
                        auc = metrics.roc_auc_score(prior_correct, pred)
                        print(f'>>> [TESTING] AUC: {auc}')
                except:
                    prior_correct = []
                s1 = time.time()
                pred = trainer.infer(test_part)
#                 trainer.finetune_batch()
                print(f'Testing batch: {time.time() - s1}')

        if n >= params.n_chunks:
            break

        print(f'Batch-Time elapsed: {time.time() - start}')

    print(f'Saving all model + data ...')
    trainer.save_model(params.save_dir)
    trainer.save_data(params.save_dir)

    print(f'Training finished in {time.time() - start} seconds')
    

STARTING TRAINING FROM SCRATCH ...
Moving model to gpus ...

***Training chunk-0:
Stat: (88843, 10) / Train: (9071, 10) / Valid: (2086, 10)
[Train] Num-cases: 347
Epoch: 0
+++ [Train] Loss: 0.7020116448402405 AUC: 0.46160601073731344 ACC: 0.45126146788990823
>>> [Train] TOTAL AUC: 0.562517929056271
Epoch: 1
+++ [Train] Loss: 0.6498209238052368 AUC: 0.6760590141016964 ACC: 0.6371629542790153
>>> [Train] TOTAL AUC: 0.6755117609226325
[Valid] num_users: 349
- [Valid] AUC: 0.705767817932674 ACC: 0.6472081218274112
>>> [Valid] TOTAL AUC: 0.6920052786260207
Testing batch: 0.3643836975097656
>>> [TESTING] AUC: 0.4935064935064935
Testing batch: 0.3928215503692627
>>> [TESTING] AUC: 0.2556818181818182
Testing batch: 0.3653285503387451
>>> [TESTING] AUC: 0.7916666666666666
Testing batch: 0.37448883056640625
Batch-Time elapsed: 16.071417093276978

***Training chunk-1:
Stat: (88070, 10) / Train: (9708, 10) / Valid: (2222, 10)
[Train] Num-cases: 370
Epoch: 0
+++ [Train] Loss: 0.6511670351028442 AUC

>>> [Train] TOTAL AUC: 0.736604313497288
[Valid] num_users: 414
- [Valid] AUC: 0.762933614526959 ACC: 0.7027300303336703
>>> [Valid] TOTAL AUC: 0.7569078788446756
Testing batch: 0.37944746017456055
>>> [TESTING] AUC: 0.7532467532467533
Testing batch: 0.37693309783935547
>>> [TESTING] AUC: 0.6306818181818182
Testing batch: 0.3818378448486328
>>> [TESTING] AUC: 0.6527777777777777
Testing batch: 0.3742399215698242
Batch-Time elapsed: 155.0754017829895

***Training chunk-11:
Stat: (89428, 10) / Train: (8598, 10) / Valid: (1974, 10)
[Train] Num-cases: 329
Epoch: 0
+++ [Train] Loss: 0.6028352975845337 AUC: 0.7253943173555567 ACC: 0.6731334509112287
>>> [Train] TOTAL AUC: 0.7102826807739776
Epoch: 1
+++ [Train] Loss: 0.6073315143585205 AUC: 0.7104588356488571 ACC: 0.6619964973730298
>>> [Train] TOTAL AUC: 0.732847612758734
[Valid] num_users: 329
- [Valid] AUC: 0.7567186818252917 ACC: 0.6813576494427558
>>> [Valid] TOTAL AUC: 0.7543629193818787
Testing batch: 0.37093687057495117
>>> [TESTING] 

Testing batch: 0.38223838806152344
Batch-Time elapsed: 294.4828886985779

***Training chunk-21:
Stat: (86822, 10) / Train: (10693, 10) / Valid: (2485, 10)
[Train] Num-cases: 414
Epoch: 0
+++ [Train] Loss: 0.5758416056632996 AUC: 0.7532454182330827 ACC: 0.7045454545454546
>>> [Train] TOTAL AUC: 0.7186105017410466
Epoch: 1
+++ [Train] Loss: 0.5964434146881104 AUC: 0.7248324592074591 ACC: 0.6730196545562835
>>> [Train] TOTAL AUC: 0.7411033144129002
[Valid] num_users: 415
- [Valid] AUC: 0.7916196306515644 ACC: 0.7229696063776782
>>> [Valid] TOTAL AUC: 0.7688153859892708
Testing batch: 0.3811309337615967
>>> [TESTING] AUC: 0.7922077922077921
Testing batch: 0.3775825500488281
>>> [TESTING] AUC: 0.8863636363636364
Testing batch: 0.4499046802520752
>>> [TESTING] AUC: 0.6875
Testing batch: 0.39772915840148926
Batch-Time elapsed: 308.7396037578583

***Training chunk-22:
Stat: (84361, 10) / Train: (12650, 10) / Valid: (2989, 10)
[Train] Num-cases: 498
Epoch: 0
+++ [Train] Loss: 0.6133988499641418

>>> [Train] TOTAL AUC: 0.7508021853110447
[Valid] num_users: 365
- [Valid] AUC: 0.7890885390885392 ACC: 0.7040609137055838
>>> [Valid] TOTAL AUC: 0.7727048490580494
Testing batch: 0.39023566246032715
>>> [TESTING] AUC: 0.8311688311688311
Testing batch: 0.40128397941589355
>>> [TESTING] AUC: 0.7954545454545454
Testing batch: 0.41686296463012695
>>> [TESTING] AUC: 0.7638888888888888
Testing batch: 0.3916144371032715
Batch-Time elapsed: 449.9091446399689

***Training chunk-32:
Stat: (87574, 10) / Train: (10074, 10) / Valid: (2352, 10)
[Train] Num-cases: 392
Epoch: 0
+++ [Train] Loss: 0.6068423986434937 AUC: 0.722708592670907 ACC: 0.6790123456790124
>>> [Train] TOTAL AUC: 0.729061787390769
Epoch: 1
+++ [Train] Loss: 0.5807669162750244 AUC: 0.7534713162243527 ACC: 0.6858485026423957
>>> [Train] TOTAL AUC: 0.7484532803632139
[Valid] num_users: 390
- [Valid] AUC: 0.7607869996614496 ACC: 0.6915650406504065
>>> [Valid] TOTAL AUC: 0.7678423473043146
Testing batch: 0.39177489280700684
>>> [TESTIN

Testing batch: 0.3872988224029541
Batch-Time elapsed: 590.4084117412567

***Training chunk-42:
Stat: (87233, 10) / Train: (10355, 10) / Valid: (2412, 10)
[Train] Num-cases: 402
Epoch: 0
+++ [Train] Loss: 0.597180962562561 AUC: 0.7228230117261805 ACC: 0.6787878787878788
>>> [Train] TOTAL AUC: 0.7412244457842715
Epoch: 1
+++ [Train] Loss: 0.5815079808235168 AUC: 0.7554864644180641 ACC: 0.6860174781523096
>>> [Train] TOTAL AUC: 0.7584137525089999
[Valid] num_users: 402
- [Valid] AUC: 0.7630574532839048 ACC: 0.6867588932806324
>>> [Valid] TOTAL AUC: 0.7753949811781142
Testing batch: 0.3799772262573242
>>> [TESTING] AUC: 0.7402597402597403
Testing batch: 0.3751490116119385
>>> [TESTING] AUC: 0.8636363636363636
Testing batch: 0.3943040370941162
>>> [TESTING] AUC: 0.8819444444444444
Testing batch: 0.3844153881072998
Batch-Time elapsed: 604.3855581283569

***Training chunk-43:
Stat: (88765, 10) / Train: (9117, 10) / Valid: (2118, 10)
[Train] Num-cases: 353
Epoch: 0
+++ [Train] Loss: 0.59601771

>>> [Train] TOTAL AUC: 0.7593084892630241
[Valid] num_users: 252
- [Valid] AUC: 0.7679306399532176 ACC: 0.6976861167002012
>>> [Valid] TOTAL AUC: 0.782149327428398
Testing batch: 0.37874555587768555
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.37993693351745605
>>> [TESTING] AUC: 0.8522727272727272
Testing batch: 0.37721991539001465
>>> [TESTING] AUC: 0.8680555555555556
Testing batch: 0.38573527336120605
Batch-Time elapsed: 744.7782769203186

***Training chunk-53:
Stat: (85290, 10) / Train: (11878, 10) / Valid: (2832, 10)
[Train] Num-cases: 472
Epoch: 0
+++ [Train] Loss: 0.6052042841911316 AUC: 0.7285509415139044 ACC: 0.6696375519904931
>>> [Train] TOTAL AUC: 0.7374392063023517
Epoch: 1
+++ [Train] Loss: 0.5818547010421753 AUC: 0.7513713430851063 ACC: 0.6916666666666667
>>> [Train] TOTAL AUC: 0.7566528081006243
[Valid] num_users: 472
- [Valid] AUC: 0.7617459305299595 ACC: 0.6831337325349301
>>> [Valid] TOTAL AUC: 0.7760466895922463
Testing batch: 0.38932180404663086
>>> [TEST

Testing batch: 0.39031004905700684
Batch-Time elapsed: 886.741774559021

***Training chunk-63:
Stat: (82756, 10) / Train: (13956, 10) / Valid: (3288, 10)
[Train] Num-cases: 548
Epoch: 0
+++ [Train] Loss: 0.6113173961639404 AUC: 0.7185387844126859 ACC: 0.6710608913998745
>>> [Train] TOTAL AUC: 0.7477690926871449
Epoch: 1
+++ [Train] Loss: 0.5856362581253052 AUC: 0.7585112539401104 ACC: 0.6895910780669146
>>> [Train] TOTAL AUC: 0.7605171324878653
[Valid] num_users: 548
- [Valid] AUC: 0.7684454378363517 ACC: 0.6987891019172553
>>> [Valid] TOTAL AUC: 0.7797550668720286
Testing batch: 0.3856515884399414
>>> [TESTING] AUC: 0.8181818181818182
Testing batch: 0.38456130027770996
>>> [TESTING] AUC: 0.8920454545454546
Testing batch: 0.3891563415527344
>>> [TESTING] AUC: 0.8263888888888888
Testing batch: 0.3847992420196533
Batch-Time elapsed: 901.540678024292

***Training chunk-64:
Stat: (84988, 10) / Train: (12246, 10) / Valid: (2766, 10)
[Train] Num-cases: 461
Epoch: 0
+++ [Train] Loss: 0.578493

>>> [Train] TOTAL AUC: 0.7597275598285229
[Valid] num_users: 419
- [Valid] AUC: 0.7797561581967722 ACC: 0.7122521606507372
>>> [Valid] TOTAL AUC: 0.7792584258778804
Testing batch: 0.406447172164917
>>> [TESTING] AUC: 0.7922077922077922
Testing batch: 0.3880350589752197
>>> [TESTING] AUC: 0.9375
Testing batch: 0.40320444107055664
>>> [TESTING] AUC: 0.7847222222222222
Testing batch: 0.38845324516296387
Batch-Time elapsed: 1042.6963291168213

***Training chunk-74:
Stat: (87568, 10) / Train: (10080, 10) / Valid: (2352, 10)
[Train] Num-cases: 392
Epoch: 0
+++ [Train] Loss: 0.5848092436790466 AUC: 0.7553149436629347 ACC: 0.6831200487507617
>>> [Train] TOTAL AUC: 0.7511299708515276
Epoch: 1
+++ [Train] Loss: 0.5756884217262268 AUC: 0.7634988167605392 ACC: 0.6975425330812854
>>> [Train] TOTAL AUC: 0.7658351844432645
[Valid] num_users: 392
- [Valid] AUC: 0.7890147893080982 ACC: 0.7256944444444444
>>> [Valid] TOTAL AUC: 0.7836026936575478
Testing batch: 0.4007692337036133
>>> [TESTING] AUC: 0.76

Testing batch: 0.39267587661743164
Batch-Time elapsed: 1202.3078167438507

***Training chunk-84:
Stat: (87617, 10) / Train: (10103, 10) / Valid: (2280, 10)
[Train] Num-cases: 380
Epoch: 0
+++ [Train] Loss: 0.5825344324111938 AUC: 0.7534276895462291 ACC: 0.6875383670963782
>>> [Train] TOTAL AUC: 0.7486203214448098
Epoch: 1
+++ [Train] Loss: 0.579613447189331 AUC: 0.7492304986368833 ACC: 0.6828839390386869
>>> [Train] TOTAL AUC: 0.7611917595607379
[Valid] num_users: 380
- [Valid] AUC: 0.7803967798541287 ACC: 0.7045685279187818
>>> [Valid] TOTAL AUC: 0.7809376887772516
Testing batch: 0.37940025329589844
>>> [TESTING] AUC: 0.7272727272727273
Testing batch: 0.39269065856933594
>>> [TESTING] AUC: 0.8636363636363636
Testing batch: 0.3928978443145752
>>> [TESTING] AUC: 0.7569444444444444
Testing batch: 0.3945322036743164
Batch-Time elapsed: 1216.014904975891

***Training chunk-85:
Stat: (87861, 10) / Train: (9859, 10) / Valid: (2280, 10)
[Train] Num-cases: 380
Epoch: 0
+++ [Train] Loss: 0.5829

>>> [Train] TOTAL AUC: 0.7700144209663035
[Valid] num_users: 460
- [Valid] AUC: 0.7967418708119609 ACC: 0.7217741935483871
>>> [Valid] TOTAL AUC: 0.7871791076565523
Testing batch: 0.3878297805786133
>>> [TESTING] AUC: 0.7662337662337663
Testing batch: 0.40147829055786133
>>> [TESTING] AUC: 0.9090909090909092
Testing batch: 0.4097597599029541
>>> [TESTING] AUC: 0.7777777777777777
Testing batch: 0.42467427253723145
Batch-Time elapsed: 1358.4735136032104

***Training chunk-95:
Stat: (88342, 10) / Train: (9432, 10) / Valid: (2226, 10)
[Train] Num-cases: 370
Epoch: 0
+++ [Train] Loss: 0.5890103578567505 AUC: 0.7502259953621822 ACC: 0.6955475330926595
>>> [Train] TOTAL AUC: 0.7620745555762448
Epoch: 1
+++ [Train] Loss: 0.550069272518158 AUC: 0.7777950937950938 ACC: 0.7189349112426036
>>> [Train] TOTAL AUC: 0.774781432297713
[Valid] num_users: 371
- [Valid] AUC: 0.7853239228669859 ACC: 0.712938711367208
>>> [Valid] TOTAL AUC: 0.7890107713246935
Testing batch: 0.3847534656524658
>>> [TESTING] 

Testing batch: 0.40615320205688477
Batch-Time elapsed: 1499.9115114212036

***Training chunk-105:
Stat: (88202, 10) / Train: (9566, 10) / Valid: (2232, 10)
[Train] Num-cases: 372
Epoch: 0
+++ [Train] Loss: 0.5931742191314697 AUC: 0.7304809078809613 ACC: 0.6860986547085202
>>> [Train] TOTAL AUC: 0.75040560366595
Epoch: 1
+++ [Train] Loss: 0.5719624757766724 AUC: 0.762685629547301 ACC: 0.696895922093731
>>> [Train] TOTAL AUC: 0.762225911955433
[Valid] num_users: 372
- [Valid] AUC: 0.7922839122269987 ACC: 0.7190954773869347
>>> [Valid] TOTAL AUC: 0.7869676399006043
Testing batch: 0.3863029479980469
>>> [TESTING] AUC: 0.7922077922077922
Testing batch: 0.39269447326660156
>>> [TESTING] AUC: 0.9204545454545454
Testing batch: 0.3964197635650635
>>> [TESTING] AUC: 0.7708333333333334
Testing batch: 0.3879368305206299
Batch-Time elapsed: 1514.0204510688782

***Training chunk-106:
Stat: (84712, 10) / Train: (12419, 10) / Valid: (2869, 10)
[Train] Num-cases: 478
Epoch: 0
+++ [Train] Loss: 0.570175

>>> [Train] TOTAL AUC: 0.7640040890553068
[Valid] num_users: 321
- [Valid] AUC: 0.8002395160962936 ACC: 0.7245720040281974
>>> [Valid] TOTAL AUC: 0.7844784279868404
Testing batch: 0.3956937789916992
>>> [TESTING] AUC: 0.7922077922077922
Testing batch: 0.3859443664550781
>>> [TESTING] AUC: 0.9034090909090908
Testing batch: 0.4024503231048584
>>> [TESTING] AUC: 0.8194444444444444
Testing batch: 0.4315907955169678
Batch-Time elapsed: 1655.4999556541443

***Training chunk-116:
Stat: (86402, 10) / Train: (11024, 10) / Valid: (2574, 10)
[Train] Num-cases: 429
Epoch: 0
+++ [Train] Loss: 0.5722572207450867 AUC: 0.7670755018482003 ACC: 0.7084352078239609
>>> [Train] TOTAL AUC: 0.7466725997795642
Epoch: 1
+++ [Train] Loss: 0.5748321413993835 AUC: 0.7654934954751131 ACC: 0.6983372921615202
>>> [Train] TOTAL AUC: 0.760767739682953
[Valid] num_users: 429
- [Valid] AUC: 0.7858940864434225 ACC: 0.7229696063776782
>>> [Valid] TOTAL AUC: 0.7790025720376603
Testing batch: 0.39144206047058105
>>> [TESTIN

Testing batch: 0.39685654640197754
>>> [TESTING] AUC: 0.8194444444444444
Testing batch: 0.3981006145477295
Batch-Time elapsed: 1798.1445481777191

***Training chunk-126:
Stat: (89469, 10) / Train: (8530, 10) / Valid: (2001, 10)
[Train] Num-cases: 333
Epoch: 0
+++ [Train] Loss: 0.5861421227455139 AUC: 0.7436542669584245 ACC: 0.6811868686868687
>>> [Train] TOTAL AUC: 0.7595660131065852
Epoch: 1
+++ [Train] Loss: 0.5755113363265991 AUC: 0.7591451040081177 ACC: 0.6904761904761905
>>> [Train] TOTAL AUC: 0.7702854202845824
[Valid] num_users: 334
- [Valid] AUC: 0.7658638677465521 ACC: 0.6970912738214644
>>> [Valid] TOTAL AUC: 0.7893167324476896
Testing batch: 0.39657115936279297
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.38904237747192383
>>> [TESTING] AUC: 0.9147727272727273
Testing batch: 0.4028890132904053
>>> [TESTING] AUC: 0.8055555555555556
Testing batch: 0.4062168598175049
Batch-Time elapsed: 1812.0564529895782

***Training chunk-127:
Stat: (86578, 10) / Train: (10932, 10) 

>>> [Train] TOTAL AUC: 0.7509343670503772
Epoch: 1
+++ [Train] Loss: 0.5688116550445557 AUC: 0.7580812471024572 ACC: 0.6754020250148898
>>> [Train] TOTAL AUC: 0.7629222111109867
[Valid] num_users: 437
- [Valid] AUC: 0.7830167155002642 ACC: 0.7101090188305252
>>> [Valid] TOTAL AUC: 0.7792482860482342
Testing batch: 0.39669179916381836
>>> [TESTING] AUC: 0.7532467532467533
Testing batch: 0.403353214263916
>>> [TESTING] AUC: 0.9090909090909091
Testing batch: 0.4081764221191406
>>> [TESTING] AUC: 0.8125
Testing batch: 0.41173243522644043
Batch-Time elapsed: 1953.9057915210724

***Training chunk-137:
Stat: (85628, 10) / Train: (11684, 10) / Valid: (2688, 10)
[Train] Num-cases: 448
Epoch: 0
+++ [Train] Loss: 0.570618212223053 AUC: 0.7602544235064801 ACC: 0.6873857404021938
>>> [Train] TOTAL AUC: 0.7628498088064498
Epoch: 1
+++ [Train] Loss: 0.5722423195838928 AUC: 0.7655236976607609 ACC: 0.6958608278344331
>>> [Train] TOTAL AUC: 0.7712668216499864
[Valid] num_users: 448
- [Valid] AUC: 0.7891

Testing batch: 0.40747499465942383
>>> [TESTING] AUC: 0.8977272727272727
Testing batch: 0.4304020404815674
>>> [TESTING] AUC: 0.7777777777777777
Testing batch: 0.4282090663909912
Batch-Time elapsed: 2095.075714111328

***Training chunk-147:
Stat: (87184, 10) / Train: (10416, 10) / Valid: (2400, 10)
[Train] Num-cases: 400
Epoch: 0
+++ [Train] Loss: 0.5637072324752808 AUC: 0.767876059322034 ACC: 0.6939582156973462
>>> [Train] TOTAL AUC: 0.7610211961699025
Epoch: 1
+++ [Train] Loss: 0.5489200353622437 AUC: 0.7844527565457797 ACC: 0.7092113184828417
>>> [Train] TOTAL AUC: 0.7740596294216131
[Valid] num_users: 400
- [Valid] AUC: 0.8054484408399679 ACC: 0.7454453441295547
>>> [Valid] TOTAL AUC: 0.7909127001008562
Testing batch: 0.3861839771270752
>>> [TESTING] AUC: 0.7662337662337663
Testing batch: 0.3992304801940918
>>> [TESTING] AUC: 0.8806818181818182
Testing batch: 0.3987710475921631
>>> [TESTING] AUC: 0.736111111111111
Testing batch: 0.40271472930908203
Batch-Time elapsed: 2109.18556857

>>> [Train] TOTAL AUC: 0.746307084729399
Epoch: 1
+++ [Train] Loss: 0.5663903951644897 AUC: 0.7609801403869201 ACC: 0.6867313915857605
>>> [Train] TOTAL AUC: 0.7563694318967448
[Valid] num_users: 250
- [Valid] AUC: 0.7921768900758588 ACC: 0.7226095617529881
>>> [Valid] TOTAL AUC: 0.7775512851389303
Testing batch: 0.39293861389160156
>>> [TESTING] AUC: 0.7272727272727273
Testing batch: 0.3930940628051758
>>> [TESTING] AUC: 0.8579545454545454
Testing batch: 0.3911724090576172
>>> [TESTING] AUC: 0.7152777777777778
Testing batch: 0.4011106491088867
Batch-Time elapsed: 2249.402574777603

***Training chunk-158:
Stat: (87571, 10) / Train: (10095, 10) / Valid: (2334, 10)
[Train] Num-cases: 389
Epoch: 0
+++ [Train] Loss: 0.5360907912254333 AUC: 0.8014185554046279 ACC: 0.7306763285024155
>>> [Train] TOTAL AUC: 0.7659068543007672
Epoch: 1
+++ [Train] Loss: 0.5703204274177551 AUC: 0.7693199148590675 ACC: 0.7002915451895044
>>> [Train] TOTAL AUC: 0.7770255371991159
[Valid] num_users: 389
- [Valid] 

Testing batch: 0.3930323123931885
>>> [TESTING] AUC: 0.9375
Testing batch: 0.4005568027496338
>>> [TESTING] AUC: 0.7638888888888888
Testing batch: 0.4105801582336426
Batch-Time elapsed: 2392.4981615543365

***Training chunk-168:
Stat: (88811, 10) / Train: (9119, 10) / Valid: (2070, 10)
[Train] Num-cases: 344
Epoch: 0
+++ [Train] Loss: 0.5578223466873169 AUC: 0.7751601854312458 ACC: 0.6977163461538461
>>> [Train] TOTAL AUC: 0.7595994541838139
Epoch: 1
+++ [Train] Loss: 0.5612187385559082 AUC: 0.7817876056708661 ACC: 0.706604324956166
>>> [Train] TOTAL AUC: 0.7717342989553022
[Valid] num_users: 345
- [Valid] AUC: 0.7856399409739301 ACC: 0.7163366336633663
>>> [Valid] TOTAL AUC: 0.7878464489183327
Testing batch: 0.4009544849395752
>>> [TESTING] AUC: 0.8571428571428571
Testing batch: 0.4045243263244629
>>> [TESTING] AUC: 0.9090909090909092
Testing batch: 0.402219295501709
>>> [TESTING] AUC: 0.75
Testing batch: 0.41854071617126465
Batch-Time elapsed: 2406.6268968582153

***Training chunk-16

>>> [Train] TOTAL AUC: 0.7597940355086116
Epoch: 1
+++ [Train] Loss: 0.5650681257247925 AUC: 0.7699745896029675 ACC: 0.7031963470319634
>>> [Train] TOTAL AUC: 0.7680391871424904
[Valid] num_users: 258
- [Valid] AUC: 0.8026381187102074 ACC: 0.7240161453077699
>>> [Valid] TOTAL AUC: 0.7865200159598644
Testing batch: 0.4248011112213135
>>> [TESTING] AUC: 0.8831168831168831
Testing batch: 0.4014549255371094
>>> [TESTING] AUC: 0.875
Testing batch: 0.40038466453552246
>>> [TESTING] AUC: 0.7847222222222222
Testing batch: 0.41551828384399414
Batch-Time elapsed: 2549.8712565898895

***Training chunk-179:
Stat: (88655, 10) / Train: (9163, 10) / Valid: (2182, 10)
[Train] Num-cases: 362
Epoch: 0
+++ [Train] Loss: 0.5872427225112915 AUC: 0.7357098266224001 ACC: 0.6762127410870836
>>> [Train] TOTAL AUC: 0.7557462558977442
Epoch: 1
+++ [Train] Loss: 0.5574899911880493 AUC: 0.7802402446245061 ACC: 0.7043147208121827
>>> [Train] TOTAL AUC: 0.7655971071110953
[Valid] num_users: 364
- [Valid] AUC: 0.8019

Testing batch: 0.3956427574157715
>>> [TESTING] AUC: 0.8920454545454546
Testing batch: 0.4084813594818115
>>> [TESTING] AUC: 0.8055555555555556
Testing batch: 0.4024360179901123
Batch-Time elapsed: 2693.261823654175

***Training chunk-189:
Stat: (86401, 10) / Train: (11079, 10) / Valid: (2520, 10)
[Train] Num-cases: 420
Epoch: 0
+++ [Train] Loss: 0.5878809094429016 AUC: 0.7426301924912238 ACC: 0.6725455614344503
>>> [Train] TOTAL AUC: 0.7547746875080943
Epoch: 1
+++ [Train] Loss: 0.5588675141334534 AUC: 0.7772674027958639 ACC: 0.7153240460327075
>>> [Train] TOTAL AUC: 0.7644522002246941
[Valid] num_users: 419
- [Valid] AUC: 0.772142807596058 ACC: 0.712152420185376
>>> [Valid] TOTAL AUC: 0.7805701057796748
Testing batch: 0.40617823600769043
>>> [TESTING] AUC: 0.8831168831168831
Testing batch: 0.4148592948913574
>>> [TESTING] AUC: 0.8920454545454545
Testing batch: 0.40354323387145996
>>> [TESTING] AUC: 0.8125
Testing batch: 0.4102318286895752
Batch-Time elapsed: 2707.8817834854126

***Tr

>>> [Train] TOTAL AUC: 0.7598108035501582
Epoch: 1
+++ [Train] Loss: 0.5579698085784912 AUC: 0.7710144849154337 ACC: 0.6951779563719862
>>> [Train] TOTAL AUC: 0.7673300500811304
[Valid] num_users: 450
- [Valid] AUC: 0.7832328531431035 ACC: 0.7146389713155292
>>> [Valid] TOTAL AUC: 0.7835827232408337
Testing batch: 0.4090921878814697
>>> [TESTING] AUC: 0.8181818181818182
Testing batch: 0.40552639961242676
>>> [TESTING] AUC: 0.8863636363636364
Testing batch: 0.42717695236206055
>>> [TESTING] AUC: 0.7847222222222223
Testing batch: 0.42653346061706543
Batch-Time elapsed: 2850.4848058223724

***Training chunk-200:
Stat: (87157, 10) / Train: (10419, 10) / Valid: (2424, 10)
[Train] Num-cases: 404
Epoch: 0
+++ [Train] Loss: 0.6072356104850769 AUC: 0.7250105587231724 ACC: 0.6604105571847507
>>> [Train] TOTAL AUC: 0.7605655757356942
Epoch: 1
+++ [Train] Loss: 0.5561763644218445 AUC: 0.7805058655507949 ACC: 0.7105108631826189
>>> [Train] TOTAL AUC: 0.7711813183844509
[Valid] num_users: 404
- [Val

Testing batch: 0.3993813991546631
>>> [TESTING] AUC: 0.8863636363636362
Testing batch: 0.4059720039367676
>>> [TESTING] AUC: 0.8125
Testing batch: 0.40358853340148926
Batch-Time elapsed: 2993.936304807663

***Training chunk-210:
Stat: (85496, 10) / Train: (11810, 10) / Valid: (2694, 10)
[Train] Num-cases: 449
Epoch: 0
+++ [Train] Loss: 0.5688941478729248 AUC: 0.7660183642941258 ACC: 0.705208929593589
>>> [Train] TOTAL AUC: 0.750076571477912
Epoch: 1
+++ [Train] Loss: 0.5749055743217468 AUC: 0.7667095986008224 ACC: 0.6904619076184763
>>> [Train] TOTAL AUC: 0.7601020267830203
[Valid] num_users: 449
- [Valid] AUC: 0.7763995705641227 ACC: 0.7123152709359606
>>> [Valid] TOTAL AUC: 0.7837728586162542
Testing batch: 0.41195034980773926
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.42121434211730957
>>> [TESTING] AUC: 0.8806818181818181
Testing batch: 0.44910168647766113
>>> [TESTING] AUC: 0.8055555555555556
Testing batch: 0.42201662063598633
Batch-Time elapsed: 3008.452961921692

***

+++ [Train] Loss: 0.5984736084938049 AUC: 0.7355375937723345 ACC: 0.6662902315076228
>>> [Train] TOTAL AUC: 0.7573488194489262
Epoch: 1
+++ [Train] Loss: 0.5713520646095276 AUC: 0.7582127083867474 ACC: 0.6897374701670644
>>> [Train] TOTAL AUC: 0.7659402874907096
[Valid] num_users: 401
- [Valid] AUC: 0.8016065111578051 ACC: 0.724155320221886
>>> [Valid] TOTAL AUC: 0.7874975814739207
Testing batch: 0.44228696823120117
>>> [TESTING] AUC: 0.7922077922077922
Testing batch: 0.4058539867401123
>>> [TESTING] AUC: 0.9090909090909092
Testing batch: 0.40851783752441406
>>> [TESTING] AUC: 0.7847222222222222
Testing batch: 0.4200704097747803
Batch-Time elapsed: 3152.3467292785645

***Training chunk-221:
Stat: (85524, 10) / Train: (11806, 10) / Valid: (2670, 10)
[Train] Num-cases: 445
Epoch: 0
+++ [Train] Loss: 0.5568141937255859 AUC: 0.774207563564925 ACC: 0.6999445368829729
>>> [Train] TOTAL AUC: 0.7647427442765227
Epoch: 1
+++ [Train] Loss: 0.5646235346794128 AUC: 0.7726372977032238 ACC: 0.704493

Testing batch: 0.41061902046203613
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.4307866096496582
>>> [TESTING] AUC: 0.8693181818181819
Testing batch: 0.4156010150909424
>>> [TESTING] AUC: 0.798611111111111
Testing batch: 0.4103107452392578
Batch-Time elapsed: 3295.2702355384827

***Training chunk-231:
Stat: (86988, 10) / Train: (10588, 10) / Valid: (2424, 10)
[Train] Num-cases: 404
Epoch: 0
+++ [Train] Loss: 0.5803285837173462 AUC: 0.7546855203170395 ACC: 0.6927339901477833
>>> [Train] TOTAL AUC: 0.7654353928329374
Epoch: 1
+++ [Train] Loss: 0.5439155101776123 AUC: 0.7824558512537418 ACC: 0.720360824742268
>>> [Train] TOTAL AUC: 0.7771905505175907
[Valid] num_users: 404
- [Valid] AUC: 0.7938566654732069 ACC: 0.7116716122650841
>>> [Valid] TOTAL AUC: 0.7958995047362186
Testing batch: 0.4097867012023926
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.39575910568237305
>>> [TESTING] AUC: 0.8806818181818182
Testing batch: 0.414287805557251
>>> [TESTING] AUC: 0.819444444444

[Train] Num-cases: 513
Epoch: 0
+++ [Train] Loss: 0.5825219750404358 AUC: 0.7487028694370483 ACC: 0.6903614457831325
>>> [Train] TOTAL AUC: 0.7599401090120935
Epoch: 1
+++ [Train] Loss: 0.5636582970619202 AUC: 0.7761759357332896 ACC: 0.704131227217497
>>> [Train] TOTAL AUC: 0.7675458244749154
[Valid] num_users: 514
- [Valid] AUC: 0.7761326502645838 ACC: 0.6925051334702259
>>> [Valid] TOTAL AUC: 0.7872993145944258
Testing batch: 0.41669321060180664
>>> [TESTING] AUC: 0.8311688311688311
Testing batch: 0.4061288833618164
>>> [TESTING] AUC: 0.8693181818181818
Testing batch: 0.4233736991882324
>>> [TESTING] AUC: 0.8680555555555555
Testing batch: 0.3931849002838135
Batch-Time elapsed: 3452.139795780182

***Training chunk-242:
Stat: (88504, 10) / Train: (9348, 10) / Valid: (2148, 10)
[Train] Num-cases: 358
Epoch: 0
+++ [Train] Loss: 0.5789346694946289 AUC: 0.7530690069477094 ACC: 0.6769779892920881
>>> [Train] TOTAL AUC: 0.7491548588078984
Epoch: 1
+++ [Train] Loss: 0.5483832955360413 AUC: 0.

Testing batch: 0.4240405559539795
>>> [TESTING] AUC: 0.8311688311688311
Testing batch: 0.4114806652069092
>>> [TESTING] AUC: 0.8068181818181819
Testing batch: 0.4245004653930664
>>> [TESTING] AUC: 0.7569444444444444
Testing batch: 0.41785478591918945
Batch-Time elapsed: 3595.3056671619415

***Training chunk-252:
Stat: (88033, 10) / Train: (9711, 10) / Valid: (2256, 10)
[Train] Num-cases: 376
Epoch: 0
+++ [Train] Loss: 0.5809093117713928 AUC: 0.7515410146989093 ACC: 0.6800720288115246
>>> [Train] TOTAL AUC: 0.7589131455362929
Epoch: 1
+++ [Train] Loss: 0.567691445350647 AUC: 0.7582137741633818 ACC: 0.6988602279544092
>>> [Train] TOTAL AUC: 0.7700179447846768
[Valid] num_users: 376
- [Valid] AUC: 0.7689032753494638 ACC: 0.686491935483871
>>> [Valid] TOTAL AUC: 0.7843604904154957
Testing batch: 0.4262216091156006
>>> [TESTING] AUC: 0.8051948051948051
Testing batch: 0.42804431915283203
>>> [TESTING] AUC: 0.8238636363636364
Testing batch: 0.41745710372924805
>>> [TESTING] AUC: 0.75694444444

[Train] Num-cases: 361
Epoch: 0
+++ [Train] Loss: 0.5802611708641052 AUC: 0.7546427892134071 ACC: 0.6941580756013745
>>> [Train] TOTAL AUC: 0.7604243876637062
Epoch: 1
+++ [Train] Loss: 0.5561210513114929 AUC: 0.7776627531863134 ACC: 0.7177879133409351
>>> [Train] TOTAL AUC: 0.7685203639849952
[Valid] num_users: 360
- [Valid] AUC: 0.8024035483261365 ACC: 0.7247247247247247
>>> [Valid] TOTAL AUC: 0.7877980992618375
Testing batch: 0.4050135612487793
>>> [TESTING] AUC: 0.8181818181818182
Testing batch: 0.40810418128967285
>>> [TESTING] AUC: 0.8352272727272727
Testing batch: 0.4064669609069824
>>> [TESTING] AUC: 0.7847222222222222
Testing batch: 0.41416430473327637
Batch-Time elapsed: 3752.5634365081787

***Training chunk-263:
Stat: (92001, 10) / Train: (6531, 10) / Valid: (1468, 10)
[Train] Num-cases: 244
Epoch: 0
+++ [Train] Loss: 0.5842896103858948 AUC: 0.7488952716235495 ACC: 0.6707865168539325
>>> [Train] TOTAL AUC: 0.7572325347004565
Epoch: 1
+++ [Train] Loss: 0.5467100739479065 AUC:

- [Valid] AUC: 0.801351089690404 ACC: 0.728
>>> [Valid] TOTAL AUC: 0.7770360273326532
Testing batch: 0.39849042892456055
>>> [TESTING] AUC: 0.8571428571428571
Testing batch: 0.41720080375671387
>>> [TESTING] AUC: 0.8352272727272727
Testing batch: 0.41251373291015625
>>> [TESTING] AUC: 0.8194444444444444
Testing batch: 0.41515183448791504
Batch-Time elapsed: 3894.677062034607

***Training chunk-273:
Stat: (86963, 10) / Train: (10535, 10) / Valid: (2502, 10)
[Train] Num-cases: 417
Epoch: 0
+++ [Train] Loss: 0.574279248714447 AUC: 0.7673031639777482 ACC: 0.6987542468856173
>>> [Train] TOTAL AUC: 0.7559656678413518
Epoch: 1
+++ [Train] Loss: 0.5683701038360596 AUC: 0.7685109116953152 ACC: 0.7117414248021108
>>> [Train] TOTAL AUC: 0.7664408290087276
[Valid] num_users: 417
- [Valid] AUC: 0.7975798175141514 ACC: 0.7237237237237237
>>> [Valid] TOTAL AUC: 0.7820480783106366
Testing batch: 0.4133107662200928
>>> [TESTING] AUC: 0.8571428571428571
Testing batch: 0.411970853805542
>>> [TESTING] AUC

[Train] Num-cases: 400
Epoch: 0
+++ [Train] Loss: 0.565194308757782 AUC: 0.7702247356051704 ACC: 0.6951807228915663
>>> [Train] TOTAL AUC: 0.7586416524436007
Epoch: 1
+++ [Train] Loss: 0.5811862349510193 AUC: 0.7530010145539106 ACC: 0.6809029896278218
>>> [Train] TOTAL AUC: 0.768604752781841
[Valid] num_users: 400
- [Valid] AUC: 0.7999671587719432 ACC: 0.7240684793554885
>>> [Valid] TOTAL AUC: 0.7860094237885955
Testing batch: 0.42450928688049316
>>> [TESTING] AUC: 0.8181818181818182
Testing batch: 0.4104292392730713
>>> [TESTING] AUC: 0.8295454545454546
Testing batch: 0.42370009422302246
>>> [TESTING] AUC: 0.8333333333333333
Testing batch: 0.4088113307952881
Batch-Time elapsed: 4054.866697072983

***Training chunk-284:
Stat: (88879, 10) / Train: (9032, 10) / Valid: (2089, 10)
[Train] Num-cases: 348
Epoch: 0
+++ [Train] Loss: 0.5595952272415161 AUC: 0.774618931423501 ACC: 0.6902857142857143
>>> [Train] TOTAL AUC: 0.7669711510717099
Epoch: 1
+++ [Train] Loss: 0.5685923099517822 AUC: 0.7

Testing batch: 0.41197872161865234
>>> [TESTING] AUC: 0.8441558441558441
Testing batch: 0.4233665466308594
>>> [TESTING] AUC: 0.8068181818181819
Testing batch: 0.4188830852508545
>>> [TESTING] AUC: 0.8472222222222221
Testing batch: 0.42116808891296387
Batch-Time elapsed: 4199.0185441970825

***Training chunk-294:
Stat: (87197, 10) / Train: (10349, 10) / Valid: (2454, 10)
[Train] Num-cases: 409
Epoch: 0
+++ [Train] Loss: 0.5644815564155579 AUC: 0.7661604182566097 ACC: 0.702803738317757
>>> [Train] TOTAL AUC: 0.7614181773114479
Epoch: 1
+++ [Train] Loss: 0.5789635181427002 AUC: 0.7603427845958528 ACC: 0.6864608076009501
>>> [Train] TOTAL AUC: 0.7707418895925567
[Valid] num_users: 409
- [Valid] AUC: 0.7741256533483005 ACC: 0.6969072164948453
>>> [Valid] TOTAL AUC: 0.7792375611107697
Testing batch: 0.4075775146484375
>>> [TESTING] AUC: 0.8181818181818182
Testing batch: 0.4126121997833252
>>> [TESTING] AUC: 0.8011363636363636
Testing batch: 0.4198024272918701
>>> [TESTING] AUC: 0.875
Testin

In [20]:
with open(f'./save/bert_save/user_dict.pickle', 'wb') as handle:
    pickle.dump(trainer.dataset.user_dict, handle, protocol=3)

In [None]:
import riiideducation
# You can only call make_env() once, so don't lose it!
env = riiideducation.make_env()

# You can only iterate through a result from `env.iter_test()` once
# so be careful not to lose it once you start iterating.
iter_test = env.iter_test()

In [None]:
################################
# Submission
################################

print(f'Start testing ....')
for (test_df, sample_prediction_df) in iter_test:
    # do prediction
    pred = trainer.infer(test_df)

    # fill all value first
    test_df['answered_correctly'] = 0.5
    
    # Only fill question-type
    test_df.loc[test_df['content_type_id'] == 0, 'answered_correctly'] = pred

    # submit prediction    
    env.predict(test_df.loc[test_df['content_type_id'] == 0, ['row_id', 'answered_correctly']])
    
    trainer.finetune_batch()