# Fine-tune BERT for tissue and cell type NER
Fine tune existing bioNER models for tissue and cell type identification in abstracts. Find optimal parameters and save the final model. 

In [11]:
from datasets import Dataset, Features, Sequence, Value, ClassLabel
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
import evaluate
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import precision_recall_fscore_support
import random
from iob_functions import *

random.seed(6002)

In [12]:
BASE_DIR = "../data/"

training_f = process_tab_delim_iob(BASE_DIR + 'tags/fulltext_iob/fulltext_tissues_train.iob')
training_a = process_tab_delim_iob(BASE_DIR + 'tags/abstract_iob/abstract_tissues_train.iob')
training = {'sentences': training_f['sentences'] + training_a['sentences'], 'tags': training_f['tags'] + training_a['tags']}

valid_f = process_tab_delim_iob(BASE_DIR + 'tags/fulltext_iob/fulltext_tissues_validation.iob')
valid_a = process_tab_delim_iob(BASE_DIR + 'tags/abstract_iob/abstract_tissues_validation.iob')
validation = {'sentences': valid_f['sentences'] + valid_a['sentences'], 'tags': valid_f['tags'] + valid_a['tags']}

test_f = process_tab_delim_iob(BASE_DIR + 'tags/fulltext_iob/fulltext_tissues_test.iob')
test_a = process_tab_delim_iob(BASE_DIR + 'tags/abstract_iob/abstract_tissues_test.iob')
test = {'sentences': test_f['sentences'] + test_a['sentences'], 'tags': test_f['tags'] + test_a['tags']}

In [13]:
features = Features({"tokens": Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
                     "tags": Sequence(feature=ClassLabel(names=["O", "B-CELL_TYPE", "I-CELL_TYPE", "B-TISSUE", "I-TISSUE"]))})

In [14]:
tag_stats(validation)

{'O': 80750,
 'B-CELL_TYPE': 339,
 'I-CELL_TYPE': 213,
 'B-TISSUE': 415,
 'I-TISSUE': 88}

In [15]:
tag_stats(training)

{'O': 484660,
 'B-CELL_TYPE': 3611,
 'I-CELL_TYPE': 2654,
 'B-TISSUE': 3369,
 'I-TISSUE': 559}

In [16]:
tag_stats(test)

{'O': 134071,
 'B-CELL_TYPE': 1383,
 'B-TISSUE': 985,
 'I-CELL_TYPE': 917,
 'I-TISSUE': 160}

In [17]:
training_ds = Dataset.from_dict({"tokens": training['sentences'], "tags": training['tags']}, features=features)
validation_ds = Dataset.from_dict({"tokens": validation['sentences'], "tags": validation['tags']}, features=features)
test_ds = Dataset.from_dict({"tokens": test['sentences'], "tags": test['tags']}, features=features)

In [18]:
all_tags = training_ds.features["tags"].feature
tag_list = training_ds.features["tags"].feature.names
id2tag = {idx: tag for idx, tag in enumerate(all_tags.names)}
tag2id = {tag: idx for idx, tag in enumerate(all_tags.names)}

In [19]:
# fine-tune the best performing model
m = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
MAX_LENGTH = 256

