In [None]:
from transformers import BertTokenizer, BertPreTrainedModel, AdamW, BertConfig, BertModel
from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup

from keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import torch
import time, random
import numpy as np
import torch.nn as nn
import datetime
from scipy import sparse
from apex import amp

In [None]:
# load Train Data Set

train_inputs = torch.load('./ipc_train_input256.pt')
train_masks = torch.load('./ipc_train_mask256.pt')
train_labels = torch.load('./train_label.pt')

# load Valid Data Set
test_inputs = torch.load('./valid_input.pt')
test_masks = torch.load('./valid_mask.pt')
test_labels = torch.load('./valid_labels.pt')

In [None]:
batch_size = 64

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_data_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = RandomSampler(test_data)
test_data_loader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [None]:
GPU_NUM = 1
device = torch.device(f'cuda:{GPU_NUM}')
torch.cuda.set_device(device)

In [None]:
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        _, pooled_output = self.bert(input_ids, attention_mask)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

In [None]:
model = BertForMultiLabelSequenceClassification.from_pretrained('bert-base-uncased', cache_dir=None, num_labels=1383)

model = model.cuda()

In [None]:
EPOCHS = 20

optimizer = AdamW(model.parameters(), lr=5e-5, correct_bias=False, eps=1e-8, weight_decay=0.01)

model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # AMP 적용을 위한 코드

total_step = len(train_data_loader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_step/10,
    num_training_steps=total_step
)

loss_fn = torch.nn.BCEWithLogitsLoss().to(device)

In [None]:
def format_time(elapsed):

    # 반올림
    elapsed_rounded = int(round((elapsed)))
    
    # hh:mm:ss으로 형태 변경
    return str(datetime.timedelta(seconds=elapsed_rounded))

def predict_step(model, batch, k:int, loss_fn):
    model.eval()
    batch = tuple(t.to(device) for t in batch)
    input_ids, input_mask, input_labels = batch
    with torch.no_grad():
        loss = model(input_ids, token_type_ids=None, attention_mask=input_mask)
        scores, labels = torch.topk(loss, 5)
        loss = loss_fn(loss, input_labels.float())
        return torch.sigmoid(scores).cpu(), labels.cpu(), loss.cpu()

def get_p_5(predict, target, top):
    prediction = []
    for index_list in predict:
        predicts = [0]*target.shape[1]
        for index in index_list[:top]:
            predicts[index] = 1
        prediction.append(predicts)
    prediction = np.array(prediction)
    target = np.array(target)
    return np.sum(np.multiply(prediction,target))/(top*target.shape[0])

def get_ndcg_5(predict, target, top):

    target = sparse.csr_matrix(np.array(target))
    log = 1.0 / np.log2(np.arange(top) + 2)
    dcg = np.zeros((target.shape[0], 1))
    
    for i in range(top):
        prediction = []
        for index_list in predict:
            p = index_list[i: i+1]
            predicts = [0]*target.shape[1]
            predicts[p] = 1
            prediction.append(predicts)
        prediction = sparse.csr_matrix(np.array(prediction))
        dcg += prediction.multiply(target).sum(axis=-1) * log[i]
        
    return np.average(dcg / log.cumsum()[np.minimum(target.sum(axis=-1), top) - 1])

In [None]:
import time, random
import numpy as np
import logzero
logzero.setup_default_logger(logfile='./BERT_XMTC.log')
from logzero import logger
import os
import tqdm
import torch.nn.functional as F

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


model.zero_grad()
best_p3, check_step, best_loss = 0, 0, 1.0
        
