# Bolt-GAN Example Notebook

This notebook demonstrates how to use Bolt-GAN to generate synthetic multivariate time series data.

## ðŸ“‹ Contents
1. Setup and Installation
2. Load and Prepare Data
3. Train Bolt-GAN
4. Generate Synthetic Data
5. Evaluate Quality

## 1. Setup and Installation

In [None]:
# Install dependencies (uncomment if needed)
# !pip install torch numpy matplotlib scikit-learn

import numpy as np
import matplotlib.pyplot as plt
from boltgan import BoltGANTrainer
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load and Prepare Data

For this example, we'll create synthetic sine wave data. 
In practice, replace this with your own time series data.

In [None]:
# Create example data (or load your own)
def create_example_data(n_samples=1000, seq_len=144, n_features=5):
    """Create example multivariate time series data"""
    data = []
    for i in range(n_samples):
        # Create sine waves with different frequencies
        t = np.linspace(0, 4*np.pi, seq_len)
        sample = []
        for j in range(n_features):
            freq = (j + 1) * 0.5
            phase = np.random.rand() * 2 * np.pi
            amplitude = 1.0 + np.random.rand() * 0.5
            noise = np.random.randn(seq_len) * 0.1
            signal = amplitude * np.sin(freq * t + phase) + noise
            sample.append(signal)
        data.append(np.array(sample).T)
    return np.array(data)

# Generate example data
data = create_example_data(n_samples=1000, seq_len=144, n_features=5)
print(f"Data shape: {data.shape}")
print(f"Data range: [{data.min():.3f}, {data.max():.3f}]")

# Visualize a sample
plt.figure(figsize=(12, 4))
for i in range(5):
    plt.plot(data[0, :, i], label=f'Feature {i+1}')
plt.title('Example Real Data Sample')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Optional: Load your own data
# data = np.load('your_data.npy')  # Shape: [samples, timesteps, features]

## 3. Train Bolt-GAN

In [None]:
# Initialize trainer
trainer = BoltGANTrainer(
    seq_len=data.shape[1],
    feature_dim=data.shape[2],
    latent_dim=100,
    lstm_hidden=128,      # Reduced for faster training
    lstm_layers=2,        # Reduced for faster training
    dropout=0.3,
    lr=0.0002,
    num_epochs=100,       # Reduced for demo (use 2500 for real training)
    feedback_interval=20, # Reduced for demo (use 500 for real training)
    feedback_epochs=10,
    batch_size=32
)

# Train (this will take a while)
print("Training Bolt-GAN...")
g_losses, d_losses = trainer.train(data, save_dir='./demo_results')

In [None]:
# Plot training losses
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(g_losses, label='Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(d_losses, label='Discriminator Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Generate Synthetic Data

In [None]:
# Generate synthetic samples
num_synthetic = 100
synthetic_data = trainer.generate(num_samples=num_synthetic)

print(f"Generated data shape: {synthetic_data.shape}")
print(f"Generated data range: [{synthetic_data.min():.3f}, {synthetic_data.max():.3f}]")

# Visualize synthetic samples
plt.figure(figsize=(12, 4))
for i in range(5):
    plt.plot(synthetic_data[0, :, i], label=f'Feature {i+1}')
plt.title('Example Synthetic Data Sample')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Compare real vs synthetic (overlay plot)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx in range(6):
    ax = axes[idx]
    
    # Plot real
    for i in range(data.shape[2]):
        ax.plot(data[idx, :, i], alpha=0.6, linewidth=1)
    
    # Plot synthetic
    for i in range(synthetic_data.shape[2]):
        ax.plot(synthetic_data[idx, :, i], alpha=0.6, linewidth=1, linestyle='--')
    
    ax.set_title(f'Sample {idx+1}')
    ax.grid(True, alpha=0.3)
    
    if idx == 0:
        ax.legend(['Real (solid)', 'Synthetic (dashed)'])

plt.suptitle('Real vs Synthetic Data Comparison', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Evaluate Quality

In [None]:
# Simple statistical comparison
print("Statistical Comparison:")
print("=" * 50)

for i in range(data.shape[2]):
    real_mean = data[:, :, i].mean()
    real_std = data[:, :, i].std()
    
    synth_mean = synthetic_data[:, :, i].mean()
    synth_std = synthetic_data[:, :, i].std()
    
    print(f"\nFeature {i+1}:")
    print(f"  Real:      mean={real_mean:.3f}, std={real_std:.3f}")
    print(f"  Synthetic: mean={synth_mean:.3f}, std={synth_std:.3f}")
    print(f"  Difference: {abs(real_mean - synth_mean):.3f} (mean), {abs(real_std - synth_std):.3f} (std)")

In [None]:
# Distribution comparison
from sklearn.manifold import TSNE

# Flatten for t-SNE
real_flat = data[:200].reshape(200, -1)
synth_flat = synthetic_data[:100].reshape(100, -1)

# Combine
combined = np.vstack([real_flat, synth_flat])
labels = np.array([0]*200 + [1]*100)

# t-SNE
print("Computing t-SNE (this may take a minute)...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embedded = tsne.fit_transform(combined)

# Plot
plt.figure(figsize=(10, 8))
plt.scatter(embedded[labels==0, 0], embedded[labels==0, 1], 
           c='red', alpha=0.5, label='Real', s=30)
plt.scatter(embedded[labels==1, 0], embedded[labels==1, 1], 
           c='blue', alpha=0.5, label='Synthetic', s=30)
plt.legend()
plt.title('t-SNE: Real vs Synthetic Data')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.grid(True, alpha=0.3)
plt.show()

print("\nâœ… If red and blue points overlap well, the synthetic data quality is good!")

## 6. Save Results

In [None]:
# Save synthetic data
np.save('synthetic_data.npy', synthetic_data)
print("Saved synthetic_data.npy")

print("\nâœ… Example complete!")
print("\nNext steps:")
print("1. Train on your own dataset")
print("2. Increase epochs to 2500 for better quality")
print("3. Evaluate with downstream ML tasks")
print("4. Compare with baseline methods")