In [1]:
import os
import json
import sys
from collections import defaultdict
import pandas as pd
from tqdm import tqdm

# 必须覆盖掉已经安装的 paddlenlp
sys.path.insert(0, r"G:\code\github\PaddleNLP")
from paddlenlp import Taskflow

In [2]:
data_dir = r"G:\dataset\text_classify\tnews\paddlenlp"

choices = []
with open(os.path.join(data_dir, "label.txt"), "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        choices.append(line)
print(len(choices))

15


In [3]:
# 还需要一个 label_id
raw_dir = r"G:\dataset\text_classify\tnews\raw"

label_list = []
zh2label_id = dict()
with open(os.path.join(raw_dir, "label_index2en2zh.json"), "r", encoding="utf-8") as f:
    for line in f:
        line = json.loads(line)
        label_list.append(line)
        label_id = line["label"]
        label_zh = line["label_zh"]
        zh2label_id[label_zh] = label_id

In [4]:
zh2label_id

{'故事': '100',
 '文化': '101',
 '娱乐': '102',
 '体育': '103',
 '财经': '104',
 '房产': '106',
 '汽车': '107',
 '教育': '108',
 '科技': '109',
 '军事': '110',
 '旅游': '112',
 '国际': '113',
 '股票': '114',
 '农业': '115',
 '电竞': '116'}

In [6]:
model_dir = r"G:\code\github\PaddleNLP\outputs\tnews\plm"

task = Taskflow(
    "zero_shot_text_classification", 
    model="utc-base",
    schema=choices,
    task_path=model_dir,
    precision="fp32",
    single_label=True,
    batch_size=32,
)
task_instance = task.task_instance
print(task("加长3.4米，玛莎拉蒂Ghibli奇特改装，内饰极尽奢华"))

[32m[2023-07-16 22:54:17,540] [    INFO][0m - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/taskflow/zero_shot_text_classification/utc-base/vocab.txt[0m


  0%|          | 0.00/182k [00:00<?, ?B/s]

[32m[2023-07-16 22:54:17,967] [    INFO][0m - Downloading special_tokens_map.json from https://paddlenlp.bj.bcebos.com/taskflow/zero_shot_text_classification/utc-base/special_tokens_map.json[0m


  0%|          | 0.00/112 [00:00<?, ?B/s]

[32m[2023-07-16 22:54:18,158] [    INFO][0m - Downloading tokenizer_config.json from https://paddlenlp.bj.bcebos.com/taskflow/zero_shot_text_classification/utc-base/tokenizer_config.json[0m


  0%|          | 0.00/198 [00:00<?, ?B/s]

[32m[2023-07-16 22:54:18,369] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'G:\code\github\PaddleNLP\outputs\tnews\plm'.[0m
[32m[2023-07-16 22:54:18,385] [    INFO][0m - Assigning ['[O-MASK]'] to the additional_special_tokens key of the tokenizer[0m
[32m[2023-07-16 22:54:19,122] [    INFO][0m - loading configuration file G:\code\github\PaddleNLP\outputs\tnews\plm\config.json[0m
[32m[2023-07-16 22:54:19,929] [    INFO][0m - All model checkpoint weights were used when initializing UTC.
[0m
[32m[2023-07-16 22:54:19,930] [    INFO][0m - All the weights of UTC were initialized from the model checkpoint at G:\code\github\PaddleNLP\outputs\tnews\plm.
If your task is similar to the task the model of the checkpoint was trained on, you can already use UTC for predictions without further training.[0m
[32m[2023-07-16 22:54:19,931] [    INFO][0m - Converting to the inference model cost a little time.[0m
[32m[2023-07-16 22:54

[{'text_a': '加长3.4米，玛莎拉蒂Ghibli奇特改装，内饰极尽奢华', 'predictions': [{'label': '汽车', 'score': 0.9951342415991393}]}]


In [8]:
input_file = os.path.join(data_dir, "test.txt")
output_file = os.path.join(data_dir, "../submit/tnewsf_predict.json")

with open(input_file, "r", encoding="utf-8") as fr, open(output_file, "w", encoding="utf-8") as f:
    data_list = []
    for line in fr:
        line = json.loads(line)
        data_list.append(line)
    
    result = task(data_list)
    for data, item in zip(data_list, result):
        index_id = data["id"]
        predict_label = item["predictions"][0]["label"]
        label_id = zh2label_id[predict_label]
        
        f.write(json.dumps({"id": index_id, "label": label_id}, ensure_ascii=False) + "\n")


In [12]:
line

{'id': '0',
 'text_a': '加长3.4米，玛莎拉蒂Ghibli奇特改装，内饰极尽奢华',
 'text_b': '',
 'question': '',
 'choices': ['故事',
  '文化',
  '娱乐',
  '体育',
  '财经',
  '房产',
  '汽车',
  '教育',
  '科技',
  '军事',
  '旅游',
  '国际',
  '股票',
  '农业',
  '电竞']}