# üß† CSF Browser Training Notebook

**Contrastive Semantic Features (CSF) Extractor for Sign Language Generation**

## Key points
- **35 Condition Types** (expanded from 4): Weather, Time, Health, Schedule, Mood, Social, Activity, Financial
- **18,885 Training Samples** (up from 6,293)
- **4 Languages**: English, Vietnamese, Japanese, French
- **Custom 8K BPE Tokenizer** for browser deployment
- **~23 MB Model** for real-time inference

---

## Architecture
- **Encoder**: 4-layer Transformer (256 hidden, 4 heads)
- **Tokenizer**: Custom BPE (8,000 vocab)
- **Output**: 9 classification heads (event, intent, time, condition, agent, object, location, purpose, modifier)

## 1Ô∏è‚É£ Setup & Installation

In [1]:
# Install required packages
!pip install -q torch transformers tokenizers onnx onnxruntime scikit-learn tqdm gdown

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("‚úÖ Setup complete!")


[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m18.1/18.1 MB[0m [31m130.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m17.4/17.4 MB[0m [31m128.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m46.0/46.0 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m86.8/86.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hMounted at /content/drive
‚úÖ Setup complete!


## 2Ô∏è‚É£ Configuration

In [2]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# CONFIGURATION
# ============================================================

# Paths
# Paths - Auto-download from Google Drive
DATA_PATH = "/content/csf_train_19k.jsonl"
GDRIVE_FILE_ID = "1OWILS9T9kybftmSmI1sopoy9P2oglIgc"

# Auto-download if not exists
if not os.path.exists(DATA_PATH):
    print("üì• Downloading dataset from Google Drive...")
    !gdown --id {GDRIVE_FILE_ID} -O {DATA_PATH}
    print(f"‚úÖ Downloaded to {DATA_PATH}")
OUTPUT_DIR = "/content/drive/MyDrive/csf_browser_v4"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model hyperparameters
CONFIG = {
    "vocab_size": 8000,
    "hidden_size": 256,
    "num_attention_heads": 4,
    "num_hidden_layers": 4,
    "intermediate_size": 1024,
    "max_position_embeddings": 128,
    "dropout": 0.1,
    "max_length": 64,
}

# Training hyperparameters
TRAIN_CONFIG = {
    "batch_size": 64,
    "learning_rate": 2e-4,
    "epochs": 15,
    "warmup_ratio": 0.1,
    "weight_decay": 0.01,
    "seed": 42,
}

# ============================================================
# LABELS - 35 CONDITIONS!
# ============================================================

LABELS = {
    "event": ["GO", "STAY", "BUY", "WORK", "MEET", "EAT", "LEARN"],  # 7
    "intent": ["NONE", "PLAN", "WANT", "DECIDE"],  # 4
    "time": ["NONE", "TODAY", "TOMORROW", "YESTERDAY", "NOW"],  # 5
    "condition": [
        "NONE",
        # Weather (5)
        "IF_RAIN", "IF_SUNNY", "IF_COLD", "IF_HOT", "IF_WINDY",
        # Time (5)
        "IF_LATE", "IF_EARLY", "IF_WEEKEND", "IF_NIGHT", "IF_MORNING",
        # Health (5)
        "IF_SICK", "IF_TIRED", "IF_HUNGRY", "IF_THIRSTY", "IF_FULL",
        # Schedule (4)
        "IF_BUSY", "IF_FREE", "IF_HOLIDAY", "IF_WORKING",
        # Mood (5)
        "IF_BORED", "IF_HAPPY", "IF_SAD", "IF_STRESSED", "IF_ANGRY",
        # Social (3)
        "IF_ALONE", "IF_WITH_FRIENDS", "IF_WITH_FAMILY",
        # Activity (5)
        "IF_FINISH_WORK", "IF_FINISH_SCHOOL", "IF_FINISH_EATING", "IF_WATCH_MOVIE", "IF_LISTEN_MUSIC",
        # Financial (2)
        "IF_HAVE_MONEY", "IF_NO_MONEY",
    ],  # 35 total
    "agent": ["ME", "YOU", "HE", "SHE", "THEY"],  # 5
    "object": ["NONE", "FOOD", "BOOK", "MEDICINE", "THING"],  # 5
    "location": ["NONE", "HOME", "SCHOOL", "HOSPITAL", "OFFICE", "STORE"],  # 6
    "purpose": ["NONE", "REST"],  # 2
    "modifier": ["NONE", "FAST", "SLOW", "ALONE"]  # 4
}

SLOT_NAMES = list(LABELS.keys())
NUM_CLASSES = {slot: len(labels) for slot, labels in LABELS.items()}
LABEL_TO_ID = {slot: {label: i for i, label in enumerate(labels)} for slot, labels in LABELS.items()}
ID_TO_LABEL = {slot: {i: label for i, label in enumerate(labels)} for slot, labels in LABELS.items()}

# Set seeds
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(TRAIN_CONFIG["seed"])

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nüñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

print(f"\nüìä Labels Summary:")
for slot, labels in LABELS.items():
    print(f"   {slot:12s}: {len(labels)} classes")
print(f"\n   Total output classes: {sum(NUM_CLASSES.values())}")

üì• Downloading dataset from Google Drive...
Downloading...
From: https://drive.google.com/uc?id=1OWILS9T9kybftmSmI1sopoy9P2oglIgc
To: /content/csf_train_19k.jsonl
100% 4.43M/4.43M [00:00<00:00, 18.7MB/s]
‚úÖ Downloaded to /content/csf_train_19k.jsonl

üñ•Ô∏è  Device: cuda
   GPU: NVIDIA A100-SXM4-40GB

üìä Labels Summary:
   event       : 7 classes
   intent      : 4 classes
   time        : 5 classes
   condition   : 35 classes
   agent       : 5 classes
   object      : 5 classes
   location    : 6 classes
   purpose     : 2 classes
   modifier    : 4 classes

   Total output classes: 73


## 3Ô∏è‚É£ Load & Prepare Data

In [3]:
# ============================================================
# LOAD DATA
# ============================================================

print("üìÇ Loading data...")

# Auto-download from Google Drive if not exists
if not os.path.exists(DATA_PATH):
    print("üì• Downloading dataset from Google Drive...")
    import gdown
    gdown.download(id=GDRIVE_FILE_ID, output=DATA_PATH, quiet=False)

with open(DATA_PATH, 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

print(f"‚úÖ Loaded {len(data)} samples")

# Show sample
print(f"\nüìù Sample entry:")
print(f"   Text: {data[0]['text']}")
print(f"   CSF:  {data[0]['csf']}")

# Statistics
from collections import Counter
conditions = Counter(s["csf"]["condition"] for s in data)
events = Counter(s["csf"]["event"] for s in data)

print(f"\nüìä Condition distribution (top 10):")
for cond, count in conditions.most_common(10):
    print(f"   {cond:20s}: {count:5d} ({100*count/len(data):.1f}%)")

print(f"\nüìä Event distribution:")
for event, count in events.most_common():
    print(f"   {event:10s}: {count:5d} ({100*count/len(data):.1f}%)")


üìÇ Loading data...
‚úÖ Loaded 18885 samples

üìù Sample entry:
   Text: C√¥ ·∫•y g·∫∑p vƒÉn ph√≤ng.
   CSF:  {'event': 'MEET', 'intent': 'NONE', 'time': 'NONE', 'condition': 'NONE', 'agent': 'SHE', 'object': 'NONE', 'location': 'OFFICE', 'purpose': 'NONE', 'modifier': 'NONE'}

üìä Condition distribution (top 10):
   NONE                :  4260 (22.6%)
   IF_RAIN             :   995 (5.3%)
   IF_SICK             :   699 (3.7%)
   IF_BUSY             :   658 (3.5%)
   IF_SUNNY            :   576 (3.1%)
   IF_WEEKEND          :   573 (3.0%)
   IF_FREE             :   528 (2.8%)
   IF_NIGHT            :   495 (2.6%)
   IF_TIRED            :   495 (2.6%)
   IF_HUNGRY           :   483 (2.6%)

üìä Event distribution:
   STAY      :  8223 (43.5%)
   GO        :  3409 (18.1%)
   EAT       :  3165 (16.8%)
   MEET      :  1231 (6.5%)
   WORK      :  1197 (6.3%)
   BUY       :  1187 (6.3%)
   LEARN     :   473 (2.5%)


## 4Ô∏è‚É£ Train Custom BPE Tokenizer

In [4]:
# ============================================================
# TRAIN CUSTOM BPE TOKENIZER
# ============================================================

from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders

print("üî§ Training custom BPE tokenizer...")

# Extract all texts
texts = [s["text"] for s in data]

# Save texts temporarily for training
with open("/content/train_texts.txt", "w", encoding="utf-8") as f:
    for text in texts:
        f.write(text + "\n")

# Create BPE tokenizer
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))

# Pre-tokenizer: split on whitespace and punctuation
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)

# Trainer
trainer = trainers.BpeTrainer(
    vocab_size=CONFIG["vocab_size"],
    special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
    min_frequency=2,
    show_progress=True,
)

# Train
tokenizer.train(files=["/content/train_texts.txt"], trainer=trainer)

# Post-processor: add [CLS] and [SEP]
tokenizer.post_processor = processors.TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B [SEP]",
    special_tokens=[
        ("[CLS]", tokenizer.token_to_id("[CLS]")),
        ("[SEP]", tokenizer.token_to_id("[SEP]")),
    ],
)

# Decoder
tokenizer.decoder = decoders.ByteLevel()

# Enable padding
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]", length=CONFIG["max_length"])
tokenizer.enable_truncation(max_length=CONFIG["max_length"])

# Save tokenizer
TOKENIZER_PATH = f"{OUTPUT_DIR}/tokenizer.json"
tokenizer.save(TOKENIZER_PATH)

print(f"\n‚úÖ Tokenizer trained!")
print(f"   Vocab size: {tokenizer.get_vocab_size()}")
print(f"   Saved to: {TOKENIZER_PATH}")

# Test tokenizer
test_texts = [
    "I go to school tomorrow.",
    "N·∫øu m∆∞a th√¨ t√¥i ·ªü nh√†.",
    "ÊòéÊó•„ÄÅÂ≠¶Ê†°„Å´Ë°å„Åè„ÄÇ",
    "Je travaille √† l'h√¥pital.",
]

print(f"\nüìù Tokenization examples:")
for text in test_texts:
    enc = tokenizer.encode(text)
    print(f"   {text[:35]:35s} ‚Üí {len(enc.ids)} tokens")

üî§ Training custom BPE tokenizer...

‚úÖ Tokenizer trained!
   Vocab size: 3954
   Saved to: /content/drive/MyDrive/csf_browser_v4/tokenizer.json

üìù Tokenization examples:
   I go to school tomorrow.            ‚Üí 64 tokens
   N·∫øu m∆∞a th√¨ t√¥i ·ªü nh√†.              ‚Üí 64 tokens
   ÊòéÊó•„ÄÅÂ≠¶Ê†°„Å´Ë°å„Åè„ÄÇ                           ‚Üí 64 tokens
   Je travaille √† l'h√¥pital.           ‚Üí 64 tokens


## 5Ô∏è‚É£ Dataset & DataLoader

In [5]:
# ============================================================
# DATASET CLASS
# ============================================================

class CSFDataset(Dataset):
    def __init__(self, samples, tokenizer, label_to_id, max_length=64):
        self.samples = samples
        self.tokenizer = tokenizer
        self.label_to_id = label_to_id
        self.max_length = max_length
        self.slot_names = list(label_to_id.keys())

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        csf = sample["csf"]

        # Tokenize
        enc = self.tokenizer.encode(text)
        input_ids = enc.ids[:self.max_length]
        attention_mask = enc.attention_mask[:self.max_length]

        # Pad if needed
        pad_len = self.max_length - len(input_ids)
        if pad_len > 0:
            input_ids = input_ids + [0] * pad_len
            attention_mask = attention_mask + [0] * pad_len

        # Labels
        labels = {}
        for slot in self.slot_names:
            val = csf.get(slot, "NONE")
            if val is None:
                val = "NONE"
            # Handle unknown labels
            if val not in self.label_to_id[slot]:
                val = list(self.label_to_id[slot].keys())[0]  # Default to first
            labels[slot] = self.label_to_id[slot][val]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": {slot: torch.tensor(label, dtype=torch.long) for slot, label in labels.items()}
        }

