In [6]:
import os
import json
import sys
import importlib
import pandas as pd
from collections import defaultdict

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer

sys.path.append("../")
import easynlp

In [2]:
from easynlp.fewshot_learning.fewshot_dataset import FewshotMultiLayerBaseDataset
# 重载
module = importlib.import_module(FewshotMultiLayerBaseDataset.__module__)
importlib.reload(module)
FewshotMultiLayerBaseDataset = module.FewshotMultiLayerBaseDataset

In [5]:
train_file = r"G:\dataset\text_classify\网页层次分类\train.csv"
test_file = r"G:\dataset\text_classify\网页层次分类\test.csv"
label_file = r"G:\dataset\text_classify\网页层次分类\label.json"
model_dir = r"G:\code\pretrain_model_dir\pai-bert-base-zh"

tokenizer = BertTokenizer.from_pretrained(model_dir)
tokenizer: BertTokenizer

In [22]:
# 先解析 label
with open(label_file, "r", encoding="utf-8") as f:
    # 这边 id 是 int
    label2id = json.load(f)
    label2id = {k: str(v) for k, v in label2id.items()}
    id2label = {v: k for k, v in label2id.items()}

# 想想每一层的标签应该怎么建立, 现在假设每层都是完整的, 也就是每个样本的标签都会到最后一层
label_enumerate_values = defaultdict(list)
label_desc = defaultdict(list)
for label, label_id in label2id.items():
    label_split = label.split(">")
    # 标签应该只放最后一层的
    idx = len(label_split) - 1
    label_enumerate_values[idx].append(label_id)
    label_desc[idx].append(label_split[-1])

# 其实应该还是有序的, 因为添加的时候是从左到右添加的
# 计算 label_desc 的最大长度
for idx in sorted(label_desc.keys()):
    cur_list = label_desc[idx]
    cur_max_len = max([len(tokenizer.tokenize(x)) for x in cur_list])
    print(cur_max_len)
    # 填充到最大长度
    label_desc[idx] = [x + "[PAD]" * (cur_max_len - len(tokenizer.tokenize(x))) for x in cur_list]

label_enumerate_values_new = []
for idx in sorted(label_enumerate_values.keys()):
    label_enumerate_values_new.append(",".join(label_enumerate_values[idx]))
label_enumerate_values_new = "@@".join(label_enumerate_values_new)

label_desc_new = []
for idx in sorted(label_desc.keys()):
    label_desc_new.append(",".join(label_desc[idx]))
label_desc_new = "@@".join(label_desc_new)

4
5


In [23]:
label_enumerate_values_new

'0,7,13,21,31,37,42@@1,2,3,4,5,6,8,9,10,11,12,14,15,16,17,18,19,20,22,23,24,25,26,27,28,29,30,32,33,34,35,36,38,39,40,41,43,44,45,46,47'

In [24]:
label_desc_new

'休闲娱乐,电脑网络,商业经济,生活服务,教育文化,博客论坛,综合其他@@影视音乐[PAD],游戏动漫[PAD],图片小说[PAD],聊天交友[PAD],娱乐其他[PAD],休闲健身[PAD],商务门户[PAD],网站资源[PAD],软硬件通信,网络其他[PAD],网络营销[PAD],农林牧渔[PAD],工业制品[PAD],机械电子[PAD],建筑环境[PAD],法律金融[PAD],商务服务[PAD],交通物流[PAD],生活用品[PAD],餐饮美食[PAD],房产家居[PAD],旅游交通[PAD],百货购物[PAD],医疗保健[PAD],时尚美容[PAD],生活常识[PAD],生活其他[PAD],高校教育[PAD],人力资源[PAD],高级教育[PAD],文体艺术[PAD],文化其他[PAD],休闲娱乐[PAD],电脑网络[PAD],生活服务[PAD],博客其他[PAD],团体组织[PAD],综合网站[PAD],个人网站[PAD],新闻综合[PAD],其他[PAD][PAD][PAD]'

In [25]:
dataset = FewshotMultiLayerBaseDataset(
    pretrained_model_name_or_path=model_dir,
    data_file=train_file,
    max_seq_length=64,
    first_sequence="text",
    input_schema="text:str:1,label0:str:1,label1:str:1",
    label_name="label0,label1",
    label_enumerate_values=label_enumerate_values_new,
    layer_num=2,
    user_defined_parameters={
        "app_parameters": {
            "pattern": "一条,label0,label1,的新闻,text",
            "label_desc": label_desc_new,
        }
    }
)

****G:\dataset\text_classify\网页层次分类\train.csv


In [26]:
dataset.masked_length

[4, 5]

In [27]:
dataset[0].keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'label_ids', 'mask_span_indices'])

In [28]:
print(dataset[0]["input_ids"])
print(dataset[0]["token_type_ids"])
print(dataset[0]["attention_mask"])
print(dataset[0]["label_ids"])
print(dataset[0]["mask_span_indices"])

[101, 671, 3340, 103, 103, 103, 103, 103, 103, 103, 103, 103, 4638, 3173, 7319, 1952, 1814, 7391, 2501, 5285, 4970, 2443, 1773, 2356, 1920, 1814, 1344, 1952, 1814, 7391, 2501, 5285, 4970, 3300, 7361, 1062, 1385, 855, 754, 2769, 1744, 5401, 714, 2168, 7657, 510, 5307, 3845, 1355, 6809, 776, 3823, 1538, 1765, 1277, 8024, 3315, 1062, 1385, 3221, 671, 2157, 683, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[-100, -100, -100, 4495, 3833, 3302, 1218, 2791, 772, 2157, 2233, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -

In [29]:
print(tokenizer.convert_ids_to_tokens(dataset[0]["input_ids"]))
print(tokenizer.convert_ids_to_tokens(dataset[0]["label_ids"]))

['[CLS]', '一', '条', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '的', '新', '闻', '奥', '城', '隐', '形', '纱', '窗', '廊', '坊', '市', '大', '城', '县', '奥', '城', '隐', '形', '纱', '窗', '有', '限', '公', '司', '位', '于', '我', '国', '美', '丽', '富', '饶', '、', '经', '济', '发', '达', '京', '津', '唐', '地', '区', '，', '本', '公', '司', '是', '一', '家', '专', '[PAD]']
['[UNK]', '[UNK]', '[UNK]', '生', '活', '服', '务', '房', '产', '家', '居', '[PAD]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]']


In [30]:
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.batch_fn)

batch = next(iter(loader))
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'label_ids', 'mask_span_indices'])

In [31]:
for key, val in batch.items():
    print(key, val.shape)

input_ids torch.Size([2, 64])
attention_mask torch.Size([2, 64])
token_type_ids torch.Size([2, 64])
label_ids torch.Size([2, 64])
mask_span_indices torch.Size([2, 9, 1])
