In [37]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [38]:
import os
import sys
import torch
from tqdm import tqdm
from transformers import BertPreTrainedModel, BertModel
from transformers.models.bert.modeling_bert import BertOnlyMLMHead
from transformers import get_linear_schedule_with_warmup, AdamW
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from nltk.corpus import stopwords
from collections import defaultdict

In [39]:
from pydantic import BaseModel
from typing import Type

class GolbalOption(BaseModel):
    # 本地模型存盘目录
    local_model_dir = 'local_model'
    # 数据集所在目录
    dataset_dir = '/kaggle/input/agnews/'
    # 类别标签文件
    label_names_file = 'label_names.txt'
    # 训练文件和转换后存盘文件
    train_file = 'train.txt'
    train_load_file = 'train.pt'
    # 测试文件和转换后存盘文件
    test_file = 'test.txt'
    test_label_file = 'test_labels.txt'
    test_load_file = 'test.pt'
    # 构建分类词汇表用数据存盘文件
    label_name_load_file = 'label_name_data.pt'
    # 分类词汇表存盘文件
    category_vocab_load_file = 'category_vocab.pt'
    # mcp任务存盘文件
    mcp_train_load_file = 'train_data.pt'
    mcp_load_file = 'mcp_model.pt'
    # 模型存盘文件
    final_model = 'final_model.pt'
    out_file = 'out.txt'
    # 模型训练参数
    eval_batch_size = 128
    train_batch_size = 32
    top_pred_num = 50
    category_vocab_size = 100
    match_threshold = 20
    max_len = 200
    update_interval = 50
    accum_steps = 4
    mcp_epochs = 5
    self_train_epochs = 1
    early_stop = True
    
    bert_model = 'bert-base-uncased'
    
    # 扩展参数配置（实例化后赋值）
    device: str = None
    tokenizer : Type
    vocab: dict = None
    vocab_size : int = None
    mask_id: int = None
    inv_vocab : dict = None
    label_name_dict : dict = None
    label2class: dict = None
    num_class: int = None
    

In [40]:
def read_label_names(opt):
    """
    从文件中读取标签名
    """
    label_name_file = open(os.path.join(opt.dataset_dir, opt.label_names_file))
    label_names = label_name_file.readlines()
    # 读取每个标签中的单词list，以行号作为类别id存入字典 {0:[word1,word2,...], 1:[word1,word2,...], ...}
    label_name_dict = {i: [word.lower() for word in category_words.strip().split()] for i, category_words in enumerate(label_names)}
    print(f"每个类别使用的标签名称分别是: {label_name_dict}")
    # 所有标签类别映射字典
    label2class = {}
    for class_idx in label_name_dict:
        for word in label_name_dict[class_idx]:
            # assert标记用作标签名称的单词
            assert word not in label2class, f"\"{word}\" 作为标签名，被应用在多分类任务中"
            # 类别->类别id
            label2class[word] = class_idx
    return label_name_dict, label2class

In [41]:
from transformers import BertTokenizerFast

opt= GolbalOption(
    tokenizer=BertTokenizerFast
)
# 训练设备
opt.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# tokenizer
opt.tokenizer = BertTokenizerFast.from_pretrained(opt.bert_model, model_max_length=opt.max_len)
# 词汇表
opt.vocab = opt.tokenizer.get_vocab()
# 词汇表长度
opt.vocab_size = len(opt.vocab)
# [MASK]掩码id
opt.mask_id = opt.vocab[opt.tokenizer.mask_token]
# 反向词汇表
opt.inv_vocab = {k:v for v, k in opt.vocab.items()}
# label信息
opt.label_name_dict, opt.label2class = read_label_names(opt)
# 类别数量
opt.num_class = len(opt.label_name_dict)

# 检验并创建本地模型存盘目录
if not os.path.exists(opt.local_model_dir):
    os.makedirs(opt.local_model_dir)

# print(opt.device)
# print(opt.tokenizer('have fun!'))
# print(len(opt.vocab))
# print(opt.num_class)

#### LOTClass模型

