In [1]:
import os
train_path = '../../raw_train_artifact'
test_path = '../../raw_test_artifact'
embedding_path = '../../embedding_artifact'
input_path = '../../input_artifact'
input_split_path = '../../input_artifact/input_split'
model_path = '../../model_artifact'
output_path = '../../output_artifact'

In [2]:
import sys
import gc
gc.enable()
import time
import re

import numpy as np
import pandas as pd
pd.set_option('display.max_columns',120)
pd.set_option('display.max_rows',2000)
pd.set_option('precision',5)
pd.set_option('float_format', '{:.5f}'.format)

import tqdm
import joblib
import json

from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, roc_auc_score
from gensim.models import Word2Vec
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
import logging

log_path = '[1.3]LSTM with Creative, Advertiser & Product Embedding Sequence.log'
    
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s %(levelname)-s: %(message)s', datefmt='%H:%M:%S')

fh = logging.FileHandler(log_path)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
logger.addHandler(fh)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
logger.addHandler(sh)

logger.info(f'Restart notebook\n==========================\n{time.ctime()}\n==========================')

07:40:10 INFO: Restart notebook
Thu Jun  4 07:40:10 2020


In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info('Device in Use: {}'.format(DEVICE))
torch.cuda.empty_cache()
t = torch.cuda.get_device_properties(DEVICE).total_memory/1024**3
c = torch.cuda.memory_cached(DEVICE)/1024**3
a = torch.cuda.memory_allocated(DEVICE)/1024**3
logger.info('CUDA Memory: Total {:.2f} GB, Cached {:.2f} GB, Allocated {:.2f} GB'.format(t,c,a))

07:40:10 INFO: Device in Use: cuda
07:40:10 INFO: CUDA Memory: Total 8.00 GB, Cached 0.00 GB, Allocated 0.00 GB


## Data Loader

### General Utility

In [5]:
inp_embed_artifact = {
    'creative': {
        'embedding_artifact': r'C:\JupyterNotebook\Tencent-Ads-Algo-Comp-2020\embedding_artifact\creative_id_embed_s160_w64_cbow_38168zon',
        'train_file_prefix': 'train_creative_agg_user',
        'test_file_prefix': 'test_creative_agg_user'
    },
    'ad': {
        'embedding_artifact': r'C:\JupyterNotebook\Tencent-Ads-Algo-Comp-2020\embedding_artifact\ad_id_embed_s160_w64_cbow_ibfi8g78',
        'train_file_prefix': 'train_ad_agg_user',
        'test_file_prefix': 'test_ad_agg_user'
    },
    'advertiser': {
        'embedding_artifact': r'C:\JupyterNotebook\Tencent-Ads-Algo-Comp-2020\embedding_artifact\advertiser_id_embed_s128_w64_cbow_n4re8tds',
        'train_file_prefix': 'train_advertiser_agg_user',
        'test_file_prefix': 'test_advertiser_agg_user'
    },
    'product': {
        'embedding_artifact': r'C:\JupyterNotebook\Tencent-Ads-Algo-Comp-2020\embedding_artifact\product_id_embed_s128_w64_cbow_8yemmp45',
        'train_file_prefix': 'train_product_agg_user',
        'test_file_prefix': 'test_product_agg_user'
    }
}

def get_truth(split_id, logger=None):
    """
    Get user id and ground truth
    """
    start = time.time()
    
    truth_path = os.path.join(input_split_path, f'train_truth_{split_id}.npy')
    with open(truth_path, 'rb') as f:
        truth = np.load(f)
        
    inp_user = truth[:,0]
    out_age = torch.from_numpy(truth[:,1]).long()
    out_gender = torch.from_numpy(truth[:,2]).long()
    
    del truth
    _ = gc.collect()
    
    if logger: logger.info(f'Target output ready after {time.time()-start:.2f}s')
    return inp_user, out_age, out_gender

