# Focused Learning: Multi-Label BCE Loss in T-FREE

## Learning Objectives
1. **Understand the fundamental difference between single-label and multi-label classification**
2. **Explore why T-FREE requires multi-label loss instead of traditional cross-entropy**
3. **Implement and analyze the multi-label BCE loss mechanism**
4. **Visualize and compare the loss behavior for different word predictions**

## Paper Context

From the T-FREE paper (Deiseroth et al., 2025):

> "T-FREE inherently exploits morphological similarities and allows for strong compression of embedding layers... through sparse activation patterns over character triplets" (Section 1)

> "The backbone of the language model will remain free of subword tokenization as we directly embed each word in the input text with sparse activation patterns over hashed character triplets" (Page 1)

The key innovation here is that **each word activates multiple trigrams simultaneously**, requiring a loss function that can handle multiple positive labels per example.

## 1. Theoretical Foundation

### Traditional Single-Label Classification

In traditional language models with tokenizers:
- Each token maps to **exactly one** vocabulary ID
- Loss function: Cross-Entropy (CE)
- Mathematical formulation:

$$\mathcal{L}_{CE} = -\sum_{i=1}^{|V|} y_i \log(\hat{y}_i)$$

where $y_i \in \{0, 1\}$ and $\sum_i y_i = 1$ (one-hot encoding)

### T-FREE Multi-Label Classification

In T-FREE:
- Each word activates **multiple trigrams**
- Loss function: Binary Cross-Entropy (BCE) with multi-label
- Mathematical formulation:

$$\mathcal{L}_{BCE} = -\frac{1}{|V|}\sum_{i=1}^{|V|} [y_i \log(\sigma(\hat{y}_i)) + (1-y_i) \log(1-\sigma(\hat{y}_i))]$$

where $y_i \in \{0, 1\}$ and $\sum_i y_i \geq 1$ (multi-hot encoding)

In [None]:
# Environment setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 2. Understanding Single-Label vs Multi-Label

Let's visualize the fundamental difference between these two approaches:

In [None]:
# Demonstrate single-label vs multi-label encoding
vocab_size = 10

# Single-label (traditional tokenizer)
single_label = torch.zeros(vocab_size)
single_label[3] = 1  # Only one position is active

# Multi-label (T-FREE)
multi_label = torch.zeros(vocab_size)
multi_label[[1, 3, 5, 7]] = 1  # Multiple positions are active

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Single-label visualization
ax1.bar(range(vocab_size), single_label.numpy(), color='blue', alpha=0.7)
ax1.set_title('Single-Label Encoding (Traditional Tokenizer)', fontsize=14)
ax1.set_xlabel('Vocabulary Index')
ax1.set_ylabel('Activation')
ax1.set_ylim(-0.1, 1.1)
ax1.grid(axis='y', alpha=0.3)

# Multi-label visualization
ax2.bar(range(vocab_size), multi_label.numpy(), color='green', alpha=0.7)
ax2.set_title('Multi-Label Encoding (T-FREE)', fontsize=14)
ax2.set_xlabel('Vocabulary Index (Trigram Hash)')
ax2.set_ylabel('Activation')
ax2.set_ylim(-0.1, 1.1)
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Single-label: {single_label.sum().item()} active positions")
print(f"Multi-label: {multi_label.sum().item()} active positions")

## 3. Implementing Multi-Label BCE Loss

Now let's implement the multi-label BCE loss as used in T-FREE:

