# fastner 
*fastner* is a Python package to finetune transformer-based models for the Named Entity Recognition task in a simple and fast way.


In [None]:
!pip install --upgrade fastner

In [2]:
import pandas as pd
from ast import literal_eval 

from transformers import TrainingArguments, EarlyStoppingCallback
from fastner import train_test

  metric = load_metric("seqeval")


Downloading builder script:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

## Check dataset and tagging scheme

**Dataset scheme**

The datasets given as input (train, validation, test) must have two columns named:

* *tokens*: contains the tokens of the several examples
* *tags*: contains the labels of the respective tokens


**Tagging scheme**

The labels of the dataset given as input must comply with the tagging scheme:

* IOB (Inside, Outside, Beginning), also known as BIO

Training set

In [3]:
conll2003_train = pd.read_csv('./dataset/conll2003/conll2003_train.csv', index_col=0, converters={'tags': literal_eval, 'tokens': literal_eval}).reset_index()

conll2003_train[:5]

Unnamed: 0,index,tokens,tags
0,0,"[EU, rejects, German, call, to, boycott, Briti...","[B-ORG, O, B-MISC, O, O, O, B-MISC, O, O]"
1,1,"[Peter, Blackburn]","[B-PER, I-PER]"
2,2,"[BRUSSELS, 1996-08-22]","[B-LOC, O]"
3,3,"[The, European, Commission, said, on, Thursday...","[O, B-ORG, I-ORG, O, O, O, O, O, O, B-MISC, O,..."
4,4,"[Germany, 's, representative, to, the, Europea...","[B-LOC, O, O, O, O, B-ORG, I-ORG, O, O, O, B-P..."


Validation set

In [4]:
conll2003_val = pd.read_csv('./dataset/conll2003/conll2003_val.csv', index_col=0, converters={'tags': literal_eval, 'tokens': literal_eval}).reset_index()

conll2003_val[:5]

Unnamed: 0,index,tokens,tags
0,0,"[CRICKET, -, LEICESTERSHIRE, TAKE, OVER, AT, T...","[O, O, B-ORG, O, O, O, O, O, O, O, O]"
1,1,"[LONDON, 1996-08-30]","[B-LOC, O]"
2,2,"[West, Indian, all-rounder, Phil, Simmons, too...","[B-MISC, I-MISC, O, B-PER, I-PER, O, O, O, O, ..."
3,3,"[Their, stay, on, top, ,, though, ,, may, be, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, B-ORG,..."
4,4,"[After, bowling, Somerset, out, for, 83, on, t...","[O, O, B-ORG, O, O, O, O, O, O, O, O, B-LOC, I..."


Test set

In [5]:
conll2003_test = pd.read_csv('./dataset/conll2003/conll2003_test.csv', index_col=0, converters={'tags': literal_eval, 'tokens': literal_eval}).reset_index()
conll2003_test[:5]

Unnamed: 0,index,tokens,tags
0,0,"[SOCCER, -, JAPAN, GET, LUCKY, WIN, ,, CHINA, ...","[O, O, B-LOC, O, O, O, O, B-PER, O, O, O, O]"
1,1,"[Nadim, Ladki]","[B-PER, I-PER]"
2,2,"[AL-AIN, ,, United, Arab, Emirates, 1996-12-06]","[B-LOC, O, B-LOC, I-LOC, I-LOC, O]"
3,3,"[Japan, began, the, defence, of, their, Asian,...","[B-LOC, O, O, O, O, O, B-MISC, I-MISC, O, O, O..."
4,4,"[But, China, saw, their, luck, desert, them, i...","[O, B-LOC, O, O, O, O, O, O, O, O, O, O, O, O,..."


## Let's finetune

