Imports

In [None]:
# Cell 1: Imports
import sys
import os
import re
from transformers import AutoTokenizer
import pandas as pd
import numpy as np

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, prepare_all_samples
from bert_training_utils import (
    build_gold_lookup, 
    get_label_for_pair, 
    create_training_pairs,
    compute_class_weights, 
    downsample_classes, 
    upsample_classes, 
    handle_class_imbalance, 
    add_special_tokens, 
    tokenize_function
)
from bert_extractor_utils import preprocess_input, mark_entities_full_text

Test BERT Pre-Processing & Utility Functions

In [None]:
# Cell 2: Load Test Data
def load_test_data():
    """Test 1: Load and prepare test data"""
    df = load_data("../data/training_dataset.csv")
    samples = prepare_all_samples(df)
    
    print("=== Test 1: Data Loading ===")
    print(f"Loaded {len(df)} documents")
    print(f"Prepared {len(samples)} samples")
    
    # Basic validation
    sample = samples[0]
    print("\nFirst sample contents:")
    print(f"- Number of entities: {len(sample['entities_list'])}")
    print(f"- Number of absolute dates: {len(sample['dates'])}")
    print(f"- Number of relative dates: {len(sample['relative_dates'])}")
    print(f"- Number of relations: {len(sample['relations_json'])}")
    
    return samples

# Run test
samples = load_test_data()

In [None]:
# Cell 3: Test Entity and Date Separation (Updated)
def test_entity_date_separation(samples):
    """Test 2: Verify entities and dates are properly separated"""
    print("=== Test 2: Entity-Date Separation ===")
    sample = samples[0]
    
    # More thorough date pattern check
    date_patterns = [
        r'\d{1,2}/\d{1,2}/\d{2,4}',  # dd/mm/yyyy or mm/dd/yyyy
        r'\d{1,2}-\d{1,2}-\d{2,4}',   # dd-mm-yyyy or mm-dd-yyyy
        r'\d{1,2}/\d{1,2}',           # dd/mm or mm/dd
        r'\d{4}',                      # yyyy
    ]
    
    suspicious_entities = []
    for e in sample['entities_list']:
        for pattern in date_patterns:
            if re.search(pattern, e['value']):
                suspicious_entities.append(e)
                break
    
    print("\nChecking entities list:")
    print(f"Total entities: {len(sample['entities_list'])}")
    if suspicious_entities:
        print("WARNING: Found potential dates in entities list:")
        for e in suspicious_entities:
            print(f"- {e['value']} (CUI: {e.get('cui', 'N/A')})")
            print(f"  Context: {sample['note_text'][max(0, e['start']-30):e['end']+30]}")
    else:
        print("✓ No date patterns found in entities list")
    
    # Verify dates are properly separated
    print("\nChecking date separation:")
    print("Absolute dates (first 3):")
    for d in sample['dates'][:3]:
        print(f"- {d['value']} (Position: {d['start']}-{d['end']})")
        print(f"  Context: {sample['note_text'][max(0, d['start']-30):d['end']+30]}")
    
    print("\nRelative dates (all):")
    for rd in sample['relative_dates']:
        print(f"- {rd['value']} (Position: {rd['start']}-{rd['end']})")
        print(f"  Context: {sample['note_text'][max(0, rd['start']-30):rd['end']+30]}")

# Run test
test_entity_date_separation(samples)

In [None]:
# Cell 4: Test Relations
def test_relations(samples):
    """Test 3: Verify relation extraction and gold set creation"""
    print("=== Test 3: Relations Testing ===")
    sample = samples[0]
    
    # Test gold set creation
    gold_set = build_gold_lookup(sample['relations_json'])
    print(f"\nGold set size: {len(gold_set)}")
    
    # Test both absolute and relative date relations
    print("\nTesting absolute date relations:")
    abs_relations = [(r['entity'], r['date']) for r in sample['relations_json'] 
                    if r['date'] in [d['value'] for d in sample['dates']]]
    print(f"Number of absolute date relations: {len(abs_relations)}")
    print("First 3 absolute date relations:")
    for entity, date in abs_relations[:3]:
        print(f"- {entity} -> {date}")
    
    print("\nTesting relative date relations:")
    rel_relations = [(r['entity'], r['date']) for r in sample['relations_json'] 
                    if r['date'] in [rd['value'] for rd in sample['relative_dates']]]
    print(f"Number of relative date relations: {len(rel_relations)}")
    print("All relative date relations:")
    for entity, date in rel_relations:
        print(f"- {entity} -> {date}")