# ============================================================
# CREATE DATALOADERS
# ============================================================

# Split data
train_data, val_data = train_test_split(data, test_size=0.1, random_state=TRAIN_CONFIG["seed"])

print(f"üìä Data split:")
print(f"   Train: {len(train_data)} samples")
print(f"   Val:   {len(val_data)} samples")

# Create datasets
train_dataset = CSFDataset(train_data, tokenizer, LABEL_TO_ID, CONFIG["max_length"])
val_dataset = CSFDataset(val_data, tokenizer, LABEL_TO_ID, CONFIG["max_length"])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG["batch_size"], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG["batch_size"], shuffle=False, num_workers=2)

print(f"\n‚úÖ DataLoaders created!")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

üìä Data split:
   Train: 16996 samples
   Val:   1889 samples

‚úÖ DataLoaders created!
   Train batches: 266
   Val batches: 30


## 6Ô∏è‚É£ Model Architecture

In [6]:
# ============================================================
# MODEL ARCHITECTURE
# ============================================================

class CSFClassificationHead(nn.Module):
    """Classification head for each slot."""
    def __init__(self, hidden_size, num_classes, dropout=0.1):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.dropout(x)
        x = torch.tanh(self.dense(x))
        x = self.dropout(x)
        return self.classifier(x)


