# Create Training Data by Label

This notebook extracts examples from existing datasets and organizes them by our 23 canonical labels.

**Goal**: Create 23 JSON files, one per label, each containing diverse examples with subtypes.

**Datasets used** (commercially safe):
- nvidia/Nemotron-PII (100k samples) - CC BY 4.0
- gretel-pii-masking-en-v1 (5k samples) - Apache 2.0
- gretel-finance-multilingual - Apache 2.0

In [1]:
import json
import os
import ast
from collections import defaultdict, Counter
import random

# Paths
DATA_DIR = "/Users/sravan/Documents/Experiments/fintuning_PII/Data"
OUTPUT_DIR = "/Users/sravan/Documents/Experiments/fintuning_PII/Training_data"

# Create output directory if not exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 1. EXPLORE: Labels in Each Dataset

**RUN THIS SECTION FIRST** to see all unique labels in each dataset before creating mappings.

This helps you:
1. See what labels exist in each dataset
2. Find labels that need mapping
3. Update the LABEL_MAPPING dictionary accordingly

In [2]:
# ============================================================
# EXPLORE NVIDIA NEMOTRON-PII LABELS
# ============================================================
print("=" * 70)
print("NVIDIA NEMOTRON-PII LABELS")
print("=" * 70)

nvidia_path = f"{DATA_DIR}/nvidia-nemotron-pii/test.json"
with open(nvidia_path, 'r') as f:
    nvidia_data = json.load(f)

nvidia_labels = Counter()
nvidia_examples = defaultdict(list)  # Store example for each label

for item in nvidia_data:
    text = item.get('text', '')
    spans = item.get('spans', [])
    if isinstance(spans, str):
        try:
            spans = ast.literal_eval(spans)
        except:
            continue
    
    for span in spans:
        if isinstance(span, dict):
            label = span.get('label', '')
            nvidia_labels[label] += 1
            # Store one example per label
            if len(nvidia_examples[label]) < 2:
                start, end = span.get('start', 0), span.get('end', 0)
                entity_text = text[start:end] if start < end else ''
                nvidia_examples[label].append(entity_text[:50])

print(f"\nTotal unique labels: {len(nvidia_labels)}")
print(f"Total samples: {len(nvidia_data)}")
print("\nLabel                                Count    Examples")
print("-" * 70)
for label, count in nvidia_labels.most_common():
    examples = nvidia_examples[label][:2]
    ex_str = ' | '.join(examples)[:40]
    print(f"{label:35s} {count:6d}    {ex_str}")

NVIDIA NEMOTRON-PII LABELS

Total unique labels: 55
Total samples: 100000

Label                                Count    Examples
----------------------------------------------------------------------
first_name                           84043    Brian | Sabrina
date                                 73867    20300615 | 20300615
last_name                            59596    King | Garcia
company_name                         54837    TransactFlow | Mercy Health Systems
email                                53930    garciah@outlook.com | ismaeljgiacchetto@
url                                  37847    https://secure.bankofamerica.com/legal-d
occupation                           36888    secondary school teacher | Access Contro
time                                 24506    00:33 | 15h45
phone_number                         23930    931-613-1082 | 805-427-4731
country                              23475    USA | USA
customer_id                          20502    CUS498372 | 4873259160
city     

In [3]:
# ============================================================
# EXPLORE GRETEL PII-MASKING LABELS
# ============================================================
print("=" * 70)
print("GRETEL PII-MASKING-EN-V1 LABELS")
print("=" * 70)

gretel_path = f"{DATA_DIR}/gretel-pii-masking-en-v1/test.json"
with open(gretel_path, 'r') as f:
    gretel_data = json.load(f)

gretel_labels = Counter()
gretel_examples = defaultdict(list)

