# Test on Balanced Dataset

Test the model on balanced dataset (100 samples per label, max 300 tokens)

In [1]:
import json
import pandas as pd
from gliner import GLiNER
from tqdm import tqdm
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


## Load Balanced Test Dataset

In [2]:
# Load balanced test dataset (only 50 samples for quick testing)
with open('../../Data/combined_testdata/balanced_test_100_per_class_27_labels_filtered_300tok.json', 'r') as f:
    all_samples = json.load(f)
    samples = all_samples[:50]  # Take only first 50 samples

print(f"Loaded {len(samples)} samples (out of {len(all_samples)} total)")
print(f"\nExample sample:")
print(f"Text: {samples[0]['text'][:100]}...")
print(f"Entities: {samples[0]['entities'][:2]}")  # Show first 2 entities

Loaded 50 samples (out of 1771 total)

Example sample:
Text: Survey Date: January 26th, 1985 
City: Oak Grove 
How often do you encounter the following stressors...
Entities: [{'entity': 'January 26th, 1985', 'types': ['date'], 'start': 13, 'end': 31, 'original_type': 'dateofbirth', 'canonical_type': 'dateofbirth'}, {'entity': '660-03-8442', 'types': ['tax identification number'], 'start': 137, 'end': 148, 'original_type': 'taxnum', 'canonical_type': 'taxnum'}]


## Dataset Statistics

In [3]:
# Count labels in dataset
label_counts = Counter()
for sample in samples:
    for entity in sample['entities']:
        # Use the first type from the types list
        label = entity['types'][0] if entity['types'] else 'unknown'
        label_counts[label] += 1

# Show distribution
label_df = pd.DataFrame([
    {'Label': label, 'Count': count}
    for label, count in label_counts.most_common()
])

print(f"Unique labels: {len(label_df)}")
print(f"Total entities: {label_df['Count'].sum()}")
print(f"\nLabel distribution:")
label_df

Unique labels: 22
Total entities: 204

Label distribution:


Unnamed: 0,Label,Count
0,date,41
1,full name,38
2,bank account number,13
3,organization,12
4,phone number,11
5,email address,11
6,username,10
7,social security number,8
8,address,8
9,bank account balance,7


## Load Model

In [4]:
# Load the finetuned model
model = GLiNER.from_pretrained("../../finetuned_gliner")
print("Model loaded successfully")

Model loaded successfully


## Define Labels

In [5]:
# Get unique labels from dataset
labels = list(label_counts.keys())
print(f"Using {len(labels)} labels from the dataset:")
for label in sorted(labels):
    print(f"  - {label}")

Using 22 labels from the dataset:
  - address
  - bank account balance
  - bank account number
  - credit card number
  - date
  - drivers license number
  - email address
  - fax number
  - full name
  - health insurance id number
  - iban
  - identity card number
  - insurance plan number
  - ip address
  - medical condition
  - medication
  - organization
  - passport number
  - phone number
  - social security number
  - tax identification number
  - username


## Run Predictions (Full Dataset)

In [6]:
# Run predictions on all samples
results = []

for sample in tqdm(samples, desc="Evaluating"):
    text = sample['text']
    ground_truth = sample['entities']
    
    # Predict
    predictions = model.predict_entities(text, labels, threshold=0.3)
    
    results.append({
        'text': text,
        'ground_truth': ground_truth,
        'predictions': predictions
    })

print(f"\nCompleted predictions on {len(results)} samples")

Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Evaluating: 100%|██████████| 50/50 [00:06<00:00,  7.34it/s]


Completed predictions on 50 samples





## Calculate Metrics

In [7]:
# Count correct predictions
total_ground_truth = 0
total_predictions = 0
correct_predictions = 0

for result in results:
    text = result['text']
    gt = result['ground_truth']
    pred = result['predictions']
    
    total_ground_truth += len(gt)
    total_predictions += len(pred)
    
    # Match predictions with ground truth
    for p in pred:
        for g in gt:
            # Match if same span and label
            g_label = g['types'][0] if g['types'] else 'unknown'
            if (p['start'] == g['start'] and 
                p['end'] == g['end'] and 
                p['label'].lower() == g_label.lower()):
                correct_predictions += 1
                break

