In [None]:
import re
from typing import List, Tuple, Set
import pandas as pd, numpy as np, torch, torch.nn as nn
from datasets import Dataset, DatasetDict
from collections import defaultdict
from transformers import (
    AutoTokenizer, AutoConfig,
    Trainer, TrainingArguments,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback
)
# ============================================================================
# GAZETTEER CREATION - Extract from your 5K training data
# ============================================================================

def build_gazetteers(tagged_df):
    """Extract known entities from training data."""
    
    gazetteers = {}
    
    for tag in tagged_df["Tag"].unique():
        if tag and tag != "O":
            values = set(
                tagged_df[tagged_df["Tag"] == tag]["Token"]
                .str.lower()
                .str.strip()
                .unique()
            )
            # Remove very common/short tokens (noise)
            gazetteers[tag] = {v for v in values if len(v) > 0}
    
    return gazetteers


# Load tagged data to get all possible tags
tagged = pd.read_csv(
    "../data/Tagged_Titles_Train.tsv",
    sep="\t", keep_default_na=False, na_values=None
)


# Create BIO label list
BASE = set(t for t in tagged["Tag"].unique() if t and t != "O")
label_list = ["O"] + sorted(
    f"{p}-{t}" for t in BASE for p in ("B", "I")
)
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}

print(f"Created {len(label_list)} labels: {label_list[:10]}...")

gazetteers = build_gazetteers(tagged)

print(f"Gazetteers created for {len(gazetteers)} tags")
for tag, values in gazetteers.items():
    print(f"  {tag}: {len(values)} entries")

Created 59 labels: ['O', 'B-Anwendung', 'B-Anzahl_Der_Einheiten', 'B-Besonderheiten', 'B-Breite', 'B-Bremsscheiben-Aussendurchmesser', 'B-Bremsscheibenart', 'B-Einbauposition', 'B-Farbe', 'B-Größe']...
Gazetteers created for 29 tags
  Kompatible_Fahrzeug_Marke: 173 entries
  Kompatibles_Fahrzeug_Modell: 1902 entries
  Herstellernummer: 1084 entries
  Produktart: 196 entries
  Im_Lieferumfang_Enthalten: 340 entries
  Hersteller: 155 entries
  Modell: 40 entries
  Einbauposition: 41 entries
  Bremsscheiben-Aussendurchmesser: 345 entries
  Bremsscheibenart: 45 entries
  Oe/Oem_Referenznummer(N): 228 entries
  Maßeinheit: 10 entries
  Anzahl_Der_Einheiten: 17 entries
  Kompatibles_Fahrzeug_Jahr: 183 entries
  Produktlinie: 4 entries
  Material: 5 entries
  Größe: 11 entries
  Länge: 3 entries
  Breite: 3 entries
  Besonderheiten: 31 entries
  Menge: 11 entries
  Farbe: 1 entries
  Stärke: 10 entries
  Anwendung: 14 entries
  Oberflächenbeschaffenheit: 2 entries
  SAE_Viskosität: 2 entries