for item in gretel_data:
    entities_raw = item.get('entities', '[]')
    if isinstance(entities_raw, str):
        try:
            entities_raw = ast.literal_eval(entities_raw)
        except:
            continue
    
    for ent in entities_raw:
        entity_text = ent.get('entity', '')
        types = ent.get('types', [])
        for t in types:
            gretel_labels[t] += 1
            if len(gretel_examples[t]) < 2:
                gretel_examples[t].append(entity_text[:50])

print(f"\nTotal unique labels: {len(gretel_labels)}")
print(f"Total samples: {len(gretel_data)}")
print("\nLabel                                Count    Examples")
print("-" * 70)
for label, count in gretel_labels.most_common():
    examples = gretel_examples[label][:2]
    ex_str = ' | '.join(examples)[:40]
    print(f"{label:35s} {count:6d}    {ex_str}")

GRETEL PII-MASKING-EN-V1 LABELS

Total unique labels: 42
Total samples: 5000

Label                                Count    Examples
----------------------------------------------------------------------
medical_record_number                 2658    MRN-293104 | MED25315002
date_of_birth                         2331    1960-11-14 | 1975-04-21
ssn                                   1661    433-42-5929 | ZZ736903T
first_name                            1172    Ekanta | Louise
date                                  1157    1989.12.22 | 1997-01-06
last_name                             1057    Purohit | Tripathi
email                                 1049    veronicawood@example.org | xwilliams@exa
customer_id                           1033    CID-996335 | D-870175-E
employee_id                           1005    EMP730359 | B5890579
name                                   980    Zaitra Sarma | Heather Johnson
phone_number                           904    +1-869-341-9301x7005 | 280.900.0632x9032


In [4]:
# ============================================================
# EXPLORE GRETEL FINANCE MULTILINGUAL LABELS
# ============================================================
print("=" * 70)
print("GRETEL FINANCE MULTILINGUAL LABELS")
print("=" * 70)

gretel_fin_path = f"{DATA_DIR}/gretel-finance-multilingual/test.json"
with open(gretel_fin_path, 'r') as f:
    gretel_fin_data = json.load(f)

gretel_fin_labels = Counter()
gretel_fin_examples = defaultdict(list)

for item in gretel_fin_data:
    entities_raw = item.get('entities', '[]')
    if isinstance(entities_raw, str):
        try:
            entities_raw = ast.literal_eval(entities_raw)
        except:
            continue
    
    for ent in entities_raw:
        entity_text = ent.get('entity', '')
        types = ent.get('types', [])
        for t in types:
            gretel_fin_labels[t] += 1
            if len(gretel_fin_examples[t]) < 2:
                gretel_fin_examples[t].append(entity_text[:50])

print(f"\nTotal unique labels: {len(gretel_fin_labels)}")
print(f"Total samples: {len(gretel_fin_data)}")
print("\nLabel                                Count    Examples")
print("-" * 70)
for label, count in gretel_fin_labels.most_common():
    examples = gretel_fin_examples[label][:2]
    ex_str = ' | '.join(examples)[:40]
    print(f"{label:35s} {count:6d}    {ex_str}")

GRETEL FINANCE MULTILINGUAL LABELS

Total unique labels: 0
Total samples: 5594

Label                                Count    Examples
----------------------------------------------------------------------


In [5]:
# ============================================================
# SUMMARY: ALL UNIQUE LABELS ACROSS DATASETS
# ============================================================
print("=" * 70)
print("SUMMARY: ALL UNIQUE LABELS (for mapping)")
print("=" * 70)

all_source_labels = set(nvidia_labels.keys()) | set(gretel_labels.keys()) | set(gretel_fin_labels.keys())
print(f"\nTotal unique labels across all datasets: {len(all_source_labels)}")

print("\n" + "-" * 70)
print("Label                               NVIDIA   GRETEL   GRETEL-FIN")
print("-" * 70)

for label in sorted(all_source_labels):
    nv = nvidia_labels.get(label, 0)
    gr = gretel_labels.get(label, 0)
    gf = gretel_fin_labels.get(label, 0)
    print(f"{label:35s} {nv:6d}   {gr:6d}   {gf:6d}")

