# T-FREE: Subword Tokenizer-Free Generative LLMs via Sparse Representations for Memory-Efficient Embeddings

## Paper Information
- **Title:** T-FREE: Subword Tokenizer-Free Generative LLMs via Sparse Representations for Memory-Efficient Embeddings
- **Authors:** Björn Deiseroth, Manuel Brack, Patrick Schramowski, Kristian Kersting, Samuel Weinbach
- **Affiliations:** Aleph Alpha @ IPAI, Technical University Darmstadt, Hessian Center for AI (hessian.AI), German Research Center for AI (DFKI)
- **Paper Link:** [arXiv:2406.19223v2](https://arxiv.org/abs/2406.19223v2)
- **GitHub:** https://github.com/Aleph-Alpha/trigrams

## Paper Summary

T-FREE proposes a paradigm shift in how Large Language Models (LLMs) embed and decode text. Unlike traditional subword tokenizers (like BPE or Unigram) that suffer from computational overhead, ineffective vocabulary use, and large embedding layers, T-FREE directly embeds words through sparse activation patterns over character triplets (trigrams).

Key innovations:
- **No reference corpus needed:** T-FREE doesn't require tokenizer training on a specific corpus
- **Morphological similarity exploitation:** Similar words share embedding components through overlapping trigrams
- **Parameter reduction:** Achieves >85% reduction in embedding layer parameters
- **Cross-lingual performance:** Shows significant improvements in transfer learning across languages
- **Memory efficiency:** Vocabulary size reduced by 56% compared to standard tokenizers

The paper addresses three fundamental flaws of traditional tokenizers:
1. **F1:** Large vocabularies leading to massive embedding/head layers
2. **F2:** Duplicate tokens differing only in capitalization or whitespace
3. **F3:** Training data overfitting and poor cross-lingual performance

## Environment Setup

Let's set up the necessary environment for implementing T-FREE. We'll use PyTorch for the neural network components and implement the trigram-based encoding system.

In [None]:
# Install necessary libraries
!pip install torch numpy scipy transformers langchain langchain-community deepeval
!pip install matplotlib seaborn tqdm xxhash

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import xxhash
from typing import List, Tuple, Set, Dict
import re
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Core T-FREE Implementation

### Step 1: Word Splitting

Following the paper (Section 3.1, Step 1), we rigorously split text by digits and non-alphanumeric characters. Each word is represented with prefixed and suffixed whitespace.

In [None]:
class TFreeTokenizer:
    """T-FREE Tokenizer implementation based on the paper.
    
    As described in Section 3.1 of the paper:
    'First, we rigorously split the text by digits and non-alphanumeric characters.'
    """
    
    def __init__(self):
        # Special tokens as mentioned in the paper
        self.special_tokens = {
            '<whitespace>': 0,
            '<non-whitespace>': 1,
            '<unk>': 2
        }
        
        # Digits are handled separately (Section 3.1)
        self.digits = {str(i): i + 3 for i in range(10)}
        
    def split_text(self, text: str) -> List[str]:
        """Split text following T-FREE rules.
        
        From the paper: 'The resulting splits, therefore, contain entire words,
        digits, or special characters.'
        """
        # Pattern to split by non-alphanumeric characters and digits
        pattern = r'(\W|\d)'
        tokens = re.split(pattern, text)
        
        # Filter empty tokens and process
        processed_tokens = []
        for token in tokens:
            if token and not token.isspace():
                if token.isdigit():
                    # Each digit is separate (Section 3.1)
                    processed_tokens.extend(list(token))
                else:
                    processed_tokens.append(token)
                    
        return processed_tokens
    
    def apply_whitespace_rules(self, tokens: List[str]) -> List[str]:
        """Apply whitespace rules as described in the paper.
        
        From Section 3.1: 'Generally, we prefer to add no whitespace after 
        a digit embedding and similarly no whitespace before punctuation.'
        """
        processed = []
        for i, token in enumerate(tokens):
            processed.append(token)
            
            # Check if we need to add whitespace token
            if i < len(tokens) - 1:
                next_token = tokens[i + 1]
                
                # No whitespace after digits or before punctuation
                if not (token.isdigit() or next_token in '.,!?;:)]}\'"'):
                    if token not in '([{\'"':
                        processed.append('<whitespace>')
                        
        return processed

# Test the tokenizer
tokenizer = TFreeTokenizer()
test_text = "Hello world! T-FREE is 85% more efficient."
tokens = tokenizer.split_text(test_text)
print(f"Original text: {test_text}")
print(f"Tokens: {tokens}")
tokens_with_ws = tokenizer.apply_whitespace_rules(tokens)
print(f"With whitespace rules: {tokens_with_ws}")

### Step 2: Trigram Encoding

Following Section 3.1 Step 2, we encode each word into character triplets (trigrams) using convolutions of size three.

In [None]:
class TrigramEncoder:
    """Encode words into trigrams as described in Section 3.1 Step 2.
    
    'Specifically, we apply convolutions of size three and byte-wise stride 
    to each word. This operation yields a set of character triplets, 
    which we refer to as "trigrams".'
    """
    
    def __init__(self, vocab_size: int = 16384, num_hashes: int = 8, lowercase_ratio: float = 0.5):
        self.vocab_size = vocab_size
        self.num_hashes = num_hashes  # m in the paper
        self.lowercase_ratio = lowercase_ratio  # k/m ratio
        self.num_lowercase = int(num_hashes * lowercase_ratio)
        
    def extract_trigrams(self, word: str) -> List[str]:
        """Extract trigrams from a word.
        
        Example from paper: 'Hello' -> {_He, Hel, ell, llo, lo_}
        """
        # Add whitespace prefix and suffix as per paper
        padded_word = f"_{word}_"
        
        trigrams = []
        for i in range(len(padded_word) - 2):
            trigrams.append(padded_word[i:i+3])
            
        return trigrams
    
    def hash_trigram(self, trigram: str, hash_idx: int) -> int:
        """Hash a trigram to get its vocabulary index.
        
        Uses xxhash for robust hashing as mentioned in the paper.
        """
        # Create unique hash by combining trigram and hash index
        hash_input = f"{trigram}_{hash_idx}".encode('utf-8')
        hash_value = xxhash.xxh32(hash_input).intdigest()
        
        # Map to vocabulary index using modulo
        return hash_value % self.vocab_size
    
    def encode_word(self, word: str) -> Tuple[List[int], int]:
        """Encode a word into sparse activation pattern.
        
        Returns tuple of (active_indices, num_activations)
        From paper: 'Overall, we obtain n·m total activations for any single word.'
        """
        trigrams = self.extract_trigrams(word)
        active_indices = []
        
        for trigram in trigrams:
            # Calculate m hashes for each trigram
            for hash_idx in range(self.num_hashes):
                # Use lowercase for first k hashes (Section 3.1)
                if hash_idx < self.num_lowercase:
                    trigram_to_hash = trigram.lower()
                else:
                    trigram_to_hash = trigram
                    
                vocab_idx = self.hash_trigram(trigram_to_hash, hash_idx)
                active_indices.append(vocab_idx)
                
        return active_indices, len(trigrams) * self.num_hashes

# Test trigram encoding
encoder = TrigramEncoder(vocab_size=8192, num_hashes=4)
test_words = ["Hello", "hello", "world", "words"]

for word in test_words:
    trigrams = encoder.extract_trigrams(word)
    indices, num_activations = encoder.encode_word(word)
    print(f"\nWord: '{word}'")
    print(f"Trigrams: {trigrams}")
    print(f"Active indices: {indices[:10]}... (total: {len(indices)})")
    print(f"Number of activations: {num_activations}")

### Step 3: T-FREE Embedding Layer

Implementing the embedding aggregation as described in Section 3.1 Step 3.

In [None]:
class TFreeEmbedding(nn.Module):
    """T-FREE embedding layer implementation.
    
    From Section 3.1 Step 3: 'Similar to classic embedding approaches T-FREE 
    also utilizes an embedding matrix of dimension v with hidden size h. 
    However, we do not have a fixed vocabulary, whose size dictates v.'
    """
    
    def __init__(self, vocab_size: int, hidden_size: int, num_hashes: int = 8, 
                 lowercase_ratio: float = 0.5):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.encoder = TrigramEncoder(vocab_size, num_hashes, lowercase_ratio)
        
        # Embedding matrix (v x h)
        self.embeddings = nn.Embedding(vocab_size, hidden_size)
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.embeddings.weight)
        
    def forward(self, words: List[str]) -> torch.Tensor:
        """Embed a list of words using T-FREE.
        
        From paper: 'Lastly, we sum all n·m embedding entries to produce 
        the final one embedding corresponding to a word.'
        """
        batch_embeddings = []
        
        for word in words:
            # Get active indices for the word
            active_indices, _ = self.encoder.encode_word(word)
            
            # Convert to tensor
            indices_tensor = torch.tensor(active_indices, device=self.embeddings.weight.device)
            
            # Lookup embeddings and sum
            word_embeddings = self.embeddings(indices_tensor)
            word_embedding = word_embeddings.sum(dim=0)
            
            batch_embeddings.append(word_embedding)
            
        return torch.stack(batch_embeddings)
    
    def visualize_overlap(self, words: List[str]):
        """Visualize the activation pattern overlap between words."""
        patterns = {}
        for word in words:
            indices, _ = self.encoder.encode_word(word)
            patterns[word] = set(indices)
            
        # Calculate overlap matrix
        overlap_matrix = np.zeros((len(words), len(words)))
        for i, w1 in enumerate(words):
            for j, w2 in enumerate(words):
                if patterns[w1] and patterns[w2]:
                    overlap = len(patterns[w1] & patterns[w2]) / len(patterns[w1] | patterns[w2])
                    overlap_matrix[i, j] = overlap
                    
        # Visualize
        plt.figure(figsize=(8, 6))
        sns.heatmap(overlap_matrix, annot=True, fmt='.2f', 
                    xticklabels=words, yticklabels=words, cmap='YlOrRd')
        plt.title('T-FREE Activation Pattern Overlap')
        plt.tight_layout()
        plt.show()

