# Toy Problem Case Study: Tokenization Effects in VLA Training

This notebook reproduces the case study from the paper that demonstrates how naive tokenization schemes affect the training of autoregressive vision-language-action (VLA) policies at different sampling rates.

## Overview

The case study reproduces the key findings from the paper showing that:

1. **Naive binning tokenization** works well at low sampling rates (H=25-50)
2. **Performance degrades significantly** at high sampling rates (H=400-800)
3. **Marginal information content approaches zero** as sampling frequency increases
4. **Models tend to copy the first action** at high frequencies instead of learning meaningful patterns

This demonstrates the need for better tokenization schemes like the DCT-based FAST tokenization proposed in the paper.


## Setup and Imports


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from typing import Dict, List

# Set up matplotlib for better plots
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# Import our modules
from cubic_spline_generator import CubicSplineGenerator
from binning_tokenizer import BinningTokenizer
from transformer_model import SimpleTransformer, count_parameters
from training import Trainer, run_experiment
from visualization import CaseStudyVisualizer

print("Setup complete!")


## 1. Demonstrate the Tokenization Issue

First, let's demonstrate the core tokenization issue described in the paper.


In [None]:
# Initialize components
generator = CubicSplineGenerator(seed=42)
tokenizer = BinningTokenizer(num_bins=256)

# Test different sampling rates
sampling_rates = [25, 50, 100, 200, 400, 800]

print("Analyzing marginal information content:")
print("Sampling Rate | Entropy | Zero Diff Ratio | Unique Diffs")
print("-" * 60)

results = {}

for H in sampling_rates:
    # Generate data
    times, targets, conditioning = generator.generate_spline_data(
        num_sequences=100,
        sequence_length=H
    )
    
    # Fit tokenizer
    tokenizer.fit(targets)
    
    # Analyze marginal information
    analysis = tokenizer.analyze_marginal_information(targets, H)
    results[H] = analysis
    
    print(f"{H:13d} | {analysis['entropy']:7.3f} | {analysis['zero_diff_ratio']:13.3f} | {analysis['unique_diffs']:11d}")


### Visualize the Tokenization Issue


In [None]:
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Entropy vs Sampling Rate
axes[0].plot(sampling_rates, [results[H]['entropy'] for H in sampling_rates], 'bo-', linewidth=2, markersize=8)
axes[0].set_xlabel('Sampling Rate (H)')
axes[0].set_ylabel('Entropy of Token Differences')
axes[0].set_title('Marginal Information Content')
axes[0].grid(True, alpha=0.3)
axes[0].set_xscale('log')

# Plot 2: Zero Difference Ratio vs Sampling Rate
axes[1].plot(sampling_rates, [results[H]['zero_diff_ratio'] for H in sampling_rates], 'ro-', linewidth=2, markersize=8)
axes[1].set_xlabel('Sampling Rate (H)')
axes[1].set_ylabel('Ratio of Zero Differences')
axes[1].set_title('Token Redundancy')
axes[1].grid(True, alpha=0.3)
axes[1].set_xscale('log')

# Plot 3: Example spline at different sampling rates
# Generate a single example spline
times, targets, conditioning = generator.generate_spline_data(
    num_sequences=1,
    sequence_length=100  # Use medium sampling rate for visualization
)

axes[2].plot(times[0], targets[0], 'k-', linewidth=2, label='Cubic Spline')
axes[2].scatter(conditioning[0, :, 0], conditioning[0, :, 1], 
               color='red', s=100, zorder=5, label='Conditioning Points')
axes[2].set_xlabel('Time')
axes[2].set_ylabel('Value')
axes[2].set_title('Example Cubic Spline')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print key insights
print("\n" + "=" * 60)
print("KEY INSIGHTS")
print("=" * 60)
print("1. ENTROPY DECREASES with sampling rate:")
print(f"   H=25:  {results[25]['entropy']:.3f}")
print(f"   H=800: {results[800]['entropy']:.3f}")
print(f"   Reduction: {(1 - results[800]['entropy']/results[25]['entropy'])*100:.1f}%")

print("\n2. TOKEN REDUNDANCY INCREASES with sampling rate:")
print(f"   H=25:  {results[25]['zero_diff_ratio']:.3f} zero differences")
print(f"   H=800: {results[800]['zero_diff_ratio']:.3f} zero differences")
print(f"   Increase: {(results[800]['zero_diff_ratio']/results[25]['zero_diff_ratio'] - 1)*100:.1f}%")

