# Create Per-Label Datasets (v2 - Preserves All Entities)

Fixed version that keeps ALL entities in each sample, organized by primary label.

Each sample retains its complete `privacy_mask` with all entities.

In [1]:
import json
import os
import random
from collections import defaultdict, Counter
from pathlib import Path

# Paths
UNIFIED_DIR = Path("/Users/sravan/Documents/Experiments/fintuning_PII/Data/additional_datasets/unified")
OUTPUT_DIR = Path("/Users/sravan/Documents/Experiments/fintuning_PII/Data/training_by_label")
OUTPUT_DIR.mkdir(exist_ok=True)

MAX_SAMPLES_PER_LABEL = 10000

## 1. Label Mapping (Same as before)

In [2]:
SIMPLIFIED_24_LABELS = [
    "date", "full name", "username",
    "social security number", "tax identification number",
    "passport number", "driver's license number", "identification number",
    "phone number", "address", "email address", "ip address", "fax number",
    "credit card number", "credit score", "bank account number",
    "amount", "iban", "insurance number",
    "medical condition", "medication", "medical treatment",
    "organization", "url",
]

# Comprehensive label mapping (lowercase keys)
LABEL_MAP = {
    # DATE
    "date": "date", "dob": "date", "date of birth": "date", "date_of_birth": "date",
    "dateofbirth": "date", "birthday": "date", "birth_date": "date", "birth date": "date",
    "datum": "date", "fecha": "date", "data": "date", "date_time": "date",
    "datetime": "date", "time": "date", "expiry date": "date", "expiry_date": "date",
    
    # FULL NAME
    "full name": "full name", "fullname": "full name", "full_name": "full name",
    "name": "full name", "person": "full name", "person name": "full name",
    "per": "full name", "firstname": "full name", "first_name": "full name",
    "first name": "full name", "lastname": "full name", "last_name": "full name",
    "last name": "full name", "givenname": "full name", "given_name": "full name",
    "surname": "full name", "middlename": "full name", "middle_name": "full name",
    "prefix": "full name", "suffix": "full name", "title": "full name",
    "nom": "full name", "nome": "full name", "nombre": "full name",
    "όνομα": "full name", "osebno ime": "full name", "persoon": "full name",
    
    # USERNAME
    "username": "username", "user_name": "username", "user name": "username",
    "userid": "username", "user_id": "username", "user id": "username",
    "login": "username", "handle": "username", "nickname": "username",
    "password": "username",
    
    # SSN
    "social security number": "social security number",
    "social_security_number": "social security number",
    "ssn": "social security number", "socialsecuritynumber": "social security number",
    "us_ssn": "social security number",
    
    # TAX ID
    "tax identification number": "tax identification number",
    "tax_identification_number": "tax identification number",
    "tax id": "tax identification number", "tax_id": "tax identification number",
    "taxid": "tax identification number", "tin": "tax identification number",
    "us_itin": "tax identification number", "itin": "tax identification number",
    "ein": "tax identification number", "vat number": "tax identification number",
    
    # PASSPORT
    "passport number": "passport number", "passport_number": "passport number",
    "passportnumber": "passport number", "passport": "passport number",
    
    # DRIVER'S LICENSE
    "driver's license number": "driver's license number",
    "drivers license number": "driver's license number",
    "drivers_license_number": "driver's license number",
    "driver_license": "driver's license number",
    "driverslicense": "driver's license number",
    "driving license": "driver's license number",
    
    # IDENTIFICATION NUMBER
    "identification number": "identification number",
    "identification_number": "identification number",
    "id number": "identification number", "id_number": "identification number",
    "idnumber": "identification number", "national id": "identification number",
    "national_id": "identification number", "identity card number": "identification number",
    "identity_card_number": "identification number", "idcardnum": "identification number",
    "student id": "identification number", "student_id": "identification number",
    "student id number": "identification number", "employee id": "identification number",
    "employee_id": "identification number", "customer id": "identification number",
    "customer_id": "identification number", "medical record number": "identification number",
    "medical_record_number": "identification number", "mrn": "identification number",
    "birth certificate number": "identification number",
    "license plate": "identification number", "license_plate": "identification number",
    
    # PHONE NUMBER
    "phone number": "phone number", "phone_number": "phone number",
    "phonenumber": "phone number", "phone": "phone number",
    "telephone": "phone number", "telephonenum": "phone number",
    "telephone number": "phone number", "telephone_number": "phone number",
    "mobile": "phone number", "mobile phone": "phone number",
    "mobile_phone": "phone number", "mobile phone number": "phone number",
    "mobile_phone_number": "phone number", "cell phone": "phone number",
    "telefonska številka": "phone number", "stacionarna številka": "phone number",
    
    # ADDRESS
    "address": "address", "street address": "address", "street_address": "address",
    "streetaddress": "address", "home address": "address",
    "city": "address", "state": "address", "country": "address",
    "zipcode": "address", "zip_code": "address", "zip code": "address",
    "postal code": "address", "postal_code": "address", "postalcode": "address",
    "county": "address", "buildingnumber": "address", "building_number": "address",
    "buildingnum": "address", "street": "address", "street_name": "address",
    "loc": "address", "location": "address", "geo": "address",
    "coordinates": "address", "latitude": "address", "longitude": "address",
    "adresse": "address", "naslov": "address",
    
    # EMAIL
    "email address": "email address", "email_address": "email address",
    "emailaddress": "email address", "email": "email address",
    "e-mail": "email address", "mail": "email address",
    
    # IP ADDRESS
    "ip address": "ip address", "ip_address": "ip address",
    "ipaddress": "ip address", "ip": "ip address",
    "ipv4": "ip address", "ipv6": "ip address",
    "mac address": "ip address", "mac_address": "ip address",
    
    # FAX
    "fax number": "fax number", "fax_number": "fax number",
    "faxnumber": "fax number", "fax": "fax number",
    
    # CREDIT CARD
    "credit card number": "credit card number",
    "credit_card_number": "credit card number",
    "creditcardnumber": "credit card number",
    "credit card": "credit card number", "credit_card": "credit card number",
    "creditcard": "credit card number", "card number": "credit card number",
    "card_number": "credit card number", "debit card": "credit card number",
    "cvv": "credit card number", "cvc": "credit card number",
    
    # CREDIT SCORE
    "credit score": "credit score", "credit_score": "credit score",
    "creditscore": "credit score", "fico score": "credit score",
    
    # BANK ACCOUNT
    "bank account number": "bank account number",
    "bank_account_number": "bank account number",
    "bankaccountnumber": "bank account number",
    "bank account": "bank account number", "bank_account": "bank account number",
    "account number": "bank account number", "account_number": "bank account number",
    "accountnum": "bank account number", "routing number": "bank account number",
    "routing_number": "bank account number",
    "swift": "bank account number", "swift code": "bank account number",
    "swift_code": "bank account number", "swift_bic_code": "bank account number",
    "bic": "bank account number",
    
    # AMOUNT
    "amount": "amount", "bank account balance": "amount",
    "bank_account_balance": "amount", "balance": "amount",
    "transaction amount": "amount", "transaction_amount": "amount",
    "salary": "amount", "income": "amount", "price": "amount",
    "cost": "amount", "payment": "amount", "financial": "amount", "money": "amount",
    
    # IBAN
    "iban": "iban", "international bank account number": "iban",
    
    # INSURANCE
    "insurance number": "insurance number", "insurance_number": "insurance number",
    "health insurance number": "insurance number",
    "health_insurance_number": "insurance number",
    "health insurance id": "insurance number",
    "health insurance id number": "insurance number",
    "national health insurance number": "insurance number",
    "insurance plan number": "insurance number",
    "policy number": "insurance number", "policy_number": "insurance number",
    
    # MEDICAL CONDITION
    "medical condition": "medical condition", "medical_condition": "medical condition",
    "condition": "medical condition", "diagnosis": "medical condition",
    "disease": "medical condition", "illness": "medical condition",
    "disorder": "medical condition", "syndrome": "medical condition",
    
    # MEDICATION
    "medication": "medication", "medicine": "medication",
    "drug": "medication", "prescription": "medication", "drug name": "medication",
    
    # MEDICAL TREATMENT
    "medical treatment": "medical treatment", "medical_treatment": "medical treatment",
    "treatment": "medical treatment", "procedure": "medical treatment",
    "surgery": "medical treatment", "therapy": "medical treatment",
    
    # ORGANIZATION
    "organization": "organization", "organisation": "organization",
    "org": "organization", "company": "organization",
    "company name": "organization", "company_name": "organization",
    "employer": "organization", "business": "organization",
    "corporation": "organization", "institution": "organization",
    "bank name": "organization", "hospital": "organization",
    "school": "organization", "university": "organization",
    "organizacija": "organization", "klinika": "organization",
    
    # URL
    "url": "url", "website": "url", "web address": "url",
    "link": "url", "uri": "url", "domain": "url",
}