class CSFExtractor(nn.Module):
    """CSF Extractor with Transformer encoder."""
    def __init__(self, config, num_classes):
        super().__init__()
        self.hidden_size = config["hidden_size"]

        # Embeddings
        self.word_embeddings = nn.Embedding(
            config["vocab_size"],
            config["hidden_size"],
            padding_idx=0
        )
        self.position_embeddings = nn.Embedding(
            config["max_position_embeddings"],
            config["hidden_size"]
        )
        self.layer_norm = nn.LayerNorm(config["hidden_size"], eps=1e-6)
        self.dropout = nn.Dropout(config["dropout"])

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config["hidden_size"],
            nhead=config["num_attention_heads"],
            dim_feedforward=config["intermediate_size"],
            dropout=config["dropout"],
            activation="gelu",
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=config["num_hidden_layers"]
        )

        # Classification heads
        self.classification_heads = nn.ModuleDict({
            slot: CSFClassificationHead(
                config["hidden_size"],
                num_classes[slot],
                config["dropout"]
            ) for slot in num_classes.keys()
        })

    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape

        # Position IDs
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        # Embeddings
        embeddings = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
        embeddings = self.dropout(self.layer_norm(embeddings))

        # Attention mask for transformer (True = ignore)
        mask = (attention_mask == 0) if attention_mask is not None else None

        # Encode
        encoder_output = self.encoder(embeddings, src_key_padding_mask=mask)

        # Use [CLS] token (index 0)
        cls_output = encoder_output[:, 0, :]

        # Classify each slot
        return {slot: head(cls_output) for slot, head in self.classification_heads.items()}


# ============================================================
# CREATE MODEL
# ============================================================

model = CSFExtractor(CONFIG, NUM_CLASSES).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüß† Model created!")
print(f"   Total parameters:     {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Estimated size:       {total_params * 4 / 1024 / 1024:.1f} MB")

# Show architecture
print(f"\nüìê Architecture:")
print(f"   Vocab size: {CONFIG['vocab_size']}")
print(f"   Hidden size: {CONFIG['hidden_size']}")
print(f"   Layers: {CONFIG['num_hidden_layers']}")
print(f"   Attention heads: {CONFIG['num_attention_heads']}")
print(f"   FFN size: {CONFIG['intermediate_size']}")


üß† Model created!
   Total parameters:     5,851,209
   Trainable parameters: 5,851,209
   Estimated size:       22.3 MB

üìê Architecture:
   Vocab size: 8000
   Hidden size: 256
   Layers: 4
   Attention heads: 4
   FFN size: 1024




## 7Ô∏è‚É£ Training Loop

In [7]:
# ============================================================
# TRAINING SETUP
# ============================================================

from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=TRAIN_CONFIG["learning_rate"],
    weight_decay=TRAIN_CONFIG["weight_decay"]
)