In [42]:
class LOTClassModel(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        # MLM head is not trained
        for param in self.cls.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, pred_mode, attention_mask=None, token_type_ids=None, 
                position_ids=None, head_mask=None, inputs_embeds=None):
        bert_outputs = self.bert(input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=position_ids,
                                 head_mask=head_mask,
                                 inputs_embeds=inputs_embeds)
        last_hidden_states = bert_outputs[0]
        if pred_mode == "classification":
            trans_states = self.dense(last_hidden_states)
            trans_states = self.activation(trans_states)
            trans_states = self.dropout(trans_states)
            logits = self.classifier(trans_states)
        elif pred_mode == "mlm":
            logits = self.cls(last_hidden_states)
        else:
            sys.exit("Wrong pred_mode!")
        return logits

#### 创建分类词汇表

In [43]:
def create_input_dataset(opt, text_file, loader_name, label_file=None):
    # 尝试加载train.pt存盘文件。该文件中存储的就是经tokenizer编码后的input_ids和attention_masks
    loader_file = os.path.join(opt.local_model_dir, loader_name)
    if os.path.exists(loader_file):
        print(f"从 {loader_file} 文件中加载编码后的模型输入张量")
        data = torch.load(loader_file)
    else:
        print(f"从{os.path.join(opt.dataset_dir, text_file)}文件中读取语料")
        corpus = open(os.path.join(opt.dataset_dir, text_file), encoding="utf-8")
        docs = [doc.strip() for doc in corpus.readlines()]
        print(f"转换文本为tensor集合")
        encoded_dict = opt.tokenizer(docs, add_special_tokens=True, max_length=opt.max_len, padding='max_length',
                                  return_attention_mask=True, truncation=True, return_tensors='pt')
        input_ids = encoded_dict['input_ids']
        attention_masks = encoded_dict['attention_mask']
    
        print(f"编码后的文本语料存入文件:{loader_file}")
        if label_file is not None:
            print(f"从 {os.path.join(opt.dataset_dir, label_file)} 文件中读取Label")
            truth = open(os.path.join(opt.dataset_dir, label_file))
            labels = [int(label.strip()) for label in truth.readlines()]
            labels = torch.tensor(labels)
            data = {"input_ids": input_ids, "attention_masks": attention_masks, "labels": labels}
        else:
            data = {"input_ids": input_ids, "attention_masks": attention_masks}
        torch.save(data, loader_file)
    return data

def create_label_name_dataset(opt, text_file, loader_name):
    """
    根据提供的语料生成，用于创建“分类词汇表”的模型tensor
    """
    # 尝试加载存盘文件
    loader_file = os.path.join(opt.local_model_dir, loader_name)
    if os.path.exists(loader_file):
        print(f"从 {loader_file} 文件中加载包含标签名的语料张量")
        label_name_data = torch.load(loader_file)
    else:
        print(f"从 {os.path.join(opt.dataset_dir, text_file)} 文件中读取语料")
        corpus = open(os.path.join(opt.dataset_dir, text_file), encoding="utf-8")
        docs = [doc.strip() for doc in corpus.readlines()]
        print("检索包含类别词汇的语料")
        input_ids_with_label_name, \
        attention_masks_with_label_name, \
        label_name_idx = label_name_occurrence(opt, docs)

        assert len(input_ids_with_label_name) > 0, "语料中没有发现匹配的标签名!"
        label_name_data = {
            "input_ids": input_ids_with_label_name, 
            "attention_masks": attention_masks_with_label_name, 
            "labels": label_name_idx
        }
        # 数据存盘
        print(f"包含标签名的语料张量存入文件 {loader_file}")
        torch.save(label_name_data, loader_file)
    return label_name_data

def label_name_occurrence(opt, docs):
    """
    查找包含标签名的语料
    """
    text_with_label = []
    label_name_idx = []
    for doc in tqdm(docs, desc='文档检索'):
        result = label_name_in_doc(opt, doc)
        if result is not None:
            text_with_label.append(result[0])
            label_name_idx.append(result[1].unsqueeze(0))
    # 如果有符合条件的文本，就把文本转换为模型输入用的tensor并返回
    # 如果没有，就返回一些值全为1的张量
    if len(text_with_label) > 0:
        encoded_dict = opt.tokenizer(text_with_label, add_special_tokens=True, max_length=opt.max_len, 
                                 padding='max_length', return_attention_mask=True, truncation=True, return_tensors='pt')
        input_ids_with_label_name = encoded_dict['input_ids']
        attention_masks_with_label_name = encoded_dict['attention_mask']
        label_name_idx = torch.cat(label_name_idx, dim=0)
    else:
        input_ids_with_label_name = torch.ones(0, opt.max_len, dtype=torch.long)
        attention_masks_with_label_name = torch.ones(0, opt.max_len, dtype=torch.long)
        label_name_idx = torch.ones(0, opt.max_len, dtype=torch.long)
    return input_ids_with_label_name, attention_masks_with_label_name, label_name_idx

