In [1]:
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from pathlib import Path
import numpy as np
from torch.utils.data import DataLoader
from collections import Counter
import gc
import math
from ark_nlp.dataset.base._sentence_classification_dataset import SentenceClassificationDataset
from ark_nlp.factory.loss_function.focal_loss import FocalLoss
from transformers import BertConfig, BertModel
from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer
from model.nezha.configuration_nezha import NeZhaConfig
from model.nezha.modeling_nezha import NeZhaModel, NeZhaForSequenceClassification
from tokenizer import BertSpanTokenizer
from utils import WarmupLinearSchedule, seed_everything, get_default_bert_optimizer
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from task import Task
from tqdm import tqdm
from argparse import ArgumentParser
from model.model import BertForSequenceClassification, BertEnsambleForSequenceClassification, BertBiLSTMForSequenceClassification
from data_process import text_enchance
import pandas as pd
import torch
import os
import warnings
from torch import nn
warnings.filterwarnings("ignore")
seed_everything(42)


In [2]:
class FeatureExtract(nn.Module):
    def __init__(self, config, bert):
        super().__init__()

        self.num_labels = config.num_labels
        self.config = config

        self.bert = bert

        self.bilstm = nn.LSTM(input_size=config.hidden_size,
                              hidden_size=config.hidden_size // 2,
                              bidirectional=True,
                              dropout=0.1,
                              batch_first=True)

        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        **kargs
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )

        sequence_output = outputs[0]

        output, (h_n, c_n) = self.bilstm(sequence_output)

        pooled_output = output[:, 0]

        return pooled_output

In [4]:
class Args:
    def __init__(self):
        self.data_path = '../data/a_dataset/query_data.csv'
        self.goods_data_path = '../data/a_dataset/goods_data.csv'
        self.model_name_or_path = '../pretrain_model/uer_large/'
        self.model_type = 'uer_denoise_bilstm'
        self.checkpoint = './checkpoint/'
        self.test_file = '../data/a_dataset/test_a.csv'
        self.max_seq_len = 64
        self.batch_size = 16
        self.num_workers = 0
        self.fold = 5
        self.seed = 42
        self.predict_model = ''
        self.device = 'cuda:0'


def build_model_and_tokenizer(args, num_labels):
    tokenizer = BertSpanTokenizer(vocab=args.model_name_or_path,
                                  max_seq_len=args.max_seq_len)
    config = BertConfig.from_pretrained(args.model_name_or_path,
                                        num_labels=num_labels)
    bert = BertModel(config=config)
    dl_module = FeatureExtract(config, bert)
    return tokenizer, dl_module