# Scheduler
total_steps = len(train_loader) * TRAIN_CONFIG["epochs"]
scheduler = OneCycleLR(
    optimizer,
    max_lr=TRAIN_CONFIG["learning_rate"],
    total_steps=total_steps,
    pct_start=TRAIN_CONFIG["warmup_ratio"],
    anneal_strategy='cos'
)

# Loss function
criterion = nn.CrossEntropyLoss()

print(f"\n‚öôÔ∏è  Training setup:")
print(f"   Optimizer: AdamW")
print(f"   Learning rate: {TRAIN_CONFIG['learning_rate']}")
print(f"   Epochs: {TRAIN_CONFIG['epochs']}")
print(f"   Total steps: {total_steps}")


‚öôÔ∏è  Training setup:
   Optimizer: AdamW
   Learning rate: 0.0002
   Epochs: 15
   Total steps: 3990


In [8]:
# ============================================================
# TRAINING LOOP
# ============================================================

def train_epoch(model, loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0
    correct = {slot: 0 for slot in SLOT_NAMES}
    total = 0

    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = {slot: batch["labels"][slot].to(device) for slot in SLOT_NAMES}

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask)

        loss = sum(criterion(outputs[slot], labels[slot]) for slot in SLOT_NAMES)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        total += input_ids.size(0)

        for slot in SLOT_NAMES:
            preds = outputs[slot].argmax(dim=-1)
            correct[slot] += (preds == labels[slot]).sum().item()

        avg_acc = sum(correct[s] for s in SLOT_NAMES) / (total * len(SLOT_NAMES)) * 100
        pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{avg_acc:.1f}%"})

    avg_loss = total_loss / len(loader)
    accuracies = {slot: correct[slot] / total * 100 for slot in SLOT_NAMES}
    return avg_loss, accuracies


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = {slot: 0 for slot in SLOT_NAMES}
    total = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = {slot: batch["labels"][slot].to(device) for slot in SLOT_NAMES}

            outputs = model(input_ids, attention_mask)

            loss = sum(criterion(outputs[slot], labels[slot]) for slot in SLOT_NAMES)
            total_loss += loss.item()
            total += input_ids.size(0)

            for slot in SLOT_NAMES:
                preds = outputs[slot].argmax(dim=-1)
                correct[slot] += (preds == labels[slot]).sum().item()

    avg_loss = total_loss / len(loader)
    accuracies = {slot: correct[slot] / total * 100 for slot in SLOT_NAMES}
    return avg_loss, accuracies


# ============================================================
# RUN TRAINING
# ============================================================

print("\n" + "=" * 60)
print("üöÄ STARTING TRAINING")
print("=" * 60)

best_val_acc = 0
history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}

for epoch in range(TRAIN_CONFIG["epochs"]):
    print(f"\nüìÖ Epoch {epoch + 1}/{TRAIN_CONFIG['epochs']}")
    print("-" * 40)

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)

    # Evaluate
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    # Calculate average accuracy
    avg_train_acc = sum(train_acc.values()) / len(train_acc)
    avg_val_acc = sum(val_acc.values()) / len(val_acc)

    # Save history
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(avg_train_acc)
    history["val_acc"].append(avg_val_acc)

    print(f"\n   Train Loss: {train_loss:.4f} | Train Acc: {avg_train_acc:.2f}%")
    print(f"   Val Loss:   {val_loss:.4f} | Val Acc:   {avg_val_acc:.2f}%")

    # Show per-slot accuracy
    print(f"\n   Slot Accuracies (Val):")
    for slot in SLOT_NAMES:
        print(f"      {slot:12s}: {val_acc[slot]:.1f}%")

    # Save best model
    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        torch.save(model.state_dict(), f"{OUTPUT_DIR}/best_model.pt")
        print(f"\n   ‚úÖ New best model saved! ({best_val_acc:.2f}%)")

print("\n" + "=" * 60)
print(f"üèÜ Training complete! Best Val Accuracy: {best_val_acc:.2f}%")
print("=" * 60)


üöÄ STARTING TRAINING

üìÖ Epoch 1/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:07<00:00, 35.10it/s, loss=6.7018, acc=68.5%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 76.50it/s]



   Train Loss: 8.9983 | Train Acc: 68.50%
   Val Loss:   6.5203 | Val Acc:   77.20%

   Slot Accuracies (Val):
      event       : 49.1%
      intent      : 94.7%
      time        : 89.8%
      condition   : 29.3%
      agent       : 85.8%
      object      : 96.6%
      location    : 62.7%
      purpose     : 91.1%
      modifier    : 95.7%

   ‚úÖ New best model saved! (77.20%)

üìÖ Epoch 2/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.38it/s, loss=4.4123, acc=80.8%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 79.70it/s]



   Train Loss: 5.4513 | Train Acc: 80.79%
   Val Loss:   3.5924 | Val Acc:   87.87%

   Slot Accuracies (Val):
      event       : 74.1%
      intent      : 95.5%
      time        : 98.0%
      condition   : 56.9%
      agent       : 93.0%
      object      : 97.1%
      location    : 84.1%
      purpose     : 96.5%
      modifier    : 95.7%

   ‚úÖ New best model saved! (87.87%)

üìÖ Epoch 3/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.28it/s, loss=2.8357, acc=89.2%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 86.14it/s] 



   Train Loss: 3.1729 | Train Acc: 89.21%
   Val Loss:   1.8658 | Val Acc:   93.84%

   Slot Accuracies (Val):
      event       : 88.5%
      intent      : 96.3%
      time        : 99.3%
      condition   : 79.0%
      agent       : 95.3%
      object      : 97.1%
      location    : 92.8%
      purpose     : 99.3%
      modifier    : 97.0%

   ‚úÖ New best model saved! (93.84%)