In [None]:
class MultiLabelBCELoss(nn.Module):
    """
    Multi-label Binary Cross-Entropy Loss for T-FREE.
    
    This loss function handles the case where multiple trigrams
    are active for a single word prediction.
    """
    
    def __init__(self, reduction='mean', label_smoothing=0.0):
        super().__init__()
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            logits: Raw model outputs [batch_size, vocab_size]
            targets: Multi-hot encoded targets [batch_size, vocab_size]
        
        Returns:
            loss: Scalar loss value
        """
        # Apply label smoothing if specified
        if self.label_smoothing > 0:
            targets = targets * (1 - self.label_smoothing) + \
                     self.label_smoothing / targets.size(-1)
        
        # Calculate BCE loss with logits
        loss = F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none'
        )
        
        # Apply reduction
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


# Compare with traditional cross-entropy
class SingleLabelCELoss(nn.Module):
    """Traditional Cross-Entropy Loss for comparison."""
    
    def __init__(self):
        super().__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        # Convert multi-hot to single label (for comparison)
        single_targets = targets.argmax(dim=-1)
        return self.ce_loss(logits, single_targets)

## 4. Loss Behavior Analysis

Let's analyze how the multi-label BCE loss behaves differently from traditional cross-entropy:

In [None]:
# Create synthetic data to demonstrate loss behavior
batch_size = 4
vocab_size = 100

# Generate random logits
logits = torch.randn(batch_size, vocab_size)

# Create multi-hot targets (T-FREE style)
targets = torch.zeros(batch_size, vocab_size)
for i in range(batch_size):
    # Each word activates 3-8 trigrams
    num_active = np.random.randint(3, 9)
    active_indices = np.random.choice(vocab_size, num_active, replace=False)
    targets[i, active_indices] = 1

# Calculate losses
bce_loss_fn = MultiLabelBCELoss()
ce_loss_fn = SingleLabelCELoss()

bce_loss = bce_loss_fn(logits, targets)
ce_loss = ce_loss_fn(logits, targets)

print(f"Multi-label BCE Loss: {bce_loss.item():.4f}")
print(f"Single-label CE Loss: {ce_loss.item():.4f}")

# Visualize the target distribution
plt.figure(figsize=(12, 8))
sns.heatmap(targets.numpy(), cmap='Blues', cbar_kws={'label': 'Activation'})
plt.title('Multi-Hot Target Encoding for T-FREE (4 words)', fontsize=14)
plt.xlabel('Trigram Hash Index')
plt.ylabel('Word Sample')
plt.show()

## 5. Gradient Analysis

One key advantage of multi-label BCE is how gradients are distributed across multiple active positions:

In [None]:
class GradientAnalyzer:
    """Analyze gradient behavior for different loss functions."""
    
    def __init__(self, vocab_size: int = 100):
        self.vocab_size = vocab_size
        
    def compute_gradients(self, logits: torch.Tensor, targets: torch.Tensor, 
                         loss_type: str = 'bce') -> torch.Tensor:
        """Compute gradients with respect to logits."""
        logits = logits.clone().requires_grad_(True)
        
        if loss_type == 'bce':
            loss = F.binary_cross_entropy_with_logits(logits, targets)
        else:  # ce
            single_targets = targets.argmax(dim=-1)
            loss = F.cross_entropy(logits, single_targets)
            
        loss.backward()
        return logits.grad
    
    def visualize_gradient_distribution(self):
        """Compare gradient distributions between BCE and CE losses."""
        # Create sample data
        logits = torch.randn(1, self.vocab_size)
        
        # Multi-hot target with 5 active positions
        targets = torch.zeros(1, self.vocab_size)
        active_positions = [10, 25, 40, 55, 70]
        targets[0, active_positions] = 1
        
        # Compute gradients
        bce_grads = self.compute_gradients(logits, targets, 'bce')
        ce_grads = self.compute_gradients(logits.clone(), targets, 'ce')
        
        # Visualization
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
        
        # BCE gradients
        ax1.bar(range(self.vocab_size), bce_grads[0].numpy(), 
                color='green', alpha=0.7)
        ax1.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        for pos in active_positions:
            ax1.axvline(x=pos, color='red', linestyle='--', alpha=0.5)
        ax1.set_title('Multi-Label BCE Gradients', fontsize=14)
        ax1.set_ylabel('Gradient Magnitude')
        ax1.grid(axis='y', alpha=0.3)
        
        # CE gradients
        ax2.bar(range(self.vocab_size), ce_grads[0].numpy(), 
                color='blue', alpha=0.7)
        ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
        ax2.axvline(x=active_positions[0], color='red', linestyle='--', alpha=0.5)
        ax2.set_title('Single-Label CE Gradients', fontsize=14)
        ax2.set_xlabel('Vocabulary Index')
        ax2.set_ylabel('Gradient Magnitude')
        ax2.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        return bce_grads, ce_grads

# Analyze gradients
analyzer = GradientAnalyzer()
bce_grads, ce_grads = analyzer.visualize_gradient_distribution()

print("\nGradient Statistics:")
print(f"BCE - Non-zero gradients: {(bce_grads != 0).sum().item()}")
print(f"CE - Non-zero gradients: {(ce_grads != 0).sum().item()}")

## 6. Morphological Similarity and Loss

A key insight from the T-FREE paper is how multi-label BCE naturally handles morphological similarities:

In [None]:
class MorphologicalLossAnalyzer:
    """Analyze how multi-label BCE handles morphologically similar words."""
    
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        
    def extract_trigrams(self, word: str) -> List[str]:
        """Extract trigrams from a word (T-FREE style)."""
        padded_word = f"_{word}_"
        trigrams = []
        for i in range(len(padded_word) - 2):
            trigrams.append(padded_word[i:i+3])
        return trigrams
    
    def trigrams_to_indices(self, trigrams: List[str]) -> List[int]:
        """Convert trigrams to vocabulary indices using hash."""
        indices = []
        for trigram in trigrams:
            # Simple hash function for demonstration
            hash_value = sum(ord(c) for c in trigram)
            index = hash_value % self.vocab_size
            indices.append(index)
        return list(set(indices))  # Remove duplicates
    
    def create_target_vector(self, word: str) -> torch.Tensor:
        """Create multi-hot target vector for a word."""
        trigrams = self.extract_trigrams(word)
        indices = self.trigrams_to_indices(trigrams)
        
        target = torch.zeros(self.vocab_size)
        target[indices] = 1
        return target
    
    def analyze_morphological_similarity(self):
        """Analyze loss behavior for morphologically similar words."""
        # Word families to analyze
        word_families = [
            ['run', 'runs', 'running', 'runner'],
            ['happy', 'happier', 'happiest', 'happiness'],
            ['compute', 'computer', 'computing', 'computation']
        ]
        
        fig, axes = plt.subplots(3, 1, figsize=(14, 12))
        
        for idx, family in enumerate(word_families):
            # Create target vectors
            targets = [self.create_target_vector(word) for word in family]
            
            # Calculate overlap matrix
            overlap_matrix = torch.zeros(len(family), len(family))
            for i in range(len(family)):
                for j in range(len(family)):
                    overlap = (targets[i] * targets[j]).sum()
                    total = targets[i].sum() + targets[j].sum() - overlap
                    overlap_matrix[i, j] = overlap / total if total > 0 else 0
            
            # Visualize
            im = axes[idx].imshow(overlap_matrix.numpy(), cmap='YlOrRd', 
                                  vmin=0, vmax=1, aspect='auto')
            axes[idx].set_xticks(range(len(family)))
            axes[idx].set_yticks(range(len(family)))
            axes[idx].set_xticklabels(family, rotation=45)
            axes[idx].set_yticklabels(family)
            axes[idx].set_title(f'Trigram Overlap: {" → ".join(family)}', 
                               fontsize=12)
            
            # Add text annotations
            for i in range(len(family)):
                for j in range(len(family)):
                    axes[idx].text(j, i, f'{overlap_matrix[i, j]:.2f}',
                                  ha='center', va='center')
            
            # Colorbar
            plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)
        
        plt.tight_layout()
        plt.show()
        
        return word_families

# Analyze morphological similarities
morph_analyzer = MorphologicalLossAnalyzer()
word_families = morph_analyzer.analyze_morphological_similarity()

## 7. Practical Implementation Considerations

Let's explore practical aspects of implementing multi-label BCE loss in a T-FREE model:

In [None]:
class TFreeLanguageHead(nn.Module):
    """
    T-FREE language modeling head with multi-label BCE loss.
    
    This demonstrates the practical implementation as described in the paper.
    """
    
    def __init__(self, hidden_size: int, vocab_size: int, 
                 label_smoothing: float = 0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.label_smoothing = label_smoothing
        
        # Projection layer
        self.proj = nn.Linear(hidden_size, vocab_size)
        
        # Loss function
        self.loss_fn = MultiLabelBCELoss(
            reduction='mean', 
            label_smoothing=label_smoothing
        )
        
    def forward(self, hidden_states: torch.Tensor, 
                targets: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        """
        Args:
            hidden_states: [batch_size, seq_len, hidden_size]
            targets: [batch_size, seq_len, vocab_size] multi-hot encoded
        
        Returns:
            Dictionary with logits and optional loss
        """
        # Project to vocabulary size
        logits = self.proj(hidden_states)
        
        output = {'logits': logits}
        
        if targets is not None:
            # Reshape for loss calculation
            batch_size, seq_len, vocab_size = logits.shape
            logits_flat = logits.view(-1, vocab_size)
            targets_flat = targets.view(-1, vocab_size)
            
            # Calculate loss
            loss = self.loss_fn(logits_flat, targets_flat)
            output['loss'] = loss
            
            # Calculate accuracy metrics
            with torch.no_grad():
                predictions = torch.sigmoid(logits_flat) > 0.5
                correct = (predictions == targets_flat).float()
                
                # Per-position accuracy
                accuracy = correct.mean()
                
                # Exact match accuracy (all trigrams correct)
                exact_match = (correct.sum(dim=1) == vocab_size).float().mean()
                
                output['accuracy'] = accuracy
                output['exact_match'] = exact_match
        
        return output


# Demonstrate usage
hidden_size = 768
vocab_size = 8000  # T-FREE uses smaller vocabulary
batch_size = 2
seq_len = 10

# Create model head
lm_head = TFreeLanguageHead(hidden_size, vocab_size)

# Mock hidden states from transformer
hidden_states = torch.randn(batch_size, seq_len, hidden_size)

# Create multi-hot targets
targets = torch.zeros(batch_size, seq_len, vocab_size)
for b in range(batch_size):
    for s in range(seq_len):
        # Each position activates 5-10 trigrams
        num_active = np.random.randint(5, 11)
        active_indices = np.random.choice(vocab_size, num_active, replace=False)
        targets[b, s, active_indices] = 1

# Forward pass
output = lm_head(hidden_states, targets)

print("T-FREE Language Head Output:")
print(f"Loss: {output['loss'].item():.4f}")
print(f"Accuracy: {output['accuracy'].item():.4f}")
print(f"Exact Match: {output['exact_match'].item():.4f}")
print(f"\nLogits shape: {output['logits'].shape}")

## 8. Loss Landscape Visualization

Let's visualize how the loss landscape differs between single-label and multi-label approaches:

In [None]:
def visualize_loss_landscape():
    """Compare loss landscapes for single vs multi-label approaches."""
    
    # Create a simple 2D parameter space
    param_range = np.linspace(-2, 2, 50)
    X, Y = np.meshgrid(param_range, param_range)
    
    # Fixed target (multi-hot)
    vocab_size = 10
    target = torch.zeros(1, vocab_size)
    target[0, [1, 3, 5, 7]] = 1
    
    # Calculate losses for different parameter values
    bce_losses = np.zeros_like(X)
    ce_losses = np.zeros_like(X)
    
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            # Create logits based on parameters
            logits = torch.zeros(1, vocab_size)
            logits[0, [1, 3, 5, 7]] = torch.tensor([X[i, j], Y[i, j], 
                                                    -X[i, j], -Y[i, j]])
            
            # BCE loss
            bce_loss = F.binary_cross_entropy_with_logits(logits, target)
            bce_losses[i, j] = bce_loss.item()
            
            # CE loss (using first active position as target)
            ce_loss = F.cross_entropy(logits, torch.tensor([1]))
            ce_losses[i, j] = ce_loss.item()
    
    # Visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # BCE loss landscape
    contour1 = ax1.contourf(X, Y, bce_losses, levels=20, cmap='viridis')
    ax1.set_title('Multi-Label BCE Loss Landscape', fontsize=14)
    ax1.set_xlabel('Parameter 1')
    ax1.set_ylabel('Parameter 2')
    plt.colorbar(contour1, ax=ax1, label='Loss')
    
    # CE loss landscape
    contour2 = ax2.contourf(X, Y, ce_losses, levels=20, cmap='plasma')
    ax2.set_title('Single-Label CE Loss Landscape', fontsize=14)
    ax2.set_xlabel('Parameter 1')
    ax2.set_ylabel('Parameter 2')
    plt.colorbar(contour2, ax=ax2, label='Loss')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"BCE Loss - Min: {bce_losses.min():.4f}, Max: {bce_losses.max():.4f}")
    print(f"CE Loss - Min: {ce_losses.min():.4f}, Max: {ce_losses.max():.4f}")

visualize_loss_landscape()

## 9. Performance Implications

The paper mentions significant parameter reduction. Let's analyze the computational implications:

In [None]:
class PerformanceAnalyzer:
    """Analyze computational performance of multi-label BCE."""
    
    def compare_memory_usage(self):
        """Compare memory usage between approaches."""
        # Model configurations
        traditional_vocab = 128000  # Llama-3 vocabulary
        tfree_vocab = 8000  # T-FREE vocabulary
        hidden_size = 4096
        
        # Calculate parameter counts
        traditional_params = traditional_vocab * hidden_size
        tfree_params = tfree_vocab * hidden_size
        
        # Memory in MB (assuming float32)
        traditional_memory = (traditional_params * 4) / (1024 ** 2)
        tfree_memory = (tfree_params * 4) / (1024 ** 2)
        
        # Visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # Parameter comparison
        models = ['Traditional\n(128k vocab)', 'T-FREE\n(8k vocab)']
        params = [traditional_params / 1e6, tfree_params / 1e6]
        
        bars1 = ax1.bar(models, params, color=['blue', 'green'], alpha=0.7)
        ax1.set_ylabel('Parameters (Millions)')
        ax1.set_title('Head Layer Parameter Count', fontsize=14)
        
        # Add value labels
        for bar, param in zip(bars1, params):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
                    f'{param:.1f}M', ha='center', va='bottom')
        
        # Memory comparison
        memory = [traditional_memory, tfree_memory]
        bars2 = ax2.bar(models, memory, color=['blue', 'green'], alpha=0.7)
        ax2.set_ylabel('Memory (MB)')
        ax2.set_title('Head Layer Memory Usage', fontsize=14)
        
        # Add value labels
        for bar, mem in zip(bars2, memory):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
                    f'{mem:.1f} MB', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
        
        # Calculate reduction
        param_reduction = (1 - tfree_params / traditional_params) * 100
        print(f"\nParameter Reduction: {param_reduction:.1f}%")
        print(f"Memory Savings: {traditional_memory - tfree_memory:.1f} MB")
        
        return param_reduction

# Analyze performance
perf_analyzer = PerformanceAnalyzer()
reduction = perf_analyzer.compare_memory_usage()

## 10. Advanced Multi-Label Techniques

Let's explore advanced techniques for multi-label learning in T-FREE:

In [None]:
class AdvancedMultiLabelBCE(nn.Module):
    """
    Advanced multi-label BCE with focal loss and class weighting.
    
    These techniques can help with imbalanced trigram distributions.
    """
    
    def __init__(self, alpha: float = 0.25, gamma: float = 2.0, 
                 pos_weight: torch.Tensor = None):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = pos_weight
        
    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Focal loss for multi-label classification.
        
        Focal loss helps focus on hard examples by down-weighting
        easy examples.
        """
        # Calculate probabilities
        probs = torch.sigmoid(logits)
        
        # Basic BCE loss
        if self.pos_weight is not None:
            bce_loss = F.binary_cross_entropy_with_logits(
                logits, targets, pos_weight=self.pos_weight, reduction='none'
            )
        else:
            bce_loss = F.binary_cross_entropy_with_logits(
                logits, targets, reduction='none'
            )
        
        # Calculate focal weight
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal_weight = (1 - pt) ** self.gamma
        
        # Apply alpha weighting
        if self.alpha is not None:
            alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)
            focal_weight = alpha_t * focal_weight
        
        # Combine
        focal_loss = focal_weight * bce_loss
        
        return focal_loss.mean()