# Run test
test_relations(samples)

In [None]:
# Cell 5: Test Text Marking
def test_text_marking(samples):
    """Test 4: Verify text marking for both absolute and relative dates"""
    print("=== Test 4: Text Marking ===")
    sample = samples[0]
    
    # Get all relations
    abs_relations = [(r['entity'], r['date']) for r in sample['relations_json'] 
                    if r['date'] in [d['value'] for d in sample['dates']]]
    rel_relations = [(r['entity'], r['date']) for r in sample['relations_json'] 
                    if r['date'] in [rd['value'] for rd in sample['relative_dates']]]
    
    print(f"\nFound {len(abs_relations)} absolute date relations")
    print(f"Found {len(rel_relations)} relative date relations")
    
    # Test absolute date marking (first 3 examples)
    print("\nTesting absolute date marking:")
    for entity_text, date_text in abs_relations[:3]:
        entity = next(e for e in sample['entities_list'] if e['value'] == entity_text)
        date = next(d for d in sample['dates'] if d['value'] == date_text)
        
        print(f"\nPair: {entity_text} -> {date_text}")
        print(f"Entity: {entity['value']} (Position: {entity['start']}-{entity['end']})")
        print(f"Date: {date['value']} (Position: {date['start']}-{date['end']})")
        
        marked = mark_entities_full_text(
            sample['note_text'],
            entity['start'], entity['end'],
            date['start'], date['end'],
            entity['value'], date['value']
        )
        
        # Show focused context
        start_pos = min(entity['start'], date['start'])
        end_pos = max(entity['end'], date['end'])
        context_start = max(0, start_pos - 30)
        context_end = min(len(marked), end_pos + 30)
        print(f"Marked text: ...{marked[context_start:context_end]}...")
    
    # Test relative date marking (all examples)
    print("\nTesting relative date marking:")
    for entity_text, date_text in rel_relations:
        entity = next(e for e in sample['entities_list'] if e['value'] == entity_text)
        rel_date = next(rd for rd in sample['relative_dates'] if rd['value'] == date_text)
        
        print(f"\nPair: {entity_text} -> {date_text}")
        print(f"Entity: {entity['value']} (Position: {entity['start']}-{entity['end']})")
        print(f"Date: {rel_date['value']} (Position: {rel_date['start']}-{rel_date['end']})")
        
        marked = mark_entities_full_text(
            sample['note_text'],
            entity['start'], entity['end'],
            rel_date['start'], rel_date['end'],
            entity['value'], rel_date['value']
        )
        
        # Show focused context
        start_pos = min(entity['start'], rel_date['start'])
        end_pos = max(entity['end'], rel_date['end'])
        context_start = max(0, start_pos - 30)
        context_end = min(len(marked), end_pos + 30)
        print(f"Marked text: ...{marked[context_start:context_end]}...")

# Run test
test_text_marking(samples)

