In [1]:
import os
import json
import sys

sys.path.append("../")

1. 定义任务


In [3]:
from openprompt.data_utils import InputExample

classes = ["negative", "positive"]  # There are two classes in Sentiment Analysis, one for negative and one for positive
dataset = [  # For simplicity, there's only two examples
    # text_a is the input text of the data, some other datasets may have multiple input sentences in one example.
    InputExample(
        guid=0,
        text_a="Albert Einstein was one of the greatest intellects of his time.",
    ),
    InputExample(
        guid=1,
        text_a="The film was badly made.",
    ),
]


2. 加载预训练模型


In [4]:
from openprompt.plms import load_plm

plm, tokenizer, model_config, WrapperClass = load_plm("bert", "bert-base-cased")


Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

3. 定义模板


In [5]:
from openprompt.prompts import ManualTemplate

promptTemplate = ManualTemplate(
    text='{"placeholder":"text_a"} It was {"mask"}',
    tokenizer=tokenizer,
)


In [10]:
promptTemplate.wrap_one_example(dataset[0])

[[{'text': 'Albert Einstein was one of the greatest intellects of his time.',
   'loss_ids': 0,
   'shortenable_ids': 1},
  {'text': ' It was', 'loss_ids': 0, 'shortenable_ids': 0},
  {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}],
 {'guid': 0}]

4. 定义标签映射


In [6]:
from openprompt.prompts import ManualVerbalizer

promptVerbalizer = ManualVerbalizer(
    classes=classes,
    label_words={
        "negative": ["bad"],
        "positive": ["good", "wonderful", "great"],
    },
    tokenizer=tokenizer,
)


5.定义提示模型


In [7]:
from openprompt import PromptForClassification

promptModel = PromptForClassification(
    template=promptTemplate,
    plm=plm,
    verbalizer=promptVerbalizer,
)


6.定义数据加载器


In [18]:
from openprompt import PromptDataLoader

data_loader = PromptDataLoader(
    dataset=dataset,
    tokenizer=tokenizer,
    template=promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
    max_seq_length=32,
)


tokenizing: 2it [00:00, 1961.33it/s]


In [21]:
from pprint import pprint
pprint(next(iter(data_loader)))
example = next(iter(data_loader))

for key, val in example.items():
    print(key, val.shape, val)

print(tokenizer.convert_ids_to_tokens(example["input_ids"][0]))

{"input_ids": [[101, 3986, 16127, 1108, 1141, 1104, 1103, 4459, 1107, 7854, 18465, 1116, 1104, 1117, 1159, 119, 1135, 1108, 103, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], "inputs_embeds": null, "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], "token_type_ids": null, "label": null, "decoder_input_ids": null, "decoder_inputs_embeds": null, "soft_token_ids": null, "past_key_values": null, "loss_ids": [[-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], "guid": [0], "tgt_text": null, "encoded_tgt_text": null, "input_ids_len": null}

input_ids torch.Size([1, 32]) tensor([[  101,  3986, 16127,  1108,  1141,  1104,  1103,  4459,  1107,  7854,
         18465,  1116,  1104,  1117,  1159,   119,  1135,  1108,   103,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]])
attention_mask torch.Size([1, 32]) tensor([[1, 1, 

7.训练和推理


In [9]:
import torch

# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim = -1)
        print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'

positive
negative


# 定义模板

TODO: 模板的文档里没有介绍 placeholder, 但上面的示例用了.

模板是由多个 dict 组成的, key 可以是

- meta: 原始文本输入字段, 也可以是其他 key information
- mask: 是需要预测的文本
- soft: soft token, TODO: 具体是啥, 可以修改的 prompt 吗?
- text: 纯文本, 可以直接写不需要包装

```
# 情感分类
{"meta": "sentence"}. It is {"mask"}.

# 新闻分类
A {"mask"} news : {"meta": "title"} {"meta": "description"}

# 纯文本可以直接写
{"meta": "sentence"} {"text": "In this sentence,"} {"meta": "entity"} {"text": "is a"} {"mask"},
{"meta": "sentence"}. In this sentence, {"meta": "entity"} is a {"mask"},

# soft 字段
{"meta": "premise"} {"meta": "hypothesis"} {"soft": "Does the first sentence entails the second?"} {"mask"} {"soft"}.
{"soft": None, "duplicate": 10000} {"meta": "text"} {"mask"}
{"soft": None, "duplicate": 10000, "same": True}

# 支持后处理
{"meta": 'context', "post_processing": lambda s: s.rstrip(string.punctuation)}. {"soft": "It was"} {"mask"}
{"text": "This sentence is", "post_processing": "mlp"} {"soft": None, "post_processing": "mlp"}
```


# 定义标签映射

需要定义标签映射

```python
ManualVerbalizer(
    label_words=label_words,
    ...
)
```

`label_words` 可以是多种格式, 如果数组, 嵌套数组, 字典.
字典最直观, 一个标签可以对应多个标签映射.

```python
{
    "person-scholar": ["scholar", "scientist"],
    "building-library": ["library"],
    "building-hotel": ["hotel"],
    "location-road/railway/highway/transit": ["road", "railway", "highway", "transit"]
}
```

也可以从文件中加载.

```python
ManualVerbalizer(...).from_file(file_path=file_path)
```

文件有个优势是可以有多组标签映射, 需要的时候使用 choice 选择对应的标签映射组.

```python
ManualVerbalizer(...).from_file(file_path=file_path, choice=0)
```
