# Bio_ClinicalBERT Finetuning(raredis corpus) and Llama Accuracy check

### Preprocessing the data
    - Read text and annotation files
    - Tokenize text
    - Align with annotations
    - Generate BIO tags (labels)

In [6]:
import os
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")


def parse_ann_file(ann_path):
    """
    Parse annotation files.
    """
    entities = []
    with open(ann_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if parts[0].startswith("T"):
                label_and_span, text = parts[1], parts[2]
                label, span = label_and_span.split(" ", 1)
                if label in {"SIGN", "SYMPTOM"}:
                    ranges = span.split(";")
                    spans = [(int(start), int(end)) for start, end in (r.split() for r in ranges)]
                    entities.append((label, spans, text))
    return entities

def align_tokens_and_labels(text, entities, tokenizer):
    """
    Align tokens with BIO labels.
    """
    text = text.strip()
    tokenized = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False)
    tokens = tokenizer.tokenize(text, add_special_tokens=False)
    token_spans = tokenized["offset_mapping"]
    labels = ["O"] * len(tokens)

    for label, spans, entity_text in entities:
        for start, end in spans:
            for idx, (tok_start, tok_end) in enumerate(token_spans):
                if tok_start >= start and tok_end <= end:
                    if tok_start == start:
                        labels[idx] = "B-HPO"
                    else:
                        labels[idx] = "I-HPO"


    aligned_labels = []
    for token, label in zip(tokens, labels):
        subwords = tokenizer.tokenize(token)
        if len(subwords) == 1:
            aligned_labels.append(label)
        else:
            aligned_labels.append(label)
            aligned_labels.extend(["I-HPO" if label != "O" else "O"] * (len(subwords) - 1))

    return tokens, aligned_labels

def preprocess_data(folder_path, tokenizer):
    data = []
    for file in os.listdir(folder_path):
        if file.endswith(".txt"):
            txt_path = os.path.join(folder_path, file)
            ann_path = txt_path.replace(".txt", ".ann")

            with open(txt_path, 'r') as f:
                text = f.read()

            if os.path.exists(ann_path):
                entities = parse_ann_file(ann_path)
                tokens, labels = align_tokens_and_labels(text, entities, tokenizer)
                data.append((tokens, labels))
    return data


# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Preprocess data
train_data = preprocess_data("datasets/RareDis-v1/train", tokenizer)
dev_data = preprocess_data("datasets/RareDis-v1/dev", tokenizer)
test_data = preprocess_data("datasets/RareDis-v1/test", tokenizer)

### Custom Dataset

In [7]:
from torch.utils.data import Dataset
MAX_LEN = 512

class HPODataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.label_map = {"O": 0, "B-HPO": 1, "I-HPO": 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, labels = self.data[idx]
        encoded = self.tokenizer(tokens,
                                is_split_into_words=True,
                                padding="max_length",
                                truncation=True,
                                max_length=MAX_LEN,
                                return_tensors="pt",
                                )

        label_ids = [self.label_map[label] for label in labels]
        label_ids = label_ids[:MAX_LEN]
        label_ids += [0] * (MAX_LEN - len(label_ids))

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "labels": torch.tensor(label_ids, dtype=torch.long),
        }


train_dataset = HPODataset(train_data, tokenizer)
dev_dataset = HPODataset(dev_data, tokenizer)
test_dataset = HPODataset(test_data, tokenizer)

### Finetuning Step

In [14]:
import torch
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

save_path = "./saved_hpo_bert_raredis2"
model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(train_dataset.label_map))
for param in model.parameters():
    if not param.is_contiguous():
        param.data = param.data.contiguous()

data_collator = DataCollatorForTokenClassification(tokenizer, padding=True)

training_args = TrainingArguments(
    output_dir=save_path,
    logging_strategy='epoch',
    eval_strategy='epoch',
    learning_rate=7e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

# Evaluate and save
metrics = trainer.evaluate(test_dataset)
print(metrics)

trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,0.2464,0.148825
2,0.1161,0.134647
3,0.0755,0.131026
4,0.0528,0.146272
5,0.0392,0.158709


{'eval_loss': 0.16992920637130737, 'eval_runtime': 1.6127, 'eval_samples_per_second': 128.978, 'eval_steps_per_second': 8.061, 'epoch': 5.0}


('./saved_hpo_bert_raredis2/tokenizer_config.json',
 './saved_hpo_bert_raredis2/special_tokens_map.json',
 './saved_hpo_bert_raredis2/vocab.txt',
 './saved_hpo_bert_raredis2/added_tokens.json',
 './saved_hpo_bert_raredis2/tokenizer.json')

### Extraction Code and Accuracy test

In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForCausalLM
import torch

model_name = '../generative_models/saved_llama_3.2_3B_ins'
llama_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
llama_tokenizer = AutoTokenizer.from_pretrained(model_name)
disclamer = "Disclamer:\nThe information provided is for educational purposes and should not replace professional medical advice. Individuals should consult healthcare professionals or local health authorities for personalized guidance."

bert_model_path = './saved_hpo_bert_raredis2'
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_path)
bert_model = AutoModelForTokenClassification.from_pretrained(bert_model_path)
bert_model.eval()
labels = {0: "O", 1: "B-HPO", 2:"I-HPO"}

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
def get_phenotypes(text):
    """
    Extract HPO terms from text.
    
    Arg:
    - text (str): Input text for NER.
    
    Returns:
    - List of recognized HPO terms.
    """

    inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
    
    predictions = torch.argmax(logits, dim=-1)

    tokens = bert_tokenizer.convert_ids_to_tokens(input_ids[0])
    pred_labels = predictions[0].tolist()
    
    recognized_entities = []
    current_entity = []
    
    for token, label_id in zip(tokens, pred_labels):
        label = labels.get(label_id, "O")

        if label == "B-HPO":
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
            current_entity = [token]
        elif label == "I-HPO":
            if current_entity:
                current_entity.append(token)
        else:
            if current_entity:
                recognized_entities.append(" ".join(current_entity))
                current_entity = []

    if current_entity:
        recognized_entities.append(" ".join(current_entity))
    
    phenotypes = [" ".join(e.replace(" ##", "").replace("##", "") for e in entity.split()) for entity in recognized_entities]
    
    return phenotypes