# Calculate metrics
precision = correct_predictions / total_predictions if total_predictions > 0 else 0
recall = correct_predictions / total_ground_truth if total_ground_truth > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print(f"Results:")
print(f"="*50)
print(f"Ground Truth Entities: {total_ground_truth}")
print(f"Predicted Entities: {total_predictions}")
print(f"Correct Predictions: {correct_predictions}")
print(f"\nPrecision: {precision:.2%}")
print(f"Recall: {recall:.2%}")
print(f"F1 Score: {f1:.2%}")

Results:
Ground Truth Entities: 204
Predicted Entities: 168
Correct Predictions: 82

Precision: 48.81%
Recall: 40.20%
F1 Score: 44.09%


## Per-Label Performance

In [8]:
# Calculate per-label metrics
label_stats = {}

for label in labels:
    tp = 0  # True positives
    fp = 0  # False positives
    fn = 0  # False negatives
    
    for result in results:
        text = result['text']
        gt = result['ground_truth']
        pred = result['predictions']
        
        # Ground truth for this label
        gt_for_label = [(g['start'], g['end']) for g in gt 
                        if (g['types'][0] if g['types'] else 'unknown').lower() == label.lower()]
        
        # Predictions for this label
        pred_for_label = [(p['start'], p['end']) for p in pred 
                          if p['label'].lower() == label.lower()]
        
        # Count matches
        for pred_span in pred_for_label:
            if pred_span in gt_for_label:
                tp += 1
            else:
                fp += 1
        
        # Count missed
        for gt_span in gt_for_label:
            if gt_span not in pred_for_label:
                fn += 1
    
    # Calculate metrics
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0
    
    label_stats[label] = {
        'Precision': prec,
        'Recall': rec,
        'F1': f1_score,
        'Support': tp + fn
    }

# Convert to DataFrame
perf_df = pd.DataFrame(label_stats).T
perf_df = perf_df.sort_values('F1', ascending=False)

# Format percentages
perf_df['Precision'] = perf_df['Precision'].apply(lambda x: f"{x:.1%}")
perf_df['Recall'] = perf_df['Recall'].apply(lambda x: f"{x:.1%}")
perf_df['F1'] = perf_df['F1'].apply(lambda x: f"{x:.1%}")

print("\nPer-Label Performance:")
perf_df


Per-Label Performance:


Unnamed: 0,Precision,Recall,F1,Support
fax number,100.0%,100.0%,100.0%,4.0
credit card number,83.3%,83.3%,83.3%,6.0
social security number,66.7%,75.0%,70.6%,8.0
bank account number,80.0%,61.5%,69.6%,13.0
email address,63.6%,63.6%,63.6%,11.0
date,68.8%,53.7%,60.3%,41.0
tax identification number,100.0%,40.0%,57.1%,5.0
phone number,60.0%,54.5%,57.1%,11.0
organization,36.0%,75.0%,48.6%,12.0
address,44.4%,50.0%,47.1%,8.0


## Show Example Predictions

In [9]:
# Show 3 examples
for i, result in enumerate(results[:3]):
    text = result['text']
    gt = result['ground_truth']
    pred = result['predictions']
    
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    print(f"Text: {text[:200]}...")
    
    print(f"\nGround Truth ({len(gt)} entities):")
    for entity in gt:
        label = entity['types'][0] if entity['types'] else 'unknown'
        print(f"  - {entity['entity']:30s} → {label}")
    
    print(f"\nPredictions ({len(pred)} entities):")
    for p in pred:
        print(f"  - {p['text']:30s} → {p['label']} (score: {p['score']:.2f})")


Example 1
Text: Survey Date: January 26th, 1985 
City: Oak Grove 
How often do you encounter the following stressors? 
- Taxes and paperwork: Tax number 660-03-8442 
- Financial management: Credit Card Number 6290812...

Ground Truth (3 entities):
  - January 26th, 1985             → date
  - 660-03-8442                    → tax identification number
  - 6290812888615710               → credit card number

Predictions (4 entities):
  - January 26th, 1985             → date (score: 1.00)
  - Oak Grove                      → address (score: 0.96)
  - 660-03-8442                    → tax identification number (score: 1.00)
  - 6290812888615710               → credit card number (score: 1.00)

Example 2
Text: **Life Insurance Claim Authorization Form**

**Policyholder Information:**
- Policyholder Name: Maria Orta
- Policy Number:

**Insured's Personal Details:**
- First Name: Maria
- Last Name: Orta
- Dat...

Ground Truth (10 entities):
  - Maria                          → full name
  - 