In [22]:
import os
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")

MAX_LEN=512

def parse_ann_file(ann_path):
    """
    Parse annotation files.
    """
    entities = []
    with open(ann_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if parts[0].startswith("T"):
                label_and_span, text = parts[1], parts[2]
                label, span = label_and_span.split(" ", 1)
                if label in {"SIGN", "SYMPTOM"}:
                    ranges = span.split(";")
                    spans = [(int(start), int(end)) for start, end in (r.split() for r in ranges)]
                    entities.append((label, spans, text))
    return entities

def align_tokens_and_labels(text, entities, tokenizer):
    """
    Align tokens with BIO labels.
    """
    text = text.strip()
    tokenized = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    tokens = tokenizer.tokenize(text, add_special_tokens=False)
    token_spans = tokenized["offset_mapping"]
    labels = ["O"] * len(tokens)

    for label, spans, entity_text in entities:
        for start, end in spans:
            for idx, (tok_start, tok_end) in enumerate(token_spans):
                if tok_start >= start and tok_end <= end:
                    if tok_start == start:
                        labels[idx] = "B-HPO"
                    else:
                        labels[idx] = "I-HPO"


    aligned_labels = []
    for token, label in zip(tokens, labels):
        subwords = tokenizer.tokenize(token)
        if len(subwords) == 1:
            aligned_labels.append(label)
        else:
            aligned_labels.append(label)
            aligned_labels.extend(["I-HPO" if label != "O" else "O"] * (len(subwords) - 1))

    return tokens, aligned_labels

def preprocess_data(folder_path, tokenizer):
    data = []
    for file in os.listdir(folder_path):
        if file.endswith(".txt"):
            txt_path = os.path.join(folder_path, file)
            ann_path = txt_path.replace(".txt", ".ann")

            with open(txt_path, 'r') as f:
                text = f.read()

            if os.path.exists(ann_path):
                entities = parse_ann_file(ann_path)
                tokens, labels = align_tokens_and_labels(text, entities, tokenizer)
                data.append((tokens, labels))
    return data


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Preprocess data
train_data = preprocess_data("datasets/RareDis-v1/train", tokenizer)
dev_data = preprocess_data("datasets/RareDis-v1/dev", tokenizer)
test_data = preprocess_data("datasets/RareDis-v1/test", tokenizer)

In [23]:
from torch.utils.data import Dataset

class NERDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_map = {"O": 0, "B-HPO": 1, "I-HPO": 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, labels = self.data[idx]
        encoded = self.tokenizer(
                                tokens,
                                is_split_into_words=True,
                                padding="max_length",
                                truncation=True,
                                max_length=self.max_len,
                                return_tensors="pt",
                                )

        label_ids = [self.label_map[label] for label in labels]
        label_ids = label_ids[: self.max_len]
        label_ids += [0] * (self.max_len - len(label_ids))

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }

# Create dataset
train_dataset = NERDataset(train_data, tokenizer)
dev_dataset = NERDataset(dev_data, tokenizer)
test_dataset = NERDataset(test_data, tokenizer)

In [33]:
import torch
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(train_dataset.label_map))
for param in model.parameters():
    if not param.is_contiguous():
        param.data = param.data.contiguous()

data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)

training_args = TrainingArguments(
    output_dir="./ner_model",
    logging_strategy='epoch',
    eval_strategy='epoch',
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,0.3595,0.253984
2,0.1937,0.219017
3,0.1374,0.232527


TrainOutput(global_step=138, training_loss=0.23018282392750616, metrics={'train_runtime': 16.7516, 'train_samples_per_second': 130.555, 'train_steps_per_second': 8.238, 'total_flos': 142865293370112.0, 'train_loss': 0.23018282392750616, 'epoch': 3.0})

In [34]:
metrics = trainer.evaluate(test_dataset)
print(metrics)

trainer.save_model("./ner_model")
tokenizer.save_pretrained("./ner_model")

{'eval_loss': 0.22048313915729523, 'eval_runtime': 0.6824, 'eval_samples_per_second': 304.818, 'eval_steps_per_second': 19.051, 'epoch': 3.0}


('./ner_model/tokenizer_config.json',
 './ner_model/special_tokens_map.json',
 './ner_model/vocab.txt',
 './ner_model/added_tokens.json',
 './ner_model/tokenizer.json')