In [55]:
import os
import json
import sys
import importlib
import pandas as pd

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer

sys.path.append("../")
import easynlp

In [43]:
from easynlp.fewshot_learning.fewshot_dataset import FewshotMultiLayerBaseDataset
# 重载
module = importlib.import_module(FewshotMultiLayerBaseDataset.__module__)
importlib.reload(module)
FewshotMultiLayerBaseDataset = module.FewshotMultiLayerBaseDataset

In [4]:
train_file = r"G:\dataset\text_classify\网页层次分类\train.csv"
test_file = r"G:\dataset\text_classify\网页层次分类\test.csv"
label_file = r"G:\dataset\text_classify\网页层次分类\label.json"
model_dir = r"G:\code\pretrain_model_dir\pai-bert-base-zh"

In [44]:
# 先解析 label
with open(label_file, "r", encoding="utf-8") as f:
    # 这边 id 是 int
    label2id = json.load(f)
    label2id = {k: str(v) for k, v in label2id.items()}
    id2label = {v: k for k, v in label2id.items()}

label_enumerate_values = [[], []]
label_desc = [[], []]
for label, label_id in label2id.items():
    label_split = label.split(">")
    if len(label_split) == 1:
        label_enumerate_values[0].append(label_id)
        label_desc[0].append(label)
    elif len(label_split) == 2:
        label_enumerate_values[1].append(label_id)
        label_desc[1].append(label_split[1])
    else:
        raise ValueError(f"label split error, {label}")

label_enumerate_values = [",".join(x) for x in label_enumerate_values]
label_desc = [",".join(x) for x in label_desc]

label_enumerate_values = "@@".join(label_enumerate_values)
label_desc = "@@".join(label_desc)

In [16]:
label_enumerate_values

'0,7,13,21,31,37,42@@1,2,3,4,5,6,8,9,10,11,12,14,15,16,17,18,19,20,22,23,24,25,26,27,28,29,30,32,33,34,35,36,38,39,40,41,43,44,45,46,47'

In [17]:
label_desc

'休闲娱乐,电脑网络,商业经济,生活服务,教育文化,博客论坛,综合其他@@影视音乐,游戏动漫,图片小说,聊天交友,娱乐其他,休闲健身,商务门户,网站资源,软硬件通信,网络其他,网络营销,农林牧渔,工业制品,机械电子,建筑环境,法律金融,商务服务,交通物流,生活用品,餐饮美食,房产家居,旅游交通,百货购物,医疗保健,时尚美容,生活常识,生活其他,高校教育,人力资源,高级教育,文体艺术,文化其他,休闲娱乐,电脑网络,生活服务,博客其他,团体组织,综合网站,个人网站,新闻综合,其他'

In [45]:
dataset = FewshotMultiLayerBaseDataset(
    pretrained_model_name_or_path=model_dir,
    data_file=train_file,
    max_seq_length=64,
    first_sequence="text",
    input_schema="text:str:1,label0:str:1,label1:str:1",
    label_name="label0,label1",
    label_enumerate_values=label_enumerate_values,
    layer_num=2,
    user_defined_parameters={
        "app_parameters": {
            "pattern": "一条,label0,label1,的新闻,text",
            "label_desc": label_desc,
        }
    }
)

****G:\dataset\text_classify\网页层次分类\train.csv


In [46]:
dataset.masked_length

[4, 4]

In [48]:
dataset[0].keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'label_ids', 'mask_span_indices'])

In [50]:
print(dataset[0]["input_ids"])
print(dataset[0]["token_type_ids"])
print(dataset[0]["attention_mask"])
print(dataset[0]["label_ids"])
print(dataset[0]["mask_span_indices"])

[101, 671, 3340, 103, 103, 103, 103, 103, 103, 103, 103, 4638, 3173, 7319, 1952, 1814, 7391, 2501, 5285, 4970, 2443, 1773, 2356, 1920, 1814, 1344, 1952, 1814, 7391, 2501, 5285, 4970, 3300, 7361, 1062, 1385, 855, 754, 2769, 1744, 5401, 714, 2168, 7657, 510, 5307, 3845, 1355, 6809, 776, 3823, 1538, 1765, 1277, 8024, 3315, 1062, 1385, 3221, 671, 2157, 683, 689, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
[-100, -100, -100, 4495, 3833, 3302, 1218, 2791, 772, 2157, 2233, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

In [51]:
tokenizer = BertTokenizer.from_pretrained(model_dir)

In [54]:
print(tokenizer.convert_ids_to_tokens(dataset[0]["input_ids"]))
print(tokenizer.convert_ids_to_tokens(dataset[0]["label_ids"]))

['[CLS]', '一', '条', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '的', '新', '闻', '奥', '城', '隐', '形', '纱', '窗', '廊', '坊', '市', '大', '城', '县', '奥', '城', '隐', '形', '纱', '窗', '有', '限', '公', '司', '位', '于', '我', '国', '美', '丽', '富', '饶', '、', '经', '济', '发', '达', '京', '津', '唐', '地', '区', '，', '本', '公', '司', '是', '一', '家', '专', '业', '[PAD]']
['[UNK]', '[UNK]', '[UNK]', '生', '活', '服', '务', '房', '产', '家', '居', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]']


In [59]:
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.batch_fn)

batch = next(iter(loader))
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'label_ids', 'mask_span_indices'])

In [60]:
for key, val in batch.items():
    print(key, val.shape)

input_ids torch.Size([2, 64])
attention_mask torch.Size([2, 64])
token_type_ids torch.Size([2, 64])
label_ids torch.Size([2, 64])
mask_span_indices torch.Size([2, 8, 1])