In [None]:
# Cell 6: Test Training Pair Creation (Updated)
def test_training_pairs(samples):
    """Test 5: Verify training pair creation and labeling"""
    print("=== Test 5: Training Pair Creation ===")
    
    # Create pairs from first sample only
    sample = samples[0]
    df = create_training_pairs([sample])
    
    print("\nOverall pair statistics:")
    print(f"Total pairs created: {len(df)}")
    print("\nOverall label distribution:")
    print(df['label'].value_counts())
    
    # Get the actual dates used in each pair
    rel_dates = set(rd['value'] for rd in sample['relative_dates'])
    abs_dates = set(d['value'] for d in sample['dates'])
    
    # Filter based on the date being used in the pair, not just text content
    def get_date_from_text(text):
        # Extract the text between [E2] and [/E2]
        import re
        match = re.search(r'\[E2\](.*?)\[/E2\]', text)
        return match.group(1).strip() if match else None
    
    df['date_used'] = df['marked_text'].apply(get_date_from_text)
    rel_pairs = df[df['date_used'].isin(rel_dates)]
    abs_pairs = df[df['date_used'].isin(abs_dates)]
    
    print("\nAbsolute date pairs:")
    print(f"Total absolute date pairs: {len(abs_pairs)}")
    print("Label distribution:")
    print(abs_pairs['label'].value_counts())
    
    print("\nRelative date pairs:")
    print(f"Total relative date pairs: {len(rel_pairs)}")
    print("Label distribution:")
    print(rel_pairs['label'].value_counts())
    
    # Show examples of both types
    print("\nExample positive absolute date pair:")
    if len(abs_pairs[abs_pairs['label'] == 1]) > 0:
        pos_abs = abs_pairs[abs_pairs['label'] == 1].iloc[0]
        print(f"Text: ...{pos_abs['marked_text'][max(0, pos_abs['ent1_start']-30):min(len(pos_abs['marked_text']), pos_abs['ent1_end']+30)]}...")
    
    print("\nExample positive relative date pair:")
    if len(rel_pairs[rel_pairs['label'] == 1]) > 0:
        pos_rel = rel_pairs[rel_pairs['label'] == 1].iloc[0]
        print(f"Text: ...{pos_rel['marked_text'][max(0, pos_rel['ent1_start']-30):min(len(pos_rel['marked_text']), pos_rel['ent1_end']+30)]}...")

# Run test
test_training_pairs(samples)

In [None]:
# Cell 7: Test Class Balancing
def test_class_balancing(samples):
    """Test 6: Verify class balancing methods"""
    print("=== Test 6: Class Balancing ===")
    
    # Create full training set
    df = create_training_pairs(samples)
    print("\nOriginal class distribution:")
    print(df['label'].value_counts())
    
    # Test weighted balancing
    weighted_df, weights = handle_class_imbalance(df, method='weighted')
    print("\nWeighted balancing:")
    print(f"Class weights: {weights}")
    
    # Test downsampling
    down_df, _ = handle_class_imbalance(df, method='downsample')
    print("\nDownsampling results:")
    print(down_df['label'].value_counts())
    
    # Test upsampling
    up_df, _ = handle_class_imbalance(df, method='upsample')
    print("\nUpsampling results:")
    print(up_df['label'].value_counts())

# Run test
test_class_balancing(samples)

In [None]:
# Cell 8: Test Tokenization
def test_tokenization():
    """Test 7: Verify tokenization and special token handling"""
    print("=== Test 7: Tokenization ===")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
    print("\nBefore adding special tokens:")
    print(f"Vocabulary size: {len(tokenizer)}")
    print(f"Special tokens: {tokenizer.all_special_tokens}")
    
    # Add special tokens
    tokenizer = add_special_tokens(tokenizer)
    print("\nAfter adding special tokens:")
    print(f"Vocabulary size: {len(tokenizer)}")
    print(f"Special tokens: {tokenizer.all_special_tokens}")
    
    # Test tokenization with both absolute and relative dates
    examples = [
        {"marked_text": "[E1]patient[/E1] seen on [E2]2023-01-01[/E2]"},
        {"marked_text": "[E1]symptoms[/E1] started [E2]last week[/E2]"}
    ]
    
    print("\nTokenization tests:")
    for i, example in enumerate(examples, 1):
        encoded = tokenize_function(example, tokenizer, max_length=32)
        decoded = tokenizer.decode(encoded['input_ids'])
        print(f"\nExample {i}:")
        print(f"Original: {example['marked_text']}")
        print(f"Decoded:  {decoded}")
        
        # Verify markers are preserved
        for marker in ["[E1]", "[/E1]", "[E2]", "[/E2]"]:
            if marker not in decoded:
                print(f"WARNING: {marker} was lost in tokenization!")

# Run test
test_tokenization()