for epoch_i in range(0, EPOCHS):
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, EPOCHS))
    print('Training...')
    
    logger.info("")
    logger.info('======== Epoch {:} / {:} ========'.format(epoch_i + 1, EPOCHS))
    logger.info('Training...')
    
    t0 = time.time()
    
    total_loss = 0
    
    for step, batch in enumerate(train_data_loader):    
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, input_labels = batch
        model.train()
        logit = model(input_ids, token_type_ids=None, attention_mask=input_mask)
        loss = loss_fn(logit, input_labels.float())
        optimizer.zero_grad()
        with amp.scale_loss(loss, optimizer) as scaled_loss: # AMP 적용을 위한 코드
            scaled_loss.backward()
        #loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
              
        if step % 1000 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            
            check_val_loss, valid_step = 0, 0
            p1, p3, p5 = 0.0, 0.0, 0.0
            for batch in test_data_loader:
                _, labels, val_loss = predict_step(model, batch, 5, loss_fn)
                targets = batch[2].cpu()
                check_val_loss += val_loss.item()
                valid_step += 1

                p1 += get_p_5(labels, targets, 1)
                p3 += get_p_5(labels, targets, 3)
                p5 += get_p_5(labels, targets, 5)
                
            if  best_p3 < p3:
                path = './save_model'
                if os.path.exists(path):
                    model.save_pretrained(path)
                else:
                    os.mkdir(path)
                    model.save_pretrained(path)
                best_p3 = p3
                logger.info("create best model")
                check_step = 0
            else:
                check_step += 1
                if check_step >= 25:
                    break
            
            print("{:>2} in {:>6}     train loss: {:.5f}     valid loss: {:5f}     p1 : {:.5f}     p3 : {:.5f}     p5 : {:.5f}     check_step : {:>2}".format(epoch_i, train_data_loader.batch_size*step, loss.item(), check_val_loss/valid_step, p1/valid_step, p3/valid_step, p5/valid_step, check_step))
            logger.info("{:>2} in {:>6}     train loss: {:.5f}     valid loss: {:5f}     p1 : {:.5f}     p3 : {:.5f}     p5 : {:.5f}     check_step : {:>2}".format(epoch_i, train_data_loader.batch_size*step, loss.item(), check_val_loss/valid_step, p1/valid_step, p3/valid_step, p5/valid_step, check_step))


        del loss, batch, input_ids, input_mask, input_labels

    
    avg_train_loss = total_loss / len(train_data_loader)
    print("Average training loss: {0:.5f}".format(avg_train_loss))
    print("Training epcoh took: {:}".format(format_time(time.time() - t0)))
    logger.info("")
    logger.info("Average training loss: {0:.5f}".format(avg_train_loss))
    logger.info("Training epcoh took: {:}".format(format_time(time.time() - t0)))
    
    check_val_loss, valid_step = 0, 0
    p1, p3, p5 = 0.0, 0.0, 0.0
    for batch in test_data_loader:
        _, labels, val_loss = predict_step(model, batch, 5, loss_fn)
        targets = batch[2].cpu()
        check_val_loss += val_loss.item()
        valid_step += 1

        p1 += get_p_5(labels, targets, 1)
        p3 += get_p_5(labels, targets, 3)
        p5 += get_p_5(labels, targets, 5)
        
        
    avg_valid_loss = check_val_loss/valid_step
       
    if  best_p3 < p3:
        path = './save_model'
        if os.path.exists(path):
            model.save_pretrained(path)
        else:
            os.mkdir(path)
            model.save_pretrained(path)
        best_loss = avg_valid_loss
        best_p3 = p3
        logger.info("")
        logger.info("create best model")
        check_step = 0
    else:
        check_step += 1
        if check_step >= 25:
            break

    print("{:>2}     valid loss: {:5f}     p1 : {:.5f}     p3 : {:.5f}     p5 : {:.5f}     check_step : {:>2}".format(epoch_i, check_val_loss/valid_step, p1/valid_step, p3/valid_step, p5/valid_step, check_step))
    logger.info("")
    logger.info("{:>2}     valid loss: {:5f}     p1 : {:.5f}     p3 : {:.5f}     p5 : {:.5f}     check_step : {:>2}".format(epoch_i, check_val_loss/valid_step, p1/valid_step, p3/valid_step, p5/valid_step, check_step))
            
    t0 = time.time()
    
    del val_loss, check_val_loss, avg_valid_loss, avg_train_loss, total_loss

In [None]:
model = BertForMultiLabelSequenceClassification.from_pretrained('./save_model', cache_dir=None, num_labels=1383)
model = model.cuda()

In [None]:
from tqdm import tqdm

batch_size = 1024

loss_fn = nn.BCEWithLogitsLoss().to(device)
import torch.nn.functional as F

test_inputs = torch.load('./test_input.pt')
test_masks = torch.load('./test_mask.pt')
test_labels = torch.load('./test_labels.pt')

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = RandomSampler(test_data)
valid_data_loader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

p1, p3, p5, valid_step = 0.0, 0.0, 0.0, 0
for batch in tqdm(valid_data_loader):
        labels = predict_step(model, batch, 5, loss_fn)[1]
        targets = batch[2]

        p1 += get_ndcg_5(labels, targets, 1)
        p3 += get_ndcg_5(labels, targets, 3)
        p5 += get_ndcg_5(labels, targets, 5)
        valid_step += 1

print("p1 : {:.5f}     p3 : {:.5f}     p5 : {:.5f}".format(p1/valid_step, p3/valid_step, p5/valid_step))