def label_name_in_doc(opt, doc):
    """
    在现有文本中查找标签名称并标记索引，用[MASK]替换超出词典范围的标签名称
    """
    # 拆分文本
    doc = opt.tokenizer.tokenize(doc)
    # 创建一个和语料长度一致，全部值都是-1的张量序列作为label
    label_idx = -1 * torch.ones(opt.max_len, dtype=torch.long)
    new_doc = []
    wordpcs = []
    idx = 1 # 由于[CLS] token，所以索引从1开始
    for i, wordpc in enumerate(doc):
        # 添加词汇和子词到wordpcs
        wordpcs.append(wordpc[2:] if wordpc.startswith("##") else wordpc)
        if idx >= opt.max_len - 1: # 最后一个索引应当是 [SEP]
            break
        if i == len(doc) - 1 or not doc[i+1].startswith("##"):
            word = ''.join(wordpcs)
            # 如果词汇出现在类别标签中，label序列中对应位置标记为class id
            if word in opt.label2class:
                label_idx[idx] = opt.label2class[word]
                # 用[MASK]标记替换标记化器词汇表中没有的标签名
                if word not in opt.tokenizer.get_vocab():
                    wordpcs = [opt.tokenizer.mask_token]
            new_word = ''.join(wordpcs)
            if new_word != opt.tokenizer.unk_token:
                idx += len(wordpcs)
                new_doc.append(new_word)
            wordpcs = []
    # 如果找到了匹配的标签，返回替换后的文本和对应位置的索引，否则返回None
    if (label_idx >= 0).any():
        return ' '.join(new_doc), label_idx
    else:
        return None

def make_dataloader(data_dict, batch_size):
    """
    根据已填充的数据集字典，创建并返回DataLoader
    """
    if "labels" in data_dict:
        dataset = TensorDataset(data_dict["input_ids"], data_dict["attention_masks"], data_dict["labels"])
    else:
        dataset = TensorDataset(data_dict["input_ids"], data_dict["attention_masks"])
    dataset_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataset_loader

In [44]:
def create_category_vocabulary(opt, model, label_name_data, loader_name, top_pred_num=50, category_vocab_size=100):
    """
    构建类别词汇表
    """
    # 尝试从文件中直接加载分类词汇表
    loader_file = os.path.join(opt.local_model_dir, loader_name)
    if os.path.exists(loader_file):
        print(f"从 {loader_file} 文件中加载分类词汇表")
        category_vocab = torch.load(loader_file)
    else:
        print("构建分类词汇表")

        model.eval()
        # ["input_ids","attention_masks","labels"]
        label_name_dataset_loader = make_dataloader(label_name_data, opt.eval_batch_size)
        # 统计分类标签的出现频率 {0:{},1:{},2:{},3:{}}
        category_words_freq = {i: defaultdict(float) for i in range(opt.num_class)}
        
        for batch in tqdm(label_name_dataset_loader):
            with torch.no_grad():
                input_ids = batch[0].to(opt.device)
                input_mask = batch[1].to(opt.device)
                label_pos = batch[2].to(opt.device)
                match_idx = label_pos >= 0
                # 进行MLM推理
                predictions = model(input_ids,
                                    pred_mode="mlm",
                                    token_type_ids=None, 
                                    attention_mask=input_mask)
                # 过滤有分类值的logits，提取每个分类token值最大的前50个索引
                _, sorted_res = torch.topk(predictions[match_idx], top_pred_num, dim=-1)
                # 提取出语料中所有的类别值
                label_idx = label_pos[match_idx]
                for i, word_list in enumerate(sorted_res):
                    for j, word_id in enumerate(word_list):
                        # 分别统计各分类token索引出现的次数
                        category_words_freq[label_idx[i].item()][word_id.item()] += 1
        
        # 过滤掉停用词和属于多分类的词汇后结果存入self.category_vocab
        # 存储结构 {category_id:[token_id,...],...}
        category_vocab = filter_keywords(opt, category_words_freq, category_vocab_size)

        # 保存到文件
        torch.save(category_vocab, loader_file)
    for i, cat_vocab in category_vocab.items():
        print(f"Class {i} category vocabulary: {[opt.inv_vocab[w] for w in cat_vocab]}\n")
    
    return category_vocab