def get_diagnosis(text, with_bert=False):
    
    if with_bert:
        text = get_phenotypes(text)

    messages = [
        {"role": "system", "content": "You are an expert and experienced from the healthcare and biomedical domain with extensive medical knowledge and practical experience. Diagnose the condition based on given text. Also I'm just using it for my project so be consistent every time."},
        {"role": "user", "content": f"Based on this text give me the disease name only. TEXT: {text}"},
    ]

    tokenized_message = llama_tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", return_dict=True)
    response_token_ids = llama_model.generate(tokenized_message['input_ids'].cuda(),
                                              attention_mask=tokenized_message['attention_mask'].cuda(),
                                              max_new_tokens=128, 
                                              pad_token_id = llama_tokenizer.eos_token_id
                                             )
    generated_tokens =response_token_ids[:, len(tokenized_message['input_ids'][0]):]
    diagnosis = llama_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]

    return ' '.join(diagnosis.split()[1:])


def compare(y, y1, y2, y3):
    
    yl = len(y)
    
    if yl <= len(y1):
        if y in y1:
            return True
    else:
        if y1 in y:
            return True

    if yl <= len(y2):
        if y in y2:
            return True
    else:
        if y2 in y:
            return True
        
    if yl <= len(y3):
        if y in y3:
            return True
    else:
        if y3 in y:
            return True
    
    return False

In [3]:
import pandas as pd

df = pd.read_csv('datasets/symptom_disease.csv')
df

Unnamed: 0,text,disease
0,"I’ve been feeling really tired, and I’m notici...",Congestive Heart Failure
1,"My throat hurts when I swallow, and I have a f...",Streptococcal Pharyngitis (Strep Throat)
2,I have sharp chest pain that gets worse with d...,Pneumothorax
3,"I have frequent urination, excessive thirst, a...",Diabetes Mellitus (Type 1 or Type 2)
4,"My joints hurt and are swollen, especially in ...",Rheumatoid Arthritis
...,...,...
95,I feel full after eating only a small amount o...,Gastroparesis
96,"I’ve been losing weight without trying, and I’...",Hyperthyroidism
97,"I have a constant headache, sensitivity to lig...",Migraine
98,"I’m feeling lightheaded, with blurred vision, ...",Hypoglycemia


In [9]:
df['output1'] = df['text'].apply(lambda x: get_diagnosis(x))
df['output2'] = df['text'].apply(lambda x: get_diagnosis(x))
df['output3'] = df['text'].apply(lambda x: get_diagnosis(x))
df['is_correct'] = df.apply(lambda x: compare(x['disease'], x['output1'], x['output2'], x['output3']), axis=1)

print("Accuracy of Llama without bert: ", (df['is_correct'].value_counts()[True])/len(df))

Accuracy of Llama without bert:  0.4


In [15]:
df['b1'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['b2'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['b3'] = df['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df['is_correct2'] = df.apply(lambda x: compare(x['disease'], x['b1'], x['b2'], x['b3']), axis=1)

if True in df['is_correct2'].values:
    print("Accuracy of Llama bert raredis: ", (df['is_correct2'].value_counts()[True])/len(df))
else:
    print("Cannot compute accuracy")

Cannot compute accuracy


In [4]:
# On symptom2disease dataset
import pandas as pd

df2 = pd.read_csv('datasets/Symptom2Disease.csv', usecols=['text', 'label'])
df2 = df2.groupby('label').sample(n=4)
df2

Unnamed: 0,label,text
586,Acne,I've recently been battling a pretty itchy ras...
582,Acne,"When I awoke this morning, I realised that I h..."
593,Acne,I discovered a huge rash on my skin yesterday....
592,Acne,"Yesterday, I noticed an enormous rash all over..."
535,Arthritis,"My muscles have been feeling feeble recently, ..."
...,...,...
1107,peptic ulcer disease,I have difficulty sleeping due to abdominal pa...
942,urinary tract infection,I noticed blood in my urinating. I occasionall...
940,urinary tract infection,"I have to go to the bathroom a lot, but genera..."
934,urinary tract infection,"My lower abdomen hurts, and when I urinate, it..."


In [5]:
df2['output1'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['output2'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['output3'] = df2['text'].apply(lambda x: get_diagnosis(x))
df2['is_correct'] = df2.apply(lambda x: compare(x['label'], x['output1'], x['output2'], x['output3']), axis=1)
print("Accuracy of Llama without bert: ", (df2['is_correct'].value_counts()[True])/len(df2))

df2['b1'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['b2'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['b3'] = df2['text'].apply(lambda x: get_diagnosis(x, with_bert=True))
df2['is_correct2'] = df2.apply(lambda x: compare(x['label'], x['b1'], x['b2'], x['b3']), axis=1)

if True in df2['is_correct2'].values:
    print("Accuracy of Llama bert synthetic: ", (df2['is_correct2'].value_counts()[True])/len(df2))
else:
    print("Cannot compute accuracy")

Accuracy of Llama without bert:  0.3229166666666667
Accuracy of Llama bert synthetic:  0.010416666666666666