üìÖ Epoch 4/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 40.79it/s, loss=1.7037, acc=93.5%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 78.14it/s]



   Train Loss: 1.9497 | Train Acc: 93.46%
   Val Loss:   1.2018 | Val Acc:   96.15%

   Slot Accuracies (Val):
      event       : 93.0%
      intent      : 96.8%
      time        : 99.3%
      condition   : 90.3%
      agent       : 96.1%
      object      : 97.6%
      location    : 95.0%
      purpose     : 99.0%
      modifier    : 98.3%

   ‚úÖ New best model saved! (96.15%)

üìÖ Epoch 5/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.18it/s, loss=1.0120, acc=95.6%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 78.17it/s]



   Train Loss: 1.3192 | Train Acc: 95.61%
   Val Loss:   0.7920 | Val Acc:   97.58%

   Slot Accuracies (Val):
      event       : 94.8%
      intent      : 97.7%
      time        : 99.4%
      condition   : 95.8%
      agent       : 97.1%
      object      : 98.6%
      location    : 96.2%
      purpose     : 99.7%
      modifier    : 98.9%

   ‚úÖ New best model saved! (97.58%)

üìÖ Epoch 6/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 40.98it/s, loss=1.0391, acc=96.8%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 82.59it/s]



   Train Loss: 0.9502 | Train Acc: 96.84%
   Val Loss:   0.6250 | Val Acc:   98.08%

   Slot Accuracies (Val):
      event       : 95.8%
      intent      : 98.7%
      time        : 99.6%
      condition   : 97.9%
      agent       : 97.9%
      object      : 98.6%
      location    : 95.8%
      purpose     : 99.6%
      modifier    : 98.9%

   ‚úÖ New best model saved! (98.08%)

üìÖ Epoch 7/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.98it/s, loss=0.6526, acc=97.7%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 73.00it/s]



   Train Loss: 0.7135 | Train Acc: 97.66%
   Val Loss:   0.4973 | Val Acc:   98.51%

   Slot Accuracies (Val):
      event       : 96.3%
      intent      : 98.7%
      time        : 99.5%
      condition   : 98.5%
      agent       : 98.4%
      object      : 99.0%
      location    : 97.6%
      purpose     : 99.6%
      modifier    : 99.0%

   ‚úÖ New best model saved! (98.51%)

üìÖ Epoch 8/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.92it/s, loss=0.4298, acc=98.1%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 82.24it/s]



   Train Loss: 0.5806 | Train Acc: 98.08%
   Val Loss:   0.4351 | Val Acc:   98.58%

   Slot Accuracies (Val):
      event       : 96.6%
      intent      : 98.8%
      time        : 99.5%
      condition   : 98.9%
      agent       : 98.4%
      object      : 99.0%
      location    : 97.4%
      purpose     : 99.6%
      modifier    : 99.0%

   ‚úÖ New best model saved! (98.58%)

üìÖ Epoch 9/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 42.18it/s, loss=0.4829, acc=98.4%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 85.96it/s] 



   Train Loss: 0.4666 | Train Acc: 98.44%
   Val Loss:   0.4120 | Val Acc:   98.68%

   Slot Accuracies (Val):
      event       : 97.2%
      intent      : 98.6%
      time        : 99.5%
      condition   : 99.0%
      agent       : 98.5%
      object      : 99.3%
      location    : 97.2%
      purpose     : 99.6%
      modifier    : 99.2%

   ‚úÖ New best model saved! (98.68%)

üìÖ Epoch 10/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.47it/s, loss=0.1863, acc=98.7%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 80.92it/s]



   Train Loss: 0.4006 | Train Acc: 98.66%
   Val Loss:   0.3716 | Val Acc:   98.87%

   Slot Accuracies (Val):
      event       : 97.3%
      intent      : 99.0%
      time        : 99.5%
      condition   : 99.2%
      agent       : 98.8%
      object      : 99.2%
      location    : 97.8%
      purpose     : 99.7%
      modifier    : 99.3%

   ‚úÖ New best model saved! (98.87%)

üìÖ Epoch 11/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 40.45it/s, loss=0.6081, acc=98.9%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 79.91it/s]



   Train Loss: 0.3344 | Train Acc: 98.90%
   Val Loss:   0.3659 | Val Acc:   98.91%

   Slot Accuracies (Val):
      event       : 97.5%
      intent      : 99.2%
      time        : 99.6%
      condition   : 99.3%
      agent       : 98.7%
      object      : 99.2%
      location    : 97.7%
      purpose     : 99.6%
      modifier    : 99.3%

   ‚úÖ New best model saved! (98.91%)

üìÖ Epoch 12/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.92it/s, loss=0.6581, acc=99.0%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 81.15it/s]



   Train Loss: 0.3028 | Train Acc: 99.01%
   Val Loss:   0.3505 | Val Acc:   98.98%

   Slot Accuracies (Val):
      event       : 97.6%
      intent      : 99.1%
      time        : 99.6%
      condition   : 99.4%
      agent       : 99.0%
      object      : 99.3%
      location    : 97.9%
      purpose     : 99.6%
      modifier    : 99.4%

   ‚úÖ New best model saved! (98.98%)

