# üß† Symbolic Transformer Training

Train a tiny transformer to predict next symbols in First-Order Logic formulas.

**What this does:**
- Generates synthetic FOL training data
- Trains a small transformer model (566K - 19.6M parameters)
- Learns syntax rules like: `‚àÄ` ‚Üí must be followed by `VAR`

**Quick Start:** Run cells 1-4 in order. Training takes ~30s-90s/epoch on GPU.

---

## 1Ô∏è‚É£ Setup Environment
Clone the repository and install dependencies.

In [1]:
#@title 1. Setup Environment { display-mode: "form" }
#@markdown Run this cell first to set up the environment.

import os

# Check GPU
print("üîç Checking GPU availability...")
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null || echo "‚ö†Ô∏è No GPU detected - training will be slower"

# Clone repository
print("\nüì¶ Cloning Symbolic-Transformers repository...")
%cd /content
if not os.path.exists('Symbolic-Transformers'):
    !git clone -q https://github.com/tripptytrip/Symbolic-Transformers.git
    print("‚úì Repository cloned")
else:
    !cd Symbolic-Transformers && git stash && git pull -q
    print("‚úì Repository updated")

%cd /content/Symbolic-Transformers

# Install dependencies
print("\nüìö Installing dependencies...")
!pip -q install numpy scipy pandas tqdm rich tensorboard
print("‚úì Dependencies installed")

# Verify vocabulary
print("\nüî§ Verifying vocabulary...")
!python -c "from utils.vocabulary import Vocabulary; v = Vocabulary('unified_vocabulary.json'); print(f'‚úì Vocabulary loaded: {v.vocab_size} tokens')"

print("\n" + "="*50)
print("‚úÖ Setup complete! Proceed to Step 2.")
print("="*50)

üîç Checking GPU availability...
NVIDIA A100-SXM4-80GB, 81920 MiB

üì¶ Cloning Symbolic-Transformers repository...
/content
‚úì Repository cloned
/content/Symbolic-Transformers

üìö Installing dependencies...
‚úì Dependencies installed

üî§ Verifying vocabulary...
‚úì Vocabulary loaded: 662 tokens
  - Numerals: 0-624
  - Symbols: 625-661
  - Compositional: ['VAR', 'CONST', 'PRED', 'FUNC', 'SORT']
‚úì Vocabulary loaded: 662 tokens

‚úÖ Setup complete! Proceed to Step 2.


## 2Ô∏è‚É£ Configure Training
Adjust the settings below, then run the cell to apply them.

In [5]:
#@title 2. Training Configuration { display-mode: "form" }

#@markdown ### üèóÔ∏è Model Size
model_size = "tiny" #@param ["tiny", "small", "base"]
#@markdown - **tiny**: 566K params (~2.2MB) - Fast training, good for experiments
#@markdown - **small**: 3.5M params (~14MB) - Better accuracy, moderate training time
#@markdown - **base**: 19.6M params (~78MB) - Best accuracy, longest training

#@markdown ---
#@markdown ### üìä Dataset Size
num_train_formulas = 50000 #@param {type:"slider", min:1000, max:50000, step:1000}
#@markdown Number of unique FOL formulas to generate for training.
#@markdown - 1000-3000: Quick experiments
#@markdown - 5000-10000: Standard training
#@markdown - 20000-50000: Large-scale training (recommended for small/base models)

num_val_formulas = 5000 #@param {type:"slider", min:100, max:5000, step:100}
#@markdown Number of formulas for validation (typically 10-20% of training).

#@markdown ---
#@markdown ### üß¨ Data Generator
use_advanced_generator = True #@param {type:"boolean"}
#@markdown **Advanced generator** adds:
#@markdown - Function symbols: `P(f(x), g(y))` instead of just `P(x, y)`
#@markdown - Fixed signatures: `PRED_5` is *always* arity 2 (model learns consistency)
#@markdown - Horn clauses: `(A ‚àß B ‚àß C) ‚Üí D` (common logic programming pattern)
#@markdown - Vacuous quantification: `‚àÄx P(y)` (tests scope understanding)