def filter_keywords(opt, category_words_freq, category_vocab_size=100):
    """
    过滤掉停用词和多重分类词
    """
    # 每个token在语料中的类别列表
    all_words = defaultdict(list)
    # 筛选后的分类token字典
    sorted_dicts = {}
    # 每个分类中的token计数排序，只保留前100个  {0:{token_id:counts,...},...}
    for i, cat_dict in category_words_freq.items():
        sorted_dict = {k:v for k, v in sorted(cat_dict.items(), key=lambda item: item[1], reverse=True)[:category_vocab_size]}
        sorted_dicts[i] = sorted_dict
        for word_id in sorted_dict:
            all_words[word_id].append(i)
    # 查找在多个分类中出现的token
    repeat_words = []  
    for word_id in all_words:
        if len(all_words[word_id]) > 1:
            repeat_words.append(word_id)
    # 提取每个分类中的token_id
    category_vocab = {}  
    for i, sorted_dict in sorted_dicts.items(): 
        category_vocab[i] = np.array(list(sorted_dict.keys()))
    # nltk stopwords
    stopwords_vocab = stopwords.words('english')  
    for i, word_list in category_vocab.items():
        delete_idx = []
        for j, word_id in enumerate(word_list):
            word = opt.inv_vocab[word_id]
            # 不删除分类标签
            if word in opt.label_name_dict[i]:  
                continue
            # 删除不是纯字母的token，长度为1的token，stopword的token，在多个分类中匹配的词汇
            if not word.isalpha() or len(word) == 1 or word in stopwords_vocab or word_id in repeat_words:
                delete_idx.append(j)
        # 删除后的词汇
        category_vocab[i] = np.delete(category_vocab[i], delete_idx)

    return category_vocab

#### masked category prediction(MCP)

In [45]:
# mask分类预测
def mcp(opt, model, train_data, category_vocab, top_pred_num=50, match_threshold=20):
    """
    masked 分类预测
    """
    # 尝试加载mcp分类模型
    loader_file = os.path.join(opt.local_model_dir, opt.mcp_load_file)
    if os.path.exists(loader_file):
        print(f"\n从 {loader_file} 文件中加载通过masked进行类别预测训练的模型")
    else:
        # 准备分类模型数据
        mcp_data = prepare_mcp(opt, model, train_data, category_vocab, top_pred_num, match_threshold)
        print(f"\n通过masked进行类别预测的模型训练")
        # mcp训练
        mcp_train(opt, model, mcp_data)
    model.load_state_dict(torch.load(loader_file))
    return model

# masked 分类预测
def mcp_train(opt, model, mcp_data):
    mcp_dataset_loader = make_dataloader(mcp_data, opt.train_batch_size)
    total_steps = len(mcp_dataset_loader) * opt.mcp_epochs / opt.accum_steps
    mcp_loss = nn.CrossEntropyLoss()
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)
    
    for i in range(opt.mcp_epochs):
        model.train()
        total_train_loss = 0
        print(f"Epoch {i+1}:")
        model.zero_grad()
        for j, batch in enumerate(tqdm(mcp_dataset_loader)):
            input_ids = batch[0].to(opt.device)
            input_mask = batch[1].to(opt.device)
            labels = batch[2].to(opt.device)
            mask_pos = labels >= 0
            labels = labels[mask_pos]
            # 屏蔽分类相关的指示词
            input_ids[mask_pos] = opt.mask_id
            logits = model(input_ids, 
                            pred_mode="classification",
                            token_type_ids=None, 
                            attention_mask=input_mask)
            # 筛选出与lable相关的logits值
            logits = logits[mask_pos]
            loss = mcp_loss(logits.view(-1, opt.num_class), labels.view(-1)) / opt.accum_steps
            total_train_loss += loss.item()
            loss.backward()
            if (j+1) % opt.accum_steps == 0:
                # 梯度裁剪为1.0
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
        avg_train_loss = torch.tensor([total_train_loss / len(mcp_dataset_loader) * opt.accum_steps]).to(opt.device)
        
        print(f"Average training loss: {avg_train_loss.mean().item()}")
    
    loader_file = os.path.join(opt.local_model_dir, opt.mcp_load_file)
    torch.save(model.state_dict(), loader_file)
    