# Test T-FREE embedding
embedding_layer = TFreeEmbedding(vocab_size=8192, hidden_size=512, num_hashes=4)
test_words = ["Hello", "hello", "Hell", "World", "word", "words"]
embeddings = embedding_layer(test_words)
print(f"Embeddings shape: {embeddings.shape}")

# Visualize overlap between similar words
embedding_layer.visualize_overlap(test_words)

## T-FREE Language Model Head

Implementing the multi-label prediction head as described in Section 3.2.

In [None]:
class TFreeLMHead(nn.Module):
    """T-FREE language model head for multi-label prediction.
    
    From Section 3.2: 'In particular, we change the target loss function 
    from classic single-label binary cross-entropy (BCE) to a multi-label (ML) 
    BCE over all n·m activations of the next word targets.'
    """
    
    def __init__(self, hidden_size: int, vocab_size: int, encoder: TrigramEncoder):
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.encoder = encoder
        
        # Linear projection to vocabulary size
        self.output_projection = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Project hidden states to vocabulary logits."""
        return self.output_projection(hidden_states)
    
    def compute_loss(self, logits: torch.Tensor, target_words: List[str]) -> torch.Tensor:
        """Compute multi-label BCE loss.
        
        Implements the loss function from Section 3.2:
        L_ML_BCE = -Σ[y_j log(ŷ_j) + (1-y_j)log(1-ŷ_j)]
        """
        batch_size = logits.shape[0]
        device = logits.device
        
        # Create binary target matrix
        targets = torch.zeros(batch_size, self.vocab_size, device=device)
        
        for i, word in enumerate(target_words):
            active_indices, _ = self.encoder.encode_word(word)
            for idx in active_indices:
                targets[i, idx] = 1.0
                
        # Apply sigmoid to logits
        probs = torch.sigmoid(logits)
        
        # Multi-label BCE loss
        loss = -(targets * torch.log(probs + 1e-8) + 
                (1 - targets) * torch.log(1 - probs + 1e-8))
        
        return loss.mean()
    
    def decode_next_word(self, logits: torch.Tensor, dictionary: List[str], 
                        temperature: float = 1.0) -> str:
        """Decode the next word using dictionary lookup.
        
        From Figure 2 and Section 3.2: 'The element-wise sigmoid values 
        of the output of the last hidden layer, σ(h), is multiplied with 
        this pattern matrix using standard dot product.'
        """
        # Apply sigmoid
        probs = torch.sigmoid(logits / temperature).squeeze()
        
        # Pre-compute dictionary patterns
        word_scores = []
        for word in dictionary:
            active_indices, num_activations = self.encoder.encode_word(word)
            
            # Calculate average activation for this word
            word_prob = probs[active_indices].mean().item()
            word_scores.append((word, word_prob))
            
        # Sort by score and return top word
        word_scores.sort(key=lambda x: x[1], reverse=True)
        
        return word_scores[0][0], word_scores[:5]  # Return best word and top-5

# Test the LM head
lm_head = TFreeLMHead(hidden_size=512, vocab_size=8192, encoder=encoder)

# Simulate hidden states
batch_size = 2
hidden_states = torch.randn(batch_size, 512)
logits = lm_head(hidden_states)
print(f"Logits shape: {logits.shape}")

# Test loss computation
target_words = ["world", "hello"]
loss = lm_head.compute_loss(logits, target_words)
print(f"Loss: {loss.item():.4f}")

# Test decoding
dictionary = ["hello", "world", "the", "a", "is", "are", "Hello", "World"]
next_word, top5 = lm_head.decode_next_word(logits[0:1], dictionary)
print(f"\nPredicted next word: '{next_word}'")
print(f"Top 5 predictions: {top5}")

## Complete T-FREE Model

Now let's integrate all components into a complete T-FREE language model.

In [None]:
class TFreeLanguageModel(nn.Module):
    """Complete T-FREE language model implementation."""
    
    def __init__(self, vocab_size: int = 8192, hidden_size: int = 512, 
                 num_layers: int = 6, num_heads: int = 8, 
                 num_hashes: int = 4, lowercase_ratio: float = 0.5):
        super().__init__()
        
        # T-FREE components
        self.tokenizer = TFreeTokenizer()
        self.embedding = TFreeEmbedding(vocab_size, hidden_size, num_hashes, lowercase_ratio)
        
        # Transformer layers (simplified)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dim_feedforward=hidden_size * 4,
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        # T-FREE LM head
        self.lm_head = TFreeLMHead(hidden_size, vocab_size, self.embedding.encoder)
        
    def forward(self, text: str) -> Tuple[torch.Tensor, List[str]]:
        """Forward pass through the T-FREE model."""
        # Tokenize
        tokens = self.tokenizer.split_text(text)
        tokens = self.tokenizer.apply_whitespace_rules(tokens)
        
        # Filter out special tokens for embedding
        word_tokens = [t for t in tokens if t not in ['<whitespace>', '<non-whitespace>']]
        
        if not word_tokens:
            return None, tokens
            
        # Embed tokens
        embeddings = self.embedding(word_tokens).unsqueeze(0)  # Add batch dimension
        
        # Pass through transformer
        hidden_states = self.transformer(embeddings)
        
        # Get logits from LM head
        logits = self.lm_head(hidden_states)
        
        return logits, tokens
    
    def generate(self, prompt: str, max_length: int = 50, temperature: float = 1.0) -> str:
        """Generate text using T-FREE."""
        # Simple dictionary for demo (in practice, this would be much larger)
        dictionary = ["the", "a", "is", "are", "was", "were", "hello", "world",
                     "language", "model", "T-FREE", "efficient", "sparse", "embedding",
                     "trigram", "tokenizer", "free", "method", "paper", "shows"]
        
        generated_text = prompt
        
        for _ in range(max_length):
            # Get model predictions
            logits, _ = self.forward(generated_text)
            
            if logits is None:
                break
                
            # Decode next word
            last_logits = logits[0, -1:, :]  # Get last position
            next_word, _ = self.lm_head.decode_next_word(last_logits, dictionary, temperature)
            
            # Add to generated text
            generated_text += " " + next_word
            
        return generated_text

# Create and test the model
model = TFreeLanguageModel(vocab_size=8192, hidden_size=256, num_layers=2)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Test forward pass
test_text = "T-FREE is a new"
logits, tokens = model.forward(test_text)
print(f"\nInput: '{test_text}'")
print(f"Tokens: {tokens}")
print(f"Output shape: {logits.shape if logits is not None else 'None'}")

# Test generation (with random weights, output will be random)
generated = model.generate("T-FREE is", max_length=10)
print(f"\nGenerated text: '{generated}'")

## Comparison with Traditional Tokenizers

Let's implement a comparison to demonstrate T-FREE's advantages over traditional tokenizers.

In [None]:
from transformers import AutoTokenizer

def compare_tokenizers(text: str, tfree_encoder: TrigramEncoder):
    """Compare T-FREE with traditional tokenizers."""
    # Load some popular tokenizers
    tokenizers = {
        'GPT-2': AutoTokenizer.from_pretrained('gpt2'),
        'BERT': AutoTokenizer.from_pretrained('bert-base-uncased'),
    }
    
    results = {}
    
    # Traditional tokenizers
    for name, tokenizer in tokenizers.items():
        tokens = tokenizer.tokenize(text)
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        results[name] = {
            'tokens': tokens,
            'num_tokens': len(tokens),
            'vocab_size': tokenizer.vocab_size
        }
    
    # T-FREE
    tfree_tokenizer = TFreeTokenizer()
    words = tfree_tokenizer.split_text(text)
    total_activations = 0
    for word in words:
        if word.isalnum():
            _, num_act = tfree_encoder.encode_word(word)
            total_activations += num_act
    
    results['T-FREE'] = {
        'tokens': words,
        'num_tokens': len(words),
        'total_activations': total_activations,
        'vocab_size': tfree_encoder.vocab_size
    }
    
    return results

# Compare on sample texts
test_texts = [
    "Hello world!",
    "T-FREE reduces parameters by 85%.",
    "The quick brown fox jumps over the lazy dog.",
    "Parameter-efficient fine-tuning is important for LLMs."
]

encoder = TrigramEncoder(vocab_size=8192, num_hashes=4)

for text in test_texts:
    print(f"\nText: '{text}'")
    print("=" * 50)
    
    results = compare_tokenizers(text, encoder)
    
    for method, data in results.items():
        print(f"\n{method}:")
        print(f"  Tokens: {data['tokens']}")
        print(f"  Number of tokens: {data['num_tokens']}")
        print(f"  Vocabulary size: {data['vocab_size']:,}")
        if 'total_activations' in data:
            print(f"  Total activations: {data['total_activations']}")

## Analyzing Morphological Similarity Exploitation

Let's visualize how T-FREE exploits morphological similarities between words.

In [None]:
def analyze_morphological_similarity():
    """Analyze how T-FREE handles morphologically similar words."""
    encoder = TrigramEncoder(vocab_size=8192, num_hashes=4)
    
    # Groups of morphologically similar words
    word_groups = [
        ["run", "runs", "running", "runner"],
        ["happy", "happiness", "happily", "unhappy"],
        ["compute", "computer", "computation", "computational"],
        ["token", "tokens", "tokenize", "tokenizer"]
    ]
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    axes = axes.flatten()
    
    for idx, group in enumerate(word_groups):
        # Calculate activation patterns
        patterns = {}
        for word in group:
            indices, _ = encoder.encode_word(word)
            patterns[word] = set(indices)
        
        # Calculate overlap matrix
        n = len(group)
        overlap_matrix = np.zeros((n, n))
        
        for i, w1 in enumerate(group):
            for j, w2 in enumerate(group):
                if i == j:
                    overlap_matrix[i, j] = 1.0
                else:
                    intersection = len(patterns[w1] & patterns[w2])
                    union = len(patterns[w1] | patterns[w2])
                    overlap_matrix[i, j] = intersection / union if union > 0 else 0
        
        # Visualize
        sns.heatmap(overlap_matrix, annot=True, fmt='.2f', 
                    xticklabels=group, yticklabels=group, 
                    cmap='YlOrRd', ax=axes[idx], vmin=0, vmax=1)
        axes[idx].set_title(f'Morphological Similarity: {group[0]} family')
    
    plt.tight_layout()
    plt.show()
    
    # Analyze duplicate handling
    print("\nDuplicate Token Analysis (F2 from paper):")
    print("=" * 50)
    
    # Examples of duplicate tokens in traditional tokenizers
    duplicate_examples = [
        ("hello", "Hello"),
        ("world", "World"),
        ("the", "The"),
        (" world", "world"),  # with/without leading space
    ]
    
    for w1, w2 in duplicate_examples:
        p1 = set(encoder.encode_word(w1)[0])
        p2 = set(encoder.encode_word(w2)[0])
        overlap = len(p1 & p2) / len(p1 | p2) if len(p1 | p2) > 0 else 0
        
        print(f"'{w1}' vs '{w2}': {overlap:.2%} overlap")
        print(f"  Traditional tokenizers: Would use 2 completely separate embeddings")
        print(f"  T-FREE: Shares {len(p1 & p2)} activation indices")
        print()

analyze_morphological_similarity()

## Cross-Lingual Performance Analysis

Demonstrating T-FREE's language-agnostic properties (addressing F3 from the paper).

In [None]:
def analyze_cross_lingual_performance():
    """Analyze T-FREE's performance across different languages."""
    encoder = TrigramEncoder(vocab_size=16384, num_hashes=8)
    
    # Sample texts in different languages
    multilingual_texts = {
        'English': "The quick brown fox jumps over the lazy dog",
        'German': "Der schnelle braune Fuchs springt über den faulen Hund",
        'Spanish': "El rápido zorro marrón salta sobre el perro perezoso",
        'French': "Le renard brun rapide saute par-dessus le chien paresseux"
    }
    
    # Analyze encoding efficiency
    results = {}
    for lang, text in multilingual_texts.items():
        words = text.split()
        total_activations = 0
        total_trigrams = 0
        
        for word in words:
            indices, num_act = encoder.encode_word(word)
            total_activations += num_act
            total_trigrams += len(encoder.extract_trigrams(word))
        
        results[lang] = {
            'words': len(words),
            'avg_activations_per_word': total_activations / len(words),
            'avg_trigrams_per_word': total_trigrams / len(words),
            'text_length': len(text)
        }
    
    # Visualize results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    languages = list(results.keys())
    avg_activations = [results[lang]['avg_activations_per_word'] for lang in languages]
    avg_trigrams = [results[lang]['avg_trigrams_per_word'] for lang in languages]
    
    # Bar plot for average activations
    ax1.bar(languages, avg_activations, color='skyblue')
    ax1.set_ylabel('Average Activations per Word')
    ax1.set_title('T-FREE Encoding Consistency Across Languages')
    ax1.set_ylim(0, max(avg_activations) * 1.2)
    
    # Bar plot for average trigrams
    ax2.bar(languages, avg_trigrams, color='lightcoral')
    ax2.set_ylabel('Average Trigrams per Word')
    ax2.set_title('Trigram Distribution Across Languages')
    ax2.set_ylim(0, max(avg_trigrams) * 1.2)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print("\nCross-Lingual Encoding Statistics:")
    print("=" * 60)
    for lang, stats in results.items():
        print(f"\n{lang}:")
        print(f"  Words: {stats['words']}")
        print(f"  Avg activations/word: {stats['avg_activations_per_word']:.1f}")
        print(f"  Avg trigrams/word: {stats['avg_trigrams_per_word']:.1f}")
        print(f"  Text length: {stats['text_length']} characters")
    
    print("\nKey Insight: T-FREE maintains consistent encoding efficiency across languages")
    print("without requiring language-specific training or vocabulary optimization.")