In [3]:
# ============================================================================
# WEAK LABELING - Rule-based NER
# ============================================================================
from tqdm.auto import tqdm
class WeakNERLabeler:
    """Generate weak NER labels using rules and gazetteers."""
    
    def __init__(self, gazetteers: dict, category: int):
        self.gazetteers = gazetteers
        self.category = category
        
        # Automotive-specific patterns (German)
        self.patterns = {
            'Anzahl_Der_Einheiten': [
                (r'\b(\d+)\s*(stück|stk|x|pcs|piece)\b', 1),
                (r'\b(\d+)er\s+set\b', 1),
            ],
            'Einbauposition': [
                (r'\b(vorne?|hinten?|links?|rechts?|va|ha|vl|vr|hl|hr)\b', 0),
            ],
            'Kompatible_Fahrzeug_Marke': [
                # Will use gazetteer primarily
            ],
            'Kompatibles_Fahrzeug_Modell': [
                # Common model patterns
                (r'\b([a-z]\d+)\b', 0),  # A3, X5, etc.
            ],
        }
    
    def label_text(self, text: str) -> List[Tuple[str, str]]:
        """
        Returns list of (token, tag) pairs.
        Tags are in BIO format.
        """
        tokens = text.split()
        labels = ['O'] * len(tokens)
        text_lower = text.lower()
        
        # 1. Gazetteer matching (exact + fuzzy)
        for tag, entities in self.gazetteers.items():
            for entity in entities:
                entity_lower = entity.lower()
                
                # Multi-token entities
                if ' ' in entity_lower:
                    entity_tokens = entity_lower.split()
                    for i in range(len(tokens) - len(entity_tokens) + 1):
                        window = ' '.join(tokens[i:i+len(entity_tokens)]).lower()
                        if window == entity_lower:
                            labels[i] = f'B-{tag}'
                            for j in range(i+1, i+len(entity_tokens)):
                                labels[j] = f'I-{tag}'
                else:
                    # Single token
                    for i, tok in enumerate(tokens):
                        if tok.lower() == entity_lower and labels[i] == 'O':
                            labels[i] = f'B-{tag}'
        
        # 2. Pattern matching
        for tag, patterns in self.patterns.items():
            for pattern, group_idx in patterns:
                for match in re.finditer(pattern, text_lower):
                    matched_text = match.group(group_idx)
                    # Find token positions
                    start_char = match.start(group_idx)
                    end_char = match.end(group_idx)
                    
                    # Map character positions to token indices
                    char_pos = 0
                    for i, tok in enumerate(tokens):
                        if char_pos <= start_char < char_pos + len(tok):
                            if labels[i] == 'O':  # Don't override
                                labels[i] = f'B-{tag}'
                        char_pos += len(tok) + 1  # +1 for space
        
        return list(zip(tokens, labels))
    
    def label_batch(self, texts: List[str]) -> List[List[Tuple[str, str]]]:
        """Label multiple texts."""
        return [self.label_text(text) for text in texts]


# ============================================================================
# APPLY WEAK LABELING TO UNLABELED DATA
# ============================================================================

from multiprocessing import Pool, cpu_count
from functools import partial

def label_single_example(row_data, gazetteers, label2id, valid_tags):
    """Label a single example (for parallel processing)."""
    idx, text, category = row_data
    
    if not text.strip():
        return None
    
    labeler = WeakNERLabeler(gazetteers, category)
    
    try:
        token_label_pairs = labeler.label_text(text)
    except:
        return None
    
    tokens = [tok for tok, _ in token_label_pairs]
    labels = [label for _, label in token_label_pairs]
    
    # ✅ ADDED: Validate labels exist in label2id, otherwise use 'O'
    label_ids = []
    for lab in labels:
        if lab in label2id:
            label_ids.append(label2id[lab])
        else:
            # Invalid label - default to O
            label_ids.append(label2id['O'])
    
    # Skip if no entities (after validation)
    if all(lid == label2id['O'] for lid in label_ids):
        return None
    
    return {
        "tokens": tokens,
        "ner_tags": label_ids,
        "Category": category
    }


def create_weak_ner_dataset_parallel(
    unlabeled_df: pd.DataFrame,
    gazetteers: dict,
    label2id: dict,
    sample_size: int = 150000,
    n_workers: int = None
) -> Dataset:
    """
    Parallel version - much faster.
    """
    df = unlabeled_df.sample(n=min(sample_size, len(unlabeled_df)), random_state=42)
    
    # Prepare data for parallel processing
    row_data = [
        (idx, str(row.get("Title", "")), int(row.get("Category", 1)))
        for idx, row in df.iterrows()
    ]
    
    if n_workers is None:
        n_workers = max(1, cpu_count() - 2)  # Leave 2 cores free
    
    print(f"Generating weak labels using {n_workers} workers...")
    
    # Parallel processing
    # Update partial function to include valid_tags
    label_func = partial(
        label_single_example, 
        gazetteers=gazetteers, 
        label2id=label2id,
        valid_tags=BASE  # ✅ ADDED
    )
    
    with Pool(n_workers) as pool:
        results = list(tqdm(
            pool.imap(label_func, row_data, chunksize=100),
            total=len(row_data),
            desc="Weak labeling (parallel)"
        ))
    
    # Filter out None results
    weak_examples = [r for r in results if r is not None]
    
    skipped = len(results) - len(weak_examples)
    print(f"✓ Created {len(weak_examples)} weakly-labeled examples")
    print(f"✗ Skipped {skipped} examples")
    print(f"  Coverage: {len(weak_examples)/len(df)*100:.1f}%")
    
    return Dataset.from_list(weak_examples)