def normalize_label(label):
    """Normalize label to one of 24 simplified labels"""
    return LABEL_MAP.get(label.lower().strip(), None)

print(f"Label mappings: {len(LABEL_MAP)}")
print(f"Target labels: {len(SIMPLIFIED_24_LABELS)}")

Label mappings: 259
Target labels: 24


## 2. Load All Unified Datasets

In [3]:
all_data = []

for filepath in sorted(UNIFIED_DIR.glob("*_unified.json")):
    with open(filepath) as f:
        data = json.load(f)
    all_data.extend(data)
    print(f"Loaded {filepath.name}: {len(data):,} samples")

print(f"\nTotal: {len(all_data):,} samples")

Loaded ai4privacy_200k_unified.json: 209,261 samples
Loaded ai4privacy_400k_unified.json: 50,000 samples
Loaded beki_privy_unified.json: 100,951 samples
Loaded e3jsi_unified.json: 2,971 samples
Loaded gliner_pii_unified.json: 3,764 samples
Loaded gretel_finance_unified.json: 5,594 samples
Loaded gretel_pii_en_unified.json: 5,000 samples
Loaded nvidia_nemotron_unified.json: 100,000 samples
Loaded urchade_unified.json: 19,635 samples

Total: 497,176 samples


## 3. Normalize Labels in All Samples (Keep ALL entities)