SUMMARY: ALL UNIQUE LABELS (for mapping)

Total unique labels across all datasets: 59

----------------------------------------------------------------------
Label                               NVIDIA   GRETEL   GRETEL-FIN
----------------------------------------------------------------------
account_number                       16693      141        0
address                                  0      563        0
age                                   7167        0        0
api_key                               4667       60        0
bank_routing_number                   8354      257        0
biometric_identifier                 11379      137        0
blood_type                            5539        0        0
certificate_license_number            3002      124        0
city                                 18347      128        0
company_name                         54837      185        0
coordinate                            7659       85        0
country                            

---

## 2. Define Label Mapping

**UPDATE THE MAPPING BELOW** based on the labels you discovered above.

Map dataset-specific labels → (canonical_label, subtype)

In [None]:
# Mapping from source labels -> (canonical_label, subtype)
LABEL_MAPPING = {
    # === FULL NAME ===
    "first_name": ("full name", "person_name"),
    "last_name": ("full name", "person_name"),
    "name": ("full name", "person_name"),
    "person": ("full name", "person_name"),
    "patient_name": ("full name", "patient_name"),
    "doctor_name": ("full name", "doctor_name"),
    
    # === DATE ===
    "date": ("date", "general_date"),
    "date_of_birth": ("date", "date_of_birth"),
    "date_time": ("date", "datetime"),
    "time": ("date", "time"),
    "expiration_date": ("date", "expiration_date"),
    
    # === ADDRESS ===
    "address": ("address", "general_address"),
    "street_address": ("address", "street_address"),
    "city": ("address", "city"),
    "state": ("address", "state"),
    "country": ("address", "country"),
    "postcode": ("address", "postal_code"),
    "coordinate": ("address", "coordinates"),
    
    # === PHONE ===
    "phone_number": ("phone number", "general_phone"),
    "mobile_phone_number": ("phone number", "mobile_phone"),
    
    # === FAX ===
    "fax_number": ("fax number", "fax"),
    
    # === EMAIL ===
    "email": ("email address", "email"),
    "email_address": ("email address", "email"),
    
    # === SSN ===
    "ssn": ("social security number", "ssn"),
    "social_security_number": ("social security number", "ssn"),
    
    # === CREDIT CARD ===
    "credit_card_number": ("credit card number", "credit_card"),
    "credit_debit_card": ("credit card number", "credit_debit_card"),
    
    # === BANK ACCOUNT ===
    "account_number": ("bank account number", "account_number"),
    "bank_routing_number": ("bank account number", "routing_number"),
    
    # === AMOUNT ===
    "amount": ("amount", "general_amount"),
    "transaction_amount": ("amount", "transaction_amount"),
    "salary": ("amount", "salary"),
    
    # === CREDIT SCORE ===
    "credit_score": ("credit score", "credit_score"),
    
    # === IBAN ===
    "iban": ("iban", "iban"),
    "swift_bic": ("iban", "swift_code"),
    
    # === TAX ID ===
    "tax_id": ("tax identification number", "tax_id"),
    "tax_identification_number": ("tax identification number", "tax_id"),
    
    # === DRIVER'S LICENSE ===
    "driver_license": ("driver's license number", "drivers_license"),
    "drivers_license_number": ("driver's license number", "drivers_license"),
    
    # === PASSPORT ===
    "passport_number": ("passport number", "passport"),
    "passport": ("passport number", "passport"),
    
    # === ID NUMBER ===
    "national_id": ("identification number", "national_id"),
    "identity_card_number": ("identification number", "identity_card"),
    "employee_id": ("identification number", "employee_id"),
    "customer_id": ("identification number", "customer_id"),
    "unique_identifier": ("identification number", "unique_id"),
    "certificate_license_number": ("identification number", "certificate_number"),
    "medical_record_number": ("identification number", "medical_record_number"),
    
    # === INSURANCE ===
    "health_insurance_id_number": ("insurance number", "health_insurance"),
    "health_plan_beneficiary_number": ("insurance number", "health_plan"),
    "insurance_number": ("insurance number", "insurance"),
    
    # === IP ADDRESS ===
    "ip_address": ("ip address", "ip"),
    "ipv4": ("ip address", "ipv4"),
    "ipv6": ("ip address", "ipv6"),
    
    # === USERNAME ===
    "username": ("username", "username"),
    "user_name": ("username", "username"),
    "user_id": ("username", "user_id"),
    
    # === ORGANIZATION ===
    "organization": ("organization", "organization"),
    "company_name": ("organization", "company"),
    "hospital": ("organization", "hospital"),
    
    # === MEDICAL CONDITION ===
    "medical_condition": ("medical condition", "condition"),
    "diagnosis": ("medical condition", "diagnosis"),
    "disease": ("medical condition", "disease"),
    
    # === MEDICAL TREATMENT ===
    "medical_treatment": ("medical treatment", "treatment"),
    "procedure": ("medical treatment", "procedure"),
    "therapy": ("medical treatment", "therapy"),
    
    # === MEDICATION ===
    "medication": ("medication", "medication"),
    "drug_name": ("medication", "drug"),
    "prescription": ("medication", "prescription"),
}