# Create weak dataset
df_unsup = pd.read_csv(
    "../data/Listing_Titles.tsv", 
    sep="\t", 
    keep_default_na=False, 
    na_values=None
)

weak_ds = create_weak_ner_dataset_parallel(
    df_unsup,
    gazetteers,
    label2id,
    sample_size=200000,
    n_workers=16  # Adjust based on your CPU
)



Generating weak labels using 16 workers...


Weak labeling (parallel):   0%|          | 0/200000 [00:00<?, ?it/s]

✓ Created 199229 weakly-labeled examples
✗ Skipped 771 examples
  Coverage: 99.6%


In [4]:
# ============================================================================
# IMPROVED WEAK LABELING - Context-Aware Rules
# ============================================================================
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm.auto import tqdm
class ImprovedWeakNERLabeler:
    """
    Enhanced weak labeling with:
    1. Context-aware rules (für X vs ORIGINAL X)
    2. Better pattern matching
    3. Position-based hints
    """
    
    def __init__(self, gazetteers: dict, category: int):
        self.gazetteers = gazetteers
        self.category = category
        
        # Context indicators for disambiguation
        self.context_rules = {
            'Kompatible_Fahrzeug_Marke': {
                'before': ['für', 'passend', 'geeignet', 'kompatibel'],
                'avoid_before': ['original', 'oem', 'genuine']
            },
            'Hersteller': {
                'before': ['original', 'oem', 'genuine', 'von'],
                'after': ['marke', 'qualität', 'hersteller']
            }
        }
        
        # Enhanced automotive patterns
        self.patterns = {
            'Anzahl_Der_Einheiten': [
                (r'\b(\d+)\s*(stück|stk|x|pcs|piece|teilig)\b', 1),
                (r'\b(\d+)er[\s-]*(set|satz)\b', 1),
                (r'\b(\d+)[\s-]*teilig\b', 1),
            ],
            'Einbauposition': [
                (r'\b(vorne?|hinten?|links?|rechts?|va|ha|vl|vr|hl|hr)\b', 0),
                (r'\b(vorder|hinter|vorn|hint)\w*\b', 0),
                (r'\b(front|rear|left|right)\b', 0),
            ],
            'Durchmesser': [
                (r'\b(ø|durchmesser|dm\.?|diameter)\s*(\d+)\s*(mm|cm)?\b', 2),
                (r'\b(\d+)\s*mm\b', 1),
            ],
            'Breite': [
                (r'\bbreite\s*(\d+)\s*(mm|cm)?\b', 1),
                (r'\b(\d+)\s*mm\s*breit\b', 1),
            ],
            'Zähnezahl': [
                (r'\b(\d+)\s*zähne?\b', 1),
                (r'\b(\d+)[\s-]*teeth\b', 1),
            ],
            'Kompatibles_Fahrzeug_Modell': [
                # Model patterns: A3, X5, Golf 7, etc.
                (r'\b([A-Z]\d+)\b', 0),  # A3, X5, E46
                (r'\b([A-Z]{1,3}[\s-]\d{1,3})\b', 0),  # E-46, C 220
            ],
        }
        
        # Common German automotive brands (case-insensitive)
        self.known_brands = {
            'bmw', 'audi', 'vw', 'volkswagen', 'mercedes', 'benz', 'opel',
            'ford', 'renault', 'peugeot', 'citroen', 'fiat', 'seat', 'skoda',
            'porsche', 'volvo', 'saab', 'toyota', 'nissan', 'honda', 'mazda',
            'hyundai', 'kia', 'chevrolet', 'chrysler', 'jeep', 'land', 'rover',
            'mini', 'alfa', 'romeo', 'lancia', 'subaru', 'suzuki', 'mitsubishi',
            'dacia', 'jaguar'
        }
        
        # Common part manufacturers
        self.known_manufacturers = {
            'bosch', 'ate', 'brembo', 'zimmermann', 'febi', 'lemförder',
            'sachs', 'bilstein', 'corteco', 'mahle', 'mann', 'hella',
            'valeo', 'continental', 'trw', 'skf', 'fag', 'snr', 'ina',
            'dayco', 'gates', 'contitech', 'meyle', 'optimal', 'ruville',
            'swag', 'topran', 'trucktec', 'vemo', 'pierburg', 'elring'
        }
        
    def _contextual_tagging(self, tokens, labels):
            """Add context-aware corrections."""
            for i in range(len(tokens)):
                tok_lower = tokens[i].lower()
                
                # "für X" pattern → X is Marke
                if i > 0 and tokens[i-1].lower() == 'für':
                    if tok_lower in self.known_brands and labels[i] == 'O':
                        labels[i] = 'B-Kompatible_Fahrzeug_Marke'
                
                # "ORIGINAL X" pattern → X is Hersteller
                if i > 0 and tokens[i-1].lower() == 'original':
                    if labels[i] == 'O' or labels[i].endswith('Kompatible_Fahrzeug_Marke'):
                        labels[i] = 'B-Hersteller'
                
                # "X mm" pattern → X is dimension
                if i < len(tokens) - 1 and tokens[i+1].lower() in ['mm', 'cm']:
                    if tokens[i].isdigit() and labels[i] == 'O':
                        labels[i] = 'B-Durchmesser'
            
            return labels
    
    def label_text(self, text: str) -> List[Tuple[str, str]]:
        """
        Returns list of (token, tag) pairs with context-aware labeling.
        """
        tokens = text.split()
        labels = ['O'] * len(tokens)
        text_lower = text.lower()
        
        # PHASE 1: Gazetteer matching with context awareness
        for tag, entities in self.gazetteers.items():
            for entity in entities:
                entity_lower = entity.lower()
                
                # Multi-token entities
                if ' ' in entity_lower:
                    entity_tokens = entity_lower.split()
                    for i in range(len(tokens) - len(entity_tokens) + 1):
                        window = ' '.join(tokens[i:i+len(entity_tokens)]).lower()
                        if window == entity_lower:
                            # Check context for disambiguation
                            resolved_tag = self._resolve_tag_with_context(
                                tag, i, tokens, entity_lower
                            )
                            
                            if labels[i] == 'O':  # Don't override
                                labels[i] = f'B-{resolved_tag}'
                                for j in range(i+1, i+len(entity_tokens)):
                                    labels[j] = f'I-{resolved_tag}'
                else:
                    # Single token
                    for i, tok in enumerate(tokens):
                        if tok.lower() == entity_lower and labels[i] == 'O':
                            resolved_tag = self._resolve_tag_with_context(
                                tag, i, tokens, entity_lower
                            )
                            labels[i] = f'B-{resolved_tag}'
        
        # PHASE 2: Brand/Manufacturer heuristics
        for i, tok in enumerate(tokens):
            tok_lower = tok.lower().strip('.,;:')
            
            if labels[i] == 'O':  # Only if not already labeled
                # Check if it's a known brand
                if tok_lower in self.known_brands:
                    # Context check: für BMW = Marke, ORIGINAL BMW = Hersteller
                    if i > 0:
                        prev = tokens[i-1].lower()
                        if prev in ['für', 'passend', 'geeignet', 'kompatibel']:
                            labels[i] = 'B-Kompatible_Fahrzeug_Marke'
                        elif prev in ['original', 'oem', 'genuine']:
                            labels[i] = 'B-Hersteller'
                        else:
                            # Default: assume vehicle brand
                            labels[i] = 'B-Kompatible_Fahrzeug_Marke'
                    else:
                        labels[i] = 'B-Kompatible_Fahrzeug_Marke'
                
                # Check if it's a known manufacturer
                elif tok_lower in self.known_manufacturers:
                    labels[i] = 'B-Hersteller'
        
        # PHASE 3: Pattern matching
        for tag, patterns in self.patterns.items():
            for pattern, group_idx in patterns:
                for match in re.finditer(pattern, text_lower, re.IGNORECASE):
                    matched_text = match.group(group_idx)
                    start_char = match.start(group_idx)
                    
                    # Map character positions to token indices
                    char_pos = 0
                    for i, tok in enumerate(tokens):
                        tok_len = len(tok)
                        if char_pos <= start_char < char_pos + tok_len:
                            if labels[i] == 'O':  # Don't override
                                labels[i] = f'B-{tag}'
                            break
                        char_pos += tok_len + 1  # +1 for space
        
        # PHASE 4: Fix common tokenization issues
        labels = self._fix_compound_entities(tokens, labels)
        labels = self._contextual_tagging(tokens, labels)
        return list(zip(tokens, labels))
    
    def _resolve_tag_with_context(
        self, 
        original_tag: str, 
        position: int, 
        tokens: List[str],
        entity: str
    ) -> str:
        """
        Disambiguate tags using context.
        E.g., BMW could be Marke or Hersteller depending on context.
        """
        # Only disambiguate between Marke and Hersteller
        if original_tag not in ['Kompatible_Fahrzeug_Marke', 'Hersteller']:
            return original_tag
        
        # Check if entity is a known brand
        if entity.lower() not in self.known_brands:
            return original_tag
        
        # Look at previous token
        if position > 0:
            prev = tokens[position - 1].lower().strip('.,;:')
            
            # "für BMW" → Marke
            if prev in ['für', 'passend', 'geeignet', 'kompatibel', 'fits', 'fit']:
                return 'Kompatible_Fahrzeug_Marke'
            
            # "ORIGINAL BMW" → Hersteller
            if prev in ['original', 'oem', 'genuine', 'von']:
                return 'Hersteller'
        
        # Look at next token
        if position < len(tokens) - 1:
            next_tok = tokens[position + 1].lower().strip('.,;:')
            
            # "BMW Qualität" → Hersteller
            if next_tok in ['qualität', 'original', 'teil', 'hersteller']:
                return 'Hersteller'
        
        # Default: prefer Marke (more common in titles)
        return 'Kompatible_Fahrzeug_Marke'
    
    def _fix_compound_entities(
        self, 
        tokens: List[str], 
        labels: List[str]
    ) -> List[str]:
        """
        Fix common issues like number+unit being split.
        E.g., "259" "mm" → both should be tagged together
        """
        fixed_labels = labels.copy()
        
        for i in range(len(tokens) - 1):
            current = tokens[i]
            next_tok = tokens[i + 1].lower()
            
            # If current is a number and next is a unit
            if current.isdigit() and next_tok in ['mm', 'cm', 'm', 'kg', 'g', 'stk', 'x']:
                # If current has a tag and next doesn't
                if labels[i].startswith('B-') and labels[i+1] == 'O':
                    tag = labels[i][2:]  # Remove B- prefix
                    fixed_labels[i+1] = f'I-{tag}'
        
        return fixed_labels


