In [1]:
import transformers, datasets, accelerate
import pprint
import typing
import torch

# §7.1 实体识别

In [2]:
# 下载数据集

raw_datasets: datasets.DatasetDict = datasets.load_dataset(
    "conll2003",
    trust_remote_code=True
) # type: ignore

pprint.pprint(raw_datasets)

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})


In [16]:
# 使用.features()方法获取对应关系

ner_tags_map: list[str] = raw_datasets["train"].features["ner_tags"].feature.names

pprint.pprint({"ner_tags_map": ner_tags_map})

print(raw_datasets["train"][0]["ner_tags"])
print(raw_datasets["train"][0]["tokens"])
print([ner_tags_map[i] for i in raw_datasets["train"][0]["ner_tags"]])

{'ner_tags_map': ['O',
                  'B-PER',
                  'I-PER',
                  'B-ORG',
                  'I-ORG',
                  'B-LOC',
                  'I-LOC',
                  'B-MISC',
                  'I-MISC']}
[3, 0, 7, 0, 0, 0, 7, 0, 0]
['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']


In [18]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "bert-base-cased"
)

pprint.pprint({
    "is_split_into_words=True": tokenizer(
        raw_datasets["train"][0]["tokens"], 
        is_split_into_words=True # 当输入为list[str<token>]时使用
    ).tokens(),
    "is_split_into_words=False": tokenizer(
        raw_datasets["train"][0]["tokens"], 
        is_split_into_words=False # 当输入为str时使用
    ).tokens()
})



{'is_split_into_words=False': ['[CLS]', 'EU', '[SEP]'],
 'is_split_into_words=True': ['[CLS]',
                              'EU',
                              'rejects',
                              'German',
                              'call',
                              'to',
                              'boycott',
                              'British',
                              'la',
                              '##mb',
                              '.',
                              '[SEP]']}


In [32]:
def align_labels_with_tokens(labels: list[int], word_ids: list[int]) -> list[int]:
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            current_word = word_id
            new_labels.append(-100 if word_id is None else labels[word_id])
        elif word_id is None:
            new_labels.append(-100)
        else:
            new_labels.append(labels[word_id] + (1 if labels[word_id] % 2 == 1 else 0)) # type: ignore
    return new_labels

example_tokens: list[int] = raw_datasets["train"][0]["tokens"]
example_labels: list[int] = raw_datasets["train"][0]["ner_tags"]
example_word_ids: list[int | None] = tokenizer(
    raw_datasets["train"][0]["tokens"],
    is_split_into_words=True
).word_ids()

pprint.pprint({
    "example_tokens": example_tokens,
    "example_labels": example_labels,
    "example_word_ids": example_word_ids,
    "example_final_alignment_result": align_labels_with_tokens(
        example_labels, example_word_ids # type: ignore
    )
})


{'example_final_alignment_result': [-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, 0, -100],
 'example_labels': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'example_tokens': ['EU',
                    'rejects',
                    'German',
                    'call',
                    'to',
                    'boycott',
                    'British',
                    'lamb',
                    '.'],
 'example_word_ids': [None, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, None]}


In [33]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

tokenized_datasets = raw_datasets.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]