# Our 23 canonical labels
CANONICAL_LABELS = [
    "full name", "date", "address", "phone number", "fax number",
    "email address", "social security number", "credit card number",
    "bank account number", "amount", "credit score", "iban",
    "tax identification number", "driver's license number", "passport number",
    "identification number", "insurance number", "ip address", "username",
    "organization", "medical condition", "medical treatment", "medication"
]

print(f"Defined {len(LABEL_MAPPING)} source label mappings")
print(f"Target: {len(CANONICAL_LABELS)} canonical labels")

In [None]:
# ============================================================
# CHECK: UNMAPPED LABELS
# Run this after defining LABEL_MAPPING to find missing mappings
# ============================================================
print("=" * 70)
print("UNMAPPED LABELS (labels in datasets but NOT in LABEL_MAPPING)")
print("=" * 70)

unmapped = []
for label in all_source_labels:
    # Normalize for comparison
    normalized = label.lower().replace(' ', '_').replace('-', '_')
    if normalized not in LABEL_MAPPING:
        total = nvidia_labels.get(label, 0) + gretel_labels.get(label, 0) + gretel_fin_labels.get(label, 0)
        unmapped.append((label, total))

unmapped.sort(key=lambda x: -x[1])  # Sort by count descending

if unmapped:
    print(f"\nFound {len(unmapped)} unmapped labels:")
    print("\nLabel                               Total Count   Action Needed")
    print("-" * 70)
    for label, count in unmapped:
        print(f"{label:35s} {count:6d}        # TODO: add mapping")
else:
    print("\n✓ All labels are mapped!")

print("\n" + "=" * 70)
print("To add a mapping, use format:")
print('  "source_label": ("canonical_label", "subtype"),')
print("=" * 70)

## 2. Load Datasets

In [None]:
def load_nvidia_data():
    """Load nvidia/Nemotron-PII dataset"""
    path = f"{DATA_DIR}/nvidia-nemotron-pii/test.json"
    with open(path, 'r') as f:
        data = json.load(f)
    
    samples = []
    for item in data:
        text = item.get('text', '')
        spans = item.get('spans', [])
        
        # Parse spans if string
        if isinstance(spans, str):
            try:
                spans = ast.literal_eval(spans)
            except:
                continue
        
        entities = []
        for span in spans:
            if isinstance(span, dict):
                label = span.get('label', '').lower().replace(' ', '_')
                start = span.get('start', 0)
                end = span.get('end', 0)
                entity_text = text[start:end] if start < end else span.get('text', '')
            else:
                continue
            
            if label in LABEL_MAPPING:
                canonical, subtype = LABEL_MAPPING[label]
                entities.append({
                    'text': entity_text,
                    'label': canonical,
                    'subtype': subtype,
                    'original_label': label,
                    'start': start,
                    'end': end
                })
        
        if entities:
            samples.append({
                'text': text,
                'entities': entities,
                'source': 'nvidia-nemotron'
            })
    
    print(f"Loaded {len(samples)} samples from nvidia-nemotron")
    return samples