In [4]:
# Process each sample: normalize all entity labels
normalized_data = []
unmapped_labels = Counter()
total_entities = 0
mapped_entities = 0

for item in all_data:
    source_text = item.get('source_text', '')
    language = item.get('language', 'en')
    source = item.get('source', 'unknown')
    original_entities = item.get('privacy_mask', [])
    
    if not original_entities or not source_text:
        continue
    
    # Normalize ALL entities in this sample
    normalized_entities = []
    for entity in original_entities:
        total_entities += 1
        original_label = entity.get('label', '')
        normalized_label = normalize_label(original_label)
        
        if normalized_label:
            mapped_entities += 1
            normalized_entities.append({
                'label': normalized_label,
                'start': entity.get('start', 0),
                'end': entity.get('end', 0),
                'value': entity.get('value', ''),
                'original_label': original_label  # Keep for reference
            })
        else:
            unmapped_labels[original_label] += 1
    
    # Only keep samples that have at least one normalized entity
    if normalized_entities:
        normalized_data.append({
            'source_text': source_text,
            'language': language,
            'source': source,
            'privacy_mask': normalized_entities
        })

print(f"Samples with normalized entities: {len(normalized_data):,}")
print(f"Total entities: {total_entities:,}")
print(f"Mapped entities: {mapped_entities:,} ({mapped_entities/total_entities*100:.1f}%)")
print(f"Unmapped labels: {len(unmapped_labels):,} unique")

Samples with normalized entities: 434,827
Total entities: 1,992,394
Mapped entities: 1,435,662 (72.1%)
Unmapped labels: 8,142 unique


## 4. Show Unmapped Labels

In [5]:
print("Top 30 unmapped labels:")
print("="*50)
for label, count in unmapped_labels.most_common(30):
    print(f"  {label:40s} {count:>6,}")

Top 30 unmapped labels:
  occupation                               37,099
  AGE                                      15,733
  NRP                                      13,789
  SEX                                      13,528
  JOBTYPE                                  13,433
  CURRENCYSYMBOL                           13,147
  credit_debit_card                        12,867
  GENDER                                   12,847
  JOBTITLE                                 12,828
  JOBAREA                                  12,681
  ACCOUNTNAME                              12,533
  ACCOUNTNUMBER                            12,473
  COMPANYNAME                              12,167
  SECONDARYADDRESS                         12,008
  BITCOINADDRESS                           11,682
  biometric_identifier                     11,520
  employment_status                        11,018
  MASKEDNUMBER                             10,893
  health_plan_beneficiary_number           10,599
  USERAGENT               

## 5. Group Samples by Primary Label

Each sample goes to the file of its FIRST normalized entity's label.
But the sample keeps ALL its entities.

In [6]:
# Group samples by the label of their FIRST entity
samples_by_label = defaultdict(list)

for sample in normalized_data:
    entities = sample.get('privacy_mask', [])
    if entities:
        primary_label = entities[0]['label']
        samples_by_label[primary_label].append(sample)

print("Samples per primary label:")
print("="*60)
for label in SIMPLIFIED_24_LABELS:
    count = len(samples_by_label[label])
    if count > 0:
        # Count average entities per sample
        avg_entities = sum(len(s['privacy_mask']) for s in samples_by_label[label]) / count
        print(f"  {label:30s} {count:>8,} samples  (avg {avg_entities:.1f} entities/sample)")