def label_single_example(row_data, gazetteers, label2id, known_brands, known_manufacturers):
    """Label a single example with improved labeler."""
    idx, text, category = row_data
    
    if not text.strip():
        return None
    
    # Use improved labeler
    labeler = ImprovedWeakNERLabeler(gazetteers, category)
    
    try:
        token_label_pairs = labeler.label_text(text)
    except Exception as e:
        return None
    
    tokens = [tok for tok, _ in token_label_pairs]
    labels = [label for _, label in token_label_pairs]
    
    # Validate labels exist in label2id
    label_ids = []
    for lab in labels:
        if lab in label2id:
            label_ids.append(label2id[lab])
        else:
            label_ids.append(label2id['O'])
    
    # Skip if no entities (after validation)
    if all(lid == label2id['O'] for lid in label_ids):
        return None
    
    return {
        "tokens": tokens,
        "ner_tags": label_ids,
        "Category": category
    }

def create_weak_ner_dataset_parallel(
    unlabeled_df: pd.DataFrame,
    gazetteers: dict,
    label2id: dict,
    sample_size: int = 150000,
    n_workers: int = None
) -> Dataset:
    """
    Parallel version - much faster.
    """
    df = unlabeled_df.sample(n=min(sample_size, len(unlabeled_df)), random_state=42)
    
    # Prepare data for parallel processing
    row_data = [
        (idx, str(row.get("Title", "")), int(row.get("Category", 1)))
        for idx, row in df.iterrows()
    ]
    
    if n_workers is None:
        n_workers = max(1, cpu_count() - 2)  # Leave 2 cores free
    
    print(f"Generating weak labels using {n_workers} workers...")

    known_brands = {
    'bmw', 'audi', 'vw', 'volkswagen', 'mercedes', 'benz', 'opel',
    'ford', 'renault', 'peugeot', 'citroen', 'fiat', 'seat', 'skoda'
    }
        
    known_manufacturers = {
        'bosch', 'ate', 'brembo', 'zimmermann', 'febi', 'lemförder',
        'sachs', 'bilstein', 'corteco', 'mahle', 'mann', 'hella'
    }

    # Update partial function
    label_func = partial(
        label_single_example, 
        gazetteers=gazetteers, 
        label2id=label2id,
        known_brands=known_brands,
        known_manufacturers=known_manufacturers
    )
    
    # Rest is same...
    with Pool(n_workers) as pool:
        results = list(tqdm(
            pool.imap(label_func, row_data, chunksize=100),
            total=len(row_data),
            desc="Weak labeling (improved)"
        ))
    
    # Filter out None results
    weak_examples = [r for r in results if r is not None]
    
    skipped = len(results) - len(weak_examples)
    print(f"✓ Created {len(weak_examples)} weakly-labeled examples")
    print(f"✗ Skipped {skipped} examples")
    print(f"  Coverage: {len(weak_examples)/len(df)*100:.1f}%")
    
    return Dataset.from_list(weak_examples)


