In [8]:
import os
import json
import pandas as pd

from torch.utils.data import DataLoader
from transformers import AutoTokenizer

import sys
sys.path.append("../")
import easynlp
from easynlp.appzoo import FewshotSequenceClassificationDataset
# from easynlp.fewshot_learning import FewshotBaseDataset

In [2]:
tokenizer = AutoTokenizer.from_pretrained(r"G:\code\pretrain_model_dir\bert-base-chinese")

In [3]:
tokenizer("a中国A")

{'input_ids': [101, 143, 704, 1744, 100, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}

In [4]:
tokenizer.convert_ids_to_tokens([101, 143, 704, 1744, 100, 102])

['[CLS]', 'a', '中', '国', '[UNK]', '[SEP]']

In [22]:
dataset = FewshotSequenceClassificationDataset(
    pretrained_model_name_or_path= r"G:\code\pretrain_model_dir\pai-bert-base-zh",
    data_file="./tmp/fewshot_data/train.csv",
    max_seq_length=32,
    first_sequence="text",
    input_schema="text:str:1,label:str:1",
    label_name="label",
    label_enumerate_values="Positive,Negative",
    user_defined_parameters={
        "app_parameters": {
            "label_desc": "好的,差的",
            "pattern": "text,是一条商品,label,评",
        },
    },
)

****./tmp/fewshot_data/train.csv


In [23]:
dataset[0].keys()

dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'label_ids', 'mask_span_indices'])

In [24]:
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.batch_fn)

In [25]:
inputs = next(iter(loader))

for key, val in inputs.items():
    print(key, val.shape)

input_ids torch.Size([2, 32])
attention_mask torch.Size([2, 32])
token_type_ids torch.Size([2, 32])
label_ids torch.Size([2, 32])
mask_span_indices torch.Size([2, 2, 1])


In [27]:
inputs["mask_span_indices"], inputs["mask_span_indices"].shape

(tensor([[[28],
          [29]],
 
         [[28],
          [29]]]),
 torch.Size([2, 2, 1]))

In [18]:
inputs["label_ids"]

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, 1962, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, 1962, -100, -100]])

In [20]:
tokenizer.convert_ids_to_tokens(1962)

'好'