def get_embed_seq(split_id, embed_var, inp_user, max_seq=100, train=True, logger=None):
    """
    Get corresponding embedding sequence
    """
    global inp_embed_artifact, input_split_path
    assert embed_var in inp_embed_artifact
    
    start = time.time()
    embedding = Word2Vec.load(inp_embed_artifact[embed_var]['embedding_artifact'])
    if logger: logger.info(f'{embed_var.capitalize()} embedding artifact is loaded after {time.time()-start:.2f}s')
    start = time.time()
    file_prefix = inp_embed_artifact[embed_var]['train_file_prefix'] if train else inp_embed_artifact[embed_var]['test_file_prefix']
    raw_path = os.path.join(input_split_path, f'{file_prefix}_{split_id}.json')
    with open(raw_path, 'r') as f:
        raw = json.load(f)
    inp_seq = []
    for user in inp_user:
        inp_seq.append(torch.from_numpy(np.stack([embedding.wv[key] for key in raw[str(user)][:max_seq]], axis=0)).float())
    inp_last_idx = np.array([i.shape[0] for i in inp_seq])-1
    
    del embedding, raw
    _ = gc.collect()
    
    if logger: logger.info(f'{embed_var.capitalize()} embedding sequence ready after {time.time()-start:.2f}s')
    return inp_seq, inp_last_idx

In [6]:
def prepare_train(split_id, max_seq=100, logger=None):
    """
    Get ground truth, and embedding sequence for creative, product and advertiser
    """
    if logger: logger.info(f'Preparing Training Split-{split_id}')
        
    inp_user, out_age, out_gender = get_truth(split_id, logger=logger)
    inp_creative_seq, inp_last_idx = get_embed_seq(split_id, 'creative',inp_user, max_seq=max_seq, logger=logger)
    inp_advertiser_seq, _ = get_embed_seq(split_id, 'advertiser',inp_user, max_seq=max_seq, logger=logger)
    inp_product_seq, _ = get_embed_seq(split_id, 'product',inp_user, max_seq=max_seq, logger=logger)
    
    del inp_user
    _ = gc.collect()
    
    return out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx   

def prepare_test(split_id, max_seq=100, logger=None):
    global input_split_path
    if logger: logger.info(f'Preparing Training Split-{split_id}')
        
    idx_path = os.path.join(input_split_path, 'test_idx_shuffle.npy')
    with open(idx_path, 'rb') as f:
        test_idx = np.load(f)
    inp_user = test_idx[(split_id-1)*100000:split_id*100000]
    del test_idx
    _ = gc.collect()
    
    inp_creative_seq, inp_last_idx = get_embed_seq(split_id, 'creative',inp_user, max_seq=max_seq, train=False, logger=logger)
    inp_advertiser_seq, _ = get_embed_seq(split_id, 'advertiser',inp_user, max_seq=max_seq, train=False, logger=logger)
    inp_product_seq, _ = get_embed_seq(split_id, 'product',inp_user, max_seq=max_seq, train=False, logger=logger)
    
    _ = gc.collect()
    
    return inp_user, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx

## Model

In [7]:
class LSTM_Extraction_Layer(nn.Module):
    """
    LSTM feature extration layer
    - Layer 1: BiLSTM + Dropout + Layernorm
    - Layer 2: LSTM with Residual Connection + Dropout + Layernorm
    - Layer 3: LSTM + Batchnorm + ReLU + Dropout
    """
    def __init__(self, embed_size, lstm_hidden_size, rnn_dropout=0.2, mlp_dropout=0.4, **kwargs):
        super(LSTM_Extraction_Layer, self).__init__(**kwargs)
        self.embed_size = embed_size
        self.lstm_hidden_size = lstm_hidden_size
        self.rnn_dropout = rnn_dropout
        self.mlp_dropout = mlp_dropout
        
        self.bi_lstm = nn.LSTM(input_size=embed_size, hidden_size=lstm_hidden_size, bias=True, bidirectional=True)
        self.rnn_dropout_1 = nn.Dropout(p=rnn_dropout)
        self.layernorm_1 = nn.LayerNorm(2*lstm_hidden_size)
        self.lstm_1 = nn.LSTM(input_size=2*lstm_hidden_size, hidden_size=2*lstm_hidden_size)
        self.rnn_dropout_2 = nn.Dropout(p=rnn_dropout)
        self.layernorm_2 = nn.LayerNorm(2*lstm_hidden_size)
        self.lstm_2 = nn.LSTM(input_size=2*lstm_hidden_size, hidden_size=2*lstm_hidden_size)
        self.batchnorm = nn.BatchNorm1d(2*lstm_hidden_size)
        self.mlp_dropout = nn.Dropout(p=mlp_dropout)
        
    def forward(self, inp_embed, inp_last_idx):
        bilstm_out, _ = self.bi_lstm(inp_embed.permute(1,0,2))                            # (max_seq_length, batch_size, embed_size) -> (max_seq_length, batch_size, 2*lstm_hidden_size)
        bilstm_out = self.layernorm_1(self.rnn_dropout_1(bilstm_out))                     # (max_seq_length, batch_size, 2*lstm_hidden_size)
        lstm_out, _ = self.lstm_1(bilstm_out)                                             # (max_seq_length, batch_size, 2*lstm_hidden_size)
        lstm_out = self.rnn_dropout_2(lstm_out)                                           # (max_seq_length, batch_size, 2*lstm_hidden_size)
        lstm_out = self.layernorm_2(lstm_out+bilstm_out)                                  # (max_seq_length, batch_size, 2*lstm_hidden_size)
        lstm_out, _ = self.lstm_2(lstm_out)                                               # (max_seq_length, batch_size, 2*lstm_hidden_size)
        lstm_out = lstm_out.permute(1,0,2)[np.arange(len(inp_last_idx)), inp_last_idx,:]  # (batch_size, 2*lstm_hidden_size)
        lstm_out = self.mlp_dropout(F.relu(self.batchnorm(lstm_out)))                     # (batch_size, 2*lstm_hidden_size)
        return lstm_out
    