# Create weak dataset
df_unsup = pd.read_csv(
    "../data/Listing_Titles.tsv", 
    sep="\t", 
    keep_default_na=False, 
    na_values=None
)

weak_ds = create_weak_ner_dataset_parallel(
    df_unsup,
    gazetteers,
    label2id,
    sample_size=200000,
    n_workers=16  # Adjust based on your CPU
)

Generating weak labels using 16 workers...


Weak labeling (improved):   0%|          | 0/200000 [00:00<?, ?it/s]

✓ Created 199284 weakly-labeled examples
✗ Skipped 716 examples
  Coverage: 99.6%


In [5]:
weak_splits = weak_ds.train_test_split(test_size=0.05, seed=42)
weak_splits

DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags', 'Category'],
        num_rows: 189319
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'Category'],
        num_rows: 9965
    })
})

In [6]:
# Split into train/val

weak_train = weak_splits["train"]
weak_val = weak_splits["test"]

print(f"Weak train: {len(weak_train)}, Weak val: {len(weak_val)}")

Weak train: 189319, Weak val: 9965


In [7]:
# ============================================================================
# WEAK NER PRE-TRAINING (Replaces MLM Step 1)
# ============================================================================

from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedModel, AutoModel

base_model = "microsoft/deberta-v3-large"
tok = AutoTokenizer.from_pretrained(base_model, use_fast=True)


