# Контекст

- NER - классификация именнованных токенов
- bert модель
- conll2003 - dataset 

In [72]:
import numpy as np
import pandas as pd
import datasets
from transformers import BertTokenizerFast,\
                         AutoModelForTokenClassification,\
                         TrainingArguments,\
                         Trainer,\
                         DataCollatorForTokenClassification,\
                         pipeline
from sklearn.metrics import precision_recall_fscore_support,\
                            accuracy_score

# Чтение файла

In [50]:
conll2003 = datasets.load_dataset('conll2003')

In [51]:
conll2003['train'][0:2]

{'id': ['0', '1'],
 'tokens': [['EU',
   'rejects',
   'German',
   'call',
   'to',
   'boycott',
   'British',
   'lamb',
   '.'],
  ['Peter', 'Blackburn']],
 'pos_tags': [[22, 42, 16, 21, 35, 37, 16, 21, 7], [22, 22]],
 'chunk_tags': [[11, 21, 11, 12, 21, 22, 11, 12, 0], [11, 12]],
 'ner_tags': [[3, 0, 7, 0, 0, 0, 7, 0, 0], [1, 2]]}

In [52]:
conll2003['train'].features['ner_tags']

Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)

# Токенизация и выравнивание

In [53]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [54]:
def tokenization_align_labels(
        
        examples,  # dict {'tokens': ...., 'ner_tags: .... и др'}
        label_all_tokens=False

        ):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples["ner_tags"]): 
        word_ids = tokenized_inputs.word_ids(batch_index=i) # [None 1 2 3 4 5 6 None]
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Пример использования функции
print(conll2003['train'][0:2]['tokens'])
tokenization_align_labels(conll2003['train'][0:2], True)['labels'] # на выходе нормализованные индексы токенов