# 为自监督的masked分类预测做准备
def prepare_mcp(opt, model, train_data, category_vocab, top_pred_num=50, match_threshold=20):
    # 尝试加载mcp训练数据
    loader_file = os.path.join(opt.local_model_dir, opt.mcp_train_load_file)
    if os.path.exists(loader_file):
        print(f"Loading masked category prediction data from {loader_file}")
        mcp_data = torch.load(loader_file)
    else:
        print("准备自监督的masked分类预测")
        model.eval()
        # 创建模型用的Dataloader
        train_dataset_loader = make_dataloader(train_data, opt.eval_batch_size)
        all_input_ids = []
        all_mask_label = []
        all_input_mask = []
        category_doc_num = defaultdict(int)

        for batch in tqdm(train_dataset_loader):
            with torch.no_grad():
                input_ids = batch[0].to(opt.device)
                input_mask = batch[1].to(opt.device)
                predictions = model(input_ids,
                                    pred_mode="mlm",
                                    token_type_ids=None,
                                    attention_mask=input_mask)
                # 提取vocabulary值最大的前50个token
                _, sorted_res = torch.topk(predictions, top_pred_num, dim=-1)
                # 遍历分类词汇表
                for i, cat_vocab in category_vocab.items():
                    # 默认值为0的匹配索引集合
                    match_idx = torch.zeros_like(sorted_res).bool()
                    # 筛选预测和当前类别的分类词汇表的相同项，使用逻辑或运算拼接
                    for word_id in cat_vocab:
                        match_idx = (sorted_res == word_id) | match_idx
                    # 统计匹配的筛选分类词汇数量
                    match_count = torch.sum(match_idx.int(), dim=-1)
                    # 至少有20个词汇和分类词汇表相同
                    valid_idx = (match_count > match_threshold) & (input_mask > 0)
                    # 筛选满足条件的语句
                    valid_doc = torch.sum(valid_idx, dim=-1) > 0
                    # 只要有满足条件的语句,就收集该语句的模型输入张量,并记录对应类别下
                    if valid_doc.any():
                        # 准备masked_label矩阵
                        mask_label = -1 * torch.ones_like(input_ids)
                        # 满足条件语句中token位置标记分类id
                        mask_label[valid_idx] = i
                        # 模型训练的输入数据
                        all_input_ids.append(input_ids[valid_doc].cpu())
                        all_mask_label.append(mask_label[valid_doc].cpu())
                        all_input_mask.append(input_mask[valid_doc].cpu())
                        # 类别中满足条件的token数量
                        category_doc_num[i] += valid_doc.int().sum().item()
        # 拼接list，转换为张量
        all_input_ids = torch.cat(all_input_ids, dim=0)
        all_mask_label = torch.cat(all_mask_label, dim=0)
        all_input_mask = torch.cat(all_input_mask, dim=0)
        
        
        mcp_data = {
            "input_ids": all_input_ids, 
            "attention_masks": all_input_mask, 
            "labels": all_mask_label,
            "category_doc_num": category_doc_num
        }
        # mcp训练数据存盘
        torch.save(mcp_data, loader_file)

    print(f"为每个类别找到的带有类别指示性术语的文档数量为: {mcp_data['category_doc_num']}")
    for i in mcp_data["category_doc_num"]:
        assert mcp_data["category_doc_num"][i] > 10, f"为类别{i}找到的带有类别指示性词汇({category_doc_num[i]})的文档太少;" \
                "尝试向训练语料库添加更多未标记文档(推荐)或减少`-- match_threshold `参数的值(不推荐)"
    print(f"共有{len(mcp_data['input_ids'])}个文档带有类别指示性词汇。")

    return mcp_data

#### self train