analyze_cross_lingual_performance()

## Parameter Reduction Analysis

Let's calculate and visualize the parameter reduction achieved by T-FREE.

In [None]:
def analyze_parameter_reduction():
    """Analyze parameter reduction in T-FREE vs traditional embeddings."""
    
    # Model configurations
    hidden_sizes = [256, 512, 768, 1024, 2048, 4096]
    
    # Traditional vocabulary sizes from the paper
    traditional_vocab_sizes = {
        'Small (32k)': 32000,
        'Medium (64k)': 64000,
        'Large (128k)': 128000,
        'XLarge (256k)': 256000
    }
    
    # T-FREE vocabulary size (can be much smaller)
    tfree_vocab_size = 8192  # As mentioned in paper: 87.5% reduction
    
    results = []
    
    for hidden_size in hidden_sizes:
        for vocab_name, trad_vocab_size in traditional_vocab_sizes.items():
            # Traditional embedding parameters
            trad_embed_params = trad_vocab_size * hidden_size
            trad_lm_head_params = trad_vocab_size * hidden_size
            trad_total = trad_embed_params + trad_lm_head_params
            
            # T-FREE parameters
            tfree_embed_params = tfree_vocab_size * hidden_size
            tfree_lm_head_params = tfree_vocab_size * hidden_size
            tfree_total = tfree_embed_params + tfree_lm_head_params
            
            # Calculate reduction
            reduction_pct = (1 - tfree_total / trad_total) * 100
            
            results.append({
                'hidden_size': hidden_size,
                'vocab_type': vocab_name,
                'traditional_params': trad_total,
                'tfree_params': tfree_total,
                'reduction_pct': reduction_pct,
                'params_saved': trad_total - tfree_total
            })
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot 1: Parameter count comparison
    for vocab_name in traditional_vocab_sizes.keys():
        data = [r for r in results if r['vocab_type'] == vocab_name]
        hidden_sizes_plot = [r['hidden_size'] for r in data]
        trad_params = [r['traditional_params'] / 1e9 for r in data]  # Convert to billions
        tfree_params = [r['tfree_params'] / 1e9 for r in data]
        
        ax1.plot(hidden_sizes_plot, trad_params, 'o-', label=f'Traditional {vocab_name}', linewidth=2)
        ax1.plot(hidden_sizes_plot, tfree_params, 's--', label=f'T-FREE', linewidth=2)
    
    ax1.set_xlabel('Hidden Size')
    ax1.set_ylabel('Parameters (Billions)')
    ax1.set_title('Embedding + LM Head Parameters: Traditional vs T-FREE')
    ax1.set_yscale('log')
    ax1.grid(True, which="both", ls="-", alpha=0.2)
    ax1.legend()
    
    # Plot 2: Reduction percentage heatmap
    reduction_matrix = np.zeros((len(hidden_sizes), len(traditional_vocab_sizes)))
    for i, hidden_size in enumerate(hidden_sizes):
        for j, vocab_name in enumerate(traditional_vocab_sizes.keys()):
            data = [r for r in results if r['hidden_size'] == hidden_size and r['vocab_type'] == vocab_name]
            reduction_matrix[i, j] = data[0]['reduction_pct']
    
    sns.heatmap(reduction_matrix, annot=True, fmt='.1f', cmap='YlOrRd',
                xticklabels=list(traditional_vocab_sizes.keys()),
                yticklabels=hidden_sizes, ax=ax2)
    ax2.set_xlabel('Traditional Vocabulary Size')
    ax2.set_ylabel('Hidden Size')
    ax2.set_title('Parameter Reduction Percentage with T-FREE')
    
    plt.tight_layout()
    plt.show()
    
    # Print example savings
    print("\nParameter Savings Examples:")
    print("=" * 60)
    
    example_configs = [
        ('Command-R', 12288, 256000),  # From paper
        ('GPT-2', 768, 50257),
        ('BERT', 768, 30522),
        ('Mistral', 4096, 32000)
    ]
    
    for model_name, hidden_size, vocab_size in example_configs:
        trad_params = 2 * vocab_size * hidden_size
        tfree_params = 2 * tfree_vocab_size * hidden_size
        saved = trad_params - tfree_params
        reduction = (1 - tfree_params / trad_params) * 100
        
        print(f"\n{model_name}:")
        print(f"  Traditional: {trad_params / 1e9:.2f}B parameters")
        print(f"  T-FREE: {tfree_params / 1e9:.2f}B parameters")
        print(f"  Saved: {saved / 1e9:.2f}B parameters ({reduction:.1f}% reduction)")