To finetune our transofrmer-based model we need to:
* define the *TrainingArguments* (arguments for the training see [hugginface documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments))
* define the *train_test* arguments of fastner (see [fastner documentation](https://pypi.org/project/fastner/))

In [6]:
args  = TrainingArguments(
    num_train_epochs = 5,
    per_device_train_batch_size = 32,
    per_device_eval_batch_size = 8,
    overwrite_output_dir=True,
    output_dir= "./models",  
    save_total_limit=1,
    evaluation_strategy="epoch",
    logging_strategy = "epoch",
    save_strategy = "epoch",
    fp16=True,   
    warmup_steps = 1000,
    load_best_model_at_end= True,
    metric_for_best_model = 'eval_loss'
    )

train_results, eval_results, test_results, trainer = train_test(
    training_set = conll2003_train, # you can use also the path ./dataset/conll2003/conll2003_train.csv without load the dataset
    validation_set = conll2003_val, 
    test_set=conll2003_test,
    train_args = args, 
    model_name='distilbert-base-uncased', 
    max_len=128, 
    loss='CE', 
    callbacks= [EarlyStoppingCallback(early_stopping_patience=3)], 
    device=0 )

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/043235d6088ecd3dd5fb5ca3592b6913fd516027/config.json
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "O",
    "1": "B-LOC",
    "2": "B-MISC",
    "3": "B-ORG",
    "4": "B-PER",
    "5": "I-LOC",
    "6": "I-MISC",
    "7": "I-ORG",
    "8": "I-PER"
  },
  "initializer_range": 0.02,
  "label2id": {
    "B-LOC": 1,
    "B-MISC": 2,
    "B-ORG": 3,
    "B-PER": 4,
    "I-LOC": 5,
    "I-MISC": 6,
    "I-ORG": 7,
    "I-PER": 8,
    "O": 0
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.22.2

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/043235d6088ecd3dd5fb5ca3592b6913fd516027/pytorch_model.bin
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenC


TRAINING...



Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForTokenClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN t

{'loss': 0.5533, 'learning_rate': 2.1850000000000003e-05, 'epoch': 1.0}


Saving model checkpoint to ./models/checkpoint-439
Configuration saved in ./models/checkpoint-439/config.json


{'eval_loss': 0.06205374747514725, 'eval_LOC_precision': 0.8819409704852427, 'eval_LOC_recall': 0.95971692977681, 'eval_LOC_f1': 0.9191866527632951, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.8026004728132388, 'eval_MISC_recall': 0.7364425162689805, 'eval_MISC_f1': 0.7680995475113123, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.8433823529411765, 'eval_ORG_recall': 0.8553318419090231, 'eval_ORG_f1': 0.8493150684931507, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.967274678111588, 'eval_PER_recall': 0.9804241435562806, 'eval_PER_f1': 0.9738050229543613, 'eval_PER_number': 1839, 'eval_overall_precision': 0.8884494974460372, 'eval_overall_recall': 0.9078969523488802, 'eval_overall_f1': 0.8980679546968687, 'eval_overall_accuracy': 0.9828847051501478, 'eval_runtime': 11.5064, 'eval_samples_per_second': 282.451, 'eval_steps_per_second': 35.372, 'epoch': 1.0}


Model weights saved in ./models/checkpoint-439/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'loss': 0.0512, 'learning_rate': 4.38e-05, 'epoch': 2.0}


Saving model checkpoint to ./models/checkpoint-878
Configuration saved in ./models/checkpoint-878/config.json


{'eval_loss': 0.04518156126141548, 'eval_LOC_precision': 0.9382585751978891, 'eval_LOC_recall': 0.9678824169842134, 'eval_LOC_f1': 0.9528403001071811, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.7929240374609782, 'eval_MISC_recall': 0.8264642082429501, 'eval_MISC_f1': 0.8093467870419544, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.8927469135802469, 'eval_ORG_recall': 0.8627889634601044, 'eval_ORG_f1': 0.8775123246113008, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.9756493506493507, 'eval_PER_recall': 0.9804241435562806, 'eval_PER_f1': 0.9780309194467046, 'eval_PER_number': 1839, 'eval_overall_precision': 0.9166666666666666, 'eval_overall_recall': 0.9260818319582421, 'eval_overall_f1': 0.9213501968339057, 'eval_overall_accuracy': 0.9869128503016786, 'eval_runtime': 11.3257, 'eval_samples_per_second': 286.958, 'eval_steps_per_second': 35.936, 'epoch': 2.0}


Model weights saved in ./models/checkpoint-878/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'loss': 0.0271, 'learning_rate': 3.6820083682008375e-05, 'epoch': 3.0}


Saving model checkpoint to ./models/checkpoint-1317
Configuration saved in ./models/checkpoint-1317/config.json


{'eval_loss': 0.04322735592722893, 'eval_LOC_precision': 0.9299065420560748, 'eval_LOC_recall': 0.974959172563963, 'eval_LOC_f1': 0.9519000797236249, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.876243093922652, 'eval_MISC_recall': 0.8600867678958786, 'eval_MISC_f1': 0.8680897646414889, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.8976674191121143, 'eval_ORG_recall': 0.889634601043997, 'eval_ORG_f1': 0.8936329588014981, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.9761517615176152, 'eval_PER_recall': 0.9793365959760739, 'eval_PER_f1': 0.9777415852334419, 'eval_PER_number': 1839, 'eval_overall_precision': 0.9288925895087428, 'eval_overall_recall': 0.9392153561205591, 'eval_overall_f1': 0.934025452109846, 'eval_overall_accuracy': 0.9882440399011116, 'eval_runtime': 11.3389, 'eval_samples_per_second': 286.623, 'eval_steps_per_second': 35.894, 'epoch': 3.0}