In [46]:
def self_train(opt, model, train_data, test_data=None):
    loader_file = os.path.join(opt.local_model_dir, opt.final_model)
    if os.path.exists(loader_file):
        print(f"\n发现 {loader_file} 中的最终模型存盘, 跳过self_training")
    else:
        # 生成乱序索引(random permute)
        rand_idx = torch.randperm(len(train_data["input_ids"]))
        # 训练数据
        train_data = {
            "input_ids": train_data["input_ids"][rand_idx],
            "attention_masks": train_data["attention_masks"][rand_idx]
        }
        print(f"\nStart self-training.")

        # 如果test_data不为空,就生成测试集data loader
        test_dataset_loader = make_dataloader(test_data, opt.eval_batch_size) if test_data is not None else None
        # 训练步数
        total_steps = int(len(train_data["input_ids"]) * opt.self_train_epochs / (opt.train_batch_size * opt.accum_steps))
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6, eps=1e-8)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)
        idx = 0
        if opt.early_stop:
            agree_count = 0
        for i in range(int(total_steps / opt.update_interval)):
            # 准备自训练的数据
            self_train_dict, idx, agree = prepare_self_train_data(opt, model, train_data, idx)
            # 如果当前预测连续3次与更新的目标分布一致，则提前停止训练
            if opt.early_stop:
                if 1 - agree < 1e-3:
                    agree_count += 1
                else:
                    agree_count = 0
                if agree_count >= 3:
                    break
            self_train_dataset_loader = make_dataloader(self_train_dict, opt.train_batch_size)
            self_train_batches(opt, model, self_train_dataset_loader, optimizer, scheduler, test_dataset_loader)
        
        loader_file = os.path.join(opt.local_model_dir, opt.final_model)
        print(f"保存训练好的模型到 {loader_file} 文件")
        torch.save(model.state_dict(), loader_file)

def prepare_self_train_data(opt, model, train_data, idx):
    """
    准备自训练的数据和目标
    """
    # 拆分数据批次大小 batch_size * update_interval * accum_steps
    target_num = min(opt.train_batch_size * opt.update_interval * opt.accum_steps, len(train_data["input_ids"]))
    # 生成并修正数据筛选的索引值
    if idx + target_num >= len(train_data["input_ids"]):
        select_idx = torch.cat((torch.arange(idx, len(train_data["input_ids"])),
                                torch.arange(idx + target_num - len(train_data["input_ids"]))))
    else:
        select_idx = torch.arange(idx, idx + target_num)
    assert len(select_idx) == target_num
    idx = (idx + len(select_idx)) % len(train_data["input_ids"])
    select_dataset = {
        "input_ids": train_data["input_ids"][select_idx],
        "attention_masks": train_data["attention_masks"][select_idx]
    }

    dataset_loader = make_dataloader(select_dataset, opt.eval_batch_size)
    # 模型分类，利用[CLS]的输出进行推理
    input_ids, input_mask, all_preds = inference(opt, model, dataset_loader, return_type="data")

    # soft labeling
    weight = all_preds**2 / torch.sum(all_preds, dim=0)
    target_dist = (weight.t() / torch.sum(weight, dim=1)).t()
    all_target_pred = target_dist.argmax(dim=-1)

    agree = (all_preds.argmax(dim=-1) == all_target_pred).int().sum().item() / len(all_target_pred)

    self_train_dict = {
        "input_ids": input_ids, 
        "attention_masks": input_mask, 
        "labels": target_dist
    }
    return self_train_dict, idx, agree
    
