Imports

In [1]:
#imports
import sys
import os
from transformers import AutoTokenizer

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, prepare_all_samples
from bert_training_utils import build_gold_lookup, get_label_for_pair, create_training_pairs, compute_class_weights, balance_classes, handle_class_imbalance, add_special_tokens, tokenize_function
from bert_extractor_utils import preprocess_input, mark_entities_full_text

Test BERT Pre-Processing & Utility Functions

In [2]:
# Load sample data
df = load_data("../data/training_dataset.csv")  # adjust path if needed
samples = prepare_all_samples(df)
print(f"Loaded {len(samples)} samples")

Loaded 5 samples


In [3]:
# Get first sample and find a positive relation
sample = samples[0]
gold_map = build_gold_lookup(sample['relations_json'])

# Find entity and date that form a relation
for rel in sample['relations_json']:
    entity = next((e for e in sample['entities_list'] if e['value'] == rel['entity']), None)
    date = next((d for d in sample['dates'] if d['value'] == rel['date']), None)
    if entity and date:
        break

print("\nSample info:")
print(f"Note text (first 100 chars): {sample['note_text'][:100]}...")
print(f"Selected entity: {entity}")
print(f"Selected date: {date}")


Sample info:
Note text (first 100 chars): Ultrasound (30nd Jun 2024): no significant findings.imp: asthma

She denies any nausea, vomiting, or...
Selected entity: {'id': 308252, 'value': 'pituitary_adenoma', 'cui': '254956000', 'start': 410, 'end': 427}
Selected date: {'id': 308321, 'value': '12nd Sep 2024', 'start': 363, 'end': 376}


In [4]:
# Test build_gold_lookup
print("Testing build_gold_lookup...")
print(f"Gold map: {gold_map}")

Testing build_gold_lookup...
Gold map: {('pituitary_adenoma', '12nd Sep 2024'), ('rheumatoid_arthritis', "16 Sep'24"), ('headache', '23rd Oct 2024'), ('GERD', '17.12.24')}


In [5]:
# Test get_label_for_pair
print("Testing get_label_for_pair...")
label = get_label_for_pair(entity['value'], date['value'], gold_map)  # Use values instead of starts
print(f"Label for {entity['value']} + {date['value']}: {label}")

Testing get_label_for_pair...
Label for pituitary_adenoma + 12nd Sep 2024: relation


In [6]:
# Test mark_entities_full_text
print("Testing mark_entities_full_text...")
marked = mark_entities_full_text(
    sample['note_text'],
    entity['start'], entity['end'],
    date['start'], date['end'],
    entity['value'], date['value']
)
print(f"Marked text around entities:")
context_start = max(0, min(entity['start'], date['start']) - 50)
context_end = min(len(marked), max(entity['start'], date['end']) + 50)
print(marked[context_start:context_end])

Testing mark_entities_full_text...
Marked text around entities:
2nd Aug 2024): reveals asthma.imp: asthma

X-ray ([E2] 12nd Sep 2024 [/E2]): shows 3.1cm mass in brain.imp: [E1] pituitary_adenoma [/E1]

CLINIC VI


In [7]:
# Test preprocessing
print("Testing preprocess_input...")
preprocessed = preprocess_input(sample['note_text'], entity, date)
print("\nPreprocessed input around entities:")
context_start = max(0, min(entity['start'], date['start']) - 50)
context_end = min(len(preprocessed['marked_text']), max(entity['start'], date['end']) + 50)
print(preprocessed['marked_text'][context_start:context_end])

Testing preprocess_input...

Preprocessed input around entities:
2nd Aug 2024): reveals asthma.imp: asthma

X-ray ([E2] 12nd Sep 2024 [/E2]): shows 3.1cm mass in brain.imp: [E1] pituitary_adenoma [/E1]

CLINIC VI


In [8]:
# Test create_training_pairs
print("Testing create_training_pairs...")
df = create_training_pairs([sample])
print(f"\nCreated {len(df)} training pairs")
print("\nSample columns:", df.columns.tolist())
print("\nLabel distribution:")
print(df['label'].value_counts())
print("\nFirst positive example:")
pos = df[df['label'] == 1].iloc[0]
start = max(0, pos['ent1_start']-50)
end = min(len(pos['marked_text']), pos['ent1_end']+50)
print(f"Text snippet: ...{pos['marked_text'][start:end]}...")

