# Bio_ClinicalBERT Finetuning

### Preprocessing the data
    - Read text and annotation files
    - Tokenize text
    - Align with annotations
    - Generate BIO tags (labels)

In [1]:
import os
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")


def parse_ann_file(ann_path):
    """
    Parse annotation files.
    """
    entities = []
    with open(ann_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if parts[0].startswith("T"):
                label_and_span, text = parts[1], parts[2]
                label, span = label_and_span.split(" ", 1)
                if label in {"SIGN", "SYMPTOM"}:
                    ranges = span.split(";")
                    spans = [(int(start), int(end)) for start, end in (r.split() for r in ranges)]
                    entities.append((label, spans, text))
    return entities

def align_tokens_and_labels(text, entities, tokenizer):
    """
    Align tokens with BIO labels.
    """
    text = text.strip()
    tokenized = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    tokens = tokenizer.tokenize(text, add_special_tokens=False)
    token_spans = tokenized["offset_mapping"]
    labels = ["O"] * len(tokens)

    for label, spans, entity_text in entities:
        for start, end in spans:
            for idx, (tok_start, tok_end) in enumerate(token_spans):
                if tok_start >= start and tok_end <= end:
                    if tok_start == start:
                        labels[idx] = "B-HPO"
                    else:
                        labels[idx] = "I-HPO"


    aligned_labels = []
    for token, label in zip(tokens, labels):
        subwords = tokenizer.tokenize(token)
        if len(subwords) == 1:
            aligned_labels.append(label)
        else:
            aligned_labels.append(label)
            aligned_labels.extend(["I-HPO" if label != "O" else "O"] * (len(subwords) - 1))

    return tokens, aligned_labels

def preprocess_data(folder_path, tokenizer):
    data = []
    for file in os.listdir(folder_path):
        if file.endswith(".txt"):
            txt_path = os.path.join(folder_path, file)
            ann_path = txt_path.replace(".txt", ".ann")

            with open(txt_path, 'r') as f:
                text = f.read()

            if os.path.exists(ann_path):
                entities = parse_ann_file(ann_path)
                tokens, labels = align_tokens_and_labels(text, entities, tokenizer)
                data.append((tokens, labels))
    return data


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Preprocess data
train_data = preprocess_data("datasets/RareDis-v1/train", tokenizer)
dev_data = preprocess_data("datasets/RareDis-v1/dev", tokenizer)
test_data = preprocess_data("datasets/RareDis-v1/test", tokenizer)

### Custom Dataset

In [2]:
from torch.utils.data import Dataset
MAX_LEN = 512

class HPODataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.label_map = {"O": 0, "B-HPO": 1, "I-HPO": 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, labels = self.data[idx]
        encoded = self.tokenizer(tokens,
                                is_split_into_words=True,
                                padding="max_length",
                                truncation=True,
                                max_length=MAX_LEN,
                                return_tensors="pt",
                                )

        label_ids = [self.label_map[label] for label in labels]
        label_ids = label_ids[:MAX_LEN]
        label_ids += [0] * (MAX_LEN - len(label_ids))

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }


train_dataset = HPODataset(train_data, tokenizer)
dev_dataset = HPODataset(dev_data, tokenizer)
test_dataset = HPODataset(test_data, tokenizer)

### Finetuning Step

In [5]:
import torch
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(train_dataset.label_map))
for param in model.parameters():
    if not param.is_contiguous():
        param.data = param.data.contiguous()

data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)

training_args = TrainingArguments(
    output_dir="./ner_model",
    logging_strategy='epoch',
    eval_strategy='epoch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,0.2835,0.171244
2,0.1485,0.13822
3,0.1041,0.139376


TrainOutput(global_step=138, training_loss=0.17871321802553924, metrics={'train_runtime': 50.4691, 'train_samples_per_second': 43.333, 'train_steps_per_second': 2.734, 'total_flos': 571461173480448.0, 'train_loss': 0.17871321802553924, 'epoch': 3.0})

In [6]:
# Evaluate and save
metrics = trainer.evaluate(test_dataset)
print(metrics)

trainer.save_model("./saved_hpo_bert")
tokenizer.save_pretrained("./saved_hpo_bert")

{'eval_loss': 0.14030082523822784, 'eval_runtime': 1.689, 'eval_samples_per_second': 123.146, 'eval_steps_per_second': 7.697, 'epoch': 3.0}


('./saved_hpo_bert/tokenizer_config.json',
 './saved_hpo_bert/special_tokens_map.json',
 './saved_hpo_bert/vocab.txt',
 './saved_hpo_bert/added_tokens.json',
 './saved_hpo_bert/tokenizer.json')

In [14]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from torch.nn.functional import softmax
import torch

model_path = './saved_hpo_bert'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
model.eval()
labels = {0: "O", 1: "B-HPO", 2:"I-HPO"}

def get_hpo_terms(text):
    """
    Extract HPO terms from text.
    
    Arg:
    - text (str): Input text for NER.
    
    Returns:
    - List of recognized HPO terms.
    """

    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    
    predictions = torch.argmax(logits, dim=-1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    pred_labels = predictions[0].tolist()
    
    recognized_entities = []
    current_entity = []
    
    for token, label_id in zip(tokens, pred_labels):
        label = labels.get(label_id, "O")

        if label == "B-HPO":
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
            current_entity = [token]
        elif label == "I-HPO":
            if current_entity:
                current_entity.append(token)
        else:
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
                current_entity = []

    if current_entity:
        recognized_entities.append(" ".join(current_entity))
    
    hpo_terms = [" ".join(e.replace(" ##", "").replace("##", "") for e in entity.split()) for entity in recognized_entities]
    
    return hpo_terms

In [24]:
import re
text = """Acanthocheilonemiasis is a rare tropical infectious disease caused by a parasite known as Acanthocheilonema perstans, which belongs to a group of parasitic diseases known as filarial diseases (nematode). This parasite is found, for the most part, in Africa. Symptoms of infection may include red, itchy skin (pruritis), abdominal and chest pain, muscular pain (myalgia), and areas of localized swelling (edema). In addition, the liver and spleen may become abnormally enlarged (hepatosplenomegaly). Laboratory testing may also reveal abnormally elevated levels of certain specialized white blood cells (eosinophilia). The parasite is transmitted through the bite of small flies (A. coliroides). Acanthocheilonemiasis is a rare infectious disease caused by long “thread-like” worms, Acanthocheilonema perstans, also known as Dipetalonema perstans. The disease is transmitted by a small black insect (midge), called A. Cailicoides. Acanthocheilonema perstans, the parasite that causes Acanthocheilonemiasis is common in central Africa and in some areas of South America.  This disorder affects males and females in equal numbers. Acanthocheilonemiasis is treated by means of the administration of antifilarial drugs, some of which are newer than others. Ivermectin or diethyl-carbamazine (DEC) are frequently prescribed. Occasionally, surgery may be required to remove large adult worms. Mild cases of acanthocheilonemiasis do not require treatment.
"""
text = re.sub(r'[^a-zA-Z\s]', '', text.strip())
print(get_hpo_terms(text))

['include red it chy skin p', 'and chest pain muscular pain my', 'and', 'become abnormal']
