<a href="https://colab.research.google.com/github/sajjkavinda/biomedical-named-entity-recognition/blob/main/biomedical_named_entity_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Install dependencies
!pip install -q --upgrade transformers datasets==2.19.1 evaluate seqeval accelerate peft bioc

In [None]:
#Imports
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers.training_args import IntervalStrategy
import evaluate
from peft import LoraConfig, get_peft_model, TaskType
from peft import PromptTuningConfig, PromptTuningInit
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt

In [None]:
#Load BC5CDR dataset
dataset = load_dataset(
    "tner/bc5cdr",
    trust_remote_code=True
)
print(dataset)
print(dataset["train"].features)
print(dataset["train"][0])

In [None]:
#Label mapping
label_list = [
    "O",
    "B-Chemical",
    "I-Chemical",
    "B-Disease",
    "I-Disease"
]

id2label = {i: label for i, label in enumerate(label_list)}
label2id = {label: i for i, label in enumerate(label_list)}

num_labels = len(label_list)

print(id2label)

In [None]:
#Tokenization
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

#Align labels
def tokenize_and_align_labels(examples):

    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True
    )
    labels = []

    for i, label in enumerate(examples["tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_ds = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset["train"].column_names
)

data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
#Evaluation metrics
metric = evaluate.load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = predictions.argmax(axis=-1)
    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"]
    }

In [None]:
#Baseline model (Without fine-tuning)
baseline_model = AutoModelForTokenClassification.from_pretrained(
    "dmis-lab/biobert-base-cased-v1.1",
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)

#Training configurations
training_args_baseline = TrainingArguments(
    output_dir="./baseline_results",
    per_device_eval_batch_size=8,
    report_to="none"
)

trainer_baseline = Trainer(
    model=baseline_model,
    args=training_args_baseline,
    eval_dataset=tokenized_ds["test"],
    compute_metrics=compute_metrics,
    data_collator=data_collator
)

baseline_results = trainer_baseline.evaluate()
baseline_results

In [None]:
#Prompt tuning config
prompt_config = PromptTuningConfig(
    task_type=TaskType.TOKEN_CLS,
    num_virtual_tokens=20,
    prompt_tuning_init=PromptTuningInit.TEXT,
    prompt_tuning_init_text="Recognize disease entities in biomedical text.",
    tokenizer_name_or_path="dmis-lab/biobert-base-cased-v1.1",
    num_layers=12,
    token_dim=768,
    num_attention_heads=12
)

#Custom data collator for prompt tuning
class DataCollatorForPromptTuning(DataCollatorForTokenClassification):
    def __init__(self, tokenizer, num_virtual_tokens, **kwargs):
        super().__init__(tokenizer, **kwargs)
        self.num_virtual_tokens = num_virtual_tokens

    def __call__(self, features):
        batch = super().__call__(features)
        virtual_token_labels = torch.full(
            (batch["labels"].shape[0], self.num_virtual_tokens),
            -100,
            dtype=batch["labels"].dtype,
            device=batch["labels"].device
        )
        batch["labels"] = torch.cat([virtual_token_labels, batch["labels"]], dim=1)
        return batch

data_collator = DataCollatorForPromptTuning(tokenizer=tokenizer, num_virtual_tokens=prompt_config.num_virtual_tokens, label_pad_token_id=-100)

#Apply prompt tuning
model = get_peft_model(baseline_model, prompt_config)
model.print_trainable_parameters()

#Training arguments
training_args = TrainingArguments(
    output_dir="./ncbi_prompt_tuning",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=50,
    save_strategy="epoch",
    eval_strategy="epoch",
    report_to="none"
)

#Evaluation metrics
def compute_metrics(pred):
    labels = pred.label_ids.flatten()
    preds = pred.predictions.argmax(-1).flatten()
    mask = labels != -100
    if np.sum(mask) == 0:
        return {
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0,
            "accuracy": 1.0
        }
    return {
        "precision": precision_score(labels[mask], preds[mask], average="macro", zero_division=0),
        "recall": recall_score(labels[mask], preds[mask], average="macro", zero_division=0),
        "f1": f1_score(labels[mask], preds[mask], average="macro", zero_division=0),
        "accuracy": accuracy_score(labels[mask], preds[mask])
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
prompt_results = trainer.evaluate()
prompt_results

In [None]:
#LoRA implementation
lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,
    lora_alpha=32,
    target_modules=["query", "value"],
    lora_dropout=0.05,
    bias="none"
)

lora_model = AutoModelForTokenClassification.from_pretrained(
    "dmis-lab/biobert-base-cased-v1.1",
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)

lora_model = get_peft_model(lora_model, lora_config)
lora_model.print_trainable_parameters()

data_collator_lora = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-100)

#LoRA training configurations
training_args_lora = TrainingArguments(
    output_dir="./lora_results",
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    report_to="none",
    fp16=True
)

trainer_lora = Trainer(
    model=lora_model,
    args=training_args_lora,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    data_collator=data_collator_lora,
    compute_metrics=compute_metrics
)

trainer_lora.train()
lora_results = trainer_lora.evaluate()
lora_results


In [None]:
#Get actual keys from the results dictionary
metrics_list = [k for k in baseline_results.keys() if k.startswith("eval_") and k != "eval_runtime" and k != "eval_samples_per_second" and k != "eval_steps_per_second" and k != "eval_model_preparation_time"]

#Extract scores
baseline_scores = [baseline_results[m] for m in metrics_list]
prompt_scores   = [prompt_results[m] for m in metrics_list]
lora_scores     = [lora_results[m] for m in metrics_list]

#Clean metric names
metric_labels = [m.replace("eval_", "").capitalize() for m in metrics_list]

x = np.arange(len(metric_labels))
width = 0.25

#Display comparision
plt.figure(figsize=(10,5))
plt.bar(x - width, baseline_scores, width, label="Baseline")
plt.bar(x, prompt_scores, width, label="Prompt Tuning")
plt.bar(x + width, lora_scores, width, label="LoRA PEFT")
plt.xticks(x, metric_labels)
plt.ylabel("Score")
plt.ylim(0,1.05)
plt.title("tner/bc5cdr NER: Baseline vs Prompt Tuning vs LoRA PEFT")
plt.legend()
plt.grid(axis="y", linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()