class MLP_Classification_Layer(nn.Module):
    """
    Multilayer Perception Classification Layer
    - Layer 1: Linear + Batchnorm + ReLU + Dropout
    - Layer 2: Linear + Batchnorm + ReLU + Dropout
    - Layer 3: Linear
    """
    def __init__(self, inp_size, out_size, dropout=0.4, **kwargs):
        super(MLP_Classification_Layer, self).__init__(**kwargs)
        self.inp_size = inp_size
        self.out_size = out_size
        self.dropout = dropout
        
        self.mlp_1 = nn.Linear(inp_size, 1024)
        self.batchnorm_1 = nn.BatchNorm1d(1024)
        self.mlp_dropout_1 = nn.Dropout(p=dropout)
        self.mlp_2 = nn.Linear(1024, 512)
        self.batchnorm_2 = nn.BatchNorm1d(512)
        self.mlp_dropout_2 = nn.Dropout(p=dropout)
        self.mlp_3 = nn.Linear(512, out_size)
        
    def forward(self, inp):
        mlp_out = self.mlp_1(inp)                                                         # (batch_size, 1024)
        mlp_out = self.mlp_dropout_1(F.relu(self.batchnorm_1(mlp_out)))                   # (batch_size, 1024)
        mlp_out = self.mlp_2(mlp_out)                                                     # (batch_size, 512)
        mlp_out = self.mlp_dropout_2(F.relu(self.batchnorm_2(mlp_out)))                   # (batch_size, 512)
        mlp_out = self.mlp_3(mlp_out)                                                     # (batch_size, out_size)
        return mlp_out   
    
class Multi_Seq_LSTM_Classifier(nn.Module):
    """
    Use separate LSTM extractor to handle different sequences, concat them and feed backto multilayer perception classifier.
    """
    def __init__(self, embed_size, lstm_hidden_size, out_size, rnn_dropout=0.2, mlp_dropout=0.4, **kwargs):
        super(Multi_Seq_LSTM_Classifier, self).__init__(**kwargs)
        assert isinstance(embed_size, list) and isinstance(lstm_hidden_size, list) and len(embed_size)==len(lstm_hidden_size)
        
        self.embed_size = embed_size
        self.lstm_hidden_size = lstm_hidden_size
        self.out_size = out_size
        self.rnn_dropout = rnn_dropout
        self.mlp_dropout = mlp_dropout
        
        self.n_extraction = len(embed_size)
        self.mlp_inp_size = sum(map(lambda x:2*x, lstm_hidden_size))
        
        for index, (e_size, h_size) in enumerate(zip(embed_size, lstm_hidden_size)):
            setattr(self, f'extraction_layer_{index}', LSTM_Extraction_Layer(e_size, h_size, rnn_dropout=rnn_dropout, mlp_dropout=mlp_dropout))
        self.classification_layer = MLP_Classification_Layer(self.mlp_inp_size, out_size, dropout=mlp_dropout)
        
    def forward(self, *args):
        assert len(args)==self.n_extraction+1
        
        extract_buffer = [getattr(self, f'extraction_layer_{index}')(inp_embed, args[-1]) for index, inp_embed in enumerate(args[:-1])]
        out = torch.cat(extract_buffer, 1)
        out = self.classification_layer(out)
        
        return out

## Train Age Model

In [8]:
EPOCHES = 5
BATCH_SIZE = 256
div, mod = divmod(90000, BATCH_SIZE)
N_BATCH = div + min(mod, 1)