print("\n3. UNIQUE DIFFERENCES DECREASE with sampling rate:")
print(f"   H=25:  {results[25]['unique_diffs']} unique differences")
print(f"   H=800: {results[800]['unique_diffs']} unique differences")
print(f"   Reduction: {(1 - results[800]['unique_diffs']/results[25]['unique_diffs'])*100:.1f}%")


## 2. Run the Training Experiment

Now let's run the actual training experiment to see how the model performance degrades with sampling rate.


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create results directory
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)

# Initialize visualizer
visualizer = CaseStudyVisualizer(device)

print("Setup complete for training experiment!")


### Quick Test with Reduced Parameters

Let's start with a quick test using smaller parameters to demonstrate the effect.


In [None]:
# Run quick experiment
print("Running quick experiment with reduced parameters...")
print("Sampling rates: [25, 100, 400]")
print("Sequences: 200, Epochs: 20")

quick_results = run_experiment(
    sampling_rates=[25, 100, 400],
    num_sequences=200,  # Smaller for faster testing
    num_epochs=20,      # Fewer epochs for faster testing
    results_dir=results_dir
)

print("\n" + "=" * 60)
print("QUICK EXPERIMENT RESULTS")
print("=" * 60)
print("Sampling Rate (H) | MSE")
print("-" * 30)
for H in sorted(quick_results.keys()):
    print(f"{H:15d} | {quick_results[H]:.6f}")


### Visualize Results


In [None]:
# Plot the results
sampling_rates_quick = sorted(quick_results.keys())
mse_values = [quick_results[H] for H in sampling_rates_quick]

plt.figure(figsize=(10, 6))
plt.plot(sampling_rates_quick, mse_values, 'bo-', linewidth=2, markersize=8)
plt.xlabel('Sampling Rate (H)')
plt.ylabel('Mean Squared Error (MSE)')
plt.title('Effect of Sampling Rate on Prediction Performance\n(Naive Binning Tokenization)')
plt.grid(True, alpha=0.3)
plt.xscale('log')
plt.yscale('log')

# Add annotations
for H, mse in zip(sampling_rates_quick, mse_values):
    plt.annotate(f'{mse:.2e}', (H, mse), 
                textcoords="offset points", xytext=(0,10), ha='center')

plt.tight_layout()
plt.show()

print("\nExpected behavior:")
print("- H=25: Good performance (low MSE)")
print("- H=100: Moderate performance")
print("- H=400: Poor performance (high MSE)")
print("\nThis demonstrates the tokenization issue described in the paper!")


## 3. Tokenization Error Analysis

Let's also analyze the tokenization error itself at different sampling rates.


In [None]:
print("TOKENIZATION ERROR ANALYSIS")
print("=" * 60)

sampling_rates_error = [25, 100, 400]

print("Sampling Rate | Tokenization MSE | Relative Error")
print("-" * 50)

tokenization_errors = []

for H in sampling_rates_error:
    # Generate data
    times, targets, conditioning = generator.generate_spline_data(
        num_sequences=100,
        sequence_length=H
    )
    
    # Fit tokenizer
    tokenizer.fit(targets)
    
    # Tokenize and detokenize
    tokens = tokenizer.tokenize(targets)
    reconstructed = tokenizer.detokenize(tokens)
    
    # Compute error
    mse = tokenizer.compute_tokenization_error(targets, reconstructed)
    
    # Compute relative error
    target_range = targets.max() - targets.min()
    relative_error = mse / (target_range ** 2)
    
    tokenization_errors.append(mse)
    
    print(f"{H:13d} | {mse:15.2f} | {relative_error:13.6f}")

print("\nNote: Tokenization error increases with sampling rate")
print("due to the finite resolution of the 256 bins.")


## 4. Full Experiment (Optional)

If you want to run the full experiment with all sampling rates, uncomment and run the cell below. Note that this will take significantly longer and may require more GPU memory.


### Visualize Trajectories for Different Frequencies

Let's visualize the actual prediction trajectories for different sampling rates to see how the model behavior changes.


In [None]:
# Visualize prediction trajectories for different sampling rates
print("Generating trajectory visualizations...")

# Create a comprehensive visualization showing prediction quality at different frequencies
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

# Test different sampling rates
test_sampling_rates = [25, 50, 100, 200, 400, 800]