analyze_parameter_reduction()

## Training and Evaluation Framework

Let's implement a training framework using LangChain for data processing and DeepEval for evaluation.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
import deepeval
from deepeval.metrics import PerplexityMetric
from deepeval.test_case import LLMTestCase

class TFreeTrainer:
    """Training framework for T-FREE models with LangChain integration."""
    
    def __init__(self, model: TFreeLanguageModel, learning_rate: float = 1e-4):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=512,
            chunk_overlap=50
        )
        
    def prepare_training_data(self, texts: List[str]) -> List[Document]:
        """Prepare training data using LangChain."""
        documents = []
        for text in texts:
            # Split text into chunks
            chunks = self.text_splitter.split_text(text)
            for chunk in chunks:
                documents.append(Document(page_content=chunk))
        return documents
    
    def train_step(self, batch_texts: List[str]) -> float:
        """Single training step."""
        self.model.train()
        total_loss = 0
        
        for text in batch_texts:
            # Split into input and target
            words = text.split()
            if len(words) < 2:
                continue
                
            input_text = ' '.join(words[:-1])
            target_word = words[-1]
            
            # Forward pass
            logits, _ = self.model(input_text)
            if logits is None:
                continue
                
            # Compute loss
            last_logits = logits[0, -1:, :]  # Get last position
            loss = self.model.lm_head.compute_loss(last_logits, [target_word])
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
        return total_loss / len(batch_texts) if batch_texts else 0
    
    def evaluate_with_deepeval(self, test_texts: List[str]) -> Dict[str, float]:
        """Evaluate model using DeepEval metrics."""
        self.model.eval()
        
        # Prepare test cases
        test_cases = []
        for text in test_texts:
            # Generate continuation
            generated = self.model.generate(text, max_length=20)
            
            test_case = LLMTestCase(
                input=text,
                actual_output=generated,
                expected_output=None  # For generation tasks
            )
            test_cases.append(test_case)
        
        # Calculate perplexity (simplified)
        total_loss = 0
        total_tokens = 0
        
        with torch.no_grad():
            for text in test_texts:
                words = text.split()
                if len(words) < 2:
                    continue
                    
                for i in range(1, len(words)):
                    input_text = ' '.join(words[:i])
                    target_word = words[i]
                    
                    logits, _ = self.model(input_text)
                    if logits is None:
                        continue
                        
                    last_logits = logits[0, -1:, :]
                    loss = self.model.lm_head.compute_loss(last_logits, [target_word])
                    
                    total_loss += loss.item()
                    total_tokens += 1
        
        perplexity = np.exp(total_loss / total_tokens) if total_tokens > 0 else float('inf')
        
        return {
            'perplexity': perplexity,
            'avg_loss': total_loss / total_tokens if total_tokens > 0 else 0,
            'num_test_cases': len(test_cases)
        }