Testing create_training_pairs...

Created 384 training pairs

Sample columns: ['text', 'marked_text', 'ent1_start', 'ent1_end', 'ent2_start', 'ent2_end', 'label', 'patient_id', 'note_id', 'distance']

Label distribution:
label
0    379
1      5
Name: count, dtype: int64

First positive example:
Text snippet: ... ([E2] 16 Sep'24 [/E2]): nausea/vomiting worsening confirmed [E1] rheumatoid_arthritis [/E1] switch to aspirin

Past med...


In [9]:
# Test class weight computation
print("Testing compute_class_weights...")
weights = compute_class_weights(df, num_labels=2)
print(f"\nClass weights: {weights}")
print("(Should be ~1 on average, higher for minority class)")

# Verify weights are working as expected
counts = df['label'].value_counts()
print("\nClass distribution:")
for label, count in counts.items():
    weight = weights[int(label)]
    print(f"Class {label}: count={count}, weight={weight:.3f}")

Testing compute_class_weights...

Class weights: tensor([0.0260, 1.9740])
(Should be ~1 on average, higher for minority class)

Class distribution:
Class 0: count=379, weight=0.026
Class 1: count=5, weight=1.974


In [10]:
# Test downsampling
print("Testing balance_classes...")
balanced_df = balance_classes(df, ratio=1.0)
print("\nBefore balancing:")
print(df['label'].value_counts())
print("\nAfter balancing:")
print(balanced_df['label'].value_counts())

Testing balance_classes...

Before balancing:
label
0    379
1      5
Name: count, dtype: int64

After balancing:
label
0    5
1    5
Name: count, dtype: int64


In [11]:
# Test both imbalance handling methods
print("Testing handle_class_imbalance...")

print("\nMethod: weighted")
weighted_df, w_weights = handle_class_imbalance(df, method='weighted')
print(f"Returned weights: {w_weights}")
print("Class distribution (unchanged):")
print(weighted_df['label'].value_counts())

print("\nMethod: downsample")
down_df, d_weights = handle_class_imbalance(df, method='downsample')
print(f"Returned weights: {d_weights}")
print("Class distribution (balanced):")
print(down_df['label'].value_counts())

Testing handle_class_imbalance...

Method: weighted
Returned weights: tensor([0.0260, 1.9740])
Class distribution (unchanged):
label
0    379
1      5
Name: count, dtype: int64

Method: downsample
Returned weights: None
Class distribution (balanced):
label
0    5
1    5
Name: count, dtype: int64


In [12]:
# Test add_special_tokens and tokenize_function
print("Testing tokenizer functions...")

# Load base tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
print("\nBefore adding special tokens:")
print(f"Vocab size: {len(tokenizer)}")
print("Special tokens:", tokenizer.all_special_tokens)

# Add special tokens
tokenizer = add_special_tokens(tokenizer)
print("\nAfter adding special tokens:")
print(f"Vocab size: {len(tokenizer)}")
print("Special tokens:", tokenizer.all_special_tokens)

# Verify special tokens work
example = {"marked_text": "[E1] asthma [/E1] was diagnosed on [E2] 2024-01-01 [/E2]"}
encoded = tokenize_function(example, tokenizer, max_length=32)
decoded = tokenizer.decode(encoded['input_ids'])
print("\nTokenization roundtrip:")
print("Original:", example['marked_text'])
print("Decoded:", decoded)

# Check if special tokens are preserved
for token in ["[E1]", "[/E1]", "[E2]", "[/E2]"]:
    if token not in decoded:
        print(f"WARNING: {token} was lost in tokenization!")

Testing tokenizer functions...

Before adding special tokens:
Vocab size: 28996
Special tokens: ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']

After adding special tokens:
Vocab size: 29000
Special tokens: ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]', '[E1]', '[/E1]', '[E2]', '[/E2]']

Tokenization roundtrip:
Original: [E1] asthma [/E1] was diagnosed on [E2] 2024-01-01 [/E2]
Decoded: [CLS] [E1] asthma [/E1] was diagnosed on [E2] 2024 - 01 - 01 [/E2] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