nvidia_samples = load_nvidia_data()

In [None]:
def load_gretel_data():
    """Load gretel-pii-masking dataset"""
    path = f"{DATA_DIR}/gretel-pii-masking-en-v1/test.json"
    with open(path, 'r') as f:
        data = json.load(f)
    
    samples = []
    for item in data:
        text = item.get('text', '')
        entities_raw = item.get('entities', '[]')
        
        # Parse entities if string
        if isinstance(entities_raw, str):
            try:
                entities_raw = ast.literal_eval(entities_raw)
            except:
                continue
        
        entities = []
        for ent in entities_raw:
            entity_text = ent.get('entity', '')
            types = ent.get('types', [])
            
            for t in types:
                label = t.lower().replace(' ', '_')
                if label in LABEL_MAPPING:
                    canonical, subtype = LABEL_MAPPING[label]
                    
                    # Find position in text
                    start = text.find(entity_text)
                    end = start + len(entity_text) if start >= 0 else 0
                    
                    entities.append({
                        'text': entity_text,
                        'label': canonical,
                        'subtype': subtype,
                        'original_label': label,
                        'start': start,
                        'end': end
                    })
        
        if entities:
            samples.append({
                'text': text,
                'entities': entities,
                'source': 'gretel'
            })
    
    print(f"Loaded {len(samples)} samples from gretel")
    return samples

gretel_samples = load_gretel_data()

In [None]:
def load_gretel_finance():
    """Load gretel-finance-multilingual dataset"""
    path = f"{DATA_DIR}/gretel-finance-multilingual/test.json"
    with open(path, 'r') as f:
        data = json.load(f)
    
    samples = []
    for item in data:
        text = item.get('text', '')
        entities_raw = item.get('entities', '[]')
        
        if isinstance(entities_raw, str):
            try:
                entities_raw = ast.literal_eval(entities_raw)
            except:
                continue
        
        entities = []
        for ent in entities_raw:
            entity_text = ent.get('entity', '')
            types = ent.get('types', [])
            
            for t in types:
                label = t.lower().replace(' ', '_')
                if label in LABEL_MAPPING:
                    canonical, subtype = LABEL_MAPPING[label]
                    start = text.find(entity_text)
                    end = start + len(entity_text) if start >= 0 else 0
                    
                    entities.append({
                        'text': entity_text,
                        'label': canonical,
                        'subtype': subtype,
                        'original_label': label,
                        'start': start,
                        'end': end
                    })
        
        if entities:
            samples.append({
                'text': text,
                'entities': entities,
                'source': 'gretel-finance'
            })
    
    print(f"Loaded {len(samples)} samples from gretel-finance")
    return samples

gretel_finance_samples = load_gretel_finance()

In [None]:
# Combine all samples
all_samples = nvidia_samples + gretel_samples + gretel_finance_samples
print(f"\nTotal samples: {len(all_samples)}")

## 3. Organize by Label

In [None]:
# Group samples by canonical label
samples_by_label = defaultdict(list)

for sample in all_samples:
    for entity in sample['entities']:
        label = entity['label']
        samples_by_label[label].append({
            'text': sample['text'],
            'entity': entity,
            'source': sample['source']
        })

# Show distribution
print("Samples per label:")
print("=" * 50)
for label in CANONICAL_LABELS:
    count = len(samples_by_label[label])
    print(f"{count:6d}  {label}")

In [None]:
# Show subtype distribution for each label
print("\nSubtype distribution per label:")
print("=" * 60)

