# Bio_ClinicalBERT Finetuning(synthetic data) and Llama Accuracy check

### Preprocessing the data
    - Read text and annotation from file
    - Tokenize text
    - Align with annotations
    - Generate BIO tags (labels)

In [32]:
import re
import os
import random
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")

def align_tokens_and_labels(text, entities, tokenizer):
    """
    Align tokens with BIO labels.
    """
    text = text.strip()
    tokenized = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    tokens = tokenizer.tokenize(text, add_special_tokens=False)
    token_spans = tokenized["offset_mapping"]
    labels = ["O"] * len(token_spans)

    for spans, _ in entities:
        start, end = spans
        is_first = True
        for idx, (tok_start, tok_end) in enumerate(token_spans):
            if tok_start >= start and tok_end <= end:
                if is_first:
                    labels[idx] = "B-HPO"
                    is_first = False
                else:
                    labels[idx] = "I-HPO"

    aligned_labels = []
    for idx, (token, label) in enumerate(zip(tokens, labels)):
        subwords = tokenizer.tokenize(token)
        if len(subwords) == 1:
            aligned_labels.append(label)
        else:
            aligned_labels.append(label)
            aligned_labels.extend(["I-HPO" if label != "O" else "O"] * (len(subwords) - 1))

    return tokens, aligned_labels

def preprocess_data(file_path, tokenizer):
    data = []

    with open('datasets/hpo_dataset.txt', 'r') as f:
        file_content = f.read()

    blocks = file_content.strip().split('\n---\n')
    print(f'total_examples:', len(blocks))
        
    for block in blocks:
        clinical_text_match = re.search(r"\*\*Clinical Text\*\*:\s*(.*?)(?=\*\*Annotations\*\*:|\Z)", block, flags=re.DOTALL)
        annotations_match = re.findall(r"T(\d+)\s+LABEL\s+(\d+)\s+(\d+)\s+(.*)", block)
        
        if clinical_text_match:
            clinical_text = clinical_text_match.group(1).strip().split('"')[1]
            
            entities = []
            for annotation in annotations_match:
                associated_text = annotation[3].strip()
                matches = re.finditer(re.escape(associated_text), clinical_text)
                for _ in matches:
                    entities.append(((_.start(), _.end()), _.group()))
        
        tokens, labels = align_tokens_and_labels(clinical_text, entities, tokenizer)
        data.append((tokens, labels))

    return data


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Preprocess data
data = preprocess_data("datasets/hpo_dataset.txt", tokenizer)
random.shuffle(data)
test_size = 0.1
split_idx = int(len(data) * (1 - test_size))
train_data = data[:split_idx]
test_data = data[split_idx:]

total_examples: 104


### Custom Dataset

In [33]:
from torch.utils.data import Dataset
MAX_LEN = 512

class HPODataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.label_map = {"O": 0, "B-HPO": 1, "I-HPO": 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, labels = self.data[idx]
        encoded = self.tokenizer(tokens,
                                is_split_into_words=True,
                                padding="max_length",
                                truncation=True,
                                max_length=MAX_LEN,
                                return_tensors="pt",
                                )

        label_ids = [self.label_map[label] for label in labels]
        label_ids = label_ids[:MAX_LEN]
        label_ids += [0] * (MAX_LEN - len(label_ids))

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }


train_dataset = HPODataset(train_data, tokenizer)
test_dataset = HPODataset(test_data, tokenizer)

### Finetuning Step

In [97]:
import torch
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(train_dataset.label_map))
for param in model.parameters():
    if not param.is_contiguous():
        param.data = param.data.contiguous()

data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)

training_args = TrainingArguments(
    output_dir="./saved_hpo_bert_synthetic3",
    logging_strategy='epoch',
    eval_strategy='epoch',
    learning_rate=6e-6,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)

save_path = './saved_hpo_bert_synthetic3'
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,0.9773,0.832254
2,0.8232,0.701389
3,0.6925,0.595237
4,0.5879,0.509175
5,0.5025,0.440618
6,0.4355,0.387476
7,0.387,0.347988
8,0.3504,0.320591
9,0.3238,0.304126
10,0.313,0.297621


{'eval_loss': 0.2976209819316864, 'eval_runtime': 0.076, 'eval_samples_per_second': 144.747, 'eval_steps_per_second': 13.159, 'epoch': 10.0}


('./saved_hpo_bert_synthetic3/tokenizer_config.json',
 './saved_hpo_bert_synthetic3/special_tokens_map.json',
 './saved_hpo_bert_synthetic3/vocab.txt',
 './saved_hpo_bert_synthetic3/added_tokens.json')

### Extraction Code and Accuracy test

In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForCausalLM
import torch

# best synthetic2
bert_model_path = './saved_hpo_bert_synthetic2'
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
bert_model = AutoModelForTokenClassification.from_pretrained(bert_model_path)
bert_model.eval()
labels = {0: "O", 1: "B-HPO", 2:"I-HPO"}

