In [1]:
import os
import json

from datasets import load_from_disk, Dataset
from transformers import AutoModelForTokenClassification
from transformers import pipeline
from nervaluate import Evaluator

from src.controllers import Controller
from src.NER.bert.bert_hf import BertTokenClassifierHF
from src.tools.general_tools import load_pickled_data
from src.tools.general_tools import get_filepath, get_folder_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model_dir = "../results/evaluation/bert/checkpoint-417"
# assert os.path.isdir(model_dir), f'Model not found at {model_dir}'

In [3]:
dataset = load_from_disk("../results/dataset/bert")
label_names = list(load_pickled_data("../results/dataset/bert/labels.pkl"))
eval_original_texts = load_pickled_data("../results/dataset/bert/eval_original_text.pkl")
dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1670
    })
    eval: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 186
    })
    all: Dataset({
        features: ['ner_tags', 'original_text', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1856
    })
})

In [4]:
c = Controller('bert')

data_dir: str = get_folder_path(c._dataset_base_path)
bert_model = BertTokenClassifierHF(
    dataset_base_path=data_dir, 
    eval_base_path=get_folder_path(c._evaluation_base_path+"-3epochs"),
    device='cpu'
)
print(bert_model.output_dir)
bert_model._use_finetuned_model()

2023-03-29 18:37:53.873 | INFO     | src.NER.bert.bert_hf:_get_latest_ckpt_path:235 - Using latest checkpoint: checkpoint-834.


/home/geoph/Repos/AgroknowNER/foodner/results/evaluation/bert-3epochs


## Use our evaluate method

In [6]:
results, results_per_tag = bert_model.evaluate()

TypeError: forward() got an unexpected keyword argument 'offset_mapping'

## With best model loading

## Other Approaches and Tests

In [7]:
classifier = pipeline("ner", model="../results/evaluation/bert/checkpoint-834/")

In [8]:
out = classifier(eval_original_texts[0])
print(out[:3], eval_original_texts[0])

[{'entity': 'B-SKIP', 'score': 0.63059723, 'index': 27, 'word': 'AN', 'start': 65, 'end': 67}, {'entity': 'I-SKIP', 'score': 0.75765026, 'index': 30, 'word': 'DR', 'start': 72, 'end': 74}, {'entity': 'B-SKIP', 'score': 0.80880636, 'index': 34, 'word': 'F', 'start': 81, 'end': 82}] Title 21: Food and Drugs PART 556-TOLERANCES FOR RESIDUES OF NEW ANIMAL DRUGS IN FOOD Subpart B-Specific Tolerances for Residues of New Animal Drugs $556.513 Piperazine. A tolerance of 0.1 part per million piperazine base is established for edible tissues of poultry and swine. [64 FR 23019, Apr. 29, 1999]


In [7]:
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification


def predict2(text: str, model_path: str):
    # tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer = bert_model.tokenizer
    inputs = tokenizer(text, return_tensors="pt")
    # model = AutoModelForTokenClassification.from_pretrained(model_path)
    bert_model._use_finetuned_model()
    model = bert_model.model
    with torch.no_grad():
        logits = model(**inputs).logits
    predictions = torch.argmax(logits, dim=2)
    predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]
    tok_text = tokenizer.tokenize(text)
    assert len(tok_text) == len(predicted_token_class)-2, f"{len(tok_text)} != {len(predicted_token_class)}"
    print(list(zip(tok_text, predicted_token_class[1: -1]))[:30])
    return tokenizer, model
    # return predicted_token_class

model_dir_200 = "../results/evaluation/bert/checkpoint-50"
model_dir_ALL = "../results/evaluation/bert/checkpoint-834"
tok, model = predict2(eval_original_texts[0], model_dir_ALL)

2023-03-23 15:59:34.374 | INFO     | src.NER.bert.bert_hf:_get_latest_ckpt_path:99 - Using latest checkpoint: checkpoint-834.


[('Title', 'O'), ('21', 'O'), (':', 'O'), ('Food', 'O'), ('and', 'O'), ('Drugs', 'O'), ('PA', 'O'), ('##RT', 'O'), ('55', 'O'), ('##6', 'O'), ('-', 'O'), ('TO', 'O'), ('##LE', 'O'), ('##RA', 'O'), ('##NC', 'O'), ('##ES', 'O'), ('F', 'O'), ('##OR', 'O'), ('R', 'O'), ('##ES', 'O'), ('##ID', 'O'), ('##UE', 'O'), ('##S', 'O'), ('OF', 'O'), ('NE', 'O'), ('##W', 'O'), ('AN', 'B-SKIP'), ('##IM', 'O'), ('##AL', 'O'), ('DR', 'I-SKIP')]


In [9]:
token_classes = bert_model.predict_token_classes(eval_original_texts[0])
print(token_classes)

2023-03-23 15:59:55.030 | INFO     | src.NER.bert.bert_hf:_get_latest_ckpt_path:99 - Using latest checkpoint: checkpoint-834.


['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-SKIP', 'O', 'O', 'I-SKIP', 'O', 'O', 'O', 'B-SKIP', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-SKIP', 'I-SKIP', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Substance', 'I-Substance', 'I-Substance', 'O', 'O', 'O', 'O', 'B-Value', 'O', 'I-Value', 'B-Unit', 'I-Unit', 'I-Unit', 'B-Substance', 'I-Substance', 'I-Substance', 'I-Substance', 'O', 'O', 'O', 'B-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'I-Usage', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [24]:
ground_truths = [[bert_model.id2label.get(x, 'O') for x in ner_labels] for ner_labels in bert_model.dataset['eval']['labels']]

In [31]:
print(len(ground_truths[41]), len(token_classes))
print(list(zip(ground_truths[41], token_classes)))

371 107
[('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-SKIP'), ('O', 'O'), ('O', 'O'), ('O', 'I-SKIP'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-SKIP'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-SKIP'), ('O', 'I-SKIP'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-Substance'), ('O', 'I-Substance'), ('O', 'I-Substance'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'O'), ('O', 'B-Value'), ('O', 'O'), ('B-Substance', 'I-Value'), ('O', 'B-Unit'), ('O', 'I-Unit'), ('I-

In [21]:
tok.batch_encode_plus([eval_original_texts[0]], return_tensors="pt")

{'input_ids': tensor([[  101, 11772,  1626,   131,  6702,  1105, 26500,  8544, 10460,  3731,
          1545,   118, 16972, 17516,  9664, 15517,  9919,   143,  9565,   155,
          9919,  9949, 24846,  1708, 11345, 26546,  2924, 23096, 13371, 12507,
         22219,  2591, 13472, 15969,   143,  2346, 15609, 12859, 17482,  1204,
           139,   118,   156, 27934,  1706,  2879,  3923,  1116,  1111, 11336,
          5053, 21405,  1116,  1104,  1203, 10854, 26500,   109,  3731,  1545,
           119,  4062,  1495, 12558, 19888,  1673,   119,   138, 15745,  1104,
           121,   119,   122,  1226,  1679,  1550,  9415, 15265,  2042,  2259,
          1110,  1628,  1111, 24525, 14749,  1104,   185,  6094, 21001,  1105,
           188, 17679,   119,   164,  3324,   143,  2069, 11866, 16382,   117,
         23844,   119,  1853,   117,  1729,   166,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [28]:
inputs = tok(eval_original_texts[0], return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]

In [37]:
logits.shape, inputs['input_ids'].shape, predictions.shape

(torch.Size([1, 107, 19]), torch.Size([1, 107]), torch.Size([1, 107]))