# Demonstrate advanced loss behavior
def demonstrate_focal_loss():
    """Show how focal loss affects training dynamics."""
    vocab_size = 100
    batch_size = 4
    
    # Create imbalanced targets (some trigrams are rare)
    targets = torch.zeros(batch_size, vocab_size)
    # Common trigrams
    targets[:, :20] = torch.bernoulli(torch.ones(batch_size, 20) * 0.7)
    # Rare trigrams
    targets[:, 80:] = torch.bernoulli(torch.ones(batch_size, 20) * 0.1)
    
    # Random logits
    logits = torch.randn(batch_size, vocab_size)
    
    # Calculate different losses
    standard_bce = MultiLabelBCELoss()
    focal_bce = AdvancedMultiLabelBCE(alpha=0.25, gamma=2.0)
    
    standard_loss = standard_bce(logits, targets)
    focal_loss = focal_bce(logits, targets)
    
    print("Loss Comparison:")
    print(f"Standard BCE: {standard_loss.item():.4f}")
    print(f"Focal BCE: {focal_loss.item():.4f}")
    
    # Visualize per-position losses
    with torch.no_grad():
        standard_losses = F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none'
        ).mean(dim=0)
        
        probs = torch.sigmoid(logits)
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal_weight = (1 - pt) ** 2.0
        focal_losses = (focal_weight * F.binary_cross_entropy_with_logits(
            logits, targets, reduction='none'
        )).mean(dim=0)
    
    plt.figure(figsize=(14, 6))
    x = np.arange(vocab_size)
    
    plt.bar(x - 0.2, standard_losses.numpy(), width=0.4, 
            label='Standard BCE', alpha=0.7)
    plt.bar(x + 0.2, focal_losses.numpy(), width=0.4, 
            label='Focal BCE', alpha=0.7)
    
    plt.axvline(x=20, color='red', linestyle='--', alpha=0.5, 
                label='Common/Rare boundary')
    plt.axvline(x=80, color='red', linestyle='--', alpha=0.5)
    
    plt.xlabel('Vocabulary Index')
    plt.ylabel('Average Loss')
    plt.title('Per-Position Loss: Standard vs Focal BCE', fontsize=14)
    plt.legend()
    plt.grid(axis='y', alpha=0.3)
    plt.show()

