In [1]:
import os
import torch
import torch.nn as nn
from transformers import BartForSequenceClassification
from config import get_config
import numpy as np
from tqdm import tqdm
from keyword_process import MyDataset, DataLoader, paired_collate_fn
from log import Logger
import warnings
warnings.filterwarnings('ignore')
from tensorboardX import SummaryWriter

In [2]:
class Trainer:
    def __init__(self, config):
        self.config = config
        self.logger = Logger(log_path=config.save_path).logger
        self.model = BartForSequenceClassification.from_pretrained(
            config.base_params, num_labels=9,  problem_type='single_label_classification').to(config.device)
        self.scheduler = ''
        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=config.lr)
        self.criterion = nn.CrossEntropyLoss().to(config.device)
        self.writer = SummaryWriter(config.tensorboard_path)

    def _update(self, inputs, mode):
        inputs = [i.to(self.config.device)if not isinstance(i, list) else i for i in inputs ]
        enc_input_padded, enc_mask, labels = inputs
        labels = torch.Tensor(labels).long().to(self.config.device)
        out = self.model(enc_input_padded, enc_mask, labels=labels)
        outputs, loss = out.logits, out.loss
        num_corr, num = self._cal_preformance(outputs, labels)
        if mode == 'train':
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        return loss.item(), num_corr, num

    def _cal_preformance(self, logits, labels):
        # mask = mask.unsqueeze(dim=-1).repeat(1, 1, logits.shape[-1])
        # logits = logits.masked_select(mask).view(-1, logits.shape[-1])
        # logits = logits.reshape(-1)
        # labels = torch.Tensor(labels).to(self.config.device)
        # loss = self.model(logits, labels=labels).loss
        # loss = self.criterion(logits, labels)
        num_corr = 0
#         for i in range(logits.shape[0]):
#             if ((logits[0] > 0.5) == labels[i]).all():
#                 num_corr += 1
        _, indices = logits.max(dim=1)
        num_corr = indices.eq(labels).sum().item()
        return num_corr, logits.shape[0]

    def _run_epoch(self, dataset, mode):
        if mode not in ['train', 'eval']:
            raise Exception("you must select 'train' or 'eval' as a value of mode!")
        total_corr = 0  # 预测正确总个数
        total_labs = 0  # 总标签数
        total_loss = 0  # 总损失
        total_num = 0  # 样本总个数
        if mode == 'train':
            self.model.train()
        else:
            self.model.eval()
        for idx, inputs in tqdm(enumerate(dataset), total=len(dataset)):
            # if idx == 29: continue
            # print(idx)

            loss, num_corr, num = self._update(inputs, mode)

            total_corr += num_corr
            total_labs += num
            total_loss += loss
            total_num += 1

            # if idx == 33: exit()
        avg_loss = round(total_loss/(total_num), 5)  # 平均损失
        accura = round(total_corr/(total_labs), 4)  # 准确率
        # f1 = round(2*recall*accura/(recall+accura+1e-5), 2)  # f1
        return avg_loss, accura

    def train(self, num_eopch, train_dataset, val_dataset, save_path, checkpoint_path=False):
        start_epoch = 0
        if checkpoint_path:
            checkpoint = torch.load(checkpoint_path, map_location=self.config['device'])
            start_epoch = checkpoint['epoch']
            assert start_epoch < num_eopch
            self.model.load_state_dict(checkpoint['params'])

        t_avg_loss, t_accura = self._run_epoch(train_dataset, 'eval')
        e_avg_loss, e_accura = self._run_epoch(val_dataset, 'eval')
        self.writer.add_scalars('show', {'train-loss': t_avg_loss,
                                               'eval-loss': e_avg_loss,
                                               'train-acc': t_accura,
                                               'eval-acc': e_accura}
                                       , 1)

        self.logger.info(f'-Train  loss:{t_avg_loss}   accuracy:{t_accura * 100}%')
        self.logger.info(f'-Eval   loss:{e_avg_loss}   accuracy:{e_accura * 100}%')
        best_params, save_epoch = 0, 0
        for epoch in range(start_epoch, num_eopch):
            self.logger.info(f'Epoch {epoch+1}/{num_eopch}:')
            t_avg_loss, t_accura = self._run_epoch(train_dataset, 'train')
            e_avg_loss, e_accura = self._run_epoch(val_dataset, 'eval')
            self.writer.add_scalars('show', {'train-loss': t_avg_loss,
                                               'eval-loss': e_avg_loss,
                                               'train-acc': t_accura,
                                               'eval-acc': e_accura}
                                       , epoch+2)
            self.logger.info(f'-Train  loss:{t_avg_loss}   accuracy:{t_accura*100}%')
            self.logger.info(f'-Eval   loss:{e_avg_loss}   accuracy:{e_accura*100}%')
            if abs(t_accura-e_accura) <= 0.1 and t_accura >= 0.93 and e_accura >= 0.93:
                if e_accura>best_params:
                    best_params = e_accura
                    model_state_dict = self.model.state_dict()
                    checkpoint = {
                        'params': model_state_dict,
                        'configs': self.config,
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'epoch': epoch}
                    if not os.path.exists(save_path):
                        os.mkdir(save_path)
                    if os.path.exists(save_path + '/' + str(save_epoch) + '.pth'):
                        os.remove(save_path + '/' + str(save_epoch) + '.pth')
                        self.logger.info(f'the model saved in: {save_path}/{save_epoch}.pth  will be revomed')
                    file_name = f'{save_path}/{epoch}.pth'
                    self.logger.info(f'the model will be saved in: {file_name}')
                    torch.save(checkpoint, file_name)
                    save_epoch = epoch

    def test(self, test_dataset, checkpoint_path=False):
        if checkpoint_path:
            checkpoint = torch.load(checkpoint_path, map_location=self.config.device)
            self.model.load_state_dict(checkpoint['params'])
        avg_loss, accura = self._run_epoch(test_dataset, 'eval')
        self.logger.info(f'-test   loss:{avg_loss}   accuracy:{accura * 100}%')


