# Forward-Only Token-Adaptive Learning (FOTAL) Demo

This notebook demonstrates the capabilities of the NoBackdrop model, which implements Forward-Only Token-Adaptive Learning (FOTAL). This novel approach enables efficient language model training without backpropagation, making it suitable for training on limited hardware such as an RTX 3050 GPU.

In [None]:
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from datasets import load_dataset

# Add project root to path
sys.path.append(os.path.abspath('..'))

from no_backdrop.model.hebbian_lm import HebbianLM
from no_backdrop.training.trainer import Trainer
from no_backdrop.training.data_utils import prepare_dataloaders

## 1. Creating a FOTAL Model

First, let's create a FOTAL model with appropriate parameters for demonstration purposes.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Ensure the tokenizer has padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Create model
model = HebbianLM(
    vocab_size=len(tokenizer),
    hidden_size=256,  # Small size for demonstration
    num_hidden_layers=4,
    num_attention_heads=8,
    window_size=64,
    dropout=0.1,
    max_position_embeddings=512,
    update_rate=0.01,
    use_fast_weights=True,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id,
    eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id,
)

# Move model to device
model.to(device)

# Create trainer
trainer = Trainer(
    model=model,
    learning_rate=5e-5,
    weight_decay=0.01,
    device=device,
    log_interval=10,
    eval_interval=50,
    save_interval=100,
    checkpoint_dir="./checkpoints",
    use_wandb=False,
)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

## 2. Training on a Small Dataset

Let's train our model on a small dataset to demonstrate the forward-only learning approach.

In [None]:
# Load a small dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Extract a small subset for demonstration
train_texts = dataset["train"]["text"][:100]  # Just 100 examples
eval_texts = dataset["validation"]["text"][:20]  # Just 20 examples

# Filter out empty texts
train_texts = [text for text in train_texts if text.strip()]
eval_texts = [text for text in eval_texts if text.strip()]

print(f"Training on {len(train_texts)} texts, evaluating on {len(eval_texts)} texts")

# Prepare dataloaders
train_dataloader, eval_dataloader = prepare_dataloaders(
    train_texts=train_texts,
    eval_texts=eval_texts,
    tokenizer=tokenizer,
    batch_size=2,  # Small batch size for demonstration
    max_length=128,  # Short sequences for demonstration
    stride=64,
    num_workers=0,  # Use 0 for Jupyter notebook
    pad_token_id=tokenizer.pad_token_id,
)

In [None]:
# Train for a few steps
history = trainer.train(
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    num_epochs=1,
    max_steps=50,  # Just 50 steps for demonstration
)

## 3. Visualizing Training Progress

In [None]:
# Plot training and evaluation loss
plt.figure(figsize=(10, 6))
plt.plot(history["train_loss"], label="Train Loss")
plt.plot([i * trainer.eval_interval for i in range(len(history["eval_loss"]))], 
         history["eval_loss"], label="Eval Loss", marker="o")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training and Evaluation Loss")
plt.legend()
plt.grid(True)
plt.show()

# Plot perplexity
plt.figure(figsize=(10, 6))
plt.plot(history["train_perplexity"], label="Train Perplexity")
plt.plot([i * trainer.eval_interval for i in range(len(history["eval_perplexity"]))], 
         history["eval_perplexity"], label="Eval Perplexity", marker="o")
plt.xlabel("Steps")
plt.ylabel("Perplexity")
plt.title("Training and Evaluation Perplexity")
plt.legend()
plt.grid(True)
plt.show()

## 4. Text Generation

Let's generate some text with our trained model.

In [None]:
# Define prompts
prompts = [
    "Once upon a time",
    "The meaning of life is",
    "In the future, artificial intelligence will",
]

# Generate text
for prompt in prompts:
    generated_texts = trainer.generate_text(
        prompt=prompt,
        tokenizer=tokenizer,
        max_length=50,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        update_model=False,  # Don't update model during generation
    )
    
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {generated_texts[0]}")

## 5. Streaming Adaptation

One of the key features of FOTAL is its ability to adapt during inference. Let's demonstrate this capability.