Model weights saved in ./models/checkpoint-1317/pytorch_model.bin
Deleting older checkpoint [models/checkpoint-439] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'loss': 0.0131, 'learning_rate': 1.8451882845188285e-05, 'epoch': 4.0}


Saving model checkpoint to ./models/checkpoint-1756
Configuration saved in ./models/checkpoint-1756/config.json


{'eval_loss': 0.0419057197868824, 'eval_LOC_precision': 0.9614758545849159, 'eval_LOC_recall': 0.9646162221012521, 'eval_LOC_f1': 0.9630434782608697, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.847265221878225, 'eval_MISC_recall': 0.8904555314533622, 'eval_MISC_f1': 0.8683236382866207, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.904029304029304, 'eval_ORG_recall': 0.9202087994034303, 'eval_ORG_f1': 0.9120473022912048, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.9772357723577236, 'eval_PER_recall': 0.9804241435562806, 'eval_PER_f1': 0.9788273615635179, 'eval_PER_number': 1839, 'eval_overall_precision': 0.9349053470607771, 'eval_overall_recall': 0.947971038895437, 'eval_overall_f1': 0.9413928601287518, 'eval_overall_accuracy': 0.9897826876199367, 'eval_runtime': 11.1992, 'eval_samples_per_second': 290.198, 'eval_steps_per_second': 36.342, 'epoch': 4.0}


Model weights saved in ./models/checkpoint-1756/pytorch_model.bin
Deleting older checkpoint [models/checkpoint-878] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'loss': 0.0067, 'learning_rate': 8.368200836820084e-08, 'epoch': 5.0}


Saving model checkpoint to ./models/checkpoint-2195
Configuration saved in ./models/checkpoint-2195/config.json


{'eval_loss': 0.044513680040836334, 'eval_LOC_precision': 0.961017866811045, 'eval_LOC_recall': 0.9662493195427327, 'eval_LOC_f1': 0.9636264929424538, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.878755364806867, 'eval_MISC_recall': 0.8882863340563991, 'eval_MISC_f1': 0.883495145631068, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.9096296296296297, 'eval_ORG_recall': 0.9157345264727815, 'eval_ORG_f1': 0.9126718691936083, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.9799240368963646, 'eval_PER_recall': 0.9820554649265906, 'eval_PER_f1': 0.9809885931558937, 'eval_PER_number': 1839, 'eval_overall_precision': 0.942397856664434, 'eval_overall_recall': 0.9476342818656339, 'eval_overall_f1': 0.9450088153807404, 'eval_overall_accuracy': 0.9898172639731687, 'eval_runtime': 11.0912, 'eval_samples_per_second': 293.024, 'eval_steps_per_second': 36.696, 'epoch': 5.0}


Model weights saved in ./models/checkpoint-2195/pytorch_model.bin
Deleting older checkpoint [models/checkpoint-1317] due to args.save_total_limit


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./models/checkpoint-1756 (score: 0.0419057197868824).
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8
The following columns in the evaluation set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'train_runtime': 444.3873, 'train_samples_per_second': 157.982, 'train_steps_per_second': 4.939, 'train_loss': 0.13027186328566429, 'epoch': 5.0}


***** Running Prediction *****
  Num examples = 3453
  Batch size = 8
The following columns in the test set don't have a corresponding argument in `DistilBertForTokenClassification.forward` and have been ignored: token_type_ids. If token_type_ids are not expected by `DistilBertForTokenClassification.forward`,  you can safely ignore this message.


{'eval_loss': 0.0419057197868824, 'eval_LOC_precision': 0.9614758545849159, 'eval_LOC_recall': 0.9646162221012521, 'eval_LOC_f1': 0.9630434782608697, 'eval_LOC_number': 1837, 'eval_MISC_precision': 0.847265221878225, 'eval_MISC_recall': 0.8904555314533622, 'eval_MISC_f1': 0.8683236382866207, 'eval_MISC_number': 922, 'eval_ORG_precision': 0.904029304029304, 'eval_ORG_recall': 0.9202087994034303, 'eval_ORG_f1': 0.9120473022912048, 'eval_ORG_number': 1341, 'eval_PER_precision': 0.9772357723577236, 'eval_PER_recall': 0.9804241435562806, 'eval_PER_f1': 0.9788273615635179, 'eval_PER_number': 1839, 'eval_overall_precision': 0.9349053470607771, 'eval_overall_recall': 0.947971038895437, 'eval_overall_f1': 0.9413928601287518, 'eval_overall_accuracy': 0.9897826876199367, 'eval_runtime': 11.3038, 'eval_samples_per_second': 287.513, 'eval_steps_per_second': 36.006, 'epoch': 5.0}
TEST...
END TEST...