In [20]:
# from https://huggingface.co/docs/transformers/tasks/token_classification
def tokenize_and_align_labels(data, tknzr, max_length=50):
    tokenized_inputs = tknzr(data['tokens'], truncation=True, is_split_into_words=True, max_length=max_length)

    labels = []
    for i, label in enumerate(data['tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [21]:
from sklearn.metrics import precision_recall_fscore_support

seqeval = evaluate.load("seqeval")

def flatten(l):
    return [item for sublist in l for item in sublist]

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [tag_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [tag_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)

    f1res = precision_recall_fscore_support(flatten(true_labels), flatten(true_predictions), labels=all_tags.names)

    df = list(zip(all_tags.names, f1res[2], f1res[0], f1res[1]))
    df = pd.DataFrame(df, columns = ['Level', 'F1-Score', 'Precision', 'Recall'])   
    print(df)

    return results

In [22]:
# Model Training

tokenizer = AutoTokenizer.from_pretrained(m)
# dynamically pad sentences to longest length in batch for efficiency
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

train_tokenized = training_ds.map(tokenize_and_align_labels, batched=True, fn_kwargs={'tknzr': tokenizer, 'max_length': MAX_LENGTH})
val_tokenized = validation_ds.map(tokenize_and_align_labels, batched=True, fn_kwargs={'tknzr': tokenizer, 'max_length': MAX_LENGTH})
test_tokenized = test_ds.map(tokenize_and_align_labels, batched=True, fn_kwargs={'tknzr': tokenizer, 'max_length': MAX_LENGTH})

model = AutoModelForTokenClassification.from_pretrained(
    m, num_labels=5, id2label=id2tag, label2id=tag2id
)

training_args = TrainingArguments(
    output_dir="model/" + m,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

test_preds = trainer.predict(test_tokenized)

Map:   0%|          | 0/20596 [00:00<?, ? examples/s]

Map:   0%|          | 0/3499 [00:00<?, ? examples/s]

Map:   0%|          | 0/5753 [00:00<?, ? examples/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Cell Type,Tissue,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
1,0.0148,0.0095,"{'precision': 0.7203389830508474, 'recall': 0.7434402332361516, 'f1': 0.7317073170731706, 'number': 343}","{'precision': 0.7926565874730022, 'recall': 0.8822115384615384, 'f1': 0.8350398179749716, 'number': 416}",0.761322,0.819499,0.78934,0.996135
2,0.0093,0.010256,"{'precision': 0.736, 'recall': 0.8046647230320699, 'f1': 0.7688022284122563, 'number': 343}","{'precision': 0.8108108108108109, 'recall': 0.8653846153846154, 'f1': 0.8372093023255813, 'number': 416}",0.776557,0.837945,0.806084,0.996563
3,0.0061,0.011933,"{'precision': 0.7225433526011561, 'recall': 0.7288629737609329, 'f1': 0.725689404934688, 'number': 343}","{'precision': 0.8310185185185185, 'recall': 0.8629807692307693, 'f1': 0.8466981132075472, 'number': 416}",0.782776,0.802372,0.792453,0.996441
4,0.0026,0.013901,"{'precision': 0.7207446808510638, 'recall': 0.7900874635568513, 'f1': 0.7538247566063978, 'number': 343}","{'precision': 0.8356164383561644, 'recall': 0.8798076923076923, 'f1': 0.8571428571428572, 'number': 416}",0.782555,0.839262,0.809917,0.996625
5,0.001,0.016297,"{'precision': 0.7384196185286104, 'recall': 0.7900874635568513, 'f1': 0.7633802816901408, 'number': 343}","{'precision': 0.8048245614035088, 'recall': 0.8822115384615384, 'f1': 0.8417431192660549, 'number': 416}",0.775213,0.84058,0.806574,0.996539


         Level  F1-Score  Precision    Recall
0            O  0.998407   0.998674  0.998142
1  B-CELL_TYPE  0.765832   0.764706  0.766962
2  I-CELL_TYPE  0.815348   0.833333  0.798122
3     B-TISSUE  0.873563   0.835165  0.915663
4     I-TISSUE  0.844920   0.797980  0.897727




         Level  F1-Score  Precision    Recall
0            O  0.998637   0.998909  0.998365
1  B-CELL_TYPE  0.795455   0.767123  0.825959
2  I-CELL_TYPE  0.839161   0.833333  0.845070
3     B-TISSUE  0.879813   0.852941  0.908434
4     I-TISSUE  0.841463   0.907895  0.784091




         Level  F1-Score  Precision    Recall
0            O  0.998631   0.998514  0.998749
1  B-CELL_TYPE  0.775811   0.775811  0.775811
2  I-CELL_TYPE  0.802083   0.900585  0.723005
3     B-TISSUE  0.880473   0.865116  0.896386
4     I-TISSUE  0.826087   0.791667  0.863636




         Level  F1-Score  Precision    Recall
0            O  0.998680   0.998983  0.998377
1  B-CELL_TYPE  0.796610   0.764228  0.831858
2  I-CELL_TYPE  0.831354   0.841346  0.821596
3     B-TISSUE  0.883117   0.865741  0.901205
4     I-TISSUE  0.863388   0.831579  0.897727




         Level  F1-Score  Precision    Recall
0            O  0.998661   0.998996  0.998327
1  B-CELL_TYPE  0.791966   0.770950  0.814159
2  I-CELL_TYPE  0.839329   0.857843  0.821596
3     B-TISSUE  0.870968   0.834437  0.910843
4     I-TISSUE  0.857143   0.829787  0.886364




         Level  F1-Score  Precision    Recall
0            O  0.997209   0.998533  0.995888
1  B-CELL_TYPE  0.833112   0.770603  0.906657
2  I-CELL_TYPE  0.881210   0.871795  0.890830
3     B-TISSUE  0.886905   0.867119  0.907614
4     I-TISSUE  0.772603   0.687805  0.881250


In [23]:
test_preds.metrics

{'test_loss': 0.017371343448758125,
 'test_CELL_TYPE': {'precision': 0.7363034316676701,
  'recall': 0.8811239193083573,
  'f1': 0.8022302394227615,
  'number': 1388},
 'test_TISSUE': {'precision': 0.8164435946462715,
  'recall': 0.867005076142132,
  'f1': 0.8409650418513048,
  'number': 985},
 'test_overall_precision': 0.7672700406353897,
 'test_overall_recall': 0.8752633796881585,
 'test_overall_f1': 0.8177165354330709,
 'test_overall_accuracy': 0.9935248239334148,
 'test_runtime': 22.8285,
 'test_samples_per_second': 252.009,
 'test_steps_per_second': 31.539}

In [24]:
# save the final model
import os
if not os.path.exists("../models"):
    os.makedirs("../models")

trainer.save_model("../models/BiomedNLP-PubMedBERT-base-uncased-abstract-fine-tuned-for-tissue-celltype")

### test cases
try some test cases to see how the model performs in practice.

In [25]:
from transformers import pipeline

tissue_cell_classifier = pipeline(
    "token-classification", model="../models/BiomedNLP-PubMedBERT-base-uncased-abstract-fine-tuned-for-tissue-celltype", aggregation_strategy="simple"
)
tissue_cell_classifier("Heart, brain, and lung are all samples that we acquired for this analysis.")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[{'entity_group': 'TISSUE',
  'score': 0.98907346,
  'word': 'heart',
  'start': 0,
  'end': 5},
 {'entity_group': 'TISSUE',
  'score': 0.9912765,
  'word': 'brain',
  'start': 7,
  'end': 12},
 {'entity_group': 'TISSUE',
  'score': 0.9949338,
  'word': 'lung',
  'start': 18,
  'end': 22}]

In [26]:
tissue_cell_classifier("Renal cell carcinoma and breast cancer have relatively good prognosis.")

[{'entity_group': 'CELL_TYPE',
  'score': 0.88980865,
  'word': 'renal cell',
  'start': 0,
  'end': 10},
 {'entity_group': 'TISSUE',
  'score': 0.99476445,
  'word': 'breast',
  'start': 25,
  'end': 31}]

In [27]:
tissue_cell_classifier("Realistic scRNA-seq Generation with Automatic Cell-Type identification using Introspective Variational Autoencoders.")

[]

In [28]:
tissue_cell_classifier("We took blood from the left arm.")

[{'entity_group': 'TISSUE',
  'score': 0.9656411,
  'word': 'blood',
  'start': 8,
  'end': 13},
 {'entity_group': 'TISSUE',
  'score': 0.77455556,
  'word': 'arm',
  'start': 28,
  'end': 31}]

In [29]:
tissue_cell_classifier("This arm of the study included 1000 participants.")

[]