def self_train_batches(opt, model, self_train_loader, optimizer, scheduler, test_dataset_loader):
    """
    使用带有目标标签的批数据训练模型
    """
    st_loss = nn.KLDivLoss(reduction='batchmean')
    model.train()
    total_train_loss = 0
    wrap_train_dataset_loader = tqdm(self_train_loader,desc='self training')
    model.zero_grad()
    
    for j, batch in enumerate(wrap_train_dataset_loader):
        input_ids = batch[0].to(opt.device)
        input_mask = batch[1].to(opt.device)
        target_dist = batch[2].to(opt.device)
        logits = model(input_ids,
                        pred_mode="classification",
                        token_type_ids=None,
                        attention_mask=input_mask)
        logits = logits[:, 0, :]
        # 先算softmax再计算log
        preds = nn.LogSoftmax(dim=-1)(logits)
        loss = st_loss(preds.view(-1, opt.num_class), target_dist.view(-1, opt.num_class)) / opt.accum_steps
        total_train_loss += loss.item()
        loss.backward()
        if (j+1) % opt.accum_steps == 0:
            # Clip the norm of the gradients to 1.0.
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
    if test_dataset_loader is not None:
        acc = inference(opt, model, test_dataset_loader, return_type="acc")
        acc = torch.tensor(acc).mean().item()

    avg_train_loss = torch.tensor([total_train_loss / len(wrap_train_dataset_loader) * opt.accum_steps]).to(opt.device)

    print(f"lr: {optimizer.param_groups[0]['lr']:.4g}")
    print(f"Average training loss: {avg_train_loss.mean().item()}")
    if test_dataset_loader is not None:
        print(f"Test acc: {acc}")
        
def inference(opt, model, dataset_loader, return_type):
    """
    使用模型对Dataloader进行推理,针对return_type参数不同取值,返回不同结果
    """
    if return_type == "data":
        all_input_ids = []
        all_input_mask = []
        all_preds = []
    elif return_type == "acc":
        pred_labels = []
        truth_labels = []
    elif return_type == "pred":
        pred_labels = []
    model.eval()
    
    for batch in tqdm(dataset_loader,desc=f'{return_type} processing'):
        with torch.no_grad():
            input_ids = batch[0].to(opt.device)
            input_mask = batch[1].to(opt.device)
            logits = model(input_ids,
                            pred_mode="classification",
                            token_type_ids=None,
                            attention_mask=input_mask)
            # 提取[CLS]的logits进行文本分类推理
            logits = logits[:,0,:]
            if return_type == "data":
                all_input_ids.append(input_ids)
                all_input_mask.append(input_mask)
                all_preds.append(nn.Softmax(dim=-1)(logits))
            elif return_type == "acc":
                labels = batch[2]
                pred_labels.append(torch.argmax(logits, dim=-1).cpu())
                truth_labels.append(labels)
            elif return_type == "pred":
                pred_labels.append(torch.argmax(logits, dim=-1).cpu())

    if return_type == "data":
        all_input_ids = torch.cat(all_input_ids, dim=0)
        all_input_mask = torch.cat(all_input_mask, dim=0)
        all_preds = torch.cat(all_preds, dim=0)
        return all_input_ids, all_input_mask, all_preds
    elif return_type == "acc":
        pred_labels = torch.cat(pred_labels, dim=0)
        truth_labels = torch.cat(truth_labels, dim=0)
        samples = len(truth_labels)
        acc = (pred_labels == truth_labels).float().sum() / samples
        return acc.to(opt.device)
    elif return_type == "pred":
        pred_labels = torch.cat(pred_labels, dim=0)
        return pred_labels

#### 模型训练

In [47]:
# 创建模型对象
model = LOTClassModel.from_pretrained(
    opt.bert_model,
    output_attentions=False,
    output_hidden_states=False,
    num_labels=opt.num_class
)
model.to(opt.device)

# 构建分类别词汇表
label_name_data = create_label_name_dataset(opt, opt.train_file, opt.label_name_load_file)
category_vocab = create_category_vocabulary(
    opt = opt, 
    model = model, 
    label_name_data = label_name_data, 
    loader_name = opt.category_vocab_load_file, 
    top_pred_num=opt.top_pred_num, 
    category_vocab_size=opt.category_vocab_size
)

# 训练mask类别的预测
input_data = create_input_dataset(opt, opt.train_file, opt.train_load_file)
mcp_model = mcp(
    opt=opt, 
    model=model, 
    train_data=input_data, 
    category_vocab = category_vocab,
    top_pred_num=opt.top_pred_num, 
    match_threshold=opt.match_threshold
)

# 自训练
test_data = create_input_dataset(opt, opt.test_file, opt.test_load_file, opt.test_label_file)
self_train(opt, mcp_model, input_data, test_data)

In [48]:

loader_file = os.path.join(opt.local_model_dir, opt.final_model)
print(f"保存训练好的模型到 {loader_file} 文件")
torch.save(model.state_dict(), loader_file)