def tok_fn_weak(batch):
    """Tokenize for NER."""
    enc = tok(
        batch["tokens"],
        is_split_into_words=True,
        # padding=False,
        truncation=True,
        max_length=256
    )
    
    all_labels = []
    for i in range(len(enc["input_ids"])):
        word_ids = enc.word_ids(batch_index=i)
        gold = batch["ner_tags"][i]
        seq = []
        prev = None
        
        for wid in word_ids:
            if wid is None:
                seq.append(-100)
            elif wid != prev:
                seq.append(gold[wid])
                prev = wid
            else:
                seq.append(-100)
        
        all_labels.append(seq)
    
    enc["labels"] = all_labels
    enc["category_id"] = batch["Category"]
    return enc


# Tokenize weak datasets
tok_weak_train = weak_train.map(
    tok_fn_weak, 
    batched=True, 
    remove_columns=["tokens", "ner_tags", "Category"],
    num_proc=16
)

tok_weak_val = weak_val.map(
    tok_fn_weak, 
    batched=True, 
    remove_columns=["tokens", "ner_tags", "Category"],
    num_proc=16
)


# ============================================================================
# SIMPLE TOKEN CLASSIFICATION MODEL (No CRF yet)
# ============================================================================

class WeakNERModel(PreTrainedModel):
    """Simple token classification for weak pre-training."""
    config_class = AutoConfig
    
    def __init__(self, config, num_labels=None, base_model_name=None):
        super().__init__(config)
        self.num_labels = num_labels
        
        # Load base encoder
        if base_model_name:
            self.encoder = AutoModel.from_pretrained(
                base_model_name, 
            )
        else:
            self.encoder = AutoModel.from_config(config)
        
        # Simple classifier (no CRF for weak training)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        
        # Initialize classifier only
        nn.init.normal_(self.classifier.weight, std=0.02)
        nn.init.zeros_(self.classifier.bias)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        logits = self.classifier(self.dropout(sequence_output))
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels),
                labels.view(-1)
            )
        
        return {"loss": loss, "logits": logits}