In [None]:
# Define a prompt
prompt = "The capital of France is"

# Generate text before adaptation
print("Before adaptation:")
generated_texts = trainer.generate_text(
    prompt=prompt,
    tokenizer=tokenizer,
    max_length=30,
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    update_model=False,
)
print(f"Prompt: {prompt}")
print(f"Generated: {generated_texts[0]}")

# Adapt the model with new information
adaptation_text = "The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid."
print(f"\nAdapting model with: {adaptation_text}")

# Tokenize adaptation text
tokenized = tokenizer(adaptation_text, return_tensors="pt")
adaptation_ids = tokenized["input_ids"].to(device)
attention_mask = torch.ones_like(adaptation_ids)

# Forward pass with adaptation
with torch.no_grad():
    outputs = model(
        input_ids=adaptation_ids,
        attention_mask=attention_mask,
        update_model=True,  # Enable adaptation
        compute_loss=False,
    )

# Generate text after adaptation
print("\nAfter adaptation:")
generated_texts = trainer.generate_text(
    prompt=prompt,
    tokenizer=tokenizer,
    max_length=30,
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    update_model=False,
)
print(f"Prompt: {prompt}")
print(f"Generated: {generated_texts[0]}")

# Try with a different prompt
new_prompt = "The capital of Italy is"
generated_texts = trainer.generate_text(
    prompt=new_prompt,
    tokenizer=tokenizer,
    max_length=30,
    temperature=0.7,
    top_k=50,
    top_p=0.9,
    update_model=False,
)
print(f"\nPrompt: {new_prompt}")
print(f"Generated: {generated_texts[0]}")

## 6. Single Batch Learning

FOTAL models can learn effectively from a single batch. Let's demonstrate this capability.

In [None]:
# Create a new model for this demonstration
new_model = HebbianLM(
    vocab_size=len(tokenizer),
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=8,
    window_size=64,
    dropout=0.1,
    max_position_embeddings=512,
    update_rate=0.05,  # Higher update rate for faster learning
    use_fast_weights=True,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id,
    eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id,
)

# Move model to device
new_model.to(device)

# Create trainer
new_trainer = Trainer(
    model=new_model,
    learning_rate=5e-5,
    weight_decay=0.01,
    device=device,
)

# Define a text to learn
learning_text = """
The quick brown fox jumps over the lazy dog. 
Python is a high-level, interpreted programming language. 
Machine learning is a field of study that gives computers the ability to learn without being explicitly programmed.
"""

# Tokenize the text
tokenized = tokenizer(learning_text, return_tensors="pt")
input_ids = tokenized["input_ids"].to(device)
attention_mask = torch.ones_like(input_ids)

# Prepare batch
batch = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}

# Measure initial perplexity
new_model.eval()
with torch.no_grad():
    outputs = new_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        update_model=False,
        compute_loss=True,
    )

initial_loss = outputs["loss"].item()
initial_perplexity = new_trainer._compute_perplexity(outputs["loss"])

print(f"Initial loss: {initial_loss:.4f}")
print(f"Initial perplexity: {initial_perplexity:.2f}")

# Train on a single batch
print("\nTraining on a single batch...")
metrics = new_trainer.train_step(batch)

# Measure final perplexity
new_model.eval()
with torch.no_grad():
    outputs = new_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        update_model=False,
        compute_loss=True,
    )

final_loss = outputs["loss"].item()
final_perplexity = new_trainer._compute_perplexity(outputs["loss"])

print(f"Final loss: {final_loss:.4f}")
print(f"Final perplexity: {final_perplexity:.2f}")
print(f"Improvement: {initial_perplexity - final_perplexity:.2f}")

# Generate text based on the learned content
prompts = [
    "The quick brown",
    "Python is a",
    "Machine learning is",
]

print("\nGenerating text based on learned content:")
for prompt in prompts:
    generated_texts = new_trainer.generate_text(
        prompt=prompt,
        tokenizer=tokenizer,
        max_length=30,
        temperature=0.7,
        top_k=50,
        top_p=0.9,
        update_model=False,
    )
    
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {generated_texts[0]}")

