The jupter notebooks involved in this article are in the [Chapter 4 code base](https://github.com/datawhalechina/learn-nlp-with-transformers/tree/main/docs/%E7%AF%87%E7%AB%A04-%E4%BD%BF%E7%94%A8Transformers%E8%A7%A3%E5%86%B3NLP%E4%BB%BB%E5%8A%A1).

If you are opening this notebook in Google Colab, you may need to install the Transformers and 🤗Datasets libraries. Uncomment the following commands to install them.

In [None]:
!pip install datasets transformers seqeval

If you are opening this notebook locally, please make sure you have installed the above dependencies. You can also find the multi-GPU distributed training version of this notebook [here](https://github.com/huggingface/transformers/tree/master/examples/token-classification).

The model structure involved in this section is basically the same as the BERT in the previous chapter. What you need to learn additionally are the data processing methods and model training methods for specific tasks.

# *Sequence labeling (token-level classification problem)*

Sequence labeling can also be seen as a token-level classification problem: classify each token. In this notebook, we will show how to use the transformer model in [🤗 Transformers](https://github.com/huggingface/transformers) to do token-level classification. Token-level classification tasks usually refer to predicting a label result for each token in the text. The figure below shows a NER entity noun recognition task.

![Widget inference representing the NER task](https://github.com/huggingface/notebooks/blob/master/examples/images/token_classification.png?raw=1)

The most common token-level classification tasks:

- NER (Named-entity recognition) distinguishes nouns and entities in text (person names, organization names, location names...).

- POS (Part-of-speech tagging (part-of-speech tagging) tags tokens according to their grammar (noun, verb, adjective, etc.)
- Chunking (Chunking phrases) puts tokens of the same phrase together.

For the above tasks, we will show how to load the dataset using a simple Dataset library and fine-tune the pre-trained model using the `Trainer` interface in transformer.

As long as the pre-trained transformer model has a token classification neural network layer at the top (such as the `BertForTokenClassification` mentioned in the previous chapter) (in addition, due to the new tokenizer feature of the transformer library, the corresponding pre-trained model may also need to have the fast tokenizer function, refer to [this table](https://huggingface.co/transformers/index.html#bigtable)), then this notebook can theoretically use a variety of transformer models ([model panel](https://huggingface.co/models)) to solve any token-level classification task.

If the task you are dealing with is different, it is likely that only minor changes are needed to use this notebook to process it. At the same time, you should adjust the btach size required for fine-tuning training according to your GPU video memory to avoid video memory overflow.

In [None]:
task = "ner" #需要是"ner", "pos" 或者 "chunk"
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

## Download Data

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to load data and corresponding metrics. Data loading and metric loading only require simple `load_dataset` and `load_metric`.

In [None]:
from datasets import load_dataset, load_metric

The examples in this notebook use the [CONLL 2003 dataset](https://www.aclweb.org/anthology/W03-0419.pdf) dataset. This notebook should work for any token classification task in the 🤗 Datasets library. If you are using your own custom dataset from a json/csv file, you will need to check out the [datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) to learn how to load it. Custom datasets may require some adjustments to the loading property names.

In [None]:
datasets = load_dataset("conll2003")

The `datasets` object itself is a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict) data structure. For the training set, validation set, and test set, just use the corresponding key (train, validation, test) to get the corresponding data.

In [None]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})

Whether in the training set, validation set or test set, datasets contain a column called tokens (generally speaking, the text is divided into many words) and a column called label, which corresponds to the annotation of the tokens.

Given a data segmentation key (train, validation, or test) and a subscript, you can view the data.

In [None]:
datasets["train"][0]

{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}

All data labels have been encoded into integers and can be used directly by the pre-trained transformer model. The actual categories corresponding to the encodings of these integers are stored in `features`.

In [None]:
datasets["train"].features[f"ner_tags"]

Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], names_file=None, id=None), length=-1, id=None)

So taking NER as an example, the label category corresponding to 0 is "O", 1 corresponds to "B-PER", and so on. "O" means no special entity. This example contains 4 entity categories (PER, ORG, LOC, MISC), each of which has a B- (entity start token) prefix and an I- (entity middle token) prefix.

- 'PER' for person
- 'ORG' for organization
- 'LOC' for location
- 'MISC' for miscellaneous

In [None]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names
label_list

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

To further understand what the data looks like, the following function will randomly select a few examples from the dataset and display them.

In [None]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(datasets["train"])

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags
0,2227,"[Result, of, a, French, first, division, match, on, Friday, .]","[NN, IN, DT, JJ, JJ, NN, NN, IN, NNP, .]","[B-NP, B-PP, B-NP, I-NP, I-NP, I-NP, I-NP, B-PP, B-NP, O]","[O, O, O, B-MISC, O, O, O, O, O, O]"
1,2615,"[Mid-tier, golds, up, in, heavy, trading, .]","[NN, NNS, IN, IN, JJ, NN, .]","[B-NP, I-NP, B-PP, B-PP, B-NP, I-NP, O]","[O, O, O, O, O, O, O]"
2,10256,"[Neagle, (, 14-6, ), beat, the, Braves, for, the, third, time, this, season, ,, allowing, two, runs, and, six, hits, in, eight, innings, .]","[NNP, (, CD, ), VB, DT, NNPS, IN, DT, JJ, NN, DT, NN, ,, VBG, CD, NNS, CC, CD, NNS, IN, CD, NN, .]","[B-NP, O, B-NP, O, B-VP, B-NP, I-NP, B-PP, B-NP, I-NP, I-NP, B-NP, I-NP, O, B-VP, B-NP, I-NP, O, B-NP, I-NP, B-PP, B-NP, I-NP, O]","[B-PER, O, O, O, O, O, B-ORG, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]"
3,10720,"[Hansa, Rostock, 4, 1, 2, 1, 5, 4, 5]","[NNP, NNP, CD, CD, CD, CD, CD, CD, CD]","[B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP, I-NP]","[B-ORG, I-ORG, O, O, O, O, O, O, O]"
4,7125,"[MONTREAL, 70, 59, .543, 11]","[NNP, CD, CD, CD, CD]","[B-NP, I-NP, I-NP, I-NP, I-NP]","[B-ORG, O, O, O, O]"
5,3316,"[Softbank, Corp, said, on, Friday, that, it, would, procure, $, 900, million, through, the, foreign, exchange, market, by, September, 5, as, part, of, its, acquisition, of, U.S., firm, ,, Kingston, Technology, Co, .]","[NNP, NNP, VBD, IN, NNP, IN, PRP, MD, NN, $, CD, CD, IN, DT, JJ, NN, NN, IN, NNP, CD, IN, NN, IN, PRP$, NN, IN, NNP, NN, ,, NNP, NNP, NNP, .]","[B-NP, I-NP, B-VP, B-PP, B-NP, B-SBAR, B-NP, B-VP, B-NP, I-NP, I-NP, I-NP, B-PP, B-NP, I-NP, I-NP, I-NP, B-PP, B-NP, I-NP, B-PP, B-NP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, O, B-NP, I-NP, I-NP, O]","[B-ORG, I-ORG, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-LOC, O, O, B-ORG, I-ORG, I-ORG, O]"
6,3923,"[Ghent, 3, Aalst, 2]","[NN, CD, NNP, CD]","[B-NP, I-NP, I-NP, I-NP]","[B-ORG, O, B-ORG, O]"
7,2776,"[The, separatists, ,, who, swept, into, Grozny, on, August, 6, ,, still, control, large, areas, of, the, centre, of, town, ,, and, Russian, soldiers, are, based, at, checkpoints, on, the, approach, roads, .]","[DT, NNS, ,, WP, VBD, IN, NNP, IN, NNP, CD, ,, RB, VBP, JJ, NNS, IN, DT, NN, IN, NN, ,, CC, JJ, NNS, VBP, VBN, IN, NNS, IN, DT, NN, NNS, .]","[B-NP, I-NP, O, B-NP, B-VP, B-PP, B-NP, B-PP, B-NP, I-NP, O, B-ADVP, B-VP, B-NP, I-NP, B-PP, B-NP, I-NP, B-PP, B-NP, O, O, B-NP, I-NP, B-VP, I-VP, B-PP, B-NP, B-PP, B-NP, I-NP, I-NP, O]","[O, O, O, O, O, O, B-LOC, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-MISC, O, O, O, O, O, O, O, O, O, O]"
8,1178,"[Doctor, Masserigne, Ndiaye, said, medical, staff, were, overwhelmed, with, work, ., ""]","[NNP, NNP, NNP, VBD, JJ, NN, VBD, VBN, IN, NN, ., ""]","[B-NP, I-NP, I-NP, B-VP, B-NP, I-NP, B-VP, I-VP, B-PP, B-NP, O, O]","[O, B-PER, I-PER, O, O, O, O, O, O, O, O, O]"
9,10988,"[Reuters, historical, calendar, -, September, 4, .]","[NNP, JJ, NN, :, NNP, CD, .]","[B-NP, I-NP, I-NP, O, B-NP, I-NP, O]","[B-ORG, O, O, O, O, O, O]"


## Preprocessing data

Before feeding the data into the model, we need to preprocess the data. The preprocessing tool is called `Tokenizer`. `Tokenizer` first tokenizes the input, then converts the tokens into the corresponding token ID required in the pre-model, and then converts them into the input format required by the model.

In order to achieve the purpose of data preprocessing, we use the `AutoTokenizer.from_pretrained` method to instantiate our tokenizer, which ensures:

- We get a tokenizer that corresponds to the pre-trained model one by one.
- When using the tokenizer corresponding to the specified model checkpoint, we also download the vocabulary required by the model, more precisely, the tokens vocabulary.

This downloaded tokens vocabulary will be cached so that it will not be downloaded again when used again.

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Note: The following code requires that the tokenizer must be of type transformers.PreTrainedTokenizerFast, because we need to use some special features of the fast tokenizer (such as multi-threaded fast tokenizer) during preprocessing.

Almost all tokenizers corresponding to models have corresponding fast tokenizers. We can view the features of the tokenizers corresponding to all pre-trained models in the [Model Tokenizer Correspondence Table](https://huggingface.co/transformers/index.html#bigtable).

In [None]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

Check the [big table of models here](https://huggingface.co/transformers/index.html#bigtable) to see if the model has a fast tokenizer.

The tokenizer can preprocess a single text or a pair of texts. The data obtained after tokenizer preprocessing meets the input format of the pre-trained model.

In [None]:
tokenizer("Hello, this is one sentence!")

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 999, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [None]:
tokenizer(["Hello", ",", "this", "is", "one", "sentence", "split", "into", "words", "."], is_split_into_words=True)

{'input_ids': [101, 7592, 1010, 2023, 2003, 2028, 6251, 3975, 2046, 2616, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Note that transformer pre-trained models usually use subwords during pre-training. If our text input has been segmented into words, these words will be further segmented by our tokenizer. For example:

In [None]:
example = datasets["train"][4]
print(example["tokens"])

['Germany', "'s", 'representative', 'to', 'the', 'European', 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.']


In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print(tokens)

['[CLS]', 'germany', "'", 's', 'representative', 'to', 'the', 'european', 'union', "'", 's', 'veterinary', 'committee', 'werner', 'z', '##wing', '##mann', 'said', 'on', 'wednesday', 'consumers', 'should', 'buy', 'sheep', '##me', '##at', 'from', 'countries', 'other', 'than', 'britain', 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.', '[SEP]']


The words "Zwingmann" and "sheepmeat" are further divided into 3 subtokens.

Since the annotated data is usually annotated at the word level, since the word is also divided into subtokens, it means that we also need to align the subtokens of the annotated data. At the same time, due to the requirements of the pre-trained model input format, some special symbols such as: `[CLS]` and `[SEP]` are often required.

In [None]:
len(example[f"{task}_tags"]), len(tokenized_input["input_ids"])

(31, 39)

The tokenizer has a ``word_ids`` method that can help us with this.

In [None]:
print(tokenized_input.word_ids())

[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 10, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 18, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, None]


We can see that word_ids maps each subtoken position to a word subscript. For example, the first position corresponds to the 0th word, and the second and third positions correspond to the first word. Special characters correspond to None. With this list, we can align subtokens with words and labels.

In [None]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"{task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))

39 39


We usually set the label of special characters to -100. In the model, -100 is usually ignored and loss is not calculated.

We have two ways to align labels:
- Align multiple subtokens to a word and a label
- Align the first subtoken of multiple subtokens to a word and a label, and directly assign -100 to other subtokens.

We provide these two methods, which can be switched by `label_all_tokens = True`.

In [None]:
label_all_tokens = True

Finally we put everything together into our preprocessing function. `is_split_into_words=True` is already finished above.

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
# We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

The above preprocessing function can process one sample or multiple sample examples. If it processes multiple samples, it returns a list of the results of the preprocessing of multiple samples.

In [None]:
tokenize_and_align_labels(datasets['train'][:5])

{'input_ids': [[101, 7327, 19164, 2446, 2655, 2000, 17757, 2329, 12559, 1012, 102], [101, 2848, 13934, 102], [101, 9371, 2727, 1011, 5511, 1011, 2570, 102], [101, 1996, 2647, 3222, 2056, 2006, 9432, 2009, 18335, 2007, 2446, 6040, 2000, 10390, 2000, 18454, 2078, 2329, 12559, 2127, 6529, 5646, 3251, 5506, 11190, 4295, 2064, 2022, 11860, 2000, 8351, 1012, 102], [101, 2762, 1005, 1055, 4387, 2000, 1996, 2647, 2586, 1005, 1055, 15651, 2837, 14121, 1062, 9328, 5804, 2056, 2006, 9317, 10390, 2323, 4965, 8351, 4168, 4017, 2013, 3032, 2060, 2084, 3725, 2127, 1996, 4045, 6040, 2001, 24509, 1012, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'labels': [[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100], [-100, 1, 2, -100], [-100, 5, 0, 

Next, all samples in the dataset datasets are preprocessed by using the map function to apply the preprocessing function prepare_train_features to all samples.

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

Even better, the returned results are automatically cached to avoid recalculation the next time they are processed (but be aware that if the input changes, it may be affected by the cache!). The datasets library function will detect the input parameters to determine if there are any changes. If there are no changes, the cached data will be used. If there are changes, the data will be reprocessed. However, if the input parameters do not change, it is best to clear the cache when you want to change the input. The way to clear it is to use the `load_from_cache_file=False` parameter. In addition, the `batched=True` parameter used above is a feature of the tokenizer, because it uses multiple threads to process the input in parallel.

## Fine-tune the pre-trained model

Now that the data is ready, we need to download and load our pre-trained model, and then fine-tune the pre-trained model. Since we are doing a seq2seq task, we need a model class that can solve this task. We use the class `AutoModelForTokenClassification`. Similar to tokenizer, the `from_pretrained` method can also help us download and load the model, and it will also cache the model so that we don't download the model repeatedly.

In [None]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN t

Since our fine-tuning task is the token classification task, and we loaded a pre-trained language model, we will be prompted that some mismatched neural network parameters were thrown away when loading the model (for example, the neural network head of the pre-trained language model was thrown away, and the neural network head of the token classification was randomly initialized).

In order to get a `Trainer` training tool, we need 3 more elements, the most important of which is the training settings/parameters [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments). This training setting contains all the properties that can define the training process.

In [None]:
args = TrainingArguments(
    f"test-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

The evaluation_strategy = "epoch" parameter above tells the training code that we will do a validation evaluation once per epoch.

The batch_size is defined above before this notebook.

Finally, we need a data collator to feed our processed input to the model.

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

The last thing left to set up the `Trainer` is to define the evaluation method. We use the [`seqeval`](https://github.com/chakki-works/seqeval) metric to complete the evaluation. Before sending the model predictions to the evaluation, we will also do some data post-processing:

In [None]:
metric = load_metric("seqeval")

The input to the evaluation is a list of predictions and labels

In [None]:
labels = [label_list[i] for i in example[f"{task}_tags"]]
metric.compute(predictions=[labels], references=[labels])

{'LOC': {'f1': 1.0, 'number': 2, 'precision': 1.0, 'recall': 1.0},
 'ORG': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'PER': {'f1': 1.0, 'number': 1, 'precision': 1.0, 'recall': 1.0},
 'overall_accuracy': 1.0,
 'overall_f1': 1.0,
 'overall_precision': 1.0,
 'overall_recall': 1.0}

Do some post-processing on the model prediction results:
- Select the subscript with the maximum probability of the predicted classification
- Convert the subscript to label
- Ignore the -100

The following function combines the above steps.

In [None]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

We calculate the total precision/recall/f1 of all categories, so we will discard the precision/recall/f1 of a single category 

Put the data/model/parameters into `Trainer`

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Call the `train` method to start training

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.237721,0.068198,0.903148,0.921132,0.912051,0.979713
2,0.05316,0.059337,0.927697,0.93299,0.930336,0.983113
3,0.02985,0.059346,0.929267,0.939143,0.934179,0.984257


TrainOutput(global_step=2634, training_loss=0.08569671253227518)

We can use the evaluate method again to evaluate other datasets.

In [None]:
trainer.evaluate()

{'eval_loss': 0.05934586375951767,
 'eval_precision': 0.9292672127518264,
 'eval_recall': 0.9391430808815304,
 'eval_f1': 0.9341790463472988,
 'eval_accuracy': 0.9842565968195466,
 'epoch': 3.0}

If we want to get the precision/recall/f1 of a single category, we can directly input the results into the same evaluation function:

In [None]:
predictions, labels, _ = trainer.predict(tokenized_datasets["validation"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

{'LOC': {'precision': 0.949718574108818,
  'recall': 0.966768525592055,
  'f1': 0.9581677077418134,
  'number': 2618},
 'MISC': {'precision': 0.8132387706855791,
  'recall': 0.8383428107229894,
  'f1': 0.8255999999999999,
  'number': 1231},
 'ORG': {'precision': 0.9055232558139535,
  'recall': 0.9090466926070039,
  'f1': 0.9072815533980583,
  'number': 2056},
 'PER': {'precision': 0.9759552042160737,
  'recall': 0.9765985497692815,
  'f1': 0.9762767710049424,
  'number': 3034},
 'overall_precision': 0.9292672127518264,
 'overall_recall': 0.9391430808815304,
 'overall_f1': 0.9341790463472988,
 'overall_accuracy': 0.9842565968195466}

Finally, don’t forget to upload your model to [🤗 Model Hub](https://huggingface.co/models) (click [here](https://huggingface.co/transformers/model_sharing.html) to see how to upload). Then you can use your uploaded model directly by using the model name, just like at the beginning of this notebook.