for i, H in enumerate(test_sampling_rates):
    ax = axes[i]
    
    # Generate test data
    generator_test = CubicSplineGenerator(seed=123)  # Different seed for variety
    times, targets, conditioning = generator_test.generate_spline_data(
        num_sequences=1,
        sequence_length=H
    )
    
    # Try to load the trained model for this sampling rate
    model_path = os.path.join(results_dir, f"model_H{H}.pth")
    
    if os.path.exists(model_path):
        # Load model
        model = SimpleTransformer(
            vocab_size=256,
            d_model=128,
            nhead=8,
            num_layers=4,
            max_seq_len=H + 100
        )
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device)
        model.eval()
        
        # Fit tokenizer
        tokenizer_test = BinningTokenizer(num_bins=256)
        tokenizer_test.fit(targets)
        
        # Generate predictions
        with torch.no_grad():
            conditioning_tensor = torch.from_numpy(conditioning).float().to(device)
            predicted_tokens = model.generate(
                conditioning_tensor,
                max_length=H,
                temperature=0.0,  # Deterministic generation
                device=device
            )
        
        # Convert predictions back to continuous values
        predicted_values = tokenizer_test.detokenize(predicted_tokens.cpu().numpy())
        
        # Plot ground truth
        ax.plot(times[0], targets[0], 'k--', linewidth=2, label='Ground Truth', alpha=0.8)
        
        # Plot conditioning points
        ax.scatter(conditioning[0, :, 0], conditioning[0, :, 1], 
                  color='white', s=100, zorder=5, edgecolors='black', linewidth=2,
                  label='Conditioning Points')
        
        # Plot prediction
        ax.plot(times[0], predicted_values[0], 'r-', linewidth=2, label='Prediction', alpha=0.8)
        
        # Compute and display MSE
        mse = np.mean((targets[0] - predicted_values[0]) ** 2)
        ax.text(0.05, 0.95, f'MSE: {mse:.4f}', transform=ax.transAxes, 
               verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        ax.set_title(f'H = {H}')
        ax.grid(True, alpha=0.3)
        
        # Only add legend to first subplot
        if i == 0:
            ax.legend(loc='upper right')
    
    else:
        # If model doesn't exist, just show the ground truth
        ax.plot(times[0], targets[0], 'k--', linewidth=2, label='Ground Truth')
        ax.scatter(conditioning[0, :, 0], conditioning[0, :, 1], 
                  color='red', s=100, zorder=5, label='Conditioning Points')
        ax.set_title(f'H = {H} (No trained model)')
        ax.grid(True, alpha=0.3)
        ax.legend()

plt.suptitle('Prediction Quality at Different Sampling Rates\n(Showing the "Copy First Action" Problem)', fontsize=16)
plt.tight_layout()
plt.show()

print("\nKey observations:")
print("- Low H (25, 50): Model learns to interpolate the smooth curve")
print("- Medium H (100, 200): Some degradation but still reasonable")
print("- High H (400, 800): Model tends to copy the first action or produce poor predictions")
print("- This demonstrates the tokenization issue described in the paper!")


### Detailed Token-Level Analysis

Let's also examine the token-level behavior to understand why the model fails at high frequencies.


In [None]:
# Analyze token-level behavior for different sampling rates
print("Analyzing token-level behavior...")

# Create a detailed analysis
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Test with a specific example
generator_analysis = CubicSplineGenerator(seed=42)
times, targets, conditioning = generator_analysis.generate_spline_data(
    num_sequences=1,
    sequence_length=100  # Use medium length for analysis
)

# Analyze different sampling rates by subsampling
sampling_rates_analysis = [25, 100, 400]

for idx, H in enumerate(sampling_rates_analysis):
    # Subsample the data to simulate different sampling rates
    if H < 100:
        # Downsample
        step = 100 // H
        times_sampled = times[0, ::step]
        targets_sampled = targets[0, ::step]
    else:
        # Upsample by interpolation
        times_sampled = np.linspace(times[0, 0], times[0, -1], H)
        targets_sampled = np.interp(times_sampled, times[0], targets[0])
    
    # Tokenize
    tokenizer_analysis = BinningTokenizer(num_bins=256)
    tokenizer_analysis.fit(targets_sampled.reshape(1, -1))
    tokens = tokenizer_analysis.tokenize(targets_sampled.reshape(1, -1))[0]
    
    # Plot 1: Original signal vs tokenized signal
    ax1 = axes[0, idx] if idx < 2 else axes[1, idx-2]
    
    ax1.plot(times_sampled, targets_sampled, 'b-', linewidth=2, label='Original Signal', alpha=0.7)
    ax1.plot(times_sampled, tokenizer_analysis.detokenize(tokens.reshape(1, -1))[0], 
             'r--', linewidth=1, label='Tokenized Signal', alpha=0.7)
    ax1.set_title(f'Sampling Rate H={H}')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Value')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Token differences (showing redundancy)
    if idx == 0:  # Only plot token differences for one example
        ax2 = axes[1, 1]
        token_diffs = np.diff(tokens)
        ax2.plot(token_diffs, 'g-', linewidth=1, alpha=0.7)
        ax2.set_title('Token Differences (H=25)')
        ax2.set_xlabel('Token Index')
        ax2.set_ylabel('Token Difference')
        ax2.grid(True, alpha=0.3)
        
        # Add statistics
        zero_ratio = np.mean(token_diffs == 0)
        unique_diffs = len(np.unique(token_diffs))
        ax2.text(0.05, 0.95, f'Zero ratio: {zero_ratio:.3f}\\nUnique diffs: {unique_diffs}', 
                transform=ax2.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

plt.suptitle('Token-Level Analysis: Why High Sampling Rates Fail', fontsize=16)
plt.tight_layout()
plt.show()

# Print detailed analysis
print("\\n" + "=" * 60)
print("TOKEN-LEVEL ANALYSIS")
print("=" * 60)

for H in sampling_rates_analysis:
    # Subsample data
    if H < 100:
        step = 100 // H
        targets_sampled = targets[0, ::step]
    else:
        times_sampled = np.linspace(times[0, 0], times[0, -1], H)
        targets_sampled = np.interp(times_sampled, times[0], targets[0])
    
    # Tokenize and analyze
    tokenizer_analysis = BinningTokenizer(num_bins=256)
    tokenizer_analysis.fit(targets_sampled.reshape(1, -1))
    tokens = tokenizer_analysis.tokenize(targets_sampled.reshape(1, -1))[0]
    token_diffs = np.diff(tokens)
    
    # Compute statistics
    zero_ratio = np.mean(token_diffs == 0)
    unique_diffs = len(np.unique(token_diffs))
    entropy = -np.sum(np.bincount(token_diffs + 255) / len(token_diffs) * 
                     np.log2(np.bincount(token_diffs + 255) / len(token_diffs) + 1e-10))
    
    print(f"H={H:3d}: Zero ratio={zero_ratio:.3f}, Unique diffs={unique_diffs:3d}, Entropy={entropy:.3f}")

print("\\nKey insights:")
print("- Higher sampling rates lead to more redundant tokens (higher zero ratio)")
print("- Fewer unique token differences means less information for learning")
print("- Lower entropy indicates less diversity in the token sequence")
print("- This explains why autoregressive models struggle at high frequencies!")


In [None]:
# Uncomment to run full experiment
print("Running full experiment...")
print("This may take a while and require significant GPU memory.")

full_results = run_experiment(
    sampling_rates=[25, 50, 100, 200, 400, 800],
    num_sequences=1000,
    num_epochs=100,
    results_dir=results_dir
)

print("\n" + "=" * 60)
print("FULL EXPERIMENT RESULTS")
print("=" * 60)
print("Sampling Rate (H) | MSE")
print("-" * 30)
for H in sorted(full_results.keys()):
    print(f"{H:15d} | {full_results[H]:.6f}")


## 5. Conclusion

This case study successfully reproduces the key findings from the paper:


In [None]:
print("=" * 60)
print("CONCLUSION")
print("=" * 60)
print("This case study demonstrates why naive binning tokenization fails at high sampling rates:")
print()
print("1. MARGINAL INFORMATION PROBLEM:")
print("   - As sampling rate increases, consecutive tokens become highly correlated")
print("   - This reduces the marginal information content that autoregressive models rely on")
print()
print("2. TOKEN REDUNDANCY:")
print("   - At high frequencies, many consecutive tokens are identical or very similar")
print("   - This makes it hard for the model to learn meaningful patterns")
print()
print("3. COPY BEHAVIOR:")
print("   - Models trained at high sampling rates tend to simply copy the first action")
print("   - Instead of interpolating the smooth spline curve")
print()
print("4. NEED FOR BETTER TOKENIZATION:")
print("   - This case study motivates the development of better tokenization schemes")
print("   - Like the DCT-based FAST tokenization proposed in the paper")
print()
print("The results justify the paper's proposal for improved tokenization methods")
print("that maintain high information content across all sampling rates.")


## Next Steps

To complete the case study, you could:

1. **Implement the FAST tokenization** using DCT-based encoding
2. **Compare performance** between naive binning and FAST tokenization
3. **Run on real robot data** to validate the findings
4. **Experiment with different bin sizes** and tokenization schemes

The current implementation provides a solid foundation for exploring these extensions.