## 7. Memory Usage Analysis

Let's analyze the memory usage of our FOTAL model compared to traditional models.

In [None]:
# Create a baseline model for comparison
from transformers import GPT2Config, GPT2LMHeadModel

# Create a GPT-2 model with similar size
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=512,
    n_ctx=512,
    n_embd=256,
    n_layer=4,
    n_head=8,
)

baseline_model = GPT2LMHeadModel(config)
baseline_model.to(device)

# Measure memory usage for NoBackdrop model
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

# Generate random input
input_ids = torch.randint(0, model.vocab_size, (1, 512), device=device)
attention_mask = torch.ones_like(input_ids)

# Forward pass with NoBackdrop model
model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    update_model=True,
    compute_loss=True,
)

no_backdrop_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB

# Measure memory usage for baseline model
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

# Forward pass with baseline model
baseline_model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    labels=input_ids,
)

baseline_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB

# Calculate memory ratio
memory_ratio = no_backdrop_memory / baseline_memory if baseline_memory > 0 else float('inf')

print(f"NoBackdrop Memory Usage: {no_backdrop_memory:.2f} MB")
print(f"Baseline Memory Usage: {baseline_memory:.2f} MB")
print(f"Memory Ratio: {memory_ratio:.2f}x")

# Plot memory usage
plt.figure(figsize=(10, 6))
plt.bar(["NoBackdrop", "Baseline"], [no_backdrop_memory, baseline_memory])
plt.title("Memory Usage Comparison")
plt.ylabel("Memory Usage (MB)")
plt.grid(True, axis='y')
plt.show()

## 8. Training Speed Analysis

Let's compare the training speed of our FOTAL model with traditional backpropagation.

In [None]:
import time

# Generate random input
batch_size = 4
seq_len = 128
input_ids = torch.randint(0, model.vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids)

# Prepare batch
batch = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
}

# Benchmark NoBackdrop model
num_steps = 10
model.train()
torch.cuda.synchronize()
start_time = time.time()

for _ in range(num_steps):
    trainer.train_step(batch)

torch.cuda.synchronize()
no_backdrop_time = time.time() - start_time

# Benchmark baseline model
baseline_model.train()
optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=5e-5)

torch.cuda.synchronize()
start_time = time.time()

for _ in range(num_steps):
    # Forward pass
    outputs = baseline_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=input_ids,
    )
    
    # Backward pass
    loss = outputs.loss
    loss.backward()
    
    # Update parameters
    optimizer.step()
    optimizer.zero_grad()

torch.cuda.synchronize()
baseline_time = time.time() - start_time

# Calculate speed ratio
speed_ratio = baseline_time / no_backdrop_time if no_backdrop_time > 0 else float('inf')

print(f"NoBackdrop Training Time: {no_backdrop_time:.4f} seconds")
print(f"Baseline Training Time: {baseline_time:.4f} seconds")
print(f"Speed Ratio: {speed_ratio:.2f}x")
print(f"NoBackdrop Steps per Second: {num_steps / no_backdrop_time:.2f}")
print(f"Baseline Steps per Second: {num_steps / baseline_time:.2f}")

# Plot training speed
plt.figure(figsize=(10, 6))
plt.bar(["NoBackdrop", "Baseline"], [num_steps / no_backdrop_time, num_steps / baseline_time])
plt.title("Training Speed Comparison")
plt.ylabel("Steps per Second")
plt.grid(True, axis='y')
plt.show()

## 9. Conclusion

In this notebook, we've demonstrated the key features of the NoBackdrop model with Forward-Only Token-Adaptive Learning (FOTAL):

1. **Efficient Training**: The model can be trained without backpropagation, making it suitable for limited hardware.
2. **Streaming Adaptation**: The model can adapt to new information during inference.
3. **Single Batch Learning**: The model can learn effectively from a single batch of data.
4. **Memory Efficiency**: The model uses less memory compared to traditional backpropagation models.
5. **Training Speed**: The model trains faster than traditional backpropagation models.

These features make FOTAL a promising approach for efficient language model training and adaptation on limited hardware.