üìÖ Epoch 13/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 40.80it/s, loss=0.2237, acc=99.1%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 78.17it/s]



   Train Loss: 0.2704 | Train Acc: 99.12%
   Val Loss:   0.3472 | Val Acc:   99.03%

   Slot Accuracies (Val):
      event       : 97.8%
      intent      : 99.2%
      time        : 99.6%
      condition   : 99.4%
      agent       : 99.0%
      object      : 99.2%
      location    : 97.9%
      purpose     : 99.7%
      modifier    : 99.5%

   ‚úÖ New best model saved! (99.03%)

üìÖ Epoch 14/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 42.57it/s, loss=0.1528, acc=99.2%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 77.90it/s]



   Train Loss: 0.2565 | Train Acc: 99.16%
   Val Loss:   0.3446 | Val Acc:   99.00%

   Slot Accuracies (Val):
      event       : 97.8%
      intent      : 99.2%
      time        : 99.6%
      condition   : 99.4%
      agent       : 99.0%
      object      : 99.2%
      location    : 97.8%
      purpose     : 99.6%
      modifier    : 99.4%

üìÖ Epoch 15/15
----------------------------------------


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 266/266 [00:06<00:00, 41.66it/s, loss=0.1478, acc=99.2%]
Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [00:00<00:00, 78.40it/s]


   Train Loss: 0.2552 | Train Acc: 99.17%
   Val Loss:   0.3434 | Val Acc:   99.01%

   Slot Accuracies (Val):
      event       : 97.7%
      intent      : 99.2%
      time        : 99.6%
      condition   : 99.3%
      agent       : 99.0%
      object      : 99.3%
      location    : 97.9%
      purpose     : 99.6%
      modifier    : 99.5%

üèÜ Training complete! Best Val Accuracy: 99.03%





## 8Ô∏è‚É£ Save Model

In [9]:
# ============================================================
# SAVE FINAL MODEL
# ============================================================

# Load best model
model.load_state_dict(torch.load(f"{OUTPUT_DIR}/best_model.pt"))
model.eval()

# Save PyTorch model
torch.save(model.state_dict(), f"{OUTPUT_DIR}/pytorch_model.bin")

# Save config
config_to_save = {
    **CONFIG,
    "num_classes": NUM_CLASSES,
    "slot_names": SLOT_NAMES,
}
with open(f"{OUTPUT_DIR}/config.json", "w") as f:
    json.dump(config_to_save, f, indent=2)

# Save labels
with open(f"{OUTPUT_DIR}/labels.json", "w") as f:
    json.dump(LABELS, f, indent=2)

print(f"\n‚úÖ Model saved to: {OUTPUT_DIR}")
print(f"   - pytorch_model.bin")
print(f"   - config.json")
print(f"   - labels.json")
print(f"   - tokenizer.json")


‚úÖ Model saved to: /content/drive/MyDrive/csf_browser_v4
   - pytorch_model.bin
   - config.json
   - labels.json
   - tokenizer.json


## 9Ô∏è‚É£ Export to ONNX

In [11]:
!pip install onnxscript

Collecting onnxscript
  Downloading onnxscript-0.5.7-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.13-py3-none-any.whl.metadata (3.2 kB)