## Check the results

In [7]:
train_results

{'train_runtime': 444.3873,
 'train_samples_per_second': 157.982,
 'train_steps_per_second': 4.939,
 'train_loss': 0.13027186328566429,
 'epoch': 5.0}

In [8]:
eval_results

{'eval_loss': 0.0419057197868824,
 'eval_LOC_precision': 0.9614758545849159,
 'eval_LOC_recall': 0.9646162221012521,
 'eval_LOC_f1': 0.9630434782608697,
 'eval_LOC_number': 1837,
 'eval_MISC_precision': 0.847265221878225,
 'eval_MISC_recall': 0.8904555314533622,
 'eval_MISC_f1': 0.8683236382866207,
 'eval_MISC_number': 922,
 'eval_ORG_precision': 0.904029304029304,
 'eval_ORG_recall': 0.9202087994034303,
 'eval_ORG_f1': 0.9120473022912048,
 'eval_ORG_number': 1341,
 'eval_PER_precision': 0.9772357723577236,
 'eval_PER_recall': 0.9804241435562806,
 'eval_PER_f1': 0.9788273615635179,
 'eval_PER_number': 1839,
 'eval_overall_precision': 0.9349053470607771,
 'eval_overall_recall': 0.947971038895437,
 'eval_overall_f1': 0.9413928601287518,
 'eval_overall_accuracy': 0.9897826876199367,
 'eval_runtime': 11.3038,
 'eval_samples_per_second': 287.513,
 'eval_steps_per_second': 36.006,
 'epoch': 5.0}

In [9]:
test_results

{'test_loss': 0.10277716815471649,
 'test_LOC_precision': 0.9061032863849765,
 'test_LOC_recall': 0.9262147570485902,
 'test_LOC_f1': 0.9160486502521507,
 'test_LOC_number': 1667,
 'test_MISC_precision': 0.7425997425997426,
 'test_MISC_recall': 0.8219373219373219,
 'test_MISC_f1': 0.7802569303583502,
 'test_MISC_number': 702,
 'test_ORG_precision': 0.8478642480983031,
 'test_ORG_recall': 0.872366044551475,
 'test_ORG_f1': 0.8599406528189911,
 'test_ORG_number': 1661,
 'test_PER_precision': 0.972396486825596,
 'test_PER_recall': 0.9591584158415841,
 'test_PER_f1': 0.9657320872274143,
 'test_PER_number': 1616,
 'test_overall_precision': 0.8852005532503457,
 'test_overall_recall': 0.9068366985476444,
 'test_overall_f1': 0.8958880139982501,
 'test_overall_accuracy': 0.9810804222684743,
 'test_runtime': 10.8545,
 'test_samples_per_second': 318.117,
 'test_steps_per_second': 39.799}

## Load the model and do inference

In [33]:
import numpy as np
from transformers import DistilBertForTokenClassification, DistilBertTokenizer
import torch

**Load the model**

In [11]:
model = DistilBertForTokenClassification.from_pretrained('./models/checkpoint-2195')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case=True)

loading configuration file ./models/checkpoint-2195/config.json
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForTokenClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "id2label": {
    "0": "O",
    "1": "B-LOC",
    "2": "B-MISC",
    "3": "B-ORG",
    "4": "B-PER",
    "5": "I-LOC",
    "6": "I-MISC",
    "7": "I-ORG",
    "8": "I-PER"
  },
  "initializer_range": 0.02,
  "label2id": {
    "B-LOC": 1,
    "B-MISC": 2,
    "B-ORG": 3,
    "B-PER": 4,
    "I-LOC": 5,
    "I-MISC": 6,
    "I-ORG": 7,
    "I-PER": 8,
    "O": 0
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.22.2",
  "vocab_size": 30522
}



**Inference**

In [74]:
text = ["Apple CEO Tim Cook introduces the new iPhone"]
encoding = tokenizer(text, padding=True, return_tensors="pt")
predicted = model(**encoding)
predicted_id = torch.argmax(predicted.logits, -1)
predicted_id

tensor([[0, 3, 0, 4, 8, 0, 0, 0, 2, 0]])

In [78]:
 #Convert id to label and remove the first and the last tokens (101 and 102),
 #are tokens used by the tokenizer to delimit the sentence
 [model.config.id2label[k] for k in predicted_id.tolist()[0]][1:-1]

['B-ORG', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'B-MISC']