def train_age(model, loss_fn, optimizer, device, checkpoint_dir, checkpoint_prefix, logger=None, epoch_start=0):
    global EPOCHES, BATCH_SIZE, N_BATCH, TEST_SIZE
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    
    for epoch in range(1+epoch_start, EPOCHES+1+epoch_start):
        if logger: 
            logger.info('=========================')
            logger.info(f'Processing Epoch {epoch}/{EPOCHES+epoch_start}')
            logger.info('=========================')
            
        train_file = [1,2,3,4,5,6,7,8,9]
        test_file = [10]
            
        train_running_loss, train_n_batch = 0, 0
        pred_y, true_y = [], []
        for index, split_id in enumerate(train_file, start=1):
            out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx = prepare_train(split_id)
            model.train()
            
            for batch_index in range(N_BATCH):
                x1 = torch.nn.utils.rnn.pad_sequence(inp_creative_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x2 = torch.nn.utils.rnn.pad_sequence(inp_advertiser_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x3 = torch.nn.utils.rnn.pad_sequence(inp_product_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x4 = inp_last_idx[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
                y = out_age[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE].to(device)
                optimizer.zero_grad()
                yp = F.softmax(model(x1, x2, x3, x4), 1)
                loss = loss_fn(yp, y)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
                optimizer.step()
                
                train_running_loss += loss.item()
                train_n_batch += 1
                
                del x1, x2, x3, x4, y, yp
                _ = gc.collect()
                torch.cuda.empty_cache()
            
            if logger:
                logger.info(f'Epoch {epoch}/{EPOCHES+epoch_start} - Training Split {index}/{len(train_file)} Done - Train Loss: {train_running_loss/train_n_batch:.6f}')
            
            del out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx
            _ = gc.collect()
            torch.cuda.empty_cache()   
        
        model.eval()
        test_running_loss, test_n_batch = 0, 0
        true_y, pred_y = [], []
        
        for index, split_id in enumerate(test_file, start=1):
            out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx = prepare_train(split_id)
            for batch_index in range(N_BATCH):
                x1 = torch.nn.utils.rnn.pad_sequence(inp_creative_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x2 = torch.nn.utils.rnn.pad_sequence(inp_advertiser_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x3 = torch.nn.utils.rnn.pad_sequence(inp_product_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x4 = inp_last_idx[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
                y = out_age[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE].to(device)
                yp = F.softmax(model(x1, x2, x3, x4), 1)
                loss = loss_fn(yp, y)
            
                test_running_loss += loss.item()
                test_n_batch += 1
            
                pred_y.extend(list(yp.cpu().detach().numpy()))
                true_y.extend(list(y.cpu().detach().numpy()))
            
                del x1, x2, x3, x4, y, yp
                _ = gc.collect()
                torch.cuda.empty_cache()
            
            del out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx
            _ = gc.collect()
            torch.cuda.empty_cache()
        
        pred = np.argmax(np.array(pred_y), 1)
        true = np.array(true_y).reshape((-1,))
        acc_score = accuracy_score(true, pred)
        
        if logger:
            logger.info(f'Epoch {epoch}/{EPOCHES+epoch_start} Done - Test Loss: {test_running_loss/test_n_batch:.6f}, Test Accuracy: {acc_score:.6f}')
            
        ck_file_name = f'{checkpoint_prefix}_{epoch}.pth'
        ck_file_path = os.path.join(checkpoint_dir, ck_file_name)
        
        torch.save(model.state_dict(), ck_file_path)

In [9]:
model = Multi_Seq_LSTM_Classifier([160, 128, 128], [256, 256, 256], 10).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
device = DEVICE
optimizer = torch.optim.Adam(model.parameters())
checkpoint_dir = os.path.join(model_path, 'Multi_Seq_LSTM_Classifier_Creative_Advertiser_Product_Age')
checkpoint_prefix = 'Multi_Seq_LSTM_Classifier_Creative_Advertiser_Product_Age'

train_age(model, loss_fn, optimizer, device, checkpoint_dir, checkpoint_prefix, logger=logger)

07:40:16 INFO: Processing Epoch 1/5
07:46:12 INFO: Epoch 1/5 - Training Split 1/9 Done - Train Loss: 2.137750
07:52:21 INFO: Epoch 1/5 - Training Split 2/9 Done - Train Loss: 2.123506
07:58:28 INFO: Epoch 1/5 - Training Split 3/9 Done - Train Loss: 2.114756
08:04:33 INFO: Epoch 1/5 - Training Split 4/9 Done - Train Loss: 2.108503
08:10:39 INFO: Epoch 1/5 - Training Split 5/9 Done - Train Loss: 2.103345
08:16:41 INFO: Epoch 1/5 - Training Split 6/9 Done - Train Loss: 2.099431
08:22:41 INFO: Epoch 1/5 - Training Split 7/9 Done - Train Loss: 2.095377
08:28:42 INFO: Epoch 1/5 - Training Split 8/9 Done - Train Loss: 2.092019
08:34:43 INFO: Epoch 1/5 - Training Split 9/9 Done - Train Loss: 2.089040
08:37:49 INFO: Epoch 1/5 Done - Test Loss: 2.065752, Test Accuracy: 0.385378
08:37:49 INFO: Processing Epoch 2/5
08:43:45 INFO: Epoch 2/5 - Training Split 1/9 Done - Train Loss: 2.063969
08:49:47 INFO: Epoch 2/5 - Training Split 2/9 Done - Train Loss: 2.061952
08:55:51 INFO: Epoch 2/5 - Training S

In [None]:
model = Multi_Seq_LSTM_Classifier([160, 128, 128], [256, 256, 256], 10)
checkpoint_dir = os.path.join(model_path, 'Multi_Seq_LSTM_Classifier_Creative_Advertiser_Product_Age')
checkpoint_prefix = 'Multi_Seq_LSTM_Classifier_Creative_Advertiser_Product_Age'
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'{checkpoint_prefix}_5.pth')))

model = model.to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
device = DEVICE
optimizer = torch.optim.Adam(model.parameters())

train_age(model, loss_fn, optimizer, device, checkpoint_dir, checkpoint_prefix, logger=logger, epoch_start=5)

12:45:19 INFO: Processing Epoch 6/10
12:51:17 INFO: Epoch 6/10 - Training Split 1/9 Done - Train Loss: 2.026355
12:57:26 INFO: Epoch 6/10 - Training Split 2/9 Done - Train Loss: 2.026364
13:03:32 INFO: Epoch 6/10 - Training Split 3/9 Done - Train Loss: 2.025312
13:09:42 INFO: Epoch 6/10 - Training Split 4/9 Done - Train Loss: 2.025212
13:15:48 INFO: Epoch 6/10 - Training Split 5/9 Done - Train Loss: 2.024471
13:21:54 INFO: Epoch 6/10 - Training Split 6/9 Done - Train Loss: 2.024291
13:28:01 INFO: Epoch 6/10 - Training Split 7/9 Done - Train Loss: 2.023931
13:34:06 INFO: Epoch 6/10 - Training Split 8/9 Done - Train Loss: 2.023714
13:40:12 INFO: Epoch 6/10 - Training Split 9/9 Done - Train Loss: 2.023405
13:43:24 INFO: Epoch 6/10 Done - Test Loss: 2.027633, Test Accuracy: 0.425144
13:43:24 INFO: Processing Epoch 7/10
13:49:21 INFO: Epoch 7/10 - Training Split 1/9 Done - Train Loss: 2.021880
13:55:30 INFO: Epoch 7/10 - Training Split 2/9 Done - Train Loss: 2.021543
14:01:34 INFO: Epoch 7/

## Code Backup

In [None]:
EPOCHES = 5
BATCH_SIZE = 512
N_BATCH = 90000//BATCH_SIZE
TEST_SIZE = 90000%BATCH_SIZE

def train_age(model, loss_fn, optimizer, device, checkpoint_dir, checkpoint_prefix, logger=None, epoch_start=0):
    global EPOCHES, BATCH_SIZE, N_BATCH, TEST_SIZE
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    
    for epoch in range(1+epoch_start, EPOCHES+1+epoch_start):
        if logger: 
            logger.info('=========================')
            logger.info(f'Processing Epoch {epoch}/{EPOCHES+epoch_start}')
            logger.info('=========================')
            
        train_file = [1,2,3,4,5,6,7,8,9]
        test_file = [10]
            
        train_running_loss, train_n_batch = 0, 0
        pred_y, true_y = [], []
        for index, split_id in enumerate(train_file, start=1):
            out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx = prepare_train(split_id)
            train_creative_seq, test_creative_seq = inp_creative_seq[:-TEST_SIZE], inp_creative_seq[-TEST_SIZE:]
            train_advertiser_seq, test_advertiser_seq = inp_advertiser_seq[:-TEST_SIZE], inp_advertiser_seq[-TEST_SIZE:]
            train_product_seq, test_product_seq = inp_product_seq[:-TEST_SIZE], inp_product_seq[-TEST_SIZE:]
            train_last_idx, test_last_idx = inp_last_idx[:-TEST_SIZE], inp_last_idx[-TEST_SIZE:]
            train_age, test_age = out_age[:-TEST_SIZE], out_age[-TEST_SIZE:]
            
            model.train()
            
            for batch_index in range(N_BATCH):
                x1 = torch.nn.utils.rnn.pad_sequence(train_creative_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x2 = torch.nn.utils.rnn.pad_sequence(train_advertiser_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x3 = torch.nn.utils.rnn.pad_sequence(train_product_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x4 = train_last_idx[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
                y = train_age[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE].to(device)
                optimizer.zero_grad()
                yp = F.softmax(model(x1, x2, x3, x4), 1)
                loss = loss_fn(yp, y)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=100)
                optimizer.step()
                
                train_running_loss += loss.item()
                train_n_batch += 1
                
                del x1, x2, x3, x4, y, yp
                _ = gc.collect()
                torch.cuda.empty_cache()
            
            model.eval()
            
            x1 = torch.nn.utils.rnn.pad_sequence(test_creative_seq, batch_first=True, padding_value=0).to(device)
            x2 = torch.nn.utils.rnn.pad_sequence(test_advertiser_seq, batch_first=True, padding_value=0).to(device)
            x3 = torch.nn.utils.rnn.pad_sequence(test_product_seq, batch_first=True, padding_value=0).to(device)
            x4 = test_last_idx
            y = test_age.to(device)
            yp = F.softmax(model(x1, x2, x3, x4), 1)
            loss = loss_fn(yp, y)
            
            pred_y.extend(list(yp.cpu().detach().numpy()))
            true_y.extend(list(y.cpu().detach().numpy()))
            
            del x1, x2, x3, x4, y, yp
            _ = gc.collect()
            torch.cuda.empty_cache()
            
            pred = np.argmax(np.array(pred_y), 1)
            true = np.array(true_y).reshape((-1,))
            acc_score = accuracy_score(true, pred)
            
            if logger:
                logger.info(f'Epoch {epoch}/{EPOCHES+epoch_start} - Training Split {index}/{len(train_file)} Done - Train Loss: {train_running_loss/train_n_batch:.6f}, Val Loss: {loss.item():.6f}, Val Accuracy: {acc_score:.6f}')
            
            del out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx
            del train_creative_seq, test_creative_seq, train_advertiser_seq, test_advertiser_seq, train_product_seq, test_product_seq, train_last_idx, test_last_idx, train_age, test_age
            _ = gc.collect()
            torch.cuda.empty_cache()   
        
        model.eval()
        test_running_loss, test_n_batch = 0, 0
        true_y, pred_y = [], []
        
        for index, split_id in enumerate(test_file, start=1):
            out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx = prepare_train(split_id)
            for batch_index in range(N_BATCH+1):
                x1 = torch.nn.utils.rnn.pad_sequence(inp_creative_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x2 = torch.nn.utils.rnn.pad_sequence(inp_advertiser_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x3 = torch.nn.utils.rnn.pad_sequence(inp_product_seq[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE], batch_first=True, padding_value=0).to(device)
                x4 = inp_last_idx[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE]
                y = out_age[batch_index*BATCH_SIZE:(batch_index+1)*BATCH_SIZE].to(device)
                yp = F.softmax(model(x1, x2, x3, x4), 1)
                loss = loss_fn(yp, y)
            
                test_running_loss += loss.item()
                test_n_batch += 1
            
                pred_y.extend(list(yp.cpu().detach().numpy()))
                true_y.extend(list(y.cpu().detach().numpy()))
            
                del x1, x2, x3, x4, y, yp
                _ = gc.collect()
                torch.cuda.empty_cache()
            
            del out_age, out_gender, inp_creative_seq, inp_advertiser_seq, inp_product_seq, inp_last_idx
            _ = gc.collect()
            torch.cuda.empty_cache()
        
        pred = np.argmax(np.array(pred_y), 1)
        true = np.array(true_y).reshape((-1,))
        acc_score = accuracy_score(true, pred)
        
        if logger:
            logger.info(f'Epoch {epoch}/{EPOCHES+epoch_start} Done - Test Loss: {test_running_loss/test_n_batch:.6f}, Test Accuracy: {acc_score:.6f}')
            
        ck_file_name = f'{checkpoint_prefix}_{epoch}.pth'
        ck_file_path = os.path.join(checkpoint_dir, ck_file_name)
        
        torch.save(model.state_dict(), ck_file_path)