demonstrate_focal_loss()

## Summary and Key Insights

### 1. **Multi-Label Nature of T-FREE**
- Each word activates multiple trigrams simultaneously
- Binary Cross-Entropy (BCE) naturally handles this multi-hot encoding
- Traditional Cross-Entropy would lose information by forcing single-label prediction

### 2. **Gradient Distribution Benefits**
- BCE distributes gradients across all active trigrams
- This leads to more stable training dynamics
- Morphologically similar words share gradients through common trigrams

### 3. **Memory Efficiency**
- Vocabulary reduction from 128k to 8k tokens
- Parameter reduction of >85% in embedding/head layers
- Maintains performance through sparse representations

### 4. **Morphological Awareness**
- Words with similar spellings naturally share trigrams
- This creates implicit morphological understanding
- No need to learn separate embeddings for each word variant

### 5. **Practical Considerations**
- Label smoothing can help with overfitting
- Focal loss variants can handle imbalanced trigram distributions
- Exact match accuracy is a useful metric alongside per-position accuracy

## References

Deiseroth, B., Brack, M., Schramowski, P., Kersting, K., & Weinbach, S. (2025). T-FREE: Subword Tokenizer-Free Generative LLMs via Sparse Representations for Memory-Efficient Embeddings. *arXiv preprint arXiv:2406.19223v2*.