In [3]:
import os
import json
import sys
import importlib
import pandas as pd
from collections import defaultdict

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

sys.path.append("../")
import easynlp

In [20]:
from easynlp.fewshot_learning.fewshot_predictor import PromptMultiLayerPredictor
from easynlp.fewshot_learning.fewshot_application import FewshotMultiLayerClassification
# 重载
module = importlib.import_module(PromptMultiLayerPredictor.__module__)
importlib.reload(module)
PromptMultiLayerPredictor = module.PromptMultiLayerPredictor
FewshotPyModelPredictor = module.FewshotPyModelPredictor

In [4]:
def parse_label(label_file, model_dir):
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    # 先解析 label
    with open(label_file, "r", encoding="utf-8") as f:
        # 这边 id 是 int
        label2id = json.load(f)
        label2id = {k: str(v) for k, v in label2id.items()}

    # 想想每一层的标签应该怎么建立, 现在假设每层都是完整的, 也就是每个样本的标签都会到最后一层
    label_enumerate_values = defaultdict(list)
    label_desc = defaultdict(list)
    for label, label_id in label2id.items():
        label_split = label.split(">")
        # 标签应该只放最后一层的. TODO: 同名怎么办?
        idx = len(label_split) - 1
        label_enumerate_values[idx].append(label_id)
        label_desc[idx].append(label_split[-1])

    # 其实应该还是有序的, 因为添加的时候是从左到右添加的
    # 计算 label_desc 的最大长度
    for idx in sorted(label_desc.keys()):
        cur_list = label_desc[idx]
        cur_max_len = max([len(tokenizer.tokenize(x)) for x in cur_list])
        print(f"layer_{idx}, max_label_len: {cur_max_len}")
        # 填充到最大长度
        label_desc[idx] = [x + "[PAD]" * (cur_max_len - len(tokenizer.tokenize(x))) for x in cur_list]

    label_enumerate_values_new = []
    for idx in sorted(label_enumerate_values.keys()):
        label_enumerate_values_new.append(",".join(label_enumerate_values[idx]))
    label_enumerate_values_new = "@@".join(label_enumerate_values_new)

    label_desc_new = []
    for idx in sorted(label_desc.keys()):
        label_desc_new.append(",".join(label_desc[idx]))
    label_desc_new = "@@".join(label_desc_new)

    return label_enumerate_values_new, label_desc_new


In [19]:
model_dir = r"G:\code\github\EasyNLP\demo\tmp\fewshot_multi_layer"
input_schema = "text:str:1,label0:str:1,label1:str:1"
label_name = "label0,label1"
pattern = "一条,label0,label1,的新闻,text"
label_file = r"G:\dataset\text_classify\网页层次分类\label.json"
label_enumerate_values, label_desc = parse_label(label_file, model_dir)

with open(label_file, "r", encoding="utf-8") as f:
    label2id = json.load(f)
    label2id = {k: str(v) for k, v in label2id.items()}
    id2label = {v: k for k, v in label2id.items()}

layer_0, max_label_len: 4
layer_1, max_label_len: 5


In [21]:
predictor = PromptMultiLayerPredictor(
    model_dir=model_dir,
    model_cls=FewshotMultiLayerClassification,
    user_defined_parameters={
        "app_parameters": {
            "pattern": pattern,
            "label_desc": label_desc,
        }
    },
    first_sequence="text",
    second_sequence=None,
    label_name=label_name,
    sequence_length=128,
)

embedding size: 21128


In [33]:
text = "荔枝壳 专注于云计算、大数据等IT技术，以及最新动态资讯"
result = predictor.run({"text": text})
result
for key, val in result.items():
    if key.startswith("predictions"):
        print(key, val, id2label[val])

predictions_0 0 休闲娱乐
predictions_1 1 休闲娱乐>影视音乐


In [26]:
inputs = predictor.preprocess({"text": "华孚时尚股份有限公司 华孚时尚股份有限公司"})
predictor.predict(inputs)["logits"].shape

torch.Size([1, 128, 21128])