for label in CANONICAL_LABELS:
    subtypes = Counter()
    for sample in samples_by_label[label]:
        subtypes[sample['entity']['subtype']] += 1
    
    if subtypes:
        print(f"\n{label}:")
        for subtype, count in subtypes.most_common():
            print(f"  {count:5d}  {subtype}")

## 4. Create JSON Files per Label

In [None]:
def create_label_file(label, samples, max_per_subtype=200):
    """Create a JSON file for a specific label with diverse subtypes"""
    
    # Group by subtype
    by_subtype = defaultdict(list)
    for sample in samples:
        subtype = sample['entity']['subtype']
        by_subtype[subtype].append(sample)
    
    # Sample from each subtype (balanced)
    final_samples = []
    for subtype, subtype_samples in by_subtype.items():
        # Shuffle and take up to max_per_subtype
        random.shuffle(subtype_samples)
        selected = subtype_samples[:max_per_subtype]
        
        for s in selected:
            final_samples.append({
                'text': s['text'],
                'entities': [{
                    'text': s['entity']['text'],
                    'label': label,
                    'subtype': subtype,
                    'start': s['entity']['start'],
                    'end': s['entity']['end']
                }],
                'source': s['source']
            })
    
    return final_samples, dict(by_subtype)

# Create files for each label
random.seed(42)  # For reproducibility

for label in CANONICAL_LABELS:
    samples = samples_by_label[label]
    
    if not samples:
        print(f"WARNING: No samples for '{label}'")
        continue
    
    final_samples, subtypes_info = create_label_file(label, samples)
    
    # Create filename (replace spaces and special chars)
    filename = label.replace(' ', '_').replace("'", "") + '.json'
    filepath = os.path.join(OUTPUT_DIR, filename)
    
    # Save
    output_data = {
        'label': label,
        'total_samples': len(final_samples),
        'subtypes': {k: len(v) for k, v in subtypes_info.items()},
        'samples': final_samples
    }
    
    with open(filepath, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    print(f"Created {filename}: {len(final_samples)} samples, {len(subtypes_info)} subtypes")

## 5. Summary Statistics

In [None]:
# Load and summarize all created files
print("\nFinal Dataset Summary")
print("=" * 70)

total_samples = 0
for filename in sorted(os.listdir(OUTPUT_DIR)):
    if filename.endswith('.json') and filename != 'label_subtypes_schema.json':
        filepath = os.path.join(OUTPUT_DIR, filename)
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        label = data['label']
        count = data['total_samples']
        subtypes = data['subtypes']
        total_samples += count
        
        print(f"{label:30s} {count:5d} samples  ({len(subtypes)} subtypes)")

print("=" * 70)
print(f"{'TOTAL':30s} {total_samples:5d} samples")

## 6. Create Combined Training File

In [None]:
# Combine all label files into one training file
all_training_samples = []

for filename in sorted(os.listdir(OUTPUT_DIR)):
    if filename.endswith('.json') and filename != 'label_subtypes_schema.json':
        filepath = os.path.join(OUTPUT_DIR, filename)
        with open(filepath, 'r') as f:
            data = json.load(f)
        
        all_training_samples.extend(data['samples'])

# Shuffle
random.shuffle(all_training_samples)

# Save combined file
combined_path = os.path.join(OUTPUT_DIR, '_combined_training.json')
with open(combined_path, 'w') as f:
    json.dump(all_training_samples, f, indent=2)

print(f"Created combined training file: {len(all_training_samples)} samples")
print(f"Saved to: {combined_path}")

In [None]:
# Show a few example samples
print("\nExample samples:")
print("=" * 70)

for i, sample in enumerate(all_training_samples[:3]):
    print(f"\nSample {i+1}:")
    print(f"Text: {sample['text'][:200]}...")
    print(f"Entities: {sample['entities']}")
    print(f"Source: {sample['source']}")