#@markdown ---
#@markdown ### ‚öôÔ∏è Training Parameters
num_epochs = 500 #@param {type:"slider", min:10, max:500, step:10}
#@markdown Number of training epochs.
#@markdown - 10-30: Quick experiments
#@markdown - 50-100: Standard training
#@markdown - 100-200: Train to convergence (watch for overfitting!)

batch_size = 32 #@param [32, 64, 128, 256] {type:"raw"}
#@markdown Larger batches = faster training but more memory.

#@markdown ---
#@markdown ### üíæ Resume from Checkpoint
resume_training = False #@param {type:"boolean"}
#@markdown Resume from the last saved checkpoint.

# Store configuration
config = {
    'model_size': model_size,
    'num_train_formulas': num_train_formulas,
    'num_val_formulas': num_val_formulas,
    'num_test_formulas': max(100, num_val_formulas // 2),
    'num_epochs': num_epochs,
    'batch_size': batch_size,
    'resume': resume_training,
    'use_advanced_generator': use_advanced_generator
}

# Display configuration summary
print("="*50)
print("üìã TRAINING CONFIGURATION")
print("="*50)

model_params = {'tiny': '566K', 'small': '3.5M', 'base': '19.6M'}
model_size_mb = {'tiny': '~2.2MB', 'small': '~14MB', 'base': '~78MB'}

print(f"\nüèóÔ∏è  Model:     {model_size} ({model_params[model_size]} parameters, {model_size_mb[model_size]})")
print(f"üìä Dataset:   {num_train_formulas} train / {num_val_formulas} val formulas")
print(f"üß¨ Generator: {'Advanced (functions, fixed signatures, Horn clauses)' if use_advanced_generator else 'Basic'}")
print(f"‚öôÔ∏è  Training:  {num_epochs} epochs, batch size {batch_size}")
print(f"üîÑ Resume:    {'Yes' if resume_training else 'No (fresh start)'}")

# Estimate training time
samples_per_formula = 15 if use_advanced_generator else 14
total_samples = num_train_formulas * samples_per_formula
batches_per_epoch = total_samples // batch_size
time_per_batch = {'tiny': 0.008, 'small': 0.020, 'base': 0.040}  # seconds on A100
est_epoch_time = batches_per_epoch * time_per_batch[model_size]
est_total_time = est_epoch_time * num_epochs / 60

print(f"\n‚è±Ô∏è  Estimated training time: ~{est_total_time:.0f} minutes on A100")
print(f"   ({est_epoch_time:.0f}s per epoch, {batches_per_epoch} batches)")

# Recommendations
if model_size in ['small', 'base'] and num_train_formulas < 10000:
    print(f"\nüí° TIP: {model_size} model benefits from more data.")
    print(f"   Consider increasing to 20000+ formulas.")

if num_epochs > 100 and not use_advanced_generator:
    print(f"\n‚ö†Ô∏è  WARNING: High epochs ({num_epochs}) with basic generator.")
    print(f"   Risk of overfitting! Consider:")
    print(f"   - Enabling advanced generator for richer data")
    print(f"   - Or reducing epochs to 50-100")

print("\n" + "="*50)
print("‚úÖ Configuration saved! Proceed to Step 3.")
print("="*50)

üìã TRAINING CONFIGURATION

üèóÔ∏è  Model:     tiny (566K parameters, ~2.2MB)
üìä Dataset:   50000 train / 5000 val formulas
üß¨ Generator: Advanced (functions, fixed signatures, Horn clauses)
‚öôÔ∏è  Training:  500 epochs, batch size 32
üîÑ Resume:    No (fresh start)

‚è±Ô∏è  Estimated training time: ~1562 minutes on A100
   (187s per epoch, 23437 batches)

‚úÖ Configuration saved! Proceed to Step 3.


## 3Ô∏è‚É£ Generate Training Data
Create synthetic First-Order Logic formulas for training.

In [3]:
#@title 3. Generate Training Data { display-mode: "form" }
#@markdown This generates synthetic FOL formulas.
#@markdown
#@markdown **Basic generator examples:**
#@markdown - `‚àÄx‚ÇÅ (P‚ÇÖ(x‚ÇÅ) ‚Üí Q‚ÇÇ(x‚ÇÅ))`
#@markdown - `‚àÉx‚ÇÄ ‚àÉx‚ÇÅ (R‚ÇÉ(x‚ÇÄ, x‚ÇÅ) ‚àß P‚ÇÅ(x‚ÇÄ))`
#@markdown
#@markdown **Advanced generator adds:**
#@markdown - `‚àÄx P(f(x), g(x, y))` (function symbols)
#@markdown - `(A ‚àß B ‚àß C) ‚Üí D` (Horn clauses)
#@markdown - Fixed predicate arities across all formulas

import os
os.chdir('/content/Symbolic-Transformers')

print("üîÑ Generating training data...")
print(f"   Train: {config['num_train_formulas']} formulas")
print(f"   Val:   {config['num_val_formulas']} formulas")
print(f"   Test:  {config['num_test_formulas']} formulas")
print(f"   Generator: {'Advanced' if config['use_advanced_generator'] else 'Basic'}")
print()

import sys
sys.path.insert(0, '/content/Symbolic-Transformers')

if config['use_advanced_generator']:
    # Use advanced generator with functions, fixed signatures, Horn clauses
    from data.advanced_generator import generate_advanced_training_data

    generate_advanced_training_data(
        vocab_path="unified_vocabulary.json",
        output_dir="datasets/fol_next_symbol",
        n_train=config['num_train_formulas'],
        n_val=config['num_val_formulas'],
        n_test=config['num_test_formulas'],
    )
else:
    # Use basic generator
    from data.dataset_generator import generate_training_data

    generate_training_data(
        vocab_path="unified_vocabulary.json",
        output_dir="datasets/fol_next_symbol",
        n_train=config['num_train_formulas'],
        n_val=config['num_val_formulas'],
        n_test=config['num_test_formulas'],
    )

# Show dataset stats
print("\nüìÅ Dataset files:")
!ls -lh datasets/fol_next_symbol/*.json

print("\n" + "="*50)
print("‚úÖ Data generated! Proceed to Step 4.")
print("="*50)

üîÑ Generating training data...
   Train: 50000 formulas
   Val:   5000 formulas
   Test:  2500 formulas
   Generator: Advanced

ADVANCED FOL DATASET GENERATION
‚úì Vocabulary loaded: 662 tokens
  - Numerals: 0-624
  - Symbols: 625-661
  - Compositional: ['VAR', 'CONST', 'PRED', 'FUNC', 'SORT']

‚ÑπÔ∏è Fixed Signatures (consistent across all formulas):
   Predicates: {0: 2, 1: 1, 2: 1, 3: 1, 4: 2, 5: 2, 6: 2, 7: 1, 8: 2, 9: 1}
   Functions:  {0: 1, 1: 1, 2: 1, 3: 1}

Generating train set (50000 formulas)...
  Complexity 1: 10000 formulas
  Complexity 2: 20000 formulas
  Complexity 3: 15000 formulas
  Complexity 4: 5000 formulas

Generating val set (5000 formulas)...
  Complexity 1: 1000 formulas
  Complexity 2: 2000 formulas
  Complexity 3: 1500 formulas
  Complexity 4: 500 formulas

Generating test set (2500 formulas)...
  Complexity 1: 500 formulas
  Complexity 2: 1000 formulas
  Complexity 3: 750 formulas
  Complexity 4: 250 formulas
‚úì Saved 1682302 samples to datasets/fol_next_s

## 4Ô∏è‚É£ Train the Model
Start training! Watch the loss decrease over epochs.

In [None]:
#@title 4. Train Model { display-mode: "form" }
#@markdown Training will begin with the configuration from Step 2.
#@markdown
#@markdown **What to watch for:**
#@markdown - Loss should decrease over epochs
#@markdown - Val loss < 1.0 is good progress
#@markdown - Val loss < 0.85 is excellent
#@markdown - `[BEST]` indicates a new best checkpoint was saved
#@markdown - **Stop if val loss starts rising** (overfitting)

import os
os.chdir('/content/Symbolic-Transformers')

print("="*60)
print("üöÄ STARTING TRAINING")
print("="*60)
print(f"Model: {config['model_size']} | Epochs: {config['num_epochs']}")
if config['resume']:
    print("Resuming from last checkpoint...")
print("="*60 + "\n")

# Build training command
# Note: batch-size defaults to 64 in train.py
cmd = f"python training/train.py --model-size {config['model_size']} --num-epochs {config['num_epochs']}"
if config['resume']:
    cmd += " --resume"

# Run training
!{cmd}

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE!")
print("="*60)
print("\nüìÅ Saved checkpoints:")
!ls -lh checkpoints/*.pt 2>/dev/null | tail -5
print("\nBest model saved to: checkpoints/best_model.pt")

üöÄ STARTING TRAINING
Model: tiny | Epochs: 500

‚úì Vocabulary loaded: 662 tokens
  - Numerals: 0-624
  - Symbols: 625-661
  - Compositional: ['VAR', 'CONST', 'PRED', 'FUNC', 'SORT']

‚úì Loaded vocabulary: 662 tokens

Loading datasets...
‚úì Loaded 1682302 samples from datasets/fol_next_symbol/train.json
‚úì Loaded 166031 samples from datasets/fol_next_symbol/val.json
‚úì Train batches: 26286
‚úì Val batches: 2595

Creating model...
‚úì Created tiny model with 566,934 parameters
‚úì Using device: cuda
‚úì GPU: NVIDIA A100-SXM4-80GB
‚úì VRAM: 85.2 GB

TRAINING START
Model: tiny
Vocab size: 662
Batch size: 64
Epochs: 500
Device: cuda

Epoch 1/500
  Batch 100/26286 | Loss: 6.3634 | LR: 5.00e-06
  Batch 200/26286 | Loss: 6.2427 | LR: 1.00e-05
  Batch 300/26286 | Loss: 6.0659 | LR: 1.50e-05
  Batch 400/26286 | Loss: 5.8739 | LR: 2.00e-05
  Batch 500/26286 | Loss: 5.6750 | LR: 2.50e-05
  Batch 600/26286 | Loss: 5.4612 | LR: 3.00e-05
  Batch 700/26286 | Loss: 5.2400 | LR: 3.50e-05
  Batch 

## 5Ô∏è‚É£ Evaluate Model (Optional)
Run evaluation on the test set to see accuracy metrics.

In [None]:
#@title 5. Evaluate Model { display-mode: "form" }
#@markdown Run evaluation on the test set.

import os
os.chdir('/content/Symbolic-Transformers')

print("üìä Evaluating model on test set...\n")

!python evaluate_model.py \
    --checkpoint checkpoints/best_model.pt \
    --test-data datasets/fol_next_symbol/test.json

üìä Evaluating model on test set...


EVALUATING TRAINED MODEL

Loading vocabulary...
‚úì Vocabulary size: 662

Loading model...
‚úì Created tiny model with 566,934 parameters
‚úì Loaded model from epoch 33
‚úì Best val loss: 0.9493

Loading test data...
‚úì Test samples: 82146

ACCURACY METRICS

Computing top-1 accuracy...
‚úì Top-1 Accuracy: 26.41%
  (26.4% chance of exact next symbol)

Computing top-5 accuracy...
object address  : 0x780987d631c0
object refcount : 3
object type     : 0xa2a4e0
object type name: KeyboardInterrupt
object repr     : KeyboardInterrupt()
lost sys.stderr
^C


## 6Ô∏è‚É£ Download Trained Model (Optional)
Download the trained model checkpoint to your local machine.

In [None]:
#@title 6. Download Model { display-mode: "form" }
#@markdown Download the best model checkpoint.

from google.colab import files
import os

checkpoint_path = '/content/Symbolic-Transformers/checkpoints/best_model.pt'

if os.path.exists(checkpoint_path):
    print(f"üì¶ Preparing download...")
    file_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
    print(f"   File: best_model.pt ({file_size:.1f} MB)")
    print(f"   Model: {config['model_size']}")
    print()
    files.download(checkpoint_path)
else:
    print("‚ùå No checkpoint found. Run training first (Step 4).")

## 7Ô∏è‚É£ Interactive Demo (Optional)
Try the model interactively - type tokens and see predictions!

In [None]:
#@title 7. Quick Demo { display-mode: "form" }
#@markdown See the model's predictions for a sample input.

import torch
import sys
sys.path.insert(0, '/content/Symbolic-Transformers')

from utils.vocabulary import Vocabulary
from models.transformer import SymbolicTransformer, get_model_config

# Load model
print("üîÑ Loading model...")
vocab = Vocabulary('unified_vocabulary.json')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoint = torch.load('checkpoints/best_model.pt', map_location=device, weights_only=False)
model_config = get_model_config(checkpoint['config']['model_size'], vocab.vocab_size)
model = SymbolicTransformer(**model_config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úì Loaded {checkpoint['config']['model_size']} model from epoch {checkpoint['epoch']}")
print(f"  Val loss: {checkpoint['best_val_loss']:.4f}\n")

# Demo predictions
def predict_next(tokens, top_k=5):
    """Predict next token given a sequence."""
    token_ids = [vocab.encode_label(t) if t in vocab.label_to_id else int(t) for t in tokens]
    x = torch.tensor([token_ids], device=device)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits[0, -1], dim=-1)
        top_probs, top_ids = torch.topk(probs, top_k)

    print(f"Input: {' '.join(tokens)}")
    print(f"\nTop {top_k} predictions:")
    for i, (prob, tid) in enumerate(zip(top_probs, top_ids)):
        label = vocab.decode_id(tid.item())
        bar = '‚ñà' * int(prob * 30)
        print(f"  {i+1}. {label:12s} {prob*100:5.1f}% {bar}")
    print()

# Show example predictions
print("="*50)
print("üìä EXAMPLE PREDICTIONS")
print("="*50 + "\n")

# After FORALL, should predict VAR with high confidence
predict_next(['FORALL'])

# After FORALL VAR, should predict a numeral
predict_next(['FORALL', 'VAR'])

# After a complete variable binding
predict_next(['FORALL', 'VAR', '1'])

# After predicate (should predict LPAREN)
predict_next(['PRED', '3'])

print("\nüí° The model learned FOL syntax rules!")

---

## üìñ Quick Reference

### Model Sizes
| Size | Parameters | File Size | Time/Epoch (A100) | Capacity |
|------|-----------|-----------|-------------------|----------|
| tiny | 566K | ~2.2MB | ~30s | Good for <10K formulas |
| small | 3.5M | ~14MB | ~90s | Good for 10-50K formulas |
| base | 19.6M | ~78MB | ~180s | Good for 50K+ formulas |

### Data Generator Comparison
| Feature | Basic | Advanced |
|---------|-------|----------|
| Predicates `P(x, y)` | ‚úì | ‚úì |
| Functions `f(x)` | ‚úó | ‚úì |
| Nested terms `P(f(g(x)))` | ‚úó | ‚úì |
| Fixed arities | ‚úó | ‚úì |
| Horn clauses | ‚úó | ‚úì |
| Vacuous quantification | ‚úó | ‚úì |

### Recommended Configurations

**Quick Test (5-10 min):**
- Model: tiny
- Data: 1000 formulas (basic)
- Epochs: 20

**Standard Training (30-60 min):**
- Model: small
- Data: 10000 formulas (advanced)
- Epochs: 50-100

**Best Results (2+ hours):**
- Model: base
- Data: 30000+ formulas (advanced)
- Epochs: 100-200 (watch val loss!)

### Interpreting Results
- **Val Loss > 1.5**: Model is still learning basic patterns
- **Val Loss 1.0-1.5**: Good progress, learning syntax rules
- **Val Loss 0.85-1.0**: Excellent, model understands FOL structure
- **Val Loss < 0.85**: Very good, approaching optimal

### ‚ö†Ô∏è Overfitting Warning Signs
- Train loss keeps dropping but val loss stops improving
- Val loss starts **increasing** while train loss decreases
- Gap between train and val loss > 0.1

**If overfitting:**
1. Stop training and use the checkpoint with lowest val loss
2. Generate more training data
3. Enable advanced generator for richer patterns

### Files Created
- `checkpoints/best_model.pt` - Best performing model
- `checkpoints/checkpoint_epoch_N.pt` - Periodic checkpoints
- `datasets/fol_next_symbol/` - Training data