In [None]:
# Cell 9: Test Token Length and Distance Analysis
def test_token_lengths_and_distances():
    """Test 8: Analyze both document lengths and distances between entities/dates"""
    print("=== Test 8: Token Length and Distance Analysis ===")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
    tokenizer = add_special_tokens(tokenizer)
    
    # Get first sample's pairs
    sample = samples[0]
    df = create_training_pairs([sample])
    
    # 1. Document Length Analysis
    print("\nDocument Length Analysis:")
    token_lengths = []
    for idx, row in df.iterrows():
        tokens = tokenizer(row['marked_text'], truncation=False)['input_ids']
        token_lengths.append(len(tokens))
    
    print(f"Mean doc length: {sum(token_lengths)/len(token_lengths):.1f} tokens")
    print(f"Max doc length: {max(token_lengths)} tokens")
    print(f"Number of docs > 256 tokens: {sum(l > 256 for l in token_lengths)}")
    print(f"Number of docs > 512 tokens: {sum(l > 512 for l in token_lengths)}")

    # 2. Distance Analysis for Positive Pairs
    print("\nDistance Analysis for True Relations:")
    
    # Get positive pairs
    positive_pairs = df[df['label'] == 1]
    
    # Calculate distances for positive pairs
    distances = []
    abs_distances = []
    rel_distances = []
    
    for idx, row in positive_pairs.iterrows():
        text = row['marked_text']
        e1_pos = text.find('[E1]')
        e2_pos = text.find('[E2]')
        distance = abs(e2_pos - e1_pos)
        distances.append(distance)
        
        # Better relative date detection
        relative_patterns = ['last', 'ago', 'today', 'month', 'year', 'week', 
                           'previous', 'next', 'current']
        # Extract the date text between [E2] and [/E2]
        date_start = text.find('[E2]') + 4
        date_end = text.find('[/E2]')
        date_text = text[date_start:date_end] if date_start > 0 and date_end > 0 else ""
        
        if any(pattern in date_text.lower() for pattern in relative_patterns):
            rel_distances.append(distance)
        else:
            abs_distances.append(distance)
    
    print("\nAll Positive Pairs:")
    print(f"Total positive pairs: {len(distances)}")
    if distances:
        print(f"Mean distance: {sum(distances)/len(distances):.1f} chars")
        print(f"Min distance: {min(distances)} chars")
        print(f"Max distance: {max(distances)} chars")
    
    print("\nAbsolute Date Pairs:")
    print(f"Total pairs: {len(abs_distances)}")
    if abs_distances:
        print(f"Mean distance: {sum(abs_distances)/len(abs_distances):.1f} chars")
        print(f"Min distance: {min(abs_distances)} chars")
        print(f"Max distance: {max(abs_distances)} chars")
    
    print("\nRelative Date Pairs:")
    print(f"Total pairs: {len(rel_distances)}")
    if rel_distances:
        print(f"Mean distance: {sum(rel_distances)/len(rel_distances):.1f} chars")
        print(f"Min distance: {min(rel_distances)} chars")
        print(f"Max distance: {max(rel_distances)} chars")
    
    # Show examples of closest and furthest pairs
    print("\nExample Pairs:")
    closest_idx = distances.index(min(distances))
    furthest_idx = distances.index(max(distances))
    
    closest_pair = positive_pairs.iloc[closest_idx]
    furthest_pair = positive_pairs.iloc[furthest_idx]
    
    print("\nClosest Pair:")
    print(f"Distance: {min(distances)} chars")
    closest_text = closest_pair['marked_text']
    e1_pos = closest_text.find('[E1]')
    e2_pos = closest_text.find('[/E2]')
    start = max(0, e1_pos - 30)
    end = min(len(closest_text), e2_pos + 30)
    print(f"Text: ...{closest_text[start:end]}...")
    
    print("\nFurthest Pair:")
    print(f"Distance: {max(distances)} chars")
    furthest_text = furthest_pair['marked_text']
    e1_pos = furthest_text.find('[E1]')
    e2_pos = furthest_text.find('[/E2]')
    start = max(0, e1_pos - 30)
    end = min(len(furthest_text), e2_pos + 30)
    print(f"Text: ...{furthest_text[start:end]}...")

# Run test
test_token_lengths_and_distances()