# Topology-Aware Learning Trajectory (TALT) Optimizer Demo

This notebook demonstrates how to use the improved TALT optimizer for training neural networks.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os

# Ensure TALT is installed/available
import sys
sys.path.append('..')

import talt

## 1. Setup

Let's set up our environment and parameters for the experiment.

In [None]:
# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Experiment parameters
dataset_name = 'cifar10'
batch_size = 128
epochs = 5
learning_rate = 0.01

## 2. Load Dataset

We'll use the CIFAR-10 dataset for this demonstration.

In [None]:
# Load dataset
with talt.Timer("Loading dataset"):
    train_loader, test_loader, num_channels, image_size, num_classes = talt.get_loaders(
        dataset_name=dataset_name,
        batch_size=batch_size,
        data_dir='./data'
    )

# Show a few sample images
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.figure(figsize=(10, 4))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    
# Get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images using matplotlib instead of torchvision
fig, axes = plt.subplots(1, 8, figsize=(16, 2))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

for i in range(8):
    img = images[i] / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    axes[i].imshow(np.transpose(npimg, (1, 2, 0)))
    axes[i].set_title(classes[labels[i]])
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 3. Create Model

We'll use a simple CNN model provided by the TALT package.

In [None]:
# Create model
model = talt.SimpleCNN(
    num_channels=num_channels,
    image_size=image_size,
    num_classes=num_classes
)
model = model.to(device)

# Print model summary
print("Model architecture:")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 4. Train with Standard SGD Optimizer

In [None]:
# Train with standard SGD
print("\nTraining with standard SGD optimizer...")
sgd_results = talt.train_and_evaluate_improved(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=epochs,
    learning_rate=learning_rate,
    use_improved_talt=False,
    device=device,
    visualization_dir="./visualizations",
    experiment_name="SGD_Example"
)

## 5. Train with Improved TALT Optimizer

In [None]:
# Reset model
model = talt.SimpleCNN(
    num_channels=num_channels,
    image_size=image_size,
    num_classes=num_classes
).to(device)

# Train with improved TALT
print("\nTraining with improved TALT optimizer...")
talt_results = talt.train_and_evaluate_improved(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    epochs=epochs,
    learning_rate=learning_rate,
    use_improved_talt=True,
    device=device,
    projection_dim=32,
    update_interval=20,
    valley_strength=0.2,
    smoothing_factor=0.3,
    visualization_dir="./visualizations",
    experiment_name="TALT_Example"
)

## 6. Compare Results

Let's compare the performance of the standard SGD optimizer versus the improved TALT optimizer.

In [None]:
# Plot training and test loss
plt.figure(figsize=(12, 5))

# Plot training loss
plt.subplot(1, 2, 1)
plt.plot(sgd_results['train_loss'], label='SGD')
plt.plot(talt_results['train_loss'], label='TALT')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Plot test accuracy
plt.subplot(1, 2, 2)
plt.plot(sgd_results['test_acc'], label='SGD')
plt.plot(talt_results['test_acc'], label='TALT')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Test Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final results
print("\nFinal Results:")
print(f"SGD Test Accuracy: {sgd_results['final_test_acc']:.2f}% (Time: {sgd_results['total_time']:.2f}s)")
print(f"TALT Test Accuracy: {talt_results['final_test_acc']:.2f}% (Time: {talt_results['total_time']:.2f}s)")

## 7. View Visualizations

The TALT package automatically generates visualizations. Let's take a look at them.

In [None]:
# Create a visualizer and generate additional visualizations
model_state_dict = model.state_dict()
visualizer = talt.ImprovedTALTVisualizer(output_dir="./visualizations/custom")

# Generate some example visualizations
os.makedirs("./visualizations/custom", exist_ok=True)

# You can add data from your optimizer here
# visualizer.add_data(optimizer._visualization_data)

# Display the loss visualization files
sgd_viz_path = "./visualizations/SGD_Example/loss_trajectory.png"
talt_viz_path = "./visualizations/TALT_Example/loss_trajectory.png"

if os.path.exists(sgd_viz_path) and os.path.exists(talt_viz_path):
    plt.figure(figsize=(18, 8))
    
    # SGD
    plt.subplot(1, 2, 1)
    img = plt.imread(sgd_viz_path)
    plt.imshow(img)
    plt.axis('off')
    plt.title("SGD Loss Trajectory")
    
    # TALT
    plt.subplot(1, 2, 2)
    img = plt.imread(talt_viz_path)
    plt.imshow(img)
    plt.axis('off')
    plt.title("TALT Loss Trajectory")
    
    plt.tight_layout()
    plt.show()
else:
    print("Visualization files not found. Please check if the training completed successfully.")

## 8. Conclusion

In this notebook, we demonstrated how to use the improved TALT optimizer and compared its performance with the standard SGD optimizer. The visualizations showed how TALT detects valleys in the loss landscape to accelerate convergence.