[['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], ['Peter', 'Blackburn']]


[[-100, 3, 0, 7, 0, 0, 0, 7, 0, 0, -100], [-100, 1, 2, -100]]

In [65]:
tokenized_datasets = conll2003.map(tokenization_align_labels, batched = True)

# Скачивание модели

In [56]:
model = AutoModelForTokenClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 9 # кол-во имен собственных в данных
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Training args

In [57]:
training_args = TrainingArguments(

        output_dir='training/model_points',
        do_train=True,
        do_eval=True,
        num_train_epochs=3,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=64,
        warmup_steps=100,
        weight_decay=0.01,
        logging_strategy='steps',
        logging_dir='training/logs',
        logging_steps=50,
        evaluation_strategy="steps",
        save_steps = 500,
        fp16=True,
        load_best_model_at_end=True,
        report_to = 'wandb',
        learning_rate = 0.00002

)

In [58]:
data_collator = DataCollatorForTokenClassification(tokenizer)

# Метрики

In [69]:
metric = datasets.load_metric('seqeval')
label_list = conll2003["train"].features["ner_tags"].feature.names 
label_list

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

In [70]:
def compute_metrics(p):
    predictions, labels = p 
    predictions = np.argmax(predictions, axis=2) 

    true_predictions = [ 
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 
        for prediction, label in zip(predictions, labels) 
                        ] 
    
    true_labels = [ 
      [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 
       for prediction, label in zip(predictions, labels) 
                  ] 
    
    results = metric.compute(predictions=true_predictions, references=true_labels)


    return { 
                "precision": results["overall_precision"], 
                "recall": results["overall_recall"], 
                "f1": results["overall_f1"], 
                "accuracy": results["overall_accuracy"], 
           } 

# Trainer

In [71]:
trainer = Trainer(

        model = model,
        args = training_args,
        train_dataset = tokenized_datasets['train'],
        eval_dataset = tokenized_datasets['validation'],
        compute_metrics = compute_metrics,
        tokenizer = tokenizer,
        data_collator = data_collator

 )

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/1317 [00:00<?, ?it/s]

{'loss': 0.1273, 'grad_norm': 1.7592236995697021, 'learning_rate': 1e-05, 'epoch': 0.11}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.11439711600542068, 'eval_precision': 0.8167369445345443, 'eval_recall': 0.8475260854930999, 'eval_f1': 0.8318467129170796, 'eval_accuracy': 0.9708734083563724, 'eval_runtime': 2.9826, 'eval_samples_per_second': 1089.636, 'eval_steps_per_second': 17.099, 'epoch': 0.11}
{'loss': 0.083, 'grad_norm': 1.4185106754302979, 'learning_rate': 2e-05, 'epoch': 0.23}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.08942238986492157, 'eval_precision': 0.8516528252727569, 'eval_recall': 0.8801750252440256, 'eval_f1': 0.8656790532152611, 'eval_accuracy': 0.9769479381644017, 'eval_runtime': 2.9965, 'eval_samples_per_second': 1084.584, 'eval_steps_per_second': 17.02, 'epoch': 0.23}
{'loss': 0.0982, 'grad_norm': 2.1030335426330566, 'learning_rate': 1.9178307313064914e-05, 'epoch': 0.34}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.07006510347127914, 'eval_precision': 0.8789611238157465, 'eval_recall': 0.905587344328509, 'eval_f1': 0.8920755968169762, 'eval_accuracy': 0.9803940656516491, 'eval_runtime': 2.9762, 'eval_samples_per_second': 1091.982, 'eval_steps_per_second': 17.136, 'epoch': 0.34}
{'loss': 0.0788, 'grad_norm': 1.5318435430526733, 'learning_rate': 1.835661462612983e-05, 'epoch': 0.46}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.06393799930810928, 'eval_precision': 0.8849730788056779, 'eval_recall': 0.9128239649949512, 'eval_f1': 0.8986827934719577, 'eval_accuracy': 0.9814259569331413, 'eval_runtime': 2.9759, 'eval_samples_per_second': 1092.091, 'eval_steps_per_second': 17.137, 'epoch': 0.46}
{'loss': 0.08, 'grad_norm': 1.9071953296661377, 'learning_rate': 1.753492193919474e-05, 'epoch': 0.57}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.05609885975718498, 'eval_precision': 0.9065435965056865, 'eval_recall': 0.9256142712891282, 'eval_f1': 0.9159796819052377, 'eval_accuracy': 0.9842685253689498, 'eval_runtime': 2.9818, 'eval_samples_per_second': 1089.95, 'eval_steps_per_second': 17.104, 'epoch': 0.57}
{'loss': 0.0683, 'grad_norm': 1.886765718460083, 'learning_rate': 1.6713229252259656e-05, 'epoch': 0.68}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.054820895195007324, 'eval_precision': 0.90753255315642, 'eval_recall': 0.9266240323123528, 'eval_f1': 0.9169789324673162, 'eval_accuracy': 0.9838207234920758, 'eval_runtime': 2.9997, 'eval_samples_per_second': 1083.436, 'eval_steps_per_second': 17.002, 'epoch': 0.68}
{'loss': 0.067, 'grad_norm': 1.2380889654159546, 'learning_rate': 1.5891536565324572e-05, 'epoch': 0.8}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.05572396144270897, 'eval_precision': 0.9067629678266579, 'eval_recall': 0.9296533153820262, 'eval_f1': 0.9180654811367791, 'eval_accuracy': 0.9841906467816673, 'eval_runtime': 3.0093, 'eval_samples_per_second': 1079.974, 'eval_steps_per_second': 16.947, 'epoch': 0.8}
{'loss': 0.0656, 'grad_norm': 1.930906057357788, 'learning_rate': 1.5069843878389482e-05, 'epoch': 0.91}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.051449473947286606, 'eval_precision': 0.9083632799738648, 'eval_recall': 0.935880175025244, 'eval_f1': 0.9219164456233423, 'eval_accuracy': 0.9856898095868541, 'eval_runtime': 3.0338, 'eval_samples_per_second': 1071.278, 'eval_steps_per_second': 16.811, 'epoch': 0.91}
{'loss': 0.0542, 'grad_norm': 1.703126311302185, 'learning_rate': 1.4248151191454397e-05, 'epoch': 1.03}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.05002352595329285, 'eval_precision': 0.9219917012448133, 'eval_recall': 0.9348704140020195, 'eval_f1': 0.9283863959221191, 'eval_accuracy': 0.9862154900510105, 'eval_runtime': 2.987, 'eval_samples_per_second': 1088.031, 'eval_steps_per_second': 17.074, 'epoch': 1.03}
{'loss': 0.0358, 'grad_norm': 1.490326166152954, 'learning_rate': 1.3426458504519311e-05, 'epoch': 1.14}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.05245550721883774, 'eval_precision': 0.9175223583968202, 'eval_recall': 0.9323460114439582, 'eval_f1': 0.9248747913188649, 'eval_accuracy': 0.9860792025232662, 'eval_runtime': 2.9815, 'eval_samples_per_second': 1090.073, 'eval_steps_per_second': 17.106, 'epoch': 1.14}
{'loss': 0.0431, 'grad_norm': 1.1926366090774536, 'learning_rate': 1.2604765817584223e-05, 'epoch': 1.25}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04916873946785927, 'eval_precision': 0.9272273105745212, 'eval_recall': 0.9370582295523393, 'eval_f1': 0.932116849418264, 'eval_accuracy': 0.9867995794556287, 'eval_runtime': 3.1028, 'eval_samples_per_second': 1047.437, 'eval_steps_per_second': 16.437, 'epoch': 1.25}
{'loss': 0.0379, 'grad_norm': 2.537076234817505, 'learning_rate': 1.1783073130649139e-05, 'epoch': 1.37}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.0478006973862648, 'eval_precision': 0.9202433810228581, 'eval_recall': 0.9417704476607203, 'eval_f1': 0.9308824752557597, 'eval_accuracy': 0.9871110938047584, 'eval_runtime': 3.009, 'eval_samples_per_second': 1080.096, 'eval_steps_per_second': 16.949, 'epoch': 1.37}
{'loss': 0.0386, 'grad_norm': 0.8048515319824219, 'learning_rate': 1.0961380443714052e-05, 'epoch': 1.48}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.0464089997112751, 'eval_precision': 0.927668772863319, 'eval_recall': 0.9389094580949175, 'eval_f1': 0.9332552693208431, 'eval_accuracy': 0.9868969276897317, 'eval_runtime': 2.9949, 'eval_samples_per_second': 1085.188, 'eval_steps_per_second': 17.029, 'epoch': 1.48}
{'loss': 0.0424, 'grad_norm': 0.9645754098892212, 'learning_rate': 1.0139687756778966e-05, 'epoch': 1.59}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04622066393494606, 'eval_precision': 0.9256376283537595, 'eval_recall': 0.9405923931336251, 'eval_f1': 0.9330550918196995, 'eval_accuracy': 0.9869942759238347, 'eval_runtime': 2.9843, 'eval_samples_per_second': 1089.025, 'eval_steps_per_second': 17.089, 'epoch': 1.59}
{'loss': 0.0347, 'grad_norm': 1.3958693742752075, 'learning_rate': 9.31799506984388e-06, 'epoch': 1.71}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04604792594909668, 'eval_precision': 0.929865381419312, 'eval_recall': 0.9416021541568496, 'eval_f1': 0.9356969646291495, 'eval_accuracy': 0.9875978349752735, 'eval_runtime': 3.0024, 'eval_samples_per_second': 1082.449, 'eval_steps_per_second': 16.986, 'epoch': 1.71}
{'loss': 0.036, 'grad_norm': 1.7382843494415283, 'learning_rate': 8.496302382908793e-06, 'epoch': 1.82}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04637272655963898, 'eval_precision': 0.9260912698412699, 'eval_recall': 0.942611915180074, 'eval_f1': 0.934278565471226, 'eval_accuracy': 0.9875588956816324, 'eval_runtime': 3.2996, 'eval_samples_per_second': 984.957, 'eval_steps_per_second': 15.456, 'epoch': 1.82}
{'loss': 0.0389, 'grad_norm': 0.9077371954917908, 'learning_rate': 7.674609695973705e-06, 'epoch': 1.94}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.045242197811603546, 'eval_precision': 0.9276391871799108, 'eval_recall': 0.9449680242342645, 'eval_f1': 0.9362234264276781, 'eval_accuracy': 0.987675713562556, 'eval_runtime': 3.225, 'eval_samples_per_second': 1007.746, 'eval_steps_per_second': 15.814, 'epoch': 1.94}
{'loss': 0.0291, 'grad_norm': 0.41188639402389526, 'learning_rate': 6.85291700903862e-06, 'epoch': 2.05}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04656890779733658, 'eval_precision': 0.9285242654341367, 'eval_recall': 0.946650959272972, 'eval_f1': 0.9375, 'eval_accuracy': 0.987675713562556, 'eval_runtime': 3.0822, 'eval_samples_per_second': 1054.43, 'eval_steps_per_second': 16.546, 'epoch': 2.05}
{'loss': 0.0237, 'grad_norm': 2.099060297012329, 'learning_rate': 6.0312243221035336e-06, 'epoch': 2.16}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04782210290431976, 'eval_precision': 0.929316338354577, 'eval_recall': 0.9447997307303938, 'eval_f1': 0.9369940749394976, 'eval_accuracy': 0.9879093493244032, 'eval_runtime': 3.0203, 'eval_samples_per_second': 1076.063, 'eval_steps_per_second': 16.886, 'epoch': 2.16}
{'loss': 0.0257, 'grad_norm': 1.496081829071045, 'learning_rate': 5.209531635168448e-06, 'epoch': 2.28}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.046769093722105026, 'eval_precision': 0.93082384018491, 'eval_recall': 0.9488387748232918, 'eval_f1': 0.939744978748229, 'eval_accuracy': 0.9880066975585063, 'eval_runtime': 3.0429, 'eval_samples_per_second': 1068.075, 'eval_steps_per_second': 16.761, 'epoch': 2.28}
{'loss': 0.0236, 'grad_norm': 1.307188630104065, 'learning_rate': 4.387838948233361e-06, 'epoch': 2.39}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.0468832291662693, 'eval_precision': 0.9318859794497846, 'eval_recall': 0.9463143722652305, 'eval_f1': 0.9390447561790247, 'eval_accuracy': 0.9881235154394299, 'eval_runtime': 3.1468, 'eval_samples_per_second': 1032.804, 'eval_steps_per_second': 16.207, 'epoch': 2.39}
{'loss': 0.0229, 'grad_norm': 0.8871119022369385, 'learning_rate': 3.566146261298275e-06, 'epoch': 2.51}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04612216353416443, 'eval_precision': 0.9355322338830585, 'eval_recall': 0.9451363177381353, 'eval_f1': 0.9403097530347426, 'eval_accuracy': 0.9885323780226627, 'eval_runtime': 3.0333, 'eval_samples_per_second': 1071.439, 'eval_steps_per_second': 16.813, 'epoch': 2.51}
{'loss': 0.0224, 'grad_norm': 0.594257116317749, 'learning_rate': 2.7444535743631883e-06, 'epoch': 2.62}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.045229483395814896, 'eval_precision': 0.9325600662800332, 'eval_recall': 0.9471558397845843, 'eval_f1': 0.939801285797779, 'eval_accuracy': 0.9882598029671742, 'eval_runtime': 3.0877, 'eval_samples_per_second': 1052.558, 'eval_steps_per_second': 16.517, 'epoch': 2.62}
{'loss': 0.0188, 'grad_norm': 0.45527184009552, 'learning_rate': 1.922760887428102e-06, 'epoch': 2.73}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04545798897743225, 'eval_precision': 0.9330238726790451, 'eval_recall': 0.9471558397845843, 'eval_f1': 0.9400367462836143, 'eval_accuracy': 0.9882403333203535, 'eval_runtime': 3.107, 'eval_samples_per_second': 1046.011, 'eval_steps_per_second': 16.414, 'epoch': 2.73}
{'loss': 0.0269, 'grad_norm': 1.5482826232910156, 'learning_rate': 1.1010682004930157e-06, 'epoch': 2.85}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04520042613148689, 'eval_precision': 0.9315362989912354, 'eval_recall': 0.947997307303938, 'eval_f1': 0.9396947201601468, 'eval_accuracy': 0.9882403333203535, 'eval_runtime': 3.0305, 'eval_samples_per_second': 1072.443, 'eval_steps_per_second': 16.829, 'epoch': 2.85}
{'loss': 0.0208, 'grad_norm': 0.2904956042766571, 'learning_rate': 2.7937551355792936e-07, 'epoch': 2.96}


  0%|          | 0/51 [00:00<?, ?it/s]