def extract_cv(args):
    data_df = pd.read_csv(args.data_path)
    goods_df = pd.read_csv(args.goods_data_path)
    data_df = pd.concat([data_df, goods_df])
    data_df['label'] = data_df['label'].apply(lambda x: str(x))

    data_df['text'] = data_df['text'].apply(lambda x: text_enchance(x))
    data_df = data_df.drop(data_df[(data_df['text'] == '')].index)

    test_data_df = pd.read_csv(args.test_file)
    test_data_df['label'] = 1
    test_data_df['label'] = test_data_df['label'].apply(lambda x: str(x))
    test_data_df['text'] = test_data_df['text'].apply(
        lambda x: text_enchance(x))
    test_data_df.loc[(test_data_df['text'] == ''), 'text'] = '比赛占位字符'

    test_dataset = SentenceClassificationDataset(
        test_data_df, categories=sorted(data_df['label'].unique()))
    tokenizer, _ = build_model_and_tokenizer(
        args, len(test_dataset.cat2id))

    test_dataset.convert_to_ids(tokenizer)
    test_generator = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.num_workers)

    def extract(generator):
        feature_vector = []

        with torch.no_grad():
            for inputs in tqdm(generator):
                inputs['input_ids'] = inputs['input_ids'].to(
                    torch.device(args.device))
                inputs['attention_mask'] = inputs['attention_mask'].to(
                    torch.device(args.device))
                inputs['token_type_ids'] = inputs['token_type_ids'].to(
                    torch.device(args.device))
                inputs['label_ids'] = inputs['label_ids'].to(
                    torch.device(args.device))

                outputs = model(**inputs)
                feature_vector.append(outputs.cpu().numpy())
        feature_vector = np.vstack(feature_vector)
        return feature_vector

    kfold = StratifiedKFold(
        n_splits=args.fold, shuffle=True, random_state=args.seed)
    args.checkpoint = os.path.join(args.checkpoint, args.model_type)
    model_type = args.model_type

    train_data = []
    dev_data = []
    test_data = []

    for fold, (train_idx, dev_idx) in enumerate(kfold.split(data_df, data_df['label'])):
        print(f'========== {fold + 1} ==========')
        args.model_type = f'{model_type}-{fold + 1}'
        args.predict_model = os.path.join(
            args.checkpoint, args.model_type, 'best_model.pth')

        train_data_df, dev_data_df = data_df.iloc[train_idx], data_df.iloc[dev_idx]
        train_dataset = SentenceClassificationDataset(
            train_data_df, categories=sorted(train_data_df['label'].unique()))
        dev_dataset = SentenceClassificationDataset(
            dev_data_df, categories=train_dataset.categories)

        tokenizer, model = build_model_and_tokenizer(
            args, len(train_dataset.cat2id))

        train_dataset.convert_to_ids(tokenizer)
        dev_dataset.convert_to_ids(tokenizer)

        train_generator = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=args.num_workers)

        dev_generator = DataLoader(
            dev_dataset,
            batch_size=args.batch_size,
            pin_memory=True,
            num_workers=args.num_workers)

        model.load_state_dict(torch.load(args.predict_model), strict=False)
        model.to(torch.device(args.device))
        print('抽取train')
        train_data.append([extract(train_generator), fold])
        print('抽取val')
        dev_data.append([extract(dev_generator), fold])
        print('抽取test')
        test_data.append([extract(test_generator), fold])

        del model, tokenizer
        gc.collect()
        torch.cuda.empty_cache()
    return train_data, dev_data, test_data


args = Args()

train_data, dev_data, test_data = extract_cv(args)


抽取train


100%|██████████| 1971/1971 [02:25<00:00, 13.50it/s]


抽取val


100%|██████████| 493/493 [00:36<00:00, 13.63it/s]


抽取test


100%|██████████| 313/313 [00:22<00:00, 13.63it/s]


抽取train


100%|██████████| 1971/1971 [02:24<00:00, 13.69it/s]


抽取val


100%|██████████| 493/493 [00:35<00:00, 13.71it/s]


抽取test


100%|██████████| 313/313 [00:22<00:00, 13.68it/s]


抽取train


100%|██████████| 1971/1971 [02:24<00:00, 13.67it/s]


抽取val


100%|██████████| 493/493 [00:36<00:00, 13.64it/s]


抽取test


100%|██████████| 313/313 [00:22<00:00, 13.64it/s]


抽取train


100%|██████████| 1971/1971 [02:24<00:00, 13.66it/s]


抽取val


100%|██████████| 493/493 [00:36<00:00, 13.62it/s]


抽取test


100%|██████████| 313/313 [00:22<00:00, 13.61it/s]


抽取train


100%|██████████| 1971/1971 [02:24<00:00, 13.64it/s]


抽取val


100%|██████████| 493/493 [00:36<00:00, 13.66it/s]


抽取test


100%|██████████| 313/313 [00:22<00:00, 13.64it/s]


In [6]:
train = pd.DataFrame(train_data, columns=['vector', 'fold'])

In [7]:
train

Unnamed: 0,vector,fold
0,"[[-0.58633494, 0.02467776, -0.7460927, -0.6725...",0
1,"[[-0.29947793, 0.28890938, 0.30287006, -0.6648...",1
2,"[[-0.66576487, 0.32492572, 0.021158926, -0.672...",2
3,"[[0.0023959056, 1.0785404e-05, -0.608786, -0.0...",3
4,"[[0.30451664, 0.14475796, -0.75497985, 0.01594...",4