# Demo training
model = TFreeLanguageModel(vocab_size=8192, hidden_size=256, num_layers=2)
trainer = TFreeTrainer(model)

# Sample training data
training_texts = [
    "T-FREE is a tokenizer-free approach for language models.",
    "It uses sparse representations based on character trigrams.",
    "This method reduces embedding parameters by 85 percent.",
    "The approach shows improved cross-lingual transfer learning."
]

# Prepare data
documents = trainer.prepare_training_data(training_texts)
print(f"Prepared {len(documents)} training documents")

# Simulate training
print("\nTraining for 5 steps...")
for step in range(5):
    loss = trainer.train_step(training_texts)
    print(f"Step {step + 1}, Loss: {loss:.4f}")

# Evaluate
print("\nEvaluation:")
test_texts = [
    "T-FREE is a",
    "The method uses",
    "Sparse representations"
]

metrics = trainer.evaluate_with_deepeval(test_texts)
print(f"Perplexity: {metrics['perplexity']:.2f}")
print(f"Average Loss: {metrics['avg_loss']:.4f}")
print(f"Test Cases: {metrics['num_test_cases']}")

## Template for Personal Research

Here's a template for applying T-FREE to your own research:

In [None]:
# Research Template: Implementing T-FREE for Your Application

class CustomTFreeModel:
    """Template for implementing T-FREE in your research."""
    
    def __init__(self, config):
        """
        Configuration options to explore:
        - vocab_size: Try different sizes (4k, 8k, 16k) based on your needs
        - num_hashes: More hashes = better disambiguation but more parameters
        - lowercase_ratio: Higher ratio = better generalization across cases
        - hidden_size: Match this to your downstream task requirements
        """
        self.config = config
        # Initialize your T-FREE model here
        
    def adapt_for_domain(self, domain_texts):
        """
        Domain adaptation suggestions:
        1. Analyze trigram distributions in your domain
        2. Adjust hash functions for domain-specific patterns
        3. Consider domain-specific whitespace rules
        """
        pass
    
    def optimize_for_language(self, language):
        """
        Language-specific optimizations:
        1. Adjust trigram extraction for character-based languages
        2. Modify whitespace rules for languages without spaces
        3. Consider longer n-grams for morphologically rich languages
        """
        pass
    
    def benchmark_against_baseline(self, baseline_tokenizer):
        """
        Benchmarking checklist:
        1. Compare vocabulary sizes and parameter counts
        2. Measure encoding/decoding speed
        3. Evaluate cross-lingual performance
        4. Test on out-of-vocabulary words
        5. Analyze morphological generalization
        """
        pass

# Research directions to explore:
print("Research Directions for T-FREE:")
print("1. Hybrid approaches: Combine T-FREE with byte-level fallback")
print("2. Dynamic vocabulary: Adapt dictionary during inference")
print("3. Hierarchical decoding: Group words by morphological families")
print("4. Multi-modal extension: Apply T-FREE principles to other modalities")
print("5. Compression techniques: Further reduce embedding size with quantization")

## Conclusion

This notebook has implemented the core concepts of T-FREE, demonstrating:

1. **Tokenizer-free approach**: Direct word embedding through character trigrams
2. **Parameter efficiency**: >85% reduction in embedding layer parameters
3. **Morphological similarity**: Automatic exploitation of word similarities
4. **Cross-lingual robustness**: Consistent performance across languages
5. **No training corpus needed**: Eliminates tokenizer training overhead

The implementation shows how T-FREE addresses the three fundamental flaws (F1-F3) of traditional tokenizers while maintaining competitive performance. This paradigm shift opens new possibilities for more efficient and adaptable language models.