In [3]:
config = get_config()
# 定义Summary_Writer
print(config)
train_dataset = MyDataset(config.data_path, mode='train')
eval_dataset =MyDataset(config.data_path, mode='test')
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, collate_fn=paired_collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size, collate_fn=paired_collate_fn)
trainer = Trainer(config)
# trainer.train(config.num_epoch, train_loader, eval_loader, config.save_path,)
#                   checkpoint_path='./data/ckpts/fine_tuned/30.pth')
# trainer.test(eval_loader, checkpoint_path='./ckpts/finetune2/7.pth')


Namespace(base_params='model_path', batch_size=32, data_path='data/营业厅数字人项目事项v1.0-人工匹配规则枚举.xlsx', device='cuda:0', lr=1e-05, max_len=512, num_epoch=300, save_path='./ckpts/finetune_keyword', seed=1, split_rate=0.8, tensorboard_path='./ckpts/finetune_keyword_tensorboard')


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'BertTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'BertTokenizer'.
Some weights of BartForSequenceClassification were not initialized from the model checkpoint at model_path and are newly initialized: ['classification_head.out_proj.bias', 'classification_head.out_proj.weight', 'classification_head.dense.bias', 'classification_head.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
train = [i[0] for i in train_dataset]
test = [i[0] for i in eval_dataset]
test

['换套餐',
 '多少分钟',
 '更 套餐',
 '通话时间',
 '啥套餐',
 '优惠',
 '套餐办理',
 '哪个套餐',
 '月租',
 '超出费用',
 '超出 资费',
 '通话多少',
 '分钟数',
 '超出收费',
 '超出 资费',
 '超出 收费',
 '什么流量包',
 '免流量',
 '限速吗',
 '行行行',
 '麻烦你了',
 '我要',
 '记一个',
 '哦那你说',
 '哎没问题',
 '好的我晓得了',
 '怎样操作',
 '没什么问题',
 '啊也行',
 '行的吧',
 '用',
 '好的好的好的',
 '哦没问题',
 '哦可以啊',
 '行没有问题',
 '你说吧',
 '哦好的',
 '嗯好',
 '好的呀',
 '好哦',
 '嗯行可以的',
 '嗯是的',
 '讲吧讲吧',
 '哎好的',
 '额好啊',
 '可以',
 '额可以',
 '办',
 '讲讲看',
 '嗯没问题',
 '对是',
 '嗯',
 '可以没有问题',
 '可以啊',
 '行行好的',
 '哦哦行啊',
 '可以啊没问题',
 '哎对',
 '哦哦哦',
 '说吧',
 '要用',
 '哦好',
 '使用',
 '嗯可以',
 '哦好好',
 '行呗',
 '木有兴趣',
 '没这想法',
 '我没需求',
 '不很想',
 '还没想法',
 '不办',
 '用不上',
 '不想',
 '不同意',
 '好的不用了',
 '不必要',
 '没这方面想法',
 '还是算了吧',
 '不做这个',
 '不会考虑',
 '这个不合适我',
 '那我不要',
 '没想过',
 '不去',
 '没这个需求',
 '不弄了',
 '没计划',
 '没想法',
 '不可以',
 '啊我不想',
 '没有什么需求',
 '没有这个想法',
 '真的没兴趣',
 '没有需求',
 '太麻烦不要',
 '没这方面需要',
 '下次在合作',
 '都不用',
 '我不想了解',
 '不给',
 '现在不考虑',
 '没做了',
 '不了',
 '不搞了',
 '不用来',
 '没有想法',
 '现在不想做',
 '我没需要',
 '没这个需要',
 '没什么需求',
 '再一遍',
 '哇',
 '再说一遍可以吗',
 '没听见',
 '听得不清晰',


In [5]:
from tqdm import tqdm
count = 0
for i in tqdm(test):
    if i in train : 
        count+= 1
count

100%|██████████| 238/238 [00:00<00:00, 122604.32it/s]


160