Samples per primary label:
  date                             45,393 samples  (avg 4.0 entities/sample)
  full name                       185,837 samples  (avg 3.4 entities/sample)
  username                         23,414 samples  (avg 2.4 entities/sample)
  social security number            5,603 samples  (avg 2.0 entities/sample)
  tax identification number         2,379 samples  (avg 1.3 entities/sample)
  passport number                     147 samples  (avg 1.3 entities/sample)
  driver's license number              70 samples  (avg 1.0 entities/sample)
  identification number            11,942 samples  (avg 4.3 entities/sample)
  phone number                      4,992 samples  (avg 2.0 entities/sample)
  address                          68,582 samples  (avg 2.7 entities/sample)
  email address                     8,167 samples  (avg 2.1 entities/sample)
  ip address                       17,227 samples  (avg 1.7 entities/sample)
  fax number                           92 samples

## 6. Save Per-Label Files (Max 10k, Keep All Entities)

In [7]:
print("Saving per-label datasets...")
print("="*70)

total_saved = 0

for label in SIMPLIFIED_24_LABELS:
    samples = samples_by_label[label]
    
    if not samples:
        print(f"  {label:30s} SKIPPED (no samples)")
        continue
    
    # Shuffle and limit
    random.shuffle(samples)
    selected = samples[:MAX_SAMPLES_PER_LABEL]
    
    # Calculate stats
    total_entities = sum(len(s['privacy_mask']) for s in selected)
    avg_entities = total_entities / len(selected)
    
    # Save with pretty formatting
    filename = label.replace("'", "").replace(" ", "_") + ".json"
    output_path = OUTPUT_DIR / filename
    
    with open(output_path, 'w') as f:
        json.dump(selected, f, indent=2)
    
    size_kb = os.path.getsize(output_path) / 1024
    total_saved += len(selected)
    
    print(f"  {label:30s} {len(selected):>6,} samples, {total_entities:>8,} entities ({avg_entities:.1f}/sample)")

print(f"\n{'='*70}")
print(f"Total saved: {total_saved:,} samples")
print(f"Output directory: {OUTPUT_DIR}")

Saving per-label datasets...
  date                           10,000 samples,   39,388 entities (3.9/sample)
  full name                      10,000 samples,   33,803 entities (3.4/sample)
  username                       10,000 samples,   24,114 entities (2.4/sample)
  social security number          5,603 samples,   10,965 entities (2.0/sample)
  tax identification number       2,379 samples,    3,130 entities (1.3/sample)
  passport number                   147 samples,      195 entities (1.3/sample)
  driver's license number            70 samples,       71 entities (1.0/sample)
  identification number          10,000 samples,   43,123 entities (4.3/sample)
  phone number                    4,992 samples,    9,919 entities (2.0/sample)
  address                        10,000 samples,   26,443 entities (2.6/sample)
  email address                   8,167 samples,   17,134 entities (2.1/sample)
  ip address                     10,000 samples,   17,172 entities (1.7/sample)
  fax numbe

## 7. Verify a Sample

In [8]:
# Check a sample to verify all entities are preserved
test_file = OUTPUT_DIR / "full_name.json"
if test_file.exists():
    with open(test_file) as f:
        test_data = json.load(f)
    
    # Find a sample with multiple entities
    for sample in test_data[:50]:
        if len(sample['privacy_mask']) > 1:
            print("Sample with multiple entities:")
            print("="*60)
            print(f"Text: {sample['source_text'][:200]}...")
            print(f"\nEntities ({len(sample['privacy_mask'])}):")
            for e in sample['privacy_mask']:
                print(f"  - {e['label']:25s} '{e['value'][:30]}' [{e['start']}:{e['end']}]")
            break

Sample with multiple entities:
Text: | **Field**                | **Value**                                       |
|---------------------------|-------------------------------------------------|
| **Voter's Full Name**     | Sovanna Chh...

Entities (6):
  - full name                 'Sovanna' [189:196]
  - full name                 'Chhum' [197:202]
  - date                      '1947-03-26' [267:277]
  - address                   'Street 271, Phnom Penh' [346:368]
  - address                   'Kampong Cham' [425:437]
  - address                   'Kampong Cham' [505:517]


## 8. Create Combined File

In [9]:
# Create combined file with all labels
combined = []
for filepath in sorted(OUTPUT_DIR.glob("*.json")):
    if filepath.name != "combined_training_24labels.json":
        with open(filepath) as f:
            combined.extend(json.load(f))

combined_path = OUTPUT_DIR.parent / "combined_training_24labels.json"
with open(combined_path, 'w') as f:
    json.dump(combined, f, indent=2)

size_mb = os.path.getsize(combined_path) / 1024 / 1024
total_entities = sum(len(s['privacy_mask']) for s in combined)

print(f"Combined file: {combined_path.name}")
print(f"  Samples: {len(combined):,}")
print(f"  Entities: {total_entities:,}")
print(f"  Size: {size_mb:.1f} MB")

Combined file: combined_training_24labels.json
  Samples: 125,591
  Entities: 371,475
  Size: 125.7 MB