Downloading onnxscript-0.5.7-py3-none-any.whl (693 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m693.4/693.4 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx_ir-0.1.13-py3-none-any.whl (133 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m133.1/133.1 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx_ir, onnxscript
Successfully installed onnx_ir-0.1.13 onnxscript-0.5.7


In [12]:
# ============================================================
# EXPORT TO ONNX
# ============================================================

import onnx

print("üì¶ Exporting to ONNX...")

# Wrapper for ONNX export (returns tuple instead of dict)
class CSFExtractorONNX(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.model = base_model

    def forward(self, input_ids, attention_mask):
        logits = self.model(input_ids, attention_mask)
        return tuple(logits[slot] for slot in SLOT_NAMES)

# Create export model
model.eval()
export_model = CSFExtractorONNX(model)
export_model.eval()

# Dummy inputs
dummy_input_ids = torch.randint(0, CONFIG["vocab_size"], (1, CONFIG["max_length"])).to(device)
dummy_attention_mask = torch.ones(1, CONFIG["max_length"], dtype=torch.long).to(device)

# Export
ONNX_PATH = f"{OUTPUT_DIR}/model.onnx"

with torch.no_grad():
    torch.onnx.export(
        export_model,
        (dummy_input_ids, dummy_attention_mask),
        ONNX_PATH,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=["input_ids", "attention_mask"],
        output_names=[f"logits_{slot}" for slot in SLOT_NAMES],
        dynamic_axes={
            "input_ids": {0: "batch_size"},
            "attention_mask": {0: "batch_size"},
            **{f"logits_{slot}": {0: "batch_size"} for slot in SLOT_NAMES}
        }
    )

# Verify
onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(onnx_model)

# Get size
onnx_size = os.path.getsize(ONNX_PATH) / 1024 / 1024

print(f"\n‚úÖ ONNX model exported!")
print(f"   Path: {ONNX_PATH}")
print(f"   Size: {onnx_size:.2f} MB")

üì¶ Exporting to ONNX...


  torch.onnx.export(
W1230 15:39:08.002000 469 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 14 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[torch.onnx] Obtain model graph for `CSFExtractorONNX([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `CSFExtractorONNX([...]` with `torch.export.export(..., strict=False)`... ‚úÖ
[torch.onnx] Run decomposition...


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/no_previous_version.h:26: adapt: Assertion `

[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 13 of general pattern rewrite rules.

‚úÖ ONNX model exported!
   Path: /content/drive/MyDrive/csf_browser_v4/model.onnx
   Size: 0.42 MB


## üîü Test Inference

In [13]:
# ============================================================
# TEST ONNX INFERENCE
# ============================================================

import onnxruntime as ort
import time

print("üß™ Testing ONNX inference...")

# Load tokenizer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file(f"{OUTPUT_DIR}/tokenizer.json")

# Load ONNX model
session = ort.InferenceSession(ONNX_PATH, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
print(f"   Provider: {session.get_providers()[0]}")

# GLOSS order
GLOSS_ORDER = ["modifier", "time", "condition", "agent", "location", "object", "event", "purpose"]

def predict(text):
    enc = tokenizer.encode(text)
    input_ids = np.array([enc.ids], dtype=np.int64)
    attention_mask = np.array([enc.attention_mask], dtype=np.int64)

    outputs = session.run(None, {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    })

    csf = {}
    for i, slot in enumerate(SLOT_NAMES):
        pred_id = np.argmax(outputs[i])
        csf[slot] = ID_TO_LABEL[slot][pred_id]
    return csf

def to_gloss(csf):
    tokens = []
    for slot in GLOSS_ORDER:
        val = csf.get(slot)
        if val and val != "NONE" and not (slot == "agent" and val == "ME"):
            tokens.append(val)
    return " ".join(tokens)

# Test examples
test_examples = [
    # Basic
    "I go to school tomorrow.",
    "She stays at home.",

    # Weather conditions
    "If it rains, I stay home.",
    "If it's sunny, I go to the park.",

    # Mood conditions
    "If I'm bored, I watch Netflix.",
    "When I'm tired, I take a nap.",
    "If I'm hungry, I eat food.",

    # Time conditions
    "On the weekend, I sleep in.",
    "After work, I go home.",
    "After school, I play games.",

    # Schedule conditions
    "If I'm free, I meet friends.",
    "If I'm busy, I skip lunch.",

    # Financial conditions
    "If I have money, I go shopping.",
    "When I'm broke, I stay home.",

    # Multilingual
    "N·∫øu m∆∞a th√¨ t√¥i ·ªü nh√†.",
    "N·∫øu ƒë√≥i th√¨ t√¥i ƒÉn.",
    "ÊòéÊó•„ÄÅÂ≠¶Ê†°„Å´Ë°å„Åç„Åæ„Åô„ÄÇ",
    "Áñ≤„Çå„Åü„Çâ„ÄÅÂÆ∂„Åß‰ºë„Åø„Åæ„Åô„ÄÇ",
    "Je travaille √† l'h√¥pital.",
    "Si je suis fatigu√©, je me repose.",
]

print("\n" + "=" * 70)
print("üìù INFERENCE EXAMPLES")
print("=" * 70)

for text in test_examples:
    csf = predict(text)
    gloss = to_gloss(csf)
    print(f"\nüó£Ô∏è  {text}")
    print(f"ü§ü {gloss}")
    print(f"   condition={csf['condition']}")

üß™ Testing ONNX inference...
   Provider: CPUExecutionProvider

üìù INFERENCE EXAMPLES

üó£Ô∏è  I go to school tomorrow.
ü§ü TOMORROW SCHOOL GO
   condition=NONE

üó£Ô∏è  She stays at home.
ü§ü SHE HOME STAY
   condition=NONE

üó£Ô∏è  If it rains, I stay home.
ü§ü IF_RAIN HOME STAY
   condition=IF_RAIN

üó£Ô∏è  If it's sunny, I go to the park.
ü§ü IF_SUNNY GO
   condition=IF_SUNNY

üó£Ô∏è  If I'm bored, I watch Netflix.
ü§ü IF_BORED HOME STAY
   condition=IF_BORED

üó£Ô∏è  When I'm tired, I take a nap.
ü§ü IF_TIRED HOME STAY REST
   condition=IF_TIRED

üó£Ô∏è  If I'm hungry, I eat food.
ü§ü IF_HUNGRY EAT
   condition=IF_HUNGRY

üó£Ô∏è  On the weekend, I sleep in.
ü§ü IF_WEEKEND HOME STAY
   condition=IF_WEEKEND

üó£Ô∏è  After work, I go home.
ü§ü IF_FINISH_WORK HOME GO
   condition=IF_FINISH_WORK





üó£Ô∏è  After school, I play games.
ü§ü IF_FINISH_SCHOOL HOME GO
   condition=IF_FINISH_SCHOOL

üó£Ô∏è  If I'm free, I meet friends.
ü§ü IF_FREE HOME STAY
   condition=IF_FREE

üó£Ô∏è  If I'm busy, I skip lunch.
ü§ü IF_BUSY OFFICE WORK
   condition=IF_BUSY

üó£Ô∏è  If I have money, I go shopping.
ü§ü IF_HAVE_MONEY STORE BUY
   condition=IF_HAVE_MONEY

üó£Ô∏è  When I'm broke, I stay home.
ü§ü IF_NO_MONEY HOME STAY
   condition=IF_NO_MONEY

üó£Ô∏è  N·∫øu m∆∞a th√¨ t√¥i ·ªü nh√†.
ü§ü IF_RAIN HOME STAY
   condition=IF_RAIN

üó£Ô∏è  N·∫øu ƒë√≥i th√¨ t√¥i ƒÉn.
ü§ü IF_HUNGRY EAT
   condition=IF_HUNGRY

üó£Ô∏è  ÊòéÊó•„ÄÅÂ≠¶Ê†°„Å´Ë°å„Åç„Åæ„Åô„ÄÇ
ü§ü TOMORROW STAY
   condition=NONE

üó£Ô∏è  Áñ≤„Çå„Åü„Çâ„ÄÅÂÆ∂„Åß‰ºë„Åø„Åæ„Åô„ÄÇ
ü§ü IF_SICK HOME STAY REST
   condition=IF_SICK

üó£Ô∏è  Je travaille √† l'h√¥pital.
ü§ü HOSPITAL WORK
   condition=NONE

üó£Ô∏è  Si je suis fatigu√©, je me repose.
ü§ü IF_TIRED HOME STAY REST
   condition=IF_TIRED


In [14]:
# ============================================================
# BENCHMARK
# ============================================================

print("\n" + "=" * 60)
print("‚è±Ô∏è  BENCHMARK")
print("=" * 60)

# Prepare input
test_text = "If I'm hungry, I eat food."
enc = tokenizer.encode(test_text)
input_feed = {
    "input_ids": np.array([enc.ids], dtype=np.int64),
    "attention_mask": np.array([enc.attention_mask], dtype=np.int64)
}

# Warmup
for _ in range(20):
    _ = session.run(None, input_feed)

# Benchmark
n_runs = 100
times = []
for _ in range(n_runs):
    start = time.perf_counter()
    _ = session.run(None, input_feed)
    times.append((time.perf_counter() - start) * 1000)

print(f"\nüìä Results ({n_runs} runs):")
print(f"   Mean:       {np.mean(times):.2f} ms")
print(f"   Std:        {np.std(times):.2f} ms")
print(f"   Min:        {np.min(times):.2f} ms")
print(f"   Max:        {np.max(times):.2f} ms")
print(f"   P50:        {np.percentile(times, 50):.2f} ms")
print(f"   P95:        {np.percentile(times, 95):.2f} ms")
print(f"   Throughput: {1000/np.mean(times):.0f} inferences/sec")


‚è±Ô∏è  BENCHMARK

üìä Results (100 runs):
   Mean:       3.02 ms
   Std:        0.08 ms
   Min:        2.94 ms
   Max:        3.53 ms
   P50:        3.00 ms
   P95:        3.11 ms
   Throughput: 331 inferences/sec


## üìã Summary

In [15]:
# ============================================================
# FINAL SUMMARY
# ============================================================

print("\n" + "=" * 60)
print("üìã CSF BROWSER v4 - TRAINING SUMMARY")
print("=" * 60)

# File sizes
files = {
    "model.onnx": f"{OUTPUT_DIR}/model.onnx",
    "tokenizer.json": f"{OUTPUT_DIR}/tokenizer.json",
    "config.json": f"{OUTPUT_DIR}/config.json",
    "labels.json": f"{OUTPUT_DIR}/labels.json",
}

total_size = 0
print(f"\nüìÅ Output files:")
for name, path in files.items():
    if os.path.exists(path):
        size = os.path.getsize(path)
        total_size += size
        if size > 1024 * 1024:
            print(f"   {name:20s}: {size/1024/1024:.2f} MB")
        else:
            print(f"   {name:20s}: {size/1024:.1f} KB")

print(f"   {'‚îÄ' * 30}")
print(f"   {'TOTAL':20s}: {total_size/1024/1024:.2f} MB")

print(f"\nüìä Model Stats:")
print(f"   Training samples: {len(train_data):,}")
print(f"   Validation samples: {len(val_data):,}")
print(f"   Best accuracy: {best_val_acc:.2f}%")
print(f"   Condition types: {len(LABELS['condition'])}")
print(f"   Total output classes: {sum(NUM_CLASSES.values())}")

print(f"\nüéØ Condition Categories:")
categories = {
    "Weather": ["IF_RAIN", "IF_SUNNY", "IF_COLD", "IF_HOT", "IF_WINDY"],
    "Time": ["IF_LATE", "IF_EARLY", "IF_WEEKEND", "IF_NIGHT", "IF_MORNING"],
    "Health": ["IF_SICK", "IF_TIRED", "IF_HUNGRY", "IF_THIRSTY", "IF_FULL"],
    "Schedule": ["IF_BUSY", "IF_FREE", "IF_HOLIDAY", "IF_WORKING"],
    "Mood": ["IF_BORED", "IF_HAPPY", "IF_SAD", "IF_STRESSED", "IF_ANGRY"],
    "Social": ["IF_ALONE", "IF_WITH_FRIENDS", "IF_WITH_FAMILY"],
    "Activity": ["IF_FINISH_WORK", "IF_FINISH_SCHOOL", "IF_FINISH_EATING", "IF_WATCH_MOVIE", "IF_LISTEN_MUSIC"],
    "Financial": ["IF_HAVE_MONEY", "IF_NO_MONEY"],
}
for cat, conds in categories.items():
    print(f"   {cat:12s}: {len(conds)} conditions")

print(f"\n‚úÖ Training complete!")
print(f"   Output directory: {OUTPUT_DIR}")
print("=" * 60)


üìã CSF BROWSER v4 - TRAINING SUMMARY

üìÅ Output files:
   model.onnx          : 433.7 KB
   tokenizer.json      : 321.4 KB
   config.json         : 0.5 KB
   labels.json         : 1.2 KB
   ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
   TOTAL               : 0.74 MB

üìä Model Stats:
   Training samples: 16,996
   Validation samples: 1,889
   Best accuracy: 99.03%
   Condition types: 35
   Total output classes: 73

üéØ Condition Categories:
   Weather     : 5 conditions
   Time        : 5 conditions
   Health      : 5 conditions
   Schedule    : 4 conditions
   Mood        : 5 conditions
   Social      : 3 conditions
   Activity    : 5 conditions
   Financial   : 2 conditions

‚úÖ Training complete!
   Output directory: /content/drive/MyDrive/csf_browser_v4