# Initialize model
cfg = AutoConfig.from_pretrained(
    base_model,
    num_labels=len(label_list)
)

weak_model = WeakNERModel(
    cfg,
    num_labels=len(label_list),
    base_model_name=base_model
).to("cuda")


# Simple data collator
from transformers import DataCollatorForTokenClassification
collator = DataCollatorForTokenClassification(
    tokenizer=tok,
    padding=True,
    max_length=256,
    pad_to_multiple_of=8
)


# ============================================================================
# WEAK TRAINING ARGUMENTS (Conservative)
# ============================================================================

weak_args = TrainingArguments(
    output_dir="../models/deberta-improved-weak-ner-mk-2",
    
    # Batch size
    per_device_train_batch_size=32,
    per_device_eval_batch_size=48,
    gradient_accumulation_steps=4,  
    # Learning rate (higher than fine-tuning, lower than from scratch)
    learning_rate=3e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    
    # Epochs (2-3 for weak labels)
    num_train_epochs=3,
    
    # Optimization
    optim="adamw_torch_fused",
    lr_scheduler_type="linear",
    max_grad_norm=1.0,
    
    # Mixed precision
    bf16=True,
    fp16=False,
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    
    # Efficiency
    dataloader_num_workers=16,
    dataloader_pin_memory=True,
    gradient_checkpointing=False,
    
    # Logging
    logging_steps=100,
    
    # Reproducibility
    seed=42,
    report_to="none",
)


weak_trainer = Trainer(
    model=weak_model,
    args=weak_args,
    train_dataset=tok_weak_train,
    eval_dataset=tok_weak_val,
    data_collator=collator,
    processing_class=tok,
)

# Train on weak labels
print("Starting weak NER pre-training...")
weak_trainer.train()

# Save encoder only (discard classifier head)
weak_model.encoder.save_pretrained("../models//deberta-improved-weak-ner-mk-2")
tok.save_pretrained("../models//deberta-improved-weak-ner-mk-2")
print("Weak NER encoder saved!")

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



Map (num_proc=16):   0%|          | 0/189319 [00:00<?, ? examples/s]

Map (num_proc=16):   0%|          | 0/9965 [00:00<?, ? examples/s]

pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

Starting weak NER pre-training...




Step,Training Loss,Validation Loss
500,0.2628,0.053387
1000,0.1091,0.024563
1500,0.0716,0.015878
2000,0.0464,0.011402
2500,0.0392,0.008833
3000,0.026,0.007253
3500,0.0206,0.007079
4000,0.0181,0.005948




Weak NER encoder saved!