model_name = '../generative_models/saved_llama_3.2_3B_ins'
llama_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
llama_tokenizer = AutoTokenizer.from_pretrained(model_name)
disclamer = "Disclamer:\nThe information provided is for educational purposes and should not replace professional medical advice. Individuals should consult healthcare professionals or local health authorities for personalized guidance."

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
def get_phenotypes(text):
    """
    Extract HPO terms from text.
    
    Arg:
    - text (str): Input text for NER.
    
    Returns:
    - List of recognized HPO terms.
    """

    inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    
    predictions = torch.argmax(logits, dim=-1)

    tokens = bert_tokenizer.convert_ids_to_tokens(input_ids[0])
    pred_labels = predictions[0].tolist()
    
    recognized_entities = []
    current_entity = []
    
    for token, label_id in zip(tokens, pred_labels):
        label = labels.get(label_id, "O")

        if label == "B-HPO":
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
            current_entity = [token]
        elif label == "I-HPO":
            current_entity.append(token)
        else:
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
                current_entity = []

    if current_entity:
        recognized_entities.append(" ".join(current_entity))
    
    hpo_terms = [" ".join(e.replace(" ##", "").replace("##", "") for e in entity.split()) for entity in recognized_entities]
    
    return hpo_terms

def get_diagnosis(text, with_bert=False):
    
    if with_bert:
        text = get_phenotypes(text)

    messages = [
        {"role": "system", "content": "You are an expert and experienced from the healthcare and biomedical domain with extensive medical knowledge and practical experience. Diagnose the condition based on given text. Also I'm just using it for my project so be consistent every time."},
        {"role": "user", "content": f"Based on this text give me the disease name only. TEXT: {text}"},
    ]

    tokenized_message = llama_tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", return_dict=True)
    response_token_ids = llama_model.generate(tokenized_message['input_ids'].cuda(),
                                              attention_mask=tokenized_message['attention_mask'].cuda(),
                                              max_new_tokens=128, 
                                              pad_token_id = llama_tokenizer.eos_token_id
                                             )
    generated_tokens =response_token_ids[:, len(tokenized_message['input_ids'][0]):]
    diagnosis = llama_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return ' '.join(diagnosis.split()[1:])


def compare(y, y1, y2, y3):
    
    yl = len(y)
    
    if yl <= len(y1):
        if y in y1:
            return True
    else:
        if y1 in y:
            return True

    if yl <= len(y2):
        if y in y2:
            return True
    else:
        if y2 in y:
            return True
        
    if yl <= len(y3):
        if y in y3:
            return True
    else:
        if y3 in y:
            return True
    
    return False

In [3]:
# On synthetic testset

import pandas as pd

df = pd.read_csv('datasets/symptom_disease.csv')
df

Unnamed: 0,text,disease
0,"I’ve been feeling really tired, and I’m notici...",Congestive Heart Failure
1,"My throat hurts when I swallow, and I have a f...",Streptococcal Pharyngitis (Strep Throat)
2,I have sharp chest pain that gets worse with d...,Pneumothorax
3,"I have frequent urination, excessive thirst, a...",Diabetes Mellitus (Type 1 or Type 2)
4,"My joints hurt and are swollen, especially in ...",Rheumatoid Arthritis
...,...,...
95,I feel full after eating only a small amount o...,Gastroparesis
96,"I’ve been losing weight without trying, and I’...",Hyperthyroidism
97,"I have a constant headache, sensitivity to lig...",Migraine
98,"I’m feeling lightheaded, with blurred vision, ...",Hypoglycemia


In [4]:
df['output1'] = df['text'].apply(lambda x: get_diagnosis(x))
df['output2'] = df['text'].apply(lambda x: get_diagnosis(x))
df['output3'] = df['text'].apply(lambda x: get_diagnosis(x))
df['is_correct'] = df.apply(lambda x: compare(x['disease'], x['output1'], x['output2'], x['output3']), axis=1)
print("Accuracy of Llama without bert: ", (df['is_correct'].value_counts()[True])/len(df))

Accuracy of Llama without bert:  0.35


In [5]:
df['b1'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['b2'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['b3'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['is_correct2'] = df.apply(lambda x: compare(x['disease'], x['b1'], x['b2'], x['b3']), axis=1)

if True in df['is_correct2'].values:
    print("Accuracy of Llama bert synthetic: ", (df['is_correct2'].value_counts()[True])/len(df))
else:
    print("Cannot compute accuracy")

Accuracy of Llama bert synthetic:  0.21


In [7]:
# On symptom2disease dataset
import pandas as pd

df2 = pd.read_csv('datasets/Symptom2Disease.csv', usecols=['text', 'label'])
df2 = df2.groupby('label').sample(n=4)
df2

Unnamed: 0,label,text
568,Acne,I've been noticing a really nasty rash on my s...
588,Acne,I woke up today to find that I had a major ras...
559,Acne,Lately I've been experiencing a skin rash with...
571,Acne,I woke up this morning to find a really nasty ...
501,Arthritis,I've been feeling really weak in my muscles an...
...,...,...
1113,peptic ulcer disease,My bloody stools have caused me to lose iron a...
905,urinary tract infection,"I have pain in my abdomen, and often get fever..."
939,urinary tract infection,"My spirits have been incredibly low, and my pe..."
923,urinary tract infection,I have a mild temperature and blood in my pee....


In [8]:
df2['output1'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['output2'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['output3'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['is_correct'] = df2.apply(lambda x: compare(x['label'], x['output1'], x['output2'], x['output3']), axis=1)
print("Accuracy of Llama without bert: ", (df2['is_correct'].value_counts()[True])/len(df2))

df2['b1'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['b2'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['b3'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['is_correct2'] = df2.apply(lambda x: compare(x['label'], x['b1'], x['b2'], x['b3']), axis=1)

if True in df2['is_correct2'].values:
    print("Accuracy of Llama bert synthetic: ", (df2['is_correct2'].value_counts()[True])/len(df2))
else:
    print("Cannot compute accuracy")

Accuracy of Llama without bert:  0.2916666666666667
Accuracy of Llama bert synthetic:  0.13541666666666666