{'eval_loss': 0.04488206282258034, 'eval_precision': 0.9347717842323652, 'eval_recall': 0.9478290138000673, 'eval_f1': 0.9412551182418317, 'eval_accuracy': 0.9885518476694832, 'eval_runtime': 3.0067, 'eval_samples_per_second': 1080.926, 'eval_steps_per_second': 16.962, 'epoch': 2.96}
{'train_runtime': 183.21, 'train_samples_per_second': 229.917, 'train_steps_per_second': 7.188, 'train_loss': 0.04753323887095093, 'epoch': 3.0}


TrainOutput(global_step=1317, training_loss=0.04753323887095093, metrics={'train_runtime': 183.21, 'train_samples_per_second': 229.917, 'train_steps_per_second': 7.188, 'train_loss': 0.04753323887095093, 'epoch': 3.0})

# Проверка

In [74]:
q=[trainer.evaluate(eval_dataset = tokenized_datasets[data]) for data in ['train', 'validation', 'test']]
pd.DataFrame(q, index=["train","val", 'test']).iloc[:,:5]

  0%|          | 0/220 [00:00<?, ?it/s]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|          | 0/54 [00:00<?, ?it/s]

Unnamed: 0,eval_loss,eval_precision,eval_recall,eval_f1,eval_accuracy
train,0.016475,0.972156,0.977659,0.9749,0.995747
val,0.046769,0.930824,0.948839,0.939745,0.988007
test,0.101807,0.885808,0.910588,0.898027,0.978723


In [75]:
text = 'Moscow never sleeps. But Dima sleeps'

from transformers import pipeline
nlp = pipeline("ner", model = model, tokenizer=tokenizer)

ner_results = nlp(text)
print(ner_results)

[{'entity': 'LABEL_5', 'score': 0.99672866, 'index': 1, 'word': 'moscow', 'start': 0, 'end': 6}, {'entity': 'LABEL_0', 'score': 0.99969566, 'index': 2, 'word': 'never', 'start': 7, 'end': 12}, {'entity': 'LABEL_0', 'score': 0.9996524, 'index': 3, 'word': 'sleeps', 'start': 13, 'end': 19}, {'entity': 'LABEL_0', 'score': 0.99969065, 'index': 4, 'word': '.', 'start': 19, 'end': 20}, {'entity': 'LABEL_0', 'score': 0.9996772, 'index': 5, 'word': 'but', 'start': 21, 'end': 24}, {'entity': 'LABEL_1', 'score': 0.997707, 'index': 6, 'word': 'dim', 'start': 25, 'end': 28}, {'entity': 'LABEL_1', 'score': 0.9791077, 'index': 7, 'word': '##a', 'start': 28, 'end': 29}, {'entity': 'LABEL_0', 'score': 0.99958986, 'index': 8, 'word': 'sleeps', 'start': 30, 'end': 36}]


In [89]:
print(ner_results[0])
print(ner_results[5])
print(label_list)

{'entity': 'LABEL_5', 'score': 0.99672866, 'index': 1, 'word': 'moscow', 'start': 0, 'end': 6}
{'entity': 'LABEL_1', 'score': 0.997707, 'index': 6, 'word': 'dim', 